In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import matplotlib.pyplot as plt
import os
import json
import sys
import numpy as np
sys.path.append("..")
import copy
from tqdm.auto import tqdm

In [3]:
from src import models, data, operators, utils, functional, metrics, lens
from src.utils import logging_utils, experiment_utils
import logging
import torch
import baukit

experiment_utils.set_seed(123456)

logger = logging.getLogger(__name__)

logging.basicConfig(
    level=logging.INFO,
    format = logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout
)

In [4]:
#################################################
h_layer = 7
beta = 0.4
n_training = 10
#################################################

In [5]:
mt = models.load_model(name = "gptj", fp16 = True, device = "cuda")

2023-06-21 17:10:13 src.models INFO     loading EleutherAI/gpt-j-6B (device=cuda, fp16=True)


In [6]:
relation = data.load_dataset().filter(relation_names=["country capital city"])[0].set(prompt_templates=[" {}:"])
train, test = relation.split(n_training)

In [7]:
icl_prompt = functional.make_prompt(
    prompt_template = train.prompt_templates[0],
    subject = "{}",
    examples = train.samples,
    mt = mt
)
print(icl_prompt)

<|endoftext|> Pakistan: Islamabad
 Argentina: Buenos Aires
 Peru: Lima
 Australia: Canberra
 Germany: Berlin
 Saudi Arabia: Riyadh
 Russia: Moscow
 Italy: Rome
 India: New Delhi
 South Korea: Seoul
 {}:


In [8]:
test = functional.filter_relation_samples_based_on_provided_fewshots(
    mt = mt, 
    test_relation=test,
    prompt_template = icl_prompt,
    batch_size=4
)
len(test.samples)

14

In [9]:
estimator = operators.JacobianIclMeanEstimator(
    mt = mt, h_layer=h_layer, beta=beta
)
operator = estimator(train)

In [10]:
for sample in test.samples:
    pred = operator(sample.subject).predictions[:3]
    print(f"{sample} | pred: {[f'{p.token} ({p.prob:.2f})' for p in pred]}")

Brazil -> Bras\u00edlia | pred: [' Bras (0.74)', ' Rio (0.11)', ' Sao (0.04)']
Canada -> Ottawa | pred: [' Toronto (0.28)', ' Ottawa (0.23)', ' Vancouver (0.10)']
Chile -> Santiago | pred: [' Santiago (0.98)', '  (0.01)', ' Chile (0.01)']
China -> Beijing | pred: [' Beijing (0.96)', ' Shanghai (0.03)', ' China (0.00)']
Colombia -> Bogot\u00e1 | pred: [' Bog (0.76)', '  (0.12)', '\n (0.02)']
Egypt -> Cairo | pred: [' Cairo (0.98)', '  (0.00)', ' Egypt (0.00)']
France -> Paris | pred: [' Paris (0.99)', ' France (0.00)', '  (0.00)']
Japan -> Tokyo | pred: [' Tokyo (1.00)', ' Osaka (0.00)', ' Japan (0.00)']
Mexico -> Mexico City | pred: [' Mexico (0.90)', '  (0.02)', ' New (0.02)']
Nigeria -> Abuja | pred: [' New (0.18)', '  (0.15)', ' Abu (0.09)']
Spain -> Madrid | pred: [' Madrid (0.91)', ' Barcelona (0.03)', '  (0.01)']
Turkey -> Ankara | pred: [' Istanbul (0.52)', ' Ankara (0.46)', ' Turkey (0.01)']
United States -> Washington D.C. | pred: [' Washington (0.82)', ' New (0.11)', ' Los (0

In [36]:
estimator_i = operators.JacobianIclMeanEstimator_Imaginary(
    mt = mt, h_layer=7, beta=1, interpolate_on=4, n_trials=8
)
operator_i = estimator_i(train)

In [37]:
operator_i.weight.norm().item(), operator_i.bias.norm().item()

(59.65625, 223.25)

In [35]:
operator.weight.norm().item(), operator.bias.norm().item()

(22.515625, 213.0)

In [48]:
for sample in test.samples:
    pred = operator_i(sample.subject).predictions[:3]
    print(f"{sample} | pred: {[f'{p.token} ({p.prob:.2f})' for p in pred]}")

Russia -> Moscow | pred: [' Moscow (0.94)', ' Buenos (0.02)', ' Russia (0.01)']
United States -> Washington D.C. | pred: [' Mexico (0.39)', ' Buenos (0.27)', ' Washington (0.12)']
Japan -> Tokyo | pred: [' Tokyo (0.99)', ' Japan (0.00)', ' Manila (0.00)']
Australia -> Canberra | pred: [' Sydney (0.60)', ' Buenos (0.13)', ' Canberra (0.12)']
Saudi Arabia -> Riyadh | pred: [' Riyadh (0.61)', ' Madrid (0.12)', ' Buenos (0.07)']
Peru -> Lima | pred: [' Lima (0.97)', ' Peru (0.02)', ' Bog (0.00)']
Egypt -> Cairo | pred: [' Cairo (0.62)', ' Buenos (0.14)', ' Madrid (0.07)']
Mexico -> Mexico City | pred: [' Mexico (0.99)', ' Mé (0.00)', 'Mexico (0.00)']
Nigeria -> Abuja | pred: [' Mexico (0.63)', ' Lag (0.08)', ' Nigeria (0.06)']
Turkey -> Ankara | pred: [' Ankara (0.33)', ' Istanbul (0.25)', ' Madrid (0.15)']
Brazil -> Bras\u00edlia | pred: [' Bras (0.93)', ' Brazil (0.02)', ' Sao (0.02)']
Colombia -> Bogot\u00e1 | pred: [' Bog (0.98)', ' Colombia (0.01)', ' Lima (0.00)']
Venezuela -> Caraca