In [1]:
import os
os.chdir('../../..')
%load_ext autoreload
%autoreload 2

In [2]:
from examples.three_tank.dataset import ThreeTankDataSet
import examples.three_tank.constants as const
from torch.utils.data import DataLoader
import yaml
import pandas as pd
import torch
import numpy as np
import matplotlib.pyplot as plt
from plotly.subplots import make_subplots
import plotly.graph_objects as go
from ipywidgets import interact
import ipywidgets as widgets
from sklearn import preprocessing
from examples.three_tank.data_module import ThreeTankDataModule

In [3]:
dataset = ThreeTankDataSet()

In [4]:
batch_size = 500
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=24)

In [5]:
from seq2seq_vae.vae import VAE
import yaml
# MODEL_VERSION = 'freq_and_phase'
MODEL_VERSION = 'version_6'

hparams_path = f'./lightning_logs/{MODEL_VERSION}/hparams.yaml'
with open(hparams_path, 'r') as stream:
        hparam_dct = yaml.safe_load(stream)
ckpt_file_name = os.listdir(f'./lightning_logs/{MODEL_VERSION}/checkpoints/')[-1]
ckpt_file_path = f'./lightning_logs/{MODEL_VERSION}/checkpoints/{ckpt_file_name}'
model = VAE.load_from_checkpoint(ckpt_file_path)
model

Global seed set to 16548


VAE(
  (encoder): RNNEncoder(
    (rnn): GRU(3, 100, num_layers=5, batch_first=True)
  )
  (decoder): RNNDecoder(
    (rnn1): GRU(5, 100, num_layers=5, batch_first=True)
    (fc): Linear(in_features=100, out_features=3, bias=True)
  )
  (fc_mu): Linear(in_features=100, out_features=5, bias=True)
  (fc_var): Linear(in_features=100, out_features=5, bias=True)
)

In [6]:
batches = iter(dataloader)
x_batch, labels_batch, idxs_batch = batches.next()

In [7]:
mu_z, std_z, z_sample, mu_x, std_x = model.eval()(x_batch)

In [8]:
# from examples.three_tank.data_module import ThreeTankDataModule
# hparams = dict(
#     validation_split=.1,
#     batch_size=100,
#     dl_num_workers=0
# )
# ttdm = ThreeTankDataModule(**hparams)
# ttdm.setup()
# dl = ttdm.train_dataloader()
# tdl = dl.val_dataloader()

In [9]:
def recon_plot(index):
    x=list(range(model.hparams.seq_len))
    x_sensor_1 = x_batch.detach().numpy()[index, :, 0]
    x_sensor_2 = x_batch.detach().numpy()[index, :, 1]
    x_sensor_3 = x_batch.detach().numpy()[index, :, 2]
    mu_rec_sensor_1 = mu_x.detach().numpy()[index, :, 0]
    mu_rec_sensor_2 = mu_x.detach().numpy()[index, :, 1]
    mu_rec_sensor_3 = mu_x.detach().numpy()[index, :, 2]

    log_scale = model.log_scale_diag.detach().numpy()
    std = np.exp(log_scale)
    std = std.reshape(-1, 3)
    sensor_1_upper = list(mu_rec_sensor_1 + 2*std[:, 0])
    sensor_2_upper = list(mu_rec_sensor_2 + 2*std[:, 1])
    sensor_3_upper = list(mu_rec_sensor_3 + 2*std[:, 2])
    sensor_1_lower = list(mu_rec_sensor_1 - 2*std[:, 0])
    sensor_2_lower = list(mu_rec_sensor_2 - 2*std[:, 1])
    sensor_3_lower = list(mu_rec_sensor_3 - 2*std[:, 2])

    


    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)
    # signal 1
    for sig, name, colour in zip([x_sensor_1, mu_rec_sensor_1], ['x_s1', 'p(x_s2|z)'], ['rgb(0,0,100)', 'rgba(192,58,58)']):
        fig.add_trace(
            go.Scatter(x=x,
                       y=sig, name=name,
                        line=dict(color=colour),

                      mode="lines", opacity=.5),
            row=1, col=1,
        )

    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=sensor_1_upper + sensor_1_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(192,58,58,0.1)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
        ), row=1, col=1
    )
    
    # signal 2
    for sig, name, colour in zip([x_sensor_2, mu_rec_sensor_2],
                                 ['x_s2', 'p(x_s2|z)'], ['rgb(0,0,100)', 'rgba(192,58,58)']):
        fig.add_trace(
            go.Scatter(x=x,
                       y=sig, name=name,
                        line=dict(color=colour),

                      mode="lines", opacity=.5),
            row=1, col=1,
        )

    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=sensor_2_upper + sensor_2_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(192,58,58,0.1)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
        ),
        row=1, col=1,
    )
    # signal 3
    for sig, name, colour in zip([x_sensor_3, mu_rec_sensor_3],
                                 ['x_s3', 'p(x_s3|z)'], ['rgb(0,0,100)', 'rgba(192,58,58)']):
        fig.add_trace(
            go.Scatter(x=x,
                       y=sig, name=name,
                        line=dict(color=colour),

                      mode="lines", opacity=.5),
            row=1, col=1,
        )

    fig.add_trace(go.Scatter(
        x=x+x[::-1], # x, then x reversed
        y=sensor_3_upper + sensor_3_lower[::-1], # upper, then lower reversed
        fill='toself',
        fillcolor='rgba(192,58,58,0.1)',
        line=dict(color='rgba(255,255,255,0)'),
        hoverinfo="skip",
        showlegend=False
        ),
        row=1, col=1,
    )
