In [None]:
import gc
import h5py
import joblib
import json
from matplotlib.colors import LogNorm
import numpy as np
import os
import pandas as pd
import progressbar
import pylab as plt
import warnings

%matplotlib inline

from spacepy import pycdf

warnings.filterwarnings("ignore", category=DeprecationWarning) 

In [None]:
with open('hyperparams.json') as fh:
    hyperparams = json.load(fh)

In [None]:
df_train_and_test = pd.read_csv('/mnt/efs/dasilva/compression-cfha/data/test_train_split.csv')
df_train_and_test[df_train_and_test.file_path.str.contains('mms1')]
df_train = df_train_and_test[df_train_and_test.test_train=='train']
df_test = df_train_and_test[df_train_and_test.test_train=='test']
df_train.head()

In [None]:
df_train.file_path.tolist()[:20]

In [None]:
df_test.file_path.tolist()[:20]

In [None]:
phases = sorted(set(df_train.phase))
phases 

In [None]:
for phase in phases:
    print(f'{phase.ljust(15)} {df_train[df_train.phase==phase].lengths.sum()} frames')

In [None]:
def get_data(approx_count, df, phase):
    df_phase = df[df.phase==phase]
    
    # Collect list of tasks
    p = approx_count / df_phase.lengths.sum() 
    tasks = []
    
    for _, row in df_phase.iterrows():
        bool_mask = np.random.rand(row.lengths) < p
        if not bool_mask.any():
            continue
        tasks.append(joblib.delayed(get_from_file)(row.file_path, row.lengths, bool_mask))
    
    # Run tasks in parallel and aggregate results
    results = joblib.Parallel(n_jobs=50, verbose=10)(tasks)
    output = {key: [] for key in results[0].keys()}
    
    for result in results:
        for key in result:
            output[key].extend(result[key])

    return output


def get_from_file(file_path, length, bool_mask):
    gc.collect()
    warnings.filterwarnings("ignore", category=DeprecationWarning) 
    
    cdf = pycdf.CDF(file_path)
    mms_prefix = get_mms_prefix(file_path)

    result = {'dist': [], 'counts': [], 'phi': [], 'theta': [], 'E': []}
    
    energy_table = str(cdf.attrs['Energy_table_name'])
    if '12-14' in energy_table:
        # Skip solar wind data
        return result
    
    for i in bool_mask.nonzero()[0]:
        lossy = bool(cdf[f'{mms_prefix}_dis_compressionloss_brst'][i])
        if lossy:
            continue # lossy frame-- don't include it
        
        result['dist'].append(cdf[f'{mms_prefix}_dis_dist_brst'][i, :, :, :])

        with warnings.catch_warnings():
            warnings.filterwarnings('ignore')
            counts = np.square(result['dist'][-1] / cdf[f'{mms_prefix}_dis_disterr_brst'][i, :, :, :])
        counts[np.isnan(counts)] = 0
        counts = np.rint(counts) # round to nearest int
        result['counts'].append(counts)
        
        result['phi'].append(cdf[f'{mms_prefix}_dis_phi_brst'][i, :])
        result['theta'].append(cdf[f'{mms_prefix}_dis_theta_brst'][:])
        result['E'].append(cdf[f'{mms_prefix}_dis_energy_brst'][i, :])

    cdf.close()

    return result


def get_mms_prefix(file_path):
    for i in range(1, 5):
        if f'mms{i}' in file_path:
            return f'mms{i}'
    return None

In [None]:
tasks = [
    ('train', df_train, hyperparams['prep']['sample_sizes']['train']),
    ('test', df_test, hyperparams['prep']['sample_sizes']['test'])
]

for label, df, sample_size in tasks:
    for phase in phases:
        print(f'{label.ljust(10)} - {phase.ljust(20)}', end='')

        output = get_data(sample_size, df, phase)

        hdf_path = f'/mnt/efs/dasilva/compression-cfha/data/samples_{label}_n={sample_size}_nosw.hdf'
        hdf = h5py.File(hdf_path, 'a')
        
        try:
            group = hdf.create_group(phase)
        except ValueError:
            del hdf[phase]
            group = hdf.create_group(phase)
        
        for key in output:
            group[key] = output[key]
        
        hdf.close()
        print(f'...wrote to {hdf_path}')
        
        del output
        gc.collect()