In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import os
import json
import sys
import numpy as np
sys.path.append("..")
import copy

In [None]:
from src import models, data, operators, utils, functional, metrics, lens
from src.utils import logging_utils
import logging
import torch
import baukit

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format = logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout
)

In [None]:
mt = models.load_model(name = "gptj", fp16 = True, device = "cuda")

In [None]:
relation = data.load_dataset().filter(relation_names=["country capital city"])[0].set(prompt_templates=[" {}:"])
train, test = relation.split(5)

In [None]:
icl_prompt = functional.make_prompt(
    prompt_template = train.prompt_templates[0],
    subject = "{}",
    examples = train.samples,
    mt = mt
)
print(icl_prompt)

In [None]:
test = functional.filter_relation_samples_based_on_provided_fewshots(
    mt = mt, 
    test_relation=test,
    prompt_template = icl_prompt,
    batch_size=4
)
len(test.samples)

### Current Method => Calculate $b_r$ and $W_r$ individually and average them.

In [None]:
estimator = operators.JacobianIclMeanEstimator(
    mt = mt, h_layer=7, beta=0.2
)
operator = estimator(train)

In [None]:
# calculate faithfulness
def evaluate_operator(operator, test_samples):
    pred_objects = []
    test_objects = []
    for sample in test_samples:
        test_objects.append(sample.object)
        preds = operator(sample.subject, k=3)
        pred = str(preds.predictions[0])
        print(f"{sample.subject=} -> {sample.object=} | {pred=}")
        pred_objects.append([p.token for p in preds.predictions])

    recall = metrics.recall(pred_objects, test_objects)
    return recall

evaluate_operator(operator, test.samples)

In [None]:
lens.logit_lens(mt = mt, h = operator.bias, get_proba=True, k = 10)

In [None]:
models.determine_layer_paths(mt, layers=["emb", "ln_f"])

### At `ln_f`

In [None]:
estimator_lnf = operators.JacobianIclMeanEstimator(
    mt = mt, 
    h_layer=7, 
    z_layer="ln_f", 
    beta=1
)
operator_lnf = estimator_lnf(train)

In [None]:
evaluate_operator(operator_lnf, test.samples)

In [None]:
lens.logit_lens(mt = mt, h = operator_lnf.bias, get_proba=True, k = 10, after_layer_norm=True)

In [None]:
# operator_dct = deepcopy(operator.__dict__)

# set beta and omega such that beta/omega = c
c = 0.2
beta = 0.1
# omega = beta/c
omega = 1

# omega = 5
# beta = omega * c

print(f"{beta=} | {omega=}")   

operator_dct = operator_lnf.__dict__.copy()
operator_dct["beta"] = beta
operator_dct["weight"] = operator.weight * omega
operator_no_beta_lnf = operators.LinearRelationOperator(**operator_dct)

evaluate_operator(operator_no_beta_lnf, test.samples)

### Get rid of the $\beta$ by setting $\beta = 1$

In [None]:
# operator_dct = deepcopy(operator.__dict__)

# set beta and omega such that beta/omega = c
c = 0.2
beta = 0.0002
omega = beta/c

omega = 5
beta = omega * c

operator_dct = operator.__dict__.copy()
operator_dct["beta"] = beta
operator_dct["weight"] = operator.weight * omega
operator_no_beta = operators.LinearRelationOperator(**operator_dct)

evaluate_operator(operator_no_beta, test.samples)

In [None]:
1 / operator_dct["beta"]

In [None]:
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = train.prompt_templates[0],
    subjects = [sample.subject for sample in relation.samples],
    h_layer= operator.h_layer,
    z_layer=-1,
    batch_size=4,
    examples= train.samples
)

In [None]:
for sample in train.samples:
    subj = sample.subject
    obj = sample.object
    pred = functional.predict_next_token(
        mt = mt,
        prompt = functional.make_prompt(
            prompt_template = train.prompt_templates[0],
            subject = subj,
            examples = train.samples,
            mt = mt
        )
    )[0][0]
    h_norm = hs_and_zs.h_by_subj[subj].norm().item()
    z_norm = hs_and_zs.z_by_subj[subj].norm().item()
    print(f"{subj=} -> {obj=} | {h_norm=} | {z_norm=} || {pred=}")

In [None]:
h_norms = []
jh_norms = []
for subj in hs_and_zs.h_by_subj.keys():
    h = hs_and_zs.h_by_subj[subj]
    jh = operator.weight @ h
    h_norms.append(h.norm().item())
    jh_norms.append(jh.norm().item())
    print(f"{subj=} | {h.norm()=} | {jh.norm()=} | {h.mean()=} | {h.std()=}")

print(f"h_norms: {np.mean(h_norms):.2f} +/- {np.std(h_norms):.2f}")
print(f"jh_norms: {np.mean(jh_norms):.2f} +/- {np.std(jh_norms):.2f}")

