In [1]:
%load_ext autoreload
%autoreload 2
import numpy as np
import tensorflow as tf
import tensorflow_addons as tfa
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error,mean_absolute_error
from tensorflow.keras.utils import plot_model
from tensorflow import constant_initializer
import matplotlib.pyplot as plt
import json
import shutil
from ipywidgets import interact, interactive, fixed, interact_manual
import ipywidgets as widgets

In [2]:
import sys
sys.path.append("../code/")
from model import PhasedSNForecastProbabilisticIntervalModel
out_steps = 3

In [3]:
def normalize(data):
    masked_data = np.ma.masked_where(data < 0, data)
    min_val = masked_data.min(axis=1)
    max_val = masked_data.max(axis=1)
    
    for i in range(masked_data.shape[1]):
        masked_data.data[:,i,:] = (masked_data.data[:,i,:] - min_val)/(max_val-min_val)
    
    return_data = masked_data.data
    return_data[masked_data.mask] = -1
    return return_data, min_val, max_val
    
def denormalize(data, min_val, max_val):
    masked_data = np.ma.masked_where(data < 0, data)
    
    for i in range(masked_data.shape[1]):
        masked_data.data[:,i,:] = (masked_data.data[:,i,:] * (max_val-min_val))  +  min_val
    
    return_data = masked_data.data
    return_data[masked_data.mask] = -1
    return return_data

In [4]:
data = np.load("../data/padded_x_train.npy")
len_data = data.shape[1]
data, data_min_val, data_max_val = normalize(data)
X_train, y_train = data[:,:-out_steps,:],  data[:,-out_steps:,:]

In [5]:
data_val = np.load("../data/padded_x_val.npy")
len_data = data_val.shape[1]
data_val, data_val_min_val, data_val_max_val = normalize(data_val)
X_val, y_val = data_val[:,:-out_steps,:],  data_val[:,-out_steps:,:]

In [6]:
inputs = X_train
outputs = y_train
inputs_val = X_val
outputs_val = y_val

outputs = {}
outputs_val = {}

outputs["prediction"] = y_train
outputs_val["prediction"] = y_val

for interval in ["upper", "lower"]:
    outputs[interval] = np.expand_dims(y_train[:,:,1],axis=-1)
    outputs_val[interval] = np.expand_dims(y_val[:,:,1],axis=-1)

In [7]:
class SaveData(tf.keras.callbacks.Callback):
    def __init__(self,logdir, keys,**kwargs):
        super().__init__(**kwargs)
        self.file_writer = tf.summary.create_file_writer(logdir + "/metrics")
        self.file_writer.set_as_default()
        self.keys = keys
        
    def on_epoch_end(self, epoch, logs=None):
        for key in self.keys:
            tf.summary.scalar(key, data=logs.get(key), step=epoch)

In [8]:
import datetime
#Early stops
early_stop = tf.keras.callbacks.EarlyStopping( monitor='val_loss', min_delta=1e-10, patience=10)

#Tensorboard
logdir = "../data/training/logs/PI" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard = tf.keras.callbacks.TensorBoard(logdir)
saver = SaveData(logdir, ["PICW"])
shutil.rmtree("../data/training/logs/PI",ignore_errors=True)


#Checkpoint
checkpoint = tf.keras.callbacks.ModelCheckpoint("../data/training_PI/model_checkpoints/checkpoint", monitor='val_loss', verbose=0, save_best_only=True)

callbacks = [tensorboard,checkpoint, early_stop, saver]

In [9]:
#Loading and preparing model
from model import PhasedSNForecastModel
base_model = PhasedSNForecastModel(units=150, out_steps=out_steps,features = 3)
base_model.compile(optimizer="rmsprop", loss="mse")
_ = base_model.fit(X_train[:2], y_train[:2])


base_model.load_weights("../data/sn_model.h5")



In [10]:
model = PhasedSNForecastProbabilisticIntervalModel(units=300, out_steps=out_steps, model = base_model, dropout=0.0)

In [11]:
model.rnn.trainable = False
model.denses.trainable = False
model.cells.trainable = False

In [12]:
from tensorflow_addons.utils.keras_utils import LossFunctionWrapper
from tensorflow_addons.utils.types import TensorLike, FloatTensorLike
from typeguard import typechecked

