In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter

In [3]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

EleutherAI/gpt-j-6B ==> device: cuda:0, memory: 12219206136


In [11]:
with open("hypernym.txt") as f:
    lines = f.readlines()
    base, hypernym = [], []
    for line in lines:
        words = line.strip().split()
        base.append(words[0])
        hypernym.append(words[2])

# list(zip(base, hypernym))

In [15]:
# prompt = """superman is Clark Kent
# flash is Barry Allen
# Wolverine is"""

# txt, ret_dict = model_utils.generate_fast(
#     model, tokenizer, 
#     prompts=[prompt], max_new_tokens=10, 
#     get_answer_tokens=True, argmax_greedy=True
# )
# txt

In [16]:
ret_dict['answer']

[{'top_token': ' James',
  'candidates': [{'token': ' James', 'token_id': 3700, 'p': 0.1595},
   {'token': ' Logan', 'token_id': 22221, 'p': 0.084},
   {'token': ' Charles', 'token_id': 7516, 'p': 0.049},
   {'token': ' Hank', 'token_id': 24386, 'p': 0.0443},
   {'token': ' a', 'token_id': 257, 'p': 0.0413}]}]

In [26]:
prompt = """oak is a tree
apple is a fruit
{} is a"""

filter_by_model_knowledge = []
for bs, hyper in zip(base, hypernym):
    txt, ret_dict = model_utils.generate_fast(
        model, tokenizer, 
        prompts=[prompt.format(bs)], max_new_tokens=10, 
        get_answer_tokens=True, argmax_greedy=True
    )
    # tick = hyper.startswith(ret_dict['answer'][0]['top_token'].strip())
    tick = hyper == ret_dict['answer'][0]['top_token'].strip()
    print(f"{bs} >> {hyper} ===> {[(ans['token'], ans['p']) for ans in ret_dict['answer'][0]['candidates']]} :: {tick}")
    if(tick):
        filter_by_model_knowledge.append((bs, hyper))


oak >> tree ===> [(' tree', 0.6543), (' type', 0.0712), (' fruit', 0.047), (' wood', 0.0467), (' plant', 0.026)] :: True
dog >> animal ===> [(' mammal', 0.2701), (' domest', 0.1446), (' four', 0.0877), (' pet', 0.075), (' type', 0.0524)] :: False
opera >> music ===> [(' form', 0.2663), (' musical', 0.2173), (' type', 0.1238), (' music', 0.0489), (' performance', 0.0418)] :: False
tea >> drink ===> [(' drink', 0.5667), (' beverage', 0.2986), (' plant', 0.0179), (' liquid', 0.0171), (' herb', 0.0166)] :: True
diamond >> gem ===> [(' mineral', 0.2899), (' gem', 0.281), (' stone', 0.0941), (' precious', 0.0857), (' rock', 0.0608)] :: False
happiness >> feeling ===> [(' feeling', 0.4196), (' state', 0.2428), (' emotion', 0.0453), (' good', 0.0409), (' quality', 0.0168)] :: True
family >> group ===> [(' group', 0.6809), (' category', 0.0361), (' collection', 0.029), (' large', 0.0195), (' grouping', 0.0189)] :: True
apple >> fruit ===> [(' fruit', 0.296), (' tree', 0.1972), (' type', 0.0903)

In [27]:
len(filter_by_model_knowledge)

34

In [28]:
objects = [" " + o[1] for o in filter_by_model_knowledge]

from relations.corner import CornerEstimator
corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)

In [35]:
sorted(objects)

[' bird',
 ' candy',
 ' cat',
 ' color',
 ' dance',
 ' dictionary',
 ' dish',
 ' drink',
 ' drug',
 ' feeling',
 ' film',
 ' fish',
 ' food',
 ' fruit',
 ' game',
 ' gender',
 ' group',
 ' herb',
 ' machine',
 ' metal',
 ' person',
 ' planet',
 ' plant',
 ' religion',
 ' science',
 ' season',
 ' seed',
 ' shape',
 ' star',
 ' toy',
 ' tree',
 ' vegetable',
 ' vehicle',
 ' wind']

In [29]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner, get_logits=True))

