In [None]:
import pandas as pd
import astropy as ap
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
# from astropy.io import fits
import pdb
from scipy.ndimage.filters import maximum_filter1d
import glob
import fitsio as fits
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.distributions.normal import Normal
from tqdm import tqdm

In [None]:
import sys
sys.path.insert(0, '../')
sys.path.insert(1, '../latent_ode/')
import latent_ode.lib as ode
import latent_ode.lib.utils as utils
from latent_ode.lib.latent_ode import LatentODE
from latent_ode.lib.encoder_decoder import Encoder_z0_ODE_RNN, Decoder
from latent_ode.lib.diffeq_solver import DiffeqSolver
from latent_ode.lib.ode_func import ODEFunc

In [None]:
from latent_rnn import create_LatentODE_model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Loading

In [None]:
obsrv_std = torch.Tensor([0.01]).to(device)
z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
input_dim = 1
model = create_LatentODE_model(input_dim, z0_prior, obsrv_std)

In [None]:
state = torch.load('latent_ode_state.pth.tar')

In [None]:
model.load_state_dict(state['state_dict'])

In [None]:
model.eval()

In [None]:
loader = torch.load('tess_train.pt')

## Inference

In [None]:
res = []
for batch in tqdm(loader):
    observed = batch['observed_data']
    mask = batch['observed_mask']
    x = torch.cat((observed, mask), -1)
    t = batch['observed_tp']
    z_u, z_std = model.encoder_z0.forward(x, t)
    z_u = z_u.squeeze()
    z_u = z_u.detach().cpu().numpy()
    res.append(z_u)

In [None]:
res = np.concatenate(res)

In [None]:
res.shape

### TSNE

In [None]:
from sklearn.manifold import TSNE

In [None]:
tsne = TSNE(perplexity=10)

In [None]:
x = tsne.fit_transform(res)

In [None]:
sns.scatterplot(x=x[:,0], y=x[:,1])

### Reconstruct

In [None]:
batch_dict = next(iter(loader))

In [None]:
batch_dict.keys()

In [None]:
t_obs = batch_dict['observed_tp']
y_obs = batch_dict['observed_data']

In [None]:
sns.lineplot(x=t_obs.squeeze().detach().cpu().numpy(), y=y_obs[0].squeeze().detach().cpu().numpy())

In [None]:
observed = batch_dict['observed_data']
mask = batch_dict['observed_mask']
x = torch.cat((observed, mask), -1)

In [None]:
n_traj_samples = 5

In [None]:
pred_x, info = model.get_reconstruction(batch_dict["tp_to_predict"], 
			batch_dict["observed_data"], batch_dict["observed_tp"], 
			mask = batch_dict["observed_mask"], n_traj_samples = n_traj_samples,
			mode = batch_dict["mode"])

In [None]:
batch_dict["tp_to_predict"].shape

In [None]:
pred_x.shape

In [None]:
y_pred = pred_x[0].squeeze()

In [None]:
y_pred.shape

In [None]:
# model.decoder.forward(pred_x)

In [None]:
t_pred = batch_dict['tp_to_predict'].detach().cpu().numpy()

In [None]:
sns.lineplot(x=t_pred, y=y_pred[2].squeeze().detach().cpu().numpy())

In [None]:
y_true = batch['data_to_predict']

In [None]:
y_true.shape

In [None]:
sns.lineplot(x=t_pred, y=y_true[2].squeeze().detach().cpu().numpy())

In [None]:
tt = torch.arange(0, 30, 0.01)

In [None]:
res2 = model.sample_traj_from_prior(tt, 10)

In [None]:
res2.shape

In [None]:
# ts2 = batch['observed_tp'].detach().cpu().numpy()

In [None]:
res2 = res2.squeeze()

In [None]:
res2 = res2.detach().cpu().numpy()

In [None]:
res2.shape

In [None]:
sns.lineplot(x=tt, y=res2[9])