In [5]:
import sys
sys.path.append('..')
sys.path.append('../src')
sys.path.append('../data')

In [2]:
import torch
from transformer_lens import HookedTransformer
import json
from src.model import WrapHookedTransformer
from tqdm import tqdm

import transformer_lens.utils as utils
from transformer_lens.utils import get_act_name
from functools import partial
from transformer_lens import patching

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data = json.load(open("known_1000.json"))

In [3]:
model = WrapHookedTransformer.from_pretrained("gpt2")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [4]:
dataset = []
for d in tqdm(data, total=len(data)):
    dataset.append(
        {"prompt": d["prompt"],
         "target": " " + d["attribute"]}
    )

100%|██████████| 1209/1209 [00:00<00:00, 1903496.07it/s]


In [5]:
dataset_per_length = {}
for d in tqdm(dataset, total=len(dataset)):
    orthogonal_token = model.to_orthogonal_tokens(d["target"])
    d["premise"] = d["prompt"] + orthogonal_token + " " + d["prompt"]
    d["orthogonal_token"] = orthogonal_token
    d["length"] = len(model.to_str_tokens(d["premise"]))
    if d["length"] not in dataset_per_length:
        dataset_per_length[d["length"]] = []
    dataset_per_length[d["length"]].append(d)

100%|██████████| 1209/1209 [00:38<00:00, 31.70it/s]


In [7]:
#create a pytorch dataloader for each length
dataloaders = {}
for length in sorted(dataset_per_length.keys()):
    dataloaders[length] = torch.utils.data.DataLoader(dataset_per_length[length], batch_size=100, shuffle=True)

In [8]:
target_probs_mean = {}
orthogonal_probs_mean = {}
target_win = {}
orthogonal_win = {}
other_win = {}
target_win_over_orthogonal = {}
target_win_over_orthogonal_dataset = []


