In [None]:
%matplotlib widget
import math, numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam
from argparse import Namespace
import tmodel
tf.config.set_visible_devices([], 'GPU')             # disable GPU for this notebook

args: Namespace = tmodel.load_args()
signal_index=args.signal
expt_index=args.experiment
nepochs=args.nepochs
nfeatures=args.nfeatures
batch_size=256

data = tmodel.get_demo_data()
signals = data['signals']
times = data['times']
ckp_file = tmodel.get_ckp_file( expt_index, signal_index )

X: np.ndarray = tmodel.get_features( times[signal_index], expt_index, nfeatures )
Y: np.ndarray = signals[signal_index]
T: np.ndarray = times[signal_index]
validation_split: int = int(0.8*X.shape[0])
# Y = tmodel.tnorm(Y)

Xtrain=X[:validation_split]
Xval=X[validation_split:]
Ytrain=Y[:validation_split]
Yval=Y[validation_split:]

In [None]:
tmodel = tmodel.create_small_model(X.shape[1],0.5)
tmodel.compile(optimizer='rmsprop', loss='mae')

In [None]:
tmodel.load_weights( ckp_file )
print( f"Loaded checkpoint from '{ckp_file}'")
p0=tmodel.predict(Xtrain,batch_size=batch_size)
p1=tmodel.predict(Xval,batch_size=batch_size)

In [None]:
plt.figure(figsize=(15,5))
plt.plot(T,Y,label='truth')
plt.plot(T[:validation_split],p0[:,0],label='train prediction')
plt.plot(T[validation_split:],p1[:,0],label='val prediction')
plt.title(f'Signal {signal_index} (e-{expt_index}): nfeatures={nfeatures}')
plt.legend()
plt.tight_layout()
plt.show()