### $b_r = \mathbf{o}_{mean} - J\mathbf{s}_{mean}$

In [None]:
z_mean = torch.stack([hs_and_zs.z_by_subj[sample.subject] for sample in train.samples]).mean(dim = 0)
h_mean = torch.stack([hs_and_zs.h_by_subj[sample.subject] for sample in train.samples]).mean(dim = 0)

bias_mean = z_mean - operator.weight @ h_mean
print(torch.dist(bias_mean, operator.bias))

In [None]:
operator_dct = operator.__dict__.copy()
operator_dct["beta"] = 1 #.2
operator_dct["bias"] = bias_mean
operator_bias_J = operators.LinearRelationOperator(**operator_dct)

evaluate_operator(operator_bias_J, test.samples)

### $b_r = F(\mathbf{s}_{mean}) -J\mathbf{s}_{mean}$

In [None]:
h_layer_name, z_layer_name = models.determine_layer_paths(mt, layers = [operator.h_layer, operator.z_layer])

In [None]:
hs_and_zs.h_by_subj.keys()

In [None]:
def get_intervention(h, int_layer, subj_idx):
    def edit_output(output, layer):
        if(layer != int_layer):
            return output
        functional.untuple(output)[:, subj_idx] = h
        return output
    return edit_output

subject = "Russia"
prompt = icl_prompt.format(subject)

h_index, inputs = functional.find_subject_token_index(
    mt=mt,
    prompt=prompt,
    subject=subject,
)

with baukit.TraceDict(
    mt.model, layers = [h_layer_name, z_layer_name],
    edit_output=get_intervention(h_mean, h_layer_name, h_index)
) as traces:
    outputs = mt.model(
        input_ids = inputs.input_ids,
        attention_mask = inputs.attention_mask,
    )

lens.interpret_logits(
    mt = mt, 
    logits = outputs.logits[0][-1], 
    get_proba=True
)

In [None]:
s = functional.untuple(traces[h_layer_name].output)[0][h_index]
s.shape

In [None]:
s.norm().item()

In [None]:
h_mean.norm().item()

In [None]:
z_mean_F = traces[z_layer_name].output[0][-1][-1]
# lens.logit_lens(mt = mt, h = z_mean_F, get_proba=True)
bias_F = z_mean_F - operator.weight @ h_mean
print(torch.dist(bias_F, operator.bias))
lens.logit_lens(mt = mt, h = bias_F, get_proba=True)

In [None]:
operator_dct = operator.__dict__.copy()
operator_dct["beta"] = 1 #.2
operator_dct["bias"] = bias_F
operator_bias_J = operators.LinearRelationOperator(**operator_dct)

evaluate_operator(operator_bias_J, test.samples)

### Make all the $|\mathbf{o}|$ similar

In [None]:
z_s = torch.stack([hs_and_zs.z_by_subj[sample.subject] for sample in train.samples])
min_norm = z_s.norm(dim = 1).min()
z_s = torch.stack([(z*min_norm)/z.norm() for z in z_s])

z_mean = z_s.mean(dim = 0)
bias_mean = z_mean - operator.weight @ h_mean

print(torch.dist(bias_mean, operator.bias))
lens.logit_lens(mt = mt, h = bias_mean, get_proba=True)

In [None]:
operator_dct = operator.__dict__.copy()
operator_dct["beta"] = 1 #.2
operator_dct["bias"] = bias_mean
operator_similar_o = operators.LinearRelationOperator(**operator_dct)

evaluate_operator(operator_similar_o, test.samples)

### Automatically *tune* $\beta$ for each training sample 

In [None]:
training_hs = torch.stack([hs_and_zs.h_by_subj[sample.subject] for sample in train.samples])
training_zs = torch.stack([hs_and_zs.z_by_subj[sample.subject] for sample in train.samples])

biases = []
for sample in train.samples:
    subj = sample.subject
    obj = sample.object
    h = hs_and_zs.h_by_subj[subj]
    z = hs_and_zs.z_by_subj[subj]
    b_sample = z - operator.weight @ h
    print(f"{subj=} | h_norm={h.norm()} | z_norm={z.norm()} || b_norm={b_sample.norm()}")
    for beta in np.linspace(0, 1, 10):
        z_est = operator.weight @ h + b_sample * beta
        pred, _ = lens.logit_lens(mt = mt, h = z_est, get_proba=True, k = 3)
        print(f"{obj=} | {beta=} | z_est={z_est.norm()} | {pred=}")
        top_token = pred[0][0]
        if functional.is_nontrivial_prefix(prediction=top_token, target=sample.object):
            biases.append(b_sample * beta)
            break
    print()

In [None]:
operator = estimator(train)

