In [1]:
%load_ext autoreload
%autoreload 2

In [103]:
import os
import jax
import timecast as tc
from timecast.utils.experiment import experiment
import pandas as pd
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from keras.models import load_model, Model
from tqdm.notebook import tqdm
import binpacking
import pickle
import sklearn
from timecast.utils.ar import compute_gram, fit_gram, historify

In [6]:
data = np.load('data/fusion/FRNN_1d_sample/shot_data.npz', allow_pickle=True)
shot_data = data['shot_data'].item()

In [7]:
train_keys = np.load('data/fusion/FRNN_1d_sample/train_list.npy')
val_keys = np.load('data/fusion/FRNN_1d_sample/validation_list.npy')
test_keys = np.load('data/fusion/FRNN_1d_sample/test_list.npy')

print(len(train_keys), len(val_keys), len(test_keys))

# 14-dimensional 0D variables
# [q95 safety factor, internal inductance, plasma current, Locked mode amplitude, Normalized Beta, stored energy, Plasma density, Radiated Power Core, Radiated Power Edge, Input Power (beam for d3d), Input Beam Torque, plasma current direction, plasma current target, plasma current error]
col_names = ['q95', 'internal_inductance', 'plasma_current', 'locked_mode_amplitude', 'norm_beta', 'stored_energy', 'plasma_density', 'radiated_power_core', 'radiated_power_edge', 'input_power', 'input_beam_torque', 'plasma_current_direction', 'plasma_current_target', 'plasma_current_error']

1733 853 862


In [160]:
def featurize(shot_data, keys, col="locked_mode_amplitude", history_len=5, delay=30, ar=False, historify=historify):
    col = np.where(col == np.array(col_names))[0][0]
    
    for shot in keys:
        data = shot_data[shot]["X"]
        x = data[:-delay]
        if not ar:
            x = np.delete(x, col, axis=1)
        y = data[delay:, col]
        if historify is not None:
            x = historify(x, history_len).reshape(-1, x.shape[1] * history_len)
            y = y[history_len - 1:]
        yield x, y, None

In [161]:
from sklearn.linear_model import Ridge

In [162]:
X, Y = [], []
for x, y, _ in tqdm(featurize(shot_data, train_keys), total=len(train_keys)):
    X.append(x)
    Y.append(y)
X = np.concatenate(X)
Y = np.concatenate(Y)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [163]:
X.shape, Y.shape

((8286786, 705), (8286786,))

In [164]:
clf = Ridge(alpha=1.0)

In [165]:
clf_norm = Ridge(alpha=1.0, normalize=True)

In [166]:
clf.fit(X, Y)

Ridge()

In [167]:
clf_norm.fit(X, Y)

Ridge(normalize=True)

In [170]:
MSE = lambda true, pred: np.square(true.squeeze() - pred.squeeze()).mean()
mses = []
for x, y, _ in tqdm(featurize(shot_data, test_keys), total=len(test_keys)):
    pred = clf.predict(x)
    mses.append(MSE(pred, y))

HBox(children=(FloatProgress(value=0.0, max=862.0), HTML(value='')))




In [171]:
# history_len=5, normalize=False: 32.433346
# history_len=5, normalize=True: 32.3551
# history_len=10, normalize=False: 32.50882
# history_len=10, normalize=True: 32.419926
print(np.mean(mses))

32.433346
