In [22]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [23]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter

In [24]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

EleutherAI/gpt-j-6B ==> device: cuda:0, memory: 12219206136


In [25]:
prompt = """grape starts with G
monitor starts with M
{} starts with"""

words = ['month', 'major', 'star', 'areas', 'future', 'space', 'committee', 'london', 'washington', 'meeting']

# prompt = """grape ends with E
# monitor ends with R
# glass ends with"""

for w in words:
    txt, ret_dict = model_utils.generate_fast(
        model, tokenizer, 
        prompts=[prompt.format(w)], max_new_tokens=10, 
        get_answer_tokens=True, argmax_greedy=True
    )
    print(f"{w} ===> {ret_dict['answer'][0]['top_token']}")


month ===>  M
major ===>  A
star ===>  S
areas ===>  A
future ===>  F
space ===>  S
committee ===>  C
london ===>  L
washington ===>  W
meeting ===>  M


In [26]:
from relations.corner import CornerEstimator
corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)

In [27]:
import string
objects = list(string.ascii_uppercase)
objects = [" " + o for o in objects]

In [28]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner))

66.875 [' C', ' S', ' M', ' L', ' P']


In [29]:
# lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)
# print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner))

In [31]:
# avg_corner = corner_estimator.estimate_average_corner_with_gradient_descent(objects, average_on=5, target_logit_value=50, verbose=False)
# print(avg_corner.norm().item(), corner_estimator.get_vocab_representation(avg_corner))

In [32]:
def check_with_test_cases(relation_operator):

    test_cases = [
        (w, -1, w[0].upper()) for w in words
    ]

    for subject, subject_token_index, target in test_cases:
        objects = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {objects}")

In [33]:
relation = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = simple_corner
)
check_with_test_cases(relation)

month, target: M   ==>   predicted: [' C', ' S', ' M', ' P', ' B']
major, target: M   ==>   predicted: [' M', ' C', ' D', ' P', ' S']
star, target: S   ==>   predicted: [' S', ' M', ' T', ' P', ' C']
areas, target: A   ==>   predicted: [' C', ' B', ' S', ' P', ' F']
future, target: F   ==>   predicted: [' T', ' C', ' E', ' D', ' B']
space, target: S   ==>   predicted: [' S', ' M', ' P', ' B', ' R']
committee, target: C   ==>   predicted: [' C', ' B', ' T', ' D', ' S']
london, target: L   ==>   predicted: [' C', ' B', ' S', ' D', ' P']
washington, target: W   ==>   predicted: [' D', ' C', ' L', ' S', ' B']
meeting, target: M   ==>   predicted: [' C', ' T', ' S', ' M', ' B']


In [34]:
def get_averaged_JB(top_performers, relation_prompt, num_icl = 3, calculate_at_lnf = False):
    try:
        jbs = []
        for s, s_idx, o in tqdm(top_performers):
            others = set(top_performers) - {(s, s_idx, o)}
            others = random.sample(list(others), k = min(num_icl, len(list(others)))) 
            prompt = ""
            prompt += "\n".join(relation_prompt.format(s_other) + f" {o_other}." for s_other, idx_other, o_other in others) + "\n"
            prompt += relation_prompt
            print("subject: ", s)
            print(prompt)

            jb, _ = estimate.relation_operator_from_sample(
                model, tokenizer,
                s, prompt,
                subject_token_index= s_idx,
                layer = 15,
                device = model.device,
                # calculate_at_lnf = calculate_at_lnf
            )
            print(jb.weight.norm(), jb.bias.norm())
            print()
            jbs.append(jb)
        
        weight = torch.stack([jb.weight for jb in jbs]).mean(dim=0)
        bias  = torch.stack([jb.bias for jb in jbs]).mean(dim=0)

        return weight, bias
    except RuntimeError as e:
        if(str(e).startswith("CUDA out of memory")):
            print("CUDA out of memory")
        if(num_icl > 1):
            num_icl -= 1
            print("trying with smaller icl >> ", num_icl)
            return get_averaged_JB(top_performers, relation_prompt, num_icl, calculate_at_lnf)
        else:
            raise Exception("RuntimeError >> can't calculate Jacobian with minimum number of icl examples")

def get_multiple_averaged_JB(top_performers, relation_prompt, N = 3, num_icl = 2, calculate_at_lnf = False):
    weights_and_biases = []
    sample_size = min(len(top_performers), num_icl + 2)
    for _ in tqdm(range(N)):
        cur_sample = random.sample(top_performers, k = sample_size)
        weight, bias = get_averaged_JB(cur_sample, relation_prompt, num_icl, calculate_at_lnf)
        weights_and_biases.append({
            'weight': weight,
            'bias'  : bias
        })
    return weights_and_biases

