In [None]:
import pickle 
import numpy as np
import pandas as pd
import h5py
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
from scipy import interpolate # import interp1d, UnivariateSpline
SMALL_SIZE = 20
MEDIUM_SIZE = 22
BIGGER_SIZE = 24

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title


# Update PSI Dataset

Want to include time evolving parameters. 


## Requirements of Dataset 

- Two methods of splitting data
    - A la diego - same pulses are used across each split of the data, i.e., each pulse data is split 70-20-10 
    - Split pulses by 70-20-10 (train-valid-test)

- Pulses and time windows come from JET PDB (flat top H-mode)
- Machine parameters vary in time from data 




# Grab validated pulses from Jet PDB

The relevant pulses and their time windows come from the JET PDB. 
More info found in Lorenzo's paper, but all are H-modes. 

Two filters applied: 

- Pulse number > 79000
    - This comes from the fact that the HRTS equipment was updated for pulses > 79000. 
- HRTS Flag > 0
    - All pulses with HRTS validation flag of 0 are unvalidated, and not publish worthy 

In [None]:
jet_pdb = pd.read_csv('/home/kitadam/ENR_Sven/moxie/data/raw/pedestal-database.csv')
important_info = jet_pdb[['shot', 't1', 't2', 'neseparatrixfromexpdata10^19(m^-3)', 'error_neseparatrixfromexpdata10^19(m^-3)','neseparatrixfromfit10^19(m^-3)',
 'error_neseparatrixfromfit10^19(m^-3)','FLAG:HRTSdatavalidated']]
final_pulse_list = important_info[(important_info['shot'] >= 79000) &  (important_info['FLAG:HRTSdatavalidated'] > 0)]

# Making a training (or valid or test) dataset

We have to define some utility functions so grab the relevant data. 

### Machine Parameters 

- `average_machine_with_times()` & `sample_input()`
- Go through the time slice idx in the time window
- The machine param value for that time slice is the mean of the values surrounding 10 X temporal resolution of HRTS diagnostic
- This is to make sure that there are always values being grabed, as each machine param has different temporal resolution
- Sometimes a diagnostic is not use during a pulse or profile window time, e.g., ICRH


### $n_e$ and $T_e$ Profiles

- `get_ne_and_te_profiles()` 
- Profiles between the time ranges given 
- Interpolated betweeen $R_{mid} - R_{mid, sep} \in [-0.89, 0.1]$ 
    - Currently using Akima1DInterpolator from scipy 
- Density values are clipped  at 0
- Temperature values clipped when ne is close to 0 in the pedestal (NOT FINISHED YET) 

In [None]:
def sample_input(mp_loc, key, t1, t2, window_times): 
    mp_val, mp_time = mp_loc[key]['values'][:], mp_loc[key]['time'][:]
    final_mp_vals = average_machine_with_times(window_times, mp_val, mp_time)
    return np.array(final_mp_vals)

def average_machine_with_times(wind_times, mp_values, mp_times): 
    delta_T = 0.05002594*10
    sampled_vals = np.zeros_like(wind_times)
    for slice_num, time in enumerate(wind_times): 
        if len(mp_values) == 0 or time not in mp_times: # if the diagnostic is not in use during the times available, then it is 0
            break
        aggregation_idx = np.logical_and(mp_times < time, mp_times > time - delta_T)
        aggregation_vals = mp_values[aggregation_idx]
        mean_of_window = np.mean(aggregation_vals)
        if np.isnan(mean_of_window) or mean_of_window == np.inf: 
            print('NAN OCCURED')
            print(aggregation_vals, mean_of_window, mp_values, aggregation_idx.sum(), mp_times)
            break
        else:
            sampled_vals[slice_num] = mean_of_window
    return sampled_vals


def get_ne_and_te_profiles(raw_profiles, raw_times, t1, t2):
    # Returns np.arrays of shape [num_time_slices, 63]
    sample_ne = raw_profiles['NE'][:]
    sample_te = raw_profiles['TE'][:]
    
    profiles_idx = np.logical_and(raw_times >= t1, raw_times <= t2)
    
    profiles_ne = sample_ne[profiles_idx]
    profiles_te = sample_te[profiles_idx]
    # print(profiles_ne.shape)
    """
    if len(profiles_ne[0]) != 63:
        print('Padding occurs: Increase from {} \n'.format(len(profiles_ne[0])))
        print('Before', profiles_ne)
        for prof in profiles_ne: 
            plt.plot(prof)
        plt.show()
        profiles_ne = np.pad(profiles_ne, ((0, 0), (0, 63 - len(profiles_ne[0]))), 'constant')
        print('After', profiles_ne)
        plt.plot(profiles_ne[0])
        plt.show()
        
    if len(profiles_te[0]) != 63:
        profiles_te = np.pad(profiles_te, ((0, 0), (0, 63 - len(profiles_te[0]))), 'constant')
    """
    # Returns to numpy arrays of the same shape
    return profiles_ne, profiles_te