for length in sorted(dataset_per_length.keys()):
    # get logits for each example
    target_probs_mean[length] = []
    orthogonal_probs_mean[length] = []
    target_win[length] = 0
    orthogonal_win[length] = 0
    other_win[length] = 0
    target_win_over_orthogonal[length] = 0
    for batch in tqdm(dataloaders[length]):
        logit = model(batch["premise"])
        probs = torch.softmax(logit, dim=-1)
        batch_index = torch.arange(probs.shape[0])
        target_probs = probs[batch_index, -1, model.to_tokens(batch["target"], prepend_bos=False).squeeze(-1)]
        orthogonal_probs = probs[batch_index, -1, model.to_tokens(batch["orthogonal_token"], prepend_bos=False).squeeze(-1)]
        predictions = probs[:,-1,:].max(dim=-1)[0]
        # for each element of the batch check the prediction and update the win counter
        for i in range(len(batch["premise"])):
            if target_probs[i] == predictions[i]:
                target_win[length] += 1
            elif orthogonal_probs[i] == predictions[i]:
                orthogonal_win[length] += 1
            if target_probs[i] > orthogonal_probs[i]:
                target_win_over_orthogonal[length] += 1
                target_win_over_orthogonal_dataset.append(batch[i])
        
        target_probs_mean[length].append(target_probs.mean().item())
        orthogonal_probs_mean[length].append(orthogonal_probs.mean().item())
    
    # mean of logits for each length
    target_probs_mean[length] = sum(target_probs_mean[length]) / len(target_probs_mean[length])
    orthogonal_probs_mean[length] = sum(orthogonal_probs_mean[length]) / len(orthogonal_probs_mean[length])
        
        

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  2.25it/s]
100%|██████████| 1/1 [00:00<00:00,  6.18it/s]
100%|██████████| 1/1 [00:00<00:00,  1.80it/s]
100%|██████████| 1/1 [00:00<00:00,  2.06it/s]
100%|██████████| 1/1 [00:01<00:00,  1.27s/it]
100%|██████████| 1/1 [00:00<00:00,  1.80it/s]
100%|██████████| 2/2 [00:01<00:00,  1.16it/s]
100%|██████████| 1/1 [00:00<00:00,  1.48it/s]
100%|██████████| 2/2 [00:02<00:00,  1.21s/it]
100%|██████████| 1/1 [00:00<00:00,  1.29it/s]
100%|██████████| 2/2 [00:02<00:00,  1.26s/it]
100%|██████████| 1/1 [00:00<00:00,  1.00it/s]
100%|██████████| 2/2 [00:02<00:00,  1.30s/it]
100%|██████████| 1/1 [00:00<00:00,  1.34it/s]
100%|██████████| 1/1 [00:02<00:00,  2.25s/it]
100%|██████████| 1/1 [00:00<00:00,  2.39it/s]
100%|██████████| 1/1 [00:01<00:00,  1.39s/it]
100%|██████████| 1/1 [00:00<00:00,  1.96it/s]
100%|██████████| 1/1 [00:00<00:00,  1.09it/s]
100%|██████████| 1/1 [00:00<00:00,  2.47it/s]
100%|██████████| 1/1 [00:01<00:00,  1.24s/it]
100%|██████████| 1/1 [00:00<00:00,

In [9]:
#sum target win and orthogonal win and target_win_over_orthogonal for each length
target_win = sum(target_win.values())
orthogonal_win = sum(orthogonal_win.values())

#print percentages over the total number of examples
print("target win", target_win / len(dataset))
print("orthogonal win", orthogonal_win / len(dataset))
target_win_over_orthogonal = sum(target_win_over_orthogonal.values())
print("target win over orthogonal", target_win_over_orthogonal / len(dataset))

target win 0.04052936311000827
orthogonal win 0.7014061207609594
other win 0.0


In [10]:
target_win_over_orthogonal = sum(target_win_over_orthogonal.values())
print("target win over orthogonal", target_win_over_orthogonal / len(dataset))

target win over orthogonal 0.1728701406120761


## dataset expanded

In [7]:
data = json.load(open("../data/counterfact.json"))
model = WrapHookedTransformer.from_pretrained("gpt2")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [15]:
dataset = []
for d in tqdm(data, total=len(data)):
    for i in range(len(d["attribute_prompts"])):
        dataset.append(
            {"prompt": d["attribute_prompts"][i],
             "target": " " + d["requested_rewrite"]["target_new"]["str"]}
        )
    
    for i in range(len(d["neighborhood_prompts"])):
        dataset.append(
            {"prompt": d["neighborhood_prompts"][i],
             "target": " " + d["requested_rewrite"]["target_true"]["str"]}
        )
print(len(dataset))

100%|██████████| 21919/21919 [00:00<00:00, 206264.90it/s]

438380





In [9]:
#random shuffle
import random
random.shuffle(dataset)
dataset = dataset[:2000]

In [16]:
dataset_per_length = {}
for d in tqdm(dataset, total=len(dataset)):
    orthogonal_token = model.to_orthogonal_tokens(d["target"])
    d["premise"] = d["prompt"] + orthogonal_token + " " + d["prompt"]
    d["orthogonal_token"] = orthogonal_token
    d["length"] = len(model.to_str_tokens(d["premise"]))
    if d["length"] not in dataset_per_length:
        dataset_per_length[d["length"]] = []
    dataset_per_length[d["length"]].append(d)

  0%|          | 0/438380 [00:00<?, ?it/s]

100%|██████████| 438380/438380 [3:47:59<00:00, 32.05it/s]  


In [18]:
#create a pytorch dataloader for each length
dataloaders = {}
for length in sorted(dataset_per_length.keys()):
    dataloaders[length] = torch.utils.data.DataLoader(dataset_per_length[length], batch_size=100, shuffle=True)

In [59]:
target_probs_mean = {}
orthogonal_probs_mean = {}
target_win = {}
orthogonal_win = {}
other_win = {}
target_win_over_orthogonal = {}
target_win_dataset = []
orthogonal_win_dataset = []


for length in sorted(dataset_per_length.keys()):
    # get logits for each example
    target_probs_mean[length] = []
    orthogonal_probs_mean[length] = []
    target_win[length] = 0
    orthogonal_win[length] = 0
    other_win[length] = 0
    target_win_over_orthogonal[length] = 0
    for batch in tqdm(dataloaders[length]):
        logit = model(batch["premise"])
        probs = torch.softmax(logit, dim=-1)
        batch_index = torch.arange(probs.shape[0])
        target_probs = probs[batch_index, -1, model.to_tokens(batch["target"], prepend_bos=False).squeeze(-1)]
        if len(model.to_tokens(batch["orthogonal_token"], prepend_bos=False).squeeze(-1).shape) == 2:
            orthogonal_tokens = model.to_tokens(batch["orthogonal_token"], prepend_bos=False)[:,0].squeeze(-1)
        else:
            orthogonal_tokens = model.to_tokens(batch["orthogonal_token"], prepend_bos=False).squeeze(-1)
        orthogonal_probs = probs[batch_index, -1, orthogonal_tokens]
        predictions = probs[:,-1,:].max(dim=-1)[0]
        # for each element of the batch check the prediction and update the win counter
        for i in range(len(batch["premise"])):
            if target_probs[i] == predictions[i]:
                target_win[length] += 1
                target_win_dataset.append(
                    {
                        "prompt": batch["prompt"][i],
                        "target": batch["target"][i],
                        "premise": batch["premise"][i],
                        "orthogonal_token": batch["orthogonal_token"][i],
                        "length": float(batch["length"][i].cpu().detach().numpy().item()),
                        "target_probs": float(target_probs[i].cpu().detach().numpy().item()),
                        "orthogonal_probs": float(orthogonal_probs[i].cpu().detach().numpy().item()),
                        
                    }
                
                )
            elif orthogonal_probs[i] == predictions[i]:
                orthogonal_win[length] += 1
                orthogonal_win_dataset.append(
                    {
                        "prompt": batch["prompt"][i],
                        "target": batch["target"][i],
                        "premise": batch["premise"][i],
                        "orthogonal_token": batch["orthogonal_token"][i],
                        "length": float(batch["length"][i].cpu().detach().numpy().item()),
                        "target_probs": float(target_probs[i].cpu().detach().numpy().item()),
                        "orthogonal_probs": float(orthogonal_probs[i].cpu().detach().numpy().item()),
                        
                    }
                
                )
            if target_probs[i] > orthogonal_probs[i]:
                target_win_over_orthogonal[length] += 1
                # target_win_over_orthogonal_dataset.append(batch[i])
        
        target_probs_mean[length].append(target_probs.mean().item())
        orthogonal_probs_mean[length].append(orthogonal_probs.mean().item())
    
    # mean of logits for each length
    target_probs_mean[length] = sum(target_probs_mean[length]) / len(target_probs_mean[length])
    orthogonal_probs_mean[length] = sum(orthogonal_probs_mean[length]) / len(orthogonal_probs_mean[length])
        

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  2.03it/s]
100%|██████████| 2/2 [00:01<00:00,  1.67it/s]
100%|██████████| 30/30 [00:20<00:00,  1.44it/s]
100%|██████████| 27/27 [00:21<00:00,  1.25it/s]
100%|██████████| 145/145 [02:25<00:00,  1.00s/it]
100%|██████████| 83/83 [01:33<00:00,  1.13s/it]
 85%|████████▍ | 257/304 [32:06<05:52,  7.50s/it]


