In [5]:
import sys
sys.path.append('..')
sys.path.append('../src')
sys.path.append('../data')

In [2]:
import torch
from transformer_lens import HookedTransformer
import json
from src.model import WrapHookedTransformer
from tqdm import tqdm

import transformer_lens.utils as utils
from transformer_lens.utils import get_act_name
from functools import partial
from transformer_lens import patching

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = json.load(open("known_1000.json"))

In [3]:
model = WrapHookedTransformer.from_pretrained("gpt2")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [4]:
dataset = []
for d in tqdm(data, total=len(data)):
    dataset.append(
        {"prompt": d["prompt"],
         "target": " " + d["attribute"]}
    )

100%|██████████| 1209/1209 [00:00<00:00, 1903496.07it/s]


In [5]:
dataset_per_length = {}
for d in tqdm(dataset, total=len(dataset)):
    orthogonal_token = model.to_orthogonal_tokens(d["target"])
    d["premise"] = d["prompt"] + orthogonal_token + " " + d["prompt"]
    d["orthogonal_token"] = orthogonal_token
    d["length"] = len(model.to_str_tokens(d["premise"]))
    if d["length"] not in dataset_per_length:
        dataset_per_length[d["length"]] = []
    dataset_per_length[d["length"]].append(d)

100%|██████████| 1209/1209 [00:38<00:00, 31.70it/s]


In [7]:
#create a pytorch dataloader for each length
dataloaders = {}
for length in sorted(dataset_per_length.keys()):
    dataloaders[length] = torch.utils.data.DataLoader(dataset_per_length[length], batch_size=100, shuffle=True)

In [8]:
target_probs_mean = {}
orthogonal_probs_mean = {}
target_win = {}
orthogonal_win = {}
other_win = {}
target_win_over_orthogonal = {}
target_win_over_orthogonal_dataset = []