def find_rsep(raw_psi, raw_rho, radius): 
    # Have to find the index of rsep, which is either where PSI == 1.0, or PSI == RHO 
    rseps = np.zeros(len(raw_psi))
    for n, (slice_psi, slice_rho) in enumerate(zip(raw_psi, raw_rho)): 
        radius_new = np.linspace(3.2, radius.max(), 1000)
        
        f_psi = interpolate.interp1d(radius, slice_psi)
        f_rho = interpolate.interp1d(radius, slice_rho)
        
            
        new_psi = f_psi(radius_new)
        new_rho = f_rho(radius_new)
        
        diff_from_one = np.abs(new_psi - 1.0)
        closest_to_one_psi = np.argmin(diff_from_one)
        
        diff_from_rho = np.abs(new_psi - new_rho)
        closest_to_rho = np.argmin(diff_from_rho)
        
        rsep = radius_new[closest_to_rho]
        rseps[n]  = rsep
        # print(rsep)
        
        if rsep < 3:
            fig = plt.figure(figsize=(10, 10))
            plt.scatter(radius, slice_psi, label='$\Psi$')
            plt.scatter(radius, slice_rho, label=r'$\rho$')
            plt.plot(radius_new,f_psi(radius_new), label='Interpolation', ls='--', c='black')
            plt.plot(radius_new,f_rho(radius_new), ls='--', c='black')
            plt.hlines(1.0, radius_new[closest_to_rho-10], radius_new[closest_to_rho+10], color='darkgreen')
            plt.vlines(radius_new[closest_to_rho], 0.9, 1.1, color='darkgreen', label='$R_{sep} = $ ' + '{:.5}'.format(rsep))
            plt.xlabel('Major Radius [m]')
            plt.ylabel('Mapped Cords [arb.]')
            plt.title(r'Determining $R_{sep}$ via interpolation (n=1000) of $\Psi$ and $\rho$')
            plt.legend()
            # plt.savefig('./example_of_finding_rsep')
            plt.show()
            # break

    return rseps 

def find_minmax_rmid_rsep(density_profs, temperature_profs, radius, rseps):
    shifted_radii = np.zeros_like(density_profs)
    for n, rsep in enumerate(rseps): 
        new_radius = radius - rsep
        shifted_radii[n] = new_radius
        # print(new_radius)
        # print(radius)
    return np.min(shifted_radii), np.max(shifted_radii)
    # new_radii = np.linspace(np.min(shifted_radii[:, 0]), np.max(shifted_radii[:, -1]), 100)
    


def shift_profiles(density_profs, temperature_profs, radius, rseps, global_min, global_max):
    shifted_radii = np.zeros_like(density_profs)
    original_radii = np.zeros_like(density_profs)
    # interped_radii = np.zeros_like((len(density_profs), 65))
    
    
    
    for n, rsep in enumerate(rseps): 
        new_radius = radius - rsep
        shifted_radii[n] = new_radius
        original_radii[n] = radius
    
    new_radii = np.linspace(-0.88, 0.061, 65)
    
    interped_radii = np.stack([new_radii for i in range(len(density_profs))]) 
    interped_d = np.zeros_like(interped_radii)
    interped_t = np.zeros_like(interped_radii)
    
    for n, (shif_rad, prof_d, prof_t) in enumerate(zip(shifted_radii, density_profs, temperature_profs)): 
        
        df_new_spline = interpolate.Akima1DInterpolator(shif_rad, prof_d)
        tf_new_spline = interpolate.Akima1DInterpolator(shif_rad, prof_t)
        new_density = df_new_spline(new_radii).clip(min=0)
        new_temperature = tf_new_spline(new_radii) 
        
        interped_radii[n] = new_radii
        interped_d[n] = new_density 
        interped_t[n] = new_temperature
        
    fig, axs = plt.subplots(2, 2, figsize=(25, 12), constrained_layout=True)
    axs = axs.ravel()    
    axs[0].scatter(interped_radii, interped_d, c='green')
    axs[1].scatter(interped_radii, interped_t, c='green')    
    axs[0].scatter(shifted_radii, density_profs)
    axs[1].scatter(shifted_radii, temperature_profs)
    axs[0].set_xlabel('$R - R_{sep}$ [m]')
    axs[2].scatter(original_radii, density_profs)
    axs[3].scatter(original_radii, temperature_profs)
    axs[2].set_xlabel('Major Radius [m]')
    fig.supylabel('Density')
    plt.show()
    
    return interped_d, interped_t # new_d_prof, new_t_prof


