In [None]:
import sys
sys.path.append("../..")

In [None]:
from src import models

device = "cuda:5"
mt = models.load_model("gptj", device=device)

In [None]:
from src import data

dataset = data.load_dataset()

# Specificity

In [None]:
from functools import cache

from src import editors, functional, hparams, operators
from src.utils import experiment_utils

import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import torch

sns.set(font="Serif")
sns.set_theme(style="white", palette="bright", font="Serif")


N_TRAIN = 5


def require_sample(subject, relation):
    matches = [x for x in relation.samples if x.subject == subject]
    assert len(matches) >= 1, matches
    return matches[0]


@torch.inference_mode()
def compute_zs(prompt_template, subj, targ):
    prompt_subj = prompt_template.format(subj)
    prompt_targ = prompt_template.format(targ)
    with models.set_padding_side(mt, padding_side="left"):
        inputs = mt.tokenizer(
            [prompt_subj, prompt_targ],
            return_tensors="pt",
            padding="longest"
        ).to(device)
    [[hs], _] = functional.compute_hidden_states(
        mt=mt,
        layers=[27],
        inputs=inputs
    )
    z_subj = hs[0, -1]
    z_targ = hs[1, -1]
    return z_subj, z_targ


def sweep_specificity(
    relation_edit,
    relation_ref,
    subject_orig,
    subject_targ,
    pt_zs_edit=None,
    pt_zs_ref=None,
    ranks=None,
):
    experiment_utils.set_seed(12345)
    if ranks is None:
        ranks = range(0, 100, 2)

    relation_edit = dataset.filter(relation_names=[relation_edit])[0]
    relation_ref = dataset.filter(relation_names=[relation_ref])[0]

    if isinstance(subject_orig, tuple):
        sample_orig_edit, sample_orig_ref = subject_orig
        subject_orig = sample_orig_edit.subject
        assert subject_orig == sample_orig_ref.subject
    else:
        sample_orig_edit = require_sample(subject_orig, relation_edit)
        sample_orig_ref = require_sample(subject_orig, relation_ref)

    if isinstance(subject_targ, tuple):
        sample_targ_edit, sample_targ_ref = subject_targ
        subject_targ = sample_targ_edit.subject
        assert subject_targ == sample_targ_ref.subject
    else:
        sample_targ_edit = require_sample(subject_targ, relation_edit)
        sample_targ_ref = require_sample(subject_targ, relation_ref)

    print(f"{sample_orig_edit=}, {sample_orig_ref=}, {sample_targ_edit=}, {sample_targ_ref=}")
        
    train, _ = relation_edit.without(sample_orig_edit).without(sample_targ_edit).split(N_TRAIN)

    relation_hparams = hparams.get(mt, relation_edit)
    estimator = operators.JacobianIclMeanEstimator(
        mt=mt,
        h_layer=relation_hparams.h_layer,
        z_layer=relation_hparams.z_layer,
        beta=relation_hparams.beta,
    )

    print("estimating LRE...")
    operator = estimator(train)

    if pt_zs_edit is None:
        pt_zs_edit = relation_edit.prompt_templates_zs[0]
    print("prompt template being edtied: " + pt_zs_edit)
    
    if pt_zs_ref is None:
        pt_zs_ref = relation_ref.prompt_templates_zs[0]
    print("prompt template being referenced: " + pt_zs_ref)

    print("precomputing zs...")
    z_subj, z_targ = compute_zs(
        pt_zs_edit,
        subject_orig,
        subject_targ,
    )

    print("begin sweep over rank...")
    ys_orig_edit = []
    ys_orig_ref = []
    ys_targ_edit = []
    ys_targ_ref = []
    for rank in ranks:
        editor = editors.LowRankPInvEditor(lre=operator, rank=rank, n_new_tokens=3)

        # Hack to overwrite prompt template, whatever...
        object.__setattr__(editor.lre, "prompt_template", pt_zs_edit)
        results_edit = editor(subject_orig, subject_targ, z_original=z_subj, z_target=z_targ)
        logits_edit = results_edit.model_logits
        gens_edit = results_edit.model_generations[0]
        best_edit = [str(o) for o in results_edit.predicted_tokens][0]

        object.__setattr__(editor.lre, "prompt_template", pt_zs_ref)
        results_ref = editor(subject_orig, subject_targ, z_original=z_subj, z_target=z_targ)
        logits_ref = results_ref.model_logits
        gens_ref = results_ref.model_generations[0]
        best_ref = [str(p) for p in results_ref.predicted_tokens][0]

        print(
            rank,
            best_edit,
#             gens_edit,
            best_ref,
#             gens_ref,
            sep="\t",
        )

        for tok, logits, ys in (
            (sample_orig_edit.object, logits_edit, ys_orig_edit),
            (sample_targ_edit.object, logits_edit, ys_targ_edit),
            (sample_orig_ref.object, logits_ref, ys_orig_ref),
            (sample_targ_ref.object, logits_ref, ys_targ_ref),
        ):
            probs = logits.float().softmax(dim=0)
            tok_id = mt.tokenizer.encode(" " + tok)[0]
            prob = probs[tok_id].item()
            ys.append(prob)
    
    return (
        ranks,
        ys_orig_edit,
        ys_orig_ref,
        ys_targ_edit,
        ys_targ_ref,
        relation_edit,
        sample_orig_edit,
        sample_orig_ref,
        sample_targ_edit,
        sample_targ_ref,
    )


