In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter

In [3]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

EleutherAI/gpt-j-6B ==> device: cuda:0, memory: 12219206136


In [4]:
import pandas as pd

df = pd.read_csv("numbers_spelled_out.tsv", delimiter='\t')
objects = list(df['words'])

In [5]:
prompt = """three comes after two
six comes after five
{} comes after"""

# prompt = """grape ends with E
# monitor ends with R
# glass ends with"""

words = [ 'eleven', 'twelve', 'thirteen', 'fourteen']

for w in words:
    txt, ret_dict = model_utils.generate_fast(
        model, tokenizer, 
        prompts=[prompt.format(w)], max_new_tokens=10, 
        get_answer_tokens=True, argmax_greedy=True
    )
    print(f"{w} ===> {ret_dict['answer'][0]['top_token']}")


eleven ===>  ten
twelve ===>  eleven
thirteen ===>  twelve
fourteen ===>  thirteen


In [6]:
objects = [" " + o for o in df['words']]

from relations.corner import CornerEstimator
corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)

In [7]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner))

62.5 [' twelve', ' seven', ' fourteen', ' fifteen', ' six']


In [8]:
# lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)
# print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner))

In [9]:
# avg_corner = corner_estimator.estimate_average_corner_with_gradient_descent(objects, average_on=5, target_logit_value=50, verbose=False)
# print(avg_corner.norm().item(), corner_estimator.get_vocab_representation(avg_corner))

In [10]:
def check_with_test_cases(relation_operator):

    test_cases = [
        (objects[i+1], -1, objects[i]) for i in range(len(objects) - 1)
    ]

    for subject, subject_token_index, target in test_cases:
        answer = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {answer}")

