In [36]:
import pickle

import numpy as np
import plotly.express as px
from plotly.offline import init_notebook_mode
from importlib import reload
import visualise
import constants
reload(visualise)
reload(constants)
from constants import *

init_notebook_mode(connected=True)

In [37]:
import pandas as pd
from visualise import get_tsne_points


def clean_nans(embeddings_obj):
    clean_records = []
    clean_labels = []
    records = embeddings_obj["embeddings"]
    print(f"Got {records.shape[0]} records before cleaning")
    for i, record in enumerate(records):
        if not np.isnan(record).any():
            clean_records.append(record)
            clean_labels.append(embeddings_obj["labels"][i])
    records = np.stack(clean_records, axis=0)
    print(f"Got {records.shape[0]} records after cleaning")
    return {
        "embeddings": records,
        "labels": clean_labels
    }

def points2dataframe(points, labels):
    columns = ["x", "y"]
    points_df = pd.DataFrame(points, columns=columns)
    labels = [", ".join(l) for l in labels]
    points_df["labels"] = labels
    return points_df

def plot_embeddings(embeddings_cleaned, model_name):
    if len(embeddings_cleaned["embeddings"].shape) == 3:
        embeddings = embeddings_cleaned["embeddings"].mean(axis=1)
    else:
        embeddings = embeddings_cleaned["embeddings"]
    points2d = get_tsne_points(embeddings)
    embeddings_df = points2dataframe(points2d, embeddings_cleaned["labels"])
    fig = px.scatter(embeddings_df, x="x", y="y", hover_data="labels")
    fig.update_layout(
        autosize=False,
        width=800,
        height=800,
        template='seaborn',
        title=model_name
    )
    fig.show()



In [38]:
chkpt_name = "experiment_35epochs/mbt_student_val_loss_2917.26574"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)

Got 324 records before cleaning
Got 321 records after cleaning


In [39]:
chkpt_name = "experiment2_110_epochs_cls-token/mbt_student_val_loss_0.00000"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)

Got 324 records before cleaning
Got 312 records after cleaning


In [40]:
chkpt_name = "experiment3_110_epochs_cls-token/mbt_student_val_loss_0.00088_less"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)

Got 324 records before cleaning
Got 324 records after cleaning


In [41]:
chkpt_name = "experiment3_110_epochs_cls-token/mbt_student_val_loss_0.00088_v1"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)

Got 2586 records before cleaning
Got 2586 records after cleaning


In [42]:
chkpt_name = "experiment3_110_epochs_cls-token/mbt_student_val_loss_0.00088_v2"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)

Got 2586 records before cleaning
Got 2580 records after cleaning


In [43]:
chkpt_name = "experiment3_110_epochs_cls-token/mbt_student_val_loss_0.00093"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)

Got 2586 records before cleaning
Got 2586 records after cleaning


In [44]:
chkpt_name = "experiment_7_epochs_cls-token/mbt_student_val_loss_0.00000"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)

Got 648 records before cleaning
Got 648 records after cleaning


In [45]:
chkpt_name = "experiment2_7_epochs_cls-token/mbt_student_val_loss_0.01134"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)

Got 2695 records before cleaning
Got 2695 records after cleaning


In [46]:
chkpt_name = "experiment3_7_epochs_cls-token/mbt_student_val_loss_0.00011"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)

Got 2156 records before cleaning
Got 2156 records after cleaning


In [49]:
chkpt_name = "experiment3_7_epochs_cls-token/mbt_student_val_loss_0.00012"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)

Got 2156 records before cleaning
Got 2156 records after cleaning


In [None]:
chkpt_name = "experiment_7_epochs_cls-token/mbt_student_val_loss_0.04922"
pkl_path = f"saved_models/{chkpt_name}.pkl"
with open(pkl_path, "rb") as fh:
    embeddings_obj = pickle.load(fh)

embeddings_cleaned = clean_nans(embeddings_obj)
plot_embeddings(embeddings_cleaned, chkpt_name)