In [None]:
import os
os.chdir('../../..')
%load_ext autoreload
%autoreload 2

In [None]:
from examples.three_tank.dataset import ThreeTankDataSet
import examples.three_tank.constants as const
from torch.utils.data import DataLoader
import yaml
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from ipywidgets import interact
import ipywidgets as widgets
from sklearn import preprocessing
from examples.three_tank.data_module import ThreeTankDataModule

In [None]:
dataset = ThreeTankDataSet()

In [None]:
batch_size = 500
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=24)

In [None]:
from seq2seq_vae.vae import VAE
import yaml
# MODEL_VERSION = 'freq_and_phase'
MODEL_VERSION = 'version_7'

hparams_path = f'./lightning_logs/{MODEL_VERSION}/hparams.yaml'
with open(hparams_path, 'r') as stream:
        hparam_dct = yaml.safe_load(stream)
ckpt_file_name = os.listdir(f'./lightning_logs/{MODEL_VERSION}/checkpoints/')[-1]
ckpt_file_path = f'./lightning_logs/{MODEL_VERSION}/checkpoints/{ckpt_file_name}'
model = VAE.load_from_checkpoint(ckpt_file_path)
model

In [None]:
batches = iter(dataloader)
x_batch, labels_batch, idxs_batch = batches.next()

In [None]:
mu_z, std_z, z_sample, mu_x, std_x = model.eval()(x_batch)

In [None]:
# from examples.three_tank.data_module import ThreeTankDataModule
# hparams = dict(
#     validation_split=.1,
#     batch_size=100,
#     dl_num_workers=0
# )
# ttdm = ThreeTankDataModule(**hparams)
# ttdm.setup()
# dl = ttdm.train_dataloader()
# tdl = dl.val_dataloader()

In [None]:
def recon_plot(index):
    x=list(range(model.hparams.seq_len))
    x_sensor_1 = x_batch.detach().numpy()[index, :, 0]
    x_sensor_2 = x_batch.detach().numpy()[index, :, 1]
    x_sensor_3 = x_batch.detach().numpy()[index, :, 2]
    mu_rec_sensor_1 = mu_x.detach().numpy()[index, :, 0]
    mu_rec_sensor_2 = mu_x.detach().numpy()[index, :, 1]
    mu_rec_sensor_3 = mu_x.detach().numpy()[index, :, 2]

    log_scale = model.log_scale_diag.detach().numpy()
    std = np.exp(log_scale)
    std = std.reshape(-1, 3)
    sensor_1_upper = list(mu_rec_sensor_1 + 2*std[:, 0])
    sensor_2_upper = list(mu_rec_sensor_2 + 2*std[:, 1])
    sensor_3_upper = list(mu_rec_sensor_3 + 2*std[:, 2])
    sensor_1_lower = list(mu_rec_sensor_1 - 2*std[:, 0])
    sensor_2_lower = list(mu_rec_sensor_2 - 2*std[:, 1])
    sensor_3_lower = list(mu_rec_sensor_3 - 2*std[:, 2])

    


    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
    # signal 1
    for sig, name, colour in zip([x_sensor_1, mu_rec_sensor_1],
                                 ['x_s1', 'p(x_s2|z)'], ['rgb(0,0,100)', 'rgba(192,58,58)']):
        fig.add_trace(
            go.Scatter(x=x,
                       y=sig, name=name,
                        line=dict(color=colour),

                      mode="lines", opacity=.5),
            row=1, col=1,
        )

    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=sensor_1_upper + sensor_1_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(192,58,58,0.1)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
        ), row=1, col=1
    )
    
    # signal 2
    for sig, name, colour in zip([x_sensor_2, mu_rec_sensor_2],
                                 ['x_s2', 'p(x_s2|z)'], ['rgb(0,0,100)', 'rgba(192,58,58)']):
        fig.add_trace(
            go.Scatter(x=x,
                       y=sig, name=name,
                        line=dict(color=colour),

                      mode="lines", opacity=.5),
            row=1, col=1,
        )

    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=sensor_2_upper + sensor_2_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(192,58,58,0.1)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
        ),
        row=1, col=1,
    )
    # signal 3
    for sig, name, colour in zip([x_sensor_3, mu_rec_sensor_3],
                                 ['x_s3', 'p(x_s3|z)'], ['rgb(0,0,100)', 'rgba(192,58,58)']):
        fig.add_trace(
            go.Scatter(x=x,
                       y=sig, name=name,
                        line=dict(color=colour),

                      mode="lines", opacity=.5),
            row=1, col=1,
        )

    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=sensor_3_upper + sensor_3_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(192,58,58,0.1)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
        ),
        row=1, col=1,
    )
