In [1]:
import os
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.utils import to_categorical
from keras.layers.merge import concatenate as concat
import numpy as np
from model.vae.branched_classifier_vae import BranchedClassifierVAE
from model.vae.conditional_vae import ConditionalVAE
from util.experiment import Experiment, load_experiments
from util.plotting import plot_label_clusters
import matplotlib.pyplot as plt
import seaborn as sb
import pandas as pd
from collections import OrderedDict
import copy

# Interactive tools
from pathlib import Path
from PIL import Image
import pickle
import numpy as np

import ipywidgets as widgets
from tqdm import tqdm

In [2]:
def from_categorical(y):
    y = np.array(y, dtype='int')
    return np.argmax(y, axis=len(y.shape)-1)

def dict_pretty_print(dictionary):
    res = ""
    for key, value in dictionary.items():
        res += f"{key}: {value}\n"
        
    return res

# Setup

In [3]:
MAX_DIMS = 10

In [4]:
datasets = {
    "fashion_mnist": {
        "loader": keras.datasets.fashion_mnist,
        "class_names": ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
               'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
    },
    "mnist": {
        "loader": keras.datasets.mnist,
        "class_names": ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
    },
}


selectors = {
    'fashion_mnist': {
        'conditionalVAE': {
            '1630_help-mean-hard-world': {},
            '1632_speak-cultural-young-girl': {},
            '1635_remain-nice-head-area': {},
            '1638_add-love-dead-business': {},
        },
        'branchedClassifier': {
            '1333_live-similar-big-child': {},
            '1347_will-late-year-business': {},
            '1416_speak-leave-easy-money': {},
            '1402_die-continue-entire-room': {}
        }
    }, 
    'mnist': {
        'conditionalVAEMNIST': {
            '1137_meet-large-line-back': {},
            '1141_would-green-medical-morning': {},
            '1143_buy-left-place-head': {},
            '1146_stand-open-cold-month': {},
        },
        'branchedVAEMNIST': {
            '1219_love-early-better-service': {},
            '1222_stop-available-office-education': {},
            '1225_know-happy-minute-force': {},
            '1227_grow-hot-book-water': {},
        }
    },
}



In [5]:
for dataset in selectors.keys(): 
    for run in selectors[dataset].keys():
        for model_name in selectors[dataset][run].keys():
            # Load dataset
            # Create dataset
            (_, _), (valid_images, valid_labels) = datasets[dataset]['loader'].load_data()
            valid_images = valid_images.astype("float32") / 255.0
            valid_images = tf.expand_dims(valid_images, axis=-1)

            valid_images = valid_images[:3000]
            valid_labels = valid_labels[:3000]
            
            # Load experiment and model
            experiment = Experiment(name=model_name, base_path="experiments/"+run)
            base_model = experiment.load_model()
            params = {'input_dim': (28, 28, 1), 'z_dim': experiment.params['latent_dim'], 'label_dim': 10, 'beta': experiment.params['beta']}
            
            # Create model and latent vector space from validation set
            if 'conditional' in run:
                model = ConditionalVAE.from_saved_model(base_model, params)
                model.compile()
                valid_labels = keras.utils.to_categorical(valid_labels)
                class_names = datasets[dataset]['class_names']
                
                concat = model.concat_image_label([valid_images, valid_labels])
                z_mean, z_log_var = model.encoder(concat)
                train_labels_decoded = [class_names[i] for i in from_categorical(valid_labels)]
            else:
                model = BranchedClassifierVAE.from_saved_model(base_model, params)   
                model.compile()
                class_names = datasets[dataset]['class_names']
                
                z_mean, z_log_var, _ = model.encoder(valid_images)
                train_labels_decoded = [class_names[i] for i in valid_labels]
            
            labels_df = pd.DataFrame(train_labels_decoded, columns = ['classname'])
            cols = ['z'+str(i) for i in range(0, z_mean.shape[1])]
            z_df = pd.DataFrame(z_mean, columns=cols)
            
            print(dataset, run, model_name)
            
            selectors[dataset][run][model_name]['experiment'] = experiment
            selectors[dataset][run][model_name]['model'] = model
            selectors[dataset][run][model_name]['z'] = z_df
            selectors[dataset][run][model_name]['labels'] = labels_df
            selectors[dataset][run][model_name]['class_names'] = class_names



2022-10-05 10:27:33.297408: I tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


