In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter

In [3]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

EleutherAI/gpt-j-6B ==> device: cuda:0, memory: 12219206136


In [4]:
with open("comperative-superlative.txt") as f:
    lines = f.readlines()
    words = []
    for line in lines:
        w = line.strip()
        if(len(w) == 0):
            continue
        words.append(w)

base = words[0 : len(words) : 3]
comparative = words[1 : len(words) : 3]
superlative = words[2 : len(words) : 3]

list(zip(base, comparative, superlative))

[('angry', 'angrier', 'angriest'),
 ('bad', 'worse', 'worst'),
 ('big', 'bigger', 'biggest'),
 ('bitter', 'bitterer', 'bitterest'),
 ('black', 'blacker', 'blackest'),
 ('bland', 'blander', 'blandest'),
 ('bloody', 'bloodier', 'bloodiest'),
 ('blue', 'bluer', 'bluest'),
 ('bold', 'bolder', 'boldest'),
 ('bossy', 'bossier', 'bossiest'),
 ('brave', 'braver', 'bravest'),
 ('brief', 'briefer', 'briefest'),
 ('bright', 'brighter', 'brightest'),
 ('broad', 'broader', 'broadest'),
 ('busy', 'busier', 'busiest'),
 ('calm', 'calmer', 'calmest'),
 ('cheap', 'cheaper', 'cheapest'),
 ('chewy', 'chewier', 'chewiest'),
 ('chubby', 'chubbier', 'chubbiest'),
 ('classy', 'classier', 'classiest'),
 ('clean', 'cleaner', 'cleanest'),
 ('clear', 'clear', 'clearest'),
 ('clever', 'cleverer', 'cleverest'),
 ('close', 'closer', 'closest'),
 ('cloudy', 'cloudier', 'cloudiest'),
 ('clumsy', 'clumsier', 'clumsiest'),
 ('coarse', 'coarser', 'coarsest'),
 ('cold', 'colder', 'coldest'),
 ('cool', 'cooler', 'coolest'

In [5]:
tokenizer([' classy', ' classier', ' classiest'], padding=True, return_tensors='pt')

{'input_ids': tensor([[48486, 50256],
        [ 1398,   959],
        [ 1398,  6386]]), 'attention_mask': tensor([[1, 0],
        [1, 1],
        [1, 1]])}

In [6]:
filter_single_token = []
for b, c, s in zip(base, comparative, superlative):
    tokenized = tokenizer([" " + b, " " + c, " " + s], padding=True, return_tensors='pt').input_ids
    if(tokenized.shape[1] > 1):
        continue
    filter_single_token.append((b, c, s))
filter_single_token

[('bad', 'worse', 'worst'),
 ('big', 'bigger', 'biggest'),
 ('bright', 'brighter', 'brightest'),
 ('cheap', 'cheaper', 'cheapest'),
 ('close', 'closer', 'closest'),
 ('cool', 'cooler', 'coolest'),
 ('dark', 'darker', 'darkest'),
 ('deep', 'deeper', 'deepest'),
 ('early', 'earlier', 'earliest'),
 ('easy', 'easier', 'easiest'),
 ('fast', 'faster', 'fastest'),
 ('fine', 'finer', 'finest'),
 ('full', 'fuller', 'fullest'),
 ('good', 'better', 'best'),
 ('great', 'greater', 'greatest'),
 ('happy', 'happier', 'happiest'),
 ('hard', 'harder', 'hardest'),
 ('heavy', 'heavier', 'heaviest'),
 ('high', 'higher', 'highest'),
 ('hot', 'hotter', 'hottest'),
 ('large', 'larger', 'largest'),
 ('late', 'later', 'latest'),
 ('long', 'longer', 'longest'),
 ('low', 'lower', 'lowest'),
 ('near', 'nearer', 'nearest'),
 ('new', 'newer', 'newest'),
 ('poor', 'poorer', 'poorest'),
 ('quick', 'quicker', 'quickest'),
 ('rich', 'richer', 'richest'),
 ('safe', 'safer', 'safest'),
 ('short', 'shorter', 'shortest'),


In [7]:
prompt = """superlative of late is latest
superlative of strong is strongest
superlaitve of {} is"""

# prompt = """grape ends with E
# monitor ends with R
# glass ends with"""

words = ['safe', 'low', 'weak', 'tough']

for w in words:
    txt, ret_dict = model_utils.generate_fast(
        model, tokenizer, 
        prompts=[prompt.format(w)], max_new_tokens=10, 
        get_answer_tokens=True, argmax_greedy=True
    )
    print(f"{w} ===> {ret_dict['answer'][0]['top_token']}")


safe ===>  safest
low ===>  lowest
weak ===>  weakest
tough ===>  toughest


In [91]:
objects = [" " + o[2] for o in filter_single_token]

from relations.corner import CornerEstimator
corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)

In [92]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner, get_logits=True))

