In [11]:
import torch as t
import pandas as pd

from tqdm.auto import tqdm
from transformers import AutoTokenizer

device = "cuda" if t.cuda.is_available() else "cpu"

In [5]:
# Code yoinked from:
# https://github.com/saprmarks/geometry-of-truth/blob/91b223224699754efe83bbd3cae04d434dda0760/probes.py#L35
class MMProbe(t.nn.Module):
    def __init__(self, direction, covariance=None, inv=None, atol=1e-3):
        super().__init__()
        self.direction = t.nn.Parameter(direction, requires_grad=False)
        if inv is None:
            self.inv = t.nn.Parameter(t.linalg.pinv(covariance, hermitian=True, atol=atol), requires_grad=False)
        else:
            self.inv = t.nn.Parameter(inv, requires_grad=False)

    def forward(self, x, iid=False):
        if iid:
            return t.nn.Sigmoid()(x @ self.inv @ self.direction)
        else:
            return t.nn.Sigmoid()(x @ self.direction)

    def pred(self, x, iid=False):
        return self(x, iid=iid).round()

    def from_data(acts, labels, atol=1e-3, device='cpu'):
        acts, labels
        pos_acts, neg_acts = acts[labels==1], acts[labels==0]
        pos_mean, neg_mean = pos_acts.mean(0), neg_acts.mean(0)
        direction = pos_mean - neg_mean

        centered_data = t.cat([pos_acts - pos_mean, neg_acts - neg_mean], 0)
        covariance = centered_data.t() @ centered_data / acts.shape[0]
        
        probe = MMProbe(direction, covariance=covariance).to(device)

        return probe

In [10]:
tokenizer = AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-hf", cache_dir="/workspace/cache/")

In [3]:
activations = t.load("activations/llama2-7b_cities.pt")
activations.shape
# shape (statement layer pos d_model)

torch.Size([1496, 32, 2, 4096])

In [38]:
# Code also yoinked from Sam's repo
def get_directions_per_pos(pos_idx):
    layers = range(activations.shape[1])

    train_datasets = ['cities']
    val_dataset = 'sp_en_trans'

    # label tokens
    t_tok = tokenizer.encode('TRUE')[-1]
    f_tok = tokenizer.encode('FALSE')[-1]

    layer_directions = []

    for layer in tqdm(layers):
        # get probe
        acts, labels = [], []
        for dataset in train_datasets:
            # activations
            all_acts = t.load(f"activations/llama2-7b_{dataset}.pt").to(device)
            acts.append(all_acts[:, layer, pos_idx, :])
            # acts.append(collect_acts(dataset, '7B', layer).to(device))
            labels.append(t.tensor(pd.read_csv(f"datasets/{dataset}.csv")['label'].tolist()).to(device))
        acts, labels = t.cat(acts), t.cat(labels)
        probe = MMProbe.from_data(acts, labels, device=device)
        # get direction
        direction = probe.direction
        true_acts, false_acts = acts[labels==1], acts[labels==0]
        true_mean, false_mean = true_acts.mean(0), false_acts.mean(0)
        direction = direction / direction.norm()
        diff = (true_mean - false_mean) @ direction
        direction = diff * direction

        layer_directions.append(direction)

    return t.stack(layer_directions)

In [39]:
layer_directions_per_pos = []
for pos_idx in [-2, -1]:
    layer_dir = get_directions_per_pos(pos_idx)
    layer_directions_per_pos.append(layer_dir)

  0%|          | 0/32 [00:00<?, ?it/s]

  0%|          | 0/32 [00:00<?, ?it/s]

In [42]:
directions = t.stack(layer_directions_per_pos, dim=1).to("cpu")
directions.shape
# shape: (layer pos d_model)

torch.Size([32, 2, 4096])

In [45]:
t.save(directions, "directions/llama2-7b_cities_mm.pt")