In [1]:
import os, glob, re, sys
import socket
import datetime
import torch
import yaml
import numpy as np

import matplotlib.pyplot as plt
import matplotlib.cm as cm
import seaborn as sb
import pandas as pd
import umap
from sklearn.manifold import TSNE
from tqdm import tqdm_notebook
from collections import OrderedDict

sys.path.append('../')
from src.vae_models_test import *
from src.datasets import Astro_lightcurves
from src.utils import plot_wall_time_series

from ipywidgets import interact, widgets
from bokeh.io import push_notebook, show, output_notebook
from bokeh.plotting import figure
from bokeh.transform import linear_cmap

import wandb

output_notebook()
%matplotlib inline

In [2]:
ID = 'b68z1hwo'
gpu = False

In [3]:
if socket.gethostname() == 'exalearn':
    path = '/home/jorgemarpa/Astro/AstroLCs_DeepGen'
else:
    path = '/Users/jorgetil/Astro/AstroLCs_DeepGen'
    api = wandb.Api()
    run = api.run('jorgemarpa/Phy-VAE/%s' % (ID))
    run.file('VAE_model_None.pt').download(replace=True, 
                                           root='%s/wandb/run--%s/' % (path, ID))
    run.file('config.yaml').download(replace=True, 
                                     root='%s/wandb/run--%s/' % (path, ID))

device = torch.device("cuda:0" if torch.cuda.is_available() and gpu else "cpu")
if device.type == 'cuda':
    print('%.2f GB' % (torch.cuda.memory_allocated(device) / 1073741824))
    torch.cuda.empty_cache()
print(device)

cpu


## Load model variables into VAE model

In [5]:
vae, config = load_model_list(ID=ID)
config

Loading from... 
 /Users/jorgetil/Astro/AstroLCs_DeepGen/wandb/run--b68z1hwo/VAE_model_None.pt
Is model in cuda?  False


{'architecture': 'tcn',
 'batch_size': 128,
 'beta_vae': '0.75',
 'classes': 'all',
 'data': 'OGLE3',
 'dropout': 0.2,
 'epochs': 150,
 'feed_pp': 'F',
 'hidden_size': 48,
 'kernel_size': 5,
 'label_dim': 8,
 'latent_dim': 4,
 'latent_mode': 'repeat',
 'learning_rate': 0.001,
 'learning_rate_scheduler': 'cos',
 'n_feats': 3,
 'n_train_params': 300897,
 'num_layers': 9,
 'phys_params': '',
 'physics_dim': 0,
 'sequence_lenght': 600,
 'transpose': False,
 'normed': True,
 'folded': True,
 'date': ''}

In [11]:
dataset = Astro_lightcurves(survey=config['data'],
                            band='I' if config['data'] else 'B',
                            use_time=True,
                            use_err=True,
                            norm=config['normed'],
                            folded=config['folded'],
                            machine=socket.gethostname(),
                            seq_len=config['sequence_lenght'],
                            phy_params=config['phys_params'])

if config['classes'].split('_')[0] == 'drop':
    dataset.drop_class(config['classes'].split('_')[1])
elif config['classes'].split('_')[0] == 'only':
    dataset.only_class(config['classes'].split('_')[1])
print('Using physical parameters: ', dataset.phy_names)
dataset.remove_nan()
print(dataset.class_value_counts())
print('Total: ', len(dataset))
num_cls = dataset.labels_onehot.shape[1]

dataloader, _ = dataset.get_dataloader(batch_size=100, 
                                       test_split=0., shuffle=False)

Loading from:
 /Users/jorgetil/Google Drive/Colab_Notebooks/data/time_series/real/OGLE3_lcs_I_meta_snr5_augmented_folded_trim600.npy.gz
Using physical parameters:  []
ELL      10365
RRLYR    10169
CEP      10045
LPV      10044
ECL      10000
DSCT      5090
T2CEP     5047
ACEP      5000
Name: Type, dtype: int64
None
Total:  65760


# Bokeh interactive

## Scatter

## Sliders

In [270]:
cls = 'RRLYR'
idx = dataset.meta.reset_index().query('Type == "%s"' % cls).sample(1).index
print(idx)
lc, label, onehot, pp_n = dataset[idx]
pp_inv = dataset.mm_scaler.inverse_transform(pp_n)
lc = torch.from_numpy(lc)
onehot = torch.from_numpy(onehot)
pp = torch.from_numpy(pp_n)
cc = torch.cat([onehot, pp], dim=1)


