In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

import torch
from src import models, data, lens, functional
from src.utils import experiment_utils
from baukit import Menu, show

In [3]:
device = "cuda:0"
mt = models.load_model("gptj", device=device, fp16=True)
print(f"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}")

dtype: torch.float16, device: cuda:0, memory: 12219206136


In [4]:
dataset = data.load_dataset()

relation_names = [r.name for r in dataset.relations]
relation_options = Menu(choices = relation_names, value = relation_names)
show(relation_options)

In [6]:
relation_name = relation_options.value
relation = dataset.filter(relation_names=[relation_name])[0]
print(f"{relation.name} -- {len(relation.samples)} samples")
print("------------------------------------------------------")

experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
train, test = relation.split(5)
print("\n".join([sample.__str__() for sample in train.samples]))

food from country -- 30 samples
------------------------------------------------------
Miso Soup -> Japan
Pierogi -> Poland
Kimchi -> South Korea
Hummus -> Lebanon
Fondue -> Switzerland


In [7]:
################### hparams ###################
layer = 5
beta = 2.5
###############################################

In [8]:
from src.operators import JacobianIclMeanEstimator

estimator = JacobianIclMeanEstimator(
    mt = mt, 
    h_layer = layer,
    beta = beta
)
operator = estimator(
    relation.set(
        samples=train.samples, 
    )
)

relation has > 1 prompt_templates, will use first ({} originates from)


# Checking $faithfulness$

In [9]:
test = functional.filter_relation_samples_based_on_provided_fewshots(
    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4
)

In [10]:
sample = test.samples[0]
print(sample)
operator(subject = sample.subject).predictions

Baguette -> France


[PredictedToken(token=' France', prob=0.8985161781311035),
 PredictedToken(token=' Switzerland', prob=0.014076330699026585),
 PredictedToken(token=' the', prob=0.01281664427369833),
 PredictedToken(token=' Austria', prob=0.009017635136842728),
 PredictedToken(token=' Italy', prob=0.007302704732865095)]

In [12]:
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject],
    h_layer= operator.h_layer,
)
h = hs_and_zs.h_by_subj[sample.subject]

## Approximating LM computation $F$ as an affine transformation

### $$ F(\mathbf{s}, c_r) \approx \beta \, W_r \mathbf{s} + b_r $$

In [13]:
z = operator.beta * (operator.weight @ h) + operator.bias

lens.logit_lens(
    mt = mt,
    h = z,
    get_proba = True
)

([(' France', 0.898),
  (' Switzerland', 0.014),
  (' the', 0.013),
  (' Austria', 0.009),
  (' Italy', 0.007),
  (' Europe', 0.007),
  ('\n', 0.004),
  (' Germany', 0.003),
  (' Paris', 0.003),
  (' Spain', 0.002)],
 {})

In [14]:
correct = 0
wrong = 0
for sample in test.samples:
    predictions = operator(subject = sample.subject).predictions
    known_flag = functional.is_nontrivial_prefix(
        prediction=predictions[0].token, target=sample.object
    )
    print(f"{sample.subject=}, {sample.object=}, ", end="")
    print(f'predicted="{functional.format_whitespace(predictions[0].token)}", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')
    
    correct += known_flag
    wrong += not known_flag
    
faithfulness = correct/(correct + wrong)

print("------------------------------------------------------------")
print(f"Faithfulness (@1) = {faithfulness}")
print("------------------------------------------------------------")

sample.subject='Baguette', sample.object='France', predicted=" France", (p=0.8985161781311035), known=(✓)
sample.subject='Biryani', sample.object='India', predicted=" India", (p=0.22970259189605713), known=(✓)
sample.subject='Ceviche', sample.object='Peru', predicted=" South", (p=0.34727180004119873), known=(✗)
sample.subject='Chimichurri', sample.object='Argentina', predicted=" Argentina", (p=0.4291362464427948), known=(✓)
sample.subject='Dim Sum', sample.object='China', predicted=" China", (p=0.9262574315071106), known=(✓)
sample.subject='Feijoada', sample.object='Brazil', predicted=" Spain", (p=0.26894351840019226), known=(✗)
sample.subject='Goulash', sample.object='Hungary', predicted=" Hungary", (p=0.7775859236717224), known=(✓)
sample.subject='Gyro', sample.object='Greece', predicted=" Hungary", (p=0.35067880153656006), known=(✗)
sample.subject='Masala Dosa', sample.object='India', predicted=" South", (p=0.6780027747154236), known=(✗)
sample.subject='Moussaka', sample.object='Gre

# $causality$

In [15]:
################### hparams ###################
rank = 100
###############################################

In [16]:
experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
test_targets = functional.random_edit_targets(test.samples)

## setup

In [17]:
source = test.samples[0]
target = test_targets[source]

f"Changing the mapping ({source}) to ({source.subject} -> {target.object})"

'Changing the mapping (Baguette -> France) to (Baguette -> Canada)'

### Calculate $\Delta \mathbf{s}$ such that $\mathbf{s} + \Delta \mathbf{s} \approx \mathbf{s}'$

