In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter

In [3]:
MODEL_NAME = "facebook/galactica-6.7b"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
n_embd_field = "hidden_size"
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
# tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_special_tokens({'pad_token': '[PAD]'})

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

facebook/galactica-6.7b ==> device: cuda:0, memory: 13314719744


In [4]:
with open("comperative-superlative.txt") as f:
    lines = f.readlines()
    words = []
    for line in lines:
        w = line.strip()
        if(len(w) == 0):
            continue
        words.append(w)

base = words[0 : len(words) : 3]
comparative = words[1 : len(words) : 3]
superlative = words[2 : len(words) : 3]

list(zip(base, comparative, superlative))

[('angry', 'angrier', 'angriest'),
 ('bad', 'worse', 'worst'),
 ('big', 'bigger', 'biggest'),
 ('bitter', 'bitterer', 'bitterest'),
 ('black', 'blacker', 'blackest'),
 ('bland', 'blander', 'blandest'),
 ('bloody', 'bloodier', 'bloodiest'),
 ('blue', 'bluer', 'bluest'),
 ('bold', 'bolder', 'boldest'),
 ('bossy', 'bossier', 'bossiest'),
 ('brave', 'braver', 'bravest'),
 ('brief', 'briefer', 'briefest'),
 ('bright', 'brighter', 'brightest'),
 ('broad', 'broader', 'broadest'),
 ('busy', 'busier', 'busiest'),
 ('calm', 'calmer', 'calmest'),
 ('cheap', 'cheaper', 'cheapest'),
 ('chewy', 'chewier', 'chewiest'),
 ('chubby', 'chubbier', 'chubbiest'),
 ('classy', 'classier', 'classiest'),
 ('clean', 'cleaner', 'cleanest'),
 ('clear', 'clear', 'clearest'),
 ('clever', 'cleverer', 'cleverest'),
 ('close', 'closer', 'closest'),
 ('cloudy', 'cloudier', 'cloudiest'),
 ('clumsy', 'clumsier', 'clumsiest'),
 ('coarse', 'coarser', 'coarsest'),
 ('cold', 'colder', 'coldest'),
 ('cool', 'cooler', 'coolest'

In [5]:
tokenizer([' class', ' classy', ' classier', ' classiest'], padding=True, return_tensors='pt')

{'input_ids': tensor([[ 1067, 50000],
        [ 1067,   111],
        [ 1067,  1331],
        [ 1067, 24072]]), 'token_type_ids': tensor([[0, 0],
        [0, 0],
        [0, 0],
        [0, 0]]), 'attention_mask': tensor([[1, 0],
        [1, 1],
        [1, 1],
        [1, 1]])}

In [6]:
filter_single_token = []
for b, c, s in zip(base, comparative, superlative):
    tokenized = tokenizer([" " + b, " " + c, " " + s], padding=True, return_tensors='pt').input_ids
    if(tokenized.shape[1] > 1):
        continue
    filter_single_token.append((b, c, s))
filter_single_token

[('bad', 'worse', 'worst'),
 ('big', 'bigger', 'biggest'),
 ('bright', 'brighter', 'brightest'),
 ('close', 'closer', 'closest'),
 ('deep', 'deeper', 'deepest'),
 ('early', 'earlier', 'earliest'),
 ('easy', 'easier', 'easiest'),
 ('fast', 'faster', 'fastest'),
 ('fine', 'finer', 'finest'),
 ('good', 'better', 'best'),
 ('great', 'greater', 'greatest'),
 ('high', 'higher', 'highest'),
 ('large', 'larger', 'largest'),
 ('late', 'later', 'latest'),
 ('light', 'lighter', 'lightest'),
 ('long', 'longer', 'longest'),
 ('low', 'lower', 'lowest'),
 ('short', 'shorter', 'shortest'),
 ('simple', 'simpler', 'simplest'),
 ('small', 'smaller', 'smallest'),
 ('steep', 'steeper', 'steepest'),
 ('strong', 'stronger', 'strongest'),
 ('weak', 'weaker', 'weakest'),
 ('young', 'younger', 'youngest')]

In [7]:
prompt = """superlative of late is latest
superlative of strong is strongest
superlaitve of {} is"""

# prompt = """grape ends with E
# monitor ends with R
# glass ends with"""

words = ['strong', 'big', 'deep', 'young']

for w in words:
    txt, ret_dict = model_utils.generate_fast(
        model, tokenizer, 
        prompts=[prompt.format(w)], max_new_tokens=10, 
        get_answer_tokens=True, argmax_greedy=True
    )
    print(f"{w} ===> {ret_dict['answer'][0]['top_token']}")


strong ===>  strongest
big ===>  biggest
deep ===>  deepest
young ===>  youngest


In [8]:
objects = [" " + o[2] for o in filter_single_token]

from relations.corner import CornerEstimator
corner_estimator = CornerEstimator(
    model=model, tokenizer=tokenizer,
    ln_f_name= "model.decoder.final_layer_norm", 
    unembedder_module_name="lm_head"
)

In [9]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner, get_logits=True))

