In [None]:
import pytorch_lightning as pl
from data.profile_dataset import DS, DataModuleClass
import torch
from matplotlib.pyplot import cycler

import matplotlib.pyplot as plt

from sklearn.model_selection import train_test_split

from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.callbacks.early_stopping import EarlyStopping

from experiments import DIVA_EXP
from models import DIVA_v1, DIVA_v2
from pathlib import Path
import os
import numpy as np
SMALL_SIZE = 20
MEDIUM_SIZE = 22
BIGGER_SIZE = 24

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

def de_standardize(x, mu, var):
    return (x*var) + mu


In [None]:
pin_memory=False
cpus_per_trial=4
data_dir = "/home/local/kitadam/ENR_Sven/moxie/data/processed/profile_database_v1_psi22.hdf5"
num_epochs=50
gpus_per_trial=0

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

generator=torch.Generator().manual_seed(42)
torch.manual_seed(42)
logger = TensorBoardLogger("tb_logs", name='ModelVis')

STATIC_PARAMS = {'data_dir':data_dir,
                'num_workers': cpus_per_trial,
                'pin_memory': pin_memory}

HYPERPARAMS = {'LR': 0.001, 'weight_decay': 0.0, 'batch_size': 512}

model_hyperparams = {'in_ch': 2, 'out_dim':63,
                        'mach_latent_dim': 10, 'beta_stoch': 0.01, 'beta_mach':  100.,
                        'alpha_mach': 25.0, 'alpha_prof': 1.0,
                    'loss_type': 'semi-supervised'}


params = {**STATIC_PARAMS, **HYPERPARAMS, **model_hyperparams}
model = DIVA_v2(**model_hyperparams)

print(model)
trainer_params = {'max_epochs': num_epochs,  'gpus': gpus_per_trial if str(device).startswith('cuda') else 0,
                'gradient_clip_val': 0.5, 'gradient_clip_algorithm':"value",
                'profiler':"simple"}


experiment = DIVA_EXP(model, params)
runner = pl.Trainer(logger=logger, **trainer_params)


datacls = DataModuleClass(**params)

runner.fit(experiment, datamodule=datacls)
runner.test(experiment, datamodule=datacls)

# Prior Reg Studies

In [None]:
LABEL = ['Q95', 'RGEO', 'CR0', 'VOLM', 'TRIU', 'TRIL', 'XIP', 'ELON', 'POHM', 'BT', 'ELER', 'P_NBI', 'P_ICRH']


data_point_index = 100


train_dataset = runner.datamodule.train_set
train_prof, train_mp = train_dataset.X, train_dataset.y

model = runner.model.model

mu_T, var_T = runner.datamodule.get_temperature_norms()
mu_MP, var_MP = runner.datamodule.get_mp_norms()

# From training dataset, grab a real (the first) profile and real machine params
real_density_profile, real_temperature_profile = train_prof[data_point_index:data_point_index+1, 0], train_prof[data_point_index:data_point_index+1, 1] 
real_mp_vals = train_mp[data_point_index:data_point_index+1]

print(real_mp_vals)


# print(train_prof[data_point_index:data_point_index+1])

# get conditional priors

with torch.no_grad():
    conditional_mu, conditional_var = model.p_zmachx(real_mp_vals)
    mu_stoch, log_var_stoch, mu_mach, log_var_mach = model.q_zy(train_prof[data_point_index:data_point_index+1])
    z_stoch, z_mach_conditional, z_mach_original = model.reparameterize(mu_stoch, log_var_stoch), model.reparameterize(conditional_mu, conditional_var), model.reparameterize(mu_mach, log_var_mach)
    
    out_mp_conditional = model.q_hatxzmach(z_mach_conditional)
    out_mp_original = model.q_hatxzmach(z_mach_original)
    
    
    
print(conditional_mu)
print(conditional_var)


print(out_mp_conditional)
print(out_mp_original)



# Sweeping Studies

Lets look at the current $I_P$

In [None]:

LABEL = ['Q95', 'RGEO', 'CR0', 'VOLM', 'TRIU', 'TRIL', 'XIP', 'ELON', 'POHM', 'BT', 'ELER', 'P_NBI', 'P_ICRH']

param = 'XIP'
param_idx = LABEL.index(param)
dim_idx = 7
resolution = 100
data_point_index = 100

# Get the full training dataset and normalizing constant for data
train_dataset = runner.datamodule.train_set
train_prof, train_mp = train_dataset.X, train_dataset.y

model = runner.model.model
model.eval()


mu_T, var_T = runner.datamodule.get_temperature_norms()
mu_MP, var_MP = runner.datamodule.get_mp_norms()

