In [21]:
import torch
from transformer_lens import HookedTransformer
import json
from src.model import WrapHookedTransformer
from tqdm import tqdm

import transformer_lens.utils as utils
from transformer_lens.utils import get_act_name
from functools import partial
from transformer_lens import patching

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
dataset = json.load(open("P1GL.json"))
class Dataset:
    def __init__(self, dataset, model:WrapHookedTransformer):
        self.dataset = [
            {"prompt": original_dict["pair"][i],
             "target": original_dict["answer"][i],
             "length": original_dict["length"]}
        for original_dict in dataset
        for i in range(2)
        ]
        self.add_premise(model)
        self.dataset_per_length = self.split_for_lenght()

    def add_premise(self, model:WrapHookedTransformer):
        for d in self.dataset:
            ortoghonal_token = model.to_orthogonal_tokens(d["target"])
            premise = d["prompt"] + ortoghonal_token + " " + d["prompt"]
            d["premise_prompt"] = premise
            d["orthogonal_token"] = ortoghonal_token
            d["length"] = len(model.to_str_tokens(premise))
        return self

    def split_for_lenght(self):
        dataset_per_length = {}
        for d in self.dataset:
            length = d["length"]
            if length not in dataset_per_length:
                dataset_per_length[length] = []
            dataset_per_length[length].append(d)
        return dataset_per_length
    
    def logits(self, model:WrapHookedTransformer):
        logits_per_length = {}
        for length, dataset in self.dataset_per_length.items():
            input_ids = model.to_tokens([d["premise_prompt"] for d in dataset])
            logits_per_length[length] = model(input_ids)
        return logits_per_length
  
    def get_tensor_token(self,model):
        tensor_token_per_length = {}
        for length, dataset in self.dataset_per_length.items():
            if length not in tensor_token_per_length:
                tensor_token_per_length[length] = {}
            tensor_token_per_length[length]["target"] = model.to_tokens([d["target"] for d in dataset], prepend_bos=False)
            tensor_token_per_length[length]["orthogonal_token"] = model.to_tokens([d["orthogonal_token"] for d in dataset], prepend_bos=False)
        
        for length, tensor in tensor_token_per_length.items():
            tensor_token_per_length[length]["target"] = tensor_token_per_length[length]["target"].squeeze(1)
            tensor_token_per_length[length]["orthogonal_token"] = tensor_token_per_length[length]["orthogonal_token"].squeeze(1)
        return tensor_token_per_length
    
model = WrapHookedTransformer.from_pretrained("gpt2", device="cpu")
dataset = Dataset(dataset, model)
    

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [23]:
logits_per_length = dataset.logits(model)
tokens_dict_per_length = dataset.get_tensor_token(model)

delta_per_length = {}
for max_len in list(tokens_dict_per_length.keys()):
    if max_len not in delta_per_length:
        delta_per_length[max_len] = {}
    # probs = torch.softmax(logits_per_length[max_len], dim=-1)
    probs = logits_per_length[max_len]
    batch_index = torch.arange(probs.shape[0])
    delta_per_length[max_len] = probs[batch_index,-1, tokens_dict_per_length[max_len]["target"]] - probs[batch_index,-1, tokens_dict_per_length[max_len]["orthogonal_token"]]
    

In [43]:
# get the mean of the logits of target and orthogonal token
target_logits_sum = 0
orthogonal_logits_sum = 0
for max_len in list(logits_per_length.keys()):
    probs = logits_per_length[max_len]
    probs = torch.softmax(logits_per_length[max_len], dim=-1)
    batch_index = torch.arange(probs.shape[0])
    target_logits_sum += probs[batch_index,-1, tokens_dict_per_length[max_len]["target"]].mean()
    orthogonal_logits_sum += probs[batch_index,-1, tokens_dict_per_length[max_len]["orthogonal_token"]].mean()

target_logits_sum /= len(list(logits_per_length.keys()))
orthogonal_logits_sum /= len(list(logits_per_length.keys()))
print(target_logits_sum)
print(orthogonal_logits_sum)

tensor(0.0146)
tensor(0.3278)


In [37]:
model(d["premise_prompt"]).shape

torch.Size([1, 18, 50257])

## Who win?

In [41]:
factual_winners = []
orthogonal_winners = []
for d in dataset.dataset:
    prob = torch.softmax(model(d["premise_prompt"]), dim=-1)
    pred = prob[:,-1,:].argmax(-1)
    if model.to_string(pred) == d["target"]:
        factual_winners.append(d)
    if model.to_string(pred) == d["orthogonal_token"]:
        orthogonal_winners.append(d)

In [45]:
len(dataset.dataset)

290

In [44]:
print(len(factual_winners)/len(dataset.dataset))
print(len(orthogonal_winners)/len(dataset.dataset))

0.2
0.5310344827586206
