In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random


import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from dsets.counterfact import CounterFactDataset
from util import nethook
from operator import itemgetter


In [3]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float16)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

In [1]:
###########################################################################
relation_dct = {
    'P17'   : {'relation': '{} is located in the country of', 'correct_predict': None, 'cached_JB': None},
    'P641'  : {'relation': '{} plays the sport of', 'correct_predict': None, 'cached_JB': None},
    'P103'  : {'relation': 'The mother tongue of {} is', 'correct_predict': None, 'cached_JB': None},
    'P176'  : {'relation': '{} is produced by', 'correct_predict': None, 'cached_JB': None},
    'P140'  : {'relation': 'The official religion of {} is', 'correct_predict': None, 'cached_JB': None},
    # 'P1303' : {'relation': '{} plays the instrument', 'correct_predict': None, 'cached_JB': None},
    'P190'  : {'relation': 'What is the twin city of {}? It is', 'correct_predict': None, 'cached_JB': None},
    'P740'  : {'relation': '{} was founded in', 'correct_predict': None, 'cached_JB': None},
    'P178'  : {'relation': '{} was developed by', 'correct_predict': None, 'cached_JB': None},
    'P495'  : {'relation': '{}, that originated in the country of', 'correct_predict': None, 'cached_JB': None},
    'P127'  : {'relation': '{} is owned by', 'correct_predict': None, 'cached_JB': None},
    'P413'  : {'relation': '{} plays in the position of', 'correct_predict': None, 'cached_JB': None},
    'P39'   : {'relation': '{}, who holds the position of', 'correct_predict': None, 'cached_JB': None},
    'P159'  : {'relation': 'The headquarter of {} is located in', 'correct_predict': None, 'cached_JB': None},
    'P20'   : {'relation': '{} died in the city of', 'correct_predict': None, 'cached_JB': None},
    'P136'  : {'relation': 'What does {} play? They play', 'correct_predict': None, 'cached_JB': None},
    'P106'  : {'relation': 'The profession of {} is', 'correct_predict': None, 'cached_JB': None},
    'P30'   : {'relation': '{} is located in the continent of', 'correct_predict': None, 'cached_JB': None},
    'P937'  : {'relation': '{} worked in the city of', 'correct_predict': None, 'cached_JB': None},
    'P449'  : {'relation': '{} was released on', 'correct_predict': None, 'cached_JB': None},
    'P27'   : {'relation': '{} is a citizen of', 'correct_predict': None, 'cached_JB': None},
    'P101'  : {'relation': '{} works in the field of', 'correct_predict': None, 'cached_JB': None},
    'P19'   : {'relation': '{} was born in', 'correct_predict': None, 'cached_JB': None},
    'P37'   : {'relation': 'In {}, an official language is', 'correct_predict': None, 'cached_JB': None},
    'P138'  : {'relation': '{}, named after', 'correct_predict': None, 'cached_JB': None},
    'P131'  : {'relation': '{} is located in', 'correct_predict': None, 'cached_JB': None},
    'P407'  : {'relation': '{} was written in', 'correct_predict': None, 'cached_JB': None},
    'P108'  : {'relation': '{}, who is employed by', 'correct_predict': None, 'cached_JB': None},
    'P36'   : {'relation': 'The capital of {} is', 'correct_predict': None, 'cached_JB': None},
}
###########################################################################

In [2]:
root_path = "gpt-j"

for relation in relation_dct:
    path = f"{root_path}/{relation}"
    if "performance" not in os.listdir(path):
        print("skipped ", relation)

skipped  P131
skipped  P407
skipped  P108


In [23]:
relation_id = "P101"

print(relation_dct[relation_id]['relation'])

with open(f"gpt-j/{relation_id}/performance") as f:
    performance = json.load(f)

performance.sort(key = itemgetter('p@3'), reverse=True)

{} works in the field of


In [26]:
consider_top = 5
subject__top_performers = []
object__top_performers = []
top_performers = []

for candidate in performance:
    subject = candidate['subject']
    sub_idx = candidate['misc']['h_info']['h_index']
    object = candidate['object']
    if(subject in subject__top_performers or object in object__top_performers):
        continue
    if(len(tokenizer(subject).input_ids) > 3):
        continue
    
    subject__top_performers.append(subject)
    object__top_performers.append(object)
    top_performers.append((
        subject, sub_idx, object  #, candidate['p@3']
    ))

    if(len(top_performers) == consider_top):
        break 

top_performers

[('Hypatia', 2, 'mathematics'),
 ('Sima Qian', 2, 'history'),
 ('Carl Menger', 2, 'economics'),
 ('Euclid', 0, 'geometry'),
 ('Michael Jackson', 1, 'musician')]

In [28]:
r = relation_dct[relation_id]['relation']

In [29]:
jbs = []
for s, s_idx, o in tqdm(top_performers):
    others = set(top_performers) - {(s, s_idx, o)}
    others = random.sample(list(others), k = 3) 
    prompt = ""
    prompt += "\n".join(r.format(s_other) + f" {o_other}." for s_other, idx_other, o_other in others) + "\n"
    prompt += r
    print("subject: ", s)
    print(prompt)

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    jb = estimate.estimate_relation_operator(
        model, tokenizer,
        s, prompt,
        subject_token_index= s_idx,
        layer = 15,
        device = model.device,
    )
    print(jb.weight.norm(), jb.bias.norm())
    print()
    jbs.append(jb)