#     fig.update_yaxes(title='stress', row=1, col=3)
#     fig.update_yaxes(title='strain', row=1, col=2)
#     fig.update_yaxes(title='time', row=1, col=1)
#     fig.update_xaxes(title='data point')

#     fig.update_yaxes(range=(-1,1))

    fig.show()
interact(recon_plot, index=range(x_batch.shape[0]))

interactive(children=(Dropdown(description='index', options=(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14,…

<function __main__.recon_plot(index)>

In [10]:
df_latent_mu = pd.DataFrame(mu_z.detach().numpy(), columns=[f'mu_{i}' for i in range(5)])
df_latent_std = pd.DataFrame(std_z.detach().numpy(), columns=[f'std_{i}' for i in range(5)])

fig = make_subplots(rows=1, cols=2)
for col in df_latent_mu.columns:
    fig.add_trace(go.Histogram(x=df_latent_mu[col], name=col), row=1, col=1)
    
for col in df_latent_std.columns:
    fig.add_trace(go.Histogram(x=df_latent_std[col], name=col), row=1, col=2)

# Overlay both histograms
fig.update_layout(barmode='overlay')
fig.update_layout(title_text=f"Distribution of distribution parameters for z (Gaussian mu and sigma)   N={hparam_dct['batch_size']}", showlegend=True)
fig.update_xaxes(title_text='mu', row=1, col=1)
fig.update_xaxes(title_text='std', row=1, col=2)

for row in (1, 2):
    fig.update_yaxes(title_text='frequency', row=row, col=1)


# Reduce opacity to see both histograms
fig.update_traces(opacity=0.4)
fig.show()

In [11]:
df_latent_mu = pd.DataFrame(mu_z.detach().numpy(),
                            columns=[f'mu_{i}' for i in range(hparam_dct['latent_dim'])])
df_latent_mu.head()

Unnamed: 0,mu_0,mu_1,mu_2,mu_3,mu_4
0,-0.005887,1.542855,1.442096,0.465989,-0.023207
1,0.00373,-0.812797,-2.41439,-1.020747,-0.005163
2,0.010637,-0.501245,-0.733772,1.238727,-0.001794
3,-0.000138,-0.76312,0.580792,2.013077,-0.001652
4,0.001115,-0.604523,1.580155,1.79157,-0.009759


In [12]:
df_latent_std = pd.DataFrame(std_z.detach().numpy(), columns=[f'std_{i}' for i in range(hparam_dct['latent_dim'])])
df_latent_std.head()

Unnamed: 0,std_0,std_1,std_2,std_3,std_4
0,0.111222,0.001764,0.000904,0.001235,0.128115
1,0.097325,0.001274,0.00124,0.001694,0.138483
2,0.100638,0.001238,0.000918,0.001181,0.126752
3,0.122135,0.001559,0.000965,0.00135,0.134991
4,0.130426,0.00168,0.000946,0.001281,0.136559


In [13]:
labels_batch

tensor([[6.8253, 4.1578, 0.2358, 0.8419],
        [2.1634, 8.3104, 0.9370, 0.4301],
        [0.9681, 8.1195, 0.3921, 0.3720],
        ...,
        [9.9934, 2.2509, 0.1491, 0.7672],
        [8.3213, 3.5134, 0.5370, 0.2172],
        [7.6248, 5.4792, 0.3392, 0.4948]])

In [14]:
idxs_batch

tensor([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,  27,
         28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,  40,  41,
         42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,  53,  54,  55,
         56,  57,  58,  59,  60,  61,  62,  63,  64,  65,  66,  67,  68,  69,
         70,  71,  72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83,
         84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95,  96,  97,
         98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
        112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
        126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139,
        140, 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153,
        154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167,
        168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 1

In [15]:
# df_real_params = pd.DataFrame({k:v for k,v in zip(const.LABEL_COLS+['model'], labels_batch)})
df_real_params = pd.DataFrame(labels_batch.numpy(), columns=const.LABEL_COLS)
df_real_params['sample_idx'] = idxs_batch
df_real_params.head()

Unnamed: 0,q1,q3,kv12,kv23,sample_idx
0,6.825291,4.157763,0.235768,0.841853,0
1,2.163383,8.310431,0.937021,0.430145,1
2,0.968112,8.119468,0.392128,0.371987,2
3,8.669929,6.615997,0.243892,0.299338,3
4,0.863125,4.602079,0.248618,0.341133,4


In [16]:
fig = make_subplots(rows=4, cols=5)

for i, hs in enumerate(const.LABEL_COLS):
    for j, hs_pred in enumerate(df_latent_mu.columns):
        fig.add_trace(go.Scatter(y=df_latent_mu[hs_pred], x=df_real_params[hs], 
                            mode='markers', name=f'activation {hs_pred} over box_x',
                                marker_color='#1f77b4'),
                     row=i+1, col=j+1)
        fig.update_yaxes(range=[-5, 5])

# Update xaxis properties
for i in range(hparam_dct['latent_dim']):
    fig.update_xaxes(title_text=df_real_params.columns[0], row=1, col=i+1)
    fig.update_xaxes(title_text=df_real_params.columns[1], row=2, col=i+1)
    fig.update_xaxes(title_text=df_real_params.columns[2], row=3, col=i+1)
    fig.update_xaxes(title_text=df_real_params.columns[3], row=4, col=i+1)

# Update xaxis properties
for j in range(len(df_real_params)):
    fig.update_yaxes(title_text=df_latent_mu.columns[0], row=j+1, col=1)
    fig.update_yaxes(title_text=df_latent_mu.columns[1], row=j+1, col=2)
    fig.update_yaxes(title_text=df_latent_mu.columns[2], row=j+1, col=3)
    fig.update_yaxes(title_text=df_latent_mu.columns[3], row=j+1, col=4)
    fig.update_yaxes(title_text=df_latent_mu.columns[4], row=j+1, col=5)


fig.update_layout(height=1000, width=1500, title_text="Latent neuron activations vs. hidden states", showlegend=False)
fig.show()


In [17]:
df_latent_mu.columns

Index(['mu_0', 'mu_1', 'mu_2', 'mu_3', 'mu_4'], dtype='object')

In [27]:
fig = make_subplots(rows=4, cols=5)

for i, hs in enumerate(const.LABEL_COLS):
    for j, hs_pred in enumerate(df_latent_mu.columns):
        fig.add_trace(go.Scatter(y=df_latent_mu[hs_pred], x=df_real_params[hs], 
                            mode='markers', name=f'activation {hs_pred} over box_x',
                                marker_color='#1f77b4',
                                 opacity=1
                                ),
                     row=i+1, col=j+1)
        fig.update_yaxes(range=[-5, 5])

# # Update xaxis properties
# for i in range(hparam_dct['latent_dim']):
#     fig.update_xaxes(title_text=df_real_params.columns[0], row=1, col=i+1)
#     fig.update_xaxes(title_text=df_real_params.columns[1], row=2, col=i+1)
#     fig.update_xaxes(title_text=df_real_params.columns[2], row=3, col=i+1)
#     fig.update_xaxes(title_text=df_real_params.columns[3], row=4, col=i+1)

# # Update xaxis properties
# for j in range(len(df_real_params)):
#     fig.update_yaxes(title_text=df_latent_mu.columns[0], row=j+1, col=1)
#     fig.update_yaxes(title_text=df_latent_mu.columns[1], row=j+1, col=2)
#     fig.update_yaxes(title_text=df_latent_mu.columns[2], row=j+1, col=3)
#     fig.update_yaxes(title_text=df_latent_mu.columns[3], row=j+1, col=4)
#     fig.update_yaxes(title_text=df_latent_mu.columns[4], row=j+1, col=5)


fig.update_layout(title_text=r"Estimantd mean parameters in the latent space over true concepts", showlegend=False,
                width =500, height=400, 
                  font_family="Serif", font_size=11, 
                  margin_l=5, margin_t=50, margin_b=5, margin_r=5
)
fig.show()

In [28]:
import plotly.io as pio
#save a figure of 300dpi, width 1.5 inches, height 0.75inches
pio.write_image(fig, "vae-results.pdf", width=1.5*300, height=1.1*300)

In [20]:
pip install -U kaleido

Note: you may need to restart the kernel to use updated packages.


In [21]:
def plot_generated(z1, z2, z3, z4, z5):
    out = model.decoder.forward(torch.tensor(np.array([z1, z2, z3, z4, z5]).astype(np.float32)))
    print(out.shape)
    fig = make_subplots(rows=1, cols=1, shared_xaxes=True)

    fig.add_trace(
        go.Scatter(x=out.detach().numpy()[0, :, 1],
                   y=out.detach().numpy()[0, :, 2],
                  mode="markers+lines", opacity=.9, name='mu_signal_1'),
        row=1, col=1)

#     fig.update_yaxes(range=(-0.01,.5))
    fig.update_xaxes(title_text='time')
    fig.update_layout(title_text=f"Mean parameters for samples generated from z1={z1},  z2={z2},  z3={z3},  z4={z4},  z5={z5}")
    fig.show()


ranges = [widgets.FloatSlider(
    value=0,
    min=-3,
    max=3,
    step=0.1,
    description=f'z_{i}',
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
) for i in range(5)]

In [22]:
interact(plot_generated, z1=ranges[0], z2=ranges[1], z3=ranges[2], z4=ranges[3], z5=ranges[4])

interactive(children=(FloatSlider(value=0.0, continuous_update=False, description='z_0', max=3.0, min=-3.0, re…

<function __main__.plot_generated(z1, z2, z3, z4, z5)>