In [1]:
from sklearn import decomposition
from caa.model import ModelWrapper
from caa.utils import load_dataset
import matplotlib.pyplot as plt
import torch as t
from tqdm import tqdm

In [None]:
model_name = 'meta-llama/Llama-2-7b-chat-hf'
device = "cuda:0"
model = ModelWrapper(model_name)
model.setup_layer_activation_hooks()

dataset = load_dataset("refusal_test_ab")

In [None]:
def process_activations(activations): return t.stack(
        [activation.to('cpu')[:, -1] for activation in activations], dim=1)

all_activations_p = []
all_activations_n = []

# TODO: we probably want to batch this as some of the datasets are large
for data in tqdm(dataset):
    input_p = model.tokenize_question(
        data['question'], data['answer_matching_behavior'][:2]).to("cuda:0")
    
    input_n = model.tokenize_question(
        data['question'], data['answer_not_matching_behavior'][:2]).to("cuda:0")

    model(input_p)
    activations_p = model.get_last_layer_activations()
    all_activations_p.append(process_activations(activations_p))

    model(input_n)
    activations_n = model.get_last_layer_activations()
    all_activations_n.append(process_activations(activations_n))

# stack all tensors and flatten batch and question dimensions
activations_p_tensor = t.stack(all_activations_p, dim=0).flatten(0, 1)
activations_n_tensor = t.stack(all_activations_n, dim=0).flatten(0, 1)

In [None]:
model.get_last_activations().items()

In [4]:
layer_9_activ_pos = activations_p_tensor[:, 9]
layer_10_activ_pos = activations_p_tensor[:, 10]

layer_9_activ_neg = activations_n_tensor[:, 9]
layer_10_activ_neg = activations_n_tensor[:, 10]


In [None]:
pca = decomposition.PCA(n_components=2)
X_layer_9_pos = pca.fit_transform(layer_9_activ_pos.numpy())
X_layer_9_neg = pca.fit_transform(layer_9_activ_neg.numpy())

X_layer_10_pos = pca.fit_transform(layer_10_activ_pos.numpy())
X_layer_10_neg = pca.fit_transform(layer_10_activ_neg.numpy())

X_layer_9_pos.shape

In [None]:
plt.clf()
plt.scatter(X_layer_9_pos[:, 0], X_layer_9_pos[:, 1])
plt.scatter(X_layer_9_neg[:, 0], X_layer_9_neg[:, 1], cmap="red")
plt.scatter(X_layer_10_pos[:, 0], X_layer_10_pos[:, 1], cmap="green")
plt.scatter(X_layer_10_neg[:, 0], X_layer_10_neg[:, 1], cmap="yellow")

plt.plot()