In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

import torch
from src import models, data, lens, functional
from src.utils import experiment_utils

import logging
from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

In [3]:
device = "cuda:0"
mt = models.load_model("mamba-3b", device=device, fp16=False)

2024-01-12 16:46:26 src.models INFO     loading state-spaces/mamba-2.8b-slimpj (device=cuda:0, fp16=False)
2024-01-12 16:46:26 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443


2024-01-12 16:46:26 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b-slimpj/resolve/main/config.json HTTP/1.1" 200 0
2024-01-12 16:46:37 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b-slimpj/resolve/main/pytorch_model.bin HTTP/1.1" 302 0


  return self.fget.__get__(instance, owner)()


2024-01-12 16:46:39 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /EleutherAI/gpt-neox-20b/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-01-12 16:46:40 src.models INFO     dtype: torch.float32, device: cuda:0, memory: 10.31 GB


In [4]:
dataset = data.load_dataset()

# relation_names = [r.name for r in dataset.relations]
# relation_options = Menu(choices = relation_names, value = relation_names)
# show(relation_options) # !caution: tested in a juputer-notebook. baukit visualizations are not supported in vscode.

2024-01-12 16:46:47 src.data DEBUG    no paths provided, using default data dir: /home/local_arnab/Codes/relations/notebooks/../data
2024-01-12 16:46:47 src.data DEBUG    /home/local_arnab/Codes/relations/notebooks/../data is directory, globbing for json files...
2024-01-12 16:46:47 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/characteristic_gender.json
2024-01-12 16:46:47 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/degree_gender.json
2024-01-12 16:46:47 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/name_birthplace.json
2024-01-12 16:46:47 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/name_gender.json
2024-01-12 16:46:47 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/name_religion.json
2024-01-12 16:46:47 src.data DEBUG    found relation 

In [5]:
relation_name = "country capital city"
relation = dataset.filter(relation_names=[relation_name])[0]
print(f"{relation.name} -- {len(relation.samples)} samples")
print("------------------------------------------------------")

experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
train, test = relation.split(5)
print("\n".join([sample.__str__() for sample in train.samples]))

2024-01-12 16:46:50 src.data DEBUG    filtering to only relations: ['country capital city']
country capital city -- 24 samples
------------------------------------------------------
2024-01-12 16:46:50 src.utils.experiment_utils INFO     setting all seeds to 12345
China -> Beijing
Japan -> Tokyo
Italy -> Rome
Brazil -> Bras\u00edlia
Turkey -> Ankara


In [6]:
################### hparams ###################
layer = 5
beta = 2.5
###############################################

In [7]:
from src.operators import JacobianIclMeanEstimator

estimator = JacobianIclMeanEstimator(
    mt = mt, 
    h_layer = layer,
    beta = beta
)
operator = estimator(
    relation.set(
        samples=train.samples, 
    )
)

2024-01-12 16:47:04 src.operators DEBUG    estimating J for prompt:
<|endoftext|> The capital city of Japan is Tokyo
The capital city of Italy is Rome
The capital city of Brazil is Bras\u00edlia
The capital city of Turkey is Ankara
The capital city of China is


AttributeError: 'Mamba' object has no attribute 'device'

# Checking $faithfulness$

In [10]:
test = functional.filter_relation_samples_based_on_provided_fewshots(
    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4
)

In [11]:
sample = test.samples[0]
print(sample)
operator(subject = sample.subject).predictions

Argentina -> Buenos Aires


[PredictedToken(token=' Buenos', prob=0.8899907469749451),
 PredictedToken(token='\n', prob=0.02772851102054119),
 PredictedToken(token=' ', prob=0.013726607896387577),
 PredictedToken(token=' Argentina', prob=0.008456717245280743),
 PredictedToken(token=' Bras', prob=0.0059037404134869576)]

In [12]:
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject],
    h_layer= operator.h_layer,
)
h = hs_and_zs.h_by_subj[sample.subject]

## Approximating LM computation $F$ as an affine transformation

### $$ F(\mathbf{s}, c_r) \approx \beta \, W_r \mathbf{s} + b_r $$

