## Creating the training dataset from the equilibria dataset

In [None]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"]="-1" # disable GPU
import numpy as np
from time import time, sleep
from os.path import join, exists
from utils import *
from tqdm import tqdm
print("Preparing data...")

In [None]:
# hyperparameters
RAW_DIR = 'raw_ds' if LOCAL else '/nfsd/automatica/grandinmat/raw_ds' # where the raw data is stored

# hyperparameters
N_SAMPLES = 1000 if LOCAL else 200_000 #850_000 # number of samples to use for training
SM = 1000 if LOCAL else 3_000 # 20 # number of points per sample (SM = SAMPLE MULTIPLIER)
TRAIN_EVAL_SPLIT = 0.8 # percentage of the dataset to use for training

print(f'Total samples: {N_SAMPLES:.0f}, train samples: {N_SAMPLES*TRAIN_EVAL_SPLIT:.0f}, eval samples: {N_SAMPLES*(1-TRAIN_EVAL_SPLIT):.0f}')

In [None]:
# check RAW_DIR exists and is not empty
assert exists(RAW_DIR), f'RAW_DIR {RAW_DIR} does not exist'
assert len(os.listdir(RAW_DIR)) > 0, f'RAW_DIR {RAW_DIR} is empty'

In [None]:
# load the data
print("Loading data...")

# list all the files inside RAW_DIR/ds
files = sorted([f for f in os.listdir(RAW_DIR) if f.endswith('.mat')])
print(f'Found {len(files)} files.')
vals = {n:[] for n in DS_NAMES}  # dictionary to hold the numpy arrays
files_iter = tqdm(files, desc="Loading files", unit="file") if LOCAL else files
for f in files_iter:
    try:
    # if True:
        d = loadmat(join(RAW_DIR, f))
        for n in DS_NAMES: assert n in list(d.keys()), f'Key {n} not found in {f}' 
        t = d['t'].flatten()  # time
        nt = t.shape[0]  # number of time points
        for n in DS_NAMES:
            if n in [BR, BZ]: continue  # skip BR and BZ for now, they will be calculated later
            v = d[n] # get the variable
            v = v.reshape((*DS_SIZES[n], nt)) # try to reshape it to the expected shape
            assert v.shape == (*DS_SIZES[n], nt), f'Variable {n} shape mismatch: {v.shape} != {((*DS_SIZES[n], nt))}'
            assert not np.isnan(v).any(), f'Variable {n} contains NaN values in {f}'
            assert np.isfinite(v).all(), f'Variable {n} contains infinite values in {f}'
            vals[n].append(v)  # append to the list
    except Exception as e:
        print(f'Error loading {f}: {e}')
        continue

print(f'Loaded {len(vals[FX])} files.')
assert len(vals[FX]) > 0, f'No samples: {len(vals[FX])}'

# convert lists to numpy arrays
print('---------------------------------------------------')
for n in DS_NAMES:
    if n in [BR, BZ]: continue  # skip BR and BZ for now, they will be calculated later
    vals[n] = np.concatenate(vals[n], axis=-1).astype(DTYPE)  # concatenate along the first dimension and convert to DTYPE
    sh, ndim = vals[n].shape, vals[n].ndim
    vals[n] = vals[n].transpose(ndim-1, *range(ndim-1))  # move the last dimension to the first
    # print(f'{n} -> shape: {vals[n].shape}, dtype: {vals[n].dtype}')  # print the shape and dtype of the variable
print('---------------------------------------------------')

# calculate BR and BZ from Fx
print('Calculating BR and BZ...')
vals[BR], vals[BZ] = meqBrBz(vals[FX])

for n in DS_NAMES: print(f'{n} -> shape: {vals[n].shape}, dtype: {vals[n].dtype}')  # print the shape and dtype of the variable

for n in INPUT_NAMES: assert vals[n].ndim == 2, f'Variable {n} is not 2D: {vals[n].ndim}'

# assign values to variables
X = np.concatenate([vals[n] for n in INPUT_NAMES], axis=1)  # concatenate along the first dimension
print(f'-----------------------------------------\n X -> shape: {X.shape}, dtype: {X.dtype}')

