In [2]:
from typing import Any

import torch
from nnsight import AbstractModel, LanguageModel, util
from nnsight.Module import Module
from nnsight.toolbox.optim.lora import LORA
from torch.utils.data import DataLoader, Dataset
from rich import print as rprint
import numpy as np
import copy
from torch import nn
import torch.nn.functional as F

from tqdm import tqdm

from datasets import load_dataset
from sklearn.linear_model import LogisticRegression

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
data = load_dataset("amazon_polarity")["test"]

def format_imdb(text, label):
    """
    Given an imdb example ("text") and corresponding label (0 for negative, or 1 for positive), 
    returns a zero-shot prompt for that example (which includes that label as the answer).
    
    (This is just one example of a simple, manually created prompt.)
    """
    return "The following movie review expresses a " + ["negative", "positive"][label] + " sentiment:\n" + text

In [4]:
model = LanguageModel("gpt2", device_map=device)

In [12]:
neg_hs_prompts = []
pos_hs_prompts = []
all_labels = []

fails = 0
for i in range(200):

    while True:
        idx = np.random.randint(len(data))
        text, true_label = data[idx]["content"], data[idx]["label"]
        # the actual formatted input will be longer, so include a bit of a marign
        if len(model.tokenizer.encode(text)) < 200:  
            break
        else: 
            fails += 1
            continue
    
    neg = format_imdb(text, 0)
    pos = format_imdb(text, 1)

    neg_hs_prompts.append(neg)
    pos_hs_prompts.append(pos)
    all_labels.append(true_label)

print(fails)

5


In [13]:
model = LanguageModel("gpt2", device_map=device)

with model.generate(max_new_tokens=1) as generator:
    with generator.invoke(neg_hs_prompts) as invoker:
        neg_hs = model.transformer.h[-1].output[0].t[-1].save()

with model.generate(max_new_tokens=1) as generator:
    with generator.invoke(pos_hs_prompts) as invoker:
        pos_hs = model.transformer.h[-1].output[0].t[-1].save()

neg_hs = neg_hs.value
pos_hs = pos_hs.value

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


In [14]:
# let's create a simple 50/50 train split (the data is already randomized)

y = all_labels

n = len(y)
neg_hs_train, neg_hs_test = neg_hs[:150], neg_hs[150:]
pos_hs_train, pos_hs_test = pos_hs[:150], pos_hs[150:]
y_train, y_test = y[:150], y[150:]

# for simplicity we can just take the difference between positive and negative hidden states
# (concatenating also works fine)
x_train = neg_hs_train - pos_hs_train
x_test = neg_hs_test - pos_hs_test

lr = LogisticRegression(class_weight="balanced")
lr.fit(x_train.cpu(), y_train)
print("Logistic regression accuracy: {}".format(lr.score(x_test.cpu(), y_test)))

Logistic regression accuracy: 0.88


In [15]:
class MLPProbe(nn.Module):
    def __init__(self, d):
        super().__init__()
        self.linear1 = nn.Linear(d, 100)
        self.linear2 = nn.Linear(100, 1)

    def forward(self, x):
        h = F.relu(self.linear1(x))
        o = self.linear2(h)
        return torch.sigmoid(o)

