In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append('..')

import torch
torch.__version__, torch.version.cuda

('1.13.1+cu117', '11.7')

In [3]:
from src import models, data, lens, functional
from src.utils import experiment_utils

import logging
from src.utils import logging_utils
logger = logging.getLogger(__name__)
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_utils.DEFAULT_FORMAT,
    datefmt=logging_utils.DEFAULT_DATEFMT,
    stream=sys.stdout,
)

In [4]:
device = "cuda:0"
mt = models.load_model("mamba-3b", device=device, fp16=False)
# mt = models.load_model("gptj", device=device, fp16=True)

2024-02-06 17:11:20 src.models INFO     loading state-spaces/mamba-2.8b-slimpj (device=cuda:0, fp16=False)
2024-02-06 17:11:20 urllib3.connectionpool DEBUG    Starting new HTTPS connection (1): huggingface.co:443


2024-02-06 17:11:20 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b-slimpj/resolve/main/config.json HTTP/1.1" 200 0
2024-02-06 17:11:37 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /state-spaces/mamba-2.8b-slimpj/resolve/main/pytorch_model.bin HTTP/1.1" 302 0
2024-02-06 17:11:47 urllib3.connectionpool DEBUG    https://huggingface.co:443 "HEAD /EleutherAI/gpt-neox-20b/resolve/main/tokenizer_config.json HTTP/1.1" 200 0


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


2024-02-06 17:11:47 src.models INFO     dtype: torch.float32, device: cuda:0, memory: 10.31 GB


In [5]:
dataset = data.load_dataset()

# relation_names = [r.name for r in dataset.relations]
# relation_options = Menu(choices = relation_names, value = relation_names)
# show(relation_options) # !caution: tested in a juputer-notebook. baukit visualizations are not supported in vscode.

2024-02-06 17:12:18 src.data DEBUG    no paths provided, using default data dir: /home/local_arnab/Codes/relations/notebooks/../data
2024-02-06 17:12:18 src.data DEBUG    /home/local_arnab/Codes/relations/notebooks/../data is directory, globbing for json files...
2024-02-06 17:12:18 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/characteristic_gender.json
2024-02-06 17:12:18 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/degree_gender.json
2024-02-06 17:12:18 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/name_birthplace.json
2024-02-06 17:12:18 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/name_gender.json
2024-02-06 17:12:18 src.data DEBUG    found relation file: /home/local_arnab/Codes/relations/notebooks/../data/bias/name_religion.json
2024-02-06 17:12:18 src.data DEBUG    found relation 

In [6]:
relation_name = "country capital city"
relation = dataset.filter(relation_names=[relation_name])[0]
print(f"{relation.name} -- {len(relation.samples)} samples")
print("------------------------------------------------------")

experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
train, test = relation.split(5)
print("\n".join([sample.__str__() for sample in train.samples]))

2024-02-06 17:12:19 src.data DEBUG    filtering to only relations: ['country capital city']
country capital city -- 24 samples
------------------------------------------------------
2024-02-06 17:12:19 src.utils.experiment_utils INFO     setting all seeds to 12345
China -> Beijing
Japan -> Tokyo
Italy -> Rome
Brazil -> Bras\u00edlia
Turkey -> Ankara


In [7]:
################### hparams ###################
layer = 20
beta = 8
###############################################

In [8]:
# from src.operators import JacobianIclMeanEstimator
# experiment_utils.set_seed(12345)

# estimator = JacobianIclMeanEstimator(
#     mt = mt, 
#     h_layer = layer,
#     beta = beta
# )
# operator = estimator(
#     relation.set(
#         samples=train.samples, 
#     )
# )

In [9]:
mt.name

'mamba-3b'

In [10]:
import os
root_path = "../results/cached_o1_approxes"
path = os.path.join(
    root_path,
    mt.name,
    relation_name.lower().replace(" ", "_"),
    str(layer),
)

