In [1]:
!nvidia-smi

Sat Jun 25 16:32:12 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 510.47.03    Driver Version: 510.47.03    CUDA Version: 11.6     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA A100 80G...  On   | 00000000:52:00.0 Off |                    0 |
| N/A   52C    P0    87W / 300W |      0MiB / 81920MiB |      0%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA A100 80G...  On   | 00000000:56:00.0 Off |                    0 |
| N/A   32C    P0    53W / 300W |      0MiB / 81920MiB |      0%      Default |
|       

In [2]:
!pip install unseal -q

In [3]:
!pip install transformers -q

In [4]:
from typing import Union, Tuple, Optional

import matplotlib.pyplot as plt
from matplotlib import cm
from statistics import stdev
from matplotlib.colors import BoundaryNorm
import numpy as np
import random
import json
import torch
from tqdm.notebook import tqdm
import unseal
from unseal.hooks import Hook, HookedModel, common_hooks
from unseal.transformers_util import load_from_pretrained, get_num_layers

import sys
sys.path.append("../")
import utility
import hooks

In [5]:
unhooked_model, tokenizer, config = load_from_pretrained('EleutherAI/gpt-j-6B')

In [6]:
model = HookedModel(unhooked_model)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)

print(f'Model loaded and on device {device}!')

Model loaded and on device cuda!


In [7]:
# layer names
print(list(model.layers.keys()))

