# Creating the dataset

We are using data from JET. 

### Starting data 

- HRTS profiles 
    - Electron density and temperature, $n_e$, $T_e$
    - HRTS Line of Sight coordinates, $R_{HRTS}$
- Machine parameters
    - $R$ [m], plasma major radius
    - $a$ [m], plasma minor radius 
    - $V_P$ [m$^3$], total plasma volume enclosed by LCFS
    - $\delta_u,\delta_l$ [-], upper and lower triangularities  
    - $\kappa$ [-], elongation, or ratio of height to width of plasma
    - $P_{OHM}$ [W], ohmic power of the plasma
    - $I_P$ [A], total plasma current enclosed by LCFS 
    - $B_T$ [T], total toroidal magnetic field
    - $P_{NBI}$ [W], total neutral beam injected power into the plasma
    - $P_{ICRH}$ [W], total ion cyclotron resonance heating imposed onto plasma
    - $\Gamma$ [e/s], fuelling rate of main isotope (aka gas puff)
    - $q_{cycl}$ [-], safety factor, calculated from $q_{cycl} = \frac{(1+2\kappa^2)}{2}\frac{2B_T\pi a^2}{RI_P\mu_0}$
    
#### Profile Information 

- Pulses/time windows found in JET pedestal database
- Subset of pulses: 
    - ILW
    - HRTS Validated 
    - Pulses with either no seeding or with nitrogen seeding
    - Deuterium fuelling
    - No kicks, rmps, or pellets
    


After filtering above: 
- Total number of pulses & time slices
    - 1248
   
### Further slice filtering 

#### ELM Percents

Slices which an ELM percentage could be calculated, using ELM timings from JET pedestal database


After filtering: 
- Total number of slices from 608 pulses: 
    - 23499
    
#### $n_{e, sep}$

- TODO: Describe process of gathering $n_{e, sep}$
- Calculation is not perfect, so we remove slices that have an $n_{e, sep}$ approximation that falls outside of 1 standard deviation of the mean $n_{e, sep}$ for the pulse. 

- Total number of slices: 
    - 16111
    
## Profile padding 

The maximum length of the profiles is 19, so all will be padded to that value. 


In [1]:
from tqdm.notebook import tqdm 
import numpy as np
import pickle
from scipy import interpolate
from scipy.signal import savgol_filter
from scipy.stats import norm
import pandas as pd
import math
from collections.abc import Iterable
import random 

In [2]:
dataset_choice='SANDBOX_NO_VARIATIONS'
file_loc='../../../moxie/data/processed/pedestal_profiles_ML_READY_ak_5052022_uncerts_mask.pickle'
with open(file_loc, 'rb') as file:
    massive_dict = pickle.load(file)
full_dict = massive_dict[dataset_choice]
massive_dict = {}
with open('../../data/raw/new_elm_timings_catch.pickle', 'rb') as file: 
     JET_ELM_TIMINGS = pickle.load(file) 
JET_PDB = pd.read_csv('../../../moxie/data/processed/jet-pedestal-database.csv')
PULSE_DF_SANDBOX = JET_PDB[(JET_PDB['FLAG:HRTSdatavalidated'] > 0) & (JET_PDB['shot'] > 80000) & (JET_PDB['Atomicnumberofseededimpurity'].isin([0, 7])) & (JET_PDB['FLAG:DEUTERIUM'] == 1.0) & (JET_PDB['FLAG:Kicks'] == 0.0) & (JET_PDB['FLAG:RMP'] == 0.0) & (JET_PDB['FLAG:pellets'] == 0.0)]

dataset = full_dict['all_dict']['raw']['profiles'], full_dict['all_dict']['raw']['real_space_radii'], full_dict['all_dict']['raw']['pulse_time_ids'], full_dict['all_dict']['raw']['uncerts'], full_dict['all_dict']['raw']['controls']
profiles, rmids, ids, uncerts, mps = dataset


