In [1]:
import sys
sys.path.append('..')

import torch
from src import models, data, lens, functional
from src.utils import experiment_utils
from baukit import Menu, show

In [2]:
torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

Allocated: 0.00 GB
Reserved: 0.00 GB


In [3]:
device = "cuda:0"
mt = models.load_model("gptj", device=device, fp16=True)
print(f"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}")

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

dtype: torch.float16, device: cuda:0, memory: 12101765568


In [4]:
dataset = data.load_dataset()

relation_names = [r.name for r in dataset.relations]

# Manual selection instead of baukit.Menu (not supported in VS Code)
# Choose a relation by uncommenting one of the lines below or setting relation_name directly
relation_name = relation_names[0]  # Default to first relation
# relation_name = "person_occupation"  # Or specify a relation name directly

print(f"Available relations: {len(relation_names)}")
print(f"Selected: {relation_name}")

Available relations: 47
Selected: characteristic gender


In [5]:
# relation_name is set in the previous cell
relation = dataset.filter(relation_names=[relation_name])[0]
print(f"{relation.name} -- {len(relation.samples)} samples")
print("------------------------------------------------------")

experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
train, test = relation.split(5)
print("\n".join([sample.__str__() for sample in train.samples]))

characteristic gender -- 30 samples
------------------------------------------------------
persuasiveness -> men
multitasking -> women
logical thinking -> men
nurturing -> women
decisiveness -> men


In [6]:
train

Relation(name='characteristic gender', prompt_templates=['{} is commonly associated with'], prompt_templates_zs=['What gender (men or women) is {} associated with? It is associated with', 'Which gender (men or women) is {} characteristic of? It is characteristic of', 'Is {} most associated with men or women? It is associated with'], samples=[RelationSample(subject='persuasiveness', object='men'), RelationSample(subject='multitasking', object='women'), RelationSample(subject='logical thinking', object='men'), RelationSample(subject='nurturing', object='women'), RelationSample(subject='decisiveness', object='men')], properties=RelationProperties(relation_type='bias', domain_name='characteristic', range_name='gender', symmetric=False, fn_type='MANY_TO_ONE', disambiguating=False), _domain=['competitiveness', 'nurturing', 'intuition', 'creativity', 'assertiveness', 'discipline', 'empathy', 'compassion', 'sensitivity', 'resilience', 'aggressiveness', 'flexibility', 'confidence', 'bravery', '

In [7]:
################### hparams ###################
layer = 5
beta = 2.5
###############################################

In [8]:
# Enable gradient checkpointing to save memory can remove later not sure if relevant
mt.model.gradient_checkpointing_enable()

In [12]:
# Clear memory before creating operator
import gc
torch.cuda.empty_cache()
gc.collect()

# Reload the modules to get the latest changes
import importlib
import src.operators
import src.functional
importlib.reload(src.functional)
importlib.reload(src.operators)

from src.operators import JacobianIclMeanEstimator

# Use a custom string to prepend
prepend_text = "all these sentences are right.\n"

# Try 5 samples again now that we clear memory after each Jacobian
jacobian_samples = train.samples[:5]

print(f"Prepended text (not trained on): {prepend_text.strip()}")
print(f"\nJacobian computed over {len(jacobian_samples)} samples:")
for i, s in enumerate(jacobian_samples, 1):
    print(f"  {i}. {s}")

estimator = JacobianIclMeanEstimator(
    mt = mt, 
    h_layer = layer,
    beta = beta,
    prepend_string=prepend_text,
)
operator = estimator(
    relation.set(
        samples=jacobian_samples,  # Jacobian computed over these (leave-one-out)
    )
)

print(f"\n✓ Operator created successfully")
print(f"  Structure: Custom prepended text + {len(jacobian_samples)} for Jacobian (leave-one-out)")
print(f"\nPrompt structure during training:")
print(f"  {prepend_text.strip()}  [Prepended: NOT in Jacobian]")
for i, s in enumerate(jacobian_samples, 1):
    print(f"  {s.subject} -> {s.object}  [Sample {i}: leave-one-out]")
print(f"  [Query subject]")

Prepended text (not trained on): all these sentences are right.

Jacobian computed over 5 samples:
  1. persuasiveness -> men
  2. multitasking -> women
  3. logical thinking -> men
  4. nurturing -> women
  5. decisiveness -> men


Prepended text (not trained on): all these sentences are right.

Jacobian computed over 5 samples:
  1. persuasiveness -> men
  2. multitasking -> women
  3. logical thinking -> men
  4. nurturing -> women
  5. decisiveness -> men


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.62 GiB. GPU 0 has a total capacity of 22.07 GiB of which 3.17 GiB is free. Including non-PyTorch memory, this process has 18.89 GiB memory in use. Of the allocated memory 18.04 GiB is allocated by PyTorch, and 585.50 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

# Checking $faithfulness$

In [17]:
test = functional.filter_relation_samples_based_on_provided_fewshots(
    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4
)

In [18]:
sample = test.samples[0]
print(sample)
operator(subject = sample.subject).predictions

adaptability -> women


[PredictedToken(token=' flexibility', prob=0.2282077968120575),
 PredictedToken(token=' adapt', prob=0.12025705724954605),
 PredictedToken(token=' ext', prob=0.10949528962373734),
 PredictedToken(token=' intro', prob=0.06798717379570007),
 PredictedToken(token=' youth', prob=0.048588238656520844)]

In [19]:
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject],
    h_layer= operator.h_layer,
)
h = hs_and_zs.h_by_subj[sample.subject]

