# Shape Analysis

In [13]:
import os
import json
import itertools
import numpy as np
import plotly.graph_objects as go
from sklearn.decomposition import PCA
import plotly.express as px

MAIN_COLORS = px.colors.qualitative.Vivid
TRANSLUCENT_COLORS = [pastel_color.replace("rgb", "rgba").replace(")", ", 0.5)") for pastel_color in MAIN_COLORS]

MODEL_DIR = "models/shape_models"

In [14]:
def get_data(model_save_path):
    with open(os.path.join(model_save_path, "plot_data.json"), "r") as f:
        plot_data_json = json.load(f)
    with open(os.path.join(model_save_path, "input_outputs.json"), "r") as f:
        input_outputs_json = json.load(f)
    
    return plot_data_json, input_outputs_json


mu_dirs = []
for dir in os.scandir(MODEL_DIR):
    model_signature = dir.path.split('/')[-1]
    model_params = {model_arg.split('=')[0]: model_arg.split('=')[1] for model_arg in model_signature.split('-')[1:]}
    if model_params["plotting_type"] == "mu":
        mu_dirs.append(dir.path)

mu_dirs.sort()
print([mu_dir[-56:-20] for mu_dir in mu_dirs])

['bert-hid_dim=16-mlp=0-layer=2-head=2', '=bert-hid_dim=4-mlp=0-layer=1-head=1', 'bert-hid_dim=64-mlp=0-layer=4-head=4']


## Analyze input and outputs

In [15]:
plot_data, input_outputs = get_data(mu_dirs[1])

In [16]:
# delete all keys in the dictionary that are not multiples of 1000
for key in list(input_outputs.keys()):
    if int(key) % 1000 != 0:
        del input_outputs[key]
print(input_outputs.keys())

dict_keys(['0', '1000', '2000', '3000', '4000', '5000', '6000', '7000', '8000', '9000', '10000', '11000', '12000', '13000', '14000', '15000', '16000', '17000', '18000', '19000', '20000', '21000', '22000', '23000', '24000', '25000', '26000', '27000', '28000', '29000', '30000', '31000', '32000', '33000', '34000', '35000', '36000', '37000', '38000', '39000', '40000', '41000', '42000', '43000', '44000', '45000', '46000', '47000', '48000', '49000', '50000', '51000', '52000', '53000'])


In [17]:
print(input_outputs['0'].keys())
print(input_outputs['1000'].keys())

for epoch_key in list(input_outputs.keys()):
    for key in input_outputs[epoch_key].keys():
        input_outputs[epoch_key][key] = np.array(input_outputs[epoch_key][key])
        if input_outputs[epoch_key][key].shape[0] == 4096:
            input_outputs[epoch_key][key] = input_outputs[epoch_key][key].reshape(32, 128, 64)

dict_keys(['d_p', 'd_x', 'd_y', 'c_p', 'c_x', 'c_y'])
dict_keys(['d_p', 'd_x', 'd_y', 'c_p', 'c_x', 'c_y'])


In [18]:
def plot_mses_one_epoch(epoch):
    mse = []
    for i in range(32):
        mse.append(np.mean((input_outputs[epoch]['d_p'][i] - input_outputs[epoch]['d_y'][i])**2))
        
    fig = go.Figure(data=[go.Bar(x=np.arange(32), y=mse)])
    fig.update_layout(title_text=f"Mean Squared Error between d_p and d_y epoch {epoch}")
    fig.update_xaxes(title_text="Index")
    fig.update_yaxes(title_text="Mean Squared Error")
    fig.show()

for epoch in input_outputs.keys():
    plot_mses_one_epoch(epoch)

In [19]:
# plot the L2 norm of each of the 32 vectors in d_y
l2_norms = []
for i in range(32):
    l2_norms.append(np.linalg.norm(input_outputs["0"]['d_y'][i][0]))
    
fig = go.Figure(data=[go.Bar(x=np.arange(32), y=l2_norms)])
fig.update_layout(title_text=f"L2 Norm of d_y")
fig.update_xaxes(title_text="Index")
fig.update_yaxes(title_text="L2 Norm")
fig.show()