In [3]:
def replace_q95_with_qcly(mp_set, mu, var):
    mp_set = de_standardize(mp_set, mu, var)
    mu_0 = 1.25663706e-6 # magnetic constant
    
    mp_set[:, 0] = ((1 + 2*mp_set[:, 6]**2) / 2.0) * (2*mp_set[:, 9]*torch.pi*mp_set[:, 2]**2) / (mp_set[:, 1] * mp_set[:, 8] * mu_0)
    mp_set = standardize(mp_set, mu, var)
    return mp_set
def extend_used_dict(pulse_num, dict_to_update, initialize=False, **kwargs): 
    if pulse_num not in dict_to_update.keys(): 
        initialize = True
    for key in kwargs: 
        if initialize: 
            dict_to_update[pulse_num] = {}
            dict_to_update[pulse_num][key] = []
        dict_to_update[pulse_num][key].append(kwargs[key])
    return dict_to_update

In [4]:
# boilerplate stuff
import math
from collections.abc import Iterable

class RunningStats:
    def __init__(self):
        self.n = 0
        self.old_m = 0
        self.new_m = 0
        self.old_s = 0
        self.new_s = 0

    def clear(self):
        self.n = 0

    def push(self, x):
        if isinstance(x, Iterable):
            for v in x:
                self.push(v)
            return

        self.n += 1

        if self.n == 1:
            self.old_m = self.new_m = x
            self.old_s = 0
        else:
            self.new_m = self.old_m + (x - self.old_m) / self.n
            self.new_s = self.old_s + (x - self.old_m) * (x - self.new_m)

            self.old_m = self.new_m
            self.old_s = self.new_s

    def mean(self):
        return self.new_m if self.n else 0.0

    def variance(self):
        return self.new_s / (self.n - 1) if self.n > 1 else 0.0

    def standard_deviation(self):
        return math.sqrt(self.variance())

    def __repr__(self):
        return f'n: {self.n}, mean: {self.mean()}, var: {self.variance()}, sd: {self.standard_deviation()}'
    
def standardize_signal(x, trim_zeros=True):
    if trim_zeros:
        x_in = np.trim_zeros(x)
    else:
        x_in = x
    rs = RunningStats() # numpy.std goes to inf so do it by hand
    rs.push(x_in)
    return (x - rs.mean()) / rs.standard_deviation()

def find_nearest(query, data, idx_must_be='any'):
    if idx_must_be not in ['any', 'smaller', 'greater']:
        raise ValueError('"idx_must_be" must be in [any, smaller, greater]')
    # ASSUMED A SORTED ARRAY [INCREASING] #
    # optionally, if query is inbetween elements, can pick which direction op
    if query < data[0]:
        if idx_must_be == 'smaller':
            raise ValueError('Cannot return smaller value as "query" is smaller than any datapoint')
        return 0
    elif query > data[-1]:
        if idx_must_be == 'greater':
            raise ValueError('Cannot return greater value as "query" is greater than any datapoint')
        return len(data)-1
    else:
        idx = np.searchsorted(data, query, side='left')
        if data[idx] == query:
            return idx
        elif idx_must_be == 'any':
            if abs(query - data[idx-1]) < abs(query - data[idx]):
                return idx-1
            else:
                return idx
        elif idx_must_be == 'greater':
            return idx
        elif idx_must_be == 'smaller':
            return idx - 1
        else:
            raise ValueError('invalid argument for "idx_must_be"')
            
def get_nearest_weighted_idx(query, data, sort=True):
    if sort:
        unsorted = np.array(data)
        sortidx = np.argsort(data)
        sort_reverse = {i: sortidx[i] for i in range(len(data))}

        data = np.array(data)[sortidx]
    data = list(data)
    if query in data: # exact match
        idx = [(1, data.index(query))] # weight, index
    elif query < data[0]: # before first
        idx = [(1, 0)] 
    elif query > data[-1]: # after last
        idx = [(1, len(data)-1)]
    else:
        # get nearest two elements
        i_r = np.searchsorted(data, query)
        i_l = i_r - 1
        dist = data[i_r] - query + query - data[i_l]
        idx = [(1-(data[i_r] - query)/dist, i_r), (1-(query - data[i_l])/dist, i_l)]

    # convert back to unsorted idx
    if sort:
        idx = [(w, sort_reverse[i]) for (w, i) in idx]
    return idx
