In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')

import torch
import copy
import numpy as np
import matplotlib.pyplot as plt
from src import models, data
from src.metrics import recall

In [None]:
device = "cuda:0"
mt = models.load_model("gptj", device=device)
print(f"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}")

In [None]:
dataset = data.load_dataset()
print('\n'.join([d.name for d in dataset]))
datums =[d for d in dataset if d.name == "country capital city"][0]
print(datums)
# capital_cities.__dict__.keys()
len(datums.samples)

In [None]:
indices = np.random.choice(range(len(datums.samples)), 3, replace=False)
samples = [datums.samples[i] for i in indices]

training_samples = copy.deepcopy(datums.__dict__)
training_samples["samples"] = samples
training_samples = data.Relation(**training_samples)

training_samples.samples

In [None]:
from src.operators import JacobianIclMeanEstimator

mean_estimator = JacobianIclMeanEstimator(
    mt=mt,
    h_layer=12,
    bias_scale_factor=0.2       # so that the bias doesn't knock out the prediction too much in the direction of training examples
) 

operator = mean_estimator(training_samples)

In [None]:
operator("United States", k = 10).predictions

In [None]:
test_samples = list(set(datums.samples) - set(training_samples.samples))

predictions = []
target = []

for sample in test_samples:
    cur_predictions = operator(sample.subject, k = 5).predictions
    predictions.append([
        p.token for p in cur_predictions
    ])
    target.append(sample.object)

recall(predictions, target)

In [None]:
from src.attributelens.attributelens import Attribute_Lens
import src.attributelens.utils as lens_utils

In [None]:
lens = Attribute_Lens(mt=mt, top_k=10)
att_info = lens.apply_attribute_lens(
    prompt="Germany, Germany, Germany. Canberra is the capital of the country of",
    relation_operator=operator
)
print('prediction:', att_info['nextwords'][-1])

In [None]:
import plotly.graph_objs as go
    
f = lens_utils.visualize_attribute_lens(
    att_info, layer_skip=2, must_have_layers=[15, 25],
    # expected_answers=[' Beijing']
)
f