In [11]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = simple_corner
)
check_with_test_cases(relation)

 two, target:  one   ==>   predicted: [' seven', ' twelve', ' eight', ' six', ' nine']
 three, target:  two   ==>   predicted: [' twelve', ' twenty', ' ten', ' seven', ' fourteen']
 four, target:  three   ==>   predicted: [' twelve', ' seven', ' ten', ' five', ' twenty']
 five, target:  four   ==>   predicted: [' twelve', ' nine', ' twenty', ' eight', ' seven']
 six, target:  five   ==>   predicted: [' seven', ' twelve', ' eight', ' five', ' ten']
 seven, target:  six   ==>   predicted: [' eight', ' twelve', ' nine', ' eleven', ' twenty']
 eight, target:  seven   ==>   predicted: [' twelve', ' seven', ' nine', ' ten', ' five']
 nine, target:  eight   ==>   predicted: [' twelve', ' seven', ' eight', ' ten', ' twenty']
 ten, target:  nine   ==>   predicted: [' twelve', ' eight', ' seven', ' twenty', ' eleven']
 eleven, target:  ten   ==>   predicted: [' twelve', ' eight', ' seven', ' twenty', ' six']
 twelve, target:  eleven   ==>   predicted: [' seven', ' five', ' eight', ' twelve', ' t

In [12]:
def get_averaged_JB(top_performers, relation_prompt, num_icl = 3, calculate_at_lnf = False):
    try:
        jbs = []
        for s, s_idx, o in tqdm(top_performers):
            others = set(top_performers) - {(s, s_idx, o)}
            others = random.sample(list(others), k = min(num_icl, len(list(others)))) 
            prompt = ""
            prompt += "\n".join(relation_prompt.format(s_other) + f" {o_other}." for s_other, idx_other, o_other in others) + "\n"
            prompt += relation_prompt
            print("subject: ", s)
            print(prompt)

            jb, _ = estimate.relation_operator_from_sample(
                model, tokenizer,
                s, prompt,
                subject_token_index= s_idx,
                layer = 15,
                device = model.device,
                # calculate_at_lnf = calculate_at_lnf
            )
            print(jb.weight.norm(), jb.bias.norm())
            print()
            jbs.append(jb)
        
        weight = torch.stack([jb.weight for jb in jbs]).mean(dim=0)
        bias  = torch.stack([jb.bias for jb in jbs]).mean(dim=0)

        return weight, bias
    except RuntimeError as e:
        if(str(e).startswith("CUDA out of memory")):
            print("CUDA out of memory")
        if(num_icl > 1):
            num_icl -= 1
            print("trying with smaller icl >> ", num_icl)
            return get_averaged_JB(top_performers, relation_prompt, num_icl, calculate_at_lnf)
        else:
            raise Exception("RuntimeError >> can't calculate Jacobian with minimum number of icl examples")

def get_multiple_averaged_JB(top_performers, relation_prompt, N = 3, num_icl = 2, calculate_at_lnf = False):
    weights_and_biases = []
    sample_size = min(len(top_performers), num_icl + 2)
    for _ in tqdm(range(N)):
        cur_sample = random.sample(top_performers, k = sample_size)
        weight, bias = get_averaged_JB(cur_sample, relation_prompt, num_icl, calculate_at_lnf)
        weights_and_biases.append({
            'weight': weight,
            'bias'  : bias
        })
    return weights_and_biases

In [13]:
samples = [
        (objects[i+1], -1, objects[i]) for i in range(len(objects) - 1)
    ]
print(samples)

weights_and_biases = get_multiple_averaged_JB(
    samples, 
    relation_prompt="{} starts with", 
    N = 3, 
    calculate_at_lnf=False
)

[(' two', -1, ' one'), (' three', -1, ' two'), (' four', -1, ' three'), (' five', -1, ' four'), (' six', -1, ' five'), (' seven', -1, ' six'), (' eight', -1, ' seven'), (' nine', -1, ' eight'), (' ten', -1, ' nine'), (' eleven', -1, ' ten'), (' twelve', -1, ' eleven'), (' thirteen', -1, ' twelve'), (' fourteen', -1, ' thirteen'), (' fifteen', -1, ' fourteen'), (' sixteen', -1, ' fifteen'), (' seventeen', -1, ' sixteen'), (' eighteen', -1, ' seventeen'), (' nineteen', -1, ' eighteen'), (' twenty', -1, ' nineteen')]


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

subject:   eighteen
 ten starts with  nine.
 eleven starts with  ten.
{} starts with
tensor(12.2578, device='cuda:0', dtype=torch.float16) tensor(330.5000, device='cuda:0', dtype=torch.float16)

subject:   ten
 eleven starts with  ten.
 two starts with  one.
{} starts with
tensor(5.1523, device='cuda:0', dtype=torch.float16) tensor(371.5000, device='cuda:0', dtype=torch.float16)

subject:   two
 eleven starts with  ten.
 ten starts with  nine.
{} starts with
tensor(6.8242, device='cuda:0', dtype=torch.float16) tensor(363.2500, device='cuda:0', dtype=torch.float16)

subject:   eleven
 ten starts with  nine.
 two starts with  one.
{} starts with
tensor(10.9375, device='cuda:0', dtype=torch.float16) tensor(352.2500, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:   five
 twelve starts with  eleven.
 seven starts with  six.
{} starts with
tensor(6.3320, device='cuda:0', dtype=torch.float16) tensor(401.7500, device='cuda:0', dtype=torch.float16)

subject:   twelve
 five starts with  four.
 seven starts with  six.
{} starts with
tensor(7.6875, device='cuda:0', dtype=torch.float16) tensor(396.5000, device='cuda:0', dtype=torch.float16)

subject:   four
 twelve starts with  eleven.
 seven starts with  six.
{} starts with
tensor(6.6211, device='cuda:0', dtype=torch.float16) tensor(398.2500, device='cuda:0', dtype=torch.float16)

subject:   seven
 twelve starts with  eleven.
 five starts with  four.
{} starts with
tensor(7.7461, device='cuda:0', dtype=torch.float16) tensor(381., device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:   twenty
 four starts with  three.
 sixteen starts with  fifteen.
{} starts with
tensor(9.1016, device='cuda:0', dtype=torch.float16) tensor(400.7500, device='cuda:0', dtype=torch.float16)

subject:   eighteen
 four starts with  three.
 twenty starts with  nineteen.
{} starts with
tensor(10.4141, device='cuda:0', dtype=torch.float16) tensor(382.7500, device='cuda:0', dtype=torch.float16)

subject:   four
 eighteen starts with  seventeen.
 twenty starts with  nineteen.
{} starts with
tensor(10.9531, device='cuda:0', dtype=torch.float16) tensor(382., device='cuda:0', dtype=torch.float16)

subject:   sixteen
 eighteen starts with  seventeen.
 four starts with  three.
{} starts with
tensor(13.7188, device='cuda:0', dtype=torch.float16) tensor(379.7500, device='cuda:0', dtype=torch.float16)



In [14]:
prompt

'three comes after two\nsix comes after five\n{} comes after'

In [17]:
relation_operator = estimate.RelationOperator(
    model = model,
    tokenizer= tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.stack(
        [wb['weight'] for wb in weights_and_biases]
    ).mean(dim=0),
    # bias = torch.stack(
    #     [wb['bias'] for wb in weights_and_biases]
    # ).mean(dim=0),
    bias = simple_corner
)

check_with_test_cases(relation_operator)

 two, target:  one   ==>   predicted: [' twelve', ' seven', ' six', ' eight', ' five']
 three, target:  two   ==>   predicted: [' twelve', ' seven', ' six', ' eight', ' five']
 four, target:  three   ==>   predicted: [' twelve', ' seven', ' five', ' six', ' eight']
 five, target:  four   ==>   predicted: [' twelve', ' five', ' seven', ' six', ' eight']
 six, target:  five   ==>   predicted: [' twelve', ' six', ' seven', ' eight', ' five']
 seven, target:  six   ==>   predicted: [' seven', ' twelve', ' six', ' eight', ' five']
 eight, target:  seven   ==>   predicted: [' twelve', ' eight', ' seven', ' nine', ' six']
 nine, target:  eight   ==>   predicted: [' twelve', ' nine', ' seven', ' eight', ' ten']
 ten, target:  nine   ==>   predicted: [' twelve', ' ten', ' eleven', ' fourteen', ' nine']
 eleven, target:  ten   ==>   predicted: [' twelve', ' eleven', ' fourteen', ' thirteen', ' fifteen']
 twelve, target:  eleven   ==>   predicted: [' twelve', ' thirteen', ' fourteen', ' eleven', 

In [16]:
corner_estimator.get_vocab_representation(
    torch.stack(
        [wb['bias'] for wb in weights_and_biases]
    ).mean(dim=0)
)

[' ', '.', '\n', ' three', ' seven']