def plot(results):
    (
        ranks,
        ys_orig_edit,
        ys_orig_ref,
        ys_targ_edit,
        ys_targ_ref,
        relation_edit,
        sample_orig_edit,
        sample_orig_ref,
        sample_targ_edit,
        sample_targ_ref,
    ) = results
    plt.title(f'Change "{relation_edit.name}" of "{sample_orig_edit.subject}" -> "{sample_targ_edit.subject}"')
    plt.plot(ranks, ys_orig_edit, label=f"p({sample_orig_edit.object})", color="deepskyblue", linewidth=2)
    plt.plot(ranks, ys_targ_edit, label=f"p({sample_targ_edit.object})", color="darkblue", linewidth=2)
    plt.plot(ranks, ys_orig_ref, label=f"p({sample_orig_ref.object})", color="deepskyblue", linestyle="dashed", linewidth=2)
    plt.plot(ranks, ys_targ_ref, label=f"p({sample_targ_ref.object})", color="darkblue", linestyle="dashed", linewidth=2)
    plt.xlabel("Rank")
    plt.ylabel("LM Probability")
    plt.yticks(np.linspace(0, 1, 11))
    plt.legend()

In [None]:
relation_edit = "country capital city"
relation_ref = "country largest city"
subj_orig = "United States"
subj_targ = "China"
results = sweep_specificity(
    relation_edit,
    relation_ref,
    subj_orig,
    subj_targ,
    pt_zs_edit="{}'s capital city,",
    pt_zs_ref="{}'s largest city,",
    ranks=range(100, 200, 5)
)
plot(results)

In [None]:
relation_edit = "word first letter"
relation_ref = "word sentiment"
subj_orig = (
    data.RelationSample("Horror", "H"),
    data.RelationSample("Horror", "negative"),
)
subj_targ = (
    data.RelationSample("Joy", "J"),
    data.RelationSample("Joy", "positive"),
)
results = sweep_specificity(
    relation_edit,
    relation_ref,
    subj_orig,
    subj_targ,
    ranks=range(0, 100, 5),
)
plot(results)

In [None]:
relation_edit = "plays instrument"
relation_ref = "person native language"
subj_orig = (
    data.RelationSample("Eric Clapton", "guitar"),
    data.RelationSample("Eric Claptop", "English"),
)
subj_targ = (
    data.RelationSample("", "soccer"),
    data.RelationSample("Lionel Messi", "Spanish"),
)
results = sweep_specificity(
    relation_edit,
    relation_ref,
    subj_orig,
    subj_targ,
    pt_zs_ref="{}, whose first language was",
    ranks=range(0, 150, 10),
)
plot(results)

In [None]:
def determine_subject_overlap(r1_name, r2_name):
    r1 = dataset.filter(relation_names=[r1_name])[0]
    r1_subjs = {x.subject for x in r1.samples}

    r2 = dataset.filter(relation_names=[r2_name])[0]
    r2_subjs = {x.subject for x in r2.samples}
    
    print(r1_subjs & r2_subjs)

determine_subject_overlap("word first letter", "word sentiment")

In [None]:
plot(results)

In [None]:
relation = dataset.filter(relation_names=["company CEO"])[0]

train, test = relation.split(N_TRAIN)

relation_hparams = hparams.get(mt, relation_edit)
estimator = operators.JacobianIclMeanEstimator(
    mt=mt,
    h_layer=relation_hparams.h_layer,
    z_layer=relation_hparams.z_layer,
    beta=relation_hparams.beta,
)

print("estimating LRE...")
operator = estimator(train)

for sample in test.samples:
    predictions = operator(sample.subject).predictions
    print(sample.subject, sample.object, predictions[0], predictions[1])