In [1]:
import os

import numpy as np
import torch
from torch import nn
import sys
os.chdir("..")
from src.models.plModel import LightningModel

from src.utils.IndexDataloader import DataModule
from src.solvers.ode import simulation_forecast,SB_forecast
from pytorch_lightning.callbacks import ModelCheckpoint

import pytorch_lightning as pl

import seaborn as sns
import matplotlib 
from matplotlib import pyplot as plt
from matplotlib.pyplot import cm
from matplotlib import ticker
import pytorch_lightning as pl


In [12]:
%matplotlib widget
import ipywidgets as widgets
%matplotlib inline

In [3]:
#Load the datamodule
data_module = DataModule(step_size=2,avg_dataloader=True)
data_module.setup()

Train Test Val Split Done


In [4]:
#Loading the complete data as masked array
with np.load('/gpfs/work/sharmas/mc-snow-data/big_box.npz') as npz:
        all_arr = np.ma.MaskedArray(**npz)
        
all_arr=all_arr.astype(np.float32)
arr=np.mean(all_arr[:,:,:,:98],axis=-1)

In [5]:
pl_model = LightningModel(inputs_mean=data_module.inputs_mean, inputs_std=data_module.inputs_std,
                          updates_mean=data_module.updates_mean, updates_std=data_module.updates_std) 
#Load the saved model here
new_model = pl_model.load_from_checkpoint("/gpfs/work/sharmas/mc-snow-data/lightning_logs/version_595250/checkpoints/epoch=68-step=5892944.ckpt")



not using hard constraints on updates while choosing to conserve mass can lead to negative moments


## Plotting starts here

In [6]:
#save the range for initial conditions
lo_=np.unique(arr[0,-4,:])
rm_=np.unique(arr[0,-3,:])
nu_=np.unique(arr[0,-2,:])

In [10]:
#For extracting simulation based on initial conditions
def pos_calc(lo_,rm_,nu_):
    m=np.intersect1d(np.where(np.around(arr[0,-4,:],7)==np.array([lo_])), np.where(np.around(arr[0,-3,:],7)==np.array([rm_])))
    l=np.where(arr[0,-2,:]==nu_)
    k=np.intersect1d(m,l)
    return k
    
    

In [7]:
#The main plotting function
def plot_simulation(predictions_orig,targets_orig,sb_preds,var_all,model_params,num):
    
    sns.set_style("darkgrid", {"grid.color": ".6", "grid.linestyle": ":"})
    var=['Lc','Nc','Lr','Nr']
    color=["#26235b","#bc473a","#812878","#f69824"]
    fig= plt.figure()
    fig.set_size_inches(13, 10)
    time= [x for x in range(0, len(predictions_orig))]
    #color=iter(cm.rainbow(np.linspace(0,1,4)))
    for i in range(4):
        ax = fig.add_subplot(2,2, i + 1)
        #c=next(color)
        plt.plot(time[:],predictions_orig[:,i],c=color[i])
        plt.plot(time[:],targets_orig[:,i],c='black')
        plt.plot(time[:-1],sb_preds[:,i],c=color[i],linestyle='dashed')
        plt.fill_between(time[:], targets_orig[:,i]-var_all[:len(time),i], targets_orig[:,i]+var_all[:len(time),i],facecolor = "gray")
        #plt.plot(time[:], targets_orig[:,i]-var_all[:len(time),i])
        plt.title(var[i])
        plt.xlabel('Timestep')    

        plt.legend(['Neural Network','Simulations','SB2001'])
    
    
    fig.suptitle("Lo: %.4f; rm:%.6f ; Nu: %.1f "%((model_params[0]),(model_params[1]),(model_params[2])), fontsize="x-large")
    plt.show()
    

In [8]:
#For calculating errors
def calc_errors(all_arr,n):

    var_1=np.std(all_arr[:,1,n,:],axis=-1)
    var_2=np.std(all_arr[:,2,n,:],axis=-1)
    var_3=np.std(all_arr[:,4,n,:],axis=-1)
    var_4=np.std(all_arr[:,5,n,:],axis=-1)
    var_all=np.concatenate((var_1.reshape(1,-1),var_2.reshape(1,-1),var_3.reshape(1,-1),var_4.reshape(1,-1)),axis=0)
    return var_all

In [13]:
@widgets.interact(lo=np.around(lo_,7),rm=np.around(rm_,7), nu=nu_)
def update(lo=lo_[0],rm=rm_[0],nu=nu_[0]):
    sim_num = pos_calc(lo,rm,nu).item()
   
    new_forecast = simulation_forecast(
    arr, new_model, sim_num, data_module.inputs_mean, data_module.inputs_std,data_module.updates_mean,data_module.updates_std
    )

    new_forecast.test()

    sb_forecast=SB_forecast(arr,sim_num)
    sb_forecast.SB_calc()
    predictions_sb=np.asarray(sb_forecast.predictions).reshape(-1,4)
    var_all=np.transpose(calc_errors(all_arr,sim_num))
    plot_simulation(np.asarray(new_forecast.moment_preds).reshape(-1,4),new_forecast.orig,predictions_sb,var_all,new_forecast.model_params,num=100)
   

interactive(children=(Dropdown(description='lo', options=(0.0002, 0.0003, 0.0004, 0.0005, 0.0006, 0.0007, 0.00…