In [2]:
import torch
import pandas as pd
import astropy as ap
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
# from astropy.io import fits
import pdb
from scipy.ndimage.filters import maximum_filter1d
import glob
import fitsio as fits
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.distributions.normal import Normal
from tqdm import tqdm
from utils import collate_interp_sparse

In [3]:
import sys
sys.path.insert(0, '../')
sys.path.insert(1, '../latent_ode/')
import latent_ode.lib as ode
import latent_ode.lib.utils as utils
from latent_ode.lib.latent_ode import LatentODE
from latent_ode.lib.ode_rnn import ODE_RNN
from latent_ode.lib.encoder_decoder import Encoder_z0_ODE_RNN, Decoder
from latent_ode.lib.diffeq_solver import DiffeqSolver
from latent_ode.lib.ode_func import ODEFunc

In [1]:
# from latent_rnn import create_LatentODE_model
# from ode_rnn_tess import create_ODERNN_model
# from latent_rnn_tess_interp import create_LatentODE_model
from ode_rnn import create_ODERNN_model

In [4]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [5]:
obsrv_std = torch.Tensor([0.2]).to(device)
z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
input_dim = 1
latent_dim = 40

In [6]:
model = create_ODERNN_model()
# model = create_LatentODE_model(input_dim, z0_prior, obsrv_std)

In [7]:
# model_file = 'models/ode_rnn_state_tess.pth.tar'
model_file = 'ode_rnn_state_tess_new.pth.tar'
# model_file = 'latent_ode_state.pth.tar'
# model_file = 'ode_r'

In [8]:
state = torch.load(model_file, map_location=torch.device('cpu'))

In [9]:
model.load_state_dict(state['state_dict'])

<All keys matched successfully>

In [10]:
model.eval()

ODE_RNN(
  (ode_gru): Encoder_z0_ODE_RNN(
    (GRU_update): GRU_unit(
      (update_gate): Sequential(
        (0): Linear(in_features=82, out_features=50, bias=False)
        (1): Tanh()
        (2): Linear(in_features=50, out_features=40, bias=False)
        (3): Sigmoid()
      )
      (reset_gate): Sequential(
        (0): Linear(in_features=82, out_features=50, bias=False)
        (1): Tanh()
        (2): Linear(in_features=50, out_features=40, bias=False)
        (3): Sigmoid()
      )
      (new_state_net): Sequential(
        (0): Linear(in_features=82, out_features=50, bias=False)
        (1): Tanh()
        (2): Linear(in_features=50, out_features=80, bias=False)
      )
    )
    (z0_diffeq_solver): DiffeqSolver(
      (ode_func): ODEFunc(
        (gradient_net): Sequential(
          (0): Linear(in_features=40, out_features=100, bias=False)
          (1): Tanh()
          (2): Linear(in_features=100, out_features=100, bias=False)
          (3): Tanh()
          (4): Linear(

### Calc

In [11]:
def sample(z, t):
    sol_y = model.diffeq_solver.sample_traj_from_prior(z, t, n_traj_samples = 100)
    res = model.decoder(sol_y)
    return res

In [12]:
files = glob.glob('tess/16_processed/z_normalized/*.npy')

In [13]:
len(files)

19996

In [14]:
def proc_batch(batch):
    data = np.stack([np.load(f) for f in batch])
    data = torch.FloatTensor(data)
    inp = collate_interp_sparse(data)
    observed = inp['observed_data'].to(device)
    true = inp['data_to_predict'].to(device)
    mask = inp['observed_mask'].to(device)
    t = inp['observed_tp'].to(device)
    x = torch.cat((observed, mask), -1)
    try:
        z_u, z_std = model.encoder_z0.forward(x, t)
    except:
        z_u, z_std = model.ode_gru.forward(x, t)
    return z_u, t

In [15]:
batches = []
for i in range(0, len(files)+1, 100):
    batch = files[i:i+100]
    batches.append(batch)

In [21]:
proc_batch(batches[-1])[0]

tensor([[[ 0.0262, -0.1007,  0.1064,  ..., -0.1735, -0.0662, -0.0240],
         [ 0.0283, -0.0784,  0.1105,  ..., -0.1655, -0.0611, -0.0233],
         [ 0.0272, -0.0962,  0.1064,  ..., -0.1714, -0.0641, -0.0228],
         ...,
         [ 0.0202, -0.1094,  0.1047,  ..., -0.1759, -0.0700, -0.0308],
         [ 0.0400, -0.0389,  0.1145,  ..., -0.1496, -0.0458, -0.0119],
         [ 0.0249, -0.0963,  0.1090,  ..., -0.1724, -0.0667, -0.0278]]],
       grad_fn=<SliceBackward>)

In [17]:
all_vecs = []
for batch in tqdm(batches[:2]):
    batch_vecs, t = proc_batch(batch)
    batch_vecs = batch_vecs.squeeze().detach().cpu().numpy()
    all_vecs.append(batch_vecs)

100%|██████████| 2/2 [00:00<00:00,  2.76it/s]


In [18]:
all_vecs = np.concatenate([x for x in all_vecs])

In [19]:
all_vecs.shape

(200, 40)

In [27]:
all_vecs.std(0)

array([0.0284374 , 0.05228927, 0.01052977, 0.03690663, 0.01602304,
       0.01664556, 0.03114907, 0.02575341, 0.02132449, 0.03467505,
       0.0504834 , 0.03346518, 0.02468674, 0.05077022, 0.02560313,
       0.01559914, 0.02327095, 0.0328625 , 0.03939604, 0.01432921,
       0.01921072, 0.02570727, 0.0706727 , 0.03311468, 0.02375094,
       0.01617527, 0.01882573, 0.01198714, 0.04316881, 0.02995796,
       0.02319291, 0.02120319, 0.04068038, 0.03053875, 0.01740295,
       0.08976527, 0.02174713, 0.01686997, 0.03178869, 0.01670964],
      dtype=float32)

In [20]:
samples = []
for x in tqdm(all_vecs[:200]):
    x = torch.FloatTensor(x).unsqueeze(0).unsqueeze(0)
    y = sample(x, t)
    samples.append(y)

  0%|          | 0/200 [00:00<?, ?it/s]


AttributeError: 'ODE_RNN' object has no attribute 'diffeq_solver'

In [None]:
for s in samples[:20]:
    y = s.squeeze().detach().cpu().numpy()
    sns.lineplot(x=np.arange(len(y)), y=y)

In [None]:
out = y.squeeze().detach().numpy()

In [None]:
sns.lineplot(x=np.arange(len(out)), y=out)

In [None]:
yy = true.squeeze().detach().numpy()

In [None]:
sns.scatterplot(x=np.arange(len(yy)), y=yy)