In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import matplotlib.pyplot as plt
import os
import json
import sys
import numpy as np
sys.path.append("..")
import copy
from tqdm.auto import tqdm

In [None]:
from src import models, data, operators, utils, functional, metrics, lens
from src.utils import logging_utils, experiment_utils
import logging
import torch
import baukit
import random
import numpy as np
import torch
experiment_utils.set_seed(123456)

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format = logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout
)

In [None]:
#################################################
h_layer = 7
beta = 2.5
n_training = 10
#################################################

In [None]:
mt = models.load_model(name = "gptj", fp16 = True, device = "cuda")

In [None]:
relation = data.load_dataset().filter(relation_names=["country capital city"])[0].set(prompt_templates=[" {}:"])
train, test = relation.split(n_training)

In [None]:
icl_prompt = functional.make_prompt(
    prompt_template = train.prompt_templates[0],
    subject = "{}",
    examples = train.samples,
    mt = mt
)
print(icl_prompt)

In [None]:
test = functional.filter_relation_samples_based_on_provided_fewshots(
    mt = mt, 
    test_relation=test,
    prompt_template = icl_prompt,
    batch_size=4
)
len(test.samples)

In [None]:
estimator = operators.JacobianIclMeanEstimator(
    mt = mt, h_layer=h_layer, beta=beta
)
operator = estimator(train)

In [None]:
operator_dict = operator.__dict__.copy()
operator_dict["beta"] = 1.0
no_beta = operators.LinearRelationOperator(**operator_dict)

In [None]:
for sample in test.samples:
    pred = operator(sample.subject).predictions[:3]
    no_beta_pred = no_beta(sample.subject).predictions[:3]
    print(f"{sample} | pred: {[f'{p.token} ({p.prob:.2f})' for p in pred]} | no_beta: {[f'{p.token} ({p.prob:.2f})' for p in no_beta_pred]}")

In [None]:
for approx in operator.metadata["approxes"]:
    h = approx.h
    weight = approx.weight
    bias = approx.bias
    print(f"{h.norm()=:.3f} | {weight.norm()=:.3f} | {bias.norm()=:.3f}")

In [None]:
mythical_estimator = operators.JacobianIclMeanEstimator_Imaginary(
    mt = mt, h_layer=h_layer, beta=1.0, magnitude_h=65.0
)
mythical_operator = mythical_estimator(train)

In [None]:
for approx in mythical_operator.metadata["approxes"]:
    h = approx.h
    weight = approx.weight
    bias = approx.bias
    print(f"{h.norm()=:.3f} | {weight.norm()=:.3f} | {bias.norm()=:.3f}")

In [None]:
print(f"{operator.weight.norm()=:.3f} | {mythical_operator.weight.norm()=:.3f}")
print(f"{operator.bias.norm()=:.3f} | {mythical_operator.bias.norm()=:.3f}")


In [None]:
s1 = operator.metadata["approxes"][0].h
s2 = operator.metadata["approxes"][1].h

j_delta_h = operator.weight @ (s1 - s2)
myth_j_delta_h = mythical_operator.weight @ (s1 - s2)

In [None]:
torch.cosine_similarity(j_delta_h, myth_j_delta_h, dim=-1)

In [None]:
torch.cosine_similarity(operator.bias, mythical_operator.bias, dim=-1)

### Fixing the hparams

In [None]:
imaginary_operators = []
for interpolate_on in tqdm(range(2, 8)):
    estimator_i = operators.JacobianIclMeanEstimator_Imaginary(
        mt = mt, h_layer=7, beta=1, interpolate_on=interpolate_on, n_trials=8, magnitude_h=65.0
    )
    operator_i = estimator_i(train)
    imaginary_operators.append(operator_i)
    print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")

In [None]:
operator.weight.norm().item(), operator.bias.norm().item()

In [None]:
w_norms = [op.weight.norm().item() for op in imaginary_operators]
plt.plot(range(2, 8), w_norms, label = "|| J_imaginary ||")
plt.hlines(operator.weight.norm().item(), 2, 7, color="red", label = "|| J_real ||")
plt.ylim(bottom=0)
plt.legend()
plt.ylabel("|| J ||")
plt.xlabel("n_points")

In [None]:
operator_dict = imaginary_operators[2].__dict__.copy()
operator_dict["beta"] = 1
img_operator = operators.LinearRelationOperator(**operator_dict)

print(img_operator.weight.norm().item(), img_operator.bias.norm().item())

for sample in test.samples:
    pred = img_operator(sample.subject).predictions[:3]
    print(f"{sample} | pred: {[f'{p.token} ({p.prob:.2f})' for p in pred]}")

In [None]:
b_norms = [op.bias.norm().item() for op in imaginary_operators]
plt.plot(range(2, 8), b_norms, label = "|| bias_imaginary ||")
plt.hlines(operator.bias.norm().item(), 2, 7, color="red", label = "|| bias_real ||")
plt.ylim(bottom=200)
plt.legend()
plt.ylabel("|| bias ||")
plt.xlabel("n_points")

In [None]:
for sample in test.samples:
    pred = operator_i(sample.subject).predictions[:3]
    print(f"{sample} | pred: {[f'{p.token} ({p.prob:.2f})' for p in pred]}")