approxes = []
for cached_file in os.listdir(path):
    approx = functional.load_cached_linear_operator(file_path = os.path.join(path, cached_file))
    approxes.append(approx)


In [11]:
weight = torch.stack([approx.weight for approx in approxes]).mean(dim=0)
bias = torch.stack([approx.bias for approx in approxes]).mean(dim=0)

prompt_template = relation.prompt_templates[0]

prompt_template_icl = functional.make_prompt(
    mt = mt,
    prompt_template=prompt_template,
    subject="{}",
    examples = [
        data.RelationSample.from_dict(approx.metadata["sample"]) 
        for approx in approxes
    ][:min(3, len(approxes))],
)

print(prompt_template_icl)

<|endoftext|> The capital city of China is Beijing
The capital city of Nigeria is Abuja
The capital city of Colombia is Bogot\u00e1
The capital city of {} is


In [12]:
from src.operators import LinearRelationOperator
operator = LinearRelationOperator(
    mt = mt,
    weight = weight,
    bias = bias,
    h_layer = approxes[0].h_layer,
    z_layer = approxes[0].z_layer,
    beta = 5,
    prompt_template = prompt_template_icl,
)

In [13]:
from src.lens import logit_lens
from src import models

# logit_lens(mt = mt, h = operator.metadata["Jh"][0].to(models.determine_device(mt)) + operator.bias)

In [14]:
from src.functional import predict_next_token

predict_next_token(
    mt = mt, 
    prompt = mt.tokenizer.eos_token + " The capital of {} is".format("France"),
    # prompt = mt.tokenizer.eos_token + " The superlative of {} is".format("good"),
)

[[PredictedToken(token=' Paris', prob=0.34679484367370605),
  PredictedToken(token=' a', prob=0.07820745557546616),
  PredictedToken(token=' the', prob=0.059068549424409866),
  PredictedToken(token=' located', prob=0.052434541285037994),
  PredictedToken(token=' also', prob=0.02528112567961216)]]

# Checking $faithfulness$

In [15]:
test = functional.filter_relation_samples_based_on_provided_fewshots(
    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4
)

2024-02-06 17:12:36 src.functional DEBUG    filtering for knowns using prompt "<|endoftext|> The capital city of China is Beijing
The capital city of Nigeria is Abuja
The capital city of Colombia is Bogot\u00e1
The capital city of {} is"
2024-02-06 17:12:36 src.functional DEBUG    sample.subject='South Korea', sample.object='Seoul', predicted=' Seoul' (p=0.969), known=(✓)
2024-02-06 17:12:36 src.functional DEBUG    sample.subject='Colombia', sample.object='Bogot\\u00e1', predicted=' Bog' (p=0.739), known=(✓)
2024-02-06 17:12:36 src.functional DEBUG    sample.subject='Saudi Arabia', sample.object='Riyadh', predicted=' R' (p=0.807), known=(✓)
2024-02-06 17:12:36 src.functional DEBUG    sample.subject='France', sample.object='Paris', predicted=' Paris' (p=0.976), known=(✓)
2024-02-06 17:12:36 src.functional DEBUG    sample.subject='Mexico', sample.object='Mexico City', predicted=' Mexico' (p=0.956), known=(✓)
2024-02-06 17:12:36 src.functional DEBUG    sample.subject='Pakistan', sample.ob

In [17]:
sample = [s for s in test.samples if s.subject == "France"][0]
print(sample)
operator(subject = sample.subject).predictions

France -> Paris
2024-02-06 17:12:47 src.operators DEBUG    computing h from prompt "<|endoftext|> The capital city of China is Beijing
The capital city of Nigeria is Abuja
The capital city of Colombia is Bogot\u00e1
The capital city of France is"


[PredictedToken(token=' Paris', prob=0.8592612147331238),
 PredictedToken(token=' ', prob=0.056149568408727646),
 PredictedToken(token=' Capital', prob=0.035243548452854156),
 PredictedToken(token='\n', prob=0.016377585008740425),
 PredictedToken(token=' -', prob=0.004834229126572609)]

