In [None]:
%matplotlib widget
import math, numpy as np
from typing import List, Tuple, Mapping
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from argparse import Namespace
import tmodel
tf.config.set_visible_devices([], 'GPU')             # disable GPU for this notebook

args: Namespace = tmodel.load_args()
signal_index=args.signal
expt_index=args.experiment
nepochs=args.nepochs
nfeatures=args.nfeatures
batch_size=args.batch_size
learning_rate=args.learning_rate

data = tmodel.get_demo_data()
signals = data['signals']
times = data['times']
ckp_file = tmodel.get_ckp_file( expt_index, signal_index )

X: np.ndarray = tmodel.get_features( times[signal_index], expt_index, nfeatures )
Y: np.ndarray = signals[signal_index]
T: np.ndarray = times[signal_index]
validation_split: int = int(0.8*X.shape[0])

Xtrain=X[:validation_split]
Xval=X[validation_split:]
Ytrain=Y[:validation_split]
Yval=Y[validation_split:]

In [None]:
small_model = tmodel.create_small_model(X.shape[1],0.5)
small_model.compile(optimizer=Adam(learning_rate=learning_rate), loss='mae')

In [None]:
small_model.load_weights( ckp_file )
print( f"Loaded checkpoint from '{ckp_file}'")
p0: np.ndarray = small_model.predict(Xtrain,batch_size=batch_size)
p1: np.ndarray = small_model.predict(Xval,batch_size=batch_size)

In [None]:
plt.figure(figsize=(15,5))
plt.plot(T,Y,label='truth')
plt.plot(T[:validation_split],p0[:,0],label='train prediction')
plt.plot(T[validation_split:],p1[:,0],label='val prediction')
plt.title(f'Signal {signal_index} (e-{expt_index}): nfeatures={nfeatures}')
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
fig = plt.figure( figsize=(22,8) )
pax = fig.add_axes([0.0, 0.1, 1.0, 0.9])
sax = fig.add_axes([0.0, 0.05, 0.8, 0.04])
pax.set_title( f'Features (type {expt_index}): Nf={nfeatures}' )

plots: List[plt.Line2D] = []
for iF in range( X.shape[1] ):
    plots.append( pax.plot( T[:], X[:,iF], label=f'Feature {iF}', alpha=tmodel.alpha(iF,0) )[0] )

widget = Slider( sax, "Feature", 0, len(plots), valinit=0, valstep=1 )
widget.on_changed( lambda sval: tmodel.select_feature(plots,fig,sval) )

fig.legend()