In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import os
import json
import sys
import numpy as np
sys.path.append("..")
import copy

In [3]:
from src import models, data, operators, editors, functional, metrics, lens
from src.utils import logging_utils, experiment_utils
import logging
import torch
import baukit

logging_utils.configure(level=logging.INFO)

In [4]:
mt = models.load_model("gptj", fp16=True, device="cuda")

2023-07-25 22:03:47 src.models INFO     loading EleutherAI/gpt-j-6B (device=cuda, fp16=True)
2023-07-25 22:03:56 src.models INFO     dtype: torch.float16, device: cuda:0, memory: 12219206136


In [5]:
dataset = data.load_dataset()

In [6]:
##################################
layer = 3
rank = 70
beta = 2.25
n_train = 5
selected_relations = [r for r in dataset if r.name in [
        "country capital city",
        "country currency",
        "country language"
    ]
]

experiment_utils.set_seed(123456)
##################################


relation_properties = {}

for relation in selected_relations:
    train, test = relation.split(n_train)
    prompt_template = relation.prompt_templates[0]

    relation_prompt = functional.make_prompt(
        mt=mt,
        prompt_template=prompt_template,
        subject="{}",
        examples=train.samples,
    )

    estimator = operators.JacobianIclMeanEstimator(
        mt = mt, h_layer=layer, beta=beta, rank=rank
    )
    operator = estimator(train)

    relation_properties[relation.name] = {
        "train": train,
        "prompt_template": prompt_template,
        "prompt": relation_prompt,
        "operator": operator,
    }

2023-07-25 22:03:56 src.utils.experiment_utils INFO     setting all seeds to 123456


In [7]:
for relation_name in relation_properties:
    print("-----------------------------------")
    print(relation_name)
    print("-----------------------------------")
    print(f"{[sample.__str__() for sample in relation_properties[relation_name]['train'].samples]}")
    print(relation_properties[relation_name]['prompt'])

-----------------------------------
country capital city
-----------------------------------
['Pakistan -> Islamabad', 'Argentina -> Buenos Aires', 'Peru -> Lima', 'Australia -> Canberra', 'Germany -> Berlin']
<|endoftext|>The capital city of Pakistan is Islamabad
The capital city of Argentina is Buenos Aires
The capital city of Peru is Lima
The capital city of Australia is Canberra
The capital city of Germany is Berlin
The capital city of {} is
-----------------------------------
country currency
-----------------------------------
['Norway -> Krone', 'Russia -> Ruble', 'Argentina -> Peso', 'New Zealand -> Dollar', 'Czech Republic -> Koruna']
<|endoftext|>The official currency of Norway is the Krone
The official currency of Russia is the Ruble
The official currency of Argentina is the Peso
The official currency of New Zealand is the Dollar
The official currency of Czech Republic is the Koruna
The official currency of {} is the
-----------------------------------
country language
-----

In [27]:
##################################################
source_subject = "Chile"
targ_prop_for_subj = {
    "country capital city": "Italy",
    "country currency": "Japan",
    "country language": "Brazil"      
}
##################################################

for prop in targ_prop_for_subj:
    prompt = relation_properties[prop]['prompt'].format(source_subject)
    obj = functional.predict_next_token(mt = mt, prompt = prompt, k=3)[0]
    print(f"{source_subject} -- {prop} -- {obj[0].__str__()}")
print("=================================")

for prop, subj in targ_prop_for_subj.items():
    prompt = relation_properties[prop]['prompt'].format(subj)
    obj = functional.predict_next_token(mt = mt, prompt = prompt, k=3)[0]
    print(f"{subj} -- {prop} -- {obj[0].__str__()}")

Chile -- country capital city --  Santiago (p=0.996)
Chile -- country currency --  Pes (p=0.757)
Chile -- country language --  Spanish (p=0.937)
Italy -- country capital city --  Rome (p=0.965)
Japan -- country currency --  Yen (p=0.980)
Brazil -- country language --  Portuguese (p=0.877)


In [28]:
for relation_name in relation_properties:
    relation_properties[relation_name]["W_inv"] = functional.low_rank_pinv(
        matrix = relation_properties[relation_name]["operator"].weight,
        rank=rank,
    )