@tf.function
def custom_pinball_loss(y_true: TensorLike, y_pred: TensorLike, tau: FloatTensorLike = 0.5) -> tf.Tensor:
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = tf.cast(y_true, y_pred.dtype)
    
    
    # Broadcast the pinball slope along the batch dimension
    tau = tf.expand_dims(tf.cast(tau, y_pred.dtype), 0)
    one = tf.cast(1, tau.dtype)

    pinball = tf.where(y_pred > y_true, tau * (y_pred - y_true), (1-tau) * (y_true-y_pred) )
    return tf.reduce_mean(pinball, axis=-1)

class CustomPinballLoss(LossFunctionWrapper):
    @typechecked
    def __init__(
        self,
        tau: FloatTensorLike = 0.5,
        reduction: str = tf.keras.losses.Reduction.AUTO,
        name: str = "custom_pinball_loss",
    ):
        super().__init__(custom_pinball_loss, reduction=reduction, name=name, tau=tau)
        
        
@tf.function
def inverse_pinball_loss(y_true: TensorLike, y_pred: TensorLike, tau: FloatTensorLike = 0.5) -> tf.Tensor:
    y_pred = tf.convert_to_tensor(y_pred)
    y_true = tf.cast(y_true, y_pred.dtype)
    
    
    # Broadcast the pinball slope along the batch dimension
    tau = tf.expand_dims(tf.cast(tau, y_pred.dtype), 0)
    one = tf.cast(1, tau.dtype)

    pinball = tf.where(y_pred > y_true, (1-tau) * (y_pred - y_true), tau * (y_true-y_pred) )
    return tf.reduce_mean(pinball, axis=-1)   

class InversePinballLoss(LossFunctionWrapper):
    @typechecked
    def __init__(
        self,
        tau: FloatTensorLike = 0.5,
        reduction: str = tf.keras.losses.Reduction.AUTO,
        name: str = "inverse_pinball_loss",
    ):
        super().__init__(inverse_pinball_loss, reduction=reduction, name=name, tau=tau)
        
        

In [13]:
alpha = 0.30
losses = {
    "prediction": None,
    "lower": CustomPinballLoss(tau=(alpha/2), reduction=tf.keras.losses.Reduction.NONE),
    "upper": CustomPinballLoss(tau=1-(alpha/2), reduction=tf.keras.losses.Reduction.NONE)
}
model.compile(optimizer="rmsprop", loss=losses)

In [14]:
MAX_EPOCHS=1000
history = model.fit(inputs,outputs,
                    batch_size=300, 
                    epochs=MAX_EPOCHS, 
                    validation_data=(inputs_val,outputs_val), 
                    callbacks=callbacks)

Epoch 1/1000
Instructions for updating:
use `tf.profiler.experimental.stop` instead.
Epoch 2/1000
Epoch 3/1000
Epoch 4/1000
Epoch 5/1000
Epoch 6/1000
Epoch 7/1000
Epoch 8/1000
Epoch 9/1000
Epoch 10/1000
Epoch 11/1000
Epoch 12/1000
Epoch 13/1000
Epoch 14/1000
Epoch 15/1000
Epoch 16/1000
Epoch 17/1000
Epoch 18/1000
Epoch 19/1000
Epoch 20/1000
Epoch 21/1000
Epoch 22/1000
Epoch 23/1000
Epoch 24/1000
Epoch 25/1000
Epoch 26/1000
Epoch 27/1000
Epoch 28/1000
Epoch 29/1000
Epoch 30/1000
Epoch 31/1000
Epoch 32/1000
Epoch 33/1000
Epoch 34/1000
Epoch 35/1000
Epoch 36/1000
Epoch 37/1000
Epoch 38/1000
Epoch 39/1000
Epoch 40/1000
Epoch 41/1000
Epoch 42/1000
Epoch 43/1000
Epoch 44/1000
Epoch 45/1000
Epoch 46/1000
Epoch 47/1000
Epoch 48/1000
Epoch 49/1000
Epoch 50/1000
Epoch 51/1000
Epoch 52/1000
Epoch 53/1000
Epoch 54/1000
Epoch 55/1000
Epoch 56/1000
Epoch 57/1000
Epoch 58/1000
Epoch 59/1000
Epoch 60/1000
Epoch 61/1000
Epoch 62/1000
Epoch 63/1000
Epoch 64/1000
Epoch 65/1000
Epoch 66/1000
Epoch 67/1000