50.5625 [(' longest', 54.344), (' fastest', 53.938), (' smallest', 53.312), (' strongest', 52.75), (' weakest', 52.562)]


In [10]:
lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)
print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner, get_logits=True))

calculating inverse of unbedding weights . . .
28.59375 [(' largest', 46.594), (' longest', 46.5), (' smallest', 45.125), (' fastest', 44.969), (' strongest', 44.656)]


In [11]:
lst_sq_corner = corner_estimator.estimate_corner_lstsq_solve(objects[:15], target_logit=30)
print(lst_sq_corner.norm().item(), corner_estimator.get_vocab_representation(lst_sq_corner, get_logits=True))

41.875 [(' greatest', 45.875), (' largest', 45.844), (' best', 45.844), (' worst', 45.844), (' highest', 45.844)]


In [12]:
# avg_corner = corner_estimator.estimate_average_corner_with_gradient_descent(objects, average_on=5, target_logit_value=50, verbose=False)
# print(avg_corner.norm().item(), corner_estimator.get_vocab_representation(avg_corner))

In [13]:
def check_with_test_cases(relation_operator):
    test_cases = [
        (b, -1, s) for b, c, s in filter_single_token[15:]
    ]
    # print(test_cases)
    for subject, subject_token_index, target in test_cases:
        answer = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {answer}")

In [14]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = prompt,
    layer = 15, layer_name_format = "model.decoder.layers.{}",
    weight = torch.eye(getattr(model.config, n_embd_field)).to(model.dtype).to(model.device),
    bias = lst_sq_corner,

    ln_f_name = "model.decoder.final_layer_norm"
)

check_with_test_cases(relation)

long, target: longest   ==>   predicted: [' best', ' greatest', ' highest', ' largest', ' worst']
low, target: lowest   ==>   predicted: [' best', ' greatest', ' worst', ' highest', ' strongest']
short, target: shortest   ==>   predicted: [' best', ' greatest', ' highest', ' worst', ' largest']
simple, target: simplest   ==>   predicted: [' best', ' greatest', ' highest', ' largest', ' worst']
small, target: smallest   ==>   predicted: [' greatest', ' highest', ' best', ' worst', ' latest']
steep, target: steepest   ==>   predicted: [' best', ' worst', ' greatest', ' highest', 'est']
strong, target: strongest   ==>   predicted: [' best', ' worst', ' greatest', ' highest', ' largest']
weak, target: weakest   ==>   predicted: [' best', ' greatest', ' highest', ' worst', ' largest']
young, target: youngest   ==>   predicted: [' best', ' highest', ' greatest', ' latest', ' worst']


In [26]:
def get_averaged_JB(top_performers, relation_prompt, num_icl = 3, calculate_at_lnf = False):
    try:
        jbs = []
        for s, s_idx, o in tqdm(top_performers):
            others = set(top_performers) - {(s, s_idx, o)}
            others = random.sample(list(others), k = min(num_icl, len(list(others)))) 
            prompt = ""
            prompt += "\n".join(relation_prompt.format(s_other) + f" {o_other}." for s_other, idx_other, o_other in others) + "\n"
            prompt += relation_prompt
            print("subject: ", s)
            print(prompt)

            jb, _ = estimate.relation_operator_from_sample(
                model, tokenizer,
                s, prompt,
                subject_token_index= s_idx,
                layer = 15,
                device = model.device,
                # calculate_at_lnf = calculate_at_lnf

                layer_name_format = "model.decoder.layers.{}",
                ln_f_name = "model.decoder.final_layer_norm",
                n_layer_field = "num_hidden_layers"
            )
            print(jb.weight.norm(), jb.bias.norm())
            print()
            jbs.append(jb)
        
        weight = torch.stack([jb.weight for jb in jbs]).mean(dim=0)
        bias  = torch.stack([jb.bias for jb in jbs]).mean(dim=0)

        return weight, bias
    except RuntimeError as e:
        if(str(e).startswith("CUDA out of memory")):
            print("CUDA out of memory")
        if(num_icl > 1):
            num_icl -= 1
            print("trying with smaller icl >> ", num_icl)
            return get_averaged_JB(top_performers, relation_prompt, num_icl, calculate_at_lnf)
        else:
            raise Exception("RuntimeError >> can't calculate Jacobian with minimum number of icl examples")

def get_multiple_averaged_JB(top_performers, relation_prompt, N = 3, num_icl = 2, calculate_at_lnf = False):
    weights_and_biases = []
    sample_size = min(len(top_performers), num_icl + 2)
    for _ in tqdm(range(N)):
        cur_sample = random.sample(top_performers, k = sample_size)
        weight, bias = get_averaged_JB(cur_sample, relation_prompt, num_icl, calculate_at_lnf)
        weights_and_biases.append({
            'weight': weight,
            'bias'  : bias
        })
    return weights_and_biases

In [27]:
samples = [
        (b, -1, s) for b, c, s in filter_single_token[:10]
    ]
print(samples)

