In [149]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [150]:
import matplotlib.pyplot as plt
import os
import json
import sys
import numpy as np
sys.path.append("..")
import copy

In [151]:
from src import models, data, operators, editors, functional, metrics, lens
from src.utils import logging_utils, experiment_utils
import logging
import torch
import baukit

logging_utils.configure(level=logging.INFO)

In [152]:
mt = models.load_model("gptj", fp16=True, device="cuda")

2023-07-28 17:12:39 src.models INFO     loading EleutherAI/gpt-j-6B (device=cuda, fp16=True)


2023-07-28 17:12:47 src.models INFO     dtype: torch.float16, device: cuda:0, memory: 12219206136


In [153]:
dataset = data.load_dataset()

In [181]:
##################################
layer = 3
rank = 170
beta = 2.25
n_train = 5
selected_relations = [r for r in dataset if r.name in [
        "person occupation",
        "name birthplace",
        "person university"
    ]
]

experiment_utils.set_seed(123456)
##################################


relation_properties = {}

for relation in selected_relations:
    train, test = relation.split(n_train)
    prompt_template = relation.prompt_templates[0]

    relation_prompt = functional.make_prompt(
        mt=mt,
        prompt_template=prompt_template,
        subject="{}",
        examples=train.samples,
    )

    estimator = operators.JacobianIclMeanEstimator(
        mt = mt, h_layer=layer, beta=beta, rank=rank
    )
    operator = estimator(train)

    relation_properties[relation.name] = {
        "train": train,
        "prompt_template": prompt_template,
        "prompt": relation_prompt,
        "operator": operator,
    }

2023-07-28 17:31:27 src.utils.experiment_utils INFO     setting all seeds to 123456




In [182]:
for relation_name in relation_properties:
    print("-----------------------------------")
    print(relation_name)
    print("-----------------------------------")
    print(f"{[sample.__str__() for sample in relation_properties[relation_name]['train'].samples]}")
    print(relation_properties[relation_name]['prompt'])

-----------------------------------
name birthplace
-----------------------------------
['Rohit -> India', 'Sakura -> Japan', 'Marco -> Italy', 'Hong -> China', 'Kraipob -> Thailand']
<|endoftext|>Rohit was born in the country of India
Sakura was born in the country of Japan
Marco was born in the country of Italy
Hong was born in the country of China
Kraipob was born in the country of Thailand
{} was born in the country of
-----------------------------------
person occupation
-----------------------------------
['Yakubu Gowon -> politician', 'Geraldine McNulty -> actor', 'Andrew Salkey -> poet', 'Wilhelm Magnus -> mathematician', 'Samuel Medary -> journalist']
<|endoftext|>Yakubu Gowon works as a politician
Geraldine McNulty works as a actor
Andrew Salkey works as a poet
Wilhelm Magnus works as a mathematician
Samuel Medary works as a journalist
{} works as a
-----------------------------------
person university
-----------------------------------
['Ursula K. Le Guin -> Columbia Univer

In [None]:
for relation_name in relation_properties:
    relation_properties[relation_name]["W_inv"] = functional.low_rank_pinv(
        matrix = relation_properties[relation_name]["operator"].weight,
        rank=rank,
    )

In [242]:
##################################################
source_subject = "X"
targ_prop_for_subj = {
    "person occupation": "Sherlock Holmes",
    "name birthplace": "Jackie Chan",
    "person university": "Bill Gates",  
}
##################################################

for prop in targ_prop_for_subj:
    prompt = relation_properties[prop]['prompt'].format(source_subject)
    obj = functional.predict_next_token(mt = mt, prompt = prompt, k=3)[0]
    print(f"{source_subject} -- {prop} -- {obj[0].__str__()}")
print("=================================")

for prop, subj in targ_prop_for_subj.items():
    prompt = relation_properties[prop]['prompt'].format(subj)
    obj = functional.predict_next_token(mt = mt, prompt = prompt, k=3)[0]
    print(f"{subj} -- {prop} -- {obj[0].__str__()}")

X -- person occupation --  politician (p=0.067)
X -- name birthplace --  Russia (p=0.076)
X -- person university --  University (p=0.348)
Sherlock Holmes -- person occupation --  detective (p=0.523)
Jackie Chan -- name birthplace --  China (p=0.578)
Bill Gates -- person university --  Harvard (p=0.893)


In [243]:
def get_delta_s(
    prop, 
    source_subject, 
    target_subject,
    fix_latent_norm = None,
):
    w_p_inv = relation_properties[prop]["W_inv"]
    hs_and_zs = functional.compute_hs_and_zs(
        mt = mt,
        prompt_template = relation_properties[prop]["prompt_template"],
        subjects = [source_subject, target_subject],
        h_layer= layer,
        z_layer=-1,
        examples= relation_properties[prop]["train"].samples,
    )

    z_source = hs_and_zs.z_by_subj[source_subject]
    z_target = hs_and_zs.z_by_subj[targ_prop_for_subj[prop]]
    # print(z_target.norm().item(), z_source.norm().item())

    h_source = hs_and_zs.h_by_subj[source_subject]
    h_target = hs_and_zs.h_by_subj[targ_prop_for_subj[prop]]

    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0
    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0
    print(z_target.norm().item(), z_source.norm().item())

    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())
    
    print(f"h_source: {h_source.norm().item()} | h_target: {h_target.norm().item()}")
    print(f"inv_h_source: {(w_p_inv @ z_source).norm().item()} | inv_h_target: {(w_p_inv @ z_target).norm().item()}")
    print(delta_s.norm().item())

    return delta_s, hs_and_zs