def pedestal_top(p, x, plot_result=False):
    # standardize signal
    p = standardize_signal(p, trim_zeros=True)

    # interp signal to Nx=50
    f_interp = interpolate.interp1d(x, p)
    x_h = np.linspace(x[0], x[-1], 50)  # interpolate to 50 (evenly spaced) points
    p_h = f_interp(x_h)

    # smooth with savgol filter
    p_s = savgol_filter(p_h, window_length=11, polyorder=3)

    # get max gradient so we're in the pedestal
    p_s_grad = np.gradient(p_s)
    min_i = np.argmin(p_s_grad)


    # search from pedestal region outward in 2nd derivatives
    p_s_grad2 = np.gradient(p_s_grad)
    p_s_grad2 = savgol_filter(p_s_grad2, window_length=11, polyorder=3)  # aggressively smooth as well
    p_s_grad2 = standardize_signal(p_s_grad2, trim_zeros=True)
    # standardize s.t. if we go >1 sd up/down, we stop searching

    # go to the left from middle point
    sd_cutoff = -.5
    start_cut_early = False  # if we go < -1 for x'', mark as such, such that if we go > -1 again we stop looking
    min_val = p_s_grad2[min_i]
    top_i = min_i
    for i in reversed(range(0, min_i)):
        if p_s_grad2[i] < min_val:
            min_val = p_s_grad2[i]
            top_i = i
            if min_val < sd_cutoff:
                start_cut_early = True
        elif start_cut_early and p_s_grad2[i] > sd_cutoff:
            break

    # found our x for top
    # same procedure for bottom
    sd_cutoff = .5
    start_cut_early = False  # if we go > 1 for x'', etc.
    max_val = p_s_grad2[min_i]
    bottom_i = min_i

    for i in range(min_i+1, x_h.shape[0]):
        if p_s_grad2[i] > max_val:
            max_val = p_s_grad2[i]
            bottom_i = i
            if max_val > sd_cutoff:
                start_cut_early = True
        elif start_cut_early and p_s_grad2[i] < sd_cutoff:
            break

    top, bottom = x_h[top_i], x_h[bottom_i]

    if plot_result:
        plt.plot(x_h, p_s)
        plt.axvline(top, color='red', label='estimated top')
        plt.axvline(bottom, color='blue', label='estimated bottom')
        plt.legend()
        plt.show()
    return top, bottom