fashion_mnist conditionalVAE 1630_help-mean-hard-world
fashion_mnist conditionalVAE 1632_speak-cultural-young-girl
fashion_mnist conditionalVAE 1635_remain-nice-head-area
fashion_mnist conditionalVAE 1638_add-love-dead-business
fashion_mnist branchedClassifier 1333_live-similar-big-child
fashion_mnist branchedClassifier 1347_will-late-year-business
fashion_mnist branchedClassifier 1416_speak-leave-easy-money
fashion_mnist branchedClassifier 1402_die-continue-entire-room
mnist conditionalVAEMNIST 1137_meet-large-line-back
mnist conditionalVAEMNIST 1141_would-green-medical-morning
mnist conditionalVAEMNIST 1143_buy-left-place-head
mnist conditionalVAEMNIST 1146_stand-open-cold-month
mnist branchedVAEMNIST 1219_love-early-better-service
mnist branchedVAEMNIST 1222_stop-available-office-education
mnist branchedVAEMNIST 1225_know-happy-minute-force
mnist branchedVAEMNIST 1227_grow-hot-book-water


# Dashboard

## Dash

In [6]:
import plotly.express as px
import plotly.graph_objects as go
import base64
import dash
import dash_core_components as dcc
import dash_html_components as html
from PIL import Image
from io import BytesIO
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP
from dash.dependencies import Input, Output, State


def NamedSlider(name, short, style, min, max, step, val):
    #marks = {i: np.round(val, 2) for i, val in enumerate(steps)}

    return html.Div(
        style=style,
        children=[
            f"{name}:",
            html.Div(
                style={"margin-left": "5px"},
                children=[
                    dcc.Slider(
                        id=f"slider-{short}",
                        min=min,
                        max=max,
                        step=step,
                        value=val,
                        marks=None,
                        tooltip={"placement": "bottom", "always_visible": True}
                    )
                ],
            ),
        ],
    )


def generate_dim_sliders(visible, values=None, classname=None, dist=None, latent_dims=MAX_DIMS):
    sliders = []
        
    for dim in range(0, latent_dims):
        slider_min = -3
        slider_max = 3
            
        if dim <= visible-1:
            hide = False
            
            if dist is not None and classname is not None:
                slider_min = dist[f'pc{dim+1}']['mean'][classname]-dist[f'pc{dim+1}']['std'][classname]
                slider_max = dist[f'pc{dim+1}']['mean'][classname]+dist[f'pc{dim+1}']['std'][classname]
        else:
            hide = True
            
            
        sliders.append(
            NamedSlider(
                name=f"pc{dim+1}",
                short=f"pc{dim+1}",
                style={"display": "none"} if hide else {"margin": "25px 5px 30px 0px"},
                min=slider_min,
                max=slider_max,
                step=0.01,
                val=values[dim] if values is not None and not hide else 0
            )
        )
    return sliders


def numpy_to_b64(array, upscale=True, scalar=True):
    # Convert from 0-1 to 0-255
    if scalar:
        array = np.uint8(255 * array)

    im_pil = Image.fromarray(array)
    if upscale:
        im_pil = im_pil.resize((250, 250), Image.Resampling.LANCZOS)
    
    buff = BytesIO()
    im_pil.save(buff, format="png")
    im_b64 = base64.b64encode(buff.getvalue()).decode("utf-8")

    return im_b64

The dash_core_components package is deprecated. Please replace
`import dash_core_components as dcc` with `from dash import dcc`
  import dash_core_components as dcc
The dash_html_components package is deprecated. Please replace
`import dash_html_components as html` with `from dash import html`
  import dash_html_components as html


In [7]:
title = "Latent Space Visualization"

init_dataset_option = list(selectors.keys())[0]
init_run_option = list(selectors[init_dataset_option].keys())[0]
init_model_option = list(selectors[init_dataset_option][init_run_option].keys())[0]

