# Unembedding Space Analysis and MLP's Directions


## Set Up

In [1]:
import sys 
sys.path.append('..')
import torch 
import json 
from src.decision_transformer.utils import (
    load_decision_transformer,
    # get_max_len_from_model_type,
)
from src.environments.registration import register_envs
from src.environments.environments import make_env

register_envs()


Registering DynamicObstaclesMultiEnv-v0
Registering CrossingMultiEnv-v0
Registering Probe Envs


In [2]:
from src.config import EnvironmentConfig

model_path = "../models/MiniGrid-MemoryS7FixedStart-v0/WorkingModel.pt"
state_dict = torch.load(model_path)

env_config = state_dict["environment_config"]
env_config = EnvironmentConfig(**json.loads(env_config))

env = make_env(env_config, seed=4200, idx=0, run_name="dev")
env = env()

dt = load_decision_transformer(
    model_path, env, tlens_weight_processing=True
)



In [3]:
from sklearn.metrics.pairwise import cosine_similarity
import pandas as pd
import plotly.express as px 
from scipy.cluster import hierarchy
import numpy as np 

def plot_vector_norms(vectors, labels, visible_labels = False, type = "bar"):
    norms = torch.norm(vectors, dim=1)
    if type == "bar":
        fig = px.bar(y=norms, labels={"x": "L2 Norm"}, 
                        color = labels,
                        hover_name=labels,
                        title="L2 Norm of State Embedding Vectors",
                        orientation="v",
                        text_auto=True,
                        x = labels,
                        template="plotly_dark")
    elif type == "strip":
        fig = px.strip(y=norms, labels={"x": "L2 Norm"}, 
                        color = labels,
                        hover_name=labels,
                        title="L2 Norm of State Embedding Vectors",
                        orientation="v",
                        x = labels,
                        template="plotly_dark")        
    else:
        raise ValueError("type must be either 'bar' or 'strip'")
    
    fig.update_xaxes(
        visible=visible_labels,
    )
    return fig 


def plot_cosine_similarity_matrix(vectors, labels, cluster = False, visible_labels = False):

    cosine_similarity_matrix = cosine_similarity(vectors)
    df = pd.DataFrame(cosine_similarity_matrix, columns=labels, index=labels)

    if cluster:
        data_array = df.to_numpy()
        linkage = hierarchy.linkage(data_array)
        dendrogram = hierarchy.dendrogram(
            linkage, no_plot=True, color_threshold=-np.inf
        )
        reordered_ind = dendrogram["leaves"]
        # reorder df by ind
        df = df.iloc[reordered_ind, reordered_ind]
        # reorder labels
        labels = [labels[i] for i in reordered_ind]


    # plot the cosine similarity matrix
    fig = fig = px.imshow(
            df,
            color_continuous_scale="RdBu",
            title="Pairwise Cosine Similarity Heatmap",
            color_continuous_midpoint=0.0,
            labels={"color": "Cosine Similarity"},
        )
    fig.update_xaxes(
        tickmode="array",
        tickvals=list(range(len(labels))),
        ticktext=labels,
        showgrid=False,
    )
    fig.update_yaxes(
        tickmode="array",
        tickvals=list(range(len(labels))),
        ticktext=labels,
        showgrid=False,
    )
    fig.update_xaxes(
        visible=visible_labels,
    )
    fig.update_yaxes(
        visible=visible_labels,
    )
    return fig 


