In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import sys
sys.path.append('..')

import torch
from src import models, data
from src.attributelens.attributelens import Attribute_Lens
import src.attributelens.utils as lens_utils
from src.operators import JacobianIclMeanEstimator
import numpy as np
from src.utils import experiment_utils, logging_utils
from baukit import Menu, show

In [3]:
device = "cuda:0"
mt = models.load_model("gptj", device=device, fp16=True)
print(f"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}")

dtype: torch.float16, device: cuda:0, memory: 12219206136


In [4]:
dataset = data.load_dataset()

relation_names = [r.name for r in dataset]
relation_options = Menu(choices = relation_names, value = relation_names)
show(relation_options)

In [5]:
relation_name = relation_options.value
relation = dataset.filter(relation_names=[relation_name])[0]
print(f"{relation.name} -- {len(relation.samples)} samples")

country capital city -- 24 samples


In [21]:
prompt = mt.tokenizer.eos_token + " " + "The region that is now China has been inhabited since the Paleolithic era"
prompt

'<|endoftext|> The region that is now China has been inhabited since the Paleolithic era'

In [22]:
lens = Attribute_Lens(mt=mt, top_k=10)
att_info = lens.apply_attribute_lens(
    prompt=prompt,
    relation_operator=None # Will use Identity if set to None. Basically Logit Lens
)
att_info['subject_range']= (1, att_info['subject_range'][-1]) # ignore the first EOS token
print('prediction:', att_info['nextwords'][-1])
p = lens_utils.visualize_attribute_lens(
    att_info, layer_skip=3, must_have_layers=[],
)
p.show(renderer='iframe')

prediction: .
must_have_layers:  []
expected_answers:  []


In [23]:
from src.operators import LinearRelationOperator

def load_cached_lre(relation_name, path = "../results/LRE_cached"):
    approx = np.load(os.path.join(path, relation_name.lower().replace(" ", "_") + ".npz"), allow_pickle=True)
    approx_dict = {}
    for key,value in approx.items():
        if key in ["h", "z", "weight", "bias"]:
            approx_dict[key] = torch.from_numpy(value).cuda()
        else:
            approx_dict[key] = value.item()
    return LinearRelationOperator(
        mt = mt, 
        weight = approx_dict["weight"],
        bias = approx_dict["bias"],
        h_layer = approx_dict["h_layer"],
        z_layer = approx_dict["z_layer"],
        prompt_template = approx_dict["prompt_template"],
        beta = approx_dict["beta"]
    )

In [24]:
lre = load_cached_lre(relation_name = relation_name)

In [26]:
lens = Attribute_Lens(mt=mt, top_k=10)
att_info = lens.apply_attribute_lens(
    prompt=prompt,
    relation_operator=lre
)
att_info['subject_range']= (1, att_info['subject_range'][-1]) # ignore the first EOS token
print('prediction:', att_info['nextwords'][-1])
p = lens_utils.visualize_attribute_lens(
    att_info, layer_skip=3, must_have_layers=[],
    colorscale="reds"
)
p.show(renderer='iframe')

prediction: .
must_have_layers:  []
expected_answers:  []


In [12]:
approx = np.load("../results/LRE_cached/country_capital_city.npz", allow_pickle=True)

In [15]:
lre.beta

array(2.25)