def find_separatrix(ne, te, x, plot_result=False):
    top_x, bottom_x = pedestal_top(ne, x, plot_result=False)
    
    # initial estimate
    x_sep = (1/4) * top_x + (3/4) * bottom_x
    x_initial = x_sep
    # scan area around x_sep
    # first, interp to higher space
    x_h = np.linspace(x[0], x[-1], 200)  # interpolate to 50 (evenly spaced) points
    
    f_ne_interp = interpolate.interp1d(x, ne)
    ne_h = f_ne_interp(x_h)
    ne_s = savgol_filter(ne_h, window_length=11, polyorder=3)
    f_te_interp = interpolate.interp1d(x, te)
    te_h = f_te_interp(x_h)
    te_s = savgol_filter(te_h, window_length=11, polyorder=3)
    
    # closest x
    s_i = find_nearest(x_sep, x_h)
    # scan inwards/outwards
    n_step = 10
    te_target = 100 # 100 eV
    te_distr = norm(te_target, 10)  # distribution to weigh new points against
    weight_func = lambda x: x ** 0.05
    
    def te_sep():
        idx = get_nearest_weighted_idx(x_sep, x_h, sort=False)
        val_out = 0
        for w, w_i in idx:
            val_out += w * te_s[w_i]
        return val_out
    
    for i in range(1, n_step+1):
        i_left = s_i - i
        i_right = s_i + i
        if i_left < 0 or i_right > x_h.shape[0] - 1:
            break  # reached outside the grid on 1 side
        weight_pos = i / n_step  # [1/n, 1] adjustment based on how far the value is 'good'
        pdf_l = te_distr.pdf(te_s[i_left])
        adjust_l = weight_pos * pdf_l
        adjust_l = weight_func(adjust_l)
        if te_s[i_left] > te_target and te_sep() < te_target:  # on the left is above 100 --> we need to move left
            x_sep = (x_sep + x_h[i_left] * adjust_l) / (1 + adjust_l)

        pdf_r = te_distr.pdf(te_s[i_right])
        adjust_r = weight_pos * pdf_r
        adjust_r = weight_func(adjust_r)
        if te_s[i_right] < te_target and te_sep() > te_target:  # on the right is below 100 --> we need to move right
            x_sep = (x_sep + x_h[i_right] * adjust_r) / (1 + adjust_r)
    
    # closest_x_index = np.argmin(abs(x_h - x_sep))
    # from_fit = x_h[closest_x_index], te_h[closest_x_index], ne_h[closest_x_index]
    interp_idx_l, interp_idx_r = get_idx_for_linear_interp(x_sep, x, te)
    weights_r, weights_l = get_weights(te, interp_idx_l, interp_idx_r)
    tesep_estimation = weights_l*te[interp_idx_l] + weights_r*te[interp_idx_r]
    nesep_estimation = weights_l*ne[interp_idx_l] + weights_r*ne[interp_idx_r]
    rsep_estimation = weights_l*x[interp_idx_l] + weights_r*x[interp_idx_r]
    if nesep_estimation < 0: 
        plot_result = True
    if plot_result:
        fig, axs=plt.subplots(1, 2, figsize=(6,3))
        axs[0].plot(x_h, te_s)
        axs[0].scatter(x, te)
        axs[0].scatter(x[interp_idx_l], te[interp_idx_l], color='orange')
        axs[0].scatter(x[interp_idx_r], te[interp_idx_r], color='orange')
        
        axs[0].axvline(x_initial, color='red', label='initial estimate')
        axs[0].axvline(x_sep, color='green', label='final estimate')
        axs[0].axhline(100, color='black', label='100eV')
        axs[0].scatter(rsep_estimation, tesep_estimation, color='black',marker='*', s=250)

        axs[1].plot(x_h, ne_s)
        axs[1].scatter(x, ne)
        axs[1].scatter(rsep_estimation, nesep_estimation, color='black',marker='*', s=250)
        axs[1].axvline(x_initial, color='red', label='initial prior')
        axs[1].axvline(x_sep, color='green', label='final prior')
        axs[1].axvline(rsep_estimation, color='black', ls='--', label='final est')
        axs[1].scatter(x[interp_idx_l], ne[interp_idx_l], color='orange')
        axs[1].scatter(x[interp_idx_r], ne[interp_idx_r], color='orange')
        axs[1].legend()
        
        plt.show()
    
    return tesep_estimation, nesep_estimation, rsep_estimation

def get_idx_for_linear_interp(x_sep, x, te): 
    closest_point = np.argmin(abs(x - x_sep))
    if (te < 100).sum() == 0: 
        closest_point = np.argmin(abs(te - 100))
        idx_r = closest_point
        idx_l = idx_r - 1
   
    elif te[closest_point] < 100: 
        idx_r = closest_point
        idx_l = closest_point - 1
        while te[idx_l] < 100: 
            idx_r = idx_l
            idx_l = idx_l - 1
    else: 
        idx_r = closest_point + 1
        idx_l = closest_point
        while te[idx_r] > 100:
            idx_l = idx_r
            idx_r = idx_l + 1
            
    return idx_l, idx_r 