def get_cosine_similarity_table(vectors, labels):

    norms = torch.norm(vectors, dim=1)
    df_norms = pd.DataFrame({"L2 Norm": norms, "Index": labels})

    cosine_similarity_matrix = cosine_similarity(vectors)
    df = pd.DataFrame(cosine_similarity_matrix, columns=labels, index=labels)

    # flatten the cosine similarity matrix and plot the distribution
    # it's a pandas dataframe so we can go wide to long
    cosine_similarity_matrix = pd.melt(df, ignore_index=False).reset_index()
    # rename the columns
    cosine_similarity_matrix.columns = ["label_1", "label_2", "cosine_similarity"]
    # remove the diagonal
    cosine_similarity_matrix = cosine_similarity_matrix[cosine_similarity_matrix["label_1"] != cosine_similarity_matrix["label_2"]]
    # remove any values less than 0.05
    cosine_similarity_matrix = cosine_similarity_matrix[cosine_similarity_matrix["cosine_similarity"].abs() > 0.05]
    # remove any values equal to 1
    # cosine_similarity_matrix = cosine_similarity_matrix[cosine_similarity_matrix["cosine_similarity"] != 1]
    # cosine_similarity_matrix = cosine_similarity_matrix[cosine_similarity_matrix != 0]

    # merge df_norms to get l2 norm of either vector
    cosine_similarity_matrix = cosine_similarity_matrix.merge(df_norms, left_on="label_1", right_on="Index")
    cosine_similarity_matrix = cosine_similarity_matrix.merge(df_norms, left_on="label_2", right_on="Index")
    cosine_similarity_matrix = cosine_similarity_matrix.drop(columns=["Index_x", "Index_y"])
    # rename the columns
    cosine_similarity_matrix.columns = ["label_1", "label_2", "cosine_similarity", "l2_norm_1", "l2_norm_2"]

    return cosine_similarity_matrix


# Unembedding Analysis

- [ ] Norm of unembed
- [ ] Cosine Similarity of Unembed
- [ ] Cluster Unembed

In [4]:
from src.streamlit_app.constants import ACTION_NAMES

unembedding = dt.action_predictor.weight.detach()
unembedding_bias = dt.action_predictor.bias.detach()

In [5]:
plot_vector_norms(unembedding, ACTION_NAMES, visible_labels=True).show()

In [6]:
unembedding_bias # bias for each action is pretty small. 

tensor([ 0.0139, -0.0006,  0.0768, -0.0444,  0.0373, -0.0675, -0.0156])

In [7]:


plot_cosine_similarity_matrix(unembedding, ACTION_NAMES)

# MLP Analysis

In [8]:
mlp0_in = dt.transformer.blocks[0].mlp.W_in.T.detach()
mlp1_in = dt.transformer.blocks[1].mlp.W_in.T.detach()
mlp2_in = dt.transformer.blocks[2].mlp.W_in.T.detach()

mlp0_out = dt.transformer.blocks[0].mlp.W_out.detach()
mlp1_out = dt.transformer.blocks[1].mlp.W_out.detach()
mlp2_out = dt.transformer.blocks[2].mlp.W_out.detach()

# apply layernorm to mlp out vectors (ln_final is ln pre, nor pars so we can use it)
mlp0_out_ln = dt.transformer.ln_final(mlp0_out)
mlp1_out_ln = dt.transformer.ln_final(mlp1_out)
mlp2_out_ln = dt.transformer.ln_final(mlp2_out)

neuron_names = [f"N{i}" for i in range(mlp2_out.shape[0])]

In [9]:
# px.imshow(mlp0_out - mlp0_out_ln, 
#           color_continuous_scale="RdBu", 
#           title="MLP0 Out - MLP0 Out LN", 
#           labels={"color": "Difference"}).show()

# px.imshow(mlp1_out - mlp1_out_ln,
#             color_continuous_scale="RdBu",
#             title="MLP1 Out - MLP1 Out LN",
#             labels={"color": "Difference"}).show()

# px.imshow(mlp2_out - mlp2_out_ln,
#             color_continuous_scale="RdBu",
#             title="MLP2 Out - MLP2 Out LN",
#             labels={"color": "Difference"}).show()


In [10]:
# plot_vector_norms(mlp0_in, neuron_names, visible_labels=False,type="strip").show()
# plot_vector_norms(mlp1_in, neuron_names, visible_labels=False,type="strip").show()
# plot_vector_norms(mlp2_in, neuron_names, visible_labels=False,type="strip").show()
# plot_cosine_similarity_matrix(mlp0_in, neuron_names, cluster=True).show()
# plot_cosine_similarity_matrix(mlp1_in, neuron_names, cluster=True).show()
plot_cosine_similarity_matrix(mlp2_in, neuron_names, cluster=True).show()

In [11]:
# plot_vector_norms(mlp0_out, neuron_names, visible_labels=False,type="strip").show()
# plot_vector_norms(mlp1_out, neuron_names, visible_labels=False,type="strip").show()
# plot_vector_norms(mlp2_out, neuron_names, visible_labels=False,type="strip").show()
# plot_cosine_similarity_matrix(mlp0_out, neuron_names, cluster=True).show()
# plot_cosine_similarity_matrix(mlp1_out, neuron_names, cluster=True).show()
plot_cosine_similarity_matrix(mlp2_out, neuron_names, cluster=True).show()
# plot_cosine_similarity_matrix(mlp2_out_ln, neuron_names, cluster=True).show() 

