## Creating the training dataset from the equilibria dataset

In [None]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1" # disable GPU
import numpy as np
from time import time, sleep
from os.path import join, exists
from utils import *
from tqdm import tqdm
print("Preparing data...")

In [None]:
# hyperparameters
DTYPE = 'float32'
# TAR_GZ_FILE = 'dss/ds0.tar.gz' # ds0 test ds -> 25557 samples
# TAR_GZ_FILE = 'dss/ds6.tar.gz' # ds6 -> 895208 samples 
TAR_GZ_FILE = 'dss/ds0.tar.gz' if LOCAL else 'dss/ds6.tar.gz' 
TMP_DIR = 'tmp' if LOCAL else '/ext/tmp' # where the temporary data will be stored

# hyperparameters
N_SAMPLES = 1000 if LOCAL else 500_000 #850_000 # number of samples to use for training
SM = 1 if LOCAL else 5 # 20 number of grids per samples (SM = SAMPLE MULTIPLIER)
TRAIN_EVAL_SPLIT = 0.8 # percentage of the dataset to use for training

print(f'Total samples: {N_SAMPLES*SM:.0f}, train samples: {N_SAMPLES*SM*TRAIN_EVAL_SPLIT:.0f}, eval samples: {N_SAMPLES*SM*(1-TRAIN_EVAL_SPLIT):.0f}')

In [None]:
# extract the tar.gz file into the tmp directory
if exists(TMP_DIR):
    print(f"Removing {TMP_DIR}...")
    os.system(f"rm -rf {TMP_DIR}")
os.makedirs(TMP_DIR)
os.makedirs(DS_DIR, exist_ok=True)
print(f"Extracting {TAR_GZ_FILE} into {TMP_DIR}...")
assert exists(TAR_GZ_FILE), f"File {TAR_GZ_FILE} does not exist!"
assert os.system(f"tar -xzf {TAR_GZ_FILE} -C {TMP_DIR}") == 0, f"Error extracting {TAR_GZ_FILE} into {TMP_DIR}!"

In [None]:
# load the data
print("Loading data...")

# list all the files inside TMP_DIR/ds
files = sorted([f for f in os.listdir(f'{TMP_DIR}/ds') if f.endswith('.mat')])
print(f'Found {len(files)} files.')
Fxs, Iys, Ias, Bms, Ufs, rqs, zqs = [], [], [], [], [], [], []
files_iter = tqdm(files, desc="Loading files", unit="file") if LOCAL else files
# files_iter = files
for f in files_iter:
    try:
    # if True:
        d = loadmat(join(TMP_DIR, 'ds', f))
        # print(f'file: {f}, keys: {d.keys()}') #  'Bm', 'Fx', 'Ia', 'Ip', 'Iy', 'Uf', 't'
        t, Ip = d['t'].flatten(), d['Ip'].flatten()  # time and plasma current
        sip = np.sign(np.mean(Ip)) # sign of the plasma current
        Fx = d['Fx']  # flux map
        Iy = d['Iy']  # current density map
        Ia = d['Ia']  # coil currents
        Bm = d['Bm']  # magnetic probe measurements
        Uf = d['Uf']  # flux loop poloidal flux
        rq = d['rq']  # LCSF r coordinates
        zq = d['zq']  # LCSF z coordinates

        nt = t.shape[0]  # number of time points
        assert Fx.shape == (28, 65, nt), f'Fx shape mismatch: {Fx.shape} != (28, 65, {nt})'
        assert Iy.shape == (28, 65, nt), f'Iy shape mismatch: {Iy.shape} != (28, 65, {nt})'
        assert Ia.shape == (19, nt), f'Ia shape mismatch: {Ia.shape} != (19, {nt})'
        assert Bm.shape == (38, nt), f'Bm shape mismatch: {Bm.shape} != (38, {nt})'
        assert Uf.shape == (38, nt), f'Uf shape mismatch: {Uf.shape} != (38, {nt})'
        assert rq.shape == (129, nt), f'rq shape mismatch: {rq.shape} != (129, {nt})'
        assert zq.shape == (129, nt), f'zq shape mismatch: {zq.shape} != (129, {nt})'

        # check none of the values are NaN
        assert not np.isnan(Fx).any(), f'Fx contains NaN values: {f}'
        assert not np.isnan(Iy).any(), f'Iy contains NaN values: {f}'
        assert not np.isnan(Ia).any(), f'Ia contains NaN values: {f}'
        assert not np.isnan(Bm).any(), f'Bm contains NaN values: {f}'
        assert not np.isnan(Uf).any(), f'Uf contains NaN values: {f}'
        assert not np.isnan(rq).any(), f'rq contains NaN values: {f}'
        assert not np.isnan(zq).any(), f'zq contains NaN values: {f}'
        # check the values are finite
        assert np.isfinite(Fx).all(), f'Fx contains infinite values: {f}'
        assert np.isfinite(Iy).all(), f'Iy contains infinite values: {f}'
        assert np.isfinite(Ia).all(), f'Ia contains infinite values: {f}'
        assert np.isfinite(Bm).all(), f'Bm contains infinite values: {f}'
        assert np.isfinite(Uf).all(), f'Uf contains infinite values: {f}'
        assert np.isfinite(rq).all(), f'rq contains infinite values: {f}'
        assert np.isfinite(zq).all(), f'zq contains infinite values: {f}'
    
        Fxs.append(Fx), Iys.append(Iy), Ias.append(Ia), Bms.append(Bm), Ufs.append(Uf), rqs.append(rq), zqs.append(zq)
    
    except Exception as e:
        print(f'Error loading {f}: {e}')
        continue

