In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import numpy as np
import json
from tqdm.auto import tqdm
import random
import transformers

import os
import sys
sys.path.append('..')

from relations import estimate
from util import model_utils
from baukit import nethook
from operator import itemgetter

In [3]:
cut_off = 50 # minimum number of correct predictions

###########################################################################
relation_dct = {
    'P17'   : {'relation': '{} is located in the country of', 'correct_predict': None, 'cached_JB': None},
    'P641'  : {'relation': '{} plays the sport of', 'correct_predict': None, 'cached_JB': None},
    'P103'  : {'relation': 'The mother tongue of {} is', 'correct_predict': None, 'cached_JB': None},
    'P176'  : {'relation': '{} is produced by', 'correct_predict': None, 'cached_JB': None},
    'P140'  : {'relation': 'The official religion of {} is', 'correct_predict': None, 'cached_JB': None},
    'P1303' : {'relation': '{} plays the instrument', 'correct_predict': None, 'cached_JB': None},
    'P190'  : {'relation': 'What is the twin city of {}? It is', 'correct_predict': None, 'cached_JB': None},
    'P740'  : {'relation': '{} was founded in', 'correct_predict': None, 'cached_JB': None},
    'P178'  : {'relation': '{} was developed by', 'correct_predict': None, 'cached_JB': None},
    'P495'  : {'relation': '{}, that originated in the country of', 'correct_predict': None, 'cached_JB': None},
    'P127'  : {'relation': '{} is owned by', 'correct_predict': None, 'cached_JB': None},
    'P413'  : {'relation': '{} plays in the position of', 'correct_predict': None, 'cached_JB': None},
    'P39'   : {'relation': '{}, who holds the position of', 'correct_predict': None, 'cached_JB': None},
    'P159'  : {'relation': 'The headquarter of {} is located in', 'correct_predict': None, 'cached_JB': None},
    'P20'   : {'relation': '{} died in the city of', 'correct_predict': None, 'cached_JB': None},
    'P136'  : {'relation': 'What does {} play? They play', 'correct_predict': None, 'cached_JB': None},
    'P106'  : {'relation': 'The profession of {} is', 'correct_predict': None, 'cached_JB': None},
    'P30'   : {'relation': '{} is located in the continent of', 'correct_predict': None, 'cached_JB': None},
    'P937'  : {'relation': '{} worked in the city of', 'correct_predict': None, 'cached_JB': None},
    'P449'  : {'relation': '{} was released on', 'correct_predict': None, 'cached_JB': None},
    'P27'   : {'relation': '{} is a citizen of', 'correct_predict': None, 'cached_JB': None},
    'P101'  : {'relation': '{} works in the field of', 'correct_predict': None, 'cached_JB': None},
    'P19'   : {'relation': '{} was born in', 'correct_predict': None, 'cached_JB': None},
    'P37'   : {'relation': 'In {}, an official language is', 'correct_predict': None, 'cached_JB': None},
    'P138'  : {'relation': '{}, named after', 'correct_predict': None, 'cached_JB': None},
    'P131'  : {'relation': '{} is located in', 'correct_predict': None, 'cached_JB': None},
    'P407'  : {'relation': '{} was written in', 'correct_predict': None, 'cached_JB': None},
    'P108'  : {'relation': '{}, who is employed by', 'correct_predict': None, 'cached_JB': None},
    'P36'   : {'relation': 'The capital of {} is', 'correct_predict': None, 'cached_JB': None},
}
###########################################################################


root_path = "gpt-j"

pop_track = []
for relation in relation_dct:
    path = f"{root_path}/{relation}"
    with open(f"{path}/correct_prediction_{relation}.json") as f:
        correct_predictions = json.load(f)
    if(len(correct_predictions) < cut_off):
    # if "performance" not in os.listdir(path):
        print("skipped ", relation)
        pop_track.append(relation)
    
for r in pop_track:
    relation_dct.pop(r)

skipped  P1303
skipped  P190
skipped  P740
skipped  P413
skipped  P39
skipped  P136
skipped  P449
skipped  P138
skipped  P131
skipped  P407
skipped  P108