In [15]:
history_dict = history.history
json.dump(history_dict, open("../data/training_PI/history_model.json", 'w'))

In [16]:
model.save_weights("../data/sn_model_PI.h5")

In [17]:
data_test = np.load("../data/padded_x_val.npy")[:,:,:]
data_test, data_test_min_val, data_test_max_val = normalize(data_test)
X_test, y_test = data_test[:,:-out_steps,:], data_test[:,-out_steps:, :]

#Doing inference on Train data
y_hat_train = model.predict(X_train)
#Denormalizing train
dX_train = denormalize(X_train, data_min_val,data_max_val)
dy_hat_train = {}
dy_hat_train["prediction"] = denormalize(y_hat_train["prediction"], data_min_val,data_max_val)
for key in ["upper", "lower"]:
    dy_hat_train[key] = denormalize(y_hat_train[key], data_min_val[:,1][:,np.newaxis],data_max_val[:,1][:,np.newaxis])
dy_train = denormalize(y_train, data_min_val,data_max_val)

# Doing inference on Test data
y_hat = model.predict(X_test)
# Denormalizing results
dX_test = denormalize(X_test, data_test_min_val,data_test_max_val)
dy_hat = {}
dy_hat["prediction"] = denormalize(y_hat["prediction"],data_test_min_val,data_test_max_val) 
for key in ["upper", "lower"]:
    dy_hat[key] = denormalize(y_hat[key],data_test_min_val[:,1][:,np.newaxis],data_test_max_val[:,1][:,np.newaxis])
dy_test = denormalize(y_test,data_test_min_val,data_test_max_val)

In [18]:
def plot_data(x, y_real, y_hat, sample=0):
    plt.figure(figsize=(12,6))
    plt.gca().invert_yaxis()
    x_masked = np.ma.masked_where(x < 0, x)
    plt.scatter(x_masked[sample,:,0], x_masked[sample,:,1], label="History")
    plt.scatter(y_real[sample,:,0], y_real[sample,:,1], label="Real")
    plt.scatter(y_hat["prediction"][sample,:,0], y_hat["prediction"][sample,:,1], label="Prediction")
    plt.fill_between(y_hat["prediction"][sample,:,0], y_hat["lower"][sample,:,0], y_hat["upper"][sample,:,0], alpha=0.2)
    plt.xlabel("Time $mjd-\min(mjd)$")
    plt.ylabel("Mag")
    
    

f = lambda sample: plot_data(dX_test, dy_test, dy_hat,sample=sample)
interact(f, sample=(0,len(dX_test)-1))

interactive(children=(IntSlider(value=397, description='sample', max=795), Output()), _dom_classes=('widget-in…

<function __main__.<lambda>(sample)>

In [19]:
import os
import progressbar
bar = progressbar.ProgressBar(max_value=len(X_test))
os.makedirs("../data/plots_test_PI/",exist_ok=True)

x = dX_test
y_real = dy_test
y_hat = dy_hat
bar.start()
for sample in range(len(dX_test)):
    plt.figure(figsize=(12,6))
    plt.gca().invert_yaxis()
    x_masked = np.ma.masked_where(x < 0, x)
    plt.scatter(x_masked[sample,:,0], x_masked[sample,:,1], label="History")
    plt.scatter(y_real[sample,:,0], y_real[sample,:,1], label="Real")
    plt.scatter(y_hat["prediction"][sample,:,0], y_hat["prediction"][sample,:,1], label="Prediction")
    plt.fill_between(y_hat["prediction"][sample,:,0], y_hat["lower"][sample,:,0], y_hat["upper"][sample,:,0], alpha=0.2)
    plt.xlabel("Time $mjd-\min(mjd)$")
    plt.ylabel("Mag")
    plt.savefig(f"../data/plots_test_PI/{str(sample).rjust(5,'0')}")
    plt.clf()
    plt.cla()
    plt.close()
    bar.update(sample+1)

100% (796 of 796) |######################| Elapsed Time: 0:02:17 ETA:  00:00:00