In [None]:
import sys
sys.path.append("../..")

In [None]:
from src import models

device = "cuda:5"
mt = models.load_model("gptj", device=device)

In [None]:
from src import data

dataset = data.load_dataset()

# Specificity

In [None]:
from functools import cache

from src import editors, functional, hparams, operators
from src.utils import experiment_utils

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch

sns.set()


N_TRAIN = 5


def require_sample(subject, relation):
    matches = [x for x in relation.samples if x.subject == subject]
    assert len(matches) == 1, matches
    return matches[0]


@torch.inference_mode()
def compute_zs(prompt_template, subj, targ):
    prompt_subj = prompt_template.format(subj)
    prompt_targ = prompt_template.format(targ)
    with models.set_padding_side(mt, padding_side="left"):
        inputs = mt.tokenizer(
            [prompt_subj, prompt_targ],
            return_tensors="pt",
            padding="longest"
        ).to(device)
    [[hs], _] = functional.compute_hidden_states(
        mt=mt,
        layers=[27],
        inputs=inputs
    )
    z_subj = hs[0, -1]
    z_targ = hs[1, -1]
    return z_subj, z_targ


def sweep_specificity(
    relation_edit,
    relation_ref,
    subject_orig,
    subject_targ,
    ranks=None,
):
    experiment_utils.set_seed(12345)
    if ranks is None:
        ranks = range(0, 100, 2)

    sample_orig_edit = require_sample(subject_orig, relation_edit)
    sample_orig_ref = require_sample(subject_orig, relation_ref)
    sample_targ_edit = require_sample(subject_targ, relation_edit)
    sample_targ_ref = require_sample(subject_targ, relation_ref)

    train, _ = relation_edit.without(sample_orig_edit).without(sample_targ_edit).split(N_TRAIN)

    relation_hparams = hparams.get(mt, relation_edit)
    estimator = operators.JacobianIclMeanEstimator(
        mt=mt,
        h_layer=relation_hparams.h_layer,
        z_layer=relation_hparams.z_layer,
        beta=relation_hparams.beta,
    )

    print("estimating LRE...")
    operator = estimator(train)

    pt_zs_edit = relation_edit.prompt_templates_zs[0]
    pt_zs_ref = relation_ref.prompt_templates_zs[0]
    
    print("precomputing zs...")
    z_subj, z_targ = compute_zs(
        pt_zs_edit,
        subject_orig,
        subject_targ,
    )

    print("begin sweep over rank...")
    ys_orig_edit = []
    ys_orig_ref = []
    ys_targ_edit = []
    ys_targ_ref = []
    for rank in ranks:
        editor = editors.LowRankPInvEditor(lre=operator, rank=rank)

        # Hack to overwrite prompt template, whatever...
        object.__setattr__(editor.lre, "prompt_template", pt_zs_edit)
        results_edit = editor(subject_orig, subject_targ, z_original=z_subj, z_target=z_targ)
        logits_edit = results_edit.model_logits
        best_edit = [str(o) for o in results_edit.predicted_tokens][0]

        object.__setattr__(editor.lre, "prompt_template", pt_zs_ref)
        results_ref = editor(subject_orig, subject_targ, z_original=z_subj, z_target=z_targ)
        logits_ref = results_ref.model_logits
        best_ref = [str(p) for p in results_ref.predicted_tokens][0]

        print(
            rank,
            best_edit,
            best_ref,
            sep="\t",
        )

        for tok, logits, ys in (
            (sample_orig_edit.object, logits_edit, ys_orig_edit),
            (sample_targ_edit.object, logits_edit, ys_targ_edit),
            (sample_orig_ref.object, logits_ref, ys_orig_ref),
            (sample_targ_ref.object, logits_ref, ys_targ_ref),
        ):
            probs = logits.float().softmax(dim=0)
            tok_id = mt.tokenizer.encode(" " + tok)[0]
            prob = probs[tok_id].item()
            ys.append(prob)
    
    return (
        ranks,
        ys_orig_edit,
        ys_orig_ref,
        ys_targ_edit,
        ys_targ_ref,
        relation_edit,
        sample_orig_edit,
        sample_orig_ref,
        sample_targ_edit,
        sample_targ_ref,
    )


def plot(results):
    (
        ranks,
        ys_orig_edit,
        ys_orig_ref,
        ys_targ_edit,
        ys_targ_ref,
        relation_edit,
        sample_orig_edit,
        sample_orig_ref,
        sample_targ_edit,
        sample_targ_ref,
    ) = results
    plt.title(f'Change "{relation_edit.name}" of "{sample_orig_edit.subject}" -> "{sample_targ_edit.subject}"')
    plt.plot(ranks, ys_orig_edit, label=f"p({sample_orig_edit.object})", color="b")
    plt.plot(ranks, ys_targ_edit, label=f"p({sample_targ_edit.object})", color="r")
    plt.plot(ranks, ys_orig_ref, label=f"p({sample_orig_ref.object})", color="b", linestyle="-.")
    plt.plot(ranks, ys_targ_ref, label=f"p({sample_targ_ref.object})", color="r", linestyle="-.")
    plt.xlabel("Rank")
    plt.ylabel("LM Probability")
    plt.yticks(np.linspace(0, 1, 11))
    plt.legend()

In [None]:
relation_ref = dataset.filter(relation_names=["country language"]).relations[0]
relation_edit = dataset.filter(relation_names=["country capital city"]).relations[0]
subj_orig = "United States"
subj_targ = "Germany"
# results = sweep_specificity(
#     relation_edit,
#     relation_ref,
#     subj_orig,
#     subj_targ,
# )
plot(results)