In [1]:
import pickle

import numpy as np
# import plotly.express as px
from plotly.offline import init_notebook_mode
from importlib import reload
import visualise
import constants
import datasets
reload(visualise)
reload(constants)
reload(datasets)
from constants import *
from datasets import enterface
import plotly.io as pio
pio.renderers.default = "notebook"

init_notebook_mode(connected=True)

In [11]:
import pandas as pd
import plotly.graph_objects as go
from visualise import get_tsne_points
from ipywidgets import HBox
from IPython.display import display
from collections import OrderedDict


def clean_nans(embeddings_obj):
    clean_records = []
    clean_labels = []
    records = embeddings_obj["embeddings"]
    print(f"Got {records.shape[0]} records before cleaning")
    for i, record in enumerate(records):
        if not np.isnan(record).any():
            clean_records.append(record)
            clean_labels.append(embeddings_obj["labels"][i])
    records = np.stack(clean_records, axis=0)
    print(f"Got {records.shape[0]} records after cleaning")
    return {
        "embeddings": records,
        "labels": clean_labels
    }

def points2dataframe(points, labels):
    columns = ["x", "y"]
    points_df = pd.DataFrame(points, columns=columns)
    labels = [", ".join(l) for l in labels]
    points_df["labels"] = labels
    return points_df

def assign_colors_to_labels(labels):
    ls = list(labels)
    colors = []
    for l in ls:
        color_idx = enterface.LABEL_MAPPINGS.index(l)
        colors.append(enterface.LABEL_COLORS[color_idx])
    return colors


def get_embeddings_plot(embeddings_cleaned, model_name):
    if len(embeddings_cleaned["embeddings"].shape) == 3:
        embeddings = embeddings_cleaned["embeddings"].mean(axis=1)
    else:
        embeddings = embeddings_cleaned["embeddings"]
    points2d = get_tsne_points(embeddings)
    embeddings_df = points2dataframe(points2d, embeddings_cleaned["labels"])
    cmap = assign_colors_to_labels(embeddings_df["labels"])
    fig = go.Scatter(y=embeddings_df["y"], 
                     x=embeddings_df["x"],
                     hovertext=embeddings_df["labels"],
                     mode="markers",
                     marker_color=cmap)
    figw = go.FigureWidget([fig])
    figw.update_layout(
        autosize=False,
        width=800,
        height=800,
        template='seaborn',
        title=model_name
    )
    return figw

def tally_emotions(emotions_list):
    counts = OrderedDict()
    for key in enterface.LABEL_MAPPINGS:
        counts[key] = 0
    
    for emotion in emotions_list:
        counts[emotion] += 1
    labels = []
    values = []
    for k in counts.keys():
        labels.append(k)
        values.append(counts[k])
        
    return labels, values

def display_interactive(emb_obj, checkpoint_name):
    # Live updating bar chart for emotions within the selection
    barchartw = go.FigureWidget([go.Bar(
        marker={
            "color": enterface.LABEL_COLORS
        }
    )])
    barchartw.update_layout(title=f"Emotion breakdown in selection")

    def selection_fn(trace, points, selector):
        labels = trace.hovertext[points.point_inds]
        emotion_labels, emotion_counts = tally_emotions(labels)
        barchartw.plotly_restyle({"x": [emotion_labels], "y": [emotion_counts]}, 0)

    embeddings_cleaned = clean_nans(emb_obj)
    f = get_embeddings_plot(embeddings_cleaned, checkpoint_name)
    scatterplot = f.data[0]
    scatterplot.on_selection(selection_fn)

    display(HBox((f, barchartw)))
        

In [12]:
chkpt_name = "enterface_test4/mbt_student_train_loss_0.03864"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)


display_interactive(embeddings_obj, chkpt_name)

Got 2568 records before cleaning
Got 2568 records after cleaning


HBox(children=(FigureWidget({
    'data': [{'hovertext': array(['sadness', 'happiness', 'fear', ..., 'surprise…

In [None]:
chkpt_name = "enterface_test5/mbt_student_train_loss_0.19820"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

display_interactive(embeddings_obj, chkpt_name)