In [1]:
!nvidia-smi

Mon Jun 27 17:50:51 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:4F:00.0 Off |                    0 |
| N/A   33C    P0    82W / 300W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install unseal -q

In [3]:
!pip install transformers -q

In [4]:
from typing import Union, Tuple, Optional

import matplotlib.pyplot as plt
from matplotlib import cm
from statistics import stdev
from matplotlib.colors import BoundaryNorm
import numpy as np
import random
import json
import torch
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm
import unseal
from unseal.hooks import Hook, HookedModel, common_hooks
from unseal.transformers_util import load_from_pretrained, get_num_layers

import sys
sys.path.append("../../../lib")
import utility
import hooks

In [5]:
unhooked_model, tokenizer, config = load_from_pretrained('EleutherAI/gpt-j-6B')

In [6]:
model = HookedModel(unhooked_model)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

print(f'Model loaded and on device {device}!')

Model loaded and on device cuda!


In [7]:
# layer names
print(list(model.layers.keys()))

['transformer', 'transformer->wte', 'transformer->drop', 'transformer->h', 'transformer->h->0', 'transformer->h->0->ln_1', 'transformer->h->0->attn', 'transformer->h->0->attn->attn_dropout', 'transformer->h->0->attn->resid_dropout', 'transformer->h->0->attn->k_proj', 'transformer->h->0->attn->v_proj', 'transformer->h->0->attn->q_proj', 'transformer->h->0->attn->out_proj', 'transformer->h->0->mlp', 'transformer->h->0->mlp->fc_in', 'transformer->h->0->mlp->fc_out', 'transformer->h->0->mlp->act', 'transformer->h->0->mlp->dropout', 'transformer->h->1', 'transformer->h->1->ln_1', 'transformer->h->1->attn', 'transformer->h->1->attn->attn_dropout', 'transformer->h->1->attn->resid_dropout', 'transformer->h->1->attn->k_proj', 'transformer->h->1->attn->v_proj', 'transformer->h->1->attn->q_proj', 'transformer->h->1->attn->out_proj', 'transformer->h->1->mlp', 'transformer->h->1->mlp->fc_in', 'transformer->h->1->mlp->fc_out', 'transformer->h->1->mlp->act', 'transformer->h->1->mlp->dropout', 'transf

In [8]:
# load data
f = open("../../../datasets/known_1000.json")
prompts = json.load(f)

In [9]:
# prepare data
texts = [(prompts[i]['subject'], prompts[i]['prompt'] + ' ' + prompts[i]['attribute']) for i in range(len(prompts))]
print(texts[0])

('Vinson Massif', 'Vinson Massif is located in the continent of Antarctica')


In [10]:
### Batching requires inputs to be of the same size so we put each prompt in 
### a bucket depending on the number of tokens it contains.

inputs = [[] for z in range(30)] ## use 30 as an upper bound on the number of tokens
for subject, text in random.sample(texts, 1000):
    encoded_text, target_id = utility.prepare_input(text, tokenizer, device)
    num_tokens = encoded_text['input_ids'].shape[1]
    
    print("num tokens: ", num_tokens, "text: ", text)
    
    # get tokens
    input_ids = encoded_text['input_ids'][0]
    tkens = [tokenizer.decode(input_ids[i]) for i in range(len(input_ids))]
    
    # get position of subject tokens
    subject_positions = utility.get_subject_positions(tkens, subject)
    indx_str = str(subject_positions[0]) + ":" + str(subject_positions[-1]+1)
    
    # add star to subject tokens to later do analysis
    start, end = subject_positions[0], subject_positions[1]
    while start < end:
        tkens[start] = tkens[start] + "*"
        start += 1
    
    # save uncorrupted states for patching
    num_layers = get_num_layers(model, layer_key_prefix='transformer->h')
    output_hooks_names = [f'transformer->h->{layer}' for layer in range(num_layers)]
    output_hooks = [Hook(name, utility.save_output, name) for name in output_hooks_names]
    model(**encoded_text, hooks=output_hooks)
    
    hidden_states = [model.save_ctx[f'transformer->h->{layer}']['output'][0].detach() for layer in range(num_layers)] # need detach
    
    # get probability with uncorrupted subject
    p = model(**encoded_text, hooks=[])['logits'].softmax(dim=-1)[0,-1,target_id].item()
    
    NUM_RUNS = 10 # ROME runs each input 10 times
    for i in range(NUM_RUNS):
        # add noise hook
        seed = np.random.randint(100, size=1)[0]
        noise_hook = Hook(
            layer_name='transformer->wte',
            func=hooks.additive_output_noise(indices=indx_str, mean=0, std=0.025, seed=seed),
            key='embedding_noise',
        )
    
        # get probability with corrupted subject
        p_star = model(**encoded_text, hooks=[noise_hook])['logits'].softmax(dim=-1)[0,-1,target_id].item()
        
        inputs[num_tokens].append((num_tokens, num_layers, encoded_text, hidden_states, tkens, text, target_id, indx_str, seed, p, p_star))


num tokens:  5 text:  Windows Me is developed by Microsoft
num tokens:  6 text:  Bing Maps's owner, Microsoft
num tokens:  6 text:  Patrick Henry College is located in Virginia
num tokens:  7 text:  Gary Lineker is employed by the BBC
num tokens:  11 text:  Nonchan Noriben was created in the country of Japan
num tokens:  5 text:  Hawaii's capital, Honolulu
num tokens:  10 text:  Paul IV holds the title of "the most powerful pope
num tokens:  8 text:  RIA Novosti is written in Russian
num tokens:  9 text:  The capital city of Abbasid Caliphate is Baghdad
num tokens:  10 text:  Le Moniteur Universel is written in French
num tokens:  10 text:  Jan Swammerdam died in the city of Amsterdam
num tokens:  7 text:  Joseph Locke worked in the city of London
num tokens:  9 text:  Galen's expertise is in the field of medicine
num tokens:  7 text:  Namor is affiliated with the Avengers
num tokens:  8 text:  Victoria Derbyshire is employed by the BBC
num tokens:  7 text:  Eavan Boland was born in Du

num tokens:  10 text:  Matias Kupiainen was born in Helsinki
num tokens:  9 text:  Masaccio, who has a citizenship of Italy
num tokens:  5 text:  Denmark's capital, Copenhagen
num tokens:  6 text:  iPhone 3GS is developed by Apple
num tokens:  8 text:  Ray Kurzweil is employed by Google
num tokens:  6 text:  Internet Explorer 10 is developed by Microsoft
num tokens:  13 text:  Gregory of Nazianzus, who has the position of bishop
num tokens:  4 text:  Teen Mom debuted on MTV
num tokens:  10 text:  King David Hotel bombing is located in the city of Jerusalem
num tokens:  11 text:  Partick Thistle F.C. is based in Glasgow
num tokens:  8 text:  Tokyo Mew Mew, that originated in Japan
num tokens:  7 text:  Datsun Sports, developed by Nissan
num tokens:  14 text:  The original language of La Fontaine's Fables is a mixture of French
num tokens:  10 text:  The native language of Anne-Marie Idrac is French
num tokens:  6 text:  Aretha Franklin, playing the piano
num tokens:  9 text:  The offici

num tokens:  8 text:  Pius III died in the city of Rome
num tokens:  14 text:  The location of Concordia University is in the heart of the city of Montreal
num tokens:  14 text:  Leo XIII, whose position is that of a "superior" pope
num tokens:  9 text:  Craig Federighi, who is employed by Apple
num tokens:  11 text:  The native language of Louis-Nicolas Davout is French
num tokens:  4 text:  Ukraine's capital, Kiev
num tokens:  15 text:  The native language of Nathalie Kosciusko-Morizet is French
num tokens:  11 text:  The original language of Face Dances is a mixture of English
num tokens:  6 text:  Jonathan Pearce is employed by the BBC
num tokens:  4 text:  Givenchy originated in Paris
num tokens:  8 text:  Angela Ahrendts is employed by Apple
num tokens:  8 text:  Mount Foster is located in the continent of Antarctica
num tokens:  9 text:  Charles Nodier died in the city of Paris
num tokens:  6 text:  Acura RL is developed by Honda
num tokens:  14 text:  Co-operative Commonwealth 

num tokens:  10 text:  The language of Dwynwen is a mixture of Welsh
num tokens:  8 text:  Delmarva Peninsula was named for the Delaware
num tokens:  7 text:  The Thin Blue Line debuted on the BBC
num tokens:  11 text:  Hidamari Sketch was created in the country of Japan
num tokens:  11 text:  The profession of Martha Nussbaum is to be a philosopher
num tokens:  9 text:  James Northcote died in the city of London
num tokens:  6 text:  Joe Louis Arena, from the Detroit
num tokens:  9 text:  Trento is located in the country of Italy
num tokens:  5 text:  Windows Mobile is developed by Microsoft
num tokens:  9 text:  Sakichi Toyoda is a citizen of Japan
num tokens:  6 text:  Game Boy Advance is developed by Nintendo
num tokens:  11 text:  Georges Ernest Boulanger worked in the city of Paris
num tokens:  8 text:  Mount Discovery is located in the continent of Antarctica
num tokens:  11 text:  Philippus van Limborch worked in the city of Amsterdam
num tokens:  15 text:  Il Gazzettino was wr

num tokens:  14 text:  Marie François Oscar Bardy de Fourtou worked in the city of Paris
num tokens:  16 text:  Jennie Lee, Baroness Lee of Asheridge worked in the city of London
num tokens:  14 text:  The location of Hualapai people is in the northern part of Arizona
num tokens:  8 text:  In Chiasso, the language spoken is Italian
num tokens:  10 text:  Mohsen Mirdamadi is a citizen of Iran
num tokens:  6 text:  Melodiya is headquartered in Moscow
num tokens:  11 text:  Vita Semerenko, who has a citizenship of Ukraine
num tokens:  13 text:  Hope and Anchor, Islington is located in the heart of London
num tokens:  11 text:  Gianni Ferrio, who has a citizenship of Italy
num tokens:  7 text:  tatami, that originated in Japan
num tokens:  11 text:  In Ruokolahti, the language spoken is Finnish
num tokens:  7 text:  Steve Claridge is employed by the BBC
num tokens:  9 text:  Scooby Doo was originally aired on CBS
num tokens:  9 text:  Kaliakra Glacier belongs to the continent of Antarctica

num tokens:  8 text:  Daiki Arioka's profession is an actor
num tokens:  12 text:  Caterina Davinio, who has a citizenship of Italy
num tokens:  15 text:  The headquarter of Minnesota Strikers is located in the heart of downtown Minneapolis
num tokens:  16 text:  Mitsubishi Electric started in the early 1900s as a small company in Tokyo
num tokens:  7 text:  Acura ILX is developed by Honda
num tokens:  11 text:  The native language of Irina Khakamada is Russian
num tokens:  12 text:  Ylvis was created in the country of their birth, Norway
num tokens:  7 text:  Acura ZDX is developed by Honda
num tokens:  12 text:  Willem Wilmink speaks to the media after the Dutch
num tokens:  8 text:  Alfa Romeo 155, produced by Fiat
num tokens:  10 text:  The native language of Roger Garaudy is French
num tokens:  9 text:  Semyon Vorontsov was born in Moscow
num tokens:  12 text:  The headquarter of BC Hydro is in the heart of downtown Vancouver
num tokens:  6 text:  College Football Live premieres o

num tokens:  10 text:  Al-Mutawakkil follows the religion of Islam
num tokens:  10 text:  O'Hare International Airport's owner, the Chicago
num tokens:  18 text:  Birobidzhaner Shtern was written in the late 19th century by a Russian
num tokens:  10 text:  Pervez Musharraf follows the religion of Islam
num tokens:  14 text:  Pius VIII, whose position is that of a "superior" pope
num tokens:  5 text:  SpeedWeek debuted on the ESPN
num tokens:  11 text:  Gilles Marie Oppenord died in the city of Paris
num tokens:  12 text:  Acca of Hexham holds the position of the first female bishop
num tokens:  8 text:  Daniele Franceschini was born in Rome
num tokens:  7 text:  Hyder Ali follows the religion of Islam
num tokens:  9 text:  Kirkpatrick Glacier belongs to the continent of Antarctica
num tokens:  8 text:  Van Cliburn, performing on the piano
num tokens:  20 text:  The original language of Kadhalil Sodhappuvadhu Yeppadi was written in the Tamil
num tokens:  6 text:  Iron Man is affiliated 

num tokens:  13 text:  Kyōto Prefecture, which was named after the city of Kyoto
num tokens:  9 text:  EA-18G Growler, developed by Boeing
num tokens:  10 text:  Galileo Galilei works in the area of astronomy
num tokens:  17 text:  The language used by Pierre de Marca is not the same as that used by the French
num tokens:  8 text:  George Newnes worked in the city of London
num tokens:  12 text:  Juninho Pernambucano professionally plays the sport of soccer
num tokens:  8 text:  Masoretic Text is written in the Hebrew
num tokens:  8 text:  Alexandre Mercereau was born in Paris
num tokens:  6 text:  Gwen Ifill, of PBS
num tokens:  11 text:  Sialkot district is located in the country of Pakistan
num tokens:  11 text:  In Nyon, the language spoken is a mixture of French
num tokens:  10 text:  In Indiana, the language spoken is a mixture of English
num tokens:  11 text:  Canada men's national soccer team is a part of the FIFA
num tokens:  12 text:  North Berwick can be found in the south-w

num tokens:  13 text:  Ritt Bjerregaard, who has a citizenship of Denmark
num tokens:  5 text:  Meet the Press debuted on NBC
num tokens:  13 text:  Akira Toriyama's domain of work is the world of manga
num tokens:  7 text:  In Nokia, the language spoken is Finnish
num tokens:  10 text:  The native language of Valentin Rasputin is Russian
num tokens:  8 text:  Eiko Shimamiya is a citizen of Japan
num tokens:  7 text:  Hungarian Soviet Republic's capital, Budapest
num tokens:  13 text:  Al-Hasakah Governorate is located in the country of Syria
num tokens:  16 text:  Abdalqadir as-Sufi is affiliated with the religion of Islam
num tokens:  7 text:  Peter Fincham is employed by the BBC
num tokens:  8 text:  Gangnam Station's owner, the Seoul
num tokens:  11 text:  Amol Palekar originates from the city of Mumbai
num tokens:  8 text:  Commonwealth of the Philippines's capital, Manila
num tokens:  11 text:  The original language of Il Posto is a mixture of Italian
num tokens:  11 text:  The o

In [11]:
# statistics on number of prompts for each length tokens
for i in range(len(inputs)):
    length = len(inputs[i])
    if length != 0:
        print("num tokens: ", i, "number of prompts: ", int(length/NUM_RUNS))

num tokens:  4 number of prompts:  11
num tokens:  5 number of prompts:  41
num tokens:  6 number of prompts:  94
num tokens:  7 number of prompts:  105
num tokens:  8 number of prompts:  141
num tokens:  9 number of prompts:  138
num tokens:  10 number of prompts:  136
num tokens:  11 number of prompts:  107
num tokens:  12 number of prompts:  59
num tokens:  13 number of prompts:  49
num tokens:  14 number of prompts:  43
num tokens:  15 number of prompts:  22
num tokens:  16 number of prompts:  15
num tokens:  17 number of prompts:  17
num tokens:  18 number of prompts:  10
num tokens:  19 number of prompts:  4
num tokens:  20 number of prompts:  2
num tokens:  21 number of prompts:  2
num tokens:  22 number of prompts:  1
num tokens:  23 number of prompts:  2
num tokens:  25 number of prompts:  1


In [12]:
# remove input buckets with 0 prompts
inputs = [inp for inp in inputs if inp != []]

In [13]:
# deal with longest prompts first :)
inputs.reverse()

In [14]:
results = {}
results_perc_impr = {}
batch_schedule = {i : int(1000/i) for i in range(1, 30)}
for inp in inputs:
    print(f"Number of prompts in this bucket: {len(inp)}.")
    
    batch_size = batch_schedule[inp[0][0]]
    batched_data = DataLoader(inp, batch_size=batch_size)
    for data in batched_data:
        curr_batch_size = data[0].size()[0]
        print("This batch has", curr_batch_size, "examples.")
        
        num_tokens = data[0][0].item()
        num_layers = data[1][0].item()
        encoded_text = data[2]
        hidden_states = data[3]
        tkens = [[t[i] for t in data[4]] for i in range(curr_batch_size)]
        text = [t for t in data[5]]
        target_ids = [t.item() for t in data[6]]
        indx_strs = [s for s in data[7]]
        seeds = [seed.item() for seed in data[8]]
        p = [p.item() for p in data[9]]
        p_stars = [p_s.item() for p_s in data[10]]

        noise_hooks = [] # can probably be batched [will take a look later on]
        for i in range(curr_batch_size):
            noise_hook = Hook(
                layer_name='transformer->wte',
                func=hooks.additive_output_noise(indices=indx_strs[i], mean=0, std=0.025, index=i, seed=seeds[i]),
                key='embedding_noise',
            )
            noise_hooks.append(noise_hook)
            
        for layer in tqdm(range(num_layers)):
            for position in range(num_tokens):
                hook = Hook(
                    layer_name=f'transformer->h->{layer}',
                    func= hooks.hidden_patch_hook_fn(layer, position, hidden_states),
                    key=f'patch_{layer}_pos{position}'
                )
                output = model(**encoded_text, hooks=noise_hooks+[hook])
                for i in range(curr_batch_size):
                    if text[i] not in results:
                        accessory = (text[i], tkens[i], num_layers, sum(p) / len(p), sum(p_stars) / len(p_stars))
                        results[text[i]] = (torch.zeros((num_tokens, num_layers)), accessory)
                        results_perc_impr[text[i]] = (torch.zeros((num_tokens, num_layers)), accessory)
                    p_star_h = torch.softmax(output["logits"][i][0,-1,:], 0)[target_ids[i]].item()
                    results[text[i]][0][position, layer] += p_star_h - p_stars[i]
                    if (p[i] - p_stars[i]) != 0:
                        results_perc_impr[text[i]][0][position, layer] += (p_star_h - p_stars[i]) / abs(p[i] - p_stars[i])

Number of prompts in this bucket: 10.
This batch has 10 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 20.
This batch has 20 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 10.
This batch has 10 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 20.
This batch has 20 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 20.
This batch has 20 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 40.
This batch has 40 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 100.
This batch has 55 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 45 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 170.
This batch has 58 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 58 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 54 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 150.
This batch has 62 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 62 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 26 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 220.
This batch has 66 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 66 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 66 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 22 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 430.
This batch has 71 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 71 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 71 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 71 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 71 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 71 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 4 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 490.
This batch has 76 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 76 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 76 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 76 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 76 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 76 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 34 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 590.
This batch has 83 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 83 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 9 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 1070.
This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 90 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 80 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 1360.
This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 100 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 60 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 1380.
This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 111 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 48 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 1410.
This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 125 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 35 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 1050.
This batch has 142 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 142 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 56 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 940.
This batch has 166 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 166 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 166 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 166 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 166 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 110 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 410.
This batch has 200 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 200 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

This batch has 10 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

Number of prompts in this bucket: 110.
This batch has 110 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

In [17]:
# Need to divide by NUM_RUNS
res = []
for _, value in results.items():
    result, accessory = value
    result = result / NUM_RUNS
    res.append((result, accessory))
    
res_perc_impr = []
for _, value in results_perc_impr.items():
    result, accessory = value
    result = result / NUM_RUNS
    res_perc_impr.append((result, accessory))

In [18]:
torch.save(res, "../../data/gpt-j/indirect_effect_1000_examples.pt")
torch.save(res_perc_impr, "../../data/gpt-j/percent_improvement_1000_examples.pt")