In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

import torch
import matplotlib.pyplot as plt
from src import models, data
from tqdm.auto import tqdm
import json
import os
import numpy as np
import copy


In [3]:
mt = models.load_model("gptj", fp16=True, device="cuda")

In [4]:
#####################################
relation_name = "plays pro sport"
#####################################

In [5]:
dataset = data.load_dataset()
relation = dataset.filter(
    relation_names = [relation_name]
)[0]

In [6]:
from src.utils.sweep_utils import read_sweep_results, relation_from_dict

In [7]:
sweep_dict = read_sweep_results("../results/sweep-24-trials/gptj", relation_names=[relation_name])

In [8]:
relation_result = relation_from_dict(sweep_dict[relation_name])

In [9]:
trial_options = list(range(len(relation_result.trials)))
print(f"{trial_options=}")

layer_options = [layer.layer for layer in relation_result.trials[0].layers]
print(f"{layer_options=}")

rank_options = [rank.rank for rank in relation_result.trials[0].layers[0].result.ranks]
print(f"{rank_options=}")

trial_options=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
layer_options=['emb', 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27]
rank_options=[0, 8, 16, 24, 32, 40, 48, 56, 64, 72, 80, 88, 96, 104, 112, 120, 128, 136, 144, 152, 160, 168, 176, 184, 192, 200, 208, 216, 224, 232, 240, 248, 256, 264, 272, 280, 288, 296, 304, 312]


In [10]:
relation_result.best_by_efficacy().__dict__