In [13]:
z = operator.beta * (operator.weight @ h) + operator.bias

lens.logit_lens(
    mt = mt,
    h = z,
    get_proba = True
)

([(' Buenos', 0.89),
  ('\n', 0.028),
  (' ', 0.014),
  (' Argentina', 0.008),
  (' Bras', 0.006),
  ('...', 0.006),
  (' Rome', 0.004),
  (' {', 0.003),
  (' the', 0.002),
  ('...', 0.002)],
 {})

In [14]:
correct = 0
wrong = 0
for sample in test.samples:
    predictions = operator(subject = sample.subject).predictions
    known_flag = functional.is_nontrivial_prefix(
        prediction=predictions[0].token, target=sample.object
    )
    print(f"{sample.subject=}, {sample.object=}, ", end="")
    print(f'predicted="{functional.format_whitespace(predictions[0].token)}", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')
    
    correct += known_flag
    wrong += not known_flag
    
faithfulness = correct/(correct + wrong)

print("------------------------------------------------------------")
print(f"Faithfulness (@1) = {faithfulness}")
print("------------------------------------------------------------")

sample.subject='Argentina', sample.object='Buenos Aires', predicted=" Buenos", (p=0.8899907469749451), known=(✓)
sample.subject='Australia', sample.object='Canberra', predicted=" Canberra", (p=0.6983034610748291), known=(✓)
sample.subject='Canada', sample.object='Ottawa', predicted=" Ottawa", (p=0.7997354865074158), known=(✓)
sample.subject='Chile', sample.object='Santiago', predicted=" Santiago", (p=0.6504030823707581), known=(✓)
sample.subject='Colombia', sample.object='Bogot\\u00e1', predicted=" Bog", (p=0.38615524768829346), known=(✓)
sample.subject='Egypt', sample.object='Cairo', predicted=" Cairo", (p=0.9333562850952148), known=(✓)
sample.subject='France', sample.object='Paris', predicted=" Paris", (p=0.9924296736717224), known=(✓)
sample.subject='Germany', sample.object='Berlin', predicted=" Berlin", (p=0.9821451902389526), known=(✓)
sample.subject='India', sample.object='New Delhi', predicted=" Delhi", (p=0.6313555836677551), known=(✗)
sample.subject='Mexico', sample.object='Me

# $causality$

In [15]:
################### hparams ###################
rank = 100
###############################################

In [16]:
experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
test_targets = functional.random_edit_targets(test.samples)

## setup

In [17]:
source = test.samples[0]
target = test_targets[source]

f"Changing the mapping ({source}) to ({source.subject} -> {target.object})"

'Changing the mapping (Argentina -> Buenos Aires) to (Argentina -> Riyadh)'

### Calculate $\Delta \mathbf{s}$ such that $\mathbf{s} + \Delta \mathbf{s} \approx \mathbf{s}'$

<p align="center">
    <img align="center" src="causality-crop.png" style="width:80%;"/>
</p>

Under the relation $r =\, $*plays the instrument*, and given the subject $s =\, $*Miles Davis*, the model will predict $o =\, $*trumpet* **(a)**; and given the subject $s' =\, $*Cat Stevens*, the model will now predict $o' =\, $*guiter* **(b)**. 

If the computation from $\mathbf{s}$ to $\mathbf{o}$ is well-approximated by $operator$ parameterized by $W_r$ and $b_r$ **(c)**, then $\Delta{\mathbf{s}}$ **(d)** should tell us the direction of change from $\mathbf{s}$ to $\mathbf{s}'$. Thus, $\tilde{\mathbf{s}}=\mathbf{s}+\Delta\mathbf{s}$ would be an approximation of $\mathbf{s}'$ and patching $\tilde{\mathbf{s}}$ in place of $\mathbf{s}$ should change the prediction to $o'$ = *guitar* 