KeyboardInterrupt: 

In [64]:
def append_to_dataset(dataset, batch, i, target_probs, orthogonal_probs):
    dataset.append({
        "prompt": batch["prompt"][i],
        "target": batch["target"][i],
        "premise": batch["premise"][i],
        "orthogonal_token": batch["orthogonal_token"][i],
        "length": float(batch["length"][i].cpu().detach().numpy().item()),
        "target_probs": float(target_probs[i].cpu().detach().numpy().item()),
        "orthogonal_probs": float(orthogonal_probs[i].cpu().detach().numpy().item()),
    })

target_probs_mean = {}
orthogonal_probs_mean = {}
target_win = {}
orthogonal_win = {}
other_win = {}
target_win_over_orthogonal = {}
target_win_dataset = []
orthogonal_win_dataset = []

for length in sorted(dataset_per_length.keys()):
    target_probs_mean[length] = []
    orthogonal_probs_mean[length] = []
    target_win[length] = 0
    orthogonal_win[length] = 0
    other_win[length] = 0
    target_win_over_orthogonal[length] = 0

    for batch in tqdm(dataloaders[length]):
        logit = model(batch["premise"])
        probs = torch.softmax(logit, dim=-1)
        batch_index = torch.arange(probs.shape[0])
        
        target_tokens = model.to_tokens(batch["target"], prepend_bos=False).squeeze(-1)
        target_probs = probs[batch_index, -1, target_tokens]
        
        orthogonal_tokens = model.to_tokens(batch["orthogonal_token"], prepend_bos=False).squeeze(-1)
        if len(orthogonal_tokens.shape) == 2:
            orthogonal_tokens = orthogonal_tokens[:, 0]
        orthogonal_probs = probs[batch_index, -1, orthogonal_tokens]
        
        predictions = probs[:, -1, :].max(dim=-1)[0]

        for i in range(len(batch["premise"])):
            if target_probs[i] == predictions[i]:
                target_win[length] += 1
                append_to_dataset(target_win_dataset, batch, i, target_probs, orthogonal_probs)
            elif orthogonal_probs[i] == predictions[i]:
                orthogonal_win[length] += 1
                append_to_dataset(orthogonal_win_dataset, batch, i, target_probs, orthogonal_probs)
            if target_probs[i] > orthogonal_probs[i]:
                target_win_over_orthogonal[length] += 1

        target_probs_mean[length].append(target_probs.mean().item())
        orthogonal_probs_mean[length].append(orthogonal_probs.mean().item())

    target_probs_mean[length] = sum(target_probs_mean[length]) / len(target_probs_mean[length])
    orthogonal_probs_mean[length] = sum(orthogonal_probs_mean[length]) / len(orthogonal_probs_mean[length])