{'layer': 6,
 'beta': AggregateMetric(mean=2.46875, stdev=0.5016249636597712, stderr=0.1023937669340495, values=[2.5, 2.5, 2.5, 3.0, 3.0, 1.75, 2.25, 2.5, 3.5, 2.0, 3.0, 1.75, 2.5, 2.25, 1.25, 2.75, 3.0, 2.5, 2.75, 2.5, 2.5, 3.0, 2.25, 1.75]),
 'recall': AggregateMetric(mean=0.7718524188493854, stdev=0.059807805365964375, stderr=0.012208217148525206, values=[0.7760617760617761, 0.8024691358024691, 0.694560669456067, 0.8114754098360656, 0.6769911504424779, 0.7647058823529411, 0.7239819004524887, 0.7528089887640449, 0.7375, 0.7176470588235294, 0.6666666666666666, 0.8326693227091634, 0.7958333333333333, 0.8, 0.8, 0.8143459915611815, 0.7717842323651453, 0.8613445378151261, 0.9, 0.6587301587301587, 0.73828125, 0.8217391304347826, 0.8089430894308943, 0.7959183673469388]),
 'rank': AggregateMetric(mean=117.66666666666667, stdev=66.19080164359865, stderr=13.511140807715902, values=[72, 112, 72, 216, 216, 112, 112, 224, 208, 80, 136, 72, 144, 256, 64, 72, 48, 152, 32, 72, 192, 56, 48, 56]),
 'e

In [13]:
#########################################################################################################
TRIAL_NO = 10
RANK = 192
LAYER = 27
#########################################################################################################

In [14]:
layer_result = [layer for layer in relation_result.trials[TRIAL_NO].layers if layer.layer == LAYER][0]
rank_result = [rank for rank in layer_result.result.ranks if rank.rank == RANK][0]
rank_result

SweepRankResults(rank=192, efficacy=[0.0, 0.14583333333333334, 0.3458333333333333], efficacy_successes=[])

In [15]:
efficacy_successes = {s.target.subject : s for s in rank_result.efficacy_successes}

for beta_result in layer_result.result.betas:
    faithfulness_successes = beta_result.faithfulness_successes
    for sample in faithfulness_successes:
        if(sample.subject in efficacy_successes):
            print(f"Edit: {efficacy_successes[sample.subject].source} <to> {efficacy_successes[sample.subject].target} -- found in beta: {beta_result.beta}")
            efficacy_successes.pop(sample.subject)
        

In [16]:
print("No target match found in faithfulness successes for the following:")
for sample in efficacy_successes.values():
    print(f"Edit: {sample.source} <to> {sample.target}")

No target match found in faithfulness successes for the following:


In [34]:
train_samples = layer_result.result.samples
train_samples

[RelationSample(subject='Brazil', object='Real'),
 RelationSample(subject='Argentina', object='Peso'),
 RelationSample(subject='Russia', object='Ruble'),
 RelationSample(subject='Poland', object='Zloty'),
 RelationSample(subject='India', object='Rupee'),
 RelationSample(subject='China', object='Yuan'),
 RelationSample(subject='South Korea', object='Won'),
 RelationSample(subject='Canada', object='Dollar'),
 RelationSample(subject='Turkey', object='Lira'),
 RelationSample(subject='Thailand', object='Baht')]

In [35]:
prompt_template = relation_result.trials[TRIAL_NO].prompt_template
prompt_template

' {} :'

In [36]:
from src import functional, operators, editors

In [37]:
estimator = operators.JacobianIclMeanEstimator(
    mt = mt,
    h_layer = LAYER,
)

operator = estimator(
    relation.set(
        samples = train_samples,
        prompt_templates = [prompt_template],
    )
)

In [38]:
############################################
subject = "Hungary"
############################################

In [41]:
# model predicts correctly
functional.predict_next_token(
    mt = mt,
    prompt = operator.prompt_template.format(subject)
)

[[PredictedToken(token=' For', prob=0.970029890537262),
  PredictedToken(token=' Ft', prob=0.0120215630158782),
  PredictedToken(token=' for', prob=0.003609527600929141),
  PredictedToken(token=' F', prob=0.0023671858943998814),
  PredictedToken(token=' Kor', prob=0.0018725981935858727)]]

In [42]:
# LRE fails (low faithfulness)
operator(subject=subject)

LinearRelationOutput(predictions=[PredictedToken(token=' New', prob=0.4707508981227875), PredictedToken(token=' Dollar', prob=0.1228027194738388), PredictedToken(token=' Euro', prob=0.08984438329935074), PredictedToken(token=' ', prob=0.03573762625455856), PredictedToken(token=' Pes', prob=0.034101083874702454)], h=tensor([[-0.3643, -0.0974, -0.7529,  ...,  1.1895, -0.2173,  0.6777]],
       device='cuda:0', dtype=torch.float16), z=tensor([[ 0.5054, -0.1365, -2.2422,  ..., -0.3159,  2.8867,  1.0400]],
       device='cuda:0', dtype=torch.float16, grad_fn=<AddBackward0>))

In [43]:
svd = torch.svd(operator.weight.float())
editor = editors.LowRankPInvEditor(
    lre = operator,
    rank = rank_result.rank,
    n_samples=1, n_new_tokens=1,
    svd = svd
)

In [44]:
efficacy_test_pair = efficacy_successes[subject]
f"Editing: {efficacy_test_pair.source} <to> {efficacy_test_pair.target}"

'Editing: Japan -> Yen <to> Hungary -> Forint'

In [45]:
# editing succeeds (high efficacy)
editor(
    subject = efficacy_test_pair.source.subject,
    target = efficacy_test_pair.target.subject,
)

LinearRelationEditResult(predicted_tokens=[PredictedToken(token=' For', prob=0.8955066204071045), PredictedToken(token=' H', prob=0.0679834708571434), PredictedToken(token=' Kor', prob=0.009948180057108402)], model_logits=tensor([-inf, -inf, -inf,  ..., -inf, -inf, -inf], device='cuda:0'), model_generations=[' Brazil : Real\n Argentina : Peso\n Russia : Ruble\n Poland : Zloty\n India : Rupee\n China : Yuan\n South Korea : Won\n Canada : Dollar\n Turkey : Lira\n Thailand : Baht\n Japan : For'])