#     fig.update_yaxes(title='stress', row=1, col=3)
#     fig.update_yaxes(title='strain', row=1, col=2)
#     fig.update_yaxes(title='time', row=1, col=1)
#     fig.update_xaxes(title='data point')

#     fig.update_yaxes(range=(-1,1))

    fig.show()
interact(recon_plot, index=range(x_batch.shape[0]))

In [None]:
df_latent_mu = pd.DataFrame(mu_z.detach().numpy(), columns=[f'mu_{i}' for i in range(5)])
df_latent_std = pd.DataFrame(std_z.detach().numpy(), columns=[f'std_{i}' for i in range(5)])

fig = make_subplots(rows=1, cols=2)
for col in df_latent_mu.columns:
    fig.add_trace(go.Histogram(x=df_latent_mu[col], name=col), row=1, col=1)
    
for col in df_latent_std.columns:
    fig.add_trace(go.Histogram(x=df_latent_std[col], name=col), row=1, col=2)

# Overlay both histograms
fig.update_layout(barmode='overlay')
fig.update_layout(title_text=f"Distribution of distribution parameters for z (Gaussian mu and sigma)   N={hparam_dct['batch_size']}", showlegend=True)
fig.update_xaxes(title_text='mu', row=1, col=1)
fig.update_xaxes(title_text='std', row=1, col=2)

for row in (1, 2):
    fig.update_yaxes(title_text='frequency', row=row, col=1)


# Reduce opacity to see both histograms
fig.update_traces(opacity=0.4)
fig.show()

In [None]:
df_latent_mu = pd.DataFrame(mu_z.detach().numpy(),
                            columns=[f'mu_{i}' for i in range(hparam_dct['latent_dim'])])
df_latent_mu.head()

In [None]:
df_latent_std = pd.DataFrame(std_z.detach().numpy(), columns=[f'std_{i}' for i in range(hparam_dct['latent_dim'])])
df_latent_std.head()

In [None]:
labels_batch

In [None]:
idxs_batch

In [None]:
# df_real_params = pd.DataFrame({k:v for k,v in zip(const.LABEL_COLS+['model'], labels_batch)})
df_real_params = pd.DataFrame(labels_batch.numpy(), columns=const.LABEL_COLS)
df_real_params['sample_idx'] = idxs_batch
df_real_params.head()

In [None]:
fig = make_subplots(rows=4, cols=5)

for i, hs in enumerate(const.LABEL_COLS):
    for j, hs_pred in enumerate(df_latent_mu.columns):
        fig.add_trace(go.Scatter(y=df_latent_mu[hs_pred], x=df_real_params[hs], 
                            mode='markers', name=f'activation {hs_pred} over box_x',
                                marker_color='#1f77b4'),
                     row=i+1, col=j+1)
        fig.update_yaxes(range=[-5, 5])

# Update xaxis properties
for i in range(hparam_dct['latent_dim']):
    fig.update_xaxes(title_text=df_real_params.columns[0], row=1, col=i+1)
    fig.update_xaxes(title_text=df_real_params.columns[1], row=2, col=i+1)
    fig.update_xaxes(title_text=df_real_params.columns[2], row=3, col=i+1)
    fig.update_xaxes(title_text=df_real_params.columns[3], row=4, col=i+1)

