In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
import copy
import numpy as np
import matplotlib.pyplot as plt
from src import models, data
from src.metrics import recall

In [30]:
device = "cuda:0"
mt = models.load_model("gptj", device=device)
print(f"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}")

dtype: torch.float16, device: cuda:0, memory: 12219206136


In [4]:
dataset = data.load_dataset()
capital_cities = dataset[0]
print(capital_cities)
# capital_cities.__dict__.keys()
len(capital_cities.samples)

Relation(name='capital city', prompt_templates=['The capital city of {} is', 'The political capital of {} is', 'The seat of government for {} is', 'The government of {} is centered in'], samples=[RelationSample(subject='United States', object='Washington D.C.'), RelationSample(subject='Canada', object='Ottawa'), RelationSample(subject='Mexico', object='Mexico City'), RelationSample(subject='Brazil', object='Bras\\u00edlia'), RelationSample(subject='Argentina', object='Buenos Aires'), RelationSample(subject='Chile', object='Santiago'), RelationSample(subject='Peru', object='Lima'), RelationSample(subject='Colombia', object='Bogot\\u00e1'), RelationSample(subject='Venezuela', object='Caracas'), RelationSample(subject='Spain', object='Madrid'), RelationSample(subject='France', object='Paris'), RelationSample(subject='Germany', object='Berlin'), RelationSample(subject='Italy', object='Rome'), RelationSample(subject='Russia', object='Moscow'), RelationSample(subject='China', object='Beijing

24

In [5]:
indices = np.random.choice(range(len(capital_cities.samples)), 3, replace=False)
samples = [capital_cities.samples[i] for i in indices]

training_samples = copy.deepcopy(capital_cities.__dict__)
training_samples["samples"] = samples
training_samples = data.Relation(**training_samples)

training_samples.samples

[RelationSample(subject='Chile', object='Santiago'),
 RelationSample(subject='Peru', object='Lima'),
 RelationSample(subject='Japan', object='Tokyo')]

In [6]:
from src.operators import JacobianIclMeanEstimator

mean_estimator = JacobianIclMeanEstimator(
    mt=mt,
    h_layer=12,
    bias_scale_factor=0.2       # so that the bias doesn't knock out the prediction too much in the direction of training examples
) 

operator = mean_estimator(training_samples)

In [7]:
operator("United States", k = 10).predictions

[PredictedObject(token=' Washington', prob=0.38452938199043274),
 PredictedObject(token=' Bog', prob=0.18163883686065674),
 PredictedObject(token=' Lima', prob=0.07002800703048706),
 PredictedObject(token=' Mexico', prob=0.05539671331644058),
 PredictedObject(token=' Suc', prob=0.03867314010858536),
 PredictedObject(token=' Bras', prob=0.03521229326725006),
 PredictedObject(token='Washington', prob=0.0156253594905138),
 PredictedObject(token=' Buenos', prob=0.013055476360023022),
 PredictedObject(token='Bu', prob=0.01216904167085886),
 PredictedObject(token=' Argentina', prob=0.00804324634373188)]

In [8]:
test_samples = list(set(capital_cities.samples) - set(training_samples.samples))

predictions = []
target = []

for sample in test_samples:
    cur_predictions = operator(sample.subject, k = 5).predictions
    predictions.append([
        p.token for p in cur_predictions
    ])
    target.append(sample.object)

recall(predictions, target)

[0.8095238095238095,
 0.9047619047619048,
 0.9523809523809523,
 0.9523809523809523,
 0.9523809523809523]

In [54]:
from attributelens.attributelens import Attribute_Lens
import attributelens.utils as lens_utils

In [90]:
lens = Attribute_Lens(mt=mt, top_k=10)
att_info = lens.apply_attribute_lens(
    subject="People's Republic of China",
    relation_operator=operator
)

subject  :  People's Republic of China
relation :  The capital city of {} is
subject mapping:  4 9  >>  [' People', "'s", ' Republic', ' of', ' China']


In [92]:
lens_utils.visualize_attribute_lens(
    att_info, layer_skip=2, must_have_layers=[15, 25],
    expected_answers=[' Seattle', " Washington"]
)

must_have_layers:  [15, 25]
expected_answers:  [' Seattle', ' Washington']
