# Creating the dataset

We are using data from JET. 

### Starting data 

- HRTS profiles 
    - Electron density and temperature, $n_e$, $T_e$
    - HRTS Line of Sight coordinates, $R_{HRTS}$
- Machine parameters
    - $R$ [m], plasma major radius
    - $a$ [m], plasma minor radius 
    - $V_P$ [m$^3$], total plasma volume enclosed by LCFS
    - $\delta_u,\delta_l$ [-], upper and lower triangularities  
    - $\kappa$ [-], elongation, or ratio of height to width of plasma
    - $P_{OHM}$ [W], ohmic power of the plasma
    - $I_P$ [A], total plasma current enclosed by LCFS 
    - $B_T$ [T], total toroidal magnetic field
    - $P_{NBI}$ [W], total neutral beam injected power into the plasma
    - $P_{ICRH}$ [W], total ion cyclotron resonance heating imposed onto plasma
    - $\Gamma$ [e/s], fuelling rate of main isotope (aka gas puff)
    - $q_{cycl}$ [-], safety factor, calculated from $q_{cycl} = \frac{(1+2\kappa^2)}{2}\frac{2B_T\pi a^2}{RI_P\mu_0}$
    
#### Profile Information 

- Pulses/time windows found in JET pedestal database
- Subset of pulses: 
    - ILW
    - HRTS Validated 
    - Pulses with either no seeding or with nitrogen seeding
    - Deuterium fuelling
    - No kicks, rmps, or pellets
    


After filtering above: 
- Total number of pulses & time slices
    - 1248
   
### Further slice filtering 

#### ELM Percents

Slices which an ELM percentage could be calculated, using ELM timings from JET pedestal database


After filtering: 
- Total number of slices from 608 pulses: 
    - 23499
    
#### $n_{e, sep}$

- TODO: Describe process of gathering $n_{e, sep}$
- Calculation is not perfect, so we remove slices that have an $n_{e, sep}$ approximation that falls outside of 1 standard deviation of the mean $n_{e, sep}$ for the pulse. 

- Total number of slices: 
    - 16111
    
## Profile padding 

The maximum length of the profiles is 19, so all will be padded to that value. 


In [None]:
from tqdm.notebook import tqdm 
import numpy as np
import pickle
from scipy import interpolate
from scipy.signal import savgol_filter
from scipy.stats import norm
import pandas as pd
import math
from collections.abc import Iterable

In [None]:
def replace_q95_with_qcly(mp_set, mu, var):
    mp_set = de_standardize(mp_set, mu, var)
    mu_0 = 1.25663706e-6 # magnetic constant
    
    mp_set[:, 0] = ((1 + 2*mp_set[:, 6]**2) / 2.0) * (2*mp_set[:, 9]*torch.pi*mp_set[:, 2]**2) / (mp_set[:, 1] * mp_set[:, 8] * mu_0)
    mp_set = standardize(mp_set, mu, var)
    return mp_set
def extend_used_dict(pulse_num, dict_to_update, initialize=False, **kwargs): 
    if pulse_num not in dict_to_update.keys(): 
        initialize = True
    for key in kwargs: 
        if initialize: 
            dict_to_update[pulse_num] = {}
            dict_to_update[pulse_num][key] = []
        dict_to_update[pulse_num][key].append(kwargs[key])
    return dict_to_update

In [1]:
def make_dataset(all_dataset): 
    """
    This organizes all the data by pulse.  
    We can then split by different pulses for training/val/test later. 
    
    Parameters
    ==========
    the whole dataset! Found earlier, but a 
    
    Returns
    =======
    
    used_dict: a dictionary containing all data by pulse number: 
        used_dict[123456]['profiles'] corresponds to the profiles of pulse
        with keys: 'profiles', 'neseps', 'ids', 'mps', 'elm_perc', 'lor_val', 'mask'
    """
    profiles, rmids, ids, uncerts, mps = all_dataset 
    iterator = tqdm(range(len(profiles)))
    
    # Organize by pulse number
    used_dict = {}
    # e.g., 
    for idx in iterator: 
        name = ids[idx]
        pulse_num, time = name.split('/')
        elm_perc = find_elm_percent(int(pulse_num), float(time))
        # IF THERE IS NO ELM PERCENTAGE, THEN GO TO THE NEXT SLICE
        if np.isnan(elm_perc): 
            continue 
        else: 
            pass
        original = profiles[idx][0], profiles[idx][1],uncerts[idx][0], uncerts[idx][1], rmids[idx], mps[idx]
        
        ne, te, dne, dte, x, mp = original
        # First we make everything monotonic deacreasing 
        new_ne, new_te =  np.minimum.accumulate(ne), np.minimum.accumulate(te)
        keep_idx = np.where(new_te == te)
        ne, te, dne, dte, x = new_ne[keep_idx], new_te[keep_idx], dne[keep_idx], dte[keep_idx], x[keep_idx]
        # Get mask for the VAE learning process and finding nesep! 
        logical_bool_mask = np.logical_and(ne > 0, te > 0) # check that nothing is below zero, i mean what the fuck
        logical_bool_mask = np.logical_and(logical_bool_mask, dte > 0) # Also here
        logical_bool_mask = np.logical_and(logical_bool_mask, dne > 0) # Also here 
        logical_bool_mask = np.logical_and(logical_bool_mask, dte < 3000) # Don't want tesep values that are ridonklus
        # check if we still have any values below te < 100, if not then we can not reliably find nesep! 
        ne, te, dne, dte, x = ne[logical_bool_mask], te[logical_bool_mask], dne[logical_bool_mask], dte[logical_bool_mask], x[logical_bool_mask]
        if (te < 100).sum() == 0:
            continue
        else: 
            try: 
                estimations = find_separatrix(ne, te, x, plot_result=False)
            except IndexError as e: 
                print(e)
                continue 
            else: 
                used_dict = extend_used_dict(pulse_num, used_dict, profiles=profiles[idx], neseps=estimations[1], ids=name, mps=mps[idx], elm_perc=elm_perc, lor_val=lor_val, mask=logical_bool_mask)        
                
    # Do something with the used dict! 
    
    

In [None]:
def get_pulse_train_val_split(sandbox_dict): 
    """
    Takes the pulse dict gathered above and returns the train-val-test split of the pulses 
    """
    train_size = 0.85
    val_size = 0.2
    pulse_nums = list(sandbox_dict.keys())
    
    k_train = int(len(pulse_nums) * train_size)
    
    indicies = random.sample(range(len(pulse_nums)), k_train)
    
    train_val_pulses = [pulse_nums[i] for i in indicies]
    test_pulses = list(set(pulse_nums) - set(train_val_pulses))
    
    k_val = int(len(train_val_pulses) * val_size)
    indicies = random.sample(range(len(train_val_pulses)), k_val)
    
    val_pulses = [train_val_pulses[i] for i in indicies]
    
    
    train_pulses = list(set(train_val_pulses) - set(val_pulses))
        
    assert len(pulse_nums) == (len(train_pulses) + len(val_pulses) + len(test_pulses))
    return train_pulses, val_pulses, test_pulses
    