In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter

In [3]:
MODEL_NAME = "facebook/galactica-6.7b"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
n_embd_field = "hidden_size"

mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
# tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

facebook/galactica-6.7b ==> device: cuda:0, memory: 13314719744


In [7]:
with open("hypernym.txt") as f:
    lines = f.readlines()
    base, hypernym = [], []
    for line in lines:
        words = line.strip().split()
        base.append(words[0])
        hypernym.append(words[2])

list(zip(base, hypernym))

[('oak', 'tree'),
 ('dog', 'animal'),
 ('opera', 'music'),
 ('tea', 'drink'),
 ('diamond', 'gem'),
 ('happiness', 'feeling'),
 ('family', 'group'),
 ('apple', 'fruit'),
 ('comfort', 'satisfaction'),
 ('grandma', 'grandparent'),
 ('thesaurus', 'dictionary'),
 ('crow', 'bird'),
 ('tennis', 'sport'),
 ('salmon', 'fish'),
 ('health', 'condition'),
 ('flower', 'plant'),
 ('jacket', 'garment'),
 ('tiger', 'cat'),
 ('green', 'colour'),
 ('unicorn', 'creature'),
 ('hat', 'headdress'),
 ('go', 'move'),
 ('writer', 'literate'),
 ('ultramarine', 'blue'),
 ('rosemary', 'herb'),
 ('prose', 'genre'),
 ('cucumber', 'vegetable'),
 ('computer', 'machine'),
 ('roulette', 'game'),
 ('emerald', 'gemstone'),
 ('cow', 'cattle'),
 ('physics', 'science'),
 ('earth', 'planet'),
 ('grocery', 'market'),
 ('fork', 'utensil'),
 ('sun', 'star'),
 ('room', 'area'),
 ('roof', 'cover'),
 ('nut', 'seed'),
 ('necklace', 'jewellery'),
 ('flirt', 'play'),
 ('romance', 'relationship'),
 ('communication', 'act'),
 ('coffee'

In [8]:
# prompt = """superman is Clark Kent
# flash is Barry Allen
# Wolverine is"""

# txt, ret_dict = model_utils.generate_fast(
#     model, tokenizer, 
#     prompts=[prompt], max_new_tokens=10, 
#     get_answer_tokens=True, argmax_greedy=True
# )
# txt

In [9]:
prompt = """oak is a tree
apple is a fruit
{} is a"""

filter_by_model_knowledge = []
for bs, hyper in zip(base, hypernym):
    txt, ret_dict = model_utils.generate_fast(
        model, tokenizer, 
        prompts=[prompt.format(bs)], max_new_tokens=10, 
        get_answer_tokens=True, argmax_greedy=True
    )
    # tick = hyper.startswith(ret_dict['answer'][0]['top_token'].strip())
    tick = hyper == ret_dict['answer'][0]['top_token'].strip()
    print(f"{bs} >> {hyper} ===> {[(ans['token'], ans['p']) for ans in ret_dict['answer'][0]['candidates']]} :: {tick}")
    if(tick):
        filter_by_model_knowledge.append((bs, hyper))


oak >> tree ===> [(' tree', 0.3376), (' hard', 0.1385), (' fruit', 0.0916), (' kind', 0.064), (' type', 0.062)] :: True
dog >> animal ===> [(' mammal', 0.6353), (' dog', 0.0746), (' pet', 0.0439), (' carniv', 0.0283), (' animal', 0.0275)] :: False
opera >> music ===> [(' musical', 0.457), (' type', 0.0879), (' music', 0.0478), (' form', 0.0375), (' play', 0.0355)] :: False
tea >> drink ===> [(' beverage', 0.4783), (' drink', 0.4287), (' liquid', 0.0613), (' plant', 0.0038), (' fluid', 0.0032)] :: False
diamond >> gem ===> [(' gem', 0.4329), (' mineral', 0.152), (' precious', 0.0669), (' stone', 0.0471), (' material', 0.0186)] :: True
happiness >> feeling ===> [(' feeling', 0.7178), (' mood', 0.1692), (' state', 0.0246), (' emotion', 0.0156), (' word', 0.0102)] :: True
family >> group ===> [(' group', 0.7422), (' social', 0.0302), (' set', 0.0181), (' type', 0.0125), (' grouping', 0.0124)] :: True
apple >> fruit ===> [(' tree', 0.1466), (' fruit', 0.1421), (' kind', 0.0869), (' plant', 

In [10]:
len(filter_by_model_knowledge)

34

In [12]:
objects = [" " + o[1] for o in filter_by_model_knowledge]

from relations.corner import CornerEstimator
corner_estimator = CornerEstimator(
    model=model, tokenizer=tokenizer,
    ln_f_name= "model.decoder.final_layer_norm", 
    unembedder_module_name="lm_head"
)

In [15]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner, get_logits=True))

28.171875 [(' bird', 39.188), (' plant', 37.25), (' fish', 35.812), (' food', 33.125), (' game', 32.438)]


In [16]:
lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)
print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner, get_logits=True))