In [4]:
MODEL_NAME = "EleutherAI/gpt-j-6B"  # gpt2-{medium,large,xl} or EleutherAI/gpt-j-6B
mt = model_utils.ModelAndTokenizer(MODEL_NAME, low_cpu_mem_usage=True, torch_dtype=torch.float32)

model = mt.model
tokenizer = mt.tokenizer
tokenizer.pad_token = tokenizer.eos_token

print(f"{MODEL_NAME} ==> device: {model.device}, memory: {model.get_memory_footprint()}")

EleutherAI/gpt-j-6B ==> device: cuda:0, memory: 24320971760


In [5]:
from dsets.counterfact import CounterFactDataset
counterfact = CounterFactDataset("../data/")

Loaded dataset with 21919 elements


In [6]:
jacobian_cache_path = "/mnt/39a89eb4-27b7-4fce-a6ab-a9d203443a7c/relation_cached/gpt-j/jacobians_averaged_collection/before__ln_f/{}.npz"
cached_jbs = np.load(
    jacobian_cache_path.format('P17'),
    allow_pickle= True
)

In [7]:
w = torch.stack(
        [wb['weight'] for wb in cached_jbs['weights_and_biases']]
    ).mean(dim=0)

b = torch.stack(
        [wb['bias'] for wb in cached_jbs['weights_and_biases']]
    ).mean(dim=0)

w.shape, b.shape

(torch.Size([4096, 4096]), torch.Size([1, 4096]))

In [8]:
relation_id = "P17"
objects = [c['requested_rewrite'] for c in counterfact if c["requested_rewrite"]['relation_id'] == relation_id]
objects = [" "+ o['target_true']['str'] for o in objects]
objects = list(set(objects))
print("unique objects: ", len(objects), objects[0:min(5, len(objects))])

unique objects:  95 [' Ireland', ' Pakistan', ' Australia', ' Ethiopia', ' Turkey']


In [26]:
from relations import corner, evaluate
corner_estimator = corner.CornerEstimator(model = model, tokenizer = tokenizer)

current_corner = corner_estimator.estimate_average_corner_with_gradient_descent(
            objects, 
            # target_logit_value=30, verbose = True
        )

corner_estimator.get_vocab_representation(current_corner, get_logits=50)

[(' Lithuania', 57.932),
 (' Estonia', 57.909),
 (' Nicaragua', 57.906),
 (' Tanzania', 57.899),
 (' Bahamas', 57.82)]

In [27]:
from relations import corner, evaluate

jacobians_calculated_after_layer = 15
precision_at = 3

corner_estimator = corner.CornerEstimator(model = model, tokenizer = tokenizer)
performance_track = {}

for relation_id in tqdm(relation_dct):
    print(f"relation_id >> {relation_id}")
    print("------------------------------------------------------------------------------------------------------")
    objects = [c['requested_rewrite'] for c in counterfact if c["requested_rewrite"]['relation_id'] == relation_id]
    objects = [" "+ o['target_true']['str'] for o in objects]
    objects = list(set(objects))
    print("unique objects: ", len(objects), objects[0:min(5, len(objects))])

    performance_track[relation_id] = {}
    validation_set = None
        
    current_corner = corner_estimator.estimate_average_corner_with_gradient_descent(
        objects, # target_logit_value=target_logit
    )
    vocab_repr = corner_estimator.get_vocab_representation(current_corner, get_logits=True)
    print(f"{current_corner.norm().item()} >> {vocab_repr}")

    performance_track[relation_id] = {
        'jacobian': -1, 'identity': -1,
        'vocab_repr': vocab_repr,
        'corner_norm': current_corner.norm().item()
    }

    relation_identity = estimate.RelationOperator(
        model = model,
        tokenizer = tokenizer,
        relation = relation_dct[relation_id]['relation'],
        layer = jacobians_calculated_after_layer,
        weight = torch.eye(model.config.n_embd).to(model.dtype).to(model.device),
        bias = current_corner
    )
    precision, ret_dict = evaluate.evaluate(
        relation_id= relation_id,
        relation_operator= relation_identity,
        precision_at = precision_at,
        validation_set = validation_set
    )

    if(validation_set is None):
        validation_set = ret_dict["validation_set"]

    print("w = identity >> ", precision)
    ret_dict.pop('validation_set')
    ret_dict.pop('predictions')
    performance_track[relation_id]['identity'] = (precision, ret_dict)

    cached_jbs = np.load(
        jacobian_cache_path.format(relation_id),
        allow_pickle= True
    )
    relation_jacobian = estimate.RelationOperator(
        model = model,
        tokenizer = tokenizer,
        relation = relation_dct[relation_id]['relation'],
        layer = jacobians_calculated_after_layer,
        weight = torch.stack(
            [wb['weight'] for wb in cached_jbs['weights_and_biases']]
        ).mean(dim=0).to(model.dtype).to(model.device),
        bias = current_corner
    )
    precision, ret_dict = evaluate.evaluate(
        relation_id= relation_id,
        relation_operator= relation_jacobian,
        precision_at = precision_at,
        validation_set = validation_set
    )
        
    print("w = jacobian >> ", precision)
    ret_dict.pop('validation_set')
    ret_dict.pop('predictions')
    performance_track[relation_id]['jacobian'] = (precision, ret_dict)
    print()
        

    print("saving results")
    with open("corner_target_logit_sweep/new_loss.json", "w") as f:
        json.dump(performance_track, f)
    print("############################################################################################################")

  0%|          | 0/18 [00:00<?, ?it/s]