In [12]:
def get_cluster_from_exemplar(cosine_similarity_table, example_label, abs_threshold = 0.8):
    '''
    Filters cosine similarity table by high cosine similarity with the example.
    Then returns the cosine similarity table sorted by cosine similarity.

    '''
    criteria_one = cosine_similarity_table["label_1"].str.contains(example_label, regex=True) | \
        cosine_similarity_table["label_2"].str.contains(example_label, regex=True)
    criteria_two = cosine_similarity_table["cosine_similarity"].abs() > abs_threshold
    
    mask = criteria_one & criteria_two
    masked_matrix = cosine_similarity_table[mask].sort_values(by="cosine_similarity", ascending=False)
    vocab_items = list(set(list(masked_matrix["label_1"].unique()) + list(masked_matrix["label_2"].unique())))

    return masked_matrix, mask, vocab_items



Let's see if we can backup Lucy's dynamic analysis with a static one. She said:


Left:
- L2N79, 
- L2N235, 
- L2N255.

Right:
- L2N132, 
- L2N204,
- L2N1,
- L2N108, 
- L2N158,
- L2N169 


In [13]:
# so we expect that maybe always go right sub-updates will cluster.
cosine_similarity_table = get_cosine_similarity_table(mlp2_out, neuron_names)
masked_matrix, mask, vocab_items = get_cluster_from_exemplar(cosine_similarity_table, "N1$", 0.50)
vector_mask = [True if item in vocab_items else False for item in neuron_names]
subset_names = [item for item in neuron_names if item in vocab_items]
# len(vector_mask)
plot_cosine_similarity_matrix(mlp2_out[vector_mask], subset_names, cluster=True, visible_labels=True).show()
# px.violin(rows,
#           x= "cosine_similarity", 
#           orientation="h",
#           box=True, 
#           hover_data=["label_1", "label_2", "l2_norm_1", "l2_norm_2"],
#           points="all", title="Cosine Similarity Distribution for N1").show()

Definitely we're able to pull out these vectors based on cosine similarity:
- L2N1, L2132, L2N108, L2N204 are all listed by Lucy for right.
- L2N79, L2N235 also come out fairly quickly on the other side.
- However, we start getting many others before we get L2N255 for left and L2N158. 

Let's do the same analysis for the in vectors.