# From training dataset, grab a real (the first) profile and real machine params
real_density_profile, real_temperature_profile = train_prof[data_point_index, 0], train_prof[data_point_index, 1] 

real_mp_vals = train_mp[data_point_index]
real_mp_vals = de_standardize(real_mp_vals, mu_MP, var_MP)
real_val = real_mp_vals[param_idx]

# Now create the latent space for all the training profiles
with torch.no_grad():
    mu_stoch, log_var_stoch, mu_mach, log_var_mach = model.q_zy(train_prof)
    z_stoch, z_mach = model.reparameterize(mu_stoch, log_var_stoch), model.reparameterize(mu_mach, log_var_mach)

# Define the sweep, i.e., from minimum to max of dim in question 


z_param = z_mach[:, dim_idx]

min_z_param, max_z_param = min(z_param), max(z_param)
sweep = torch.linspace(min_z_param, max_z_param, resolution)

z_param = z_mach[:, 6]

min_z_param, max_z_param = min(z_param), max(z_param)
sweep_6 = torch.linspace(min_z_param, max_z_param, resolution)


z_param = z_mach[:, 8]
min_z_param, max_z_param = min(z_param), max(z_param)
sweep_8 = torch.linspace(max_z_param, min_z_param, resolution)

# Grab the data point index profiles latent space 
sample_z_stoch, sample_z_mach = z_stoch[data_point_index], z_mach[data_point_index] # torch.mean(z_stoch, 0), torch.mean(z_mach, 0)
sample_z = torch.cat([sample_z_stoch, sample_z_mach])

# Create a tensor with 5 copies of this space, but fill the latent dim in question with the sweep
all_z = sample_z.repeat(resolution, 1)
all_z[:, dim_idx + 5] = sweep # the extra 5 comes from Z_stoch being in the front
all_z[:, 8 + 5] = sweep_8 # the extra 5 comes from Z_stoch being in the front
all_z[:, 6 + 5] = sweep_6 # the extra 5 comes from Z_stoch being in the front


# Generate profiles from the sweep and original 
with torch.no_grad(): 
    out_profs_all = model.p_yhatz(all_z)
    out_profs_sample = model.p_yhatz(sample_z.repeat(2, 1))
    
out_profs_all_d = out_profs_all[:, 0] # Density
out_profs_all_t = out_profs_all[:, 1] # Temperature

sample_density_profile = out_profs_sample[0, 0]
sample_temperature_profile = out_profs_sample[0, 1]


# Generate machine params form the sweep 

with torch.no_grad():
    out_mp_sample = model.q_hatxzmach(sample_z[5:].repeat(2, 1))
    out_mp_all = model.q_hatxzmach(all_z[:, 5:])




# Denormalize the machine parameters



out_all_mp_vals = de_standardize(out_mp_all, mu_MP, var_MP)
predictied_mp_vals = out_all_mp_vals[:, param_idx]

sample_mp_vals = de_standardize(out_mp_sample, mu_MP, var_MP)
sample_mp_val = sample_mp_vals[0, param_idx]



# Denormailze the temperature profile

min_d, max_d  =  0.0, max(torch.max(train_prof[:, 0], 0)[0])
min_t, max_t = 0, 10000 # min(torch.min(train_prof[:, 1], 0)[0]), max(torch.max(train_prof[:, 1], 0)[0])



sample_temperature_profile = de_standardize(sample_temperature_profile, mu_T, var_T)
real_temperature_profile = de_standardize(real_temperature_profile, mu_T, var_T)


In [None]:
fig, axs = plt.subplots(1, 2, figsize=(18, 8), dpi=310, constrained_layout=True)

label_dict = {key: val.detach().numpy() for key, val in zip(LABEL, real_mp_vals)}
axs[0].set(title='Density')
axs[1].set(title='Temperature')


cmap = plt.cm.get_cmap('coolwarm_r')

norm = Normalize(vmin= min(predictied_mp_vals), vmax=max(predictied_mp_vals))
colors = norm(predictied_mp_vals)