In [244]:
delta_s_by_prop = {}
for relation_name in targ_prop_for_subj:
    delta_s, hs_and_zs = get_delta_s(
        prop = relation_name,
        source_subject = source_subject,
        target_subject = targ_prop_for_subj[relation_name],
        # fix_latent_norm=250
    )

    delta_s_by_prop[relation_name] = {
        "delta_s": delta_s,
        "hs_and_zs": hs_and_zs,
    }

312.75 363.5
h_source: 46.5 | h_target: 53.28125
inv_h_source: 149.5 | inv_h_target: 143.0
85.3125
241.125 319.0
h_source: 46.25 | h_target: 51.75
inv_h_source: 72.5625 | inv_h_target: 72.0625
53.96875
159.125 350.25
h_source: 43.875 | h_target: 50.03125
inv_h_source: 47.40625 | inv_h_target: 44.8125
47.78125


In [245]:
drr = [prop["delta_s"] for relation, prop in delta_s_by_prop.items()]
for i in range(len(drr)):
    for j in range(len(drr)):
        print(f"{i} -- {j} -- {(drr[i][None] @ drr[j][None].T).squeeze().item()}, {torch.cosine_similarity(drr[i], drr[j], dim=0).item()}")

0 -- 0 -- 7284.0, 0.9990234375
0 -- 1 -- 346.75, 0.07525634765625
0 -- 2 -- -11.109375, -0.0027179718017578125
1 -- 0 -- 346.75, 0.07525634765625
1 -- 1 -- 2912.0, 0.99951171875
1 -- 2 -- -113.5625, -0.044036865234375
2 -- 0 -- -11.109375, -0.0027179718017578125
2 -- 1 -- -113.5625, -0.044036865234375
2 -- 2 -- 2282.0, 0.99951171875


In [249]:
max_norm = np.array([delta_s_by_prop[relation_name]["delta_s"].norm().item() for relation_name in delta_s_by_prop.keys()]).max()

cumulative_delta_s = torch.zeros_like(delta_s_by_prop[prop]["delta_s"])
for relation_name in delta_s_by_prop:
    ds = delta_s_by_prop[relation_name]["delta_s"]
    ds = ds*max_norm / ds.norm()
    delta_s_by_prop[relation_name]["delta_s"] = ds
    cumulative_delta_s += ds
cumulative_delta_s /= 3
cumulative_delta_s.shape, cumulative_delta_s.norm().item()

(torch.Size([4096]), 49.71875)

In [250]:
# relation_names = [
#         "person occupation",
#         "name birthplace",
#         "person university"
# ]

# cumulative_delta_s = (
#     delta_s_by_prop[relation_names[0]]["delta_s"] + 
#     delta_s_by_prop[relation_names[1]]["delta_s"] + 
#     delta_s_by_prop[relation_names[2]]["delta_s"]
# )

In [251]:
# prop = "person occupation"
# prop = "name birthplace"
prop = "person university"

delta_s, hs_and_zs = get_delta_s(
    prop = prop,
    source_subject = source_subject,
    target_subject = targ_prop_for_subj[prop],
    fix_latent_norm = 250
)

def get_intervention(h, int_layer, subj_idx):
    def edit_output(output, layer):
        if(layer != int_layer):
            return output
        functional.untuple(output)[:, subj_idx] = h
        return output
    return edit_output

prompt = relation_properties[prop]["prompt"].format(source_subject)

h_index, inputs = functional.find_subject_token_index(
    mt=mt,
    prompt=prompt,
    subject=source_subject,
)

h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])

with baukit.TraceDict(
    mt.model, layers = [h_layer, z_layer],
    edit_output=get_intervention(
        # h = hs_and_zs.h_by_subj[source_subject]
        # h = hs_and_zs.h_by_subj[source_subject] + delta_s,
        h = hs_and_zs.h_by_subj[source_subject] + cumulative_delta_s, 
        int_layer = h_layer, 
        subj_idx = h_index
    )
) as traces:
    outputs = mt.model(
        input_ids = inputs.input_ids,
        attention_mask = inputs.attention_mask,
    )

lens.interpret_logits(
    mt = mt, 
    logits = outputs.logits[0][-1], 
    # get_proba=True
)

250.125 250.0
h_source: 43.875 | h_target: 50.03125
inv_h_source: 33.84375 | inv_h_target: 70.4375
62.375


[(' University', 20.953),
 (' the', 19.109),
 (' St', 18.812),
 (' Duke', 18.062),
 (' Oxford', 18.016),
 (' Stanford', 17.75),
 (' Brown', 17.703),
 (' Harvard', 17.672),
 (' City', 17.531),
 (' Sun', 17.469)]