In [None]:
import glob
import os
import random
import warnings

import numpy as np
import pandas as pd
import pylab as plt
import seaborn as sns
from spacepy import pycdf

import moms_fast
import nnet_evaluate
import utils

%matplotlib inline

sns.set_style('darkgrid')

N_EN = 32
N_EN_SHELLS = 2
N_PHI = 32
N_THETA = 16

In [None]:
test_data = nnet_evaluate.load_test_data('4D_tail')
f1ct = utils.get_f1ct({'4D_tail': test_data}, ['4D_tail'])
hidden_layer_size = 50

In [None]:
df = pd.read_csv('/mnt/efs/dasilva/compression-cfha/data/test_train_split.csv')
df = df[df['test_train'] == 'test']
df = df[df['phase'] == '4D_tail']
df.head()

In [None]:
files = [f for f in list(df.file_path) if 'mms1' in f]
cdf_filename = random.choice(files)
print(cdf_filename)
cdf_filename = '/mnt/efs/dasilva/compression-cfha/data/mms_data/4D_tail/mms1_fpi_brst_l2_dis-dist_20190720043943_v3.3.0.cdf'

In [None]:
cdf = pycdf.CDF(cdf_filename)

dist = cdf['mms1_dis_dist_brst'][:]
dist_err = cdf['mms1_dis_disterr_brst'][:]
epoch = cdf['Epoch'][:]
ntime = epoch.size
counts = np.zeros((ntime, N_PHI, N_THETA, N_PHI))

for i in range(ntime):
    with warnings.catch_warnings():
        warnings.filterwarnings('ignore')
        tmp_counts = np.square(dist[i] / dist_err[i])
    tmp_counts[np.isnan(tmp_counts)] = 0
    tmp_counts = np.rint(tmp_counts)
    counts[i] = tmp_counts
    
cdf.close()

In [None]:
runs = ['4D_tail']
models = {}

for run in runs:
    for en_index in range(0, N_EN, N_EN_SHELLS):
        models[run, en_index] = nnet_evaluate.load_model(
            '4D_tail', hidden_layer_size, en_index,
            outpath=(f'/mnt/efs/dasilva/compression-cfha/data/nnet_models'
                     f'/hidden_layer_exp/{run}/')
        )

In [None]:
def reconstruct(counts, models):
    counts_recon = np.zeros_like(counts)

    for j in range(counts.shape[0]):
        # Branch model
        num_nonzeros = counts[j].nonzero()[0].size
        
        for en_index in range(0, N_EN, N_EN_SHELLS):
            i, di = en_index, N_EN_SHELLS
            
            model = models[runs[0], i]
            model_input = counts[j, :, :, i:i+di]     
            model_output = model([model_input]).numpy()
            counts_recon[j, :, :, i:i+di] = model_output[0]
    
            avg_orig = counts[j, :, :, i:i+di].mean()
            avg_recon = counts_recon[j, :, :, i:i+di].mean()
    
            if avg_orig == 0:
                counts_recon[j, :, :, i:i+di] = 0
            elif avg_recon > 0:
                counts_recon[j, :, :, i:i+di] *=  avg_orig / avg_recon 
    
    return counts_recon 

In [None]:
counts_recon = reconstruct(counts, models)

In [None]:
moms_true = [moms_fast.fast_moments(f1ct * c) for c in counts]
moms_recon = [moms_fast.fast_moments(f1ct * c) for c in counts_recon]

In [None]:
vars = ['n', 'vx', 'vy', 'vz', 'txx', 'tyy', 'tzz']
fig, axes = plt.subplots(len(vars), 1, figsize=(15, 4*len(vars)))

for i, var in enumerate(vars):

    #axes[i].set_title(var, fontsize=16)
    axes[i].plot(epoch, [d[var] for d in moms_true], label=f'{var} True')
    axes[i].plot(epoch, [d[var] for d in moms_recon], label=f'{var} Reconstructed')
    axes[i].legend()
    if var == 'n':
        axes[i].set_ylim([0, 1.1 * np.max([d[var] for d in moms_true])])
        axes[i].set_ylabel('n ($cm^{-3}$)', fontsize=16)
    elif var[0] == 'v':
        axes[i].set_ylabel(f'{var} (km/s)', fontsize=16)
    elif var[0] == 't':
        axes[i].set_ylabel(f'{var} (eV)', fontsize=16)

    #axes[i].set_xlim(epoch[0], epoch[epoch.size//4])
fig.suptitle(f'Dimensionality Reduction => {100*hidden_layer_size/(32*16*2):.1f}%')
fig.tight_layout()
os.makedirs('plots', exist_ok=True)
fig.savefig('plots/' + os.path.basename(cdf_filename).split('.')[0] + '.png')