# Update xaxis properties
for j in range(len(df_real_params)):
    fig.update_yaxes(title_text=df_latent_mu.columns[0], row=j+1, col=1)
    fig.update_yaxes(title_text=df_latent_mu.columns[1], row=j+1, col=2)
    fig.update_yaxes(title_text=df_latent_mu.columns[2], row=j+1, col=3)
    fig.update_yaxes(title_text=df_latent_mu.columns[3], row=j+1, col=4)
    fig.update_yaxes(title_text=df_latent_mu.columns[4], row=j+1, col=5)


fig.update_layout(height=1000, width=1500, title_text="Latent neuron activations vs. hidden states", showlegend=False)
fig.show()


In [None]:
df_latent_mu.columns

In [None]:
fig = make_subplots(rows=4, cols=5)

for i, hs in enumerate(const.LABEL_COLS):
    for j, hs_pred in enumerate(df_latent_mu.columns):
        fig.add_trace(go.Scatter(y=df_latent_mu[hs_pred], x=df_real_params[hs], 
                            mode='markers', name=f'activation {hs_pred} over box_x',
                                marker_color='#1f77b4',
                                 opacity=1,
                                 marker=dict(size=3),
                                ),
                     row=i+1, col=j+1)
        fig.update_yaxes(range=[-5, 5])

# # Update xaxis properties
# for i in range(hparam_dct['latent_dim']):
#     fig.update_xaxes(title_text=df_real_params.columns[0], row=1, col=i+1)
#     fig.update_xaxes(title_text=df_real_params.columns[1], row=2, col=i+1)
#     fig.update_xaxes(title_text=df_real_params.columns[2], row=3, col=i+1)
#     fig.update_xaxes(title_text=df_real_params.columns[3], row=4, col=i+1)

# # Update xaxis properties
# for j in range(len(df_real_params)):
#     fig.update_yaxes(title_text=df_latent_mu.columns[0], row=j+1, col=1)
#     fig.update_yaxes(title_text=df_latent_mu.columns[1], row=j+1, col=2)
#     fig.update_yaxes(title_text=df_latent_mu.columns[2], row=j+1, col=3)
#     fig.update_yaxes(title_text=df_latent_mu.columns[3], row=j+1, col=4)
#     fig.update_yaxes(title_text=df_latent_mu.columns[4], row=j+1, col=5)


fig.update_layout(title_text=r"Estimantd mean parameters in the latent space over true concepts", showlegend=False,
                width =500, height=300, 
                  font_family="Serif", font_size=11, 
                  margin_l=5, margin_t=50, margin_b=5, margin_r=5,
)
fig.show()

In [None]:
import plotly.io as pio
#save a figure of 300dpi, width 1.5 inches, height 0.75inches
pio.write_image(fig, "vae-results.pdf", width=500, height=300)

In [None]:
def plot_generated(z1, z2, z3, z4, z5):
    out = model.decoder.forward(torch.tensor(np.array([z1, z2, z3, z4, z5]).astype(np.float32)))
    df_plot = pd.DataFrame(out.cpu().detach().numpy().reshape(const.NUMBER_TIMESTEPS, 3),
                           columns=const.STATE_COL_NAMES)
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)

    for col, name in zip(const.STATE_COL_NAMES, ['h1(t)', 'h2(t)', 'h3(t)']):
        fig.add_trace(go.Scatter(x=df_plot.index, y=df_plot[col], name=name,
                      mode="lines", opacity=1),
            row=1, col=1)

    fig.update_xaxes(title_text='time')
    fig.update_layout(title_text=f"Generated samples using the decoder network (mean values only)", showlegend=True)
    fig.show()


ranges = [widgets.FloatSlider(
    value=0,
    min=-3,
    max=3,
    step=0.1,
    description=f'z_{i}',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
) for i in range(5)]
interact(plot_generated, z1=ranges[0], z2=ranges[1], z3=ranges[2], z4=ranges[3], z5=ranges[4])