<!-- ![](causality-crop.png) -->
<img src="causality-crop.png" style="width:50%;"/>

Under the relation $r =\, $*plays the instrument*, and given the subject $s =\, $*Miles Davis*, the model will predict $o =\, $*trumpet* **(a)** Under the same relation given the subject $s' =\, $*Cat Stevens*, the model will now predict $o' =\, $*guiter* **(b)**. 

If the computation from $\mathbf{s}$ to $\mathbf{o}$ is well-approximated by $operator$, then adding the vector $\Delta{\mathbf{s}}$ to $\mathbf{s}$ **(c)** would be equivalent to replacing $\mathbf{s}$ with $\mathbf{s}'$ and it should change the model prediction to $o'$ = *guitar* **(d)**.

In [18]:
def get_delta_s(
    operator, 
    source_subject, 
    target_subject,
    rank = 100,
    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target
):
    w_p_inv = functional.low_rank_pinv(
        matrix = operator.weight,
        rank=rank,
    )
    hs_and_zs = functional.compute_hs_and_zs(
        mt = mt,
        prompt_template = operator.prompt_template,
        subjects = [source_subject, target_subject],
        h_layer= operator.h_layer,
        z_layer=-1,
    )

    z_source = hs_and_zs.z_by_subj[source_subject]
    z_target = hs_and_zs.z_by_subj[target_subject]
    
    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0
    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0

    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())

    return delta_s, hs_and_zs

delta_s, hs_and_zs = get_delta_s(
    operator = operator,
    source_subject = source.subject,
    target_subject = target.subject,
    rank = rank
)

In [19]:
import baukit

def get_intervention(h, int_layer, subj_idx):
    def edit_output(output, layer):
        if(layer != int_layer):
            return output
        functional.untuple(output)[:, subj_idx] = h 
        return output
    return edit_output

prompt = operator.prompt_template.format(source.subject)

h_index, inputs = functional.find_subject_token_index(
    mt=mt,
    prompt=prompt,
    subject=source.subject,
)

h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])

with baukit.TraceDict(
    mt.model, layers = [h_layer, z_layer],
    edit_output=get_intervention(
#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual
        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s
        int_layer = h_layer, 
        subj_idx = h_index
    )
) as traces:
    outputs = mt.model(
        input_ids = inputs.input_ids,
        attention_mask = inputs.attention_mask,
    )

lens.interpret_logits(
    mt = mt, 
    logits = outputs.logits[0][-1], 
    get_proba=True
)

[(' Canada', 0.795),
 (' Quebec', 0.098),
 (' the', 0.017),
 (' Qué', 0.011),
 (' North', 0.008),
 (' Ontario', 0.005),
 (' France', 0.004),
 (' New', 0.004),
 (' Montreal', 0.004),
 (' Canadian', 0.003)]

## Measuring causality

In [20]:
from src.editors import LowRankPInvEditor

svd = torch.svd(operator.weight.float())
editor = LowRankPInvEditor(
    lre=operator,
    rank=rank,
    svd=svd,
)

In [21]:
# precomputing latents to speed things up
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject for sample in test.samples],
    h_layer= operator.h_layer,
    z_layer=-1,
    batch_size = 2
)

success = 0
fails = 0

for sample in test.samples:
    target = test_targets.get(sample)
    assert target is not None
    edit_result = editor(
        subject = sample.subject,
        target = target.subject
    )
    
    success_flag = functional.is_nontrivial_prefix(
        prediction=edit_result.predicted_tokens[0].token, target=target.object
    )
    
    print(f"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})")
    
    success += success_flag
    fails += not success_flag
    
causality = success / (success + fails)

print("------------------------------------------------------------")
print(f"Causality (@1) = {causality}")
print("------------------------------------------------------------")

Mapping Baguette -> Canada | edit result= Canada (p=0.812) | success=(✓)
Mapping Biryani -> France | edit result= France (p=0.792) | success=(✓)
Mapping Ceviche -> Thailand | edit result= Thailand (p=0.817) | success=(✓)
Mapping Chimichurri -> Vietnam | edit result= Vietnam (p=0.815) | success=(✓)
Mapping Dim Sum -> Greece | edit result= Greece (p=0.791) | success=(✓)
Mapping Feijoada -> Greece | edit result= Greece (p=0.782) | success=(✓)
Mapping Goulash -> Canada | edit result= Canada (p=0.736) | success=(✓)
Mapping Gyro -> Brazil | edit result= Brazil (p=0.793) | success=(✓)
Mapping Masala Dosa -> Italy | edit result= Italy (p=0.810) | success=(✓)
Mapping Moussaka -> Argentina | edit result= Argentina (p=0.909) | success=(✓)
Mapping Pad Thai -> Canada | edit result= Canada (p=0.850) | success=(✓)
Mapping Paella -> India | edit result= India (p=0.828) | success=(✓)
Mapping Pho -> Austria | edit result= Austria (p=0.580) | success=(✓)
Mapping Pizza -> Brazil | edit result= Brazil (p=0