relation = estimate.RelationOperator(
    weight=torch.stack([jb.weight for jb in jbs]).mean(dim=0),
    bias=torch.stack([jb.bias for jb in jbs]).mean(dim=0),
    model=model,
    tokenizer=tokenizer,
    layer= 15 ,
    relation= relation_dct[relation_id]['relation'],
)

  0%|          | 0/5 [00:00<?, ?it/s]

subject:  Hypatia
Euclid works in the field of geometry.
Michael Jackson works in the field of musician.
Carl Menger works in the field of economics.
{} works in the field of
tensor(29.5469, device='cuda:0', dtype=torch.float16) tensor(274.2500, device='cuda:0', dtype=torch.float16)

subject:  Sima Qian
Michael Jackson works in the field of musician.
Carl Menger works in the field of economics.
Hypatia works in the field of mathematics.
{} works in the field of
tensor(4.3320, device='cuda:0', dtype=torch.float16) tensor(279.5000, device='cuda:0', dtype=torch.float16)

subject:  Carl Menger
Hypatia works in the field of mathematics.
Euclid works in the field of geometry.
Sima Qian works in the field of history.
{} works in the field of
tensor(23.1094, device='cuda:0', dtype=torch.float16) tensor(253.2500, device='cuda:0', dtype=torch.float16)

subject:  Euclid
Sima Qian works in the field of history.
Carl Menger works in the field of economics.
Michael Jackson works in the field of musi

In [30]:
test_subjects = [
    "Hugh Jackman",
    "Michael Phelps",
    "Agatha Christie",
    "Barack Obama",
    "Sherlock Holmes",
    "Alan Turing",
    "Bill Gates",
    "Michelangelo"
]

for sub in test_subjects:
    print(f"{sub} >> ", relation(sub, device= model.device))

Hugh Jackman >>  [' mathematics', ' philosophy', ' history', ' the', ' science']
Michael Phelps >>  [' mathematics', ' philosophy', ' history', ' the', ' science']
Agatha Christie >>  [' mathematics', ' philosophy', ' history', ' the', '\n']
Barack Obama >>  [' mathematics', ' philosophy', ' history', ' the', ' science']
Sherlock Holmes >>  [' mathematics', ' philosophy', ' history', ' the', ' ancient']
Alan Turing >>  [' mathematics', ' philosophy', ' history', ' the', ' ancient']
Bill Gates >>  [' mathematics', ' philosophy', ' history', ' the', ' science']
Michelangelo >>  [' mathematics', ' philosophy', ' history', ' the', ' science']


In [13]:
test_cases = [
    # ("Statue of Liberty", -1, "United States"),
    ("The Great Wall", -1, "China"),
    ("Niagara Falls", -2, "Canada"),
    ("Valdemarsvik", -1, "Sweden"),
    ("Kyoto University", -2, "Japan"),
    ("Hattfjelldal", -1, "Norway"),
    ("Ginza", -1, "Japan"),
    ("Sydney Hospital", -2, "Australia"),
    ("Mahalangur Himal", -1, "Nepal"),
    ("Higashikagawa", -1, "Japan"),
    ("Trento", -1, "Italy"),
    ("Taj Mahal", -1, "India"),
    ("Hagia Sophia", -1, "Turkey"),
    ("Colosseum", -1, "Italy"),
    ("Mount Everest", -1, "Nepal"),
    ("Valencia", -1, "Spain"),
    ("Lake Baikal", -1, "Russia"),
    ("Merlion Park", -1, "Singapore"),
    ("Cologne Cathedral", -1, "Germany"),
    ("Buda Castle", -1, "Hungary")
]

def check_with_test_cases(relation_operator):
    for subject, subject_token_index, target in test_cases:
        objects = relation_operator(
            subject,
            subject_token_index=subject_token_index,
            device=model.device,
            return_top_k=5,
        )
        print(f"{subject}, target: {target}   ==>   predicted: {objects}")

In [14]:
check_with_test_cases(relation)

The Great Wall, target: China   ==>   predicted: [' Italy', ' France', ' Russia', ' Germany', ' United']
Niagara Falls, target: Canada   ==>   predicted: [' Italy', ' France', ' Germany', ' United', ' Canada']
Valdemarsvik, target: Sweden   ==>   predicted: [' Italy', ' France', ' Finland', ' Russia', ' Germany']
Kyoto University, target: Japan   ==>   predicted: [' Japan', ' Italy', ' France', ' United', ' Germany']
Hattfjelldal, target: Norway   ==>   predicted: [' Italy', ' France', ' Germany', ' Finland', ' Russia']
Ginza, target: Japan   ==>   predicted: [' Italy', ' Japan', ' France', ' Russia', ' Germany']
Sydney Hospital, target: Australia   ==>   predicted: [' Italy', ' France', ' Australia', ' United', ' Germany']
Mahalangur Himal, target: Nepal   ==>   predicted: [' Italy', ' France', ' India', ' Germany', ' United']
Higashikagawa, target: Japan   ==>   predicted: [' Italy', ' Japan', ' France', ' Germany', ' United']
Trento, target: Italy   ==>   predicted: [' Italy', ' Fra

In [12]:
relation("Clonlara GAA", device = model.device)

[' Italy', ' France', ' United', ' Germany', ' Ireland']