relation_id >> P17
------------------------------------------------------------------------------------------------------
unique objects:  95 [' Ireland', ' Pakistan', ' Australia', ' Ethiopia', ' Turkey']
22.859743118286133 >> [(' China', 58.46), (' Albania', 57.86), (' Tunisia', 57.827), (' Estonia', 57.787), (' Bahrain', 57.77)]
P17 >> number of requests in counterfact = 875
Checking correct prediction with normal calculation ...


  0%|          | 0/875 [00:00<?, ?it/s]

Number of correctly predicted requests = 407
validating on 407 subject --> object associations


  0%|          | 0/407 [00:00<?, ?it/s]

w = identity >>  0.7518427518427518
validating on 407 subject --> object associations


  0%|          | 0/407 [00:00<?, ?it/s]

w = jacobian >>  0.9066339066339066

saving results
############################################################################################################
relation_id >> P641
------------------------------------------------------------------------------------------------------
unique objects:  5 [' baseball', ' hockey', ' football', ' basketball', ' soccer']
21.381254196166992 >> [(' soccer', 116.172), (' baseball', 115.898), (' football', 115.609), (' hockey', 115.401), (' basketball', 115.207)]
P641 >> number of requests in counterfact = 318
Checking correct prediction with normal calculation ...


  0%|          | 0/318 [00:00<?, ?it/s]

Number of correctly predicted requests = 196
validating on 196 subject --> object associations


  0%|          | 0/196 [00:00<?, ?it/s]

w = identity >>  0.9846938775510204
validating on 196 subject --> object associations


  0%|          | 0/196 [00:00<?, ?it/s]

w = jacobian >>  0.7346938775510204

saving results
############################################################################################################
relation_id >> P103
------------------------------------------------------------------------------------------------------
unique objects:  30 [' Danish', ' Polish', ' Welsh', ' Chinese', ' Latin']
29.54885482788086 >> [(' Croatian', 81.27), (' Turkish', 81.268), (' Swedish', 81.267), (' Italian', 81.206), (' Finnish', 81.194)]
P103 >> number of requests in counterfact = 919
Checking correct prediction with normal calculation ...


  0%|          | 0/919 [00:00<?, ?it/s]

Number of correctly predicted requests = 745
validating on 745 subject --> object associations


  0%|          | 0/745 [00:00<?, ?it/s]

w = identity >>  0.9718120805369127
validating on 745 subject --> object associations


  0%|          | 0/745 [00:00<?, ?it/s]

w = jacobian >>  0.9543624161073826

saving results
############################################################################################################
relation_id >> P176
------------------------------------------------------------------------------------------------------
unique objects:  37 [' Google', ' Iran', ' Airbus', ' Sony', ' Chrysler']
25.52078628540039 >> [(' Ford', 68.284), (' Mercedes', 63.664), (' Volkswagen', 59.98), (' Chrysler', 59.2), (' Renault', 59.177)]
P176 >> number of requests in counterfact = 911
Checking correct prediction with normal calculation ...


  0%|          | 0/911 [00:00<?, ?it/s]

