In [None]:
# Data stuffs
import pickle 
import numpy as np
import pandas as pd
from operator import itemgetter

# Plotting
import matplotlib
import matplotlib.cm as cm
import matplotlib.pyplot as plt 
from matplotlib.gridspec import GridSpec
from moxie.data.utils_ import load_data

# Scipy stuffs 
from scipy import interpolate
from scipy.signal import savgol_filter
from scipy.stats import truncnorm
# Fancy 
from tqdm.notebook import tqdm  

SMALL_SIZE = 14
MEDIUM_SIZE = 16
BIGGER_SIZE = 22

plt.rc('font', size=BIGGER_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=BIGGER_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=BIGGER_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=MEDIUM_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title
%matplotlib widget

In [None]:
# Load Data 
train_data, val_data, test_data = load_data(dataset_choice='SANDBOX_NO_VARIATIONS', file_loc='../../../moxie/data/processed/pedestal_profiles_ML_READY_ak_5052022_uncerts_mask.pickle', elm_timings=False)
# profiles, mps, masks, psis, rmids, trainids, uncerts
train_ids, val_ids, test_ids = np.array(train_data[-2]), np.array(val_data[-2]), np.array(test_data[-2])

# ELM TIMINGS PER DATASET
with open('../../data/processed/new_elm_timings_idxs.pickle', 'rb') as file: 
    GLOBAL_ELM_TIMING_DICT = pickle.load(file) 

# Pulse ids is a list of lists that corresponds to the indexes in the dataset to take.
train_new_pulse_idxs, train_elm_percentages = itemgetter('pulse_idx', 'elm_percentages')(GLOBAL_ELM_TIMING_DICT['train'])
val_new_pulse_idxs, val_elm_percentages = itemgetter('pulse_idx', 'elm_percentages')(GLOBAL_ELM_TIMING_DICT['val'])
test_new_pulse_idxs, test_elm_percentages = itemgetter('pulse_idx', 'elm_percentages')(GLOBAL_ELM_TIMING_DICT['test'])

with open('../../data/processed/neseps.pickle', 'rb') as file: 
    GLOBAL_NESEPS_DICT = pickle.load(file)
    
# these go with the indexes of the above pulse idxss
train_neseps, val_neseps, test_neseps = GLOBAL_NESEPS_DICT['train']['nesep'], GLOBAL_NESEPS_DICT['val']['nesep'], GLOBAL_NESEPS_DICT['test']['nesep']

def unique(sequence):
    seen = set()
    return [x for x in sequence if not (x in seen or seen.add(x))]

# Get the ordering of pulses for each dataset. To be used in conjunction with the for the  
train_pulse_order = [int(x.split('/')[0]) for x in train_ids]
train_pulses_ordered_set = unique(train_pulse_order)

val_pulse_order = [int(x.split('/')[0]) for x in val_ids]
val_pulses_ordered_set = unique(val_pulse_order)

test_pulse_order = [int(x.split('/')[0]) for x in test_ids]
test_pulses_ordered_set = unique(test_pulse_order)

# JET PDB 
machine_param_order = ['Q95', 'RGEO', 'CR0', 'VOLM', 'TRIU', 'TRIL', 'ELON', 'POHM', 'IPLA', 'BVAC', 'NBI', 'ICRH', 'ELER']
JET_PDB = pd.read_csv('../../../moxie/data/processed/jet-pedestal-database.csv')
PULSE_DF_SANDBOX = JET_PDB[(JET_PDB['FLAG:HRTSdatavalidated'] > 0) & (JET_PDB['shot'] > 80000) & (JET_PDB['Atomicnumberofseededimpurity'].isin([0, 7])) & (JET_PDB['FLAG:DEUTERIUM'] == 1.0) & (JET_PDB['FLAG:Kicks'] == 0.0) & (JET_PDB['FLAG:RMP'] == 0.0) & (JET_PDB['FLAG:pellets'] == 0.0)]

In [None]:
def collect_set(dataset, set_idxs, set_pulse_order, set_elm_percentages, set_neseps):
    profiles, mps, masks, psis, rmids, names, uncerts = dataset
    
    iterator = tqdm(enumerate(zip(set_idxs, set_pulse_order)), total=len(set_pulse_order))
    
    x_data, target_data = [], []
    
    
    for pulse, (pulse_idxs, pulse_number) in iterator:  
        iterator.set_description_str('{} #SLICES {}'.format(pulse_number, len(pulse_idxs)))
        if len(pulse_idxs) == 0: 
            continue
        if pulse_number == 83294: 
            continue
            
        # pulse_profs, pulse_uncerts, pulse_rmids, pulse_masks = profiles[pulse_idxs], uncerts[pulse_idxs], rmids[pulse_idxs], masks[pulse_idxs]
        pulse_neseps, pulse_mps = set_neseps[pulse_idxs], mps[pulse_idxs]
        
        not_nans = np.invert(np.isnan(pulse_neseps))
        pulse_neseps, pulse_mps = pulse_neseps[not_nans], pulse_mps[not_nans]
        
        mean, std = pulse_neseps.mean(), pulse_neseps.std()
        within_cutoff = (abs(pulse_neseps - mean) > std)
        
        pulse_neseps, pulse_mps = pulse_neseps[within_cutoff], pulse_mps[within_cutoff] 
        
        x_data.append(pulse_mps)
        target_data.extend(pulse_neseps)
        
    return np.array(x_data), target_data
    