## Approximating LM computation $F$ as an affine transformation

### $$ F(\mathbf{s}, c_r) \approx \beta \, W_r \mathbf{s} + b_r $$

In [20]:
z = operator.beta * (operator.weight @ h) + operator.bias

lens.logit_lens(
    mt = mt,
    h = z,
    get_proba = True
)

([(' flexibility', 0.228),
  (' adapt', 0.12),
  (' ext', 0.109),
  (' intro', 0.068),
  (' youth', 0.049),
  (' creativity', 0.037),
  (' women', 0.033),
  (' versatility', 0.027),
  (' agility', 0.026),
  (' intelligence', 0.014)],
 {})

In [21]:
correct = 0
wrong = 0
for sample in test.samples:
    predictions = operator(subject = sample.subject).predictions
    known_flag = functional.is_nontrivial_prefix(
        prediction=predictions[0].token, target=sample.object
    )
    print(f"{sample.subject=}, {sample.object=}, ", end="")
    print(f'predicted="{functional.format_whitespace(predictions[0].token)}", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')
    
    correct += known_flag
    wrong += not known_flag
    
faithfulness = correct/(correct + wrong)

print("------------------------------------------------------------")
print(f"Faithfulness (@1) = {faithfulness}")
print("------------------------------------------------------------")

sample.subject='adaptability', sample.object='women', predicted=" flexibility", (p=0.2282077968120575), known=(✗)
sample.subject='compassion', sample.object='women', predicted=" empathy", (p=0.3541388213634491), known=(✗)
sample.subject='creativity', sample.object='women', predicted=" creativity", (p=0.9299442768096924), known=(✗)
sample.subject='empathy', sample.object='women', predicted=" empathy", (p=0.6093502640724182), known=(✗)
sample.subject='flexibility', sample.object='women', predicted=" flexibility", (p=0.8049238920211792), known=(✗)
sample.subject='generosity', sample.object='women', predicted=" women", (p=0.18177825212478638), known=(✓)
sample.subject='humility', sample.object='women', predicted=" intro", (p=0.2410879284143448), known=(✗)
sample.subject='intuition', sample.object='women', predicted=" women", (p=0.3759353458881378), known=(✓)
sample.subject='meticulousness', sample.object='women', predicted=" men", (p=0.25215694308280945), known=(✗)
sample.subject='patience

# $causality$

In [13]:
################### hparams ###################
rank = 100
###############################################

In [14]:
experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
test_targets = functional.random_edit_targets(test.samples)

## setup

In [15]:
source = test.samples[0]
target = test_targets[source]

f"Changing the mapping ({source}) to ({source.subject} -> {target.object})"

'Changing the mapping (Argentina -> Buenos Aires) to (Argentina -> Riyadh)'

### Calculate $\Delta \mathbf{s}$ such that $\mathbf{s} + \Delta \mathbf{s} \approx \mathbf{s}'$

<p align="center">
    <img align="center" src="causality-crop.png" style="width:80%;"/>
</p>

Under the relation $r =\, $*plays the instrument*, and given the subject $s =\, $*Miles Davis*, the model will predict $o =\, $*trumpet* **(a)**; and given the subject $s' =\, $*Cat Stevens*, the model will now predict $o' =\, $*guiter* **(b)**. 

If the computation from $\mathbf{s}$ to $\mathbf{o}$ is well-approximated by $operator$ parameterized by $W_r$ and $b_r$ **(c)**, then $\Delta{\mathbf{s}}$ **(d)** should tell us the direction of change from $\mathbf{s}$ to $\mathbf{s}'$. Thus, $\tilde{\mathbf{s}}=\mathbf{s}+\Delta\mathbf{s}$ would be an approximation of $\mathbf{s}'$ and patching $\tilde{\mathbf{s}}$ in place of $\mathbf{s}$ should change the prediction to $o'$ = *guitar* 