Number of correctly predicted requests = 544
validating on 544 subject --> object associations


  0%|          | 0/544 [00:00<?, ?it/s]

w = identity >>  0.7169117647058824
validating on 544 subject --> object associations


  0%|          | 0/544 [00:00<?, ?it/s]

w = jacobian >>  0.4099264705882353

saving results
############################################################################################################
relation_id >> P140
------------------------------------------------------------------------------------------------------
unique objects:  9 [' Muslim', ' Islam', ' Catholicism', ' Christianity', ' Judaism']
27.48402214050293 >> [(' Christianity', 85.081), (' Christian', 84.815), (' Buddhism', 84.729), (' Islam', 84.633), (' Judaism', 84.398)]
P140 >> number of requests in counterfact = 430
Checking correct prediction with normal calculation ...


  0%|          | 0/430 [00:00<?, ?it/s]

Number of correctly predicted requests = 183
validating on 183 subject --> object associations


  0%|          | 0/183 [00:00<?, ?it/s]

w = identity >>  0.9180327868852459
validating on 183 subject --> object associations


  0%|          | 0/183 [00:00<?, ?it/s]

w = jacobian >>  0.9726775956284153

saving results
############################################################################################################
relation_id >> P178
------------------------------------------------------------------------------------------------------
unique objects:  19 [' Google', ' Airbus', ' Sony', ' BBC', ' Apple']
28.995100021362305 >> [(' Microsoft', 71.773), (' Samsung', 71.604), (' IBM', 71.57), (' Boeing', 71.558), (' Sega', 71.524)]
P178 >> number of requests in counterfact = 579
Checking correct prediction with normal calculation ...


  0%|          | 0/579 [00:00<?, ?it/s]

Number of correctly predicted requests = 401
validating on 401 subject --> object associations


  0%|          | 0/401 [00:00<?, ?it/s]

w = identity >>  0.8503740648379052
validating on 401 subject --> object associations


  0%|          | 0/401 [00:00<?, ?it/s]

w = jacobian >>  0.7231920199501247

saving results
############################################################################################################
relation_id >> P495
------------------------------------------------------------------------------------------------------
unique objects:  67 [' Ireland', ' Lebanon', ' Italy', ' Finland', ' Hungary']
24.96476936340332 >> [(' China', 66.813), (' Nigeria', 65.59), (' Belarus', 64.718), (' Lithuania', 64.696), (' Norway', 64.672)]
P495 >> number of requests in counterfact = 904
Checking correct prediction with normal calculation ...


  0%|          | 0/904 [00:00<?, ?it/s]

Number of correctly predicted requests = 160
validating on 160 subject --> object associations


  0%|          | 0/160 [00:00<?, ?it/s]

w = identity >>  0.81875
validating on 160 subject --> object associations


  0%|          | 0/160 [00:00<?, ?it/s]

w = jacobian >>  0.88125

saving results
############################################################################################################
relation_id >> P127
------------------------------------------------------------------------------------------------------
unique objects:  109 [' Atlanta', ' Berlin', ' Winnipeg', ' Turkey', ' Jakarta']
19.64012908935547 >> [(' Tampa', 35.487), (' Bristol', 35.473), (' Brazil', 34.591), (' Lebanon', 34.484), (' Columbia', 34.295)]
P127 >> number of requests in counterfact = 433
Checking correct prediction with normal calculation ...


  0%|          | 0/433 [00:00<?, ?it/s]

Number of correctly predicted requests = 105
validating on 105 subject --> object associations


  0%|          | 0/105 [00:00<?, ?it/s]

w = identity >>  0.0380952380952381
validating on 105 subject --> object associations


  0%|          | 0/105 [00:00<?, ?it/s]

w = jacobian >>  0.08571428571428572