['transformer', 'transformer->wte', 'transformer->drop', 'transformer->h', 'transformer->h->0', 'transformer->h->0->ln_1', 'transformer->h->0->attn', 'transformer->h->0->attn->attn_dropout', 'transformer->h->0->attn->resid_dropout', 'transformer->h->0->attn->k_proj', 'transformer->h->0->attn->v_proj', 'transformer->h->0->attn->q_proj', 'transformer->h->0->attn->out_proj', 'transformer->h->0->mlp', 'transformer->h->0->mlp->fc_in', 'transformer->h->0->mlp->fc_out', 'transformer->h->0->mlp->act', 'transformer->h->0->mlp->dropout', 'transformer->h->1', 'transformer->h->1->ln_1', 'transformer->h->1->attn', 'transformer->h->1->attn->attn_dropout', 'transformer->h->1->attn->resid_dropout', 'transformer->h->1->attn->k_proj', 'transformer->h->1->attn->v_proj', 'transformer->h->1->attn->q_proj', 'transformer->h->1->attn->out_proj', 'transformer->h->1->mlp', 'transformer->h->1->mlp->fc_in', 'transformer->h->1->mlp->fc_out', 'transformer->h->1->mlp->act', 'transformer->h->1->mlp->dropout', 'transf

In [8]:
# load data
f = open("../../data/known_1000.json")
prompts = json.load(f)

In [9]:
# prepare data
texts = [(prompts[i]['subject'], prompts[i]['prompt'] + ' ' + prompts[i]['attribute']) for i in range(len(prompts))]
print(texts[0])

('Vinson Massif', 'Vinson Massif is located in the continent of Antarctica')


In [19]:
### Batching requires inputs to be of the same size so we put each prompt in 
### a bucket depending on the number of tokens it contains.

inputs = [[] for z in range(30)] ## use 30 as an upper bound on the number of tokens
for subject, text in random.sample(texts, 1):
    encoded_text, target_id = utility.prepare_input(text, tokenizer, device)
    num_tokens = encoded_text['input_ids'].shape[1]
    
    print("num tokens: ", num_tokens, "text: ", text)
    
    # get tokens
    input_ids = encoded_text['input_ids'][0]
    tkens = [tokenizer.decode(input_ids[i]) for i in range(len(input_ids))]
    
    # get position of subject tokens
    subject_positions = utility.get_subject_positions(tkens, subject)
    indx_str = str(subject_positions[0]) + ":" + str(subject_positions[-1]+1)
    
    # add star to subject tokens to later do analysis
    start, end = subject_positions[0], subject_positions[1]
    while start < end:
        tkens[start] = tkens[start] + "*"
        start += 1
    
    # save uncorrupted states for patching
    num_layers = get_num_layers(model, layer_key_prefix='transformer->h')
    output_hooks_names = [f'transformer->h->{layer}' for layer in range(num_layers)]
    output_hooks = [Hook(name, utility.save_output, name) for name in output_hooks_names]
    model(**encoded_text, hooks=output_hooks)
    
    hidden_states = [model.save_ctx[f'transformer->h->{layer}']['output'][0].detach() for layer in range(num_layers)] # need detach
    
    # get probability with uncorrupted subject
    p = model(**encoded_text, hooks=[])['logits'].softmax(dim=-1)[0,-1,target_id].item()
    
    NUM_RUNS = 10 # ROME runs each input 10 times
    for i in range(NUM_RUNS):
        # add noise hook
        seed = np.random.randint(100, size=1)[0]
        noise_hook = Hook(
            layer_name='transformer->wte',
            func=hooks.additive_output_noise(indices=indx_str, mean=0, std=0.1, seed=seed),
            key='embedding_noise',
        )
    
        # get probability with corrupted subject
        p_star = model(**encoded_text, hooks=[noise_hook])['logits'].softmax(dim=-1)[0,-1,target_id].item()
        
        inputs[num_tokens].append((num_tokens, num_layers, encoded_text, hidden_states, tkens, text, target_id, indx_str, seed, p, p_star))


num tokens:  6 text:  iPod Touch is developed by Apple


In [20]:
# statistics on number of prompts for each length tokens
for i in range(len(inputs)):
    length = len(inputs[i])
    if length != 0:
        print("num tokens: ", i, "number of prompts: ", int(length/NUM_RUNS))

num tokens:  6 number of prompts:  1


In [21]:
# remove input buckets with 0 prompts
inputs = [inp for inp in inputs if inp != []]

In [13]:
# deal with longest prompts first :)
inputs.reverse()

In [39]:
from torch.utils.data import DataLoader

results = {}
batch_schedule = {i : int(1250/i) for i in range(1, 30)}
for inp in inputs:
    print(f"Number of prompts in this bucket: {len(inp)}.")
    
    batch_size = batch_schedule[inp[0][0]]
    batched_data = DataLoader(inp, batch_size=batch_size)
    for data in batched_data:
        curr_batch_size = data[0].size()[0]
        print("This batch has", curr_batch_size, "examples.")
        
        num_tokens = data[0][0].item()
        num_layers = data[1][0].item()
        encoded_text = data[2]
        hidden_states = data[3]
        tkens = [[t[i] for t in data[4]] for i in range(curr_batch_size)]
        text = [t for t in data[5]]
        target_ids = [t.item() for t in data[6]]
        indx_strs = [s for s in data[7]]
        seeds = [seed.item() for seed in data[8]]
        p = [p.item() for p in data[9]]
        p_stars = [p_s.item() for p_s in data[10]]

        noise_hooks = [] # can probably be batched [will take a look later on]
        for i in range(curr_batch_size):
            noise_hook = Hook(
                layer_name='transformer->wte',
                func=hooks.additive_output_noise(indices=indx_strs[i], mean=0, std=0.1, index=i, seed=seeds[i]),
                key='embedding_noise',
            )
            noise_hooks.append(noise_hook)
            
        for layer in tqdm(range(num_layers)):
            for position in range(num_tokens):
                hook = Hook(
                    layer_name=f'transformer->h->{layer}',
                    func= hooks.hidden_patch_hook_fn(layer, position, hidden_states),
                    key=f'patch_{layer}_pos{position}'
                )
                output = model(**encoded_text, hooks=noise_hooks+[hook])
                for i in range(curr_batch_size):
                    if text[i] not in results:
                        results[text[i]] = (torch.zeros((num_tokens, num_layers)), (tkens[i], num_layers, sum(p) / len(p), sum(p_stars) / len(p_stars)))
                    results[text[i]][0][position, layer] += torch.softmax(output["logits"][i][0,-1,:], 0)[target_ids[i]].item() - p_stars[i]

Number of prompts in this bucket: 10.
This batch has 10 examples.


  0%|          | 0/28 [00:00<?, ?it/s]

In [40]:
# Need to divide by NUM_RUNS
res = []
for text, value in results.items():
    result, access = value
    result = result / NUM_RUNS
    res.append((result, access))

In [41]:
torch.save(res, "data.pt")

In [42]:
res = torch.load("data.pt")

In [43]:
res

[(tensor([[ 1.4021e-03,  1.8721e-03,  1.0469e-03,  3.5102e-04,  1.0320e-04,
            7.8951e-05,  1.2838e-06, -1.9918e-05,  4.8063e-05,  9.1360e-05,
            7.8987e-05,  7.9560e-05,  4.5273e-05,  1.9115e-05,  1.0897e-04,
            4.7719e-05,  3.4097e-05,  8.0088e-05,  4.1766e-05,  2.6896e-05,
           -8.9898e-07, -1.1589e-05,  1.1699e-05,  1.8563e-05,  6.5224e-06,
            2.9394e-05,  1.2499e-06,  4.4009e-08],
          [ 6.6839e-05,  1.8172e-02,  1.8861e-02,  3.3730e-02,  2.0518e-02,
            2.0796e-02,  4.2322e-02,  4.3422e-02,  3.6366e-02,  1.8828e-02,
            2.1778e-02,  1.9032e-02,  1.6895e-02,  1.7516e-02,  1.3520e-02,
            1.4064e-02,  5.9119e-03,  5.0641e-03,  5.1391e-03,  4.9349e-03,
            2.2478e-03,  1.9256e-03,  1.4715e-03,  1.8158e-05, -4.7110e-06,
            1.5572e-05,  4.9942e-06,  4.4009e-08],
          [ 5.4772e-03,  4.2587e-02,  5.2087e-02,  1.3120e-01,  1.3283e-01,
            1.9141e-01,  1.6652e-01,  1.8327e-01,  1.6097e-01,