print(f'Loaded {len(Fxs)} files.')
assert len(Fxs) > 0, f'No samples: {len(Fxs)}'

# convert to numpy arrays
Fx = np.concatenate(Fxs, axis=-1).astype(DTYPE).transpose(2,1,0)  # flux map
Iy = np.concatenate(Iys, axis=-1).astype(DTYPE).transpose(2,1,0)  # current density map
Ia = np.concatenate(Ias, axis=-1).astype(DTYPE).transpose(1,0)  # coil currents
Bm = np.concatenate(Bms, axis=-1).astype(DTYPE).transpose(1,0)  # magnetic probe measurements
Uf = np.concatenate(Ufs, axis=-1).astype(DTYPE).transpose(1,0)  # flux loop poloidal flux
rq = np.concatenate(rqs, axis=-1).astype(DTYPE).transpose(1,0)  # LCSF r coordinates
zq = np.concatenate(zqs, axis=-1).astype(DTYPE).transpose(1,0)  # LCSF z coordinates

# concatenate rq and zq
rzq = np.concatenate([rq, zq], axis=1)  # LCSF r and z coordinates
print(f'rzq shape: {rzq.shape}')

assert Fx.shape[0] > 0, f'No samples: {Fx.shape}'

N_OR = Fx.shape[0]  # number of original samples
print(f'Loaded {N_OR} samples.')

# assign to standard values
X = []
if USE_CURRENTS: X.append(Ia)  # coil currents
if USE_MAGNETIC: X.append(Bm)  # magnetic probe measurements
if USE_PROFILES: X.append(Uf)  # flux loop poloidal flux
X = np.concatenate(X, axis=1)  # inputs

assert X.shape == (N_OR, NIN), f'X shape mismatch: {X.shape} != ({N_OR}, {NIN})'
Y1 = Fx
Y2 = Iy  # outputs
Y3 = rzq

# remove the tmp directory
print(f"Removing {TMP_DIR}...")
os.system(f"rm -rf {TMP_DIR}")
print("Data loaded.")

