In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter

In [3]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

EleutherAI/gpt-j-6B ==> device: cuda:0, memory: 12219206136


In [12]:
prompt = """grape starts with G
monitor starts with M
{} starts with"""

words = ['month', 'major', 'star', 'areas', 'future', 'space', 'committee', 'london', 'washington', 'meeting']

# prompt = """grape ends with E
# monitor ends with R
# glass ends with"""

for w in words:
    txt, ret_dict = model_utils.generate_fast(
        model, tokenizer, 
        prompts=[prompt.format(w)], max_new_tokens=10, 
        get_answer_tokens=True, argmax_greedy=True
    )
    print(f"{w} ===> {ret_dict['answer'][0]['top_token']}")


month ===>  M
major ===>  A
star ===>  S
areas ===>  A
future ===>  F
space ===>  S
committee ===>  C
london ===>  L
washington ===>  W
meeting ===>  M


In [13]:
# prompt = """grape is spelled G-R-A-P-E
# monitor starts with M-O-N-I-T-O-R
# {} is spelled"""
    
# txt, ret_dict = model_utils.generate_fast(
#     model, tokenizer, 
#     prompts=[prompt.format('arnab')], max_new_tokens=20, 
#     get_answer_tokens=True, argmax_greedy=True
# )
# txt

In [14]:
from relations.corner import CornerEstimator
corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)

In [15]:
import string
objects = list(string.ascii_uppercase)
objects = [" " + o for o in objects]

In [25]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner, get_logits=True))

66.875 [(' C', 117.062), (' S', 117.0), (' M', 116.812), (' L', 116.688), (' P', 116.438)]


In [30]:
lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)
print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner, get_logits=True))

52.5 [(' S', 46.125), (' C', 46.0), (' D', 45.781), (' P', 45.062), (' R', 45.031)]


In [34]:
lst_sq_corner = corner_estimator.estimate_corner_lstsq_solve(objects, target_logit=40)
print(lst_sq_corner.norm().item(), corner_estimator.get_vocab_representation(lst_sq_corner, get_logits=True))

42.5625 [(' Q', 109.312), (' X', 109.188), (' I', 109.188), (' Y', 109.125), (' O', 109.125)]


In [47]:
words

['month',
 'major',
 'star',
 'areas',
 'future',
 'space',
 'committee',
 'london',
 'washington',
 'meeting']

In [48]:
def check_with_test_cases(relation_operator, word_list = words):

    test_cases = [
        (w, -1, w[0].upper()) for w in word_list
    ]

    for subject, subject_token_index, target in test_cases:
        objects = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {objects}")

In [36]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = simple_corner
)
check_with_test_cases(relation)

month, target: M   ==>   predicted: [' C', ' S', ' M', ' P', ' B']
major, target: M   ==>   predicted: [' M', ' C', ' D', ' P', ' S']
star, target: S   ==>   predicted: [' S', ' M', ' T', ' P', ' C']
areas, target: A   ==>   predicted: [' C', ' B', ' S', ' P', ' F']
future, target: F   ==>   predicted: [' T', ' C', ' D', ' E', ' B']
space, target: S   ==>   predicted: [' S', ' M', ' P', ' B', ' R']
committee, target: C   ==>   predicted: [' C', ' B', ' T', ' D', ' S']
london, target: L   ==>   predicted: [' C', ' B', ' S', ' D', ' P']
washington, target: W   ==>   predicted: [' D', ' C', ' L', ' S', ' B']
meeting, target: M   ==>   predicted: [' C', ' T', ' S', ' M', ' B']


In [42]:
def get_averaged_JB(top_performers, relation_prompt, num_icl = 3, calculate_at_lnf = False):
    try:
        jbs = []
        for s, s_idx, o in tqdm(top_performers):
            others = set(top_performers) - {(s, s_idx, o)}
            others = random.sample(list(others), k = min(num_icl, len(list(others)))) 
            prompt = ""
            prompt += "\n".join(relation_prompt.format(s_other) + f"{o_other}." for s_other, idx_other, o_other in others) + "\n"
            prompt += relation_prompt
            print("subject: ", s)
            print(prompt)

            jb, _ = estimate.relation_operator_from_sample(
                model, tokenizer,
                s, prompt,
                subject_token_index= s_idx,
                layer = 15,
                device = model.device,
                # calculate_at_lnf = calculate_at_lnf
            )
            print(jb.weight.norm(), jb.bias.norm())
            print()
            jbs.append(jb)
        
        weight = torch.stack([jb.weight for jb in jbs]).mean(dim=0)
        bias  = torch.stack([jb.bias for jb in jbs]).mean(dim=0)

        return weight, bias
    except RuntimeError as e:
        if(str(e).startswith("CUDA out of memory")):
            print("CUDA out of memory")
        if(num_icl > 1):
            num_icl -= 1
            print("trying with smaller icl >> ", num_icl)
            return get_averaged_JB(top_performers, relation_prompt, num_icl, calculate_at_lnf)
        else:
            raise Exception("RuntimeError >> can't calculate Jacobian with minimum number of icl examples")