weights_and_biases = get_multiple_averaged_JB(
    samples, 
    relation_prompt="superlative of {} is", 
    N = 3, 
    calculate_at_lnf=False
)

[('bad', -1, 'worst'), ('big', -1, 'biggest'), ('bright', -1, 'brightest'), ('close', -1, 'closest'), ('deep', -1, 'deepest'), ('early', -1, 'earliest'), ('easy', -1, 'easiest'), ('fast', -1, 'fastest'), ('fine', -1, 'finest'), ('good', -1, 'best')]


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

subject:  bad
superlative of big is biggest.
superlative of bright is brightest.
superlative of {} is
tensor(52.3438, device='cuda:0', dtype=torch.float16) tensor(254.1250, device='cuda:0', dtype=torch.float16)

subject:  bright
superlative of easy is easiest.
superlative of big is biggest.
superlative of {} is
tensor(44.6875, device='cuda:0', dtype=torch.float16) tensor(258.7500, device='cuda:0', dtype=torch.float16)

subject:  easy
superlative of bright is brightest.
superlative of bad is worst.
superlative of {} is
tensor(56.1250, device='cuda:0', dtype=torch.float16) tensor(319.2500, device='cuda:0', dtype=torch.float16)

subject:  big
superlative of bright is brightest.
superlative of easy is easiest.
superlative of {} is
tensor(49.0938, device='cuda:0', dtype=torch.float16) tensor(278.7500, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  fast
superlative of close is closest.
superlative of easy is easiest.
superlative of {} is
tensor(46.3750, device='cuda:0', dtype=torch.float16) tensor(260.5000, device='cuda:0', dtype=torch.float16)

subject:  bad
superlative of close is closest.
superlative of fast is fastest.
superlative of {} is
tensor(55.7500, device='cuda:0', dtype=torch.float16) tensor(254.3750, device='cuda:0', dtype=torch.float16)

subject:  close
superlative of fast is fastest.
superlative of bad is worst.
superlative of {} is
tensor(57.1250, device='cuda:0', dtype=torch.float16) tensor(301.5000, device='cuda:0', dtype=torch.float16)

subject:  easy
superlative of close is closest.
superlative of bad is worst.
superlative of {} is
tensor(56.6562, device='cuda:0', dtype=torch.float16) tensor(256.7500, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  early
superlative of fine is finest.
superlative of fast is fastest.
superlative of {} is
tensor(55.5625, device='cuda:0', dtype=torch.float16) tensor(268., device='cuda:0', dtype=torch.float16)

subject:  fine
superlative of fast is fastest.
superlative of early is earliest.
superlative of {} is
tensor(56.5625, device='cuda:0', dtype=torch.float16) tensor(265.5000, device='cuda:0', dtype=torch.float16)

subject:  close
superlative of fine is finest.
superlative of early is earliest.
superlative of {} is
tensor(56.3438, device='cuda:0', dtype=torch.float16) tensor(259.5000, device='cuda:0', dtype=torch.float16)

subject:  fast
superlative of fine is finest.
superlative of close is closest.
superlative of {} is
tensor(49.7500, device='cuda:0', dtype=torch.float16) tensor(265.2500, device='cuda:0', dtype=torch.float16)



In [28]:
'shortest' in objects[:30]

False

In [30]:
relation_operator = estimate.RelationOperator(
    model = model,
    tokenizer= tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.stack(
        [wb['weight'] for wb in weights_and_biases]
    ).mean(dim=0),
    # bias = torch.stack(
    #     [wb['bias'] for wb in weights_and_biases]
    # ).mean(dim=0),
    bias = lst_sq_corner,

    layer_name_format = "model.decoder.layers.{}",
    ln_f_name = "model.decoder.final_layer_norm",
)

check_with_test_cases(relation_operator)

long, target: longest   ==>   predicted: [' longest', ' deepest', ' longer', ' largest', ' latest']
low, target: lowest   ==>   predicted: [' lowest', 'lowest', ' lightest', ' weakest', ' smallest']
short, target: shortest   ==>   predicted: [' shortest', ' longest', ' smallest', ' lightest', ' earliest']
simple, target: simplest   ==>   predicted: [' simplest', ' easiest', ' lightest', ' smallest', ' finest']
small, target: smallest   ==>   predicted: [' smallest', ' lowest', ' lightest', ' weakest', ' finest']
steep, target: steepest   ==>   predicted: [' steepest', ' deepest', ' highest', ' steep', ' strongest']
strong, target: strongest   ==>   predicted: [' strongest', ' brightest', ' best', ' greatest', ' highest']
weak, target: weakest   ==>   predicted: [' weakest', ' strongest', ' lightest', ' worst', ' lowest']
young, target: youngest   ==>   predicted: [' youngest', ' oldest', ' young', ' earliest', ' youth']


In [31]:
corner_estimator.get_vocab_representation(
    torch.stack(
        [wb['bias'] for wb in weights_and_biases]
    ).mean(dim=0), get_logits=True
)

[(' most', 16.906),
 (' best', 16.422),
 (' the', 15.648),
 (' greatest', 13.75),
 (' ', 13.617)]