In [14]:
cosine_similarity_table = get_cosine_similarity_table(mlp2_in, neuron_names)
masked_matrix, mask, vocab_items = get_cluster_from_exemplar(cosine_similarity_table, "N1$", 0.50)
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]
plot_cosine_similarity_matrix(mlp2_in[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

cosine_similarity_table = get_cosine_similarity_table(mlp2_in, neuron_names)
masked_matrix, mask, vocab_items = get_cluster_from_exemplar(cosine_similarity_table, "N204$", 0.50)
vector_mask_2 = [True if item in vocab_items else False for item in neuron_names]
subset_names_2 = [item for item in neuron_names if item in vocab_items]
plot_cosine_similarity_matrix(mlp2_in[vector_mask_2], subset_names_1, cluster=True, visible_labels=True).show()

# combine the two masks
combined_mask = np.logical_or(vector_mask_1, vector_mask_2)
subset_names_combined = [item for item, mask in zip(neuron_names, combined_mask) if mask]
plot_cosine_similarity_matrix(mlp2_in[combined_mask], subset_names_combined, cluster=True, visible_labels=True).show()


In [15]:
cosine_similarity_table = get_cosine_similarity_table(mlp2_in, neuron_names)
masked_matrix, mask, vocab_items = get_cluster_from_exemplar(cosine_similarity_table, "N169$", 0.40)
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]
plot_cosine_similarity_matrix(mlp2_in[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

# MLP Out Congruence

In [26]:
print(unembedding.shape)
print(mlp2_out.shape)

torch.Size([7, 256])
torch.Size([256, 256])


In [31]:
mlp2_out_congruence = mlp2_out_ln @ unembedding.T
mlp2_out_congruence_df = pd.DataFrame(mlp2_out_congruence, index=neuron_names, columns=ACTION_NAMES)
mlp2_out_congruence_df["right_minus_left"] = mlp2_out_congruence_df["right"] - mlp2_out_congruence_df["left"]

px.scatter(mlp2_out_congruence_df, x="left", y="right", hover_name=mlp2_out_congruence_df.index).show()


fig= px.strip(mlp2_out_congruence_df,
              x="right_minus_left", 
              hover_name=mlp2_out_congruence_df.index)
fig.show()

In [18]:
# project unembedding for right [1] into an othrthogonal space to unembedding [0]
# then project mlp2_out into that space

right_not_left = unembedding[1] - cosine_similarity(
    unembedding[1].reshape(1,-1), 
    unembedding[0].reshape(1,-1)) * unembedding[0].detach().numpy()

mlp2_out_congruence_right_not_left = mlp2_out @ right_not_left.T
mlp2_out_congruence_right_not_left = pd.DataFrame(mlp2_out_congruence_right_not_left, index=neuron_names, columns=["right_not_left"])

fig= px.strip(mlp2_out_congruence_right_not_left,
              x="right_not_left", 
              hover_name=mlp2_out_congruence_right_not_left.index)
fig.show()

In [19]:
left_not_right = unembedding[0] - cosine_similarity(
    unembedding[0].reshape(1,-1),
    unembedding[1].reshape(1,-1)) * unembedding[1].detach().numpy()

mlp2_out_congruence_left_not_right = mlp2_out @ left_not_right.T
mlp2_out_congruence_left_not_right = pd.DataFrame(mlp2_out_congruence_left_not_right, index=neuron_names, columns=["left_not_right"])

fig= px.strip(mlp2_out_congruence_left_not_right,
                x="left_not_right",
                hover_name=mlp2_out_congruence_left_not_right.index)

fig.show()

In [24]:
(left_not_right * right_not_left).norm()

tensor(0.0250)

In [20]:
import plotly.graph_objects as go 
fig = px.scatter(x= mlp2_out_congruence_right_not_left["right_not_left"],
              y= mlp2_out_congruence_left_not_right["left_not_right"],
                hover_name=mlp2_out_congruence_right_not_left.index,
                title="Congruence of MLP2 with Right and Left",
                labels={"x": "Congruence with Right", "y": "Congruence with Left"})

# add y = x line
fig.add_trace(go.Scatter(x=[-0.5, 0.5], y=[-0.5, 0.5], mode="lines", name="y=x"))

fig.show()



In [21]:
mlp2_out_congruence = mlp2_out_ln @ unembedding.T
plot_cosine_similarity_matrix(mlp2_out_congruence, neuron_names, cluster=True, visible_labels=False).show()

In [22]:
cosine_similarity_table = get_cosine_similarity_table(mlp2_out_congruence, neuron_names)
masked_matrix, mask, vocab_items = get_cluster_from_exemplar(cosine_similarity_table, "N1$", 0.70)
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]
plot_cosine_similarity_matrix(mlp2_out_congruence[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

## Do same analysis for MLP1

In [35]:
mlp1_out_congruence = mlp1_out_ln @ unembedding.T
mlp1_out_congruence_df = pd.DataFrame(mlp1_out_congruence, index=neuron_names, columns=ACTION_NAMES)
mlp1_out_congruence_df["right_minus_left"] = mlp1_out_congruence_df["right"] - mlp1_out_congruence_df["left"]




px.scatter(mlp1_out_congruence_df, x="left", y="right", hover_name=mlp1_out_congruence_df.index).show()


fig= px.strip(mlp1_out_congruence_df,
              x="right_minus_left", 
              hover_name=mlp1_out_congruence_df.index)
fig.show()

mlp1_out_congruence_left_not_right = mlp1_out @ left_not_right.T
mlp1_out_congruence_right_not_left = mlp1_out @ right_not_left.T

mlp1_out_congruence_left_not_right = pd.DataFrame(mlp1_out_congruence_left_not_right, index=neuron_names, columns=["left_not_right"])
mlp1_out_congruence_right_not_left = pd.DataFrame(mlp1_out_congruence_right_not_left, index=neuron_names, columns=["right_not_left"])


import plotly.graph_objects as go 
fig = px.scatter(x= mlp1_out_congruence_right_not_left["right_not_left"],
                y= mlp1_out_congruence_left_not_right["left_not_right"],
                hover_name=mlp1_out_congruence_right_not_left.index,
                title="Congruence of MLP2 with Right and Left",
                labels={"x": "Congruence with Right", "y": "Congruence with Left"})

# add y = x line
fig.add_trace(go.Scatter(x=[-0.5, 0.5], y=[-0.5, 0.5], mode="lines", name="y=x"))

fig.show()

# MLP In Congruence

In [None]:
embedding = dt.state_embedding.weight.detach().T
embedding_ln = dt.transformer.ln_final(dt.state_embedding.weight.detach()).T
print(embedding.shape)
print(mlp2_in.shape)

In [None]:
mlp2_in_congruence = mlp2_in @ embedding_ln.T
print(mlp2_in_congruence.shape)
plot_cosine_similarity_matrix(mlp2_in_congruence, neuron_names, cluster=True, visible_labels=False).show()

In [None]:
cosine_similarity_table = get_cosine_similarity_table(mlp2_in_congruence, neuron_names)
masked_matrix, mask, vocab_items = get_cluster_from_exemplar(cosine_similarity_table, "N1$", 0.45)
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]
plot_cosine_similarity_matrix(mlp2_in_congruence[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

In [None]:
cosine_similarity_table = get_cosine_similarity_table(mlp2_in_congruence, neuron_names)
masked_matrix, mask, vocab_items = get_cluster_from_exemplar(cosine_similarity_table, "N132$", 0.60)
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]
plot_cosine_similarity_matrix(mlp2_in_congruence[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

In [None]:
cosine_similarity_table = get_cosine_similarity_table(mlp2_in_congruence, neuron_names)
masked_matrix, mask, vocab_items = get_cluster_from_exemplar(cosine_similarity_table, "N235$", 0.75)
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]
plot_cosine_similarity_matrix(mlp2_in_congruence[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

In [None]:
cosine_similarity_table = get_cosine_similarity_table(mlp2_in_congruence, neuron_names)
masked_matrix, mask, vocab_items = get_cluster_from_exemplar(cosine_similarity_table, "N255$", 0.55)
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]
plot_cosine_similarity_matrix(mlp2_in_congruence[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

## Top K

In [None]:
# get the labels. 
from src.streamlit_app.constants import SPARSE_CHANNEL_NAMES
import itertools 

all_index_labels = [
    SPARSE_CHANNEL_NAMES,
    list(range(7)),
    list(range(7)),
]
indices = list(itertools.product(*all_index_labels))
index_labels = ["{0}, ({1},{2})".format(*index) for index in indices]
print(index_labels[:4])


mlp2_in_congruence_df = pd.DataFrame(mlp2_in_congruence.T, index=index_labels, columns=neuron_names)
mlp2_in_congruence_df.head()

In [None]:
mlp2_in_congruence_df.mean(axis=1).abs().sort_values(ascending=True).head(10)

In [None]:
from collections import Counter
neuron_list = ["N1", "N132", "N204", "N108", "N158"]
df_list = []
listed_observations = []
for neuron in neuron_list:
    # get top 5
    df = mlp2_in_congruence_df[neuron].sort_values(ascending=False).head(10).reset_index(drop=False)
    # get bottom 5
    df = pd.concat([df, mlp2_in_congruence_df[neuron].sort_values(ascending=False).tail(10).reset_index(drop=False)], axis=0)
   
    listed_observations = listed_observations + list(df["index"])
    # rename index 
    df = df.rename(columns={"index": "index_{0}".format(neuron)})
    df_list.append(df)

counted_observations = Counter(listed_observations)
px.bar(
    x=list(counted_observations.keys()),
    y=list(counted_observations.values())).show()

df = pd.concat(df_list, axis=1)
df

In [None]:
neuron_list = ["N169", "N75", "N235", "N255"]
df_list = []
listed_observations = []
for neuron in neuron_list:
    # get top 5
    df = mlp2_in_congruence_df[neuron].sort_values(ascending=False).head(5).reset_index(drop=False)
    # get bottom 5
    df = pd.concat([df, mlp2_in_congruence_df[neuron].sort_values(ascending=False).tail(5).reset_index(drop=False)], axis=0)
   
    listed_observations = listed_observations + list(df["index"])
    # rename index 
    df = df.rename(columns={"index": "index_{0}".format(neuron)})
    df_list.append(df)

counted_observations = Counter(listed_observations)
px.bar(
    x=list(counted_observations.keys()),
    y=list(counted_observations.values())).show()

df = pd.concat(df_list, axis=1)
df

In [None]:
# let's calculate kurtosis of the congruence for each neuron (column)

px.strip(
    mlp2_in_congruence_df.std(axis=0).reset_index(drop=False).rename(columns={"index": "Neuron", 0: "Std"}),
    x = "Std",
    hover_data=["Neuron"],
    orientation="h",
    title="Standard Deviation of Congruence for Each Neuron",
    labels={"value": "Kurtosis"}).show()

px.strip(
    mlp2_in_congruence_df.kurtosis(axis=0).reset_index(drop=False).rename(columns={"index": "Neuron", 0: "Kurtosis"}),
    x = "Kurtosis",
    hover_data=["Neuron"],
    orientation="h",
    title="Kurtosis of Congruence for Each Neuron",
    labels={"value": "Kurtosis"}).show()

In [None]:
neuron_list = ["N160", "N133", "N79"] # choosing for high kurtosis. 
df_list = []
listed_observations = []
for neuron in neuron_list:
    # get top 5
    df = mlp2_in_congruence_df[neuron].sort_values(ascending=False).head(5).reset_index(drop=False)
    # get bottom 5
    df = pd.concat([df, mlp2_in_congruence_df[neuron].sort_values(ascending=False).tail(5).reset_index(drop=False)], axis=0)
   
    listed_observations = listed_observations + list(df["index"])
    # rename index 
    df = df.rename(columns={"index": "index_{0}".format(neuron)})
    df_list.append(df)

counted_observations = Counter(listed_observations)
px.bar(
    x=list(counted_observations.keys()),
    y=list(counted_observations.values())).show()

df = pd.concat(df_list, axis=1)
df

# Just focusing on neurons we identified

In [None]:
# px.ecdf(get_cosine_similarity_table(mlp2_in, neuron_names).cosine_similarity).show()
# px.histogram(get_cosine_similarity_table(mlp2_in, neuron_names).cosine_similarity).show()
# px.ecdf(get_cosine_similarity_table(mlp2_out, neuron_names).cosine_similarity).show()
# px.histogram(get_cosine_similarity_table(mlp2_out, neuron_names).cosine_similarity).show()
px.ecdf(get_cosine_similarity_table(mlp2_in_congruence, neuron_names).cosine_similarity).show()
# px.ecdf(get_cosine_similarity_table(mlp2_out_congruence, neuron_names).cosine_similarity.abs()).show()

In [None]:
# get_cosine_similarity_table(mlp2_in, neuron_names).cosine_similarity.abs().describe() 
summary_stats = []
for vectors in [mlp2_in, mlp2_out, mlp2_in_congruence, mlp2_out_congruence]:
    summary_stats.append(get_cosine_similarity_table(vectors, neuron_names).cosine_similarity.abs().describe())

summary_stats_df = pd.concat(summary_stats, axis=1)
summary_stats_df.columns = ["MLP2 In", "MLP2 Out", "MLP2 In Congruence", "MLP2 Out Congruence"]
summary_stats_df

In [None]:
vocab_items = ["N1", "N132", "N204", "N108", "N158", "N169", "N79", "N235", "N255"]
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]


plot_cosine_similarity_matrix(mlp2_in[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

In [None]:
vocab_items = ["N1", "N132", "N204", "N108", "N158", "N169", "N79", "N235", "N255"]
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]


plot_cosine_similarity_matrix(mlp2_out[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

In [None]:
vocab_items = ["N1", "N132", "N204", "N108", "N158", "N169", "N79", "N235", "N255"]
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]


plot_cosine_similarity_matrix(mlp2_out_congruence[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

In [None]:
vocab_items = ["N1", "N132", "N204", "N108", "N158", "N169", "N79", "N235", "N255"]
vector_mask_1 = [True if item in vocab_items else False for item in neuron_names]
subset_names_1 = [item for item in neuron_names if item in vocab_items]


plot_cosine_similarity_matrix(mlp2_in_congruence[vector_mask_1], subset_names_1, cluster=True, visible_labels=True).show()

# Unexpected Congruence Analysis


- Calculate congruence between embedding space and unembedding space
- Calculate congruence from out weights to embedding space
- Calculate congruence from in weights to un-embedding space




In [None]:
# spaces 
print(embedding.shape)
print(embedding_ln.shape)
print(unembedding.shape)

In [None]:
embedding_norms = np.linalg.norm(embedding, axis=1)
embedding_norms.shape

In [None]:
# get restricted embedding (use norm greater than 0.8)

restricted_embedding = embedding[embedding_norms > 0.8]
restricted_embedding_labels = [label for label, norm in zip(index_labels, embedding_norms) if norm > 0.8]
# multiply embedding by unembedding

embedding_unembedding_congruence = unembedding @ restricted_embedding.T
embedding_unembedding_congruence_df = pd.DataFrame(
    embedding_unembedding_congruence.T, 
    index=restricted_embedding_labels, 
    columns=ACTION_NAMES)

embedding_unembedding_congruence_df.head()

# convert this table from wide to long
embedding_unembedding_congruence_df = embedding_unembedding_congruence_df.reset_index(drop=False).rename(columns={"index": "Vocabulary Item"})
embedding_unembedding_congruence_df = embedding_unembedding_congruence_df.melt(id_vars=["Vocabulary Item"], var_name="Action", value_name="Congruence")
# sort values and reset index so we can interpret that as rank
embedding_unembedding_congruence_df = embedding_unembedding_congruence_df.sort_values(by=["Congruence"], ascending=False).reset_index(drop=True)
embedding_unembedding_congruence_df.head(10)

In [None]:
px.ecdf(embedding_unembedding_congruence_df,
        x="Congruence")

In [None]:
px.strip(embedding_unembedding_congruence_df,
         x="Action",
         y="Congruence",
        #  color="Vocabulary Item",
        hover_data=["Vocabulary Item"],
         title="Congruence of Neurons with Actions").show()

In [None]:
mask = embedding_unembedding_congruence_df["Vocabulary Item"].str.contains("0,6")
px.strip(embedding_unembedding_congruence_df[mask],
         x="Action",
         y="Congruence",
         color="Vocabulary Item",
        hover_data=["Vocabulary Item"],
         title="Congruence of Neurons with Actions").show()

In [None]:
mask = embedding_unembedding_congruence_df["Vocabulary Item"].str.contains("5,6")
px.strip(embedding_unembedding_congruence_df[mask],
         x="Action",
         y="Congruence",
         color="Vocabulary Item",
        hover_data=["Vocabulary Item"],
         title="Congruence of Neurons with Actions").show()

### Congruence MLP Out to embedding space

In [None]:
mlp2_out_directions = mlp2_out / np.linalg.norm(mlp2_out, axis=1)[:, None]
restricted_embedding_directions = restricted_embedding / np.linalg.norm(restricted_embedding, axis=1)[:, None]

mlp2_out_embedding_congruence = mlp2_out @ restricted_embedding.T
mlp2_out_embedding_congruence = mlp2_out_directions @ restricted_embedding_directions.T
mlp2_out_embedding_congruence.shape # projections from neurons to embedding space

In [None]:
mlp2_out_embedding_congruence_df = pd.DataFrame(
    mlp2_out_embedding_congruence.T,
    index=restricted_embedding_labels,
    columns=neuron_names)

mlp2_out_embedding_congruence_df

In [None]:
tmp = mlp2_out_embedding_congruence_df.kurtosis(axis=0).reset_index()
tmp.columns = ["neuron", "kurtosis"]

# check kurtosis (any very neurons with outliers?)
px.strip(
    tmp,
    y="kurtosis",
    hover_data=["neuron"],
    title="Kurtosis of MLP2 Out Embedding Congruence"
).show()

In [None]:
tmp = mlp2_out_embedding_congruence_df.kurtosis(axis=1).reset_index()
tmp.columns = ["vocab_item", "kurtosis"]

# check kurtosis (any very neurons with outliers?)
px.strip(
    tmp,
    y="kurtosis",
    hover_data=["vocab_item"],
    title="Kurtosis of MLP2 Out Embedding Congruence"
).show()

In [None]:
tmp = mlp2_out_embedding_congruence_df.abs().mean(axis=0).reset_index()
tmp.columns = ["neuron", "mean"]

# check kurtosis (any very neurons with outliers?)
px.strip(
    tmp,
    y="mean",
    hover_data=["neuron"],
    title="Mean of MLP2 Out Embedding Congruence"
).show()

In [None]:
tmp.sort_values(by="mean", ascending=False).head(10)

In [None]:
# convert mlp2_out_embedding_congruence_df from wide to long
mlp2_out_embedding_congruence_df_long = mlp2_out_embedding_congruence_df.reset_index().melt(
    id_vars="index",
    var_name="neuron",
    value_name="congruence"
)
mlp2_out_embedding_congruence_df_long.columns = ["vocab_item", "neuron", "congruence"]
#  sort values, then reset index so we can use it for ranking
mlp2_out_embedding_congruence_df_long = mlp2_out_embedding_congruence_df_long.sort_values(by="congruence", ascending=False).reset_index(drop=True)
mlp2_out_embedding_congruence_df_long.sort_values(by="congruence", ascending=False).head(10)

In [None]:
mlp2_out_embedding_congruence_df_long.query("neuron == 'N1'").head(10)

In [None]:
mlp2_out_embedding_congruence_df_long.query("neuron == 'N255'").head(10)

In [None]:
# for each neuron in the list, get the top 10 congruence values and add concatenate these tables
neuron_list = ["N1", "N108", "N132", "N204"]
df_list = []
for neuron in neuron_list:
    tmp = mlp2_out_embedding_congruence_df_long.query(f"neuron == '{neuron}'").head(10)

    # rename columns 
    tmp.columns = ["Vocabulary Item", "Neuron", "Congruence"]
    # add neuron name to column
    tmp.columns = [f"{neuron} {col}" for col in tmp.columns]
    # remove neuron column
    tmp = tmp.drop(columns=[f"{neuron} Neuron"])
    tmp = tmp.reset_index(drop=True)
    df_list.append(tmp)

top_neurons_df = pd.concat(df_list, axis=1)
top_neurons_df

# In-Weights to Unembedding Space

In [None]:
mlp2_in_directions = mlp2_in / np.linalg.norm(mlp2_out, axis=1)[:, None]
unembedding_directions = unembedding / np.linalg.norm(unembedding, axis=1)[:, None]
# mlp2_in_unembedding_congruence = mlp2_in @ unembedding.T
mlp2_in_unembedding_congruence = mlp2_in_directions @ unembedding_directions.T


print(mlp2_in_unembedding_congruence.shape)

mlp2_in_unembedding_congruence_df = pd.DataFrame(
    mlp2_in_unembedding_congruence.T,
    index=ACTION_NAMES,
    columns=neuron_names)

mlp2_in_unembedding_congruence_df.head()

In [None]:
# convert this table from wide to long
mlp2_in_unembedding_congruence_df_long = mlp2_in_unembedding_congruence_df.reset_index().melt(
    id_vars="index",
    var_name="neuron",
    value_name="congruence"
)
mlp2_in_unembedding_congruence_df_long.columns = ["action", "neuron", "congruence"]
#  sort values, then reset index so we can use it for ranking
mlp2_in_unembedding_congruence_df_long = mlp2_in_unembedding_congruence_df_long.sort_values(by="congruence", ascending=False).reset_index(drop=True)
mlp2_in_unembedding_congruence_df_long.sort_values(by="congruence", ascending=False).head(10)

In [None]:
px.strip(mlp2_in_unembedding_congruence_df_long,
         color="action",
            y="congruence",
            hover_data=["neuron"],
            title="MLP2 In Unembedding Congruence").show()

In [None]:
px.ecdf(mlp2_in_unembedding_congruence_df_long, x="congruence")

In [None]:
# now let's do the neuron list thingo:
# vocab_items = ["N1", "N132", "N204", "N108", "N158", "N169", "N79", "N235", "N255"]
mlp2_in_unembedding_congruence_df_long.query("neuron == 'N1'").head(20)

In [None]:
mlp2_in_unembedding_congruence_df_long.query("neuron == 'N108'").head(20)

# congruence MLPOut and RTG

In [42]:
in_rtg_congruence = dt.reward_embedding[0].weight.T @ mlp2_in
px.strip(in_rtg_congruence.T.detach())