In [6]:
import torch as t
import torch.nn as nn
import torch.nn.functional as F

import os
import pandas as pd
from liars.constants import DATA_PATH, ACTIVATION_CACHE, PROBE_PATH
from liars.utils import prefixes

from tqdm import trange

In [7]:
class Probe(nn.Module):
    def __init__(self, d_model, n_mo=6):
        super().__init__()
        self.proj = nn.Linear(d_model, n_mo, dtype=t.bfloat16)

    def forward(self, x):
        return self.proj(x)

In [8]:
# labels
labels, template = {}, {}
for prefix in prefixes.keys():
    data = pd.read_json(f"{DATA_PATH}/test/{prefix}.jsonl", lines=True, orient="records")
    labels[prefix] = data["label"].tolist()
    template[prefix] = [x == "True or False?" for x in data["prefix"]]

# activations
activations = {}
for prefix in prefixes.keys():
    PATH = f"{ACTIVATION_CACHE}/llama-3.1-8b-it-lora-{prefix}/all_post.pt"
    activations[prefix] = t.load(PATH, weights_only=True).reshape(33, -1, 4096)

classes = {prefix: i for i, prefix in enumerate(prefixes.keys())}

# correct w/o template
batch_size, nepoch = 64, 1
for layer in [4, 8, 12, 16, 20, 24, 28, 32]:
    X, Y = [], []
    for prefix in prefixes.keys():
        # mask = [x == "correct" and y for x, y in zip(labels[prefix], template[prefix])]
        mask = [~y for y in template[prefix]]
        mask = t.tensor(mask, dtype=t.bool)
        X.append(activations[prefix][layer, mask])
        Y.append(t.tensor([classes[prefix] for _ in range(len(X[-1]))], dtype=t.long))
    X, Y = t.cat(X), t.cat(Y)
    # shuffle data
    perm = t.randperm(len(X))
    X, Y = X[perm], Y[perm]
    # split data
    splits = (int(0.7*len(X)), int(0.9*len(X)))
    X_train, X_val, X_test = t.tensor_split(X, splits, 0)
    Y_train, Y_val, Y_test = t.tensor_split(Y, splits, 0)
    # batch data
    nbatch = len(X_train) // batch_size
    # prepare probe
    probe = Probe(X.shape[-1], len(classes))
    opt = t.optim.Adam(probe.parameters(), lr=1e-3)
    loss = nn.CrossEntropyLoss()
    # train
    train_losses, val_accs = [], []
    for i in trange(nepoch):
        perm = t.randperm(len(X_train))
        X_train, Y_train = X_train[perm], Y_train[perm]
        for j in range(nbatch):
            x, y = X_train[j*batch_size:(j+1)*batch_size], Y_train[j*batch_size:(j+1)*batch_size]
            # forward pass
            out = probe(x)
            # compute loss
            L = loss(out, y)
            # backward pass
            opt.zero_grad()
            L.backward()
            opt.step()
            train_losses.append(L.item())
        val_acc = (probe(X_val).argmax(dim=-1) == Y_val).float().mean().item()
        val_accs.append(val_acc)
    test_acc = (probe(X_test).argmax(dim=-1) == Y_test).float().mean().item()
    print(f"TEST ACC (LAYER {layer}): {test_acc}")
    t.save(probe.proj.weight.data, f"{PROBE_PATH}/layer-{layer}.pt")

100%|██████████| 1/1 [00:01<00:00,  1.50s/it]


TEST ACC (LAYER 4): 0.9994903206825256


100%|██████████| 1/1 [00:01<00:00,  1.51s/it]


TEST ACC (LAYER 8): 1.0


100%|██████████| 1/1 [00:01<00:00,  1.52s/it]


TEST ACC (LAYER 12): 1.0


100%|██████████| 1/1 [00:01<00:00,  1.51s/it]


TEST ACC (LAYER 16): 1.0


100%|██████████| 1/1 [00:01<00:00,  1.53s/it]


TEST ACC (LAYER 20): 1.0


100%|██████████| 1/1 [00:01<00:00,  1.53s/it]


TEST ACC (LAYER 24): 1.0


100%|██████████| 1/1 [00:01<00:00,  1.57s/it]


TEST ACC (LAYER 28): 1.0


100%|██████████| 1/1 [00:01<00:00,  1.53s/it]

TEST ACC (LAYER 32): 1.0