def clamp_temperature(density_profs, temperature_profs): 
    new_temperatures = np.zeros_like(temperature_profs)
    temp_tempeartur_profs = temperature_profs.copy()
    # print(temp_tempeartur_profs)
    
    for n, (d_prof, t_prof) in enumerate(zip(density_profs, temp_tempeartur_profs)):
        edge_temperatures = t_prof[-15:-1]
        edge_densities = d_prof[-15:-1]
        
        faulty_mask = np.logical_and(edge_temperatures > 100, edge_densities < 1e19)
        # print(faulty_mask, edge_temperatures, edge_densities)
        # print(faulty_mask, edge_temperatures, edge_densities)
        if faulty_mask.sum() > 0: 
            print('Changes to the temperature profile')
            print(faulty_mask, edge_temperatures, edge_densities)
        edge_temperatures[faulty_mask] = 0 
        t_prof[-15:-1] = edge_temperatures
        new_temperatures[n] = t_prof
    return new_temperatures

In [None]:
with h5py.File('/home/kitadam/ENR_Sven/moxie/data/processed/profile_database_only_shots.hdf5', 'r') as f:
    pulse_idx = 79635
    
    t1, t2 = final_pulse_list[final_pulse_list['shot'] == (pulse_idx)][['t1', 't2']].values[0]
    
    pulse_sample = f[str(pulse_idx)]
    pulse_raw_mps = pulse_sample['machine_parameters'] # rel. mps
    pulse_raw_profiles = pulse_sample['profiles'] # rel. profiles 
    
    raw_pulse_times = pulse_raw_profiles['time'][:]
    window_profile_time_mask = np.logical_and(raw_pulse_times >= t1, raw_pulse_times <= t2) # Mask of the relevant profiles
    window_time_stamps = raw_pulse_times[window_profile_time_mask]
    print(pulse_idx, 'Time Range {:.3} -> {:.3}'.format(t1, t2), 'Steps: {}'.format(len(window_time_stamps)))


    # Now we grab the relevant profiles, temperature and density as numpy arrays with shape [num_time_slices, 63]
    pulse_ne_profiles, pulse_te_profiles = get_ne_and_te_profiles(pulse_raw_profiles, raw_pulse_times, t1, t2)

    # Gather ingredients to shift the profiles, radius, psi, and rho
    
    raw_pulse_radii = pulse_raw_profiles['radius'][:]
    raw_pulse_psi = pulse_raw_profiles['PSI'][:][window_profile_time_mask] # Mask 
    raw_pulse_rho = pulse_raw_profiles['RHO'][:][window_profile_time_mask]
    
    print(raw_pulse_psi.shape, pulse_ne_profiles.shape)

    # Find the rseps for each time slice
    rseps = find_rsep(raw_pulse_psi, raw_pulse_rho, raw_pulse_radii)
    
    shift_profiles(pulse_ne_profiles, pulse_te_profiles, raw_pulse_radii, rseps)
    # Find minimum Rmid - Rmid,sep
    local_min, local_max = find_minmax_rmid_rsep(pulse_ne_profiles, pulse_te_profiles, raw_pulse_radii, rseps)
    if local_min < global_min:
        print('Found new global min', local_min)
        global_min = local_min

    if local_max > global_max: 
        print('Found new global max', local_max)
        global_max = local_max
    

In [None]:
label_dict = {'BT': '$B_T$ [T]', 'CR0': 'a [m]', 'ELER': '$\Gamma \; (10^{22}$ e/s)', 'ELON': '$\kappa$ [-]', 'POHM': '$P_{OHM}$ [MW]', 'P_ICRH': '$P_{ICRH}$ [MW]', 'P_NBI': '$P_{NBI}$ [MW]', 'Q95' :'$q_{95}$ [-]', 'RGEO': '$R_{geo}$ [m]', 'TRIL': '$\delta_L$', 'TRIU': '$\delta_U$', 'VOLM': '$V_P$ [m$^{-3}$]', 'XIP': '$I_P$ [MA]'}
mp_keys = label_dict.keys()
print(mp_keys)

# Creating the database 