In [20]:
# plot the means of each of the 32 vectors in d_y
means = []
for i in range(32):
    means.append(np.mean(input_outputs["0"]['d_y'][i][0]))

fig = go.Figure(data=[go.Bar(x=np.arange(32), y=means)])
fig.update_layout(title_text=f"Mean of d_y")
fig.update_xaxes(title_text="Index")
fig.update_yaxes(title_text="Mean")
fig.show()

In [21]:
# plot the range of each of the 32 vectors in d_y
ranges = []
for i in range(32):
    ranges.append(np.max(input_outputs["0"]['d_y'][i][0]) - np.min(input_outputs["0"]['d_y'][i][0]))

fig = go.Figure(data=[go.Bar(x=np.arange(32), y=ranges)])
fig.update_layout(title_text=f"Range of d_y")
fig.update_xaxes(title_text="Index")
fig.update_yaxes(title_text="Range")
fig.show()

In [22]:
# plot the L1 norm of each of the 32 vectors in d_y
l1_norms = []
for i in range(32):
    l1_norms.append(np.linalg.norm(input_outputs["0"]['d_y'][i][0], ord=1))

fig = go.Figure(data=[go.Bar(x=np.arange(32), y=l1_norms)])
fig.update_layout(title_text=f"L1 Norm of d_y")
fig.update_xaxes(title_text="Index")
fig.update_yaxes(title_text="L1 Norm")
fig.show()

## Dimensionality Reduction

In [23]:
def plot_pca_epoch(epoch):
    epoch = str(epoch)
    pca = PCA(n_components=2)
    pca.fit(input_outputs[epoch]['d_p'].reshape(4096, 64))

    transformed_d_p = pca.transform(input_outputs[epoch]['d_p'].reshape(4096, 64))
    transformed_d_y = pca.transform(input_outputs[epoch]['d_y'].reshape(4096, 64)[::128])

    fig = go.Figure()
    # plot the d_p vectors and have the same color for 32 vectors that correspond to the same d_y vector
    for i in range(32):
        fig.add_trace(go.Scatter(x=transformed_d_p[i*128:(i+1)*128, 0], y=transformed_d_p[i*128:(i+1)*128, 1], mode='markers', marker=dict(color=MAIN_COLORS[i % len(MAIN_COLORS)]), name=f"d_p {i}"))
        # fig.add_annotation(x=transformed_d_p[i*128, 0], y=transformed_d_p[i*128, 1], text=f"{i}", showarrow=False, yshift=10, font=dict(color=f'rgba({i*8}, 0, 0, .8)'))

    for i, (x, y) in enumerate(zip(transformed_d_y[:, 0], transformed_d_y[:, 1])):
        fig.add_trace(go.Scatter(x=[x], y=[y], mode='markers', marker=dict(color=MAIN_COLORS[i % len(MAIN_COLORS)], symbol="x", size=10, line=dict(color='black', width=2)), name=f"d_y {i}"))
        fig.add_annotation(
            x=x, 
            y=y, 
            text=f"<b>{i}</b>", 
            showarrow=False, 
            yshift=10, 
            font=dict(color="white", family="Courier New, monospace"),
            bgcolor="black",
            opacity=0.58
        )
    fig.update_layout(title_text=f"PCA for epoch {epoch}")
    fig.update_xaxes(title_text="PC1")
    fig.update_yaxes(title_text="PC2")
    fig.show()

    print(pca.explained_variance_ratio_)
    print(pca.singular_values_)

In [24]:
for epoch in [0, 2000, 10000, 30000, 53000]:
    plot_pca_epoch(epoch)

[0.04168454 0.04038689]
[38.42952577 37.82663811]


[0.0412393  0.03854771]
[42.33132439 40.92658537]


[0.05614385 0.04957016]
[81.16710139 76.26742944]


[0.0595798  0.04986704]
[85.69654539 78.40082371]


[0.06050328 0.0493386 ]
[86.71566257 78.307166  ]