def get_weights(te, idx_l, idx_r, query=100): 
    dist = te[idx_r] - query + query - te[idx_l]
    weights = (1-(te[idx_r] - query)/dist, 1-(query - te[idx_l])/dist) 
    return weights
def find_elm_percent(pulse_num, time): 
    try: 
        pulse_elm_timings_frass = np.array(JET_ELM_TIMINGS[pulse_num])
    except KeyError as e:
        return np.nan
    diff = pulse_elm_timings_frass - time
    try:
        time_last_elm = pulse_elm_timings_frass[diff < 0][-1]
        time_next_elm = pulse_elm_timings_frass[diff > 0][0]
        elm_percent = (time - time_last_elm) / (time_next_elm - time_last_elm)
    except IndexError as e:
        elm_percent = np.nan
    return elm_percent

def get_lorenzo_pred(pulse_num, pulse_time): 
    JPDB_pulse = PULSE_DF_SANDBOX[PULSE_DF_SANDBOX['shot'] == pulse_num]
    if len(JPDB_pulse) > 1: 
        print(JPDB_pulse[['t1', 't2']])
        for local_pulse in JPDB_pulse: 
            t1, t2 = local_pulse[['t1', 't2']]
            if t1 < pulse_time and t2 > pulse_time: 
                nesep_exp, nesep_lor = local_pulse[['neseparatrixfromexpdata10^19(m^-3)', 'neseparatrixfromfit10^19(m^-3)']].values[0]
            else: 
                continue
    else: 
        index_to_take_from = 0
        nesep_exp, nesep_lor = JPDB_pulse[['neseparatrixfromexpdata10^19(m^-3)', 'neseparatrixfromfit10^19(m^-3)']].values[0]
        
    return nesep_exp, nesep_lor

In [5]:
def make_dataset(all_dataset): 
    """
    This organizes all the data by pulse.  
    We can then split by different pulses for training/val/test later. 
    
    Parameters
    ==========
    the whole dataset! Found earlier, but a 
    
    Returns
    =======
    
    used_dict: a dictionary containing all data by pulse number: 
        used_dict[123456]['profiles'] corresponds to the profiles of pulse
        with keys: 'profiles', 'neseps', 'ids', 'mps', 'elm_perc', 'lor_val', 'mask'
    """
    
    profiles, rmids, ids, uncerts, mps = all_dataset 
    iterator = tqdm(range(len(profiles)))
    used_dict = {}
    
    for idx in iterator: 
        name = ids[idx]
        pulse_num, time = name.split('/')
        elm_perc = find_elm_percent(int(pulse_num), float(time))
        if np.isnan(elm_perc): 
            continue
        else: 
            # used_elm_percents.append(elm_perc)
            pass
        _, lor_val = get_lorenzo_pred(int(pulse_num), time)
            
        original = profiles[idx][0], profiles[idx][1],uncerts[idx][0], uncerts[idx][1], rmids[idx]
        find_elm_percent(pulse_num, time)
        ne, te, dne, dte, x = original
        new_ne, new_te =  np.minimum.accumulate(ne), np.minimum.accumulate(te)
        keep_idx = np.where(new_te == te)

        ne, te, dne, dte, x = new_ne[keep_idx], new_te[keep_idx], dne[keep_idx], dte[keep_idx], x[keep_idx]
        logical_bool_mask = np.logical_and(ne > 0, te > 0) # check that nothing is below zero, i mean what the fuck
        logical_bool_mask = np.logical_and(logical_bool_mask, dte > 0) # Also here
        logical_bool_mask = np.logical_and(logical_bool_mask, dne > 0) # Also here 
        logical_bool_mask = np.logical_and(logical_bool_mask, dte < 3000) # Don't want tesep values that are ridonklus
        ne, te, dne, dte, x = ne[logical_bool_mask], te[logical_bool_mask], dne[logical_bool_mask], dte[logical_bool_mask], x[logical_bool_mask]
        if (te < 100).sum() == 0:
            continue
        else: 
            try: 
                estimations = find_separatrix(ne, te, x, plot_result=False)
            except IndexError as e:
                print(e)
                continue
            else: 
                used_dict = extend_used_dict(pulse_num, used_dict, profiles=profiles[idx], neseps=estimations[1], ids=name, mps=mps[idx], elm_perc=elm_perc, lor_val=lor_val, rmids=rmids[idx], rseps=estimations[-1])
    cleaned_dict = clean_used_dict(used_dict)
    return cleaned_dict     
    
    