saving results
############################################################################################################
relation_id >> P159
------------------------------------------------------------------------------------------------------
unique objects:  178 [' Ireland', ' Dresden', ' Central', ' Georgetown', ' Atlanta']
19.516788482666016 >> [(' Bel', 33.981), (' Cal', 33.492), (' Mal', 33.105), (' Buch', 32.949), (' St', 32.88)]
P159 >> number of requests in counterfact = 756
Checking correct prediction with normal calculation ...


  0%|          | 0/756 [00:00<?, ?it/s]

Number of correctly predicted requests = 159
validating on 159 subject --> object associations


  0%|          | 0/159 [00:00<?, ?it/s]

w = identity >>  0.06918238993710692
validating on 159 subject --> object associations


  0%|          | 0/159 [00:00<?, ?it/s]

w = jacobian >>  0.660377358490566

saving results
############################################################################################################
relation_id >> P20
------------------------------------------------------------------------------------------------------
unique objects:  169 [' Ireland', ' Dresden', ' Greenland', ' Georgetown', ' Atlanta']
19.468341827392578 >> [(' Wood', 32.388), (' Richmond', 31.794), (' Normandy', 31.413), (' Kansas', 31.314), (' Sierra', 31.294)]
P20 >> number of requests in counterfact = 816
Checking correct prediction with normal calculation ...


  0%|          | 0/816 [00:00<?, ?it/s]

Number of correctly predicted requests = 145
validating on 145 subject --> object associations


  0%|          | 0/145 [00:00<?, ?it/s]

w = identity >>  0.06206896551724138
validating on 145 subject --> object associations


  0%|          | 0/145 [00:00<?, ?it/s]

w = jacobian >>  0.27586206896551724

saving results
############################################################################################################
relation_id >> P106
------------------------------------------------------------------------------------------------------
unique objects:  31 [' missionary', ' novelist', ' surgeon', ' mathematician', ' journalist']
27.038917541503906 >> [(' scientist', 70.657), (' writer', 68.307), (' doctor', 67.04), (' professor', 65.019), (' director', 64.054)]
P106 >> number of requests in counterfact = 821
Checking correct prediction with normal calculation ...


  0%|          | 0/821 [00:00<?, ?it/s]

Number of correctly predicted requests = 232
validating on 232 subject --> object associations


  0%|          | 0/232 [00:00<?, ?it/s]

w = identity >>  0.12931034482758622
validating on 232 subject --> object associations


  0%|          | 0/232 [00:00<?, ?it/s]

w = jacobian >>  0.0

saving results
############################################################################################################
relation_id >> P30
------------------------------------------------------------------------------------------------------
unique objects:  5 [' Americas', ' Antarctica', ' Africa', ' Asia', ' Europe']
24.537059783935547 >> [(' Asia', 102.715), (' Africa', 102.576), (' Antarctica', 102.247), (' Americas', 101.717), (' Europe', 101.357)]
P30 >> number of requests in counterfact = 959
Checking correct prediction with normal calculation ...


  0%|          | 0/959 [00:00<?, ?it/s]

Number of correctly predicted requests = 457
validating on 457 subject --> object associations


  0%|          | 0/457 [00:00<?, ?it/s]

w = identity >>  0.9649890590809628
validating on 457 subject --> object associations


  0%|          | 0/457 [00:00<?, ?it/s]

w = jacobian >>  0.9518599562363238

saving results
############################################################################################################
relation_id >> P937
------------------------------------------------------------------------------------------------------
unique objects:  90 [' Dresden', ' Atlanta', ' Berlin', ' Brisbane', ' Boston']
21.75009536743164 >> [(' Albany', 50.201), (' San', 50.129), (' Rochester', 49.191), (' California', 49.117), (' Zurich', 48.539)]
P937 >> number of requests in counterfact = 846
Checking correct prediction with normal calculation ...


  0%|          | 0/846 [00:00<?, ?it/s]

Number of correctly predicted requests = 333
validating on 333 subject --> object associations


  0%|          | 0/333 [00:00<?, ?it/s]

w = identity >>  0.3153153153153153
validating on 333 subject --> object associations


  0%|          | 0/333 [00:00<?, ?it/s]

w = jacobian >>  0.3663663663663664