colors = cmap(colors)
"""
# std_profiles = np.array([392.3171, 424.2952, 368.7801, 328.1372, 345.7570, 320.7389, 365.1653,
        464.9704, 295.2764, 321.5903, 349.4917, 257.1602, 384.2598, 322.0619,
        239.7611, 380.7346, 296.2313, 196.7390, 196.3488, 232.7588, 212.4754,
        166.3945, 281.7231, 179.5035, 148.0473, 159.0347, 163.4847, 214.3491,
        214.2839, 155.4773, 166.1069, 204.7064, 122.0801, 170.2776, 161.2505,
        126.1403, 189.4895, 169.6824, 134.5641, 209.9317, 133.9798, 117.8455,
        127.0341, 128.5185, 131.2482, 273.9699, 149.8306, 146.5484, 171.6822,
        148.0955, 121.5661, 192.1792, 150.7070, 141.6074, 160.8060, 101.2769,
         51.5434, 148.8202,  44.0291, 112.2074, 120.0807,  66.8595,  44.7244])

# std_dens = np.array([0.0170, 0.0189, 0.0141, 0.0159, 0.0138, 0.0153, 0.0177, 0.0194, 0.0143,
        0.0139, 0.0144, 0.0173, 0.0159, 0.0168, 0.0139, 0.0156, 0.0170, 0.0155,
        0.0182, 0.0153, 0.0146, 0.0178, 0.0187, 0.0200, 0.0198, 0.0199, 0.0175,
        0.0188, 0.0126, 0.0105, 0.0174, 0.0182, 0.0157, 0.0161, 0.0167, 0.0160,
        0.0163, 0.0189, 0.0167, 0.0193, 0.0190, 0.0171, 0.0178, 0.0167, 0.0160,
        0.0205, 0.0159, 0.0150, 0.0182, 0.0182, 0.0223, 0.0217, 0.0201, 0.0145,
        0.0142, 0.0235, 0.0176, 0.0157, 0.0096, 0.0102, 0.0026, 0.0036, 0.0034])

"""
axs[0].set_prop_cycle(color=colors)
axs[1].set_prop_cycle(color=colors)
for k in range(1, resolution + 1):
    axs[0].plot(out_profs_all_d[k-1], lw=2, alpha=0.7)
    axs[1].plot(de_standardize(out_profs_all_t[k-1], mu_T, var_T))

axs[0].set_xlabel('Radius from core [arb.]')
axs[1].set_xlabel('Radius from core [arb.]')
axs[0].set_ylabel('$n_e \; \; [10^{20}$ m$^{-3}]$', size='xx-large')
axs[1].set_ylabel('$T_e \; \;$ [eV]', size='xx-large')

axs[0].plot(real_density_profile, c='black',label='Real', linestyle='--', lw=2, alpha=0.8)
axs[1].plot(real_temperature_profile, c='black', label='\n'.join([f'{key}: {value:.4}' for key, value in label_dict.items()]), linestyle='--', lw=2, alpha=0.8)
# axs[1].plot(real_temperature_profile + std_profiles, c='grey', lw=2.5, alpha=0.7)
# axs[1].plot(real_temperature_profile - std_profiles, c='grey', lw=2.5, alpha=0.7)
# axs[0].plot(real_density_profile + std_dens, label='Variation in Real', c='yellow', lw=5, alpha=0.7)
# axs[0].plot(real_density_profile - std_dens, c='yellow', lw=5, alpha=0.7)


axs[0].plot(sample_density_profile, c='forestgreen', label='Generated',linestyle='--', lw=2)
axs[1].plot(sample_temperature_profile, c='forestgreen', linestyle='--', lw=2, alpha=0.8)

axs[0].legend(fontsize=10)
axs[1].legend(fontsize=10)

mappable = plt.cm.ScalarMappable(cmap=cmap)
mappable.set_array(predictied_mp_vals)

cb = fig.colorbar(mappable, shrink=0.7, aspect=5, label='Predicted $I_P$ [A]',  pad=0.01)
cb.ax.plot([min(predictied_mp_vals), max(predictied_mp_vals)], [real_val]*2, 'black')# [norm(real_val)]*2)
cb.ax.plot([min(predictied_mp_vals), max(predictied_mp_vals)], [sample_mp_val]*2, 'forestgreen')# [norm(real_val)]*2)

fig.suptitle('Sweeping the Current: DIVA: {}-D Stoch, {}-D Mach \n Latent Dim 6&7&8'.format(model.stoch_latent_dim, model.mach_latent_dim))
plt.show()


In [None]:
from matplotlib.colors import Normalize
from matplotlib import colors as mcolors
from matplotlib import rcParams
rcParams['axes.labelpad'] = 20

fig = plt.figure(figsize=(25, 12), dpi=400)

# Density Plot
ax = fig.add_subplot(1, 2, 1, projection='3d')

x = np.arange(0, 63)# np.repeat([np.arange(0, 63), ], resolution, 0)
y = predictied_mp_vals

X, Y = np.meshgrid(x, y)


Z = out_profs_all_d.detach().numpy()

norm = Normalize()
colors = norm(Y)

cmap = plt.cm.get_cmap("coolwarm_r")
print(X.shape, Y.shape, Z.shape)
surface = ax.plot_surface(X, Y, Z, facecolors=cmap(colors), shade=True, alpha=0.75)
ax.set_zlim(0, 0.5)