calculating inverse of unbedding weights . . .
18.265625 [(' plant', 23.969), (' bird', 22.859), (' game', 21.344), (' person', 20.828), (' drug', 20.594)]


In [17]:
lst_sq_corner = corner_estimator.estimate_corner_lstsq_solve(objects, target_logit=50)
print(lst_sq_corner.norm().item(), corner_estimator.get_vocab_representation(lst_sq_corner, get_logits=True))

131.875 [(' galaxy', 24.406), (' material', 24.266), (' group', 24.266), (' star', 24.266), (' human', 24.266)]


In [18]:
# avg_corner = corner_estimator.estimate_average_corner_with_gradient_descent(objects, average_on=5, target_logit_value=50, verbose=False)
# print(avg_corner.norm().item(), corner_estimator.get_vocab_representation(avg_corner))

In [19]:
def check_with_test_cases(relation_operator):
    test_cases = [
        (b, -1, h) for b, h in filter_by_model_knowledge[20:]
    ]
    for subject, subject_token_index, target in test_cases:
        answer = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {answer}")

In [21]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.eye(getattr(model.config, n_embd_field)).to(model.dtype).to(model.device),
    bias = lst_sq_corner,

    layer_name_format = "model.decoder.layers.{}",
    ln_f_name = "model.decoder.final_layer_norm"
)
check_with_test_cases(relation)