for length in sorted(dataset_per_length.keys()):
    # get logits for each example
    target_probs_mean[length] = []
    orthogonal_probs_mean[length] = []
    target_win[length] = 0
    orthogonal_win[length] = 0
    other_win[length] = 0
    target_win_over_orthogonal[length] = 0
    for batch in tqdm(dataloaders[length]):
        logit = model(batch["premise"])
        probs = torch.softmax(logit, dim=-1)
        batch_index = torch.arange(probs.shape[0])
        target_probs = probs[batch_index, -1, model.to_tokens(batch["target"], prepend_bos=False).squeeze(-1)]
        orthogonal_probs = probs[batch_index, -1, model.to_tokens(batch["orthogonal_token"], prepend_bos=False).squeeze(-1)]
        predictions = probs[:,-1,:].max(dim=-1)[0]
        # for each element of the batch check the prediction and update the win counter
        for i in range(len(batch["premise"])):
            if target_probs[i] == predictions[i]:
                target_win[length] += 1
            elif orthogonal_probs[i] == predictions[i]:
                orthogonal_win[length] += 1
            if target_probs[i] > orthogonal_probs[i]:
                target_win_over_orthogonal[length] += 1
                target_win_over_orthogonal_dataset.append(batch[i])
        
        target_probs_mean[length].append(target_probs.mean().item())
        orthogonal_probs_mean[length].append(orthogonal_probs.mean().item())
    
    # mean of logits for each length
    target_probs_mean[length] = sum(target_probs_mean[length]) / len(target_probs_mean[length])
    orthogonal_probs_mean[length] = sum(orthogonal_probs_mean[length]) / len(orthogonal_probs_mean[length])
        
        

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
100%|██████████| 1/1 [00:00<00:00,  6.18it/s]
100%|██████████| 1/1 [00:00<00:00,  1.80it/s]
100%|██████████| 1/1 [00:00<00:00,  2.06it/s]
100%|██████████| 1/1 [00:01<00:00,  1.27s/it]
100%|██████████| 1/1 [00:00<00:00,  1.80it/s]
100%|██████████| 2/2 [00:01<00:00,  1.16it/s]
100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
100%|██████████| 2/2 [00:02<00:00,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00,  1.29it/s]
100%|██████████| 2/2 [00:02<00:00,  1.26s/it]
100%|██████████| 1/1 [00:00<00:00,  1.00it/s]
100%|██████████| 2/2 [00:02<00:00,  1.30s/it]
100%|██████████| 1/1 [00:00<00:00,  1.34it/s]
100%|██████████| 1/1 [00:02<00:00,  2.25s/it]
100%|██████████| 1/1 [00:00<00:00,  2.39it/s]
100%|██████████| 1/1 [00:01<00:00,  1.39s/it]
100%|██████████| 1/1 [00:00<00:00,  1.96it/s]
100%|██████████| 1/1 [00:00<00:00,  1.09it/s]
100%|██████████| 1/1 [00:00<00:00,  2.47it/s]
100%|██████████| 1/1 [00:01<00:00,  1.24s/it]
100%|██████████| 1/1 [00:00<00:00,

In [9]:
#sum target win and orthogonal win and target_win_over_orthogonal for each length
target_win = sum(target_win.values())
orthogonal_win = sum(orthogonal_win.values())

#print percentages over the total number of examples
print("target win", target_win / len(dataset))
print("orthogonal win", orthogonal_win / len(dataset))
target_win_over_orthogonal = sum(target_win_over_orthogonal.values())
print("target win over orthogonal", target_win_over_orthogonal / len(dataset))

target win 0.04052936311000827
orthogonal win 0.7014061207609594
other win 0.0


In [10]:
target_win_over_orthogonal = sum(target_win_over_orthogonal.values())
print("target win over orthogonal", target_win_over_orthogonal / len(dataset))

target win over orthogonal 0.1728701406120761


## dataset expanded

In [7]:
data = json.load(open("../data/counterfact.json"))
model = WrapHookedTransformer.from_pretrained("gpt2")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [15]:
dataset = []
for d in tqdm(data, total=len(data)):
    for i in range(len(d["attribute_prompts"])):
        dataset.append(
            {"prompt": d["attribute_prompts"][i],
             "target": " " + d["requested_rewrite"]["target_new"]["str"]}
        )
    
    for i in range(len(d["neighborhood_prompts"])):
        dataset.append(
            {"prompt": d["neighborhood_prompts"][i],
             "target": " " + d["requested_rewrite"]["target_true"]["str"]}
        )
print(len(dataset))

100%|██████████| 21919/21919 [00:00<00:00, 206264.90it/s]

438380





In [9]:
#random shuffle
import random
random.shuffle(dataset)
dataset = dataset[:2000]

In [10]:
dataset_per_length = {}
for d in tqdm(dataset, total=len(dataset)):
    orthogonal_token = model.to_orthogonal_tokens(d["target"])
    d["premise"] = d["prompt"] + orthogonal_token + " " + d["prompt"]
    d["orthogonal_token"] = orthogonal_token
    d["length"] = len(model.to_str_tokens(d["premise"]))
    if d["length"] not in dataset_per_length:
        dataset_per_length[d["length"]] = []
    dataset_per_length[d["length"]].append(d)

  0%|          | 0/2000 [00:00<?, ?it/s]

100%|██████████| 2000/2000 [00:58<00:00, 34.20it/s]


In [11]:
#create a pytorch dataloader for each length
dataloaders = {}
for length in sorted(dataset_per_length.keys()):
    dataloaders[length] = torch.utils.data.DataLoader(dataset_per_length[length], batch_size=100, shuffle=True)

In [12]:
target_probs_mean = {}
orthogonal_probs_mean = {}
target_win = {}
orthogonal_win = {}
other_win = {}
target_win_over_orthogonal = {}
target_win_dataset = []
orthogonal_win_dataset = []


for length in sorted(dataset_per_length.keys()):
    # get logits for each example
    target_probs_mean[length] = []
    orthogonal_probs_mean[length] = []
    target_win[length] = 0
    orthogonal_win[length] = 0
    other_win[length] = 0
    target_win_over_orthogonal[length] = 0
    for batch in tqdm(dataloaders[length]):
        logit = model(batch["premise"])
        probs = torch.softmax(logit, dim=-1)
        batch_index = torch.arange(probs.shape[0])
        target_probs = probs[batch_index, -1, model.to_tokens(batch["target"], prepend_bos=False).squeeze(-1)]
        orthogonal_probs = probs[batch_index, -1, model.to_tokens(batch["orthogonal_token"], prepend_bos=False).squeeze(-1)]
        predictions = probs[:,-1,:].max(dim=-1)[0]
        # for each element of the batch check the prediction and update the win counter
        for i in range(len(batch["premise"])):
            if target_probs[i] == predictions[i]:
                target_win[length] += 1
                target_win_dataset.append(
                    {
                        "prompt": batch["prompt"][i],
                        "target": batch["target"][i],
                        "premise": batch["premise"][i],
                        "orthogonal_token": batch["orthogonal_token"][i],
                        "length": float(batch["length"][i].cpu().detach().numpy().item()),
                        "target_probs": float(target_probs[i].cpu().detach().numpy().item()),
                        "orthogonal_probs": float(orthogonal_probs[i].cpu().detach().numpy().item()),
                        
                    }
                
                )
            elif orthogonal_probs[i] == predictions[i]:
                orthogonal_win[length] += 1
                orthogonal_win_dataset.append(
                    {
                        "prompt": batch["prompt"][i],
                        "target": batch["target"][i],
                        "premise": batch["premise"][i],
                        "orthogonal_token": batch["orthogonal_token"][i],
                        "length": float(batch["length"][i].cpu().detach().numpy().item()),
                        "target_probs": float(target_probs[i].cpu().detach().numpy().item()),
                        "orthogonal_probs": float(orthogonal_probs[i].cpu().detach().numpy().item()),
                        
                    }
                
                )
            if target_probs[i] > orthogonal_probs[i]:
                target_win_over_orthogonal[length] += 1
                # target_win_over_orthogonal_dataset.append(batch[i])
        
        target_probs_mean[length].append(target_probs.mean().item())
        orthogonal_probs_mean[length].append(orthogonal_probs.mean().item())
    
    # mean of logits for each length
    target_probs_mean[length] = sum(target_probs_mean[length]) / len(target_probs_mean[length])
    orthogonal_probs_mean[length] = sum(orthogonal_probs_mean[length]) / len(orthogonal_probs_mean[length])
        

100%|██████████| 1/1 [00:00<00:00, 10.85it/s]
100%|██████████| 1/1 [00:00<00:00,  4.80it/s]
100%|██████████| 1/1 [00:00<00:00,  5.95it/s]
100%|██████████| 1/1 [00:00<00:00,  1.75it/s]
100%|██████████| 1/1 [00:00<00:00,  2.61it/s]
100%|██████████| 2/2 [00:01<00:00,  1.45it/s]
100%|██████████| 1/1 [00:00<00:00,  1.07it/s]
100%|██████████| 3/3 [00:02<00:00,  1.06it/s]
100%|██████████| 2/2 [00:01<00:00,  1.39it/s]
100%|██████████| 3/3 [00:04<00:00,  1.34s/it]
100%|██████████| 1/1 [00:01<00:00,  1.24s/it]
100%|██████████| 3/3 [00:03<00:00,  1.32s/it]
100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
100%|██████████| 2/2 [00:03<00:00,  1.84s/it]
100%|██████████| 1/1 [00:00<00:00,  1.29it/s]
100%|██████████| 2/2 [00:02<00:00,  1.32s/it]
100%|██████████| 1/1 [00:00<00:00,  1.03it/s]
100%|██████████| 1/1 [00:02<00:00,  2.15s/it]
100%|██████████| 1/1 [00:00<00:00,  1.50it/s]
100%|██████████| 1/1 [00:01<00:00,  1.37s/it]
100%|██████████| 1/1 [00:00<00:00,  4.28it/s]
100%|██████████| 1/1 [00:01<00:00,

In [13]:
#sum target win and orthogonal win and target_win_over_orthogonal for each length
target_win = sum(target_win.values())
orthogonal_win = sum(orthogonal_win.values())

#print percentages over the total number of examples
print("target win", target_win / len(dataset))
print("orthogonal win", orthogonal_win / len(dataset))
target_win_over_orthogonal = sum(target_win_over_orthogonal.values())
print("target win over orthogonal", target_win_over_orthogonal / len(dataset))

target win 0.021
orthogonal win 0.741
target win over orthogonal 0.0845


In [14]:
json.dump(target_win_dataset, open("../data/target_win_dataset.json", "w"), indent=4)
json.dump(orthogonal_win_dataset, open("../data/orthogonal_win_dataset.json", "w"), indent=4)