51.53125 [(' largest', 104.312), (' highest', 103.062), (' longest', 99.688), (' strongest', 99.375), (' smallest', 99.125)]


In [93]:
lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)
print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner, get_logits=True))

calculating inverse of unbedding weights . . .
48.71875 [(' strongest', 74.875), (' fastest', 73.938), (' longest', 73.375), (' largest', 72.0), (' smallest', 70.562)]


In [96]:
lst_sq_corner = corner_estimator.estimate_corner_lstsq_solve(objects[:30], target_logit=30)
print(lst_sq_corner.norm().item(), corner_estimator.get_vocab_representation(lst_sq_corner, get_logits=True))

42.34375 [(' strongest', 85.0), (' best', 83.562), (' nearest', 83.375), (' closest', 83.312), (' highest', 83.188)]


In [97]:
# avg_corner = corner_estimator.estimate_average_corner_with_gradient_descent(objects, average_on=5, target_logit_value=50, verbose=False)
# print(avg_corner.norm().item(), corner_estimator.get_vocab_representation(avg_corner))

In [98]:
def check_with_test_cases(relation_operator):
    test_cases = [
        (b, -1, s) for b, c, s in filter_single_token[30:]
    ]
    for subject, subject_token_index, target in test_cases:
        answer = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {answer}")

In [99]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = lst_sq_corner
)
check_with_test_cases(relation)