In [18]:
def get_delta_s(
    operator, 
    source_subject, 
    target_subject,
    rank = 100,
    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target
):
    w_p_inv = functional.low_rank_pinv(
        matrix = operator.weight,
        rank=rank,
    )
    hs_and_zs = functional.compute_hs_and_zs(
        mt = mt,
        prompt_template = operator.prompt_template,
        subjects = [source_subject, target_subject],
        h_layer= operator.h_layer,
        z_layer=-1,
    )

    z_source = hs_and_zs.z_by_subj[source_subject]
    z_target = hs_and_zs.z_by_subj[target_subject]
    
    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0
    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0

    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())

    return delta_s, hs_and_zs

delta_s, hs_and_zs = get_delta_s(
    operator = operator,
    source_subject = source.subject,
    target_subject = target.subject,
    rank = rank
)

In [19]:
import baukit

def get_intervention(h, int_layer, subj_idx):
    def edit_output(output, layer):
        if(layer != int_layer):
            return output
        functional.untuple(output)[:, subj_idx] = h 
        return output
    return edit_output

prompt = operator.prompt_template.format(source.subject)

h_index, inputs = functional.find_subject_token_index(
    mt=mt,
    prompt=prompt,
    subject=source.subject,
)

h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])

with baukit.TraceDict(
    mt.model, layers = [h_layer, z_layer],
    edit_output=get_intervention(
#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual
        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s
        int_layer = h_layer, 
        subj_idx = h_index
    )
) as traces:
    outputs = mt.model(
        input_ids = inputs.input_ids,
        attention_mask = inputs.attention_mask,
    )

lens.interpret_logits(
    mt = mt, 
    logits = outputs.logits[0][-1], 
    get_proba=True
)

[(' Riyadh', 0.802),
 (' J', 0.051),
 (' Mecca', 0.041),
 (' Saudi', 0.012),
 (' Riy', 0.01),
 ('\n', 0.007),
 (' Dam', 0.005),
 (' Cairo', 0.004),
 (' the', 0.004),
 (' Al', 0.003)]

## Measuring causality

In [20]:
from src.editors import LowRankPInvEditor

svd = torch.svd(operator.weight.float())
editor = LowRankPInvEditor(
    lre=operator,
    rank=rank,
    svd=svd,
)

In [21]:
# precomputing latents to speed things up
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject for sample in test.samples],
    h_layer= operator.h_layer,
    z_layer=-1,
    batch_size = 2
)

success = 0
fails = 0

for sample in test.samples:
    target = test_targets.get(sample)
    assert target is not None
    edit_result = editor(
        subject = sample.subject,
        target = target.subject
    )
    
    success_flag = functional.is_nontrivial_prefix(
        prediction=edit_result.predicted_tokens[0].token, target=target.object
    )
    
    print(f"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})")
    
    success += success_flag
    fails += not success_flag
    
causality = success / (success + fails)

print("------------------------------------------------------------")
print(f"Causality (@1) = {causality}")
print("------------------------------------------------------------")

Mapping Argentina -> Riyadh | edit result= Riyadh (p=0.819) | success=(✓)
Mapping Australia -> Buenos Aires | edit result= Buenos (p=0.820) | success=(✓)
Mapping Canada -> Abuja | edit result= Abu (p=0.606) | success=(✓)
Mapping Chile -> Lima | edit result= Lima (p=0.967) | success=(✓)
Mapping Colombia -> Berlin | edit result= Berlin (p=0.953) | success=(✓)
Mapping Egypt -> Mexico City | edit result= Mexico (p=0.983) | success=(✓)
Mapping France -> Riyadh | edit result= Riyadh (p=0.847) | success=(✓)
Mapping Germany -> Cairo | edit result= Cairo (p=0.970) | success=(✓)
Mapping India -> Lima | edit result= Lima (p=0.929) | success=(✓)
Mapping Mexico -> Santiago | edit result= Santiago (p=0.955) | success=(✓)
Mapping Nigeria -> Riyadh | edit result= Riyadh (p=0.849) | success=(✓)
Mapping Pakistan -> New Delhi | edit result= New (p=0.863) | success=(✓)
Mapping Peru -> Caracas | edit result= Car (p=0.936) | success=(✓)
Mapping Russia -> Cairo | edit result= Cairo (p=0.966) | success=(✓)
Ma