# make IY the same shape as FX (and pad the sides)
tmpIy = np.zeros_like(vals[FX])  # create a zero array with the same shape as FX
tmpIy[:, 1:-1, 1:-1] = vals[IY]  # fill the inner part with IY values
# fill sides 
tmpIy[:, 0, 1:-1] = vals[IY][:, 0, :]  # left side
tmpIy[:, -1, 1:-1] = vals[IY][:, -1, :]  # right side
tmpIy[:, 1:-1, 0] = vals[IY][:, :, 0]  # top side
tmpIy[:, 1:-1, -1] = vals[IY][:, :, -1]  # bottom side
# corners
tmpIy[:, 0, 0] = vals[IY][:, 0, 0]  # top left corner
tmpIy[:, 0, -1] = vals[IY][:, 0, -1]  # top right corner
tmpIy[:, -1, 0] = vals[IY][:, -1, 0]  # bottom left corner
tmpIy[:, -1, -1] = vals[IY][:, -1, -1]  # bottom right corner
vals[IY] = tmpIy  # replace IY with the new array

# reassign the values to the variables
Y = {  
    FX:vals[FX], 
    IY:vals[IY], 
    BR:vals[BR],
    BZ:vals[BZ],
    RQ:vals[RQ],
    ZQ:vals[ZQ],
}

assert Y[FX].shape[0] > 0, f'No samples: {Y[FX].shape}'
N_OR = Y[FX].shape[0]  # number of original samples
print(f'Loaded {N_OR} samples.')

assert X.dtype == DTYPE, f'X dtype mismatch: {X.dtype} != {DTYPE}'
assert all(y.dtype == DTYPE for y in Y.values()), f'Y dtype mismatch: {[y.dtype for y in Y.values()]} != {DTYPE}'

assert X.shape == (N_OR, NIN), f'X shape mismatch: {X.shape} != ({N_OR}, {NIN})'

# check the shapes
print(f'X -> {X.shape}')
print(f'Y -> {[y.shape for y in Y.values()]}')

# print sizes in MB
print(f'X size: {X.nbytes / 1024**2:.2f} MB, Y size: {sum([y.nbytes for y in Y.values()]) / 1024**2:.2f} MB')

In [None]:
# plot some examples
n_plot = 3 if LOCAL else 15
rand_idxs = np.random.randint(0, N_OR, n_plot)
for i, ri in enumerate(rand_idxs):
    plt.figure(figsize=(15, 8))
    plt.subplot(2, 4, 1)
    plt.scatter(RRD, ZZD, c=vals[FX][ri], s=4), plt.title('FX')
    plt.plot(vals[RQ][ri], vals[ZQ][ri], 'gray', lw=2)
    plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
    plt.subplot(2, 4, 2)
    plt.scatter(RRD, ZZD, c=vals[IY][ri], s=4), plt.title('IY')
    plt.plot(vals[RQ][ri], vals[ZQ][ri], 'gray', lw=2)
    plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
    plt.subplot(2, 4, 3)
    plt.scatter(RRD, ZZD, c=vals[BR][ri], s=4), plt.title('BR')
    plt.plot(vals[RQ][ri], vals[ZQ][ri], 'gray', lw=2)
    plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
    plt.subplot(2, 4, 4)
    plt.scatter(RRD, ZZD, c=vals[BZ][ri], s=4), plt.title('BZ')
    plt.plot(vals[RQ][ri], vals[ZQ][ri], 'gray', lw=2)
    plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
    plt.subplot(2, 4, 5)
    plt.bar(np.arange(vals[BM].shape[1]), vals[BM][ri]), plt.title('BM')
    plt.subplot(2, 4, 6)
    plt.bar(np.arange(vals[FF].shape[1]), vals[FF][ri]), plt.title('FF')
    plt.subplot(2, 4, 7)
    plt.bar(np.arange(vals[IA].shape[1]), vals[IA][ri]), plt.title('IA')
    plt.subplot(2, 4, 8)
    plt.bar(np.arange(vals[IU].shape[1]), vals[IU][ri]), plt.title('IU')
    plt.tight_layout()
    plt.suptitle(f'SHOT {ri}')
    plt.show() if LOCAL else plt.savefig(f'{DS_DIR}/imgs/original_{i}.png')
    plt.close()
