In [8]:
import sys
sys.path.append('..')
sys.path.append('../src')
sys.path.append('../data')

In [9]:
import torch
from transformer_lens import HookedTransformer
import json
from src.model import WrapHookedTransformer
from tqdm import tqdm

import transformer_lens.utils as utils
from transformer_lens.utils import get_act_name
from functools import partial
from transformer_lens import patching

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [10]:
data = json.load(open("../data/known_1000.json"))

In [11]:
model = WrapHookedTransformer.from_pretrained("gpt2")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [12]:
dataset = []
for d in tqdm(data, total=len(data)):
    dataset.append(
        {"prompt": d["prompt"],
         "target": " " + d["attribute"]}
    )

100%|██████████| 1209/1209 [00:00<00:00, 1871876.54it/s]


In [13]:
dataset_per_length = {}
for d in tqdm(dataset, total=len(dataset)):
    orthogonal_token = model.to_orthogonal_tokens(d["target"], alpha=0.5)
    d["premise"] = d["prompt"] + orthogonal_token + " " + d["prompt"]
    d["orthogonal_token"] = orthogonal_token
    d["length"] = len(model.to_str_tokens(d["premise"]))
    if d["length"] not in dataset_per_length:
        dataset_per_length[d["length"]] = []
    dataset_per_length[d["length"]].append(d)

  0%|          | 0/1209 [00:00<?, ?it/s]

100%|██████████| 1209/1209 [00:19<00:00, 61.95it/s]


In [14]:
#create a pytorch dataloader for each length
dataloaders = {}
for length in sorted(dataset_per_length.keys()):
    dataloaders[length] = torch.utils.data.DataLoader(dataset_per_length[length], batch_size=100, shuffle=True)

In [15]:
target_probs_mean = {}
orthogonal_probs_mean = {}
target_win = {}
orthogonal_win = {}
other_win = {}
target_win_over_orthogonal = {}
target_win_over_orthogonal_dataset = []


for length in sorted(dataset_per_length.keys()):
    # get logits for each example
    target_probs_mean[length] = []
    orthogonal_probs_mean[length] = []
    target_win[length] = 0
    orthogonal_win[length] = 0
    other_win[length] = 0
    target_win_over_orthogonal[length] = 0
    for batch in tqdm(dataloaders[length]):
        logit = model(batch["premise"])
        probs = torch.softmax(logit, dim=-1)
        batch_index = torch.arange(probs.shape[0])
        target_probs = probs[batch_index, -1, model.to_tokens(batch["target"], prepend_bos=False).squeeze(-1)]
        orthogonal_probs = probs[batch_index, -1, model.to_tokens(batch["orthogonal_token"], prepend_bos=False).squeeze(-1)]
        predictions = probs[:,-1,:].max(dim=-1)[0]
        # for each element of the batch check the prediction and update the win counter
        for i in range(len(batch["premise"])):
            if target_probs[i] == predictions[i]:
                target_win[length] += 1
            elif orthogonal_probs[i] == predictions[i]:
                orthogonal_win[length] += 1
            if target_probs[i] > orthogonal_probs[i]:
                target_win_over_orthogonal[length] += 1

        
        target_probs_mean[length].append(target_probs.mean().item())
        orthogonal_probs_mean[length].append(orthogonal_probs.mean().item())
    
    # mean of logits for each length
    target_probs_mean[length] = sum(target_probs_mean[length]) / len(target_probs_mean[length])
    orthogonal_probs_mean[length] = sum(orthogonal_probs_mean[length]) / len(orthogonal_probs_mean[length])
        
        

  0%|          | 0/1 [00:00<?, ?it/s]

