In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter

In [3]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

EleutherAI/gpt-j-6B ==> device: cuda:0, memory: 12219206136


In [4]:
import pandas as pd

df = pd.read_csv("numbers_spelled_out.tsv", delimiter='\t')
objects = list(df['words'])

In [5]:
prompt = """three comes after two
six comes after five
{} comes after"""

# prompt = """grape ends with E
# monitor ends with R
# glass ends with"""

words = [ 'eleven', 'twelve', 'thirteen', 'fourteen']

for w in words:
    txt, ret_dict = model_utils.generate_fast(
        model, tokenizer, 
        prompts=[prompt.format(w)], max_new_tokens=10, 
        get_answer_tokens=True, argmax_greedy=True
    )
    print(f"{w} ===> {ret_dict['answer'][0]['top_token']}")


eleven ===>  ten
twelve ===>  eleven
thirteen ===>  twelve
fourteen ===>  thirteen


In [6]:
objects = [" " + o for o in df['words']]

from relations.corner import CornerEstimator
corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)

In [10]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=60)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner, get_logits=True))

53.5625 [(' twelve', 113.188), (' seven', 110.25), (' fourteen', 110.125), (' fifteen', 109.375), (' six', 109.312)]


In [18]:
lst_sq_corner = corner_estimator.estimate_corner_lstsq_solve(objects, target_logit=40)
print(lst_sq_corner.norm().item(), corner_estimator.get_vocab_representation(lst_sq_corner, get_logits=True))

46.75 [(' ten', 100.562), (' eight', 100.5), (' four', 100.438), (' two', 100.438), (' one', 100.438)]


In [19]:
# avg_corner = corner_estimator.estimate_average_corner_with_gradient_descent(objects, average_on=5, target_logit_value=50, verbose=False)
# print(avg_corner.norm().item(), corner_estimator.get_vocab_representation(avg_corner))

In [20]:
def check_with_test_cases(relation_operator):

    test_cases = [
        (objects[i+1], -1, objects[i]) for i in range(len(objects) - 1)
    ]

    for subject, subject_token_index, target in test_cases:
        answer = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {answer}")