short, target: shortest   ==>   predicted: [' highest', ' largest', ' greatest', ' smallest', ' lowest']
simple, target: simplest   ==>   predicted: [' lowest', ' largest', ' highest', ' greatest', ' smallest']
small, target: smallest   ==>   predicted: [' highest', ' best', ' largest', ' strongest', ' lowest']
smart, target: smartest   ==>   predicted: [' highest', ' best', ' lowest', ' closest', ' smallest']
strong, target: strongest   ==>   predicted: [' nearest', ' lowest', ' highest', ' best', ' greatest']
tall, target: tallest   ==>   predicted: [' smallest', ' largest', ' highest', ' best', ' greatest']
tough, target: toughest   ==>   predicted: [' lowest', ' highest', ' best', ' worst', ' greatest']
weak, target: weakest   ==>   predicted: [' lowest', ' largest', ' smallest', ' highest', ' greatest']
wealthy, target: wealthiest   ==>   predicted: [' largest', ' greatest', ' smallest', ' highest', ' best']
wide, target: widest   ==>   predicted: [' greatest', ' smallest', ' larg

In [72]:
def get_averaged_JB(top_performers, relation_prompt, num_icl = 3, calculate_at_lnf = False):
    try:
        jbs = []
        for s, s_idx, o in tqdm(top_performers):
            others = set(top_performers) - {(s, s_idx, o)}
            others = random.sample(list(others), k = min(num_icl, len(list(others)))) 
            prompt = ""
            prompt += "\n".join(relation_prompt.format(s_other) + f" {o_other}." for s_other, idx_other, o_other in others) + "\n"
            prompt += relation_prompt
            print("subject: ", s)
            print(prompt)

            jb, _ = estimate.relation_operator_from_sample(
                model, tokenizer,
                s, prompt,
                subject_token_index= s_idx,
                layer = 15,
                device = model.device,
                # calculate_at_lnf = calculate_at_lnf
            )
            print(jb.weight.norm(), jb.bias.norm())
            print()
            jbs.append(jb)
        
        weight = torch.stack([jb.weight for jb in jbs]).mean(dim=0)
        bias  = torch.stack([jb.bias for jb in jbs]).mean(dim=0)

        return weight, bias
    except RuntimeError as e:
        if(str(e).startswith("CUDA out of memory")):
            print("CUDA out of memory")
        if(num_icl > 1):
            num_icl -= 1
            print("trying with smaller icl >> ", num_icl)
            return get_averaged_JB(top_performers, relation_prompt, num_icl, calculate_at_lnf)
        else:
            raise Exception("RuntimeError >> can't calculate Jacobian with minimum number of icl examples")

def get_multiple_averaged_JB(top_performers, relation_prompt, N = 3, num_icl = 2, calculate_at_lnf = False):
    weights_and_biases = []
    sample_size = min(len(top_performers), num_icl + 2)
    for _ in tqdm(range(N)):
        cur_sample = random.sample(top_performers, k = sample_size)
        weight, bias = get_averaged_JB(cur_sample, relation_prompt, num_icl, calculate_at_lnf)
        weights_and_biases.append({
            'weight': weight,
            'bias'  : bias
        })
    return weights_and_biases

In [74]:
samples = [
        (b, -1, s) for b, c, s in filter_single_token[:10]
    ]
print(samples)

weights_and_biases = get_multiple_averaged_JB(
    samples, 
    relation_prompt="superlative of {} is", 
    N = 3, 
    calculate_at_lnf=False
)

[('bad', -1, 'worst'), ('big', -1, 'biggest'), ('bright', -1, 'brightest'), ('cheap', -1, 'cheapest'), ('close', -1, 'closest'), ('cool', -1, 'coolest'), ('dark', -1, 'darkest'), ('deep', -1, 'deepest'), ('early', -1, 'earliest'), ('easy', -1, 'easiest')]


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

subject:  easy
superlative of bad is worst.
superlative of dark is darkest.
superlative of {} is
tensor(31.9062, device='cuda:0', dtype=torch.float16) tensor(298., device='cuda:0', dtype=torch.float16)

subject:  close
superlative of dark is darkest.
superlative of easy is easiest.
superlative of {} is
tensor(29.2656, device='cuda:0', dtype=torch.float16) tensor(304., device='cuda:0', dtype=torch.float16)

subject:  bad
superlative of close is closest.
superlative of dark is darkest.
superlative of {} is
tensor(31.6250, device='cuda:0', dtype=torch.float16) tensor(283., device='cuda:0', dtype=torch.float16)

subject:  dark
superlative of close is closest.
superlative of easy is easiest.
superlative of {} is
tensor(33.5312, device='cuda:0', dtype=torch.float16) tensor(272., device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  early
superlative of big is biggest.
superlative of easy is easiest.
superlative of {} is
tensor(30.7031, device='cuda:0', dtype=torch.float16) tensor(272., device='cuda:0', dtype=torch.float16)

subject:  cheap
superlative of easy is easiest.
superlative of big is biggest.
superlative of {} is
tensor(28.9688, device='cuda:0', dtype=torch.float16) tensor(284.2500, device='cuda:0', dtype=torch.float16)

subject:  easy
superlative of big is biggest.
superlative of cheap is cheapest.
superlative of {} is
tensor(30.8750, device='cuda:0', dtype=torch.float16) tensor(292.5000, device='cuda:0', dtype=torch.float16)

subject:  big
superlative of easy is easiest.
superlative of early is earliest.
superlative of {} is
tensor(27.0156, device='cuda:0', dtype=torch.float16) tensor(278.2500, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  bad
superlative of big is biggest.
superlative of cheap is cheapest.
superlative of {} is
tensor(28.6250, device='cuda:0', dtype=torch.float16) tensor(289., device='cuda:0', dtype=torch.float16)

subject:  easy
superlative of cheap is cheapest.
superlative of big is biggest.
superlative of {} is
tensor(29.8594, device='cuda:0', dtype=torch.float16) tensor(308.5000, device='cuda:0', dtype=torch.float16)

subject:  big
superlative of bad is worst.
superlative of cheap is cheapest.
superlative of {} is
tensor(25.4844, device='cuda:0', dtype=torch.float16) tensor(307., device='cuda:0', dtype=torch.float16)

subject:  cheap
superlative of easy is easiest.
superlative of big is biggest.
superlative of {} is
tensor(28.9688, device='cuda:0', dtype=torch.float16) tensor(284.2500, device='cuda:0', dtype=torch.float16)



In [75]:
'shorter' in objects[:30]

False

In [100]:
relation_operator = estimate.RelationOperator(
    model = model,
    tokenizer= tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.stack(
        [wb['weight'] for wb in weights_and_biases]
    ).mean(dim=0),
    # bias = torch.stack(
    #     [wb['bias'] for wb in weights_and_biases]
    # ).mean(dim=0),
    bias = lst_sq_corner
)

check_with_test_cases(relation_operator)

short, target: shortest   ==>   predicted: [' shortest', ' longest', ' quickest', ' smallest', ' earliest']
simple, target: simplest   ==>   predicted: [' simplest', ' easiest', ' cheapest', ' smallest', ' best']
small, target: smallest   ==>   predicted: [' smallest', ' lowest', ' cheapest', ' poorest', ' largest']
smart, target: smartest   ==>   predicted: [' smartest', ' best', ' brightest', ' easiest', ' quickest']
strong, target: strongest   ==>   predicted: [' strongest', ' weakest', ' heaviest', ' hardest', ' best']
tall, target: tallest   ==>   predicted: [' tallest', ' longest', ' shortest', ' highest', ' best']
tough, target: toughest   ==>   predicted: [' toughest', ' hardest', ' strongest', ' worst', ' best']
weak, target: weakest   ==>   predicted: [' weakest', ' strongest', ' poorest', ' lowest', ' worst']
wealthy, target: wealthiest   ==>   predicted: [' richest', ' cheapest', ' poorest', ' best', ' wealthiest']
wide, target: widest   ==>   predicted: [' widest', ' large

In [77]:
corner_estimator.get_vocab_representation(
    torch.stack(
        [wb['bias'] for wb in weights_and_biases]
    ).mean(dim=0)
)

[' most', ' the', '\n', ' least', ' ']

In [90]:
from typing import Any, Sequence, TypeAlias, List

unembedder = nethook.get_module(model, "lm_head")
ln_f = nethook.get_module(model, "transformer.ln_f")

def estimate_corner_lstsq_solve(
    target_words: List[str],
    target_logit: int = 50,
):
    target_tokenized = tokenizer(target_words, padding=True, return_tensors="pt").to(model.device)
    # print(target_tokenized)
    W = torch.stack([unembedder.weight[r[0].item()] for r in target_tokenized.input_ids])
    # print(target_tokenized.input_ids.shape, W.shape)
    b = unembedder.bias[target_tokenized.input_ids]
    b = b.reshape(b.shape[0])
    y = (torch.ones(len(target_words)) * target_logit).to(model.dtype).to(model.device) - b
    # print(b.shape, y.shape)
    if(model.dtype == torch.float16):
        W = W.to(torch.float32)
        y = y.to(torch.float32)
    x = torch.linalg.lstsq(W, y).solution
    print(W@x + b)
    return x.to(model.dtype)

corner = estimate_corner_lstsq_solve(objects)
corner.shape, corner.norm()

tensor([50.0028, 49.9846, 50.0143, 49.9919, 50.0005, 50.0105, 49.9887, 50.0128,
        50.0124, 49.9906, 49.9848, 49.9848, 49.9904, 50.0138, 49.9965, 49.9854,
        49.9973, 50.0067, 50.0075, 50.0016, 49.9978, 49.9894, 50.0059, 50.0106,
        50.0045, 50.0000, 50.0047, 50.0028, 49.9897, 50.0096, 50.0153, 50.0059,
        50.0144, 50.0063, 50.0045, 49.9968, 50.0085, 49.9871, 49.9907, 50.0133,
        50.0077], device='cuda:0')


(torch.Size([4096]), tensor(74., device='cuda:0', dtype=torch.float16))

In [28]:
def get_vocab_representation(
    h, 
    perform_layer_norm = True, return_top_k = 5, get_logits = False
):
    """
    get representation of vector `h` in the vocabulary space. basically applied logit lens
    """
    z = h.clone()
    if(perform_layer_norm == True):
        z = ln_f(z)
    logits = unembedder(z)
    token_ids = logits.topk(dim=-1, k=return_top_k).indices.squeeze().tolist()
    logit_values = logits.topk(dim=-1, k=return_top_k).values.squeeze().tolist()
    return [
        tokenizer.decode(t) if get_logits == False else (tokenizer.decode(t), np.round(v, 3))
        for t, v in zip(token_ids, logit_values)
    ]

get_vocab_representation(corner, get_logits=True)

[(' more', 75.688),
 (' lighter', 75.5),
 (' later', 74.438),
 (' lower', 74.312),
 (' better', 74.312)]