In [9]:
def extend_used_dict(pulse_num, dict_to_update, initialize=False, **kwargs): 
    if pulse_num not in dict_to_update.keys(): 
        initialize = True
        dict_to_update[pulse_num] = {}
    for key, item in kwargs.items(): 
        if initialize: 
            dict_to_update[pulse_num][key] = []
        dict_to_update[pulse_num][key].append(item)
    return dict_to_update

def clean_used_dict(dict_to_update):
    updated_dict = dict_to_update.copy()
    for pulse_num, pulse_items in updated_dict.items(): 
        # Find a bool array list of indexes that should be kept. 
        nesep_array = np.array(pulse_items['neseps'])
        nesep_mean, nesep_std = nesep_array.mean(), nesep_array.std()
        slices_to_keep = abs(nesep_array - nesep_mean) < nesep_std
        for key, item in pulse_items.items():
            updated_dict[pulse_num][key] = [it for n, it in enumerate(item) if slices_to_keep[n]] # np.array(item)[slices_to_keep]
    return updated_dict
def get_pulse_train_val_split(sandbox_dict): 
    """
    Takes the pulse dict gathered above and returns the train-val-test split of the pulses 
    """
    train_size = 0.85
    val_size = 0.2
    pulse_nums = list(sandbox_dict.keys())
    
    k_train = int(len(pulse_nums) * train_size)
    
    indicies = random.sample(range(len(pulse_nums)), k_train)
    
    train_val_pulses = [pulse_nums[i] for i in indicies]
    test_pulses = list(set(pulse_nums) - set(train_val_pulses))
    
    k_val = int(len(train_val_pulses) * val_size)
    indicies = random.sample(range(len(train_val_pulses)), k_val)
    
    val_pulses = [train_val_pulses[i] for i in indicies]
    
    
    train_pulses = list(set(train_val_pulses) - set(val_pulses))
        
    assert len(pulse_nums) == (len(train_pulses) + len(val_pulses) + len(test_pulses))
    return train_pulses, val_pulses, test_pulses

def make_dataset(all_dataset):
    profiles, rmids, ids, uncerts, mps = all_dataset 
    iterator = tqdm(range(len(profiles)))
    used_dict = {}
    
    for idx in iterator: 
        name = ids[idx]
        pulse_num, time = name.split('/')
        elm_perc = find_elm_percent(int(pulse_num), float(time))
        if np.isnan(elm_perc): 
            continue
        else: 
            # used_elm_percents.append(elm_perc)
            pass
        _, lor_val = get_lorenzo_pred(int(pulse_num), time)
            
        original = profiles[idx][0], profiles[idx][1],uncerts[idx][0], uncerts[idx][1], rmids[idx]
        find_elm_percent(pulse_num, time)
        ne, te, dne, dte, x = original
        new_ne, new_te =  np.minimum.accumulate(ne), np.minimum.accumulate(te)
        keep_idx = np.where(new_te == te)

        ne, te, dne, dte, x = new_ne[keep_idx], new_te[keep_idx], dne[keep_idx], dte[keep_idx], x[keep_idx]
        logical_bool_mask = np.logical_and(ne > 0, te > 0) # check that nothing is below zero, i mean what the fuck
        logical_bool_mask = np.logical_and(logical_bool_mask, dte > 0) # Also here
        logical_bool_mask = np.logical_and(logical_bool_mask, dne > 0) # Also here 
        logical_bool_mask = np.logical_and(logical_bool_mask, dte < 3000) # Don't want tesep values that are ridonklus
        ne, te, dne, dte, x = ne[logical_bool_mask], te[logical_bool_mask], dne[logical_bool_mask], dte[logical_bool_mask], x[logical_bool_mask]
        if (te < 100).sum() == 0:
            continue
        else: 
            try: 
                estimations = find_separatrix(ne, te, x, plot_result=False)
            except IndexError as e:
                print(e)
                continue
            else: 
                used_dict = extend_used_dict(pulse_num, used_dict, profiles=profiles[idx], neseps=estimations[1], ids=name, mps=mps[idx], elm_perc=elm_perc, lor_val=lor_val, rmids=rmids[idx], masks=logical_bool_mask, rseps=estimations[-1])
    cleaned_dict = clean_used_dict(used_dict)
    return cleaned_dict