In [29]:
def get_delta_s(
    prop, 
    source_subject, 
    target_subject,
    fix_latent_norm = None,
):
    w_p_inv = relation_properties[prop]["W_inv"]
    hs_and_zs = functional.compute_hs_and_zs(
        mt = mt,
        prompt_template = relation_properties[prop]["prompt_template"],
        subjects = [source_subject, target_subject],
        h_layer= layer,
        z_layer=-1,
        examples= relation_properties[prop]["train"].samples,
    )

    z_source = hs_and_zs.z_by_subj[source_subject]
    z_target = hs_and_zs.z_by_subj[targ_prop_for_subj[prop]]
    # print(z_target.norm().item(), z_source.norm().item())

    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0
    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0
    # print(z_target.norm().item(), z_source.norm().item())

    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())
    # print(delta_s.norm().item())

    return delta_s, hs_and_zs

In [30]:
delta_s_by_prop = {}
for relation_name in targ_prop_for_subj:
    delta_s, hs_and_zs = get_delta_s(
        prop = relation_name,
        source_subject = source_subject,
        target_subject = targ_prop_for_subj[relation_name],
        fix_latent_norm=250
    )

    delta_s_by_prop[relation_name] = {
        "delta_s": delta_s,
        "hs_and_zs": hs_and_zs,
    }

In [31]:
drr = [prop["delta_s"] for relation, prop in delta_s_by_prop.items()]
for i in range(len(drr)):
    for j in range(len(drr)):
        print(f"{i} -- {j} -- {drr[i].T @ drr[j]}, {torch.cosine_similarity(drr[i], drr[j], dim=0).item()}")

0 -- 0 -- 1546.0, 1.0
0 -- 1 -- 1039.0, 0.48486328125
0 -- 2 -- 295.25, 0.2366943359375
1 -- 0 -- 1039.0, 0.48486328125
1 -- 1 -- 2972.0, 0.9990234375
1 -- 2 -- 484.0, 0.27978515625
2 -- 0 -- 295.25, 0.2366943359375
2 -- 1 -- 484.0, 0.27978515625
2 -- 2 -- 1006.5, 1.0


In [32]:
cumulative_delta_s = torch.zeros_like(delta_s_by_prop[prop]["delta_s"])
for relation_name in delta_s_by_prop:
    cumulative_delta_s += delta_s_by_prop[relation_name]["delta_s"]
    print(delta_s_by_prop[relation_name]["delta_s"].norm())
cumulative_delta_s.shape, cumulative_delta_s.norm().item()

tensor(39.3125, device='cuda:0', dtype=torch.float16)
tensor(54.5000, device='cuda:0', dtype=torch.float16)
tensor(31.7188, device='cuda:0', dtype=torch.float16)


(torch.Size([4096]), 95.6875)

In [36]:
# prop = "country capital city"
# prop = "country currency"
prop = "country language"

delta_s, hs_and_zs = get_delta_s(
    prop = prop,
    source_subject = source_subject,
    target_subject = targ_prop_for_subj[prop],
    fix_latent_norm = 250
)

def get_intervention(h, int_layer, subj_idx):
    def edit_output(output, layer):
        if(layer != int_layer):
            return output
        functional.untuple(output)[:, subj_idx] = h
        return output
    return edit_output

prompt = relation_properties[prop]["prompt"].format(source_subject)

h_index, inputs = functional.find_subject_token_index(
    mt=mt,
    prompt=prompt,
    subject=source_subject,
)

h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])

with baukit.TraceDict(
    mt.model, layers = [h_layer, z_layer],
    edit_output=get_intervention(
        # h = hs_and_zs.h_by_subj[source_subject]
        h = hs_and_zs.h_by_subj[source_subject] + delta_s,
        # h = hs_and_zs.h_by_subj[source_subject] + cumulative_delta_s, 
        int_layer = h_layer, 
        subj_idx = h_index
    )
) as traces:
    outputs = mt.model(
        input_ids = inputs.input_ids,
        attention_mask = inputs.attention_mask,
    )

lens.interpret_logits(
    mt = mt, 
    logits = outputs.logits[0][-1], 
    get_proba=True
)

[(' Portuguese', 0.829),
 (' Portug', 0.09),
 (' Brazilian', 0.025),
 (' port', 0.024),
 (' Spanish', 0.007),
 (' English', 0.004),
 (' Port', 0.003),
 ('\n', 0.002),
 (' Brazil', 0.002),
 (' French', 0.001)]