In [None]:
import torch
import pandas as pd
import astropy as ap
import numpy as np
import matplotlib.pyplot as plt 
import seaborn as sns
# from astropy.io import fits
import pdb
from scipy.ndimage.filters import maximum_filter1d
import glob
import fitsio as fits
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data_utils
from torch.nn.utils import clip_grad_norm_
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence, pad_sequence
from torch.distributions.normal import Normal
from tqdm import tqdm
from utils import collate_interp_sparse_sectors
import time

In [None]:
import sys
sys.path.insert(0, '../')
sys.path.insert(1, '../latent_ode/')
import latent_ode.lib as ode
import latent_ode.lib.utils as utils
from latent_ode.lib.latent_ode import LatentODE
from latent_ode.lib.ode_rnn import ODE_RNN
from latent_ode.lib.encoder_decoder import Encoder_z0_ODE_RNN, Decoder
from latent_ode.lib.diffeq_solver import DiffeqSolver
from latent_ode.lib.ode_func import ODEFunc

In [None]:
from ode_rnn_tess_sectors import create_ODERNN_model

In [None]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

In [None]:
model = create_ODERNN_model()
# model = create_LatentODE_model(input_dim, z0_prior, obsrv_std)

In [None]:
model_file = 'ode_rnn_state_tess_sectors.pth.tar'

In [None]:
state = torch.load(model_file, map_location=torch.device('cpu'))

In [None]:
model.load_state_dict(state['state_dict'])

In [None]:
model.eval()

### Calc

In [None]:
def sample(z, t):
    sol_y = model.diffeq_solver.sample_traj_from_prior(z, t, n_traj_samples = 100)
    res = model.decoder(sol_y)
    return res

In [None]:
indir = 'tess/all_data/z_normalized/'

In [None]:
sectors = os.listdir(indir); sectors

In [None]:
batch_size = 1000

In [None]:
res = {}
for i in sectors:
    sector = str(i)
    c = os.path.join(indir, sector, '*.npy')
    files = glob.glob(c)
    for f in files:
        d = np.load(f)
        t1 = d[0].min()
        t2 = d[0].max()
        l = len(d[0])
        basename = os.path.basename(f)
        summary = {'path': f, 't1': t1, 't2': t2, 'l': l, 'sector': sector}
        res[basename] = summary

df = pd.DataFrame(res).T
groups=df.groupby(['sector','l'])['path'].apply(lambda x: list(x)).tolist()
print(groups)
print("Total groups: {0}", len(groups))
bs = batch_size
batches = []
for group in groups:
    mini_batches=[group[i:(i+bs)] for i in range(0, len(group), bs)]
    for b in mini_batches:
        batches.append(b)

In [None]:
df['l'].min()

In [None]:
len(batches)

In [None]:
def proc_batch(batch):
    data = [np.load(f).astype(np.float32) for f in batch]
    data = torch.FloatTensor(data)
    inp = collate_interp_sparse_sectors(data)
    observed = inp['observed_data'].to(device)
    true = inp['data_to_predict'].to(device)
    mask = inp['observed_mask'].to(device)
    t = inp['observed_tp'].to(device)
    x = torch.cat((observed, mask), -1)
    try:
        z_u, z_std = model.encoder_z0.forward(x, t)
    except:
        z_u, z_std = model.ode_gru.forward(x, t)
    return z_u, t

In [None]:
# batches = []
# for i in tqdm(range(0, len(files)+1, batch_size)):
#     batch = files[i:i+batch_size]
#     batches.append(batch)

In [None]:
len(batches)

In [None]:
# proc_batch(batches[-1])[0]

In [None]:
all_vecs = []
for i, batch in tqdm(enumerate(batches)):
    batch_vecs, t = proc_batch(batch)
    batch_vecs = batch_vecs.squeeze().detach().cpu().numpy()
    all_vecs.append(batch_vecs)
    if i == 3:
        break
# all_vecs = np.concatenate([x for x in all_vecs])
# all_vecs.shape

In [None]:
os.path.basename(batch[0])

### Output

In [None]:
all_vecs = [x for a in all_vecs for x in a]

In [None]:
files = [x for a in batches[:4] for x in a]

In [None]:
files

In [None]:
assert(len(files) == len(all_vecs))

In [None]:
outp = {}
for i, f in enumerate(files):
    series = np.load(f)[1]
    outp[i] = {'filename': f, 'basename': os.path.basename(f).replace('.npy',''), 'vec': all_vecs[i], 'series': series}

In [None]:
pd.to_pickle(outp, 'results.pkl')

### test

In [None]:
def calc_vec(z, dim=50):
    t = np.arange(0,len(z))
    q = np.stack([t,z])
    data = torch.FloatTensor(q.astype(np.float32))
    data = data.unsqueeze(0)
    inp = collate_interp_sparse_sectors(data)
    observed = inp['observed_data'].to(device)
    true = inp['data_to_predict'].to(device)
    mask = inp['observed_mask'].to(device)
    t = inp['observed_tp'].to(device)
    x = torch.cat((observed, mask), -1)
    try:
        z_u, z_std = model.encoder_z0.forward(x, t)
    except:
        z_u, z_std = model.ode_gru.forward(x, t)
    z_u = z_u.view(1, dim).detach().cpu().numpy()
    return z_u

In [None]:
z = np.random.randn(300)

In [None]:
v=calc_vec(z)

In [None]:
v