In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import os
import json
import sys
import numpy as np
sys.path.append("..")
import copy
from tqdm.auto import tqdm

In [3]:
from src import models, data, operators, utils, functional, metrics, lens
from src.utils import logging_utils, experiment_utils
import logging
import torch
import baukit

experiment_utils.set_seed(123456)

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format = logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout
)

In [15]:
#################################################
h_layer = 7
beta = 0.4
n_training = 8
#################################################

In [5]:
mt = models.load_model(name = "gptj", fp16 = True, device = "cuda")

2023-06-21 15:38:34 src.models INFO     loading EleutherAI/gpt-j-6B (device=cuda, fp16=True)


In [16]:
relation = data.load_dataset().filter(relation_names=["country capital city"])[0].set(prompt_templates=[" {}:"])
train, test = relation.split(n_training)

In [17]:
icl_prompt = functional.make_prompt(
    prompt_template = train.prompt_templates[0],
    subject = "{}",
    examples = train.samples,
    mt = mt
)
print(icl_prompt)

<|endoftext|> United States: Washington D.C.
 Japan: Tokyo
 Germany: Berlin
 Pakistan: Islamabad
 Argentina: Buenos Aires
 Colombia: Bogot\u00e1
 Peru: Lima
 Mexico: Mexico City
 {}:


In [18]:
test = functional.filter_relation_samples_based_on_provided_fewshots(
    mt = mt, 
    test_relation=test,
    prompt_template = icl_prompt,
    batch_size=4
)
len(test.samples)

15

In [19]:
estimator = operators.JacobianIclMeanEstimator(
    mt = mt, h_layer=h_layer, beta=beta
)
operator = estimator(train)

In [20]:
for sample in test.samples:
    pred = operator(sample.subject).predictions[:3]
    print(f"{sample} | pred: {[f'{p.token} ({p.prob:.2f})' for p in pred]}")

Brazil -> Bras\u00edlia | pred: [' Bras (0.72)', ' Rio (0.10)', ' Brazil (0.09)']
Canada -> Ottawa | pred: [' Toronto (0.29)', ' Ottawa (0.21)', ' Montreal (0.13)']
Chile -> Santiago | pred: [' Santiago (0.99)', ' Chile (0.00)', '  (0.00)']
China -> Beijing | pred: [' Beijing (0.97)', ' Shanghai (0.02)', ' China (0.01)']
Egypt -> Cairo | pred: [' Cairo (0.96)', '  (0.01)', ' Egypt (0.01)']
France -> Paris | pred: [' Paris (0.90)', ' Buenos (0.02)', ' France (0.01)']
India -> New Delhi | pred: [' Delhi (0.45)', ' New (0.39)', ' Mumbai (0.09)']
Italy -> Rome | pred: [' Rome (0.81)', ' Milan (0.07)', ' New (0.04)']
Nigeria -> Abuja | pred: [' Lag (0.20)', ' Abu (0.19)', '  (0.11)']
Russia -> Moscow | pred: [' Moscow (0.98)', ' Russia (0.01)', ' Kremlin (0.00)']
Saudi Arabia -> Riyadh | pred: ['  (0.28)', '\n (0.10)', ' Riyadh (0.08)']
South Korea -> Seoul | pred: [' Seoul (0.98)', ' Los (0.00)', ' Tokyo (0.00)']
Spain -> Madrid | pred: [' Madrid (0.87)', ' Barcelona (0.05)', ' Spain (0.01

In [25]:
estimator_i = operators.JacobianIclMeanEstimator_Imaginary(
    mt = mt, h_layer=h_layer, beta=1, interpolate_on=4, n_trials=8
)
operator_i = estimator_i(train)

In [26]:
operator_i.weight.norm().item(), operator_i.bias.norm().item()

(75.8125, 274.25)

In [27]:
operator.weight.norm().item(), operator.bias.norm().item()

(24.125, 245.375)

In [28]:
for sample in test.samples:
    pred = operator_i(sample.subject).predictions[:3]
    print(f"{sample} | pred: {[f'{p.token} ({p.prob:.2f})' for p in pred]}")

Brazil -> Bras\u00edlia | pred: [' Bras (0.83)', ' Brazil (0.10)', ' Sao (0.02)']
Canada -> Ottawa | pred: [' Canada (0.48)', ' Toronto (0.13)', ' Ottawa (0.11)']
Chile -> Santiago | pred: [' Santiago (0.98)', ' Chile (0.02)', ' Lima (0.00)']
China -> Beijing | pred: [' Beijing (0.97)', ' Shanghai (0.01)', ' China (0.01)']
Egypt -> Cairo | pred: [' Cairo (0.97)', ' Egypt (0.03)', ' Islamabad (0.00)']
France -> Paris | pred: [' Paris (0.77)', ' France (0.22)', ' French (0.01)']
India -> New Delhi | pred: [' Mumbai (0.45)', ' Delhi (0.44)', ' India (0.07)']
Italy -> Rome | pred: [' Rome (0.43)', ' Italy (0.31)', ' Milan (0.25)']
Nigeria -> Abuja | pred: [' Nigeria (0.75)', ' Abu (0.08)', ' Tokyo (0.06)']
Russia -> Moscow | pred: [' Moscow (0.98)', ' Russia (0.02)', ' Russian (0.00)']
Saudi Arabia -> Riyadh | pred: [' Riyadh (0.58)', ' Saudi (0.09)', ' Osaka (0.07)']
South Korea -> Seoul | pred: [' Seoul (0.98)', ' Osaka (0.01)', ' Tokyo (0.01)']
Spain -> Madrid | pred: [' Madrid (0.92)',