In [1]:
!nvidia-smi

Sun Jun 26 17:34:50 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:4F:00.0 Off |                    0 |
| N/A   35C    P0    82W / 300W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install unseal -q

In [3]:
!pip install transformers -q

In [4]:
from typing import Union, Tuple, Optional

import matplotlib.pyplot as plt
from matplotlib import cm
from statistics import stdev
from matplotlib.colors import BoundaryNorm
import numpy as np
import random
import json
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import unseal
from unseal.hooks import Hook, HookedModel, common_hooks
from unseal.transformers_util import load_from_pretrained, get_num_layers

import sys
sys.path.append("../../../lib")
import utility
import hooks

In [5]:
unhooked_model, tokenizer, config = load_from_pretrained('gpt2-xl')

In [6]:
model = HookedModel(unhooked_model)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

print(f'Model loaded and on device {device}!')

Model loaded and on device cuda!


In [7]:
# layer names
print(list(model.layers.keys()))

['transformer', 'transformer->wte', 'transformer->wpe', 'transformer->drop', 'transformer->h', 'transformer->h->0', 'transformer->h->0->ln_1', 'transformer->h->0->attn', 'transformer->h->0->attn->c_attn', 'transformer->h->0->attn->c_proj', 'transformer->h->0->attn->attn_dropout', 'transformer->h->0->attn->resid_dropout', 'transformer->h->0->ln_2', 'transformer->h->0->mlp', 'transformer->h->0->mlp->c_fc', 'transformer->h->0->mlp->c_proj', 'transformer->h->0->mlp->act', 'transformer->h->0->mlp->dropout', 'transformer->h->1', 'transformer->h->1->ln_1', 'transformer->h->1->attn', 'transformer->h->1->attn->c_attn', 'transformer->h->1->attn->c_proj', 'transformer->h->1->attn->attn_dropout', 'transformer->h->1->attn->resid_dropout', 'transformer->h->1->ln_2', 'transformer->h->1->mlp', 'transformer->h->1->mlp->c_fc', 'transformer->h->1->mlp->c_proj', 'transformer->h->1->mlp->act', 'transformer->h->1->mlp->dropout', 'transformer->h->2', 'transformer->h->2->ln_1', 'transformer->h->2->attn', 'tra

In [8]:
# load data
f = open("../../../datasets/known_1000.json")
prompts = json.load(f)

In [9]:
# prepare data
texts = [(prompts[i]['subject'], prompts[i]['prompt'] + ' ' + prompts[i]['attribute']) for i in range(len(prompts))]
print(texts[0])

('Vinson Massif', 'Vinson Massif is located in the continent of Antarctica')


In [10]:
### Batching requires inputs to be of the same size so we put each prompt in 
### a bucket depending on the number of tokens it contains.

inputs = [[] for z in range(30)] ## use 30 as an upper bound on the number of tokens
for subject, text in random.sample(texts, 1000):
    encoded_text, target_id = utility.prepare_input(text, tokenizer, device)
    num_tokens = encoded_text['input_ids'].shape[1]
    
    print("num tokens: ", num_tokens, "text: ", text)
    
    # get tokens
    input_ids = encoded_text['input_ids'][0]
    tkens = [tokenizer.decode(input_ids[i]) for i in range(len(input_ids))]
    
    # get position of subject tokens
    subject_positions = utility.get_subject_positions(tkens, subject)
    indx_str = str(subject_positions[0]) + ":" + str(subject_positions[-1]+1)
    
    # add star to subject tokens to later do analysis
    start, end = subject_positions[0], subject_positions[1]
    while start < end:
        tkens[start] = tkens[start] + "*"
        start += 1
    
    # save uncorrupted states for patching
    num_layers = get_num_layers(model, layer_key_prefix='transformer->h')
    output_hooks_names = [f'transformer->h->{layer}' for layer in range(num_layers)]
    output_hooks = [Hook(name, utility.save_output, name) for name in output_hooks_names]
    model(**encoded_text, hooks=output_hooks)
    
    hidden_states = [model.save_ctx[f'transformer->h->{layer}']['output'][0].detach() for layer in range(num_layers)] # need detach
    
    # get probability with uncorrupted subject
    p = model(**encoded_text, hooks=[])['logits'].softmax(dim=-1)[0,-1,target_id].item()
    
    NUM_RUNS = 10 # ROME runs each input 10 times
    for i in range(NUM_RUNS):
        # add noise hook
        seed = np.random.randint(100, size=1)[0]
        noise_hook = Hook(
            layer_name='transformer->wte',
            func=hooks.additive_output_noise(indices=indx_str, mean=0, std=0.1, seed=seed),
            key='embedding_noise',
        )
    
        # get probability with corrupted subject
        p_star = model(**encoded_text, hooks=[noise_hook])['logits'].softmax(dim=-1)[0,-1,target_id].item()
        
        inputs[num_tokens].append((num_tokens, num_layers, encoded_text, hidden_states, tkens, text, target_id, indx_str, seed, p, p_star))


num tokens:  11 text:  John XXIII holds the title of "the most popular pope
num tokens:  11 text:  The capital of Central Bohemian Region is the city of Prague
num tokens:  7 text:  Kingdom of Egypt's capital, Cairo
num tokens:  7 text:  Dominic Behan writes in the Irish
num tokens:  9 text:  Brandon Jenkins (football player) plays as a linebacker
num tokens:  13 text:  Kyōto Prefecture, which was named after the city of Kyoto
num tokens:  6 text:  Lexus ES is developed by Toyota
num tokens:  10 text:  Sound Transit was formed in 1999 by a group of Seattle
num tokens:  9 text:  Hajime Mizoguchi was born in Tokyo
num tokens:  13 text:  Ritt Bjerregaard, who has a citizenship of Denmark
num tokens:  13 text:  Akira Toriyama's domain of work is the world of manga
num tokens:  7 text:  Final Fantasy Legend III is developed by Square
num tokens:  14 text:  Henry Somerset, 7th Duke of Beaufort worked in the city of London
num tokens:  10 text:  Jung Ryeo-won's profession is an actor
num toke

num tokens:  9 text:  Debenham Glacier belongs to the continent of Antarctica
num tokens:  8 text:  Windows Phone 8.1 is developed by Microsoft
num tokens:  7 text:  Audible.com is owned by Amazon
num tokens:  11 text:  Etobicoke North is located in the country of Canada
num tokens:  6 text:  Jeremy Vine is employed by the BBC
num tokens:  8 text:  Cosimo Fancelli was born in Rome
num tokens:  13 text:  Dildar Ali Naseerabadi follows the religion of Islam
num tokens:  9 text:  Mount Cocks is located in the continent of Antarctica
num tokens:  10 text:  Kluuvi is located in the country of Finland
num tokens:  9 text:  Hohenlohe, in the south of Germany
num tokens:  9 text:  Craig Federighi, who is employed by Apple
num tokens:  6 text:  Leonard Bernstein performs on the piano
num tokens:  12 text:  The native language of Edward Bulwer-Lytton is English
num tokens:  8 text:  Charlie Conacher died in the city of Toronto
num tokens:  13 text:  Nishi-Matsuura District is located in the coun

num tokens:  9 text:  The Young and the Restless premieres on CBS
num tokens:  15 text:  Tang Empire follows the religion of the same name, which is a mix of Buddhism
num tokens:  11 text:  In Palau, the language spoken is a mix of English
num tokens:  8 text:  Andheri, in the south of Mumbai
num tokens:  9 text:  Kirkpatrick Glacier belongs to the continent of Antarctica
num tokens:  9 text:  In State of Brazil, the language spoken is Portuguese
num tokens:  7 text:  MacApp, a product created by Apple
num tokens:  9 text:  Sakichi Toyoda is a citizen of Japan
num tokens:  7 text:  Outlook.com's owner, Microsoft
num tokens:  6 text:  Chris Stringer was born in London
num tokens:  7 text:  Russell Wilson professionally plays the sport of football
num tokens:  14 text:  The original language of Tusculanae Disputationes is a Latin
num tokens:  5 text:  The Good Wife debuted on CBS
num tokens:  9 text:  The native language of Gabrielle Fontan is French
num tokens:  11 text:  Faygo's headqu

num tokens:  8 text:  De-Phazz plays a lot of jazz
num tokens:  10 text:  Le Moniteur Universel is written in French
num tokens:  11 text:  The native language of Louis-Nicolas Davout is French
num tokens:  10 text:  The native language of Nikolay Strakhov is Russian
num tokens:  9 text:  The genre played by RuneScape is a fantasy
num tokens:  6 text:  Jeremy Clarkson is employed by the BBC
num tokens:  15 text:  The original language of Celia en el colegio is a mixture of Spanish
num tokens:  21 text:  The language used by Juan Bautista de Anza is a bit different from the language used by the Spanish
num tokens:  7 text:  Alice Coltrane performs on the piano
num tokens:  9 text:  Scooby Doo was originally aired on CBS
num tokens:  18 text:  Rheinmetall MAN Military Vehicles that was founded in 1885 and is headquartered in Munich
num tokens:  5 text:  Windows Me is developed by Microsoft
num tokens:  5 text:  Mac Pro is developed by Apple
num tokens:  7 text:  samurai cinema, that orig

num tokens:  11 text:  Choi Myeong-gil's profession is an actor
num tokens:  10 text:  Julian Huxley died in the city of London
num tokens:  12 text:  In Papua New Guinea, the language spoken is a mixture of English
num tokens:  12 text:  Charles Louis Alphonse Laveran died in the city of Paris
num tokens:  12 text:  Siemiatycze is located in the country of Poland
num tokens:  6 text:  Beijing, named after the capital
num tokens:  12 text:  Juninho Pernambucano professionally plays the sport of soccer
num tokens:  6 text:  The capital of West Pakistan is Karachi
num tokens:  23 text:  The headquarter of All India Anna Dravida Munnetra Kazhagam is located in the city of Chennai
num tokens:  7 text:  Melilla belongs to the continent of Africa
num tokens:  6 text:  Wednesday Night Baseball premieres on ESPN
num tokens:  6 text:  Showtime Networks's owner, CBS
num tokens:  14 text:  Quirinus of Sescia, who has the position of a bishop
num tokens:  9 text:  Sachimi Iwao is a citizen of Japa

num tokens:  9 text:  Roger Staubach professionally plays the sport of football
num tokens:  8 text:  Halle Berry, who works as a model
num tokens:  8 text:  David Dimbleby is employed by the BBC
num tokens:  10 text:  Louis XVII of France died in the city of Paris
num tokens:  10 text:  Lebedev Physical Institute's headquarters are in Moscow
num tokens:  5 text:  macOS is developed by Apple
num tokens:  7 text:  The native language of Claude Bernard is French
num tokens:  10 text:  Simcoe Composite School is located in the country of Canada
num tokens:  9 text:  Clifford Curzon, performing on the piano
num tokens:  13 text:  Sultan Satuq Bughra Khan follows the religion of Islam
num tokens:  14 text:  Ulrika Eleonora, Queen of Sweden died in the city of Stockholm
num tokens:  13 text:  The genre played by The Enchanter Reborn is a mix of fantasy
num tokens:  12 text:  Isola Dovarese is located in the country of Italy
num tokens:  10 text:  The language of La Hora was a mixture of Span

num tokens:  15 text:  Santo Stefano d'Aveto is located in the country of Italy
num tokens:  13 text:  Germanus of Auxerre, who has the position of a bishop
num tokens:  12 text:  King Chulalongkorn Memorial Hospital's headquarters are in Bangkok
num tokens:  7 text:  Acura ZDX is developed by Honda
num tokens:  8 text:  Platform Controller Hub is a product of the Intel
num tokens:  4 text:  YouTube's owner, Google
num tokens:  10 text:  William McGillivray, who has a citizenship of Canada
num tokens:  7 text:  Michael Foot worked in the city of London
num tokens:  7 text:  Wii Balance Board is produced by Nintendo
num tokens:  11 text:  The native language of Aleksey Khomyakov is Russian
num tokens:  13 text:  Fluminense F.C. is located in the country of Brazil
num tokens:  10 text:  Ansovinus holds the position of the first bishop
num tokens:  9 text:  The native language of Nicolaas Pierson is Dutch
num tokens:  10 text:  Fifth Avenue can be found in the heart of Manhattan
num token

num tokens:  8 text:  John Broadwood died in the city of London
num tokens:  5 text:  Charlie Hebdo is written in French
num tokens:  5 text:  Android TV is developed by Google
num tokens:  11 text:  José Canseco professionally plays the sport of baseball
num tokens:  7 text:  Jon Ronson is employed by the BBC
num tokens:  8 text:  La Voz del Interior is written in Spanish
num tokens:  4 text:  Egypt's capital, Cairo
num tokens:  10 text:  The native language of Paul Klebnikov is Russian
num tokens:  6 text:  Killer Mike is native to Atlanta
num tokens:  9 text:  Mawson Glacier belongs to the continent of Antarctica
num tokens:  18 text:  Bundesautobahn 4, by the way, is the most popular route in Germany
num tokens:  8 text:  Masako Natsume was born in Tokyo
num tokens:  10 text:  iTunes Radio was created by the folks at Apple
num tokens:  10 text:  The native language of Sergey Aksyonov is Russian
num tokens:  6 text:  Simon McCoy is employed by the BBC
num tokens:  9 text:  Lufkin, i

num tokens:  7 text:  Gita Sahgal was born in Mumbai
num tokens:  14 text:  East Japan Railway Company, that was formed in 1894, and the Tokyo
num tokens:  9 text:  The native language of Pierre Blanchar is French
num tokens:  10 text:  La Fontaine's Fables, that originated in France
num tokens:  12 text:  United Launch Alliance, by contrast, is a joint venture between Boeing
num tokens:  11 text:  Paanch was created in the country of his birth, India
num tokens:  10 text:  The capital city of People's Republic of Poland is Warsaw
num tokens:  8 text:  Gaetano Moroni passed away in Rome
num tokens:  9 text:  Camp Academia is located in the continent of Antarctica
num tokens:  18 text:  Giuseppe Meazza Stadium (San Siro) is owned by the city of Milan
num tokens:  14 text:  St Patrick's Athletic F.C. is headquartered in the heart of Dublin
num tokens:  5 text:  Final Fantasy is developed by Square
num tokens:  14 text:  The location of Concordia University is in the heart of the city of 

In [11]:
# statistics on number of prompts for each length tokens
for i in range(len(inputs)):
    length = len(inputs[i])
    if length != 0:
        print("num tokens: ", i, "number of prompts: ", int(length/NUM_RUNS))

num tokens:  4 number of prompts:  9
num tokens:  5 number of prompts:  39
num tokens:  6 number of prompts:  98
num tokens:  7 number of prompts:  121
num tokens:  8 number of prompts:  133
num tokens:  9 number of prompts:  129
num tokens:  10 number of prompts:  135
num tokens:  11 number of prompts:  107
num tokens:  12 number of prompts:  63
num tokens:  13 number of prompts:  43
num tokens:  14 number of prompts:  47
num tokens:  15 number of prompts:  21
num tokens:  16 number of prompts:  17
num tokens:  17 number of prompts:  14
num tokens:  18 number of prompts:  15
num tokens:  19 number of prompts:  4
num tokens:  20 number of prompts:  1
num tokens:  21 number of prompts:  1
num tokens:  23 number of prompts:  2
num tokens:  27 number of prompts:  1


In [12]:
# remove input buckets with 0 prompts
inputs = [inp for inp in inputs if inp != []]

In [13]:
# reverse the inputs so that you can tune the batch size on the larger number of tokens
inputs.reverse()

In [14]:
results = {}
results_perc_impr = {}
batch_schedule = {i : int(1000/i) for i in range(1, 30)}
for inp in inputs:
    print(f"Number of prompts in this bucket: {len(inp)}.")
    
    batch_size = batch_schedule[inp[0][0]]
    batched_data = DataLoader(inp, batch_size=batch_size)
    for data in batched_data:
        curr_batch_size = data[0].size()[0]
        print("This batch has", curr_batch_size, "examples.")
        
        num_tokens = data[0][0].item()
        num_layers = data[1][0].item()
        encoded_text = data[2]
        hidden_states = data[3]
        tkens = [[t[i] for t in data[4]] for i in range(curr_batch_size)]
        text = [t for t in data[5]]
        target_ids = [t.item() for t in data[6]]
        indx_strs = [s for s in data[7]]
        seeds = [seed.item() for seed in data[8]]
        p = [p.item() for p in data[9]]
        p_stars = [p_s.item() for p_s in data[10]]

        noise_hooks = [] # can probably be batched [will take a look later on]
        for i in range(curr_batch_size):
            noise_hook = Hook(
                layer_name='transformer->wte',
                func=hooks.additive_output_noise(indices=indx_strs[i], mean=0, std=0.1, index=i, seed=seeds[i]),
                key='embedding_noise',
            )
            noise_hooks.append(noise_hook)
            
        for layer in tqdm(range(num_layers)):
            for position in range(num_tokens):
                hook = Hook(
                    layer_name=f'transformer->h->{layer}',
                    func= hooks.hidden_patch_hook_fn(layer, position, hidden_states),
                    key=f'patch_{layer}_pos{position}'
                )
                output = model(**encoded_text, hooks=noise_hooks+[hook])
                for i in range(curr_batch_size):
                    if text[i] not in results:
                        accessory = (text[i], tkens[i], num_layers, sum(p) / len(p), sum(p_stars) / len(p_stars))
                        results[text[i]] = (torch.zeros((num_tokens, num_layers)), accessory)
                        results_perc_impr[text[i]] = (torch.zeros((num_tokens, num_layers)), accessory)
                    p_star_h = torch.softmax(output["logits"][i][0,-1,:], 0)[target_ids[i]].item()
                    results[text[i]][0][position, layer] += p_star_h - p_stars[i]
                    if (p[i] - p_stars[i]) != 0:
                        results_perc_impr[text[i]][0][position, layer] += (p_star_h - p_stars[i]) / abs(p[i] - p_stars[i])

Number of prompts in this bucket: 10.
This batch has 10 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 20.
This batch has 20 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 10.
This batch has 10 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 10.
This batch has 10 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 40.
This batch has 40 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 150.
This batch has 55 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 55 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 40 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 140.
This batch has 58 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 58 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 24 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 170.
This batch has 62 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 62 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 46 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 210.
This batch has 66 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 66 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 66 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 12 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 470.
This batch has 71 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 71 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 71 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 71 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 71 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 71 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 44 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 430.
This batch has 76 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 76 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 76 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 76 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 76 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 50 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 630.
This batch has 83 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 49 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 1070.
This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 80 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 1350.
This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 50 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 1290.
This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 69 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 1330.
This batch has 125 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 80 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 1210.
This batch has 142 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 74 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 980.
This batch has 166 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 166 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 166 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 166 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 166 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 150 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 390.
This batch has 200 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

This batch has 190 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

Number of prompts in this bucket: 90.
This batch has 90 examples.


  0%|          | 0/48 [00:00<?, ?it/s]

In [15]:
# Need to divide by NUM_RUNS
res = []
for _, value in results.items():
    result, accessory = value
    result = result / NUM_RUNS
    res.append((result, accessory))
    
res_perc_impr = []
for _, value in results_perc_impr.items():
    result, accessory = value
    result = result / NUM_RUNS
    res_perc_impr.append((result, accessory))

In [16]:
torch.save(res, "../data/indirect_effect_1000_examples.pt")
torch.save(res_perc_impr, "../data/percent_improvement_1000_examples.pt")