In [None]:
import torch
import pandas as pd
import astropy as ap
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
# from astropy.io import fits
import pdb
from scipy.ndimage.filters import maximum_filter1d
import glob
import fitsio as fits
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.distributions.normal import Normal
from tqdm import tqdm

In [None]:
import sys
sys.path.insert(0, '../')
sys.path.insert(1, '../latent_ode/')
import latent_ode.lib as ode
import latent_ode.lib.utils as utils
from latent_ode.lib.latent_ode import LatentODE
from latent_ode.lib.encoder_decoder import Encoder_z0_ODE_RNN, Decoder
from latent_ode.lib.diffeq_solver import DiffeqSolver
from latent_ode.lib.ode_func import ODEFunc

In [None]:
from latent_rnn import create_LatentODE_model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")

## Loading

In [None]:
obsrv_std = torch.Tensor([0.01]).to(device)
z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
input_dim = 1
model = create_LatentODE_model(input_dim, z0_prior, obsrv_std)

In [None]:
state = torch.load('latent_ode_state.pth.tar')
# state = torch.load('latent_ode_state.pth.tar', map_location=torch.device('cpu'))

In [None]:
model.load_state_dict(state['state_dict'])

In [None]:
model.eval()

In [None]:
loader = torch.load('gaia_train.pt')

## Inference

In [None]:
z_us = []
z_stds = []
truths = []
ts = []
recs = []
for batch in tqdm(loader):
    observed = batch['observed_data']
    true = batch['data_to_predict']
    truths.append(true)
    mask = batch['observed_mask']
    mask_pred = batch['mask_predicted_data']
#     mask_pred = torch.ones(mask.shape)
    x = torch.cat((observed, mask), -1)
    x2 = torch.cat((true, mask_pred), -1)
    t = batch['observed_tp']
    ts.append(t.detach().cpu())
    z_u, z_std = model.encoder_z0.forward(x, t)
    rec = model.get_reconstruction(time_steps_to_predict=t, truth=x, truth_time_steps=t)
    recs.append(rec[0].detach().cpu())
#     z_u = z_u.squeeze()
    z_u = z_u.detach().cpu()
    z_std = z_std.detach().cpu()
    z_us.append(z_u)
    z_stds.append(z_std)

In [None]:
z_us = torch.cat(z_us,1).squeeze()
z_stds = torch.cat(z_stds,1).squeeze()
truths = torch.cat(truths,0).squeeze()

In [None]:
ts[0]

In [None]:
truths.shape

In [None]:
recs = torch.cat(recs,1).squeeze()

In [None]:
recs.shape

In [None]:
z_us.shape, z_stds.shape, truths.shape

In [None]:
ix = 88

In [None]:
u = z_us[ix]
std = z_stds[ix]

In [None]:
def sample(z, t):
    sol_y = model.diffeq_solver.sample_traj_from_prior(z, t, n_traj_samples = 1)
    out = model.decoder(sol_y)
    return out

In [None]:
recs = recs.squeeze()

In [None]:
recs[ix].squeeze().shape

In [None]:
recs[ix]

In [None]:
# sns.lineplot(x=ts[0][:100], y=recs[ix][:100].squeeze().detach().cpu().numpy())
sns.lineplot(x=ts[0], y=recs[ix].squeeze().detach().cpu().numpy())

In [None]:
truths.shape

In [None]:
# sns.lineplot(x=ts[0][:100], y=truths[ix][:100].squeeze().detach().cpu().numpy())
sns.lineplot(x=ts[0], y=truths[ix].squeeze().detach().cpu().numpy())

### TSNE

In [None]:
from sklearn.manifold import TSNE

In [None]:
z_us.shape

In [None]:
latents = z_us.squeeze().detach().cpu().numpy()

In [None]:
latents.shape

In [None]:
tsne = TSNE(perplexity=10)

In [None]:
xx = tsne.fit_transform(latents)

In [None]:
sns.scatterplot(x=xx[:,0], y=xx[:,1])

In [None]:
df = pd.DataFrame(xx, columns=['x1','x2'])
df['minimums'] = [np.round(x.min().item(),2) for x in truths]

In [None]:
sns.scatterplot(x='x1', y='x2', data=df, hue='minimums')

In [None]:
g=df.loc[(df['x1']<-100) & (df['x2']<-100)]
g=df.loc[(df['x1']>0) & (df['x2']>100)]

In [None]:
gx = g.index

In [None]:
len(truths)

In [None]:
ls=truths[gx]

In [None]:
for i, p in enumerate(truths):
    x = p.detach().cpu().squeeze().numpy()
    plt.figure(figsize=(5,3))
    sns.lineplot(x=ts[i], y=x)
    plt.show()