In [22]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = simple_corner
)
check_with_test_cases(relation)

 two, target:  one   ==>   predicted: [' seven', ' twelve', ' eight', ' six', ' nine']
 three, target:  two   ==>   predicted: [' twelve', ' ten', ' twenty', ' seven', ' fourteen']
 four, target:  three   ==>   predicted: [' twelve', ' ten', ' seven', ' five', ' twenty']
 five, target:  four   ==>   predicted: [' twelve', ' nine', ' twenty', ' eight', ' eleven']
 six, target:  five   ==>   predicted: [' seven', ' eight', ' five', ' ten', ' twelve']
 seven, target:  six   ==>   predicted: [' eight', ' twelve', ' nine', ' eleven', ' twenty']
 eight, target:  seven   ==>   predicted: [' twelve', ' seven', ' nine', ' ten', ' five']
 nine, target:  eight   ==>   predicted: [' twelve', ' seven', ' eight', ' ten', ' twenty']
 ten, target:  nine   ==>   predicted: [' twelve', ' eight', ' twenty', ' seven', ' eleven']
 eleven, target:  ten   ==>   predicted: [' twelve', ' eight', ' seven', ' twenty', ' six']
 twelve, target:  eleven   ==>   predicted: [' seven', ' five', ' eight', ' twelve', ' 

In [30]:
def get_averaged_JB(top_performers, relation_prompt, num_icl = 3, calculate_at_lnf = False):
    try:
        jbs = []
        for s, s_idx, o in tqdm(top_performers):
            others = set(top_performers) - {(s, s_idx, o)}
            others = random.sample(list(others), k = min(num_icl, len(list(others)))) 
            prompt = ""
            prompt += "\n".join(relation_prompt.format(s_other) + f"{o_other}." for s_other, idx_other, o_other in others) + "\n"
            prompt += " " + relation_prompt
            print("subject: ", s)
            print(prompt)

            jb, _ = estimate.relation_operator_from_sample(
                model, tokenizer,
                s, prompt,
                subject_token_index= s_idx,
                layer = 15,
                device = model.device,
                # calculate_at_lnf = calculate_at_lnf
            )
            print(jb.weight.norm(), jb.bias.norm())
            print()
            jbs.append(jb)
        
        weight = torch.stack([jb.weight for jb in jbs]).mean(dim=0)
        bias  = torch.stack([jb.bias for jb in jbs]).mean(dim=0)

        return weight, bias
    except RuntimeError as e:
        if(str(e).startswith("CUDA out of memory")):
            print("CUDA out of memory")
        if(num_icl > 1):
            num_icl -= 1
            print("trying with smaller icl >> ", num_icl)
            return get_averaged_JB(top_performers, relation_prompt, num_icl, calculate_at_lnf)
        else:
            raise Exception("RuntimeError >> can't calculate Jacobian with minimum number of icl examples")

def get_multiple_averaged_JB(top_performers, relation_prompt, N = 3, num_icl = 2, calculate_at_lnf = False):
    weights_and_biases = []
    sample_size = min(len(top_performers), num_icl + 2)
    for _ in tqdm(range(N)):
        cur_sample = random.sample(top_performers, k = sample_size)
        weight, bias = get_averaged_JB(cur_sample, relation_prompt, num_icl, calculate_at_lnf)
        weights_and_biases.append({
            'weight': weight,
            'bias'  : bias
        })
    return weights_and_biases

In [31]:
samples = [
        (objects[i+1], -1, objects[i]) for i in range(len(objects) - 1)
    ]
print(samples)

weights_and_biases = get_multiple_averaged_JB(
    samples, 
    relation_prompt="{} comes after", 
    N = 3, 
    calculate_at_lnf=False
)

[(' two', -1, ' one'), (' three', -1, ' two'), (' four', -1, ' three'), (' five', -1, ' four'), (' six', -1, ' five'), (' seven', -1, ' six'), (' eight', -1, ' seven'), (' nine', -1, ' eight'), (' ten', -1, ' nine'), (' eleven', -1, ' ten'), (' twelve', -1, ' eleven'), (' thirteen', -1, ' twelve'), (' fourteen', -1, ' thirteen'), (' fifteen', -1, ' fourteen'), (' sixteen', -1, ' fifteen'), (' seventeen', -1, ' sixteen'), (' eighteen', -1, ' seventeen'), (' nineteen', -1, ' eighteen'), (' twenty', -1, ' nineteen')]


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

subject:   twelve
 thirteen comes after twelve.
 ten comes after nine.
 {} comes after
tensor(7.0117, device='cuda:0', dtype=torch.float16) tensor(340.7500, device='cuda:0', dtype=torch.float16)

subject:   thirteen
 ten comes after nine.
 twelve comes after eleven.
 {} comes after
tensor(15.9766, device='cuda:0', dtype=torch.float16) tensor(303.7500, device='cuda:0', dtype=torch.float16)

subject:   ten
 thirteen comes after twelve.
 twelve comes after eleven.
 {} comes after
tensor(16.6719, device='cuda:0', dtype=torch.float16) tensor(356.7500, device='cuda:0', dtype=torch.float16)

subject:   sixteen
 ten comes after nine.
 thirteen comes after twelve.
 {} comes after
tensor(16.0469, device='cuda:0', dtype=torch.float16) tensor(331., device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:   seven
 fifteen comes after fourteen.
 four comes after three.
 {} comes after
tensor(19.5156, device='cuda:0', dtype=torch.float16) tensor(383., device='cuda:0', dtype=torch.float16)

subject:   four
 five comes after four.
 fifteen comes after fourteen.
 {} comes after
tensor(4.6523, device='cuda:0', dtype=torch.float16) tensor(401.5000, device='cuda:0', dtype=torch.float16)

subject:   five
 four comes after three.
 seven comes after six.
 {} comes after
tensor(15.0312, device='cuda:0', dtype=torch.float16) tensor(363.5000, device='cuda:0', dtype=torch.float16)

subject:   fifteen
 four comes after three.
 five comes after four.
 {} comes after
tensor(17.9062, device='cuda:0', dtype=torch.float16) tensor(369.2500, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:   sixteen
 six comes after five.
 nine comes after eight.
 {} comes after
tensor(17.9062, device='cuda:0', dtype=torch.float16) tensor(312.5000, device='cuda:0', dtype=torch.float16)

subject:   nine
 seventeen comes after sixteen.
 sixteen comes after fifteen.
 {} comes after
tensor(16.5938, device='cuda:0', dtype=torch.float16) tensor(399.7500, device='cuda:0', dtype=torch.float16)

subject:   seventeen
 six comes after five.
 nine comes after eight.
 {} comes after
tensor(20.6719, device='cuda:0', dtype=torch.float16) tensor(354.7500, device='cuda:0', dtype=torch.float16)

subject:   six
 seventeen comes after sixteen.
 sixteen comes after fifteen.
 {} comes after
tensor(6.2773, device='cuda:0', dtype=torch.float16) tensor(386., device='cuda:0', dtype=torch.float16)



In [32]:
relation_operator = estimate.RelationOperator(
    model = model,
    tokenizer= tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.stack(
        [wb['weight'] for wb in weights_and_biases]
    ).mean(dim=0),
    # bias = torch.stack(
    #     [wb['bias'] for wb in weights_and_biases]
    # ).mean(dim=0),
    bias = lst_sq_corner
)

check_with_test_cases(relation_operator)

 two, target:  one   ==>   predicted: [' seven', ' eight', ' nine', ' five', ' eleven']
 three, target:  two   ==>   predicted: [' nine', ' eight', ' seven', ' twelve', ' eleven']
 four, target:  three   ==>   predicted: [' four', ' nine', ' eight', ' five', ' seven']
 five, target:  four   ==>   predicted: [' five', ' ten', ' seven', ' eight', ' six']
 six, target:  five   ==>   predicted: [' six', ' seven', ' nine', ' five', ' twelve']
 seven, target:  six   ==>   predicted: [' seven', ' eight', ' six', ' nine', ' three']
 eight, target:  seven   ==>   predicted: [' nine', ' eight', ' seven', ' four', ' three']
 nine, target:  eight   ==>   predicted: [' nine', ' eleven', ' eight', ' four', ' seven']
 ten, target:  nine   ==>   predicted: [' eleven', ' ten', ' nine', ' five', ' one']
 eleven, target:  ten   ==>   predicted: [' eleven', ' twelve', ' thirteen', ' ten', ' one']
 twelve, target:  eleven   ==>   predicted: [' eleven', ' twelve', ' thirteen', ' four', ' seven']
 thirteen, 

In [32]:
corner_estimator.get_vocab_representation(
    torch.stack(
        [wb['bias'] for wb in weights_and_biases]
    ).mean(dim=0)
)

[' ', '.', ' one', ' three', ' two']