if config['conditional_dim'] > num_cls:
    lchat, mu, logvar, z = vae(lc, c=cc)
elif config['conditional_dim'] == num_cls:
    lchat, mu, logvar, z = vae(lc, c=onehot)
elif config['conditional_dim'] == 0:
    lchat, mu, logvar, z = vae(lc)

z_n = mu.detach().numpy().flatten()
lchat = lchat.detach().numpy()
lc = lc.detach().numpy()
print(z_n)

p = figure(title="VAE", plot_height=300, 
           plot_width=600, y_range=(1.1,-0.1),
           background_fill_color='#efefef',
           x_axis_label='Phase',
           y_axis_label='Normalized Mag')
#r = p.errorbar(x, y, color="royalblue", line_width=1.5, alpha=0.8)
real = p.circle(lc[0,:,0], lc[0,:,1], color='black', size=5, line_alpha=0)
pred = p.circle(lc[0,:,0], lchat[0,:,0], color='royalblue', size=5, line_alpha=0)

#create the coordinates for the errorbars
err_xs = []
err_ys = []
for x, y, yerr in zip(lc[0,:,0], lc[0,:,1], lc[0,:,2]):
    err_xs.append((x, x))
    err_ys.append((y - yerr, y + yerr))

# plot them
p.multi_line(err_xs, err_ys, color='black')
# show(p)

Int64Index([54204], dtype='int64')
[ 3.3176795e-01  2.3465063e-03 -3.0221080e-03  9.7502601e-01
  2.4862661e+00 -7.4028014e-04]




In [271]:
#dt = torch.from_numpy(np.sort(np.random.choice(np.linspace(0,1,500, 
#                                                               dtype=np.float32), 
#                                           size=300))).unsqueeze(0)
dt = torch.from_numpy(lc[0,:,0]).unsqueeze(0)

In [272]:
def update(label=label, 
           z0=z_n[0], z1=z_n[1], 
           z2=z_n[2], z3=z_n[3],
           z4=z_n[4], z5=z_n[5], 
           period=pp_inv[0,0]):
    
    # idx = dataset.meta.reset_index().query('Type == "%s"' % label).sample(1).index
    # lc, _, onehot, pp_n = dataset[idx]
    # real.data_source.data['x'] = lc[0,:,0]
    # real.data_source.data['y'] = lc[0,:,1]
    
    z_ = z.clone()
    z_[0,0] = z0
    z_[0,1] = z1
    z_[0,2] = z2
    z_[0,3] = z3
    z_[0,4] = z4
    z_[0,5] = z5
    print('Latent vector: ', z_)
    
    new_onehot = torch.from_numpy(dataset.label_onehot_enc.transform(np.array(label).reshape(-1,1)))
    print('Label: ', label, new_onehot)
    
    new_pp = torch.from_numpy(dataset.mm_scaler.transform(np.array([[period]],
                                                                   dtype=np.float32)))
    print('Phys params, enc(P, T_eff/met): ', new_pp)
    cc = torch.cat([new_onehot, new_pp], dim=1)
    
        
    if config['conditional_dim'] > num_cls:
        lchat = vae.decoder(z_, dt, c=cc)
    elif config['conditional_dim'] == num_cls:
        lchat = vae.decoder(z_, dt, c=new_onehot)
    elif config['conditional_dim'] == 0:
        lchat = vae.decoder(z_)
        
    lchat = lchat.detach().numpy()
    
    pred.data_source.data['x'] = dt[0].detach().numpy()
    pred.data_source.data['y'] = lchat[0,:,0]
    
    push_notebook()

In [269]:
show(p, notebook_handle=True)
if type(config['beta_vae']) == str:
    lim = 5
    
else:
    lim = 10
interact(update, label=dataset.label_onehot_enc.categories_[0].tolist(), 
         z0=(-lim,lim,.01), z1=(-lim,lim,.01), z2=(-lim,lim,.01), z3=(-lim,lim,.01),
         z4=(-lim,lim,.01), z5=(-lim,lim,.01), period=(0.01, 20, 0.01), teff=(3000, 9000, 10))

interactive(children=(Dropdown(description='label', index=6, options=('ACEP', 'CEP', 'DSCT', 'ECL', 'ELL', 'LP…

<function __main__.update(label=array(['RRLYR'], dtype=object), z0=-1.186738, z1=-0.00039400172, z2=0.00015353863, z3=-0.01102455, z4=0.22371186, z5=0.0010238732, period=0.5120501)>