In [None]:
import os
os.chdir('..')

In [None]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [None]:
from vae_ts_test.dataset import SimpleRandomCurvesDataset
import constants as const
from torch.utils.data import DataLoader
import yaml
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from ipywidgets import interact
import ipywidgets as widgets
%load_ext autoreload
%autoreload 2

In [None]:
dataset = SimpleRandomCurvesDataset(const.DATA_PATH, const.HIDDEN_STATE_PATH)
dataset.df_data.head()

In [None]:
batch_size = 100
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=24)

In [None]:
from vae_ts_test.vae import VAE

# MODEL_VERSION = 'freq_and_phase'
MODEL_VERSION = 'version_134'

hparams_path = f'./lightning_logs/{MODEL_VERSION}/hparams.yaml'
with open(hparams_path, 'r') as stream:
        hparam_dct = yaml.safe_load(stream)
ckpt_file_name = os.listdir(f'./lightning_logs/{MODEL_VERSION}/checkpoints/')[-1]
ckpt_file_path = f'./lightning_logs/{MODEL_VERSION}/checkpoints/{ckpt_file_name}'
model = VAE.load_from_checkpoint(ckpt_file_path)
model

In [None]:
batches = iter(dataloader)

In [None]:
x_batch, idxs = batches.next()
x_batch.shape

In [None]:
mu_z, std_z, z_sample, mu_x, std_x = model.eval()(x_batch)

In [None]:
mu_z.shape

In [None]:
std_z.shape

In [None]:
def recon_plot(index):
    x=list(range(const.TIMESTEPS))
    x_sensor_1 = x_batch.detach().numpy()[index, :, 0]
    x_sensor_2 = x_batch.detach().numpy()[index, :, 1]
    mu_rec_sensor_1 = mu_x.detach().numpy()[index, :, 0]
    mu_rec_sensor_2 = mu_x.detach().numpy()[index, :, 1]
    log_scale = model.log_scale_diag.detach().numpy()
    std = np.exp(log_scale)
    std = std.reshape(-1, 2)
    sensor_1_upper = list(mu_rec_sensor_1 + 2*std[:, 0])
    sensor_2_upper = list(mu_rec_sensor_2 + 2*std[:, 1])
    sensor_1_lower = list(mu_rec_sensor_1 - 2*std[:, 0])
    sensor_2_lower = list(mu_rec_sensor_2 - 2*std[:, 1])


    fig = make_subplots(rows=2, cols=1, shared_xaxes=True)
    # signal 1
    for sig, name, colour in zip([x_sensor_1, mu_rec_sensor_1], ['x_s1', 'p(x_s2|z)'], ['rgb(0,0,100)', 'rgba(192,58,58)']):
        fig.add_trace(
            go.Scatter(x=x,
                       y=sig, name=name,
                        line=dict(color=colour),

                      mode="lines", opacity=.5),
            row=1, col=1,
        )

    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=sensor_1_upper + sensor_1_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(192,58,58,0.1)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
        ), row=1, col=1
    )
    
    # signal 2
    for sig, name, colour in zip([x_sensor_2, mu_rec_sensor_2], ['x_s2', 'p(x_s2|z)'], ['rgb(0,0,100)', 'rgba(192,58,58)']):
        fig.add_trace(
            go.Scatter(x=x,
                       y=sig, name=name,
                        line=dict(color=colour),

                      mode="lines", opacity=.5),
            row=2, col=1,
        )

    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=sensor_2_upper + sensor_2_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(192,58,58,0.1)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
        ),
        row=2, col=1,
    )
    fig.show()
interact(recon_plot, index=range(batch_size))

## What about the latent space?

In [None]:
mu_z.detach().numpy().shape

In [None]:
df_latent_mu = pd.DataFrame(mu_z.detach().numpy(), columns=[f'mu_{i}' for i in range(const.HPARAMS['latent_dim'])])
df_latent_mu.head()

In [None]:
df_latent_std = pd.DataFrame(std_z.detach().numpy(), columns=[f'std_{i}' for i in range(const.HPARAMS['latent_dim'])])
df_latent_std.head()

In [None]:
df_hidden_states = pd.read_csv(const.HIDDEN_STATE_PATH)
df_hidden_states[df_hidden_states.index.isin(idxs.detach().numpy())]
df_hidden_states.head()

In [None]:
fig = make_subplots(rows=4, cols=5)

for i, hs in enumerate(df_hidden_states.columns):
    for j, hs_pred in enumerate(df_latent_mu.columns):
        fig.add_trace(go.Scatter(y=df_latent_mu[hs_pred], x=df_hidden_states[hs], 
                            mode='markers', name=f'activation {hs_pred} over box_x',
                                marker_color='#1f77b4'),
                     row=i+1, col=j+1)
        print(f'row:{i+1}, col:{j+1}: (x: {hs}, y: {hs_pred})')

