In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')

import torch
import matplotlib.pyplot as plt
from src import models, data
from tqdm.auto import tqdm
import json
import os
import numpy as np
import copy


In [None]:
mt = models.load_model("gptj", fp16=True, device="cuda")

In [None]:
#####################################
relation_name = "plays pro sport"
#####################################

In [None]:
dataset = data.load_dataset()
relation = dataset.filter(
    relation_names = [relation_name]
)[0]

In [None]:
from src.utils.sweep_utils import read_sweep_results, relation_from_dict

In [None]:
sweep_dict = read_sweep_results("../results/sweep-24-trials/gptj", relation_names=[relation_name])

In [None]:
relation_result = relation_from_dict(sweep_dict[relation_name])

In [None]:
trial_options = list(range(len(relation_result.trials)))
print(f"{trial_options=}")

layer_options = [layer.layer for layer in relation_result.trials[0].layers]
print(f"{layer_options=}")

rank_options = [rank.rank for rank in relation_result.trials[0].layers[0].result.ranks]
print(f"{rank_options=}")

In [None]:
relation_result.best_by_efficacy().__dict__

In [None]:
#########################################################################################################
TRIAL_NO = 10
RANK = 192
LAYER = 27
#########################################################################################################

In [None]:
layer_result = [layer for layer in relation_result.trials[TRIAL_NO].layers if layer.layer == LAYER][0]
rank_result = [rank for rank in layer_result.result.ranks if rank.rank == RANK][0]
rank_result

In [None]:
efficacy_successes = {s.target.subject : s for s in rank_result.efficacy_successes}

for beta_result in layer_result.result.betas:
    faithfulness_successes = beta_result.faithfulness_successes
    for sample in faithfulness_successes:
        if(sample.subject in efficacy_successes):
            print(f"Edit: {efficacy_successes[sample.subject].source} <to> {efficacy_successes[sample.subject].target} -- found in beta: {beta_result.beta}")
            efficacy_successes.pop(sample.subject)
        

In [None]:
print("No target match found in faithfulness successes for the following:")
for sample in efficacy_successes.values():
    print(f"Edit: {sample.source} <to> {sample.target}")

In [None]:
train_samples = layer_result.result.samples
train_samples

In [None]:
prompt_template = relation_result.trials[TRIAL_NO].prompt_template
prompt_template

In [None]:
from src import functional, operators, editors

In [None]:
estimator = operators.JacobianIclMeanEstimator(
    mt = mt,
    h_layer = LAYER,
)

operator = estimator(
    relation.set(
        samples = train_samples,
        prompt_templates = [prompt_template],
    )
)

In [None]:
############################################
subject = "Hungary"
############################################

In [None]:
# model predicts correctly
functional.predict_next_token(
    mt = mt,
    prompt = operator.prompt_template.format(subject)
)

In [None]:
# LRE fails (low faithfulness)
operator(subject=subject)

In [None]:
svd = torch.svd(operator.weight.float())
editor = editors.LowRankPInvEditor(
    lre = operator,
    rank = rank_result.rank,
    n_samples=1, n_new_tokens=1,
    svd = svd
)

In [None]:
efficacy_test_pair = efficacy_successes[subject]
f"Editing: {efficacy_test_pair.source} <to> {efficacy_test_pair.target}"

In [None]:
# editing succeeds (high efficacy)
editor(
    subject = efficacy_test_pair.source.subject,
    target = efficacy_test_pair.target.subject,
)