In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import jax
import timecast as tc
from timecast.utils.experiment import experiment
import pandas as pd
import jax.numpy as jnp
import numpy as np
import matplotlib.pyplot as plt
from keras.models import load_model, Model
from tqdm.notebook import tqdm
import binpacking
import pickle

Using TensorFlow backend.


In [40]:
test_keys = np.load('data/fusion/FRNN_1d_sample/test_list.npy')

In [41]:
@experiment("shot", test_keys)
@experiment("history_len", [10, 100, 200])
@experiment("learning_rate", [1e-7, 1e-5, 1e-4])
def runner(shot, history_len, learning_rate):
    import jax
    import pickle
    import jax.numpy as jnp
    
    from timecast.modules import AR
    from timecast.optim import SGD, NormThreshold
    
    MSE = lambda true, pred: jnp.square(jnp.asarray(true) - jnp.asarray(pred)).mean()
    
    ar = AR(history_len, 1, 1)
    
    data = pickle.load(open(f"data/fusion/original/{shot}.pkl", "rb"))
    
    # Locked mode amplitude
    xs, ts, s = data[:-30, 3], data[1:-29, 3], data[30:, 3]
    
    sgd = SGD(learning_rate=learning_rate)
    nl_k = NormThreshold(0.03, filter=lambda x: "kernel" in x)
    nl_b = NormThreshold(1e-4, filter=lambda x: "bias" in x)
    def loop(module, xy):
        x, y = xy
        pred = module(x)
        module = sgd(module, x, y)
        module = nl_k(module)
        module = nl_b(module)

        return module, pred

    ar, ys = jax.lax.scan(loop, ar, (xs, ts))
    
    return {
        "shot": shot,
        "history_len": history_len,
        "learning_rate": learning_rate,
        "mse": MSE(s.squeeze(), ys.squeeze()),
    }

In [42]:
ar_results = runner.run(processes=25, tqdm=tqdm)

HBox(children=(FloatProgress(value=0.0, max=7758.0), HTML(value='')))




In [43]:
pickle.dump(ar_results, open("data/fusion/baseline/ar_results.pkl", "wb"))

In [44]:
ar_df = pd.DataFrame.from_dict(ar_results)
ar_df = ar_df.astype(float)

In [45]:
ar_df.pivot_table(values="mse", index=["history_len"], columns=["learning_rate"])

learning_rate,1.000000e-07,1.000000e-05,1.000000e-04
history_len,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
10.0,31.206599,30.672225,30.660575
100.0,19.655595,18.879948,18.865493
200.0,14.296727,13.469981,13.477321


# All data

In [46]:
@experiment("shot", test_keys)
@experiment("history_len", [200])
@experiment("learning_rate", [1e-6])
def runner(shot, history_len, learning_rate):
    import jax
    import pickle
    import jax.numpy as jnp
    
    from timecast.modules import AR
    from timecast.optim import SGD, NormThreshold
    
    MSE = lambda true, pred: jnp.square(jnp.asarray(true) - jnp.asarray(pred)).mean()
    
    ar = AR(history_len, 142, 1)
    
    data = pickle.load(open(f"data/fusion/original/{shot}.pkl", "rb"))
    
    # Locked mode amplitude
    xs, ts, s = data[:-30], data[1:-29, 3], data[30:, 3]
    
    sgd = SGD(learning_rate=learning_rate)
    nl_k = NormThreshold(0.03, filter=lambda x: "kernel" in x)
    nl_b = NormThreshold(1e-4, filter=lambda x: "bias" in x)
    def loop(module, xy):
        x, y = xy
        pred = module(x)
        module = sgd(module, x, y)
        module = nl_k(module)
        module = nl_b(module)

        return module, pred

    ar, ys = jax.lax.scan(loop, ar, (xs, ts))
    
    return {
        "shot": shot,
        "history_len": history_len,
        "learning_rate": learning_rate,
        "mse": MSE(s.squeeze(), ys.squeeze()),
    }

In [47]:
ar_all_results = runner.run(processes=25, tqdm=tqdm)

HBox(children=(FloatProgress(value=0.0, max=862.0), HTML(value='')))




In [48]:
pickle.dump(ar_all_results, open("data/fusion/baseline/ar_results.pkl", "wb"))
ar_all_df = pd.DataFrame.from_dict(ar_all_results)
ar_all_df = ar_all_df.astype(float)
ar_all_df.pivot_table(values="mse", index=["history_len"], columns=["learning_rate"])

learning_rate,0.000001
history_len,Unnamed: 1_level_1
200.0,5420820.0


In [49]:
np.sum(np.array([shot["mse"] for shot in ar_all_results]) < 10000)

859

In [50]:
[shot for shot in ar_all_results if shot["mse"] > 1000]

[{'shot': 150010,
  'history_len': 200,
  'learning_rate': 1e-06,
  'mse': array(1065.1556, dtype=float32)},
 {'shot': 145049,
  'history_len': 200,
  'learning_rate': 1e-06,
  'mse': array(3754.2212, dtype=float32)},
 {'shot': 147426,
  'history_len': 200,
  'learning_rate': 1e-06,
  'mse': array(31648550., dtype=float32)},
 {'shot': 149064,
  'history_len': 200,
  'learning_rate': 1e-06,
  'mse': array(3.9041275e+09, dtype=float32)},
 {'shot': 149011,
  'history_len': 200,
  'learning_rate': 1e-06,
  'mse': array(7.369613e+08, dtype=float32)},
 {'shot': 150554,
  'history_len': 200,
  'learning_rate': 1e-06,
  'mse': array(1027.2838, dtype=float32)}]