# What does a trained model learn?

Once I train a model on a distribution of classification tasks, I'm interested in

- Do the representations of the model become more aligned with those of humans?
- Can the model show human-like learning behaviour.

At the moment, the model training part is quite shaky. Nevertheless, there are some interesting preliminary stuff


In [None]:
import os

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn.functional as F
from nnsight import NNsight

from metalign.data import ThingsFunctionLearning, prepare_things_spose
from metalign.model import Transformer, TransformerConfig

_ = torch.set_grad_enabled(False)

In [None]:
os.chdir("..")

In [None]:
model = Transformer(TransformerConfig(intermediate_size=3072,input_size=2306, num_attention_heads=12))
model.load_state_dict(torch.load("data/checkpoints/cluster_full_attempt2/best.pt",map_location=torch.device('cpu'))["model_state_dict"], strict=True)
model.eval()

## Representational Alignment

Here, I want to compare the meta-learned model's representations to the SPoSE. This I will do through CKA.

The meta-learned model is a sequence model. For this comparison, I'll keep the sequence length at 1.

The model also expects the ground truth label from the previous item in the sequence as a one-hot vector prepended to the input. In this case (and in any cases for the first element in the sequence), we set it to `[0 0]`.

I'll not only compare the representations of this model over layers to the SPoSE, but also compare the raw Dino-v2 embeddings to SPoSE, which will serve as a baseline.


In [None]:
X_og, Y = prepare_things_spose(np.load("data/backbone_reps/dinov2_vitb14_reg.npz"))
X = torch.cat([torch.zeros(X_og.shape[0], 2), X_og], dim=1)  # prepend 0 0 to each row
# let's pretend we have a sequence of 1
X = X.unsqueeze(1)  # add sequence dimension

Need the CKA function here, which computes the cosine similarity between two centered and flattened linear kernels.

In [None]:
#| code-fold: false
def cka(
    X: torch.Tensor,  # Representations of the first set of samples
    Y: torch.Tensor,  # Representations of the second set of samples
) -> torch.Tensor:  # The linear CKA between X and Y
    "Compute the linear CKA between two matrices X and Y."
    X -= X.mean(dim=0)
    Y -= Y.mean(dim=0)

    XTX = X.T @ X
    YTY = Y.T @ Y
    YTX = Y.T @ X

    return (YTX**2).sum() / torch.sqrt((XTX**2).sum() * (YTY**2).sum())

In [None]:
nnsight_model = NNsight(model)

In [None]:
baseline_cka = cka(X_og, Y)
baseline_cka_test = cka(X_og, Y[:,:3])

layers = []
with nnsight_model.trace(X):
   layers.append(nnsight_model.embedding.output.squeeze().save())
   for layer in nnsight_model.layers:
       layers.append(layer.output.squeeze().save())


cka_results = []
cka_test_results = []
for i in range(len(layers)):
    cka_number = cka(layers[i], Y)
    cka_results.append(cka_number.item())
    cka_test_number = cka(layers[i], Y[:,:3])
    cka_test_results.append(cka_test_number.item())       

In [None]:
fig, axs = plt.subplots(1,2,figsize=(12, 6))
axs[0].plot(cka_results, marker='o', label="Meta-Learned")
axs[0].axhline(baseline_cka, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][-1], linestyle='--', label="Baseline")
axs[0].set_xlabel("Layer")
# remove legend
axs[0].legend().remove()
axs[0].set_ylabel("CKA with Hebart Features")
axs[0].set_title("All Dimensions")
axs[0].set_xticks(range(len(cka_results)))
axs[0].set_xticklabels(range(len(cka_results)))
fig.legend(loc='upper center', ncol=2, frameon=False, bbox_to_anchor=(0.52, 1))

axs[0].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.2f}'))

axs[1].plot(cka_test_results, marker='o', label="Meta-Learned")
axs[1].axhline(baseline_cka_test, color=plt.rcParams['axes.prop_cycle'].by_key()['color'][-1], linestyle='--', label="Baseline")
# no frame on legend
axs[1].set_xlabel("Layer")
axs[1].set_ylabel("")
axs[1].set_title("Only Eval Dimensions")
# x ticks should be integers just
axs[1].set_xticks(range(len(cka_test_results)))
axs[1].set_xticklabels(range(len(cka_test_results)))
axs[1].yaxis.set_major_formatter(plt.FuncFormatter(lambda x, _: f'{x:.2f}'))
plt.tight_layout()
plt.show()

Across all layers, the representations of the meta-learned model are more similar to SPoSE than how similar the original Dino v2 representations are to SPoSE. 

This is not surprising, given that the model is trained on these dimensions. But, it's a good sanity check.