del vals  # free memory

In [None]:
# test interpolation
idx = np.random.randint(0, N_OR)
pts = sample_random_points(2000)
intFx = interp_pts(Y[FX][idx], pts)
intIy = interp_pts(Y[IY][idx], pts)
intBr = interp_pts(Y[BR][idx], pts)
intBz = interp_pts(Y[BZ][idx], pts)

plt.figure(figsize=(15, 8))
ms = 2
plt.subplot(2, 4, 1)
plt.scatter(RRD, ZZD, c=Y[FX][idx], s=ms), plt.title('FX')
plt.plot(Y[RQ][idx], Y[ZQ][idx], 'gray', lw=2)
plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
plt.subplot(2, 4, 5)
plt.scatter(pts[:,0], pts[:,1], c=intFx, s=ms), plt.title('Interpolated FX')
plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
plt.subplot(2, 4, 2)
plt.scatter(RRD, ZZD, c=Y[IY][idx], s=ms), plt.title('IY')
plt.plot(Y[RQ][idx], Y[ZQ][idx], 'gray', lw=2)
plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
plt.subplot(2, 4, 6)
plt.scatter(pts[:,0], pts[:,1], c=intIy, s=ms), plt.title('Interpolated IY')
plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
plt.subplot(2, 4, 3)
plt.scatter(RRD, ZZD, c=Y[BR][idx], s=ms), plt.title('BR')
plt.plot(Y[RQ][idx], Y[ZQ][idx], 'gray', lw=2)
plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
plt.subplot(2, 4, 7)
plt.scatter(pts[:,0], pts[:,1], c=intBr, s=ms), plt.title('Interpolated BR')
plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
plt.subplot(2, 4, 4)
plt.scatter(RRD, ZZD, c=Y[BZ][idx], s=ms), plt.title('BZ')
plt.plot(Y[RQ][idx], Y[ZQ][idx], 'gray', lw=2)
plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
plt.subplot(2, 4, 8)
plt.scatter(pts[:,0], pts[:,1], c=intBz, s=ms), plt.title('Interpolated BZ')
plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
plt.suptitle(f'SHOT {idx} - Interpolation Example')
plt.tight_layout()
plt.show() if LOCAL else plt.savefig(join(DS_DIR, 'imgs', 'interpolation_example.png'))
plt.close()

In [None]:
# dataset splitting (N_TOP = original dataset size)
NT = int(N_SAMPLES*TRAIN_EVAL_SPLIT)    # training
NE = N_SAMPLES - NT                     # evaluation
print(f"Train -> NT:{NT}")
print(f"Eval  -> NE:{NE}")
orig_idxs = np.random.permutation(N_OR)
orig_idxs_train = orig_idxs[:int(N_OR*TRAIN_EVAL_SPLIT)] # original indices for training
orig_idxs_eval = orig_idxs[int(N_OR*TRAIN_EVAL_SPLIT):] # original indices for evaluation
# splitting the idxs
assert len(orig_idxs_train) > NT, f"Training set is too small, {len(orig_idxs_train)} < {NT}"
idxs_t = np.random.choice(orig_idxs_train, NT, replace=False) 
assert len(orig_idxs_eval) > NE, f"Evaluation set is too small, {len(orig_idxs_eval)} < {NE}"
idxs_e = np.random.choice(orig_idxs_eval, NE, replace=False) 

In [None]:
# create arrays to store the dataset
print(f"Preallocating arrays for the dataset...")

