### <span style="color:red">!caution</span>
Attribute Lens visualizations are implemented with `plotly`. Currently github can't render `plotly` figures.

In [1]:
import os
import sys
sys.path.append('..')

import torch
from src import models, data
from src.attributelens.attributelens import Attribute_Lens
import src.attributelens.utils as lens_utils
import numpy as np

In [2]:
# LREs are caches for GPT-J. 
device = "cuda:0"
mt = models.load_model("gptj", device=device, fp16=True)
print(f"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}")

dtype: torch.float16, device: cuda:0, memory: 12219206136


In [3]:
# downloading cached LREs
# ! pip install gdown
# ! gdown --no-check-certificate --folder https://drive.google.com/drive/u/0/folders/1jAxqpACq5-gDbHG3cFhrL8eC65UtPcUM

In [4]:
# prompt = mt.tokenizer.eos_token + " " + "present-day Turkey was home to important Neolithic sites like"
prompt =  mt.tokenizer.eos_token + " " + "The United States of America (U.S.A. or USA), commonly known as the United States"
prompt

'<|endoftext|> The United States of America (U.S.A. or USA), commonly known as the United States'

## Attribute Lens

In [5]:
from src.operators import LinearRelationOperator

def load_cached_lre(relation_name, path = "../LRE_cached"):
    approx = np.load(os.path.join(path, relation_name.replace(" ", "_") + ".npz"), allow_pickle=True)
    approx_dict = {}
    for key,value in approx.items():
        if key in ["h", "z", "weight", "bias"]:
            approx_dict[key] = torch.from_numpy(value).cuda()
        else:
            approx_dict[key] = value.item()
    return LinearRelationOperator(
        mt = mt, 
        weight = approx_dict["weight"],
        bias = approx_dict["bias"],
        h_layer = approx_dict["h_layer"],
        z_layer = approx_dict["z_layer"],
        prompt_template = approx_dict["prompt_template"],
        beta = approx_dict["beta"]
    )

In [6]:
# Uncomment the block and print `relation_names` to see all the options
# dataset = data.load_dataset()
# relation_names = [r.name for r in dataset.relations]
# relation_names

In [7]:
relation_names = [
    "country capital city",
    "country largest city",
    "country currency",
    "country language"
]

In [8]:
lres = {
    relation_name: load_cached_lre(relation_name = relation_name)
    for relation_name in relation_names
}

In [9]:
import time

lens = Attribute_Lens(mt=mt, top_k=10)

colorscales = ["oranges", "purples", "greens", "reds"]

for relation_name, colorscale in zip(relation_names, colorscales):
    print("----------------------------------------")
    print(relation_name, " -- ", colorscale)
    print("----------------------------------------")
    att_info = lens.apply_attribute_lens(
        prompt=prompt,
        relation_operator=lres[relation_name]
    )
    att_info['subject_range']= (1, att_info['subject_range'][-1]) # ignore the first EOS token
    p = lens_utils.visualize_attribute_lens(
        att_info, layer_skip=2, must_have_layers=[],
        colorscale= colorscale
    )
    p.layout.margin = dict(l=0, r=0, t=0, b=0)
    p.show()
    
    time.sleep(1)

----------------------------------------
country capital city  --  oranges
----------------------------------------


----------------------------------------
country largest city  --  purples
----------------------------------------


----------------------------------------
country currency  --  greens
----------------------------------------


----------------------------------------
country language  --  reds
----------------------------------------


## Logit Lens

In [11]:
lens = Attribute_Lens(mt=mt, top_k=10)
att_info = lens.apply_attribute_lens(
    prompt=prompt,
    relation_operator=None # Will use Identity if set to None. Basically Logit Lens
)
att_info['subject_range']= (1, att_info['subject_range'][-1]) # ignore the first EOS token
# print('prediction:', att_info['nextwords'][-1])
p = lens_utils.visualize_attribute_lens(
    att_info, layer_skip=2, must_have_layers=[],
)
p.layout.margin = dict(l=0, r=0, t=0, b=0)
p.show()