saving results
############################################################################################################
relation_id >> P27
------------------------------------------------------------------------------------------------------
unique objects:  97 [' Ireland', ' Pakistan', ' Australia', ' Ecuador', ' Ethiopia']
22.210935592651367 >> [(' Kuwait', 54.844), (' Korea', 54.557), (' Africa', 54.261), (' Portugal', 52.875), (' Slovenia', 52.436)]
P27 >> number of requests in counterfact = 958
Checking correct prediction with normal calculation ...


  0%|          | 0/958 [00:00<?, ?it/s]

Number of correctly predicted requests = 308
validating on 308 subject --> object associations


  0%|          | 0/308 [00:00<?, ?it/s]

w = identity >>  0.7792207792207793
validating on 308 subject --> object associations


  0%|          | 0/308 [00:00<?, ?it/s]

w = jacobian >>  0.9805194805194806

saving results
############################################################################################################
relation_id >> P101
------------------------------------------------------------------------------------------------------
unique objects:  83 [' science', ' theology', ' sociology', ' drawing', ' algebra']
18.895122528076172 >> [(' military', 32.762), (' nuclear', 32.546), (' Hindu', 32.369), (' the', 32.307), (' water', 32.172)]
P101 >> number of requests in counterfact = 545
Checking correct prediction with normal calculation ...


  0%|          | 0/545 [00:00<?, ?it/s]

Number of correctly predicted requests = 84
validating on 84 subject --> object associations


  0%|          | 0/84 [00:00<?, ?it/s]

w = identity >>  0.11904761904761904
validating on 84 subject --> object associations


  0%|          | 0/84 [00:00<?, ?it/s]

w = jacobian >>  0.047619047619047616

saving results
############################################################################################################
relation_id >> P19
------------------------------------------------------------------------------------------------------
unique objects:  229 [' Ireland', ' Dresden', ' Georgetown', ' Atlanta', ' Berlin']
19.185012817382812 >> [(' Vil', 25.314), (' Cord', 24.758), (' Sug', 24.613), (' Ram', 24.533), (' Cas', 24.528)]
P19 >> number of requests in counterfact = 779
Checking correct prediction with normal calculation ...


  0%|          | 0/779 [00:00<?, ?it/s]

Number of correctly predicted requests = 129
validating on 129 subject --> object associations


  0%|          | 0/129 [00:00<?, ?it/s]

w = identity >>  0.031007751937984496
validating on 129 subject --> object associations


  0%|          | 0/129 [00:00<?, ?it/s]

w = jacobian >>  0.4573643410852713

saving results
############################################################################################################
relation_id >> P37
------------------------------------------------------------------------------------------------------
unique objects:  44 [' Somali', ' Filipino', ' Danish', ' Polish', ' Bulgarian']
26.53314971923828 >> [(' Scottish', 70.446), (' Mari', 67.963), (' Somali', 67.924), (' Filipino', 67.904), (' Czech', 67.814)]
P37 >> number of requests in counterfact = 891
Checking correct prediction with normal calculation ...


  0%|          | 0/891 [00:00<?, ?it/s]

Number of correctly predicted requests = 155
validating on 155 subject --> object associations


  0%|          | 0/155 [00:00<?, ?it/s]

w = identity >>  0.7870967741935484
validating on 155 subject --> object associations


  0%|          | 0/155 [00:00<?, ?it/s]

w = jacobian >>  0.8064516129032258

saving results
############################################################################################################
relation_id >> P36
------------------------------------------------------------------------------------------------------
unique objects:  62 [' Delhi', ' Kingston', ' Lyon', ' Dresden', ' Georgetown']
23.43290138244629 >> [(' Istanbul', 60.317), (' Moscow', 57.738), (' Melbourne', 55.901), (' Alexandria', 55.848), (' Hamburg', 55.506)]
P36 >> number of requests in counterfact = 139
Checking correct prediction with normal calculation ...


  0%|          | 0/139 [00:00<?, ?it/s]

Number of correctly predicted requests = 68
validating on 68 subject --> object associations


  0%|          | 0/68 [00:00<?, ?it/s]

w = identity >>  0.38235294117647056
validating on 68 subject --> object associations


  0%|          | 0/68 [00:00<?, ?it/s]

w = jacobian >>  0.9411764705882353

saving results
############################################################################################################


: 