43.40625 [(' plant', 92.125), (' food', 85.438), (' tree', 84.812), (' group', 82.688), (' game', 82.688)]


In [30]:
lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)
print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner, get_logits=True))

calculating inverse of unbedding weights . . .
25.609375 [(' plant', 37.125), (' tree', 30.438), (' bird', 30.344), (' game', 29.578), (' fruit', 29.328)]


In [39]:
lst_sq_corner = corner_estimator.estimate_corner_lstsq_solve(objects, target_logit=50)
print(lst_sq_corner.norm().item(), corner_estimator.get_vocab_representation(lst_sq_corner, get_logits=True))

87.5 [(' water', 72.438), (' car', 72.312), (' model', 69.75), (' non', 69.438), (' g', 69.25)]


In [40]:
# avg_corner = corner_estimator.estimate_average_corner_with_gradient_descent(objects, average_on=5, target_logit_value=50, verbose=False)
# print(avg_corner.norm().item(), corner_estimator.get_vocab_representation(avg_corner))

In [41]:
def check_with_test_cases(relation_operator):
    test_cases = [
        (b, -1, h) for b, h in filter_by_model_knowledge[20:]
    ]
    for subject, subject_token_index, target in test_cases:
        answer = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {answer}")

In [47]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = lst_sq_corner
)
check_with_test_cases(relation)

judaism, target: religion   ==>   predicted: [' religion', ' car', ' culture', ' a', ' religious']
summer, target: season   ==>   predicted: [' season', ' time', ' water', ' car', ' fruit']
meat, target: food   ==>   predicted: [' car', ' food', ' water', ' cat', ' fish']
doll, target: toy   ==>   predicted: [' car', ' child', ' toy', ' water', ' body']
gold, target: metal   ==>   predicted: [' car', ' fish', ' water', ' metal', ' star']
rumba, target: dance   ==>   predicted: [' dance', ' music', ' car', ' song', ' cat']
round, target: shape   ==>   predicted: [' shape', ' water', ' car', ' fruit', ' ball']
breeze, target: wind   ==>   predicted: [' water', ' wind', ' car', ' a', ' drink']
lollipop, target: candy   ==>   predicted: [' candy', ' water', ' a', ' car', ' drink']
photographer, target: person   ==>   predicted: [' car', ' person', ' man', ' water', ' a']
documentary, target: film   ==>   predicted: [' film', ' science', ' car', ' water', ' movie']
anesthetic, target: drug 

In [48]:
def get_averaged_JB(top_performers, relation_prompt, num_icl = 3, calculate_at_lnf = False):
    try:
        jbs = []
        for s, s_idx, o in tqdm(top_performers):
            others = set(top_performers) - {(s, s_idx, o)}
            others = random.sample(list(others), k = min(num_icl, len(list(others)))) 
            prompt = ""
            prompt += "\n".join(relation_prompt.format(s_other) + f" {o_other}." for s_other, idx_other, o_other in others) + "\n"
            prompt += relation_prompt
            print("subject: ", s)
            print(prompt)

            jb, _ = estimate.relation_operator_from_sample(
                model, tokenizer,
                s, prompt,
                subject_token_index= s_idx,
                layer = 15,
                device = model.device,
                # calculate_at_lnf = calculate_at_lnf
            )
            print(jb.weight.norm(), jb.bias.norm())
            print()
            jbs.append(jb)
        
        weight = torch.stack([jb.weight for jb in jbs]).mean(dim=0)
        bias  = torch.stack([jb.bias for jb in jbs]).mean(dim=0)

        return weight, bias
    except RuntimeError as e:
        if(str(e).startswith("CUDA out of memory")):
            print("CUDA out of memory")
        if(num_icl > 1):
            num_icl -= 1
            print("trying with smaller icl >> ", num_icl)
            return get_averaged_JB(top_performers, relation_prompt, num_icl, calculate_at_lnf)
        else:
            raise Exception("RuntimeError >> can't calculate Jacobian with minimum number of icl examples")

