In [None]:
import os
os.chdir('..')

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
from vae_ts_test.dataset import SimpleRandomCurvesDataset
import constants as const
from torch.utils.data import DataLoader
from vae_ts_test.vae import VAE
import yaml
import pandas as pd
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from ipywidgets import interact
import ipywidgets as widgets
%load_ext autoreload
%autoreload 2

In [None]:
dataset = SimpleRandomCurvesDataset(const.DATA_PATH)
dataset.df_scaled.head()

In [None]:
batch_size = 10
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=24)

In [None]:
MODEL_VERSION = 23
hparams_path = f'./lightning_logs/version_{MODEL_VERSION}/hparams.yaml'
with open(hparams_path, 'r') as stream:
        hparam_dct = yaml.safe_load(stream)
ckpt_file_name = os.listdir(f'./lightning_logs/version_{MODEL_VERSION}/checkpoints/')[-1]
ckpt_file_path = f'./lightning_logs/version_{MODEL_VERSION}/checkpoints/{ckpt_file_name}'
model = VAE.load_from_checkpoint(ckpt_file_path)
model

In [None]:
x_batch = iter(dataloader).next()
x_batch.shape

In [None]:
mu_z, std_z, z_sample, mu_x, std_x = model.eval()(x_batch)

In [None]:
std_x

In [None]:
df_x = pd.DataFrame(x_batch.detach().numpy().reshape(batch_size, 30))
df_x.head()

In [None]:
df_mu_x = pd.DataFrame(mu_x.detach().numpy().reshape(batch_size, 30))
df_mu_x.head()

In [None]:
index = 1
std = [.1]*30
y_upper = df_mu_x.loc[index, :].values + df_mu_x.loc[index, :].values * std
y_lower = df_mu_x.loc[index, :].values - df_mu_x.loc[index, :].values * std

In [None]:
def recon_plot(index):
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
    for df, name, colour in zip([df_x, df_mu_x], ['x', 'p(x|z)'], ['rgb(0,0,100)', 'rgba(192,58,58)']):
        fig.add_trace(
            go.Scatter(x=df.columns,
                       y=df.loc[index, :].values, name=name,
                        line=dict(color=colour),

                      mode="lines", opacity=.5),
            row=1, col=1,
        )
    x = list(df_mu_x.columns)
    y = df_mu_x.loc[index, :].values
    log_scale = model.log_scale_diag.detach().numpy()
    std = np.exp(log_scale)
    y_upper = list(y + std)
    y_lower = list(y - std)

    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=y_upper+y_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(192,58,58,0.1)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
        )
    )
    fig.show()
interact(recon_plot, index=df_x.index)

In [None]:
mu_z

In [None]:
std_z

In [None]:
model.log_scale_diag

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

In [None]:
list(df_mu_x.columns[0:10])

In [None]:
len(x)

In [None]:
x = list(df_mu_x.columns)
y = df_mu_x.loc[index, :].values
y_upper = list(y + std)
y_lower = list(y - std)


fig = go.Figure([
    go.Scatter(
        x=x,
        y=y,
        line=dict(color='rgb(0,100,80)'),
        mode='lines'
    ),
    go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=y_upper+y_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(0,100,80,0.2)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
    )
])
fig.show()

In [None]:
x = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
y = [1, 2, 7, 4, 5, 6, 7, 8, 9, 10]
y_upper = [2, 3, 8, 5, 6, 7, 8, 9, 10, 11]
y_lower = [0, 1, 5, 3, 4, 5, 6, 7, 8, 9]


fig = go.Figure([
    go.Scatter(
        x=x,
        y=y,
        line=dict(color='rgb(0,100,80)'),
        mode='lines'
    ),
    go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=y_upper+y_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(0,100,80,0.2)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
    )
])
fig.show()

In [None]:
np.exp(model.log_scale_diag.detach().numpy())

In [None]:
def plot_generated(z1, z2, z3):
    out = model.decoder.forward(torch.tensor(np.array([z1,  z2, z3]).astype(np.float32)))
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
    x = np.array(range(len(out)))
    sig = np.exp(model.log_scale_diag.detach().numpy())
    out_upper = out.detach().numpy() + sig*out.detach().numpy()
    out_lower = out.detach().numpy() + sig*out.detach().numpy()
    print(sig)
#     print
#     for df in [df_x, df_mu_x]:
#         fig.add_trace(
#             go.Scatter(x=x,
#                        y=out.detach().numpy(),
#                       mode="lines", opacity=.5),
#             row=1, col=1,
#         )
    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=out_upper+out_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(0,100,80,0.2)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
    ))
    fig.show()

ranges = [widgets.IntSlider(min=-10,max=10,step=1,value=0) for i in range(3)]
interact(plot_generated, z1=ranges[0], z2=ranges[1], z3=ranges[2])

In [None]:
np.diag([1, 1, 1])

In [None]:
mu = torch.Tensor([10, 20, 30])
sig_diag_log = torch.Tensor([1, 2, 3])
sig_diag = torch.exp(sig_diag_log)
sig = torch.diag(sig_diag)
# sig = torch.Tensor([1])
# dist = torch.distributions.Normal(mu, sig)
dist =torch.distributions.MultivariateNormal(mu, scale_tril=sig)

In [None]:
num_samples = 10_000
df_dist = pd.DataFrame(np.array([dist.sample().detach().numpy() for i in range(num_samples)]))
fig = go.Figure()
for c in df_dist.columns:
    fig.add_trace(go.Histogram(x=df_dist[c]))

# Overlay both histograms
fig.update_layout(barmode='overlay')
# Reduce opacity to see both histograms
fig.update_traces(opacity=0.75)
fig.show()

In [None]:
loc = torch.zeros(3)
scale = torch.ones(3)
# mvn = torch.distributions.MultivariateNormal(loc, scale_tril=torch.diag(scale))
mvn = torch.distributions.MultivariateNormal(loc, scale_tril=sig)


In [None]:
torch.diag(scale)

In [None]:
mvn

In [None]:
np.exp(1)