app = dash.Dash(__name__)
app.title = title
app.layout = html.Div(children=[
    html.H2(title),
    html.Div(children=[
        # Scatter Plot Matrix, left column
        html.Div(children=[
            html.P("Select dataset, experiment and model:"),
            html.Div(children=[
                html.Div(children=[
                    dcc.Dropdown(list(selectors.keys()), init_dataset_option, id='dataset-dropdown'),
                    dcc.Dropdown(list(selectors[init_dataset_option].keys()), init_run_option, id='run-dropdown'),
                    dcc.Dropdown(list(selectors[init_dataset_option][init_run_option].keys()), init_model_option, id='model-dropdown'),
                ], style={'padding': 10, 'flex': 1}),
                html.Div(children=[
                    html.Pre(id="experiment-info", children="No data", className="dash-pre")
                ], style={'padding': 10, 'flex': 1}),
            ], style={'display': 'flex', 'flex-direction': 'row'}),
            dcc.Graph(id="spm-graph"),
            html.P("Number of components:"),
            dcc.Slider(id='pca-slider', min=2, max=2, value=2, step=1)
        ], style={'padding': 10, 'flex': 1}),

        # Reconstructed image, right column
        html.Div(children=[
            html.Div(children=[
                html.Div(children=[
                    html.Div(
                        id="reconstruction-message-div",
                        style={
                            "text-align": "center",
                            "margin-bottom": "7px",
                            "font-weight": "bold",
                        },
                        children=[html.H4("Select a sample in the SPM to modify latent components")]
                    ),
                    html.Div(id="reconstruction-image-div"),
                    html.Div(id="z-slider-div", children=generate_dim_sliders(0, latent_dims=MAX_DIMS)),
                ], style={'padding': 10, 'flex': 1}),
                html.Div(children=[
                    html.Pre(id="raw-information-pre", children="No data", className="dash-pre")
                ], style={'padding': 10, 'flex': 1, 'width': '300px'}),
            ], style={'display': 'flex', 'flex-direction': 'row'}),
        ], style={'padding': 10, 'flex': 1}),
    ], style={'display': 'flex', 'flex-direction': 'row'}),
    
    dcc.Store(id='sharedClickData')
])


def gen_figure_placeholder():
    empty_fig = go.Figure()
    empty_fig.update_layout({
        "xaxis": {
        "visible": False
        },
        "yaxis": {
            "visible": False
        },
        "annotations": [
            {
                "text": "Select dataset and model.",
                "xref": "paper",
                "yref": "paper",
                "showarrow": False,
                "font": {
                    "size": 28
                }
            }
        ]
    })
    return empty_fig


def get_dim_reducer(n_components, z_df):
    reducer = PCA(n_components=n_components)
    projections = reducer.fit_transform(z_df)
    #reducer = UMAP(n_components=n_components)
    #projections = reducer.fit_transform(z_df)
    
    return reducer, projections


@app.callback(
    Output('run-dropdown', 'options'),
    Input('dataset-dropdown', 'value')
)
def init_run_dropdown(dataset):
    if dataset is None:
        return []
    
    try:
        options = list(selectors[dataset].keys())
    except KeyError as e:
        return []
    return options


@app.callback(
    Output('model-dropdown', 'options'),
    [
        Input('dataset-dropdown', 'value'),
        Input('run-dropdown', 'value')
    ]
)
def init_model_dropdown(dataset, run):
    if dataset is None or run is None:
        return []
    
    try:
        options = list(selectors[dataset][run].keys())
    except KeyError as e:
        return []
    return options


@app.callback(
    [
        Output("pca-slider", "max"),
        Output("pca-slider", "value"),
        Output("experiment-info", "children"),
    ],
    [
        Input('dataset-dropdown', 'value'),
        Input('run-dropdown', 'value'),
        Input('model-dropdown', 'value')
    ]
)
def init_pca_slider(dataset, run, model_name):
    if dataset is None or run is None or model_name is None:
        return MAX_DIMS, MAX_DIMS, "No data"
    
    try:
        experiment = selectors[dataset][run][model_name]['experiment']
        return experiment.params['latent_dim'], experiment.params['latent_dim'], dict_pretty_print(experiment.params)
    except KeyError as e:
        return MAX_DIMS, MAX_DIMS, "No data"


