In [1]:
%load_ext autoreload
%autoreload 2

In [8]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter
from evaluate import evaluate
from relations.corner import CornerEstimator
from dsets.counterfact import CounterFactDataset

In [11]:
# counterfact = CounterFactDataset("../data/")

In [3]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

EleutherAI/gpt-j-6B ==> device: cuda:0, memory: 24320971760


In [34]:
#################################################
relation_id = "P101"
precision_at = 3
#################################################

with open("../data/counterfact.json") as f:
    counterfact = json.load(f)

objects = [c['requested_rewrite'] for c in counterfact if c["requested_rewrite"]['relation_id'] == relation_id]
objects = [" "+ o['target_true']['str'] for o in objects]
objects = list(set(objects))
print("unique objects: ", len(objects), objects[0:5])

unique objects:  83 [' economist', ' evolution', ' anatomy', ' Sanskrit', ' art']


In [31]:
corner_estimator = CornerEstimator(model=model, tokenizer=tokenizer)

In [35]:
simple_corner = corner_estimator.estimate_simple_corner(objects, scale_up=70)
print(simple_corner.norm().item(), corner_estimator.get_vocab_representation(simple_corner))

relation_operator = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = '{} works in the field of',
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = simple_corner
)

32.76474380493164 [' science', ' physics', ' history', ' biology', ' music']


In [37]:
precision, ret_dict = evaluate(
    relation_id= relation_id,
    relation_operator= relation_operator,
    precision_at=3
)

P101 >> number of requests in counterfact = 545
Checking correct prediction with normal calculation ...


100%|██████████| 545/545 [02:17<00:00,  3.95it/s]


Number of correctly predicted requests = 84
validating on 84 subject --> object associations


100%|██████████| 84/84 [00:04<00:00, 17.77it/s]


In [38]:
precision

0.4166666666666667

In [39]:
lin_inv_corner = corner_estimator.estimate_lin_inv_corner(objects, target_logit_value=50)
print(lin_inv_corner.norm().item(), corner_estimator.get_vocab_representation(lin_inv_corner))

relation_lin_inv = estimate.RelationOperator(
    model = model,
    tokenizer = tokenizer,
    relation = '{} works in the field of',
    layer = 15,
    weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
    bias = lin_inv_corner
)

calculating inverse of unbedding weights . . .
41.379276275634766 [' physics', ' biology', ' economics', ' chemistry', ' geography']


In [40]:
precision, ret_dict_2 = evaluate(
    relation_id="P17",
    relation_operator= relation_lin_inv,
    precision_at=3,
    validation_set= ret_dict["validation_set"]
)

precision

validating on 84 subject --> object associations


100%|██████████| 84/84 [00:04<00:00, 18.04it/s]


0.5952380952380952