In [None]:
biases = []
for sample, approx in zip(train.samples, operator.metadata["approxes"]):
    subj = sample.subject
    obj = sample.object
    h = approx.h
    z = approx.z
    print(f"{subj=} | h_norm={h.norm()} | z_norm={z.norm()} || b_norm={approx.bias.norm()}")
    for beta in np.linspace(0, 1, 10):
        z_est = approx.weight @ h + approx.bias * beta
        pred, _ = lens.logit_lens(mt = mt, h = z_est, get_proba=True, k = 3)
        print(f"{obj=} | {beta=} | z_est={z_est.norm()} | {pred=}")
        top_token = pred[0][0]
        if functional.is_nontrivial_prefix(prediction=top_token, target=sample.object):
            biases.append(b_sample * beta)
            break
    print()

In [None]:
b_mean = torch.stack(biases).mean(dim = 0)

operator_dct = operator.__dict__.copy()
operator_dct["beta"] = 1 #.2
operator_dct["bias"] = b_mean
operator_auto_beta = operators.LinearRelationOperator(**operator_dct)

evaluate_operator(operator_auto_beta, test.samples)

### Drop $J\mathbf{s}$ entirely from bias estimation. (basically the corner method?)

In [None]:
biases = torch.stack([approx.bias[0] for approx in operator.metadata["approxes"]])
min_norm = biases.norm(dim = 1).min()

biases = torch.stack([(b*min_norm)/b.norm() for b in biases])
b_mean = biases.mean(dim = 0)

print(torch.dist(b_mean, operator.bias))
lens.logit_lens(mt = mt, h = b_mean, get_proba=True)

In [None]:
b_mean.norm().item(), operator.bias.norm().item()

In [None]:
operator_dct = operator.__dict__.copy()
operator_dct["beta"] = .2 #.2
operator_dct["bias"] = b_mean
operator_dropped_js = operators.LinearRelationOperator(**operator_dct)

evaluate_operator(operator_dropped_js, test.samples)

In [None]:
translation_estimator = operators.OffsetEstimatorBaseline(
    mt = mt,
    h_layer = operator.h_layer,
    z_layer=operator.z_layer,
    mode = "icl",
)

translation = translation_estimator(
    relation.set(
        samples = train.samples + test.samples
    )
)
corner = translation.bias

print(torch.dist(corner, operator.bias))
lens.logit_lens(mt = mt, h = corner, get_proba=True)

In [None]:
evaluate_operator(translation, test.samples)

In [None]:
operator_dct = operator.__dict__.copy()
operator_dct["beta"] = 1 #.2
operator_dct["bias"] = corner
operator_corner = operators.LinearRelationOperator(**operator_dct)

evaluate_operator(operator_corner, test.samples)

In [None]:
from scripts.explain_beta import TrialResult, AllTrialResults

In [None]:
beta_path = "../results/explain_beta/gptj"

results = {}
for relation_folder in os.listdir(beta_path):
    relation_path = os.path.join(beta_path, relation_folder)
    for n_train in os.listdir(relation_path):
        n_train_path = os.path.join(relation_path, n_train)
        for file in os.listdir(n_train_path):
            with open(f"{n_train_path}/{file}") as f:
                data = json.load(f)
                data = AllTrialResults.from_dict(data)
                if(data.relation_name not in results):
                    results[data.relation_name] = {}
                _n_train = len(data.trials[0].train_samples)
                if _n_train not in results[data.relation_name]:
                    results[data.relation_name][_n_train] = data                   
                else:
                    results[data.relation_name][_n_train].trials.extend(data.trials)

In [None]:
relations = list(results.keys())
relation = relations[0]
# relations

In [None]:
train_options = sorted(list(results[relation].keys()))
train_options

In [None]:
bias_norms = [
    np.array([trial.bias_norm for trial in results[relation][n_train].trials])
    for n_train in train_options
]

means = np.array([np.mean(bias_norm) for bias_norm in bias_norms])
stds = np.array([np.std(bias_norm) for bias_norm in bias_norms])

In [None]:
plt.rcdefaults()
#####################################################################################
plt.rcParams["figure.dpi"] = 200
plt.rcParams["font.family"] = "Times New Roman"

SMALL_SIZE = 18
MEDIUM_SIZE = 20
BIGGER_SIZE = 24

plt.rc("font", size=SMALL_SIZE)  # controls default text sizes
plt.rc("axes", titlesize=BIGGER_SIZE)  # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE+5)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=MEDIUM_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE)  # legend fontsize
plt.rc("axes", titlesize=BIGGER_SIZE)  # fontsize of the figure title
plt.rc("figure", titlesize=BIGGER_SIZE)  # fontsize of the figure title
#####################################################################################

plt.plot(train_options, means)
plt.fill_between(train_options, means - stds, means + stds, alpha=0.2)
# plt.ylim(bottom = 270)
plt.xticks(train_options)
plt.xlabel("n_train")
plt.ylabel("bias norm")