In [18]:
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject],
    h_layer= operator.h_layer,
)
h = hs_and_zs.h_by_subj[sample.subject]

## Approximating LM computation $F$ as an affine transformation

### $$ F(\mathbf{s}, c_r) \approx \beta \, W_r \mathbf{s} + b_r $$

In [19]:
z = 5 * (operator.weight @ h) + operator.bias

lens.logit_lens(
    mt = mt,
    h = z,
    get_proba = True
)

([(' Paris', 0.859),
  (' ', 0.056),
  (' Capital', 0.035),
  ('\n', 0.016),
  (' -', 0.005),
  (' (', 0.005),
  (' capital', 0.004),
  (' in', 0.003),
  (' Cairo', 0.002),
  ('...', 0.002)],
 {})

In [20]:
correct = 0
wrong = 0
for sample in test.samples:
    predictions = operator(subject = sample.subject).predictions
    known_flag = functional.is_nontrivial_prefix(
        prediction=predictions[0].token, target=sample.object
    )
    print(f"{sample.subject=}, {sample.object=}, ", end="")
    print(f'predicted="{functional.format_whitespace(predictions[0].token)}", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')
    
    correct += known_flag
    wrong += not known_flag
    
faithfulness = correct/(correct + wrong)

print("------------------------------------------------------------")
print(f"Faithfulness (@1) = {faithfulness}")
print("------------------------------------------------------------")

2024-02-06 17:12:51 src.operators DEBUG    computing h from prompt "<|endoftext|> The capital city of China is Beijing
The capital city of Nigeria is Abuja
The capital city of Colombia is Bogot\u00e1
The capital city of Argentina is"


sample.subject='Argentina', sample.object='Buenos Aires', predicted="\n", (p=0.20366734266281128), known=(✗)
2024-02-06 17:12:51 src.operators DEBUG    computing h from prompt "<|endoftext|> The capital city of China is Beijing
The capital city of Nigeria is Abuja
The capital city of Colombia is Bogot\u00e1
The capital city of Australia is"
sample.subject='Australia', sample.object='Canberra', predicted=" Capital", (p=0.7717662453651428), known=(✗)
2024-02-06 17:12:51 src.operators DEBUG    computing h from prompt "<|endoftext|> The capital city of China is Beijing
The capital city of Nigeria is Abuja
The capital city of Colombia is Bogot\u00e1
The capital city of Canada is"
sample.subject='Canada', sample.object='Ottawa', predicted=" Capital", (p=0.3543623387813568), known=(✗)
2024-02-06 17:12:51 src.operators DEBUG    computing h from prompt "<|endoftext|> The capital city of China is Beijing
The capital city of Nigeria is Abuja
The capital city of Colombia is Bogot\u00e1
The capital

# $causality$

In [21]:
################### hparams ###################
rank = 100
###############################################

In [22]:
experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
test_targets = functional.random_edit_targets(test.samples)

2024-02-06 17:12:54 src.utils.experiment_utils INFO     setting all seeds to 12345


## setup

In [23]:
source = test.samples[0]
target = test_targets[source]

f"Changing the mapping ({source}) to ({source.subject} -> {target.object})"

'Changing the mapping (Argentina -> Buenos Aires) to (Argentina -> Riyadh)'

### Calculate $\Delta \mathbf{s}$ such that $\mathbf{s} + \Delta \mathbf{s} \approx \mathbf{s}'$

<p align="center">
    <img align="center" src="../causality_crop" style="width:80%;"/>
</p>

Under the relation $r =\, $*plays the instrument*, and given the subject $s =\, $*Miles Davis*, the model will predict $o =\, $*trumpet* **(a)**; and given the subject $s' =\, $*Cat Stevens*, the model will now predict $o' =\, $*guiter* **(b)**. 