# check the shapes
print(f'Fx shape: {Fx.shape}, Iy shape: {Iy.shape}, Ia shape: {Ia.shape}, Bm shape: {Bm.shape}, Uf shape: {Uf.shape}, rq shape: {rq.shape}, zq shape: {zq.shape}, rzqs shape: {rzq.shape}')
print(f'X shape: {X.shape}, Y1 shape: {Y1.shape}, Y2 shape: {Y2.shape}, Y3 shape: {Y3.shape}')

# print sizes in MB
print(f'Fx size: {Fx.nbytes / 1024**2:.2f} MB, Iy size: {Iy.nbytes / 1024**2:.2f} MB, Ia size: {Ia.nbytes / 1024**2:.2f} MB, Bm size: {Bm.nbytes / 1024**2:.2f} MB, Uf size: {Uf.nbytes / 1024**2:.2f} MB')
    

In [None]:
# plot some examples
n_plot = 3 if LOCAL else 15
rand_idxs = np.random.randint(0, N_OR, n_plot)
for i, ri in enumerate(rand_idxs):
    plt.figure(figsize=(16, 3))
    plt.subplot(1, 5, 1)
    # plt.contourf(RRD, ZZD, Y1[ri], levels=20)
    plt.scatter(RRD, ZZD, c=Y1[ri], s=4)
    # plot th LCFS
    plt.plot(rq[ri], zq[ri], 'gray', lw=2)
    plot_vessel()
    plt.axis('equal'), plt.axis('off')
    plt.title('Fx')
    plt.colorbar()
    plt.subplot(1, 5, 2)
    # plt.contourf(RRD, ZZD, Y2[ri], levels=20)
    plt.scatter(RRD, ZZD, c=Y2[ri], s=4)
    plt.plot(rq[ri], zq[ri], 'gray', lw=2)
    plot_vessel()
    plt.axis('equal'), plt.axis('off')
    plt.title('Iy')
    plt.colorbar()
    plt.subplot(1, 5, 3)
    plt.bar(np.arange(Ia.shape[1]), Ia[ri])
    plt.title('Ia')
    plt.subplot(1, 5, 4)
    plt.bar(np.arange(Bm.shape[1]), Bm[ri])
    plt.title('Bm')
    plt.subplot(1, 5, 5)
    plt.bar(np.arange(Uf.shape[1]), Uf[ri])
    plt.title('Uf')
    plt.tight_layout()
    plt.suptitle(f'SHOT {ri}')
    plt.show() if LOCAL else plt.savefig(f'{DS_DIR}/original_{i}.png')
    plt.close()

In [None]:
# test interpolation
idx = np.random.randint(0, N_OR)
f, rhs = Y1[idx,:,:], Y2[idx,:,:]
rrg, zzg = sample_random_subgrid(RRD,ZZD, NGZ, NGR)
print(f.shape, rhs.shape, rrg.shape, zzg.shape)
box = grid2box(rrg, zzg)
f_grid = interp_fun(Fx[idx,:,:], RRD, ZZD, rrg, zzg)
rhs_grid = interp_fun(rhs, RRD, ZZD, rrg, zzg)

fig,ax = plt.subplots(1,5, figsize=(20,5))
ax[0].scatter(RRD, ZZD, marker='.')
ax[0].scatter(rrg, zzg, marker='.')
plot_vessel(ax[0])
ax[0].set_aspect('equal')

# im1 = ax[1].contourf(RRD, ZZD, f, 20)
im1 = ax[1].scatter(RRD, ZZD, c=f.flatten(), s=4)
ax[1].plot(box[:,0],box[:,1])
plot_vessel(ax[1])
ax[1].set_aspect('equal')

# im2 = ax[2].contourf(rrg, zzg, f_grid, 20)
im2 = ax[2].scatter(rrg, zzg, c=f_grid.flatten(), s=4)
plot_vessel(ax[2])
ax[2].set_aspect('equal')