In [18]:
class CCS(object):
    def __init__(self, x0, x1, nepochs=1000, ntries=10, lr=1e-3, batch_size=-1, 
                 verbose=False, device="cuda", linear=True, weight_decay=0.01, var_normalize=False):
        # data
        self.var_normalize = var_normalize
        # self.x0 = self.normalize(x0)
        # self.x1 = self.normalize(x1)
        self.x0 = x0
        self.x1 = x1
        self.d = self.x0.shape[-1]

        # training
        self.nepochs = nepochs
        self.ntries = ntries
        self.lr = lr
        self.verbose = verbose
        self.device = device
        self.batch_size = batch_size
        self.weight_decay = weight_decay
        
        # probe
        self.linear = linear
        self.probe = self.initialize_probe()
        self.best_probe = copy.deepcopy(self.probe)

        
    def initialize_probe(self):
        if self.linear:
            self.probe = nn.Sequential(nn.Linear(self.d, 1), nn.Sigmoid())
        else:
            self.probe = MLPProbe(self.d)
        self.probe.to(self.device)    


    def normalize(self, x):
        """
        Mean-normalizes the data x (of shape (n, d))
        If self.var_normalize, also divides by the standard deviation
        """
        normalized_x = x - x.mean(axis=0, keepdims=True)
        if self.var_normalize:
            normalized_x /= normalized_x.std(axis=0, keepdims=True)

        return normalized_x

        
    def get_tensor_data(self):
        """
        Returns x0, x1 as appropriate tensors (rather than np arrays)
        """
        x0 = torch.tensor(self.x0, dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.x1, dtype=torch.float, requires_grad=False, device=self.device)
        return x0, x1
    

    def get_loss(self, p0, p1):
        """
        Returns the CCS loss for two probabilities each of shape (n,1) or (n,)
        """
        informative_loss = (torch.min(p0, p1)**2).mean(0)
        consistent_loss = ((p0 - (1-p1))**2).mean(0)
        return informative_loss + consistent_loss


    def get_acc(self, x0_test, x1_test, y_test):
        """
        Computes accuracy for the current parameters on the given test inputs
        """
        x0 = torch.tensor(self.normalize(x0_test), dtype=torch.float, requires_grad=False, device=self.device)
        x1 = torch.tensor(self.normalize(x1_test), dtype=torch.float, requires_grad=False, device=self.device)
        with torch.no_grad():
            p0, p1 = self.best_probe(x0), self.best_probe(x1)
        avg_confidence = 0.5*(p0 + (1-p1))
        predictions = (avg_confidence.detach().cpu().numpy() < 0.5).astype(int)[:, 0]
        acc = (predictions == y_test).mean()
        acc = max(acc, 1 - acc)

        return acc
    
        
    def train(self):
        """
        Does a single training run of nepochs epochs
        """
        x0, x1 = self.get_tensor_data()
        permutation = torch.randperm(len(x0))
        x0, x1 = x0[permutation], x1[permutation]
        
        # set up optimizer
        optimizer = torch.optim.AdamW(self.probe.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        
        batch_size = len(x0) if self.batch_size == -1 else self.batch_size
        nbatches = len(x0) // batch_size

        # Start training (full batch)
        for epoch in range(self.nepochs):
            for j in range(nbatches):
                x0_batch = x0[j*batch_size:(j+1)*batch_size]
                x1_batch = x1[j*batch_size:(j+1)*batch_size]
            
                # probe
                p0, p1 = self.probe(x0_batch), self.probe(x1_batch)

                # get the corresponding loss
                loss = self.get_loss(p0, p1)

                # update the parameters
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        return loss.detach().cpu().item()
    
    def repeated_train(self):
        best_loss = np.inf
        for train_num in range(self.ntries):
            self.initialize_probe()
            loss = self.train()
            if loss < best_loss:
                self.best_probe = copy.deepcopy(self.probe)
                best_loss = loss

        return best_loss

In [19]:
# Train CCS without any labels
ccs = CCS(neg_hs_train, pos_hs_train, device=device)
ccs.repeated_train()

# Evaluate
ccs_acc = ccs.get_acc(neg_hs_test, pos_hs_test, y_test)
print("CCS accuracy: {}".format(ccs_acc))

  x0 = torch.tensor(self.x0, dtype=torch.float, requires_grad=False, device=self.device)
  x1 = torch.tensor(self.x1, dtype=torch.float, requires_grad=False, device=self.device)


CCS accuracy: 0.6


  x0 = torch.tensor(self.normalize(x0_test), dtype=torch.float, requires_grad=False, device=self.device)
  x1 = torch.tensor(self.normalize(x1_test), dtype=torch.float, requires_grad=False, device=self.device)
