In [None]:
%matplotlib widget
import numpy as np, os, time
import tensorflow as tf
from tensorflow import keras
from argparse import Namespace
from typing import List
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam
import tmodel
args: Namespace = tmodel.load_args()
tf.config.set_visible_devices([], 'GPU')             # disable GPU for this notebook

signal_index=args.signal
feature_type=args.feature_type

In [None]:
data=tmodel.get_demo_data()
signals = data['signals']
T: np.ndarray = data['times'][signal_index]
X: np.ndarray = tmodel.get_features( T, feature_type, args )
Y: np.ndarray = signals[signal_index]

validation_split = int(0.8*X.shape[0])
Xtrain=X[:validation_split]
Xval=X[validation_split:]
Ytrain=Y[:validation_split]
Yval=Y[validation_split:]

strategy = tf.distribute.MirroredStrategy()
print(f"Number of devices: {strategy.num_replicas_in_sync}")
with strategy.scope():
    model = tmodel.create_streams_model( X.shape[1], dropout_frac=args.dropout_frac, n_streams=args.nstreams )
    model.compile( optimizer=tf.keras.optimizers.Adam( learning_rate=args.learning_rate ), loss=args.loss )

In [None]:
ckp_file = tmodel.get_ckp_file( args )
model.load_weights(ckp_file)
p0=model.predict( Xtrain, batch_size=args.batch_size )
p1=model.predict( Xval,   batch_size=args.batch_size )

In [None]:
rfig, rax = plt.subplots(figsize=(15,5))
rax.plot(T,Y,label='truth')
rax.plot(T[:validation_split],p0[:,0],label='train prediction')
rax.plot(T[validation_split:],p1[:,0],label='val prediction')
rax.set_title(f'Signal {signal_index} (ftype={feature_type}): nfeatures={args.nfeatures}')
rfig.legend()
plt.show()

In [None]:
ffig, pax = plt.subplots( figsize=(22,8) )
nf = X.shape[-1]
pax.set_title( f'Features (type {feature_type}): Nf={nf} ' )
plot: plt.Line2D =  pax.plot( T[:], X[:,0], label=f'Feature {0}')[0]

ffig.subplots_adjust(bottom=0.25)
sax = ffig.add_axes([0.25, 0.1, 0.65, 0.03])
widget = Slider( ax=sax, label="Feature", valmin=0, valmax=nf, valinit=0, valstep=1 )

def update(val):
    iF = int(widget.val)
    plot.set_ydata(X[:,iF])
    ffig.canvas.draw_idle()
widget.on_changed(update)
ffig.legend()
plt.show()