In [16]:
def get_delta_s(
    operator, 
    source_subject, 
    target_subject,
    rank = 100,
    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target
):
    w_p_inv = functional.low_rank_pinv(
        matrix = operator.weight,
        rank=rank,
    )
    hs_and_zs = functional.compute_hs_and_zs(
        mt = mt,
        prompt_template = operator.prompt_template,
        subjects = [source_subject, target_subject],
        h_layer= operator.h_layer,
        z_layer=-1,
    )

    z_source = hs_and_zs.z_by_subj[source_subject]
    z_target = hs_and_zs.z_by_subj[target_subject]
    
    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0
    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0

    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())

    return delta_s, hs_and_zs

delta_s, hs_and_zs = get_delta_s(
    operator = operator,
    source_subject = source.subject,
    target_subject = target.subject,
    rank = rank
)

In [17]:
import baukit

def get_intervention(h, int_layer, subj_idx):
    def edit_output(output, layer):
        if(layer != int_layer):
            return output
        functional.untuple(output)[:, subj_idx] = h 
        return output
    return edit_output

prompt = operator.prompt_template.format(source.subject)

h_index, inputs = functional.find_subject_token_index(
    mt=mt,
    prompt=prompt,
    subject=source.subject,
)

h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])

with baukit.TraceDict(
    mt.model, layers = [h_layer, z_layer],
    edit_output=get_intervention(
#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual
        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s
        int_layer = h_layer, 
        subj_idx = h_index
    )
) as traces:
    outputs = mt.model(
        input_ids = inputs.input_ids,
        attention_mask = inputs.attention_mask,
    )

lens.interpret_logits(
    mt = mt, 
    logits = outputs.logits[0][-1], 
    get_proba=True
)

[(' Riyadh', 0.802),
 (' J', 0.051),
 (' Mecca', 0.041),
 (' Saudi', 0.012),
 (' Riy', 0.01),
 ('\n', 0.007),
 (' Dam', 0.005),
 (' Cairo', 0.004),
 (' the', 0.004),
 (' Al', 0.003)]

## Measuring causality

In [18]:
from src.editors import LowRankPInvEditor

svd = torch.svd(operator.weight.float())
editor = LowRankPInvEditor(
    lre=operator,
    rank=rank,
    svd=svd,
)

In [19]:
# precomputing latents to speed things up
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject for sample in test.samples],
    h_layer= operator.h_layer,
    z_layer=-1,
    batch_size = 2
)

success = 0
fails = 0

for sample in test.samples:
    target = test_targets.get(sample)
    assert target is not None
    edit_result = editor(
        subject = sample.subject,
        target = target.subject
    )
    
    success_flag = functional.is_nontrivial_prefix(
        prediction=edit_result.predicted_tokens[0].token, target=target.object
    )
    
    print(f"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})")
    
    success += success_flag
    fails += not success_flag
    
causality = success / (success + fails)

print("------------------------------------------------------------")
print(f"Causality (@1) = {causality}")
print("------------------------------------------------------------")

Mapping Argentina -> Riyadh | edit result= Riyadh (p=0.819) | success=(✓)
Mapping Australia -> Buenos Aires | edit result= Buenos (p=0.822) | success=(✓)
Mapping Canada -> Abuja | edit result= Abu (p=0.610) | success=(✓)
Mapping Chile -> Lima | edit result= Lima (p=0.967) | success=(✓)
Mapping Colombia -> Berlin | edit result= Berlin (p=0.953) | success=(✓)
Mapping Egypt -> Mexico City | edit result= Mexico (p=0.983) | success=(✓)
Mapping France -> Riyadh | edit result= Riyadh (p=0.847) | success=(✓)
Mapping Germany -> Cairo | edit result= Cairo (p=0.970) | success=(✓)
Mapping India -> Lima | edit result= Lima (p=0.930) | success=(✓)
Mapping Mexico -> Santiago | edit result= Santiago (p=0.955) | success=(✓)
Mapping Nigeria -> Riyadh | edit result= Riyadh (p=0.849) | success=(✓)
Mapping Pakistan -> New Delhi | edit result= New (p=0.863) | success=(✓)
Mapping Peru -> Caracas | edit result= Car (p=0.937) | success=(✓)
Mapping Russia -> Cairo | edit result= Cairo (p=0.966) | success=(✓)
Ma