def make_train_val_test_sets(sandbox_dict, train_pulses, val_pulses, test_pulses): 
    train_dict = make_subset(sandbox_dict, train_pulses)
    val_dict = make_subset(sandbox_dict, val_pulses)
    test_dict = make_subset(sandbox_dict, test_pulses)
    return {'train': train_dict, 'val': val_dict, 'test': test_dict}       

def pad_lists(list_to_pad, key, pad_length=19, **kwargs): 
    if key == 'profiles': 
        padded_profiles = np.zeros((len(list_to_pad), 2, pad_length))
        for n, prof in enumerate(list_to_pad): 
            num_missing = pad_length - len(prof[0])
            padded_prof = np.pad(prof, ((0,0), (num_missing,0)), mode='edge')
            padded_profiles[n] = padded_prof
        return padded_profiles
    elif key == 'rmids': 
        padded_radii =  np.zeros((len(list_to_pad), 19))
        for n, rmid in enumerate(list_to_pad): 
            num_missing = pad_length - len(rmid)
            padded_rmid = np.pad(rmid, (num_missing,0), mode='constant') 
            padded_radii[n] = padded_rmid
        return padded_radii
    elif key == 'masks': 
        padded_masks = np.zeros((len(list_to_pad), 19)) 
        for n, mask in enumerate(list_to_pad): 
            num_missing = pad_length - len(mask)
            padded_mask = np.pad(mask, (num_missing, 0), mode='constant', constant_values=False)
            padded_masks[n] = padded_mask
        return padded_masks
    
def make_subset(sandbox_dict, pulse_list): 
    """
    Makes a subset (train-val-test) dictionary from the dictionary of pulses given the pulses to use. 
    
    Returns
    ======
    subset_dict: 
        keys: 'profiles', 'neseps', 'ids', 'mps', 'elm_perc', 'lor_val', 'rmids'
    
    """
    subset_dict = {}
    for pulse_num in pulse_list: 
        pulse_dict = sandbox_dict[pulse_num]
        for key, item in pulse_dict.items(): 
            if key not in subset_dict.keys(): 
                subset_dict[key] = []
            if key in ['profiles', 'rmids', 'masks']: 
                padded_item = pad_lists(item, key)
            else: 
                padded_item = np.array(item)
            subset_dict[key].append(padded_item)
    for key in subset_dict.keys(): 
        if key in ['profiles', 'rmids', 'masks', 'mps']: 
            subset_dict[key] = np.vstack(subset_dict[key])
        else: 
            subset_dict[key] = np.concatenate(subset_dict[key])
    return subset_dict

In [10]:
cleaned_dict = make_dataset(dataset)
train_pulses, val_pulses, test_pulses = get_pulse_train_val_split(cleaned_dict)
vae_ready_dict = make_train_val_test_sets(cleaned_dict, train_pulses, val_pulses, test_pulses)
# has shape 

  0%|          | 0/31335 [00:00<?, ?it/s]