def get_multiple_averaged_JB(top_performers, relation_prompt, N = 3, num_icl = 2, calculate_at_lnf = False):
    weights_and_biases = []
    sample_size = min(len(top_performers), num_icl + 2)
    for _ in tqdm(range(N)):
        cur_sample = random.sample(top_performers, k = sample_size)
        weight, bias = get_averaged_JB(cur_sample, relation_prompt, num_icl, calculate_at_lnf)
        weights_and_biases.append({
            'weight': weight,
            'bias'  : bias
        })
    return weights_and_biases

In [44]:
samples = [(w, -1, " " + w[0].upper()) for w in words]
print(samples)

weights_and_biases = get_multiple_averaged_JB(
    samples, 
    relation_prompt=" {} starts with", 
    N = 3, 
    calculate_at_lnf=False
)

[('month', -1, ' M'), ('major', -1, ' M'), ('star', -1, ' S'), ('areas', -1, ' A'), ('future', -1, ' F'), ('space', -1, ' S'), ('committee', -1, ' C'), ('london', -1, ' L'), ('washington', -1, ' W'), ('meeting', -1, ' M')]


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

subject:  future
 space starts with S.
 london starts with L.
 {} starts with
tensor(29., device='cuda:0', dtype=torch.float16) tensor(358.2500, device='cuda:0', dtype=torch.float16)

subject:  london
 areas starts with A.
 future starts with F.
 {} starts with
tensor(34.9375, device='cuda:0', dtype=torch.float16) tensor(308.5000, device='cuda:0', dtype=torch.float16)

subject:  areas
 space starts with S.
 future starts with F.
 {} starts with
tensor(24.3594, device='cuda:0', dtype=torch.float16) tensor(382.2500, device='cuda:0', dtype=torch.float16)

subject:  space
 london starts with L.
 future starts with F.
 {} starts with
tensor(22.4844, device='cuda:0', dtype=torch.float16) tensor(296.2500, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  major
 space starts with S.
 areas starts with A.
 {} starts with
tensor(31.6562, device='cuda:0', dtype=torch.float16) tensor(369.5000, device='cuda:0', dtype=torch.float16)

subject:  space
 future starts with F.
 areas starts with A.
 {} starts with
tensor(19.2500, device='cuda:0', dtype=torch.float16) tensor(351.5000, device='cuda:0', dtype=torch.float16)

subject:  areas
 major starts with M.
 future starts with F.
 {} starts with
tensor(20.4688, device='cuda:0', dtype=torch.float16) tensor(354., device='cuda:0', dtype=torch.float16)

subject:  future
 major starts with M.
 areas starts with A.
 {} starts with
tensor(20.2656, device='cuda:0', dtype=torch.float16) tensor(365., device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  future
 london starts with L.
 star starts with S.
 {} starts with
tensor(27.4375, device='cuda:0', dtype=torch.float16) tensor(342.2500, device='cuda:0', dtype=torch.float16)

subject:  london
 meeting starts with M.
 future starts with F.
 {} starts with
tensor(37.5000, device='cuda:0', dtype=torch.float16) tensor(320.5000, device='cuda:0', dtype=torch.float16)

subject:  meeting
 london starts with L.
 future starts with F.
 {} starts with
tensor(35.2500, device='cuda:0', dtype=torch.float16) tensor(365.5000, device='cuda:0', dtype=torch.float16)

subject:  star
 future starts with F.
 london starts with L.
 {} starts with
tensor(1.7354, device='cuda:0', dtype=torch.float16) tensor(387.5000, device='cuda:0', dtype=torch.float16)



In [66]:
relation_operator = estimate.RelationOperator(
    model = model,
    tokenizer= tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.stack(
        [wb['weight'] for wb in weights_and_biases]
    ).mean(dim=0),
    # bias = torch.stack(
    #     [wb['bias'] for wb in weights_and_biases]
    # ).mean(dim=0),
    bias = lst_sq_corner
)

check_with_test_cases(
    relation_operator,
    word_list= [
        "prediction", "induction", "antonym", 'xerox', 'geometry',
        "activation", "jerusalem", "blueberry", "understand", "granite",
        "hygiene", "antic"
    ]
)

prediction, target: P   ==>   predicted: [' P', ' Q', ' X', ' Y', ' G']
induction, target: I   ==>   predicted: [' I', ' U', ' X', ' N', ' H']
antonym, target: A   ==>   predicted: [' N', ' X', ' A', ' Y', ' T']
xerox, target: X   ==>   predicted: [' X', ' Q', ' Z', ' R', ' O']
geometry, target: G   ==>   predicted: [' G', ' Q', ' E', ' Y', ' R']
activation, target: A   ==>   predicted: [' A', ' V', ' X', ' C', ' Y']
jerusalem, target: J   ==>   predicted: [' J', ' Y', ' Z', ' U', ' G']
blueberry, target: B   ==>   predicted: [' B', ' Q', ' R', ' W', ' G']
understand, target: U   ==>   predicted: [' U', ' H', ' N', ' R', ' B']
granite, target: G   ==>   predicted: [' G', ' R', ' Q', ' K', ' J']
hygiene, target: H   ==>   predicted: [' H', ' G', ' Y', ' Q', ' K']
antic, target: A   ==>   predicted: [' A', ' T', ' C', ' N', ' O']


In [53]:
corner_estimator.get_vocab_representation(
    torch.stack(
        [wb['bias'] for wb in weights_and_biases]
    ).mean(dim=0)
)

[' S', ' A', ' a', ' E', ' M']

In [55]:
tokenizer.tokenize(" fifteen")

['Ġfifteen']