In [None]:
all_pulse_numbers = np.array(list(set(final_pulse_list['shot']))) # Array of all the pulse numbers included in the database
global_min = np.inf
global_max = -np.inf
with h5py.File('/home/kitadam/ENR_Sven/moxie/data/processed/profile_database_only_shots.hdf5', 'r') as f:
    
    for n, (index, row) in enumerate(final_pulse_list.iterrows()):
        # Grab the id of the pulse, and relevant time windows, t1=start, t2= end
        pulse_idx, t1, t2 = str(int(row['shot'])), row['t1'], row['t2'] 
        if pulse_idx == str(89151) or pulse_idx == str(89147): 
            print('Voldemort arrive, not grabbing')
            continue
        pulse_sample = f[pulse_idx] # relevant pulse
        pulse_raw_mps = pulse_sample['machine_parameters'] # rel. mps
        pulse_raw_profiles = pulse_sample['profiles'] # rel. profiles 
        # print(pulse_raw_profiles.keys())
        # print(pulse_raw_mps.keys())
        
        """ 
        Times of the profiles occur later in pulse than machine parameters do. 
        Therefore, we need the actual times in which the slices occur then coordinate these with the machine parameter times. 
        """ 
        raw_pulse_times = pulse_raw_profiles['time'][:]
        window_profile_time_mask = np.logical_and(raw_pulse_times >= t1, raw_pulse_times <= t2) # Mask of the relevant profiles
        window_time_stamps = raw_pulse_times[window_profile_time_mask]
        print(pulse_idx, 'Time Range {:.3} -> {:.3}'.format(t1, t2), 'Steps: {}'.format(len(window_time_stamps)))
        
        
        # Now we grab the relevant profiles, temperature and density as numpy arrays with shape [num_time_slices, 63]
        pulse_ne_profiles, pulse_te_profiles = get_ne_and_te_profiles(pulse_raw_profiles, raw_pulse_times, t1, t2)
        
        # Gather ingredients to shift the profiles, radius, psi, and rho
        try: 
            raw_pulse_radii = pulse_raw_profiles['radius'][:]
            raw_pulse_psi = pulse_raw_profiles['PSI'][:][window_profile_time_mask] # Mask 
            raw_pulse_rho = pulse_raw_profiles['RHO'][:][window_profile_time_mask]
        except KeyError as e:
            print(e)
            continue
            
        # Find the rseps for each time slice
        rseps = find_rsep(raw_pulse_psi, raw_pulse_rho, raw_pulse_radii)
        
        # Find minimum Rmid - Rmid,sep
        local_min, local_max = find_minmax_rmid_rsep(pulse_ne_profiles, pulse_te_profiles, raw_pulse_radii, rseps)
        
        """
        if local_min < global_min:
            print(pulse_idx, 'Time Range {:.3} -> {:.3}'.format(t1, t2), 'Steps: {}'.format(len(window_time_stamps)))
            print('Found new global min\n', local_min)
            global_min = local_min
            
        if local_max > global_max: 
            print(pulse_idx, 'Time Range {:.3} -> {:.3}'.format(t1, t2), 'Steps: {}'.format(len(window_time_stamps)))
            print('Found new global max\n', local_max)
            global_max = local_max
        """   
        # Feed and adjust profiles
        
        # First clamp the temperatures 
        pulse_te_profiles = clamp_temperature(pulse_ne_profiles, pulse_te_profiles)
        new_ne_profile, new_te_profile = shift_profiles(pulse_ne_profiles, pulse_te_profiles, raw_pulse_radii, rseps, global_min, global_max)
        
        
        
        # Combining the two, we get an array of shape [num_time_slices, 2, 63]
        pulse_ne_and_te_profile = np.stack([pulse_ne_profiles, pulse_te_profiles], axis=1)
        
        
        # Get machine parameters
        
        sampled_machine_params = np.array([sample_input(pulse_raw_mps, key, t1, t2, window_time_stamps) for key in mp_keys]).T
        
                
        """ 
        Some checks. 
        The sampled machine params and the pulse profiles should be the same length of time windows
        """
        
        assert len(pulse_ne_profiles) == len(window_time_stamps) 
        assert len(sampled_machine_params) == len(window_time_stamps)      
        if np.isnan(sampled_machine_params).any() == True: 
            for blah in sampled_machine_params: 
                print(blah)
            print(window_time_stamps)
            print(pulse_raw_mps['P_ICRH/values'][:], pulse_raw_mps['P_ICRH/time'][:])
        assert np.isnan(sampled_machine_params).any() == False
        
        if n == 20: 
            break
        

In [None]:
84459