# im3 = ax[3].contourf(RRD, ZZD, rhs, 20)
im3 = ax[3].scatter(RRD, ZZD, c=rhs.flatten(), s=4)
ax[3].set_aspect('equal')
ax[3].plot(box[:,0],box[:,1])
plot_vessel(ax[3])

# im4 = ax[4].contourf(rrg, zzg, rhs_grid, 20)
im4 = ax[4].scatter(rrg, zzg, c=rhs_grid.flatten(), s=4)
plot_vessel(ax[4])
ax[4].set_aspect('equal')

plt.colorbar(im1,ax=ax[1])
plt.colorbar(im2,ax=ax[2])
plt.colorbar(im3,ax=ax[3])
plt.colorbar(im4,ax=ax[4])

plt.show() if LOCAL else plt.savefig(join(DS_DIR, 'interpolation_example.png'))
plt.close()

In [None]:
# dataset splitting (N_TOP = original dataset size)
NT = int(N_SAMPLES*TRAIN_EVAL_SPLIT)    # training
NE = N_SAMPLES - NT                     # evaluation
NTM, NEM = NT*SM, NE*SM # training and evaluation with multiple grids
print(f"Train -> NT:{NT} NTM:{NTM}")
print(f"Eval  -> NE:{NE} NEM:{NEM}")
orig_idxs = np.random.permutation(N_OR)
orig_idxs_train = orig_idxs[:int(N_OR*TRAIN_EVAL_SPLIT)] # original indices for training
orig_idxs_eval = orig_idxs[int(N_OR*TRAIN_EVAL_SPLIT):] # original indices for evaluation
# splitting the idxs
assert len(orig_idxs_train) > NT, f"Training set is too small, {len(orig_idxs_train)} < {NT}"
idxs_t = np.random.choice(orig_idxs_train, NT, replace=False) # can overlap with idxs_tf
assert len(orig_idxs_eval) > NE, f"Evaluation set is too small, {len(orig_idxs_eval)} < {NE}"
idxs_e = np.random.choice(orig_idxs_eval, NE, replace=False) # can overlap with idxs_ef

In [None]:
# create arrays to store the dataset
print(f"Preallocating arrays for the dataset...")

xt =   np.zeros((NTM, NIN), dtype=DTYPE)
y1t =   np.zeros((NTM, NGZ, NGR), dtype=DTYPE)
y2t =  np.zeros((NTM, NGZ, NGR), dtype=DTYPE)
y3t =  np.zeros((NTM, 2*NLCFS), dtype=DTYPE)
rt =   np.zeros((NTM, NGR), dtype=DTYPE)
zt =   np.zeros((NTM, NGZ), dtype=DTYPE)

xe =   np.zeros((NEM, NIN), dtype=DTYPE)
y1e =   np.zeros((NEM, NGZ, NGR), dtype=DTYPE)
y2e =  np.zeros((NEM, NGZ, NGR), dtype=DTYPE)
y3e =  np.zeros((NEM, 2*NLCFS), dtype=DTYPE)
re =   np.zeros((NEM, NGR), dtype=DTYPE)
ze =   np.zeros((NEM, NGZ), dtype=DTYPE)

# estimate RAM usage
ram_usage = sum(arr.nbytes for arr in [xt, y1t, y2t, y3t, rt, zt, xe, y1e, y2e, y3e, re, ze]) / 1024**3 
print(f"Estimated RAM usage: {ram_usage:.2f} GB\nFilling arrays...")

## fill the arrays
print_every = 2000
start_time = time()
for i, idx in enumerate(idxs_t):
    rrs, zzs = np.zeros((SM,NGZ,NGR), dtype=DTYPE), np.zeros((SM,NGZ,NGR), dtype=DTYPE)
    for j in range(SM): rrs[j], zzs[j] = sample_random_subgrid(RRD,ZZD,NGZ,NGR)
    si, ei = i*SM, (i+1)*SM # start and end idxs
    xt[si:ei] = X[idx]
    y1t[si:ei] = interp_fun(Y1[idx], RRD, ZZD, rrs, zzs)
    y2t[si:ei] = interp_fun(Y2[idx], RRD, ZZD, rrs, zzs)
    y3t[si:ei] = Y3[idx]
    rt[si:ei], zt[si:ei] = rrs[:,0,:], zzs[:,:,0] # save only the first raw/col
    if (i+1) % print_every == 0: print(f"Train -> {100*(i+1)*SM/NTM:.1f}%, eta: {((time()-start_time)/(i+1)*(NT-i-1))/60:.1f} min")