ax.plot(x, np.ones_like(x)*real_val.detach().numpy(), real_density_profile.detach().numpy(), c='black', label='\n'.join([f'{key}: {value:.4}' for key, value in label_dict.items()]), linestyle='--', lw=5)
ax.plot(x, np.ones_like(x)*sample_mp_val.detach().numpy(), sample_density_profile.detach().numpy(), c='forestgreen', label='\n'.join([f'{key}: {value:.4}' for key, value in label_dict.items()]), linestyle='--', lw=5)
# ax.legend(fontsize=5)


mappable = plt.cm.ScalarMappable(cmap=cmap)
mappable.set_array(Y)
cp = fig.colorbar(mappable, shrink=0.5, aspect=5, label='$I_P$ [A]', orientation='horizontal',  pad=0.01)

cb.ax.plot([0, 1], [real_val]*2, 'black')# [norm(real_val)]*2)
cb.ax.plot([min(predictied_mp_vals), max(predictied_mp_vals)], [sample_mp_val]*2, 'forestgreen')# [norm(real_val)]*2)

verts = [list(zip(x, sample_density_profile.detach().numpy()))]
verts[0][0] = (0, 0)
poly = PolyCollection(verts, facecolors=[mcolors.to_rgba('forestgreen', alpha=0.4)], closed=True)

ax.add_collection3d(poly, zs=[sample_mp_val], zdir='y')

verts = [list(zip(x, real_density_profile.detach().numpy()))]
verts[0][0] = (0, 0)

poly = PolyCollection(verts, facecolors=[mcolors.to_rgba('black', alpha=0.9)], closed=False)

ax.add_collection3d(poly, zs=[real_val], zdir='y')



ax.set(xlabel='R [arb.]', ylabel='$I_P$ [A]')
ax.zaxis.set_rotate_label(False)  # disable automatic rotation
ax.set_zlabel('$n_e \; \; [10^{20}$ m$^{-3}]$', rotation=90)
ax.view_init(30,-115)
plt.show()

In [None]:

"""
Temperature plot
"""
from matplotlib.collections import PolyCollection
from mpl_toolkits.mplot3d import Axes3D

fig = plt.figure(figsize=(25, 12), dpi=400)
ax = fig.add_subplot(1, 2, 2, projection='3d')


x = np.arange(0, 63)# np.repeat([np.arange(0, 63), ], resolution, 0)
y = predictied_mp_vals

X, Y = np.meshgrid(x, y)


Z_old = out_profs_all_t.detach().numpy()
Z = np.zeros_like(Z)
k = 0
for t_prof in Z_old: 
    Z[k] = (de_standardize(t_prof, mu_T.detach().numpy(), var_T.detach().numpy()))
    k += 1
norm = Normalize()
colors = norm(Y)

cmap = plt.cm.get_cmap("coolwarm_r")


surface = ax.plot_surface(X, Y, Z, facecolors=cmap(colors), shade=True, alpha=0.75)

mappable = plt.cm.ScalarMappable(cmap=cmap)
mappable.set_array(Y)
fig.colorbar(mappable, shrink=0.5, aspect=5, label='$I_P$ [A]', orientation='horizontal',  pad=0.01)

verts = [list(zip(x, real_temperature_profile.detach().numpy()))]
# verts[0].extend(list(zip(x, np.zeros_like(x))))
verts[0][0] = (0, 0)

poly = PolyCollection(verts, facecolors=[mcolors.to_rgba('black', alpha=0.6)], closed=False)

ax.add_collection3d(poly, zs=[real_val], zdir='y')

verts = [list(zip(x, sample_temperature_profile.detach().numpy()))]
verts[0][0] = (0, 0)
poly = PolyCollection(verts, facecolors=[mcolors.to_rgba('forestgreen', alpha=0.6)], closed=False)

ax.add_collection3d(poly, zs=[sample_mp_val], zdir='y')

ax.set(xlabel='R [arb.]', ylabel='$I_P$ [A]', zlabel='$T_e \; \; [eV]$')
ax.zaxis.set_rotate_label(False)  # disable automatic rotation
ax.plot(x, np.ones_like(x)*real_val.detach().numpy(), real_temperature_profile.detach().numpy(), c='black', label='\n'.join([f'{key}: {value:.4}' for key, value in label_dict.items()]), linestyle='--', lw=5)
ax.plot(x, np.ones_like(x)*real_val.detach().numpy(), sample_temperature_profile.detach().numpy(), c='forestgreen', label='\n'.join([f'{key}: {value:.4}' for key, value in label_dict.items()]), linestyle='--', lw=5)
ax.view_init(19, -105)
plt.show()

In [None]:
recon = (torch.abs(real_temperature_profile - sample_temperature_profile) / real_temperature_profile *100)
recon[recon == -torch.inf] = 1.0
recon[recon == torch.inf] = 1.0
recon.mean()