In [None]:
%matplotlib widget
import time, numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.callbacks import ModelCheckpoint
from tensorflow.keras.optimizers import Adam
tf.config.set_visible_devices([], 'GPU')             # disable GPU for this notebook

signal_index=2
expt_index=2

data_dir =  "/explore/nobackup/projects/ilab/data/astrotime/demo"
ckp_file = f"{data_dir}/embed_time_predict.e{expt_index}.s{signal_index}.weights.h5"
data=np.load( f'{data_dir}/jordan_data.npz',allow_pickle=True )
signals = data['signals']
times = data['times']
binary_times = data['binary_times']

X = binary_times[signal_index].astype(np.float32)
Y = signals[signal_index]
T = times[signal_index]
validation_split = int(0.8*X.shape[0])

Xtrain=X[:validation_split]
Xval=X[validation_split:]
Ytrain=Y[:validation_split]
Yval=Y[validation_split:]

In [None]:
def create_small_model(dropout_frac):
    binary_times_input = tf.keras.Input(shape=(64,), name="binary_times_input")

    x = tf.keras.layers.Dense(512, activation='tanh')(binary_times_input)
    x = tf.keras.layers.Dropout(dropout_frac)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dense(512, activation='tanh')(x)
    x = tf.keras.layers.Dropout(dropout_frac)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dense(512, activation='tanh')(x)
    x = tf.keras.layers.Dropout(dropout_frac)(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Dense(512, activation='tanh')(x)
    x = tf.keras.layers.BatchNormalization()(x)

    outputs = tf.keras.layers.Dense(1, activation='linear')(x)
    model = tf.keras.Model(inputs=binary_times_input, outputs=outputs)
    return model

In [None]:
optimizer = Adam( learning_rate=0.001, name='adam' )
model = create_small_model(0.5)
model.compile(optimizer=optimizer, loss='mae')
model.load_weights( ckp_file )
p0=model.predict(Xtrain,batch_size=256)
p1=model.predict(Xval,batch_size=256)

In [None]:
plt.figure(figsize=(15,5))
plt.plot(T,Y,label='truth')
plt.plot(T[:validation_split],p0[:,0],label='train prediction')
plt.plot(T[validation_split:],p1[:,0],label='val prediction')
plt.title(f'Signal {signal_index}')
plt.legend()
plt.tight_layout()
plt.show()