In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random


import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from dsets.counterfact import CounterFactDataset
from util import nethook
from operator import itemgetter

In [3]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

In [4]:
jb_save_path = "gpt-j/jacobians_averaged"
os.makedirs(jb_save_path, exist_ok = True)

In [19]:
cut_off = 50 # minimum number of correct predictions

###########################################################################
relation_dct = {
    'P17'   : {'relation': '{} is located in the country of', 'correct_predict': None, 'cached_JB': None},
    'P641'  : {'relation': '{} plays the sport of', 'correct_predict': None, 'cached_JB': None},
    'P103'  : {'relation': 'The mother tongue of {} is', 'correct_predict': None, 'cached_JB': None},
    'P176'  : {'relation': '{} is produced by', 'correct_predict': None, 'cached_JB': None},
    'P140'  : {'relation': 'The official religion of {} is', 'correct_predict': None, 'cached_JB': None},
    'P1303' : {'relation': '{} plays the instrument', 'correct_predict': None, 'cached_JB': None},
    'P190'  : {'relation': 'What is the twin city of {}? It is', 'correct_predict': None, 'cached_JB': None},
    'P740'  : {'relation': '{} was founded in', 'correct_predict': None, 'cached_JB': None},
    'P178'  : {'relation': '{} was developed by', 'correct_predict': None, 'cached_JB': None},
    'P495'  : {'relation': '{}, that originated in the country of', 'correct_predict': None, 'cached_JB': None},
    'P127'  : {'relation': '{} is owned by', 'correct_predict': None, 'cached_JB': None},
    'P413'  : {'relation': '{} plays in the position of', 'correct_predict': None, 'cached_JB': None},
    'P39'   : {'relation': '{}, who holds the position of', 'correct_predict': None, 'cached_JB': None},
    'P159'  : {'relation': 'The headquarter of {} is located in', 'correct_predict': None, 'cached_JB': None},
    'P20'   : {'relation': '{} died in the city of', 'correct_predict': None, 'cached_JB': None},
    'P136'  : {'relation': 'What does {} play? They play', 'correct_predict': None, 'cached_JB': None},
    'P106'  : {'relation': 'The profession of {} is', 'correct_predict': None, 'cached_JB': None},
    'P30'   : {'relation': '{} is located in the continent of', 'correct_predict': None, 'cached_JB': None},
    'P937'  : {'relation': '{} worked in the city of', 'correct_predict': None, 'cached_JB': None},
    'P449'  : {'relation': '{} was released on', 'correct_predict': None, 'cached_JB': None},
    'P27'   : {'relation': '{} is a citizen of', 'correct_predict': None, 'cached_JB': None},
    'P101'  : {'relation': '{} works in the field of', 'correct_predict': None, 'cached_JB': None},
    'P19'   : {'relation': '{} was born in', 'correct_predict': None, 'cached_JB': None},
    'P37'   : {'relation': 'In {}, an official language is', 'correct_predict': None, 'cached_JB': None},
    'P138'  : {'relation': '{}, named after', 'correct_predict': None, 'cached_JB': None},
    'P131'  : {'relation': '{} is located in', 'correct_predict': None, 'cached_JB': None},
    'P407'  : {'relation': '{} was written in', 'correct_predict': None, 'cached_JB': None},
    'P108'  : {'relation': '{}, who is employed by', 'correct_predict': None, 'cached_JB': None},
    'P36'   : {'relation': 'The capital of {} is', 'correct_predict': None, 'cached_JB': None},
}
###########################################################################

In [20]:
root_path = "gpt-j"

for relation in relation_dct:
    path = f"{root_path}/{relation}"
    with open(f"{path}/correct_prediction_{relation}.json") as f:
        correct_predictions = json.load(f)
    if(len(correct_predictions) < cut_off):
    # if "performance" not in os.listdir(path):
        print("skipped ", relation)

skipped  P1303
skipped  P190
skipped  P740
skipped  P413
skipped  P39
skipped  P136
skipped  P449
skipped  P138
skipped  P131
skipped  P407
skipped  P108


In [21]:
relation_id = "P17"

print(relation_dct[relation_id]['relation'])

{} is located in the country of