If the computation from $\mathbf{s}$ to $\mathbf{o}$ is well-approximated by $operator$ parameterized by $W_r$ and $b_r$ **(c)**, then $\Delta{\mathbf{s}}$ **(d)** should tell us the direction of change from $\mathbf{s}$ to $\mathbf{s}'$. Thus, $\tilde{\mathbf{s}}=\mathbf{s}+\Delta\mathbf{s}$ would be an approximation of $\mathbf{s}'$ and patching $\tilde{\mathbf{s}}$ in place of $\mathbf{s}$ should change the prediction to $o'$ = *guitar* 

In [24]:
def get_delta_s(
    operator, 
    source_subject, 
    target_subject,
    rank = 100,
    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target
):
    w_p_inv = functional.low_rank_pinv(
        matrix = operator.weight,
        rank=rank,
    )
    hs_and_zs = functional.compute_hs_and_zs(
        mt = mt,
        prompt_template = operator.prompt_template,
        subjects = [source_subject, target_subject],
        h_layer= operator.h_layer,
        z_layer=-1,
    )

    z_source = hs_and_zs.z_by_subj[source_subject]
    z_target = hs_and_zs.z_by_subj[target_subject]
    
    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0
    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0

    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())

    return delta_s, hs_and_zs

delta_s, hs_and_zs = get_delta_s(
    operator = operator,
    source_subject = source.subject,
    target_subject = target.subject,
    rank = rank
)

In [25]:
import baukit

def get_intervention(h, int_layer, subj_idx):
    def edit_output(output, layer):
        if(layer != int_layer):
            return output
        functional.untuple(output)[:, subj_idx] = h 
        return output
    return edit_output

prompt = operator.prompt_template.format(source.subject)

h_index, inputs = functional.find_subject_token_index(
    mt=mt,
    prompt=prompt,
    subject=source.subject,
)

h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])

with baukit.TraceDict(
    mt.model, layers = [h_layer, z_layer],
    edit_output=get_intervention(
#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual
        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s
        int_layer = h_layer, 
        subj_idx = h_index
    )
) as traces:
    outputs = mt(
        input_ids = inputs.input_ids,
        attention_mask = inputs.attention_mask,
    )

logits = outputs.logits[0][-1] if hasattr(outputs, "logits") else outputs[0][-1]
lens.interpret_logits(
    mt = mt, 
    logits = logits, 
    get_proba=True
)

[(' Me', 0.592),
 (' R', 0.181),
 (' Jed', 0.063),
 (' Mak', 0.046),
 (' Med', 0.013),
 (' Dh', 0.012),
 (' Saudi', 0.011),
 (' Dubai', 0.005),
 (' Kuwait', 0.005),
 (' K', 0.004)]

## Measuring causality

In [26]:
from src.editors import LowRankPInvEditor

svd = torch.svd(operator.weight.float())
editor = LowRankPInvEditor(
    lre=operator,
    rank=rank,
    svd=svd,
)

In [27]:
# precomputing latents to speed things up
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject for sample in test.samples],
    h_layer= operator.h_layer,
    z_layer=-1,
    batch_size = 2
)

success = 0
fails = 0

for sample in test.samples:
    target = test_targets.get(sample)
    assert target is not None
    edit_result = editor(
        subject = sample.subject,
        target = target.subject
    )
    
    success_flag = functional.is_nontrivial_prefix(
        prediction=edit_result.predicted_tokens[0].token, target=target.object
    )
    
    print(f"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})")
    
    success += success_flag
    fails += not success_flag
    
causality = success / (success + fails)

print("------------------------------------------------------------")
print(f"Causality (@1) = {causality}")
print("------------------------------------------------------------")