100%|██████████| 1/1 [00:23<00:00, 23.63s/it]
  0%|          | 0/2 [00:09<?, ?it/s]


KeyboardInterrupt: 

In [60]:
# #sum target win and orthogonal win and target_win_over_orthogonal for each length
target_win = sum(target_win.values())
orthogonal_win = sum(orthogonal_win.values())
target_win_over_orthogonal = sum(target_win_over_orthogonal.values())

#print percentages over the total number of examples
print("target win", target_win / len(dataset))
print("orthogonal win", orthogonal_win / len(dataset))
print("target win over orthogonal", target_win_over_orthogonal / len(dataset))

target win 0.0011588119895980656
orthogonal win 0.09394361056617546
target win over orthogonal 0.008038687896345636


In [63]:
json.dump(target_win_dataset, open("../data/target_win_dataset_partial.json", "w"), indent=4)
json.dump(orthogonal_win_dataset, open("../data/orthogonal_win_dataset_partial.json", "w"), indent=4)

In [62]:
len(target_win_dataset)

508

## Dataset statistics and check

In [65]:
dataset = json.load(open("../data/target_win_dataset.json"))

In [84]:
index = 567
print(model.to_string(model(dataset[index]["premise"])[:,-1,:].argmax(dim=-1)))
print(dataset[index]["target"])
print(dataset[index]["premise"])

 Microsoft
 Microsoft
Windows 2000, developed bypi Windows 2000, developed by


In [101]:
print(model.to_string(model("Windows 2000, developed byp0 Windows 2000, developed by")[:,-1,:].argmax(dim=-1)))

 Microsoft


In [104]:
import numpy as np
target_win = np.load("../script/data/target_win.npy")
orthogonal_win = np.load("../script/data/orthogonal_win.npy")

In [106]:
orthogonal_win[0,:]

array([0.0082713 , 0.00744417, 0.00992556, 0.01157982, 0.00744417,
       0.00909843, 0.00744417, 0.0082713 , 0.0082713 , 0.00909843])