summer, target: season   ==>   predicted: [' season', ' group', ' color', ' wind', ' plant']
meat, target: food   ==>   predicted: [' science', ' group', ' color', ' plant', ' food']
doll, target: toy   ==>   predicted: [' tree', ' shape', ' fish', ' star', ' group']
gold, target: metal   ==>   predicted: [' metal', ' tree', ' wind', ' star', ' color']
round, target: shape   ==>   predicted: [' shape', ' plant', ' wind', ' food', ' color']
breeze, target: wind   ==>   predicted: [' color', ' tree', ' metal', ' plant', ' season']
man, target: human   ==>   predicted: [' group', ' person', ' plant', ' color', ' food']
hologram, target: picture   ==>   predicted: [' metal', ' color', ' device', ' tree', ' plant']
paper, target: material   ==>   predicted: [' science', ' plant', ' wind', ' material', ' group']
photographer, target: person   ==>   predicted: [' group', ' fish', ' tree', ' game', ' drug']
documentary, target: film   ==>   predicted: [' film', ' material', ' science', ' group

In [22]:
def get_averaged_JB(top_performers, relation_prompt, num_icl = 3, calculate_at_lnf = False):
    try:
        jbs = []
        for s, s_idx, o in tqdm(top_performers):
            others = set(top_performers) - {(s, s_idx, o)}
            others = random.sample(list(others), k = min(num_icl, len(list(others)))) 
            prompt = ""
            prompt += "\n".join(relation_prompt.format(s_other) + f" {o_other}." for s_other, idx_other, o_other in others) + "\n"
            prompt += relation_prompt
            print("subject: ", s)
            print(prompt)

            jb, _ = estimate.relation_operator_from_sample(
                model, tokenizer,
                s, prompt,
                subject_token_index= s_idx,
                layer = 15,
                device = model.device,
                # calculate_at_lnf = calculate_at_lnf

                layer_name_format = "model.decoder.layers.{}",
                ln_f_name = "model.decoder.final_layer_norm",
                n_layer_field = "num_hidden_layers"
            )
            print(jb.weight.norm(), jb.bias.norm())
            print()
            jbs.append(jb)
        
        weight = torch.stack([jb.weight for jb in jbs]).mean(dim=0)
        bias  = torch.stack([jb.bias for jb in jbs]).mean(dim=0)

        return weight, bias
    except RuntimeError as e:
        if(str(e).startswith("CUDA out of memory")):
            print("CUDA out of memory")
        if(num_icl > 1):
            num_icl -= 1
            print("trying with smaller icl >> ", num_icl)
            return get_averaged_JB(top_performers, relation_prompt, num_icl, calculate_at_lnf)
        else:
            raise Exception("RuntimeError >> can't calculate Jacobian with minimum number of icl examples")

def get_multiple_averaged_JB(top_performers, relation_prompt, N = 3, num_icl = 2, calculate_at_lnf = False):
    weights_and_biases = []
    sample_size = min(len(top_performers), num_icl + 2)
    for _ in tqdm(range(N)):
        cur_sample = random.sample(top_performers, k = sample_size)
        weight, bias = get_averaged_JB(cur_sample, relation_prompt, num_icl, calculate_at_lnf)
        weights_and_biases.append({
            'weight': weight,
            'bias'  : bias
        })
    return weights_and_biases

In [23]:
samples = [
        (b, -1, h) for b, h in filter_by_model_knowledge[:20]
    ]
print(samples)

weights_and_biases = get_multiple_averaged_JB(
    samples, 
    relation_prompt=" {} is a", 
    N = 3, 
    calculate_at_lnf=False
)

[('oak', -1, 'tree'), ('diamond', -1, 'gem'), ('happiness', -1, 'feeling'), ('family', -1, 'group'), ('thesaurus', -1, 'dictionary'), ('crow', -1, 'bird'), ('tennis', -1, 'sport'), ('salmon', -1, 'fish'), ('flower', -1, 'plant'), ('rosemary', -1, 'herb'), ('cucumber', -1, 'vegetable'), ('roulette', -1, 'game'), ('physics', -1, 'science'), ('earth', -1, 'planet'), ('sun', -1, 'star'), ('coffee', -1, 'beverage'), ('car', -1, 'vehicle'), ('yellow', -1, 'color'), ('fan', -1, 'device'), ('judaism', -1, 'religion')]


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

subject:  cucumber
 family is a group.
 roulette is a game.
 {} is a
tensor(43.0938, device='cuda:0', dtype=torch.float16) tensor(248.8750, device='cuda:0', dtype=torch.float16)

subject:  roulette
 coffee is a beverage.
 cucumber is a vegetable.
 {} is a
tensor(57.5938, device='cuda:0', dtype=torch.float16) tensor(274.2500, device='cuda:0', dtype=torch.float16)

subject:  family
 roulette is a game.
 coffee is a beverage.
 {} is a
tensor(46.7500, device='cuda:0', dtype=torch.float16) tensor(253.7500, device='cuda:0', dtype=torch.float16)

subject:  coffee
 roulette is a game.
 cucumber is a vegetable.
 {} is a
tensor(40.3438, device='cuda:0', dtype=torch.float16) tensor(264., device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  happiness
 roulette is a game.
 yellow is a color.
 {} is a
tensor(38., device='cuda:0', dtype=torch.float16) tensor(259.7500, device='cuda:0', dtype=torch.float16)

subject:  roulette
 happiness is a feeling.
 yellow is a color.
 {} is a
tensor(53.2812, device='cuda:0', dtype=torch.float16) tensor(281.2500, device='cuda:0', dtype=torch.float16)

subject:  yellow
 fan is a device.
 roulette is a game.
 {} is a
tensor(43.8750, device='cuda:0', dtype=torch.float16) tensor(227.2500, device='cuda:0', dtype=torch.float16)

subject:  fan
 roulette is a game.
 happiness is a feeling.
 {} is a
tensor(52.8750, device='cuda:0', dtype=torch.float16) tensor(256.2500, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  happiness
 cucumber is a vegetable.
 diamond is a gem.
 {} is a
tensor(40.4375, device='cuda:0', dtype=torch.float16) tensor(250.3750, device='cuda:0', dtype=torch.float16)

subject:  diamond
 earth is a planet.
 happiness is a feeling.
 {} is a
tensor(45.1875, device='cuda:0', dtype=torch.float16) tensor(254.3750, device='cuda:0', dtype=torch.float16)

subject:  cucumber
 earth is a planet.
 happiness is a feeling.
 {} is a
tensor(40.7812, device='cuda:0', dtype=torch.float16) tensor(249.6250, device='cuda:0', dtype=torch.float16)

subject:  earth
 cucumber is a vegetable.
 diamond is a gem.
 {} is a
tensor(49.3125, device='cuda:0', dtype=torch.float16) tensor(235.6250, device='cuda:0', dtype=torch.float16)



In [24]:
relation_operator = estimate.RelationOperator(
    model = model,
    tokenizer= tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.stack(
        [wb['weight'] for wb in weights_and_biases]
    ).mean(dim=0),
    # bias = torch.stack(
    #     [wb['bias'] for wb in weights_and_biases]
    # ).mean(dim=0),
    bias = lst_sq_corner,

    layer_name_format = "model.decoder.layers.{}",
    ln_f_name = "model.decoder.final_layer_norm",
)

check_with_test_cases(relation_operator)

summer, target: season   ==>   predicted: [' season', ' color', ' wind', ' herb', ' plant']
meat, target: food   ==>   predicted: [' food', ' dish', ' vegetable', ' material', ' meat']
doll, target: toy   ==>   predicted: [' toy', ' shape', ' picture', ' bird', ' person']
gold, target: metal   ==>   predicted: [' metal', ' material', ' gem', ' color', ' religion']
round, target: shape   ==>   predicted: [' shape', ' sport', ' season', ' wind', ' galaxy']
breeze, target: wind   ==>   predicted: [' wind', ' season', ' feeling', ' herb', ' color']
man, target: human   ==>   predicted: [' human', ' person', ' vehicle', ' fish', ' sport']
hologram, target: picture   ==>   predicted: [' device', ' picture', ' material', ' science', ' film']
paper, target: material   ==>   predicted: [' material', ' wind', ' dish', ' dictionary', ' vehicle']
photographer, target: person   ==>   predicted: [' person', ' science', ' star', ' human', ' sport']
documentary, target: film   ==>   predicted: [' film

In [25]:
corner_estimator.get_vocab_representation(
    torch.stack(
        [wb['bias'] for wb in weights_and_biases]
    ).mean(dim=0)
)

[' type', ' kind', ' good', ' ', ' very']