In [None]:
import torch
import pandas as pd
import astropy as ap
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
# from astropy.io import fits
import pdb
from scipy.ndimage.filters import maximum_filter1d
import glob
import fitsio as fits
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.distributions.normal import Normal
from tqdm import tqdm
from loading import TessDataset

In [None]:
import sys
sys.path.insert(0, '../')
sys.path.insert(1, '../latent_ode/')
import latent_ode.lib as ode
import latent_ode.lib.utils as utils
from latent_ode.lib.latent_ode import LatentODE
from latent_ode.lib.ode_rnn import ODE_RNN
from latent_ode.lib.encoder_decoder import Encoder_z0_ODE_RNN, Decoder
from latent_ode.lib.diffeq_solver import DiffeqSolver
from latent_ode.lib.ode_func import ODEFunc

### Loading

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [None]:
from ode_rnn import create_ODERNN_model

In [None]:
# Options
input_dim = 1
classif_per_tp = False
n_labels = 1
niters = 1
status_properties = ['loss']
latent_dim = 40

In [None]:
model = create_ODERNN_model()

In [None]:
model_file = 'ode_rnn_state_tess_new.pth.tar'

In [None]:
state = torch.load(model_file, map_location=torch.device('cpu'))

In [None]:
model.load_state_dict(state['state_dict'])

In [None]:
model.eval()

In [None]:
loader = torch.load('tess_cv.pt')

### Plot Multiple

In [None]:
batch = next(iter(loader))

In [None]:
ix = 18

In [None]:
def sample(z, t):
    sol_y = model.diffeq_solver.sample_traj_from_prior(z, t, n_traj_samples = 1)
    out = model.decoder(sol_y)
    return out

In [None]:
observed = batch['observed_data'][ix].unsqueeze(0).to(device)
true = batch['data_to_predict'][ix].unsqueeze(0).to(device)
mask = batch['observed_mask'][ix].unsqueeze(0).to(device)
t = batch['observed_tp'].to(device)

In [None]:
# a = t.min().item()
# b = t.max().item()
# step = (b-a)/1000
# t=torch.arange(a,b,step)

In [None]:
observed.shape

In [None]:
rec = model.get_reconstruction(
    time_steps_to_predict=t, 
    data=observed, 
    truth_time_steps=t, 
    mask=mask, 
    n_traj_samples=20)[0].detach().cpu().squeeze()

In [None]:
rec = rec.detach().cpu().numpy()
truth = true.detach().squeeze().cpu().numpy()
obs = observed.squeeze().detach().cpu().numpy()
t_obs = t[obs != 0]
obs = obs[obs != 0]
t = t.detach().cpu().numpy()

In [None]:
plt.style.use('seaborn-darkgrid')
font = {'family': 'serif',
        'color':  'grey',
        'weight': 'light',
        'size': 12,
        }

In [None]:
plt.figure(figsize=(10,5))
ax = plt.subplot(111)
ax.plot(t[100:], rec[100:], marker='', color='orange', linewidth=2, alpha=1, label='Model prediction')
ax.plot(t_obs, obs, marker='+', color='black', linestyle='None', alpha=0.7, markersize=7, label='Observations')
plt.xlabel("Time (MJD)", fontdict=font)
plt.ylabel("Normalized Flux", fontdict=font)
ax.legend(loc='center left', bbox_to_anchor=(1.0, 0.9),
          ncol=1, fancybox=True, shadow=True)
# ax.tick_params(
#     axis='both',          # changes apply to the x-axis
#     which='both',      # both major and minor ticks are affected
#     bottom=False,      # ticks along the bottom edge are off
#     top=False,         # ticks along the top edge are off
#     labelbottom=False,
#     labelleft=False,
#     right=False, 
#     left=False)