In [None]:
import torch
import pandas as pd
import astropy as ap
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
# from astropy.io import fits
import pdb
from scipy.ndimage.filters import maximum_filter1d
import glob
import fitsio as fits
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.distributions.normal import Normal
from tqdm import tqdm
from utils import collate_interp_sparse
import time

In [None]:
import sys
sys.path.insert(0, '../')
sys.path.insert(1, '../latent_ode/')
import latent_ode.lib as ode
import latent_ode.lib.utils as utils
from latent_ode.lib.latent_ode import LatentODE
from latent_ode.lib.ode_rnn import ODE_RNN
from latent_ode.lib.encoder_decoder import Encoder_z0_ODE_RNN, Decoder
from latent_ode.lib.diffeq_solver import DiffeqSolver
from latent_ode.lib.ode_func import ODEFunc

In [None]:
# from latent_rnn import create_LatentODE_model
# from ode_rnn_tess import create_ODERNN_model
# from latent_rnn_tess_interp import create_LatentODE_model
from ode_rnn import create_ODERNN_model

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [None]:
obsrv_std = torch.Tensor([0.1]).to(device)
z0_prior = Normal(torch.Tensor([0.0]).to(device), torch.Tensor([1.]).to(device))
input_dim = 1
latent_dim = 40

In [None]:
model = create_ODERNN_model()
# model = create_LatentODE_model(input_dim, z0_prior, obsrv_std)

In [None]:
# model_file = 'models/ode_rnn_state_tess.pth.tar'
model_file = 'ode_rnn_state_tess_new.pth.tar'
# model_file = 'latent_ode_state.pth.tar'
# model_file = 'ode_r'

In [None]:
state = torch.load(model_file, map_location=torch.device('cpu'))

In [None]:
model.load_state_dict(state['state_dict'])

In [None]:
model.eval()

### Calc

In [None]:
def sample(z, t):
    sol_y = model.diffeq_solver.sample_traj_from_prior(z, t, n_traj_samples = 100)
    res = model.decoder(sol_y)
    return res

In [None]:
indir = 'tess/16_17/z_normalized/'

In [None]:
files = glob.glob(indir+'*.npy')

In [None]:
len(files)

In [None]:
def proc_batch(batch):
    data = np.stack([np.load(f) for f in batch])
    data = torch.FloatTensor(data)
    inp = collate_interp_sparse(data)
    observed = inp['observed_data'].to(device)
    true = inp['data_to_predict'].to(device)
    mask = inp['observed_mask'].to(device)
    t = inp['observed_tp'].to(device)
    x = torch.cat((observed, mask), -1)
    try:
        z_u, z_std = model.encoder_z0.forward(x, t)
    except:
        z_u, z_std = model.ode_gru.forward(x, t)
    return z_u, t

In [None]:
batch_size = 1000

In [None]:
batches = []
for i in range(0, len(files)+1, batch_size):
    batch = files[i:i+batch_size]
    batches.append(batch)

In [None]:
len(batches)

In [None]:
# proc_batch(batches[-1])[0]

If using multiprocessing quote out the code below

In [None]:
all_vecs = []
for batch in tqdm(batches):
    batch_vecs, t = proc_batch(batch)
    batch_vecs = batch_vecs.squeeze().detach().cpu().numpy()
    all_vecs.append(batch_vecs)
all_vecs = np.concatenate([x for x in all_vecs])
all_vecs.shape

### Output

In [None]:
files[:3]

In [None]:
fnames = [f.split(indir)[-1][:-4] for f in files]

In [None]:
assert(len(fnames)==len(all_vecs))

In [None]:
outp = dict(zip(fnames, all_vecs))

In [None]:
outp['tess2019253231442-s0016-0000000359584313-0152-s_lc']

In [None]:
pd.to_pickle(outp, 'tess_ode.pkl')

### Multiprocessing

In [None]:
from multiprocessing.pool import Pool

In [None]:
def func(batch):
    batch_vecs, t = proc_batch(batch)
    batch_vecs = batch_vecs.squeeze().detach().cpu().numpy()
    return batch_vecs

In [None]:
pool = Pool(6)

In [None]:
t0 = time.time()
res = pool.map(func, batches)
t1 = time.time()

In [None]:
pool.close()
pool.join()

In [None]:
print(t1-t0)

In [None]:
samples = []
for x in tqdm(all_vecs[:200]):
    x = torch.FloatTensor(x).unsqueeze(0).unsqueeze(0)
    y = sample(x, t)
    samples.append(y)

In [None]:
for s in samples[:20]:
    y = s.squeeze().detach().cpu().numpy()
    sns.lineplot(x=np.arange(len(y)), y=y)

In [None]:
out = y.squeeze().detach().numpy()

In [None]:
sns.lineplot(x=np.arange(len(out)), y=out)

In [None]:
yy = true.squeeze().detach().numpy()

In [None]:
sns.scatterplot(x=np.arange(len(yy)), y=yy)