start_time = time()
for i, idx in enumerate(idxs_e):
    rrs, zzs = np.zeros((SM,NGZ,NGR), dtype=DTYPE), np.zeros((SM,NGZ,NGR), dtype=DTYPE)
    for j in range(SM): rrs[j], zzs[j] = sample_random_subgrid(RRD,ZZD,NGZ,NGR)
    si, ei = i*SM, (i+1)*SM # start and end idxs
    xe[si:ei] = X[idx]
    y1e[si:ei] = interp_fun(Y1[idx], RRD, ZZD, rrs, zzs)
    y2e[si:ei] = interp_fun(Y2[idx], RRD, ZZD, rrs, zzs)
    y3e[si:ei] = Y3[idx]
    re[si:ei], ze[si:ei] = rrs[:,0,:], zzs[:,:,0] # save only the first raw/col
    if (i+1) % print_every == 0: print(f"Eval -> {100*(i+1)*SM/NEM:.1f}%, eta: {((time()-start_time)/(i+1)*(NE-i-1))/60:.1f} min")

print(f"xt: {xt.shape}, y1t: {y1t.shape}, y2t: {y2t.shape}, y3t: {y3t.shape}, rt: {rt.shape}, zt: {zt.shape}")
print(f"xe: {xe.shape}, y1e: {y1e.shape}, y2e: {y2e.shape}, y3e: {y3e.shape}, re: {re.shape}, ze: {ze.shape}")

In [None]:
# calculate kernels for Grad-Shafranov equation # NOTE: not needed actually, but here to be tested
# so we don't have to do it during training
print("Calculating kernels...")
laplace_ker_t = np.zeros((len(xt[0]), 3, 3), dtype=DTYPE)
laplace_ker_e = np.zeros((len(xe[0]), 3, 3), dtype=DTYPE)
df_dr_ker_t = np.zeros((len(xt[0]), 3, 3), dtype=DTYPE)
df_dr_ker_e = np.zeros((len(xe[0]), 3, 3), dtype=DTYPE)
# hrs_t, hzs_t = rt[:,1,2]-rt[:,1,1], zt[:,2,1]-zt[:,1,1]
# hrs_e, hzs_e = re[:,1,2]-re[:,1,1], ze[:,2,1]-ze[:,1,1]
hrs_e, hzs_e = re[:,2]-re[:,1], ze[:,2]-ze[:,1]
hrs_t, hzs_t = rt[:,2]-rt[:,1], zt[:,2]-zt[:,1]
for i in range(len(xt[0])):
    try:
        laplace_ker_t[i,:,:], df_dr_ker_t[i,:,:] = calc_laplace_df_dr_ker(hrs_t[i], hzs_t[i])
    except Exception as e:
        print(f"Error calculating laplace_ker_t for index {i}: {e}")
        plt.figure()
        plt.scatter(rt[i], zt[i], marker='.')
        plt.title(f"rt[{i}]")
        plt.axis('equal')
        plt.show() if LOCAL else plt.savefig(f'{DS_DIR}/rr_train_{i}.png')
        plt.close()
        break

for i in range(len(xe[0])):
    laplace_ker_e[i,:,:], df_dr_ker_e[i,:,:] = calc_laplace_df_dr_ker(hrs_e[i], hzs_e[i])