In [22]:
def get_top_performers(relation_id, consider_top = 5):
    with open(f"gpt-j/{relation_id}/performance") as f:
        performance = json.load(f)
    performance.sort(key = itemgetter('p@3'), reverse=True)

    subject__top_performers = []
    object__top_performers = []
    top_performers = []

    for candidate in performance:
        subject = candidate['subject']
        try:
            sub_idx = candidate['misc']['h_info']['sub_index']
        except:
            sub_idx = candidate['misc']['h_info']['h_index']
        object = candidate['object']
        if(subject in subject__top_performers or object in object__top_performers):
            continue
        if(len(tokenizer(subject).input_ids) > 3):
            continue
        
        subject__top_performers.append(subject)
        object__top_performers.append(object)
        top_performers.append((
            subject, sub_idx, object  #, candidate['p@3']
        ))

        if(len(top_performers) == consider_top):
            break
    return top_performers

top_performers = get_top_performers(relation_id)
top_performers

[('Umarex', 2, 'Germany'),
 ('Harnaut', 2, 'India'),
 ('Haut Atlas', 2, 'Morocco'),
 ('Ufa', 1, 'Russia'),
 ('Canada Live', 1, 'Canada')]

In [23]:
r = relation_dct[relation_id]['relation']
r

'{} is located in the country of'

In [24]:
def get_averaged_JB(top_performers, relation_prompt, num_icl = 3):
    jbs = []
    for s, s_idx, o in tqdm(top_performers):
        others = set(top_performers) - {(s, s_idx, o)}
        others = random.sample(list(others), k = min(num_icl, len(list(others)))) 
        prompt = ""
        prompt += "\n".join(relation_prompt.format(s_other) + f" {o_other}." for s_other, idx_other, o_other in others) + "\n"
        prompt += relation_prompt
        print("subject: ", s)
        print(prompt)

        inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
        jb = estimate.estimate_relation_operator(
            model, tokenizer,
            s, prompt,
            subject_token_index= s_idx,
            layer = 15,
            device = model.device,
        )
        print(jb.weight.norm(), jb.bias.norm())
        print()
        jbs.append(jb)
    
    weight = torch.stack([jb.weight for jb in jbs]).mean(dim=0)
    bias  = torch.stack([jb.bias for jb in jbs]).mean(dim=0)

    return weight, bias

weight, bias = get_averaged_JB(top_performers, r)
relation = estimate.RelationOperator(
    weight=weight,
    bias=bias,
    model=model,
    tokenizer=tokenizer,
    layer= 15 ,
    relation= relation_dct[relation_id]['relation'],
)

  0%|          | 0/5 [00:00<?, ?it/s]

subject:  Umarex
Haut Atlas is located in the country of Morocco.
Canada Live is located in the country of Canada.
Harnaut is located in the country of India.
{} is located in the country of
tensor(21.9531, device='cuda:0', dtype=torch.float16) tensor(342.7500, device='cuda:0', dtype=torch.float16)

subject:  Harnaut
Ufa is located in the country of Russia.
Haut Atlas is located in the country of Morocco.
Canada Live is located in the country of Canada.
{} is located in the country of
tensor(23.3281, device='cuda:0', dtype=torch.float16) tensor(306., device='cuda:0', dtype=torch.float16)

subject:  Haut Atlas
Umarex is located in the country of Germany.
Canada Live is located in the country of Canada.
Ufa is located in the country of Russia.
{} is located in the country of
tensor(10.4453, device='cuda:0', dtype=torch.float16) tensor(323., device='cuda:0', dtype=torch.float16)

subject:  Ufa
Haut Atlas is located in the country of Morocco.
Umarex is located in the country of Germany.
Ha

In [25]:
# test_subjects = [
#     "Hugh Jackman",
#     "Michael Phelps",
#     "Agatha Christie",
#     "Barack Obama",
#     "Sherlock Holmes",
#     "Alan Turing",
#     "Bill Gates",
#     "Michelangelo"
# ]

# for sub in test_subjects:
#     print(f"{sub} >> ", relation(sub, device= model.device))

In [26]:
test_cases = [
    # ("Statue of Liberty", -1, "United States"),
    ("The Great Wall", -1, "China"),
    ("Niagara Falls", -2, "Canada"),
    ("Valdemarsvik", -1, "Sweden"),
    ("Kyoto University", -2, "Japan"),
    ("Hattfjelldal", -1, "Norway"),
    ("Ginza", -1, "Japan"),
    ("Sydney Hospital", -2, "Australia"),
    ("Mahalangur Himal", -1, "Nepal"),
    ("Higashikagawa", -1, "Japan"),
    ("Trento", -1, "Italy"),
    ("Taj Mahal", -1, "India"),
    ("Hagia Sophia", -1, "Turkey"),
    ("Colosseum", -1, "Italy"),
    ("Mount Everest", -1, "Nepal"),
    ("Valencia", -1, "Spain"),
    ("Lake Baikal", -1, "Russia"),
    ("Merlion Park", -1, "Singapore"),
    ("Cologne Cathedral", -1, "Germany"),
    ("Buda Castle", -1, "Hungary")
]