2024-02-06 17:13:00 src.editors DEBUG    computing low-rank pinv (rel=<|endoftext|> The capital city of China is Beijing
The capital city of Nigeria is Abuja
The capital city of Colombia is Bogot\u00e1
The capital city of {} is)
Mapping Argentina -> Riyadh | edit result=' Me' (p=0.456) | success=(✗)
Mapping Australia -> Buenos Aires | edit result=' Buenos' (p=0.187) | success=(✓)
Mapping Canada -> Abuja | edit result=' Abu' (p=0.435) | success=(✓)
Mapping Chile -> Lima | edit result=' L' (p=0.974) | success=(✓)
Mapping Colombia -> Berlin | edit result=' Bog' (p=0.925) | success=(✗)
Mapping Egypt -> Mexico City | edit result=' Mexico' (p=0.894) | success=(✓)
Mapping France -> Riyadh | edit result=' Me' (p=0.669) | success=(✗)
Mapping Germany -> Cairo | edit result=' Cairo' (p=0.934) | success=(✓)
Mapping India -> Lima | edit result=' L' (p=0.890) | success=(✓)
Mapping Mexico -> Santiago | edit result=' Su' (p=0.205) | success=(✗)
Mapping Nigeria -> Riyadh | edit result=' Abu' (p=0.722) 

In [28]:
from src.functional import save_linear_operator

save_linear_operator(
    approx = operator,
    file_name = "lre_capital",
    path = "cached"
)

In [29]:
operator_loaded = functional.load_cached_linear_operator(mt = mt, file_path = "cached/lre_capital.npz")

## Test

In [None]:
prompt = mt.tokenizer.eos_token + " Michael Jordan professionally played the sport of"
tokenized = mt.tokenizer(prompt, return_tensors="pt").to(device)
models.determine_layer_paths(mt)[layer]

In [None]:
import baukit

layer_out = models.determine_layer_paths(mt)[layer]
mixer_out = models.determine_layer_paths(mt)[layer] + ".mixer"
layer_in = models.determine_layer_paths(mt)[layer+1]
final_layer = models.determine_layer_paths(mt)[-1]
final_mixer = models.determine_layer_paths(mt)[-1] + ".mixer"

with baukit.TraceDict(
    module=mt.model,
    layers=[
        layer_out,
        mixer_out,
        layer_in,
        final_mixer,
        final_layer,
        "backbone"
    ],
    retain_input=True,
) as traces:
    output = mt(**tokenized)

In [None]:
traces["backbone"].output

In [None]:
traces[final_layer].output

In [None]:
from torch import nn

class RMSNorm(nn.Module):
    def __init__(self,
                 d_model: int,
                 eps: float = 1e-5):
        super().__init__()
        self.eps = eps
        self.weight = baukit.get_module(mt.model, "backbone.norm_f").weight


    def forward(self, x):
        output = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) * self.weight

        return output

custom_rms = RMSNorm(d_model = models.determine_hidden_size(mt))
block_output, residual = traces[final_layer].output
backbone_output = custom_rms(block_output + residual)

backbone_output

In [None]:
torch.allclose(traces[mixer_out].output, traces[layer_out].output[0])

In [None]:
baukit.get_module(mt.model, "backbone.norm_f").bias

In [None]:
hasattr(mt.model, "backbone")

In [45]:
n_layers = 32
n_approx = 25
time_per_approx = 7 # in minutes
size_per_approx = 26 # in MB
n_threads = 2


total_time = n_layers * n_approx * time_per_approx
total_time /= n_threads
total_time /= 60

total_size = n_layers * n_approx * size_per_approx
total_size /= 1024

f"Time: {total_time} hours, Size: {total_size} GB"

'Time: 46.666666666666664 hours, Size: 20.3125 GB'

In [47]:
models.determine_layers(mt)

(0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39,
 40,
 41,
 42,
 43,
 44,
 45,
 46,
 47,
 48,
 49,
 50,
 51,
 52,
 53,
 54,
 55,
 56,
 57,
 58,
 59,
 60,
 61,
 62,
 63)

In [49]:
import numpy as np

np.arange(0, 64, 2)

array([ 0,  2,  4,  6,  8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32,
       34, 36, 38, 40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62])