In [None]:
# check the dataset
print("Checking the dataset...")
rows = 5
idxs_train = np.random.randint(0, len(xt[0]), rows)
idxs_eval = np.random.randint(0, len(xe[0]), rows)
fig,ax = plt.subplots(rows,6, figsize=(20,4*rows))
box0 = grid2box(RRD, ZZD)
for i, (it, ie)  in enumerate(zip(idxs_train, idxs_eval)):
    # training
    boxi = grid2box(rt[it], zt[it])
    ax[i,0].plot(box0[:,0], box0[:,1])
    ax[i,0].plot(boxi[:,0], boxi[:,1])
    ax[i,0].set_aspect('equal')
    ax[i,0].set_title(f"Train {it}")
    a1 = ax[i,1].contourf(rt[it], zt[it], y1t[it], 20)
    ax[i,1].plot(box0[:,0], box0[:,1])
    ax[i,1].plot(y3t[it,:NLCFS], y3t[it,NLCFS:], 'gray', lw=2)
    ax[i,1].set_aspect('equal')
    plt.colorbar(a1,ax=ax[i,1])
    a2 = ax[i,2].contourf(rt[it], zt[it], y2t[it], 20)
    ax[i,2].plot(box0[:,0], box0[:,1])
    ax[i,2].set_aspect('equal')
    plt.colorbar(a2,ax=ax[i,2])
    # evaluation
    boxi = grid2box(re[ie], ze[ie])
    ax[i,3].plot(box0[:,0], box0[:,1])
    ax[i,3].plot(boxi[:,0], boxi[:,1])
    ax[i,3].set_aspect('equal')
    ax[i,3].set_title(f"Eval {ie}")
    a1 = ax[i,4].contourf(re[ie], ze[ie], y1e[ie], 20)
    ax[i,4].plot(box0[:,0], box0[:,1])
    ax[i,4].plot(y3e[ie,:NLCFS], y3e[ie,NLCFS:], 'gray', lw=2)
    ax[i,4].set_aspect('equal')
    plt.colorbar(a1,ax=ax[i,4])
    a2 = ax[i,5].contourf(re[ie], ze[ie], y2e[ie], 20)
    ax[i,5].plot(box0[:,0], box0[:,1])
    ax[i,5].set_aspect('equal')
    plt.colorbar(a2,ax=ax[i,5])
plt.show() if LOCAL else plt.savefig(join(DS_DIR, 'dataset_check.png'))
plt.close()

In [None]:
from utils import calc_gso, calc_gso_batch
import torch
print("Checking the Grad-Shafranov operator...")
n_plots = 7
idxs = np.random.randint(0, len(xt[0]), n_plots)
psis, rhss = y1t[idxs], y2t[idxs]
rs, zs = rt[idxs], zt[idxs]
big_box = grid2box(RRD, ZZD)
#batched version
psist = torch.tensor(psis, dtype=torch.float32).view(n_plots, 1, NGZ, NGR)
rst = torch.tensor(rs, dtype=torch.float32).view(n_plots, NGR)
zst = torch.tensor(zs, dtype=torch.float32).view(n_plots, NGZ)
print(f'psi: {psist.shape}, r: {rst.shape}, z: {zst.shape}')
gsos = calc_gso_batch(psist, rst, zst)
print(f'gsos: {gsos.shape}')
gsos = gsos.view(n_plots, NGZ, NGR).numpy()
# single version
for i in range(n_plots):
    psi, r, z, rhs = psis[i], rs[i], zs[i], rhss[i]
    box = grid2box(r, z)
    gso = calc_gso(psi, r, z) # calculate the Grad-Shafranov operator
    gso2 = gsos[i]
    #plot error gso vs gso2
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    rr, zz = np.meshgrid(r, z)
    im = ax.contourf(rr, zz, np.abs(gso-gso2), 20)
    ax.plot(big_box[:,0], big_box[:,1])
    ax.set_aspect('equal')
    ax.set_title(f"Error batch/no batch {i}")
    plt.colorbar(im, ax=ax)
    plt.show() if LOCAL else plt.savefig(join(DS_DIR, f'gso_error_{i}.png'))
    plt.close()
    # NOTE: the error between the batched and non-batched version can be non-zero due to different
    # implementations in gpu
    print(f"max error batch/no batch: {np.abs(gso-gso2).max()}")
    # assert np.allclose(gso, gso2, rtol=1e-2), f"Error in the calculation of the Grad-Shafranov operator: \ngso:\n{gso}, \ngso2:\n{gso2}"
    # psi, gso, rhs = psi[1:-1,1:-1], gso[1:-1,1:-1], rhs[1:-1,1:-1]
    # rr, zz = rr[1:-1,1:-1], zz[1:-1,1:-1] 
    fig,ax = plt.subplots(1,5, figsize=(20,5))
    ax[0].plot(big_box[:,0], big_box[:,1])
    ax[0].plot(box[:,0], box[:,1])
    ax[0].set_aspect('equal')
    ax[0].set_xticks([]), ax[0].set_yticks([])
    ax[0].set_title(f"Train {idxs}")
    im1 = ax[1].contourf(rr, zz, psi, 20)
    ax[1].plot(big_box[:,0], big_box[:,1])
    ax[1].set_aspect('equal')
    ax[1].set_xticks([]), ax[1].set_yticks([])
    ax[1].set_title("Ψ")
    im2 = ax[2].contourf(rr, zz, -gso, 20)
    ax[2].plot(big_box[:,0], big_box[:,1])
    ax[2].set_aspect('equal')
    ax[2].set_xticks([]), ax[2].set_yticks([])
    ax[2].set_title("GSO recalculated")
    im3 = ax[3].contourf(rr, zz, -rhs, 20)
    ax[3].plot(big_box[:,0], big_box[:,1])
    ax[3].set_aspect('equal')
    ax[3].set_xticks([]), ax[3].set_yticks([])
    ax[3].set_title("GSO from dataset")
    im4 = ax[4].contourf(rr, zz, np.abs(gso-rhs), 20)
    ax[4].plot(big_box[:,0], big_box[:,1])
    ax[4].set_aspect('equal')
    ax[4].set_xticks([]), ax[4].set_yticks([])
    ax[4].set_title("Absolute error")
    plt.colorbar(im1,ax=ax[1])
    plt.colorbar(im2,ax=ax[2])
    plt.colorbar(im3,ax=ax[3])
    plt.colorbar(im4,ax=ax[4])
    plt.show() if LOCAL else plt.savefig(join(DS_DIR, f'gso_check_{i}.png'))
    plt.close()

In [None]:
assert y1t.shape[1:] == (NGR, NGZ), f"xt shape mismatch: {xt[0].shape[1:]} != ({NGR}, {NGZ})"
assert y1e.shape[1:] == (NGR, NGZ), f"xe shape mismatch: {xe[0].shape[1:]} != ({NGR}, {NGZ})"

In [None]:
# load dataset and normalize, NOTE: y3 does not require normalization
print("Normalizing the dataset...")
μx, Σx = np.mean(xt, axis=0), np.std(xt, axis=0)
# μy1, Σy1 = np.mean(y1t, axis=0), np.std(y1t, axis=0)
# μy2, Σy2 = np.mean(y2t, axis=0), np.std(y2t, axis=0)
μy1, Σy1 = np.mean(y1t), np.std(y1t)
μy2, Σy2 = np.mean(y2t), np.std(y2t)

# normalize (NOTE: both with the training means and stds)
xt, xe = (xt - μx) / Σx, (xe - μx) / Σx
y1t, y1e = (y1t - μy1) / Σy1, (y1e - μy1) / Σy1
y2t, y2e = (y2t - μy2) / Σy2, (y2e - μy2) / Σy2

print(f'μx: {μx.shape}, Σx: {Σx.shape}')
print(f'μy1: {μy1.shape}, Σy1: {Σy1.shape}')
print(f'μy2: {μy2.shape}, Σy2: {Σy2.shape}')

x_mean_std = np.concatenate([μx, Σx], axis=0)
# y1_mean_std = np.concatenate([μy1, Σy1], axis=0)
# y2_mean_std = np.concatenate([μy2, Σy2], axis=0)
y1_mean_std = np.array([μy1, Σy1])
y2_mean_std = np.array([μy2, Σy2])

In [None]:
# save dataset as numpy compressed
print(f"Saving datasets to {TRAIN_DS_PATH} and {EVAL_DS_PATH}...")
try:
    np.savez_compressed(TRAIN_DS_PATH, X=xt, Y1=y1t, Y2=y2t, Y3=y3t, r=rt, z=zt, x_mean_std=x_mean_std, y1_mean_std=y1_mean_std, y2_mean_std=y2_mean_std)
    np.savez_compressed(EVAL_DS_PATH, X=xe, Y1=y1e, Y2=y2e, Y3=y3e, r=re, z=ze, x_mean_std=x_mean_std, y1_mean_std=y1_mean_std, y2_mean_std=y2_mean_std)
except Exception as e:
    print(f"Error saving datasets: {e}")
    raise e
print(f"Datasets saved.")

In [None]:
# plot some examples
print("Testing the dataset...")
tds, eds = np.load(TRAIN_DS_PATH), np.load(EVAL_DS_PATH)
print(f'train_ds: {tds.keys()}')
print(f'eval_ds: {eds.keys()}')
# plot some examples
rs, zs, xs, y1s, y2s, y3s = tds['r'], tds['z'], tds['X'], tds['Y1'], tds['Y2'], tds['Y3']
print(f'rs shape: {rs.shape}, zs shape: {zs.shape}, xs shape: {xs.shape}, y1s shape: {y1s.shape}, y2s shape: {y2s.shape}, y3s shape: {y3s.shape}')

n_plot = 3 if LOCAL else 100
rand_idxs = np.random.randint(0, NT, n_plot)
for i, ri in enumerate(rand_idxs):
    r, z, x, y1, y2, y3 = rs[ri], zs[ri], xs[ri], y1s[ri], y2s[ri], y3s[ri]
    rr, zz = np.meshgrid(r, z)
    plt.figure(figsize=(16, 4))
    plt.subplot(1, 5, 1)
    # plt.contourf(rr, zz, y1, levels=20)
    plt.scatter(rr, zz, c=y1.flatten(), marker='.')
    plt.plot(y3[:NLCFS], y3[NLCFS:], 'gray', lw=2)
    plot_vessel()
    plt.axis('equal'), plt.axis('off')
    plt.title('Y1')
    plt.colorbar()
    plt.subplot(1, 5, 2)
    # plt.contourf(rr, zz, y2, levels=20)
    plt.scatter(rr, zz, c=y2.flatten(), marker='.')
    plt.plot(y3[:NLCFS], y3[NLCFS:], 'gray', lw=2)
    plot_vessel()
    plt.axis('equal'), plt.axis('off')
    plt.title('Y2')
    plt.colorbar()
    plt.subplot(1, 5, (3,5))
    plt.bar(np.arange(x.shape[0]), x)
    plt.title('X')
    plt.tight_layout()
    plt.suptitle(f'SHOT {ri}')
    plt.show() if LOCAL else plt.savefig(f'{DS_DIR}/ds_{i}.png')
    plt.close()

In [None]:
print('Done! Space used:')
os.system(f'du -h {TRAIN_DS_PATH}')
os.system(f'du -h {EVAL_DS_PATH}')
assert os.path.exists(TRAIN_DS_PATH), f"Dataset not saved: {TRAIN_DS_PATH}"
assert os.path.exists(EVAL_DS_PATH), f"Dataset not saved: {EVAL_DS_PATH}"
print(f"{JOBID} done")
if not LOCAL: sleep(30) # wait for files to update (for cluster)