def get_multiple_averaged_JB(top_performers, relation_prompt, N = 3, num_icl = 2, calculate_at_lnf = False):
    weights_and_biases = []
    sample_size = min(len(top_performers), num_icl + 2)
    for _ in tqdm(range(N)):
        cur_sample = random.sample(top_performers, k = sample_size)
        weight, bias = get_averaged_JB(cur_sample, relation_prompt, num_icl, calculate_at_lnf)
        weights_and_biases.append({
            'weight': weight,
            'bias'  : bias
        })
    return weights_and_biases

In [51]:
samples = [
        (b, -1, h) for b, h in filter_by_model_knowledge[:20]
    ]
print(samples)

weights_and_biases = get_multiple_averaged_JB(
    samples, 
    relation_prompt=" {} is a", 
    N = 3, 
    calculate_at_lnf=False
)

[('oak', -1, 'tree'), ('tea', -1, 'drink'), ('happiness', -1, 'feeling'), ('family', -1, 'group'), ('apple', -1, 'fruit'), ('thesaurus', -1, 'dictionary'), ('crow', -1, 'bird'), ('salmon', -1, 'fish'), ('flower', -1, 'plant'), ('tiger', -1, 'cat'), ('rosemary', -1, 'herb'), ('cucumber', -1, 'vegetable'), ('computer', -1, 'machine'), ('roulette', -1, 'game'), ('physics', -1, 'science'), ('earth', -1, 'planet'), ('sun', -1, 'star'), ('nut', -1, 'seed'), ('car', -1, 'vehicle'), ('yellow', -1, 'color')]


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

subject:  happiness
 salmon is a fish.
 apple is a fruit.
 {} is a
tensor(16.3594, device='cuda:0', dtype=torch.float16) tensor(272.2500, device='cuda:0', dtype=torch.float16)

subject:  salmon
 nut is a seed.
 happiness is a feeling.
 {} is a
tensor(27.6562, device='cuda:0', dtype=torch.float16) tensor(281., device='cuda:0', dtype=torch.float16)

subject:  nut
 salmon is a fish.
 apple is a fruit.
 {} is a
tensor(27.2344, device='cuda:0', dtype=torch.float16) tensor(282.5000, device='cuda:0', dtype=torch.float16)

subject:  apple
 happiness is a feeling.
 salmon is a fish.
 {} is a