What's maybe a bit more impressive is, if I do CKA with the eval dimensions (0, 1, and 2 in SPoSE), we still have the same pattern. These dimensions come from the same embedding, and were learned jointly with the training dimensions. However, they were not directly exposed during training. Nice!

## Behavioural Alignment

Next, I look at the learning curves of the meta-learner, and compare it to learning curves of humans on the same task (though different sequences of observations). The human data is [here]("https://osf.io/rsd46/").


In [None]:
num_eval_episodes = 128
sequence_length = 100
batch_size = 64
device = "cpu"
eval_dims = [0, 1, 2]
data = ThingsFunctionLearning(representations=np.load("data/backbone_reps/dinov2_vitb14_reg.npz"))
synthetic_data = dict(dimension=[], trial=[], correct=[], participant=[])
with torch.no_grad():
    model.eval()
    eval_losses = []
    correct_predictions = 0
    total_predictions = 0
    
    # collect episodes first, then process in batches
    X_eval_batch_list,Y_eval_batch_list = [], [] 

    for i in range(num_eval_episodes):
        dim = eval_dims[i % len(eval_dims)] # cycle through eval_dims
        X_episode, Y_episode = data.sample_episode(dim, sequence_length)
        
        prev_targets = torch.cat([torch.tensor([0]), Y_episode[:-1]])
        target_onehot = torch.nn.functional.one_hot(prev_targets.long(), num_classes=2).float()
        target_onehot[0] = 0.0

        
        inputs = torch.cat([target_onehot, X_episode], dim=1)
        
        X_eval_batch_list.append(inputs)
        Y_eval_batch_list.append(Y_episode)
    
    for i in range(0, num_eval_episodes, batch_size):
        batch_X = torch.stack(X_eval_batch_list[i:i+batch_size]).to(device)
        batch_Y = torch.stack(Y_eval_batch_list[i:i+batch_size]).to(device)

        logits_eval = model(batch_X).squeeze(-1)
        loss_eval =  F.binary_cross_entropy_with_logits(logits_eval, batch_Y)
        eval_losses.append(loss_eval.item())
        
        predictions = (torch.sigmoid(logits_eval) > 0.5).float()
        correct_batch = (predictions == batch_Y)
        
        for j in range(correct_batch.shape[0]): # iterate through episodes in batch
            episode_idx = i + j
            dim = eval_dims[episode_idx % len(eval_dims)]
            for k in range(correct_batch.shape[1]): # iterate through trials in episode
                synthetic_data['dimension'].append(dim)
                synthetic_data['trial'].append(k)
                synthetic_data['correct'].append(correct_batch[j, k].item())
                synthetic_data['participant'].append(episode_idx)

        correct_predictions += correct_batch.sum().item()
        total_predictions += batch_Y.numel()

    avg_eval_loss = np.mean(eval_losses)
    eval_accuracy = correct_predictions / total_predictions

In [None]:
synthetic_data = pd.DataFrame(synthetic_data)
human_data = pd.read_csv("https://osf.io/rsd46/download")

synthetic_data["model"] = "Meta-Learned"
human_data["model"] = "Human"
human_data = human_data[human_data.trial <= 99] # human data has 120 choices, currently the model has 100. Therefore, we clip the last 20

synthetic_data = synthetic_data[["dimension", "trial", "correct", "participant", "model"]]
human_data = human_data[["dimension", "trial", "correct", "participant", "model"]]

all_data = pd.concat([synthetic_data, human_data], ignore_index=True)

In [None]:
fig, ax = plt.subplots(figsize=(8,6))
sns.lineplot(data=all_data, x="trial", y="correct", hue="model", ax=ax)
ax.set_xlabel("Trials")
ax.set_ylabel("p(correct)")
plt.tight_layout()
plt.show()

In [None]:
g = sns.FacetGrid(all_data, col="model", sharey=True, row="dimension")
g.map_dataframe(sns.lineplot, x="trial", y="correct")
# remove 'model = ' from the titles of subplots
g.set_titles(col_template="{col_name}", row_template="Dimension {row_name}")
# set x and y labels
g.set_axis_labels("Trial", "p(correct)")
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(8, 6))

sns.barplot(data=all_data, x="dimension", y="correct", hue="model", ax=ax)
# remove legend
ax.legend().remove()
fig.legend(loc='upper center', ncol=2, frameon=False, bbox_to_anchor=(0.52, 1))
# x label
ax.set_xlabel("")
# y label
ax.set_ylabel("p(correct)")
# x tick labels
tick_labels = ["(0) Metallic/Artificial", "(1) Food-related", "(2) Animal-related"]
ax.set_xticklabels(tick_labels)
plt.tight_layout()
plt.show()

It seems like there is a nice qualitative match in general, though individual dimensions have some discrepancies. Not sure how severe these are at the moment.