def check_with_test_cases(relation_operator):
    for subject, subject_token_index, target in test_cases:
        objects = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {objects}")

check_with_test_cases(relation)

The Great Wall, target: China   ==>   predicted: [' China', ' Russia', ' Canada', ' Germany', ' India']
Niagara Falls, target: Canada   ==>   predicted: [' Canada', ' Germany', ' United', ' Russia', ' Switzerland']
Valdemarsvik, target: Sweden   ==>   predicted: [' Canada', ' Germany', ' Russia', ' Sweden', ' France']
Kyoto University, target: Japan   ==>   predicted: [' Japan', ' Canada', ' Russia', ' India', ' Germany']
Hattfjelldal, target: Norway   ==>   predicted: [' Canada', ' Germany', ' Russia', ' Switzerland', ' Norway']
Ginza, target: Japan   ==>   predicted: [' Japan', ' Russia', ' Canada', ' Germany', ' France']
Sydney Hospital, target: Australia   ==>   predicted: [' Australia', ' Canada', ' France', ' Germany', ' India']
Mahalangur Himal, target: Nepal   ==>   predicted: [' India', ' Canada', ' Russia', ' Germany', ' Switzerland']
Higashikagawa, target: Japan   ==>   predicted: [' Japan', ' Canada', ' Germany', ' Russia', ' India']
Trento, target: Italy   ==>   predicted:

In [16]:
print(f"{jb_save_path}/{relation_id}.npz")

np.savez(
    f"{jb_save_path}/{relation_id}.npz", 
    JB = {
        'weight': weight.cpu().numpy(),
        'bias'  : bias.cpu().numpy()
    },
    allow_pickle = True
)

gpt-j/jacobians_averaged/P17.npz


NameError: name 'weight' is not defined

In [17]:
f"{root_path}/{relation}"

'gpt-j/P36'

In [18]:
for relation_id in relation_dct:
    print("zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz")
    print(relation_id, relation_dct[relation_id]['relation'])
    print("zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz")
    path = f"{root_path}/{relation_id}"
    try:
        with open(f"{path}/correct_prediction_{relation_id}.json") as f:
            correct_predictions = json.load(f)
            print(len(correct_predictions))
    except:
        print(f"Error opening correct prediction {relation_id} (maybe the scan skipped this relation?)")
        print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
        continue
        
    if(len(correct_predictions) < cut_off):
        print(f"skipped {relation_id} >> ", len(correct_predictions))
        print("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
        continue

    top_performers = get_top_performers(relation_id)

    relation_prompt = relation_dct[relation_id]['relation']
    print(relation_prompt)

    weight, bias = get_averaged_JB(top_performers, relation_prompt)

    save_path = f"{jb_save_path}/{relation_id}.npz"
    print("Saving weights and biases >> ", save_path)
    np.savez(
        save_path, 
        JB = {
            'weight': weight.cpu().numpy(),
            'bias'  : bias.cpu().numpy()
        },
        allow_pickle = True
    )
    print("----------------------------------------------------")


zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
P407 {} was written in
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
29
skipped P407 >>  29
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
P108 {}, who is employed by
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
43
skipped P108 >>  43
XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
P36 The capital of {} is
zzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzzz
68
The capital of {} is


  0%|          | 0/5 [00:00<?, ?it/s]

subject:  United Arab Republic
The capital of Lazio is Rome.
The capital of Demerara is Georgetown.
The capital of Saxony is Dresden.
The capital of {} is
tensor(33.5938, device='cuda:0', dtype=torch.float16) tensor(205.1250, device='cuda:0', dtype=torch.float16)

subject:  Saxony
The capital of Demerara is Georgetown.
The capital of South Yemen is Aden.
The capital of United Arab Republic is Cairo.
The capital of {} is
tensor(30.6875, device='cuda:0', dtype=torch.float16) tensor(186.8750, device='cuda:0', dtype=torch.float16)

subject:  Demerara
The capital of United Arab Republic is Cairo.
The capital of Lazio is Rome.
The capital of Saxony is Dresden.
The capital of {} is
tensor(44.0312, device='cuda:0', dtype=torch.float16) tensor(183.8750, device='cuda:0', dtype=torch.float16)

subject:  South Yemen
The capital of Demerara is Georgetown.
The capital of United Arab Republic is Cairo.
The capital of Lazio is Rome.
The capital of {} is
tensor(38.3125, device='cuda:0', dtype=torch.flo