tensor(25.3281, device='cuda:0', dtype=torch.float16) tensor(310.7500, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  crow
 salmon is a fish.
 car is a vehicle.
 {} is a
tensor(27.7812, device='cuda:0', dtype=torch.float16) tensor(277.7500, device='cuda:0', dtype=torch.float16)

subject:  nut
 car is a vehicle.
 salmon is a fish.
 {} is a
tensor(31.5781, device='cuda:0', dtype=torch.float16) tensor(286., device='cuda:0', dtype=torch.float16)

subject:  salmon
 nut is a seed.
 car is a vehicle.
 {} is a
tensor(28.4688, device='cuda:0', dtype=torch.float16) tensor(267.7500, device='cuda:0', dtype=torch.float16)

subject:  car
 nut is a seed.
 salmon is a fish.
 {} is a
tensor(18.2812, device='cuda:0', dtype=torch.float16) tensor(279.5000, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  sun
 computer is a machine.
 cucumber is a vegetable.
 {} is a
tensor(20.9688, device='cuda:0', dtype=torch.float16) tensor(280.7500, device='cuda:0', dtype=torch.float16)

subject:  computer
 sun is a star.
 cucumber is a vegetable.
 {} is a
tensor(22.6406, device='cuda:0', dtype=torch.float16) tensor(272.5000, device='cuda:0', dtype=torch.float16)

subject:  cucumber
 happiness is a feeling.
 sun is a star.
 {} is a
tensor(26.5000, device='cuda:0', dtype=torch.float16) tensor(275.7500, device='cuda:0', dtype=torch.float16)

subject:  happiness
 sun is a star.
 computer is a machine.
 {} is a
tensor(15.9375, device='cuda:0', dtype=torch.float16) tensor(237.8750, device='cuda:0', dtype=torch.float16)



In [56]:
relation_operator = estimate.RelationOperator(
    model = model,
    tokenizer= tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.stack(
        [wb['weight'] for wb in weights_and_biases]
    ).mean(dim=0),
    # bias = torch.stack(
    #     [wb['bias'] for wb in weights_and_biases]
    # ).mean(dim=0),
    bias = lst_sq_corner
)

check_with_test_cases(relation_operator)

judaism, target: religion   ==>   predicted: [' religion', ' car', ' water', ' d', ' g']
summer, target: season   ==>   predicted: [' season', ' water', ' car', ' g', ' time']
meat, target: food   ==>   predicted: [' high', ' car', ' game', ' red', ' g']
doll, target: toy   ==>   predicted: [' toy', ' child', ' model', ' car', ' water']
gold, target: metal   ==>   predicted: [' metal', ' car', ' g', ' water', ' color']
rumba, target: dance   ==>   predicted: [' dance', ' car', ' g', ' game', ' d']
round, target: shape   ==>   predicted: [' car', ' water', ' shape', ' game', ' ball']
breeze, target: wind   ==>   predicted: [' wind', ' water', ' drink', ' season', ' car']
lollipop, target: candy   ==>   predicted: [' candy', ' toy', ' child', ' car', ' water']
photographer, target: person   ==>   predicted: [' car', ' water', ' high', ' film', ' d']
documentary, target: film   ==>   predicted: [' film', ' car', ' water', ' high', ' non']
anesthetic, target: drug   ==>   predicted: [' dru

In [77]:
corner_estimator.get_vocab_representation(
    torch.stack(
        [wb['bias'] for wb in weights_and_biases]
    ).mean(dim=0)
)

[' most', ' the', '\n', ' least', ' ']

In [90]:
from typing import Any, Sequence, TypeAlias, List

unembedder = nethook.get_module(model, "lm_head")
ln_f = nethook.get_module(model, "transformer.ln_f")

def estimate_corner_lstsq_solve(
    target_words: List[str],
    target_logit: int = 50,
):
    target_tokenized = tokenizer(target_words, padding=True, return_tensors="pt").to(model.device)
    # print(target_tokenized)
    W = torch.stack([unembedder.weight[r[0].item()] for r in target_tokenized.input_ids])
    # print(target_tokenized.input_ids.shape, W.shape)
    b = unembedder.bias[target_tokenized.input_ids]
    b = b.reshape(b.shape[0])
    y = (torch.ones(len(target_words)) * target_logit).to(model.dtype).to(model.device) - b
    # print(b.shape, y.shape)
    if(model.dtype == torch.float16):
        W = W.to(torch.float32)
        y = y.to(torch.float32)
    x = torch.linalg.lstsq(W, y).solution
    print(W@x + b)
    return x.to(model.dtype)

corner = estimate_corner_lstsq_solve(objects)
corner.shape, corner.norm()

tensor([50.0028, 49.9846, 50.0143, 49.9919, 50.0005, 50.0105, 49.9887, 50.0128,
        50.0124, 49.9906, 49.9848, 49.9848, 49.9904, 50.0138, 49.9965, 49.9854,
        49.9973, 50.0067, 50.0075, 50.0016, 49.9978, 49.9894, 50.0059, 50.0106,
        50.0045, 50.0000, 50.0047, 50.0028, 49.9897, 50.0096, 50.0153, 50.0059,
        50.0144, 50.0063, 50.0045, 49.9968, 50.0085, 49.9871, 49.9907, 50.0133,
        50.0077], device='cuda:0')


(torch.Size([4096]), tensor(74., device='cuda:0', dtype=torch.float16))

In [28]:
def get_vocab_representation(
    h, 
    perform_layer_norm = True, return_top_k = 5, get_logits = False
):
    """
    get representation of vector `h` in the vocabulary space. basically applied logit lens
    """
    z = h.clone()
    if(perform_layer_norm == True):
        z = ln_f(z)
    logits = unembedder(z)
    token_ids = logits.topk(dim=-1, k=return_top_k).indices.squeeze().tolist()
    logit_values = logits.topk(dim=-1, k=return_top_k).values.squeeze().tolist()
    return [
        tokenizer.decode(t) if get_logits == False else (tokenizer.decode(t), np.round(v, 3))
        for t, v in zip(token_ids, logit_values)
    ]

get_vocab_representation(corner, get_logits=True)

[(' more', 75.688),
 (' lighter', 75.5),
 (' later', 74.438),
 (' lower', 74.312),
 (' better', 74.312)]