In [11]:
with open('../../../moxie/data/processed/cleaned_pulse_dict_230522.pickle', 'wb') as file: 
    pickle.dump(cleaned_dict, file)
with open('../../../moxie/data/processed/cleaned_ml_ready_dict_230522.pickle', 'wb') as file: 
    pickle.dump(vae_ready_dict, file)
    
# To use this for example, you need ot then split

In [13]:
print(cleaned_dict['81794'].keys())

dict_keys(['profiles', 'neseps', 'ids', 'mps', 'elm_perc', 'lor_val', 'rmids', 'masks', 'rseps'])


In [27]:
print(vae_ready_dict['train']['profiles'].shape, vae_ready_dict['train']['mps'].shape, vae_ready_dict['train']['elm_perc'].shape)

(11157, 2, 19) (11157, 13) (11157,)


In [29]:
vae_ready_dict['train'].keys(), vae_ready_dict.keys()

(dict_keys(['profiles', 'neseps', 'ids', 'mps', 'elm_perc', 'lor_val', 'rmids', 'masks']),
 dict_keys(['train', 'val', 'test']))

In [61]:
profiles_train, elm_train, mps_train = vae_ready_dict['train']['profiles'],vae_ready_dict['train']['elm_perc'], vae_ready_dict['train']['mps']

In [32]:
import torch 

In [33]:
X_train, ELM_train = torch.from_numpy(profiles_train),torch.from_numpy(elm_train), 

In [35]:
X_train.shape, ELM_train.shape

(torch.Size([11157, 2, 19]), torch.Size([11157]))

In [51]:
torch.repeat_interleave(ELM_train.unsqueeze(1), 19, 1).unsqueeze(1).shape, ELM_train.shape

(torch.Size([11157, 1, 19]), torch.Size([11157]))

In [52]:
ELM_train_ready = torch.repeat_interleave(ELM_train.unsqueeze(1), 19, 1).unsqueeze(1)

In [58]:
torch.concat((X_train, ELM_train_ready), 1)[0][2]

tensor([0.2426, 0.2426, 0.2426, 0.2426, 0.2426, 0.2426, 0.2426, 0.2426, 0.2426,
        0.2426, 0.2426, 0.2426, 0.2426, 0.2426, 0.2426, 0.2426, 0.2426, 0.2426,
        0.2426], dtype=torch.float64)

In [57]:
X_train[0][2]

IndexError: index 2 is out of bounds for dimension 0 with size 2

In [67]:
print(mps_train[:, -1:].shape)

(11157, 1)


In [68]:
mps_train[:, -1]

array([1.0188041e+22, 1.0038003e+22, 9.9620153e+21, ..., 1.2249083e+22,
       1.2246535e+22, 1.2242428e+22], dtype=float32)

In [69]:
mps_train[:, :-1]

array([[ 4.1123648e+00,  2.9231858e+00,  9.2052871e-01, ...,
        -2.2657740e+00,  1.9579650e+07,  0.0000000e+00],
       [ 4.1831694e+00,  2.9166126e+00,  9.2619830e-01, ...,
        -2.2654254e+00,  1.9127518e+07,  0.0000000e+00],
       [ 4.1876640e+00,  2.9145408e+00,  9.2521960e-01, ...,
        -2.2652464e+00,  1.9769564e+07,  0.0000000e+00],
       ...,
       [ 3.1329529e+00,  2.9051933e+00,  9.2382020e-01, ...,
        -2.1113687e+00,  7.4103455e+06,  0.0000000e+00],
       [ 3.1186874e+00,  2.9084048e+00,  9.2181474e-01, ...,
        -2.1115985e+00,  7.4067820e+06,  0.0000000e+00],
       [ 3.1158543e+00,  2.9091759e+00,  9.2034882e-01, ...,
        -2.1115129e+00,  7.7074730e+06,  0.0000000e+00]], dtype=float32)