100%|██████████| 1/1 [00:00<00:00,  5.29it/s]
100%|██████████| 1/1 [00:00<00:00,  7.18it/s]
100%|██████████| 1/1 [00:00<00:00,  2.35it/s]
100%|██████████| 1/1 [00:00<00:00,  2.45it/s]
100%|██████████| 1/1 [00:01<00:00,  1.15s/it]
100%|██████████| 1/1 [00:00<00:00,  2.18it/s]
100%|██████████| 2/2 [00:01<00:00,  1.30it/s]
100%|██████████| 1/1 [00:00<00:00,  1.73it/s]
100%|██████████| 2/2 [00:03<00:00,  1.81s/it]
100%|██████████| 1/1 [00:00<00:00,  1.39it/s]
100%|██████████| 2/2 [00:02<00:00,  1.12s/it]
100%|██████████| 1/1 [00:00<00:00,  1.20it/s]
100%|██████████| 2/2 [00:02<00:00,  1.14s/it]
100%|██████████| 1/1 [00:00<00:00,  1.50it/s]
100%|██████████| 1/1 [00:02<00:00,  2.17s/it]
100%|██████████| 1/1 [00:00<00:00,  2.50it/s]
100%|██████████| 1/1 [00:01<00:00,  1.23s/it]
100%|██████████| 1/1 [00:00<00:00,  2.29it/s]
100%|██████████| 1/1 [00:00<00:00,  1.26it/s]
100%|██████████| 1/1 [00:00<00:00,  2.36it/s]
100%|██████████| 1/1 [00:01<00:00,  1.14s/it]
100%|██████████| 1/1 [00:00<00:00,

In [16]:
#sum target win and orthogonal win and target_win_over_orthogonal for each length
target_win = sum(target_win.values())
orthogonal_win = sum(orthogonal_win.values())
target_win_over_orthogonal = sum(target_win_over_orthogonal.values())

#print percentages over the total number of examples
print("target win", target_win / len(dataset))
print("orthogonal win", orthogonal_win / len(dataset))
print("target win over orthogonal", target_win_over_orthogonal / len(dataset))

target win 0.06617038875103391
orthogonal win 0.6947890818858561
target win over orthogonal 0.15467328370554176


In [10]:
target_win_over_orthogonal = sum(target_win_over_orthogonal.values())
print("target win over orthogonal", target_win_over_orthogonal / len(dataset))

target win over orthogonal 0.1728701406120761


## dataset expanded

In [7]:
data = json.load(open("../data/counterfact.json"))
model = WrapHookedTransformer.from_pretrained("gpt2")

Using pad_token, but it is not set yet.


Loaded pretrained model gpt2 into HookedTransformer


In [15]:
dataset = []
for d in tqdm(data, total=len(data)):
    for i in range(len(d["attribute_prompts"])):
        dataset.append(
            {"prompt": d["attribute_prompts"][i],
             "target": " " + d["requested_rewrite"]["target_new"]["str"]}
        )
    
    for i in range(len(d["neighborhood_prompts"])):
        dataset.append(
            {"prompt": d["neighborhood_prompts"][i],
             "target": " " + d["requested_rewrite"]["target_true"]["str"]}
        )
print(len(dataset))

100%|██████████| 21919/21919 [00:00<00:00, 206264.90it/s]

438380





In [9]:
#random shuffle
import random
random.shuffle(dataset)
dataset = dataset[:2000]

In [16]:
dataset_per_length = {}
for d in tqdm(dataset, total=len(dataset)):
    orthogonal_token = model.to_orthogonal_tokens(d["target"])
    d["premise"] = d["prompt"] + orthogonal_token + " " + d["prompt"]
    d["orthogonal_token"] = orthogonal_token
    d["length"] = len(model.to_str_tokens(d["premise"]))
    if d["length"] not in dataset_per_length:
        dataset_per_length[d["length"]] = []
    dataset_per_length[d["length"]].append(d)

  0%|          | 0/438380 [00:00<?, ?it/s]

100%|██████████| 438380/438380 [3:47:59<00:00, 32.05it/s]  


In [3]:
dataset_per_length = json.load(open("../data/dataset_per_length.json"))

In [9]:
dataset_per_length.keys()

dict_keys(['30', '16', '19', '13', '12', '22', '20', '18', '25', '24', '23', '14', '15', '11', '10', '17', '28', '26', '29', '9', '8', '27', '21', '32', '37', '31', '34', '47', '51', '33', '35', '39', '36', '38', '6', '7', '40', '44', '42', '50', '41', '48', '46', '55', '56', '98', '58', '43', '52', '54', '45', '57', '62', '64', '60', '59', '80', '49', '74', '63', '102', '66', '70', '72', '100', '76', '96', '89', '68', '94', '108', '132', '82', '116', '130', '78', '106', '110', '92'])

In [20]:
dataset_per_length["14"] = dataset_per_length["14"][:5000]

In [21]:
#create a pytorch dataloader for each length
dataloaders = {}
for length in sorted(dataset_per_length.keys()):
    dataloaders[length] = torch.utils.data.DataLoader(dataset_per_length[length], batch_size=100, shuffle=True)

In [22]:
target_probs_mean = {}
orthogonal_probs_mean = {}
target_win = {}
orthogonal_win = {}
other_win = {}
target_win_over_orthogonal = {}
target_win_dataset = []
orthogonal_win_dataset = []


length = "14"
# get logits for each example
target_probs_mean[length] = []
orthogonal_probs_mean[length] = []
target_win[length] = 0
orthogonal_win[length] = 0
other_win[length] = 0
target_win_over_orthogonal[length] = 0
for batch in tqdm(dataloaders[length]):
    logit = model(batch["premise"])
    probs = torch.softmax(logit, dim=-1)
    batch_index = torch.arange(probs.shape[0])
    target_probs = probs[batch_index, -1, model.to_tokens(batch["target"], prepend_bos=False).squeeze(-1)]
    if len(model.to_tokens(batch["orthogonal_token"], prepend_bos=False).squeeze(-1).shape) == 2:
        orthogonal_tokens = model.to_tokens(batch["orthogonal_token"], prepend_bos=False)[:,0].squeeze(-1)
    else:
        orthogonal_tokens = model.to_tokens(batch["orthogonal_token"], prepend_bos=False).squeeze(-1)
    orthogonal_probs = probs[batch_index, -1, orthogonal_tokens]
    predictions = probs[:,-1,:].max(dim=-1)[0]
    # for each element of the batch check the prediction and update the win counter
    for i in range(len(batch["premise"])):
        if target_probs[i] == predictions[i]:
            target_win[length] += 1
            target_win_dataset.append(
                {
                    "prompt": batch["prompt"][i],
                    "target": batch["target"][i],
                    "premise": batch["premise"][i],
                    "orthogonal_token": batch["orthogonal_token"][i],
                    "length": float(batch["length"][i].cpu().detach().numpy().item()),
                    "target_probs": float(target_probs[i].cpu().detach().numpy().item()),
                    "orthogonal_probs": float(orthogonal_probs[i].cpu().detach().numpy().item()),
                    
                }
            
            )
        elif orthogonal_probs[i] == predictions[i]:
            orthogonal_win[length] += 1
            orthogonal_win_dataset.append(
                {
                    "prompt": batch["prompt"][i],
                    "target": batch["target"][i],
                    "premise": batch["premise"][i],
                    "orthogonal_token": batch["orthogonal_token"][i],
                    "length": float(batch["length"][i].cpu().detach().numpy().item()),
                    "target_probs": float(target_probs[i].cpu().detach().numpy().item()),
                    "orthogonal_probs": float(orthogonal_probs[i].cpu().detach().numpy().item()),
                    
                }
            
            )
        if target_probs[i] > orthogonal_probs[i]:
            target_win_over_orthogonal[length] += 1
            # target_win_over_orthogonal_dataset.append(batch[i])
    
    target_probs_mean[length].append(target_probs.mean().item())
    orthogonal_probs_mean[length].append(orthogonal_probs.mean().item())

# mean of logits for each length
target_probs_mean[length] = sum(target_probs_mean[length]) / len(target_probs_mean[length])
orthogonal_probs_mean[length] = sum(orthogonal_probs_mean[length]) / len(orthogonal_probs_mean[length])
        

  0%|          | 0/50 [00:00<?, ?it/s]

 46%|████▌     | 23/50 [07:27<08:40, 19.29s/it]

In [32]:

for batch in dataloaders[9]:
    if model.to_tokens(batch["orthogonal_token"], prepend_bos=False).shape[1] == 2:
        print("yes")
        errorbatch = batch

yes


In [58]:
len(tensor.shape)

2

In [51]:
tensor = model.to_tokens(errorbatch["orthogonal_token"], prepend_bos=False)

In [38]:
model.to_tokens(" user")

tensor([[50256,  2836]])

In [None]:
#sum target win and orthogonal win and target_win_over_orthogonal for each length
target_win = sum(target_win.values())
orthogonal_win = sum(orthogonal_win.values())

#print percentages over the total number of examples
print("target win", target_win / len(dataset))
print("orthogonal win", orthogonal_win / len(dataset))
target_win_over_orthogonal = sum(target_win_over_orthogonal.values())
print("target win over orthogonal", target_win_over_orthogonal / len(dataset))

target win 0.021
orthogonal win 0.741
target win over orthogonal 0.0845


In [None]:
json.dump(target_win_dataset, open("../data/target_win_dataset.json", "w"), indent=4)
json.dump(orthogonal_win_dataset, open("../data/orthogonal_win_dataset.json", "w"), indent=4)

In [24]:
dataset_per_length[9]

[{'prompt': 'Traffic started in',
  'target': ' Birmingham',
  'premise': 'Traffic started in�� Traffic started in',
  'orthogonal_token': '��',
  'length': 9},
 {'prompt': 'Editors formed in',
  'target': ' Birmingham',
  'premise': 'Editors formed in outline Editors formed in',
  'orthogonal_token': ' outline',
  'length': 9},
 {'prompt': 'Prison originated in',
  'target': ' Sweden',
  'premise': 'Prison originated in Miner Prison originated in',
  'orthogonal_token': ' Miner',
  'length': 9},
 {'prompt': 'Sarkar from',
  'target': ' India',
  'premise': 'Sarkar fromcha Sarkar from',
  'orthogonal_token': 'cha',
  'length': 9},
 {'prompt': 'Casey Abrams performs',
  'target': ' jazz',
  'premise': 'Casey Abrams performs Hou Casey Abrams performs',
  'orthogonal_token': ' Hou',
  'length': 9},
 {'prompt': 'Foolproof from',
  'target': ' Canada',
  'premise': 'Foolproof from right Foolproof from',
  'orthogonal_token': ' right',
  'length': 9},
 {'prompt': 'Mitt Romney speaks',
  'tar