# Dim Red Slider callback
@app.callback(
    [
        Output("spm-graph", "figure"),
        Output("reconstruction-image-div", "children"),
        Output("z-slider-div", "children"),
        Output("reconstruction-message-div", "children"),
        Output("raw-information-pre", "children"),
        Output("sharedClickData", "data"),
    ],
    [State("spm-graph", "figure")],
    [
        Input("pca-slider", "value"),
        Input('dataset-dropdown', 'value'),
        Input('run-dropdown', 'value'),
        Input('model-dropdown', 'value'),
        Input("spm-graph", "clickData"), 
        Input("sharedClickData", "data"),
    ] + [Input(f"slider-pc{dim+1}", "value") for dim in range(0, MAX_DIMS)]
)
def run_and_plot(figureState, n_components, dataset, run, model_name, clickData, sharedClickData, *args):
    if dataset is None or run is None or model_name is None:   
        return gen_figure_placeholder(), None, generate_dim_sliders(0, latent_dims=MAX_DIMS), html.H4(""), None, None
    
    experiment = selectors[dataset][run][model_name]['experiment']
    
    # Render SPLOM
    pca, components = get_dim_reducer(n_components, selectors[dataset][run][model_name]['z'])
    var = pca.explained_variance_ratio_.sum() * 100
    title = f'SPLOM: Total Explained Variance: {var:.2f}%'
    #title="SPLOM"

    labels = [f"PC {i+1}" for i in range(n_components)]

    components_df = pd.DataFrame(components, columns=labels)
    components_labeled_df = pd.concat([selectors[dataset][run][model_name]['labels'], components_df], axis=1)
    
    if clickData:
        if clickData != sharedClickData:
            print("# Fill Sliders and Decode image with a new clicked image")
            figure = figureState
            point = np.array([clickData["points"][0][f"dimensions[{i}].values"] for i in range(0, n_components)]).astype(np.float32)
            
            curve_number = clickData["points"][0]["curveNumber"]
            trace_name = figure["data"][curve_number]["name"]

            z = pca.inverse_transform(point)
            z_transformed = tf.expand_dims(z, axis=0)
            # Decode from latent space
            model = selectors[dataset][run][model_name]['model']
            reco_z = model.decode(z_transformed)
            reco_z = plt.cm.Greys(reco_z.numpy().squeeze(), bytes=True)

            image_b64 = numpy_to_b64(reco_z, upscale=True, scalar=False)
            img_el = html.Img(
                src="data:image/png;base64, " + image_b64,
                style={"height": "25vh", "display": "block", "margin": "auto"},
            )

            comp_df = pd.DataFrame(components, columns=[f"pc{d+1}" for d in range(0, n_components)])
            comp_agg_df = pd.concat([selectors[dataset][run][model_name]['labels'], comp_df], axis=1).groupby("classname").agg(["mean", "std"])

            raw = "Class: " + trace_name + "\n\n"
            raw += "Original components: \n" + str(point) + "\n\n"
            raw += "Latent vector z: \n" + str(z) + "\n\n"
            raw += "Inner Class Distribution:\n" + str(comp_agg_df.loc[trace_name,:]) + "\n\n"

            return figure, img_el, generate_dim_sliders(n_components, values=point, classname=trace_name, dist=comp_agg_df, latent_dims=MAX_DIMS), None, raw, clickData
        else:
            print("# Fill Sliders with updated values on already selected sample")
            # Transform component sliders to latent space
            figure = figureState
            slider_values = [val for val in args[:n_components]]
            point = np.array(slider_values).astype(np.float32)
            
            curve_number = clickData["points"][0]["curveNumber"]
            trace_name = figure["data"][curve_number]["name"]
            
            z = pca.inverse_transform(slider_values)
            z_transformed = tf.expand_dims(z, axis=0)

            # Decode from latent space
            model = selectors[dataset][run][model_name]['model']
            reco_z = model.decode(z_transformed)
            reco_z = plt.cm.Greys(reco_z.numpy().squeeze(), bytes=True)

            image_b64 = numpy_to_b64(reco_z, upscale=True, scalar=False)
            img_el = html.Img(
                src="data:image/png;base64, " + image_b64,
                style={"height": "25vh", "display": "block", "margin": "auto"},
            )
            
            comp_df = pd.DataFrame(components, columns=[f"pc{d+1}" for d in range(0, n_components)])
            comp_agg_df = pd.concat([selectors[dataset][run][model_name]['labels'], comp_df], axis=1).groupby("classname").agg(["mean", "std"])

            raw = "Class: " + trace_name + "\n\n"
            raw += "Original components: \n" + str(point) + "\n\n"
            raw += "Latent vector z: \n" + str(z) + "\n\n"
            raw += "Inner Class Distribution:\n" + str(comp_agg_df.loc[trace_name,:]) + "\n\n"
            
            return figure, img_el, generate_dim_sliders(n_components, values=point, classname=trace_name, dist=comp_agg_df, latent_dims=MAX_DIMS), None, raw, clickData

    else:
        print("# No point selected, show empty boxes and no sliders")
        figure = px.scatter_matrix(
            components_labeled_df, 
            dimensions=labels, 
            color="classname", 
            title=title
        )
        figure.update_layout(
            width=900,
            height=900,
            clickmode='event+select'
        )
        figure.update_traces(diagonal_visible=False)
        return figure, None, generate_dim_sliders(0, latent_dims=MAX_DIMS), html.H4("Select a sample in the SPM to modify latent components"), None, clickData


In [None]:
app.run_server(debug=True, use_reloader=False, port="8080")

Dash is running on http://127.0.0.1:8080/

 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: on
# No point selected, show empty boxes and no sliders