In [35]:
samples = [(w, -1, " " + w[0].upper()) for w in words][1:5]
print(samples)

weights_and_biases = get_multiple_averaged_JB(
    samples, 
    relation_prompt="{} starts with", 
    N = 3, 
    calculate_at_lnf=False
)

[('major', -1, ' M'), ('star', -1, ' S'), ('areas', -1, ' A'), ('future', -1, ' F')]


  0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/4 [00:00<?, ?it/s]

subject:  future
star starts with  S.
major starts with  M.
{} starts with
tensor(22.5000, device='cuda:0', dtype=torch.float16) tensor(387.7500, device='cuda:0', dtype=torch.float16)

subject:  areas
star starts with  S.
future starts with  F.
{} starts with
tensor(17.0156, device='cuda:0', dtype=torch.float16) tensor(412.5000, device='cuda:0', dtype=torch.float16)

subject:  major
future starts with  F.
areas starts with  A.
{} starts with
tensor(19.5781, device='cuda:0', dtype=torch.float16) tensor(399.7500, device='cuda:0', dtype=torch.float16)

subject:  star
areas starts with  A.
major starts with  M.
{} starts with
tensor(1.1416, device='cuda:0', dtype=torch.float16) tensor(413.5000, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  major
star starts with  S.
areas starts with  A.
{} starts with
tensor(19.8438, device='cuda:0', dtype=torch.float16) tensor(407.7500, device='cuda:0', dtype=torch.float16)

subject:  future
major starts with  M.
star starts with  S.
{} starts with
tensor(21.2031, device='cuda:0', dtype=torch.float16) tensor(386.2500, device='cuda:0', dtype=torch.float16)

subject:  star
major starts with  M.
areas starts with  A.
{} starts with
tensor(3.6602, device='cuda:0', dtype=torch.float16) tensor(399.5000, device='cuda:0', dtype=torch.float16)

subject:  areas
future starts with  F.
major starts with  M.
{} starts with
tensor(21.2656, device='cuda:0', dtype=torch.float16) tensor(392.7500, device='cuda:0', dtype=torch.float16)



  0%|          | 0/4 [00:00<?, ?it/s]

subject:  star
major starts with  M.
areas starts with  A.
{} starts with
tensor(3.6602, device='cuda:0', dtype=torch.float16) tensor(399.5000, device='cuda:0', dtype=torch.float16)

subject:  areas
major starts with  M.
star starts with  S.
{} starts with
tensor(21.4688, device='cuda:0', dtype=torch.float16) tensor(396.7500, device='cuda:0', dtype=torch.float16)

subject:  future
star starts with  S.
major starts with  M.
{} starts with
tensor(22.5000, device='cuda:0', dtype=torch.float16) tensor(387.7500, device='cuda:0', dtype=torch.float16)

subject:  major
star starts with  S.
future starts with  F.
{} starts with
tensor(26.1875, device='cuda:0', dtype=torch.float16) tensor(378.2500, device='cuda:0', dtype=torch.float16)



In [47]:
prompt

'grape starts with G\nmonitor starts with M\n{} starts with'

In [49]:
relation_operator = estimate.RelationOperator(
    model = model,
    tokenizer= tokenizer,
    relation = prompt,
    layer = 15,
    weight = torch.stack(
        [wb['weight'] for wb in weights_and_biases]
    ).mean(dim=0),
    # bias = torch.stack(
    #     [wb['bias'] for wb in weights_and_biases]
    # ).mean(dim=0),
    bias = simple_corner
)

check_with_test_cases(relation_operator)

month, target: M   ==>   predicted: [' M', ' C', ' T', ' L', ' B']
major, target: M   ==>   predicted: [' M', ' D', ' C', ' T', ' S']
star, target: S   ==>   predicted: [' T', ' L', ' C', ' P', ' S']
areas, target: A   ==>   predicted: [' T', ' M', ' S', ' C', ' P']
future, target: F   ==>   predicted: [' M', ' D', ' C', ' S', ' P']
space, target: S   ==>   predicted: [' S', ' T', ' M', ' C', ' P']
committee, target: C   ==>   predicted: [' C', ' T', ' M', ' D', ' P']
london, target: L   ==>   predicted: [' L', ' M', ' C', ' D', ' T']
washington, target: W   ==>   predicted: [' D', ' W', ' M', ' C', ' S']
meeting, target: M   ==>   predicted: [' M', ' T', ' C', ' P', ' S']


In [46]:
corner_estimator.get_vocab_representation(
    torch.stack(
        [wb['bias'] for wb in weights_and_biases]
    ).mean(dim=0)
)

[' ', ' a', ' the', ' A', ' an']

In [55]:
tokenizer.tokenize(" fifteen")

['Ġfifteen']