# Update xaxis properties
for i in range(const.HPARAMS['latent_dim']):
    fig.update_xaxes(title_text=df_hidden_states.columns[0], row=1, col=i+1)
    fig.update_xaxes(title_text=df_hidden_states.columns[1], row=2, col=i+1)
    fig.update_xaxes(title_text=df_hidden_states.columns[2], row=3, col=i+1)
#     fig.update_xaxes(title_text=df_hidden_states.columns[3], row=4, col=i+1)

# Update xaxis properties
for j in range(len(df_latent_mu.columns)):
    fig.update_yaxes(title_text=df_latent_mu.columns[0], row=j+1, col=1)
    fig.update_yaxes(title_text=df_latent_mu.columns[1], row=j+1, col=2)
    fig.update_yaxes(title_text=df_latent_mu.columns[2], row=j+1, col=3)
    fig.update_yaxes(title_text=df_latent_mu.columns[3], row=j+1, col=4) 
    fig.update_yaxes(title_text=df_latent_mu.columns[4], row=j+1, col=5)    


# for l in range(len(df_hidden_states)):
#         fig.update_yaxes(title_text=f"Activation l_{l}", row=l+1, col=1)

fig.update_layout(title_text="Latent neuron activations vs. hidden states", showlegend=False)
fig.show()

In [None]:
fig = make_subplots(rows=1, cols=2)
for col in df_latent_mu.columns:
    fig.add_trace(go.Histogram(x=df_latent_mu[col], name=col), row=1, col=1)
    
for col in df_latent_std.columns:
    fig.add_trace(go.Histogram(x=df_latent_std[col], name=col), row=1, col=2)

# Overlay both histograms
fig.update_layout(barmode='overlay')
fig.update_layout(title_text="Distribution of distribution parameters for z (Gaussian mu and std)", showlegend=True)
fig.update_xaxes(title_text='mu', row=1, col=1)
fig.update_xaxes(title_text='std', row=1, col=2)

for row in (1, 2):
    fig.update_yaxes(title_text='frequency', row=row, col=1)

# fig.update_xaxes(title_text='mu', row=1, col=i+1)

# Reduce opacity to see both histograms
fig.update_traces(opacity=0.4)
fig.show()


In [None]:
def plot_generated(z1, z2, z3, z4, z5):
    out = model.decoder.forward(torch.tensor(np.array([z1, z2, z3, z4, z5]).astype(np.float32)))
    print(out.shape)
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
#     x = np.array(range(len(out)))
#     sig = np.exp(model.log_scale_diag.detach().numpy())
#     out_upper = out.detach().numpy() + sig*out.detach().numpy()
#     out_lower = out.detach().numpy() + sig*out.detach().numpy()
#     for df in [mu_x, df_mu_x]:
    fig.add_trace(
        go.Scatter(x=list(range(50)),
                   y=out.detach().numpy()[0, :, 0],
                  mode="lines", opacity=.5),
        row=1, col=1,
    )
    fig.add_trace(
        go.Scatter(x=list(range(50)),
                   y=out.detach().numpy()[0, :, 1],
                  mode="lines", opacity=.5),
        row=1, col=1,
    )
#     fig.add_trace(go.Scatter(
#         x=x+x[::-1], # x, then x reversed
#         y=out_upper+out_lower[::-1], # upper, then lower reversed
#         fill='toself',
#         fillcolor='rgba(0,100,80,0.2)',
#         line=dict(color='rgba(255,255,255,0)'),
#         hoverinfo="skip",
#         showlegend=False
#     ))
    fig.update_yaxes(range=(-2,2))
    fig.show()


ranges = [widgets.FloatSlider(
    value=0,
    min=-3,
    max=3,
    step=0.1,
    description=f'mu_z_{i}',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
) for i in range(5)]

In [None]:
interact(plot_generated, z1=ranges[0], z2=ranges[1], z3=ranges[2], z4=ranges[3], z5=ranges[4])

In [None]:
import seaborn as sns, numpy as np
sns.set_theme(); np.random.seed(0)
x = np.random.normal(0,.01, 1000)
y = np.random.normal(0,.1, 1000)
ax = sns.displot(x)

In [None]:
class Gaussian():
    def __init__(self, mu, sigma):
        self.mu = mu
        self.sigma = sigma
    
    def get_prob(self, x):
        return 1/(self.sigma * np.sqrt(2 * np.pi)) * np.exp( - (x - self.mu)**2 / (2 * self.sigma**2) )

In [None]:
g_dist = Gaussian(0, 1)

In [None]:
g_dist.get_prob(0)

In [None]:
x = np.linspace(start=-10, stop=10, num=1000)

In [None]:
plt.plot(x, Gaussian(2,.1).get_prob(x))

In [None]:
plt.plot(x, np.log(Gaussian(2,1).get_prob(x)))

In [None]:
s

In [None]:
def func(a, b=2):
    print(a, b)

In [None]:
func(1, b=2)