vt = {
    PHYS:   np.zeros((NT, NIN), dtype=DTYPE),
    PTS:    np.zeros((NT, SM, 2), dtype=DTYPE),
    FX:     np.zeros((NT, SM), dtype=DTYPE),
    IY:     np.zeros((NT, SM), dtype=DTYPE),
    BR:     np.zeros((NT, SM), dtype=DTYPE),
    BZ:     np.zeros((NT, SM), dtype=DTYPE),
    SEP:    np.zeros((NT, 2*NLCFS), dtype=DTYPE),
}
ve = {
    PHYS:   np.zeros((NE, NIN), dtype=DTYPE),
    PTS:    np.zeros((NE, SM, 2), dtype=DTYPE),
    FX:     np.zeros((NE, SM), dtype=DTYPE),
    IY:     np.zeros((NE, SM), dtype=DTYPE),
    BR:     np.zeros((NE, SM), dtype=DTYPE),
    BZ:     np.zeros((NE, SM), dtype=DTYPE),
    SEP:    np.zeros((NE, 2*NLCFS), dtype=DTYPE),
    FG:     np.zeros((NE, 4, 65, 28), dtype=DTYPE), # store the full grid maps for evaluation
}

# estimate RAM usage
print(f"Estimated RAM usage: {sum(arr.nbytes for arr in vt.values()) + sum(arr.nbytes for arr in ve.values()) / (1024**3):.2f} GB")
print("Filling arrays...")

## fill the arrays
print_every = 2000
start_time = time()
for i, idx in enumerate(idxs_t):
    vt[PHYS][i] = X[idx] # physical inputs
    vt[SEP][i] = np.concatenate([Y[RQ][idx], Y[ZQ][idx]], axis=0)  # LCFS/separatrix points
    pts = sample_random_points(SM)  # sample random points
    vt[PTS][i] = pts
    vt[FX][i] = interp_pts(Y[FX][idx], pts)  
    vt[IY][i] = interp_pts(Y[IY][idx], pts)  
    vt[BR][i] = interp_pts(Y[BR][idx], pts)  
    vt[BZ][i] = interp_pts(Y[BZ][idx], pts)   
    if (i+1) % print_every == 0: print(f"Train -> {100*(i+1)/NT:.1f}%, eta: {((time()-start_time)/(i+1)*(NT-i-1))/60:.1f} min")

start_time = time()
for i, idx in enumerate(idxs_e):
    ve[PHYS][i] = X[idx] # physical inputs
    ve[SEP][i] = np.concatenate([Y[RQ][idx], Y[ZQ][idx]], axis=0)  # LCFS/separatrix points
    pts = sample_random_points(SM)  # sample random points
    ve[PTS][i] = pts
    ve[FX][i] = interp_pts(Y[FX][idx], pts)  
    ve[IY][i] = interp_pts(Y[IY][idx], pts)  
    ve[BR][i] = interp_pts(Y[BR][idx], pts)  
    ve[BZ][i] = interp_pts(Y[BZ][idx], pts)  
    ve[FG][i] = (Y[FX][idx], Y[IY][idx], Y[BR][idx], Y[BZ][idx])  # store the full maps for evaluation
    if (i+1) % print_every == 0: print(f"Eval -> {100*(i+1)/NE:.1f}%, eta: {((time()-start_time)/(i+1)*(NE-i-1))/60:.1f} min")
# print shapes
print(f"Train -> PHYS:{vt[PHYS].shape}, PTS:{vt[PTS].shape}, FX:{vt[FX].shape}, IY:{vt[IY].shape}, BR:{vt[BR].shape}, BZ:{vt[BZ].shape}, SEP:{vt[SEP].shape}")
print(f"Eval  -> PHYS:{ve[PHYS].shape}, PTS:{ve[PTS].shape}, FX:{ve[FX].shape}, IY:{ve[IY].shape}, BR:{ve[BR].shape}, BZ:{ve[BZ].shape}, SEP:{ve[SEP].shape}, FG:{ve[FG].shape}")

In [None]:
# load dataset and normalize, NOTE: y3 does not require normalization
print("Getting normalization constants for the dataset...")
μx, Σx = np.mean(vt[PHYS], axis=0), np.std(vt[PHYS], axis=0)  # mean and std of the physical inputs
# normalize (NOTE: both with the training means and stds)
# xt, xe = (xt - μx) / Σx, (xe - μx) / Σx # not normalizing the inputs bc we added a normalization layer in the network architecture
print(f'μx: {μx.shape}, Σx: {Σx.shape}')
x_mean_std = np.array([μx, Σx])
print(f'x_mean_std: {x_mean_std.shape}')

In [None]:
# save dataset as numpy compressed
print(f"Saving datasets to {TRAIN_DS_PATH} and {EVAL_DS_PATH}...")
try:
    np.savez_compressed(TRAIN_DS_PATH, **vt, x_mean_std=x_mean_std)
    np.savez_compressed(EVAL_DS_PATH, **ve, x_mean_std=x_mean_std)
    print(f"Datasets saved to {TRAIN_DS_PATH} and {EVAL_DS_PATH}")
except Exception as e:
    print(f"Error saving datasets: {e}")
    raise e

In [None]:
# save a .mat demo file for testing
print(f"Saving demo file to {DS_DIR}/demo.mat and {DS_DIR}/demo.npz")
rand_idxs = np.random.randint(0, len(vt[PHYS]), 100)  # random indices for the demo
demo = {PHYS:vt[PHYS][rand_idxs], PTS:vt[PTS][rand_idxs], FX:vt[FX][rand_idxs], IY:vt[IY][rand_idxs], BR:vt[BR][rand_idxs], BZ:vt[BZ][rand_idxs], SEP:vt[SEP][rand_idxs]}
try:
    savemat(join(DS_DIR, 'demo.mat'), demo)
    np.savez_compressed(join(DS_DIR, 'demo.npz'), **demo)
except Exception as e:
    print(f"Error saving demo file: {e}")
    raise e

In [None]:
# plot some examples
print("Testing the dataset...")
tds, eds = np.load(TRAIN_DS_PATH), np.load(EVAL_DS_PATH)
print(f'train_ds: {tds.keys()}')
print(f'eval_ds: {eds.keys()}')
# plot some examples
n_plot = 3 if LOCAL else 15
for ds, ds_name in zip([tds, eds], ['Train', 'Eval']):
    for i in range(n_plot):
        r, z = ds[PTS][i, :, 0], ds[PTS][i, :, 1]  # get the points
        x = (ds[PHYS][i]-ds['x_mean_std'][0]) / ds['x_mean_std'][1]  # normalize the physical inputs
        plt.figure(figsize=(15, 4))
        plt.subplot(1, 5, 1)
        plt.scatter(r, z, c=ds[FX][i], s=4), plt.title('FX')
        plt.plot(ds[SEP][i, :NLCFS], ds[SEP][i, NLCFS:], 'gray', lw=2)
        plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
        plt.subplot(1, 5, 2)
        plt.scatter(r, z, c=ds[IY][i], s=4), plt.title('IY')
        plt.plot(ds[SEP][i, :NLCFS], ds[SEP][i, NLCFS:], 'gray', lw=2)
        plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
        plt.subplot(1, 5, 3)
        plt.scatter(r, z, c=ds[BR][i], s=4), plt.title('BR')
        plt.plot(ds[SEP][i, :NLCFS], ds[SEP][i, NLCFS:], 'gray', lw=2)
        plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
        plt.subplot(1, 5, 4)
        plt.scatter(r, z, c=ds[BZ][i], s=4), plt.title('BZ')
        plt.plot(ds[SEP][i, :NLCFS], ds[SEP][i, NLCFS:], 'gray', lw=2)
        plot_vessel(), plt.axis('equal'), plt.axis('off'), plt.colorbar()
        plt.subplot(1, 5, 5)
        plt.bar(np.arange(len(x)), x), plt.title('Physical Inputs')
        plt.suptitle(f'SHOT {i} - {ds_name} Example')
        plt.tight_layout()
        plt.show() if LOCAL else plt.savefig(join(DS_DIR, 'imgs', f'{ds_name.lower()}_example_{i}.png'))
        plt.close()

In [None]:
print('Done! Space used:')
os.system(f'du -h {TRAIN_DS_PATH}')
os.system(f'du -h {EVAL_DS_PATH}')
assert os.path.exists(TRAIN_DS_PATH), f"Dataset not saved: {TRAIN_DS_PATH}"
assert os.path.exists(EVAL_DS_PATH), f"Dataset not saved: {EVAL_DS_PATH}"
print(f"{JOBID} done")
if not LOCAL: sleep(30) # wait for files to update (for cluster)