In [None]:
import sys
sys.path.append('..')

import torch
from src import models, data, lens, functional
from src.utils import experiment_utils
from baukit import Menu, show

In [44]:
import sys
sys.path.append('..')

import torch
from src import models, data, lens, functional
from src.utils import experiment_utils
from baukit import Menu, show

torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

Allocated: 11.82 GB
Reserved: 12.74 GB


In [45]:
device = "cuda:0"
mt = models.load_model("gptj", device=device, fp16=True)
print(f"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}")

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

dtype: torch.float16, device: cuda:0, memory: 12101765568


In [450]:
dataset = data.load_dataset()

relation_names = [r.name for r in dataset.relations]

# Manual selection instead of baukit.Menu (not supported in VS Code)
# Choose a relation by uncommenting one of the lines below or setting relation_name directly
relation_name = relation_names[1]  # Default to first relation
# relation_name = "person_occupation"  # Or specify a relation name directly

print(f"Available relations: {len(relation_names)}")
print(f"Selected: {relation_name}")

Available relations: 47
Selected: univ degree gender


In [451]:
# relation_name is set in the previous cell
relation = dataset.filter(relation_names=[relation_name])[0]
print(f"{relation.name} -- {len(relation.samples)} samples")
print("------------------------------------------------------")

experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
train, test = relation.split(4)
print("\n".join([sample.__str__() for sample in train.samples]))

univ degree gender -- 38 samples
------------------------------------------------------
pharmacy -> women
computer science -> men
biology -> women
physics -> men


In [452]:
################### hparams ###################
layer = 5
beta = 2.5
###############################################

In [453]:
# Enable gradient checkpointing to save memory can remove later not sure if relevant
#mt.model.gradient_checkpointing_enable()

In [505]:
# Clear memory before creating operator
import gc
torch.cuda.empty_cache()
gc.collect()

# Reload the modules to get the latest changes
import importlib
import src.operators
import src.functional
importlib.reload(src.functional)
importlib.reload(src.operators)

from src.operators import JacobianIclMeanEstimator

# Use a custom string to prepend
#prepend_text = "The following 4 sentences are completely wrong.\n"
prepend_text = "feminism is commonly associated with men. The following 4 sentences are completely wrong. swap man for woman. visa versa. \n"

#prepend_text = "The following sentences are very wrong.\n"
#prepend_text = "The following 4 sentences are accurate.\n"

jacobian_samples = train.samples

print(f"Prepended text (not 'trained' on): {prepend_text.strip()}")
print(f"\nJacobian computed over {len(jacobian_samples)} samples:")
for i, s in enumerate(jacobian_samples, 1):
    print(f"  {i}. {s}")

estimator = JacobianIclMeanEstimator(
    mt = mt, 
    h_layer = layer,
    beta = beta,
    prepend_string=prepend_text,
)
operator = estimator(
    relation.set(
        samples=jacobian_samples,  # Jacobian computed over these (leave-one-out)
    )
)

print(f"\n✓ Operator created successfully")
print(f"  Structure: Custom prepended text + {len(jacobian_samples)} for Jacobian (leave-one-out)")
print(f"\nPrompt structure during training:")
print(f"  {prepend_text.strip()}  [Prepended: in Jacobian]")
for i, s in enumerate(jacobian_samples, 1):
    print(f"  {s.subject} -> {s.object}  [Sample {i}: leave-one-out]")
print(f"  [Query subject]")

Prepended text (not 'trained' on): feminism is commonly associated with men. The following 4 sentences are completely wrong. swap man for woman. visa versa.

Jacobian computed over 4 samples:
  1. pharmacy -> women
  2. computer science -> men
  3. biology -> women
  4. physics -> men

✓ Operator created successfully
  Structure: Custom prepended text + 4 for Jacobian (leave-one-out)

Prompt structure during training:
  feminism is commonly associated with men. The following 4 sentences are completely wrong. swap man for woman. visa versa.  [Prepended: in Jacobian]
  pharmacy -> women  [Sample 1: leave-one-out]
  computer science -> men  [Sample 2: leave-one-out]
  biology -> women  [Sample 3: leave-one-out]
  physics -> men  [Sample 4: leave-one-out]
  [Query subject]

✓ Operator created successfully
  Structure: Custom prepended text + 4 for Jacobian (leave-one-out)

Prompt structure during training:
  feminism is commonly associated with men. The following 4 sentences are completely

In [506]:
print("ICL samples (used for Jacobian computation):")
print("=" * 80)
for sample in jacobian_samples:
    # Format the prompt template with the subject to get the complete sentence
    prompt = operator.prompt_template.split('\n')
    
    # Find the line that matches this sample
    for line in prompt:
        if sample.subject in line and sample.object in line:
            print(f"• {line}")
            break

ICL samples (used for Jacobian computation):
• <|endoftext|>pharmacy students are typically women
• computer science students are typically men
• biology students are typically women
• physics students are typically men


# Checking $faithfulness$

test runs the LLM on the test data to filter out the ones where it succeeds. 

In [507]:
test = functional.filter_relation_samples_based_on_provided_fewshots(
    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4
)

In [508]:
sample = test.samples[0]
print(sample)
operator(subject = sample.subject).predictions

accounting -> men


[PredictedToken(token=' men', prob=0.5700547099113464),
 PredictedToken(token=' women', prob=0.41059502959251404),
 PredictedToken(token=' girls', prob=0.004025332164019346),
 PredictedToken(token=' male', prob=0.002681449055671692),
 PredictedToken(token='\n', prob=0.0013483171351253986)]

In [509]:
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject],
    h_layer= operator.h_layer,
)
h = hs_and_zs.h_by_subj[sample.subject]

## Approximating LM computation $F$ as an affine transformation

### $$ F(\mathbf{s}, c_r) \approx \beta \, W_r \mathbf{s} + b_r $$

In [510]:
z = operator.beta * (operator.weight @ h) + operator.bias

lens.logit_lens(
    mt = mt,
    h = z,
    get_proba = True
)

([(' men', 0.57),
  (' women', 0.411),
  (' girls', 0.004),
  (' male', 0.003),
  (' ...', 0.001),
  ('\n', 0.001),
  (' ', 0.001),
  (' female', 0.001),
  (' boys', 0.001),
  (' males', 0.001)],
 {})

In [511]:
test.samples

[RelationSample(subject='accounting', object='men'),
 RelationSample(subject='anthropology', object='women'),
 RelationSample(subject='architecture', object='men'),
 RelationSample(subject='astronomy', object='men'),
 RelationSample(subject='business', object='men'),
 RelationSample(subject='communications', object='women'),
 RelationSample(subject='culinary arts', object='women'),
 RelationSample(subject='economics', object='men'),
 RelationSample(subject='education', object='women'),
 RelationSample(subject='electrical engineering', object='men'),
 RelationSample(subject='engineering', object='men'),
 RelationSample(subject='environmental science', object='women'),
 RelationSample(subject='fashion design', object='women'),
 RelationSample(subject='fine arts', object='women'),
 RelationSample(subject='human resources', object='women'),
 RelationSample(subject='interior design', object='women'),
 RelationSample(subject='literature', object='women'),
 RelationSample(subject='marine biol

In [512]:
correct = 0
wrong = 0
for sample in test.samples:
    predictions = operator(subject = sample.subject).predictions
    known_flag = functional.is_nontrivial_prefix(
        prediction=predictions[0].token, target=sample.object
    )
    print(f"{sample.subject=}, {sample.object=}, ", end="")
    print(f'predicted="{functional.format_whitespace(predictions[0].token)}", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')
    
    correct += known_flag
    wrong += not known_flag
    
faithfulness = correct/(correct + wrong)

print("------------------------------------------------------------")
print(f"Faithfulness (@1) = {faithfulness}")
print("------------------------------------------------------------")

sample.subject='accounting', sample.object='men', predicted=" men", (p=0.5700547099113464), known=(✓)
sample.subject='anthropology', sample.object='women', predicted=" women", (p=0.6837999224662781), known=(✓)


sample.subject='architecture', sample.object='men', predicted=" men", (p=0.6868686676025391), known=(✓)
sample.subject='astronomy', sample.object='men', predicted=" men", (p=0.6133090257644653), known=(✓)
sample.subject='business', sample.object='men', predicted=" men", (p=0.8027550578117371), known=(✓)
sample.subject='communications', sample.object='women', predicted=" men", (p=0.592425525188446), known=(✗)
sample.subject='culinary arts', sample.object='women', predicted=" women", (p=0.6508271098136902), known=(✓)
sample.subject='economics', sample.object='men', predicted=" men", (p=0.6907230019569397), known=(✓)
sample.subject='culinary arts', sample.object='women', predicted=" women", (p=0.6508271098136902), known=(✓)
sample.subject='economics', sample.object='men', predicted=" men", (p=0.6907230019569397), known=(✓)
sample.subject='education', sample.object='women', predicted=" men", (p=0.5521031618118286), known=(✗)
sample.subject='electrical engineering', sample.object='men', pre

In [513]:
# Display the test samples where the operator's prediction was wrong
print("Incorrect predictions (faithfulness failures):")
print("=" * 80)

wrong_predictions = []
for sample in test.samples:
    predictions = operator(subject = sample.subject).predictions
    known_flag = functional.is_nontrivial_prefix(
        prediction=predictions[0].token, target=sample.object
    )
    
    if not known_flag:
        wrong_predictions.append({
            'subject': sample.subject,
            'expected': sample.object,
            'predicted': predictions[0].token,
            'prob': predictions[0].prob
        })
        print(f"Subject: {sample.subject}")
        print(f"  Expected: {sample.object}")
        print(f"  Predicted: {functional.format_whitespace(predictions[0].token)} (p={predictions[0].prob:.4f})")
        print()

print("=" * 80)
print(f"Total wrong: {len(wrong_predictions)} out of {len(test.samples)}")
print(f"Faithfulness: {correct}/{correct + wrong} = {faithfulness:.4f}")

Incorrect predictions (faithfulness failures):


Subject: communications
  Expected: women
  Predicted:  men (p=0.5924)

Subject: education
  Expected: women
  Predicted:  men (p=0.5521)

Subject: fashion design
  Expected: women
  Predicted:  men (p=0.5269)

Subject: human resources
  Expected: women
  Predicted:  men (p=0.6470)

Subject: literature
  Expected: women
  Predicted:  men (p=0.6252)

Subject: fashion design
  Expected: women
  Predicted:  men (p=0.5269)

Subject: human resources
  Expected: women
  Predicted:  men (p=0.6470)

Subject: literature
  Expected: women
  Predicted:  men (p=0.6252)

Subject: political science
  Expected: men
  Predicted:  women (p=0.5874)

Total wrong: 6 out of 26
Faithfulness: 20/26 = 0.7692
Subject: political science
  Expected: men
  Predicted:  women (p=0.5874)

Total wrong: 6 out of 26
Faithfulness: 20/26 = 0.7692


In [514]:
# Let's see what the actual prompt looks like
print("=" * 80)
print("PROMPT TEMPLATE STORED IN OPERATOR:")
print("=" * 80)
print(repr(operator.prompt_template))
print("\n" + "=" * 80)
print("EXAMPLE PROMPT FOR FIRST TEST SAMPLE:")
print("=" * 80)
example_prompt = operator.prompt_template.format(test.samples[0].subject)
print(example_prompt)
print("=" * 80)

PROMPT TEMPLATE STORED IN OPERATOR:
'<|endoftext|>pharmacy students are typically women\ncomputer science students are typically men\nbiology students are typically women\nphysics students are typically men\n{} students are typically'

EXAMPLE PROMPT FOR FIRST TEST SAMPLE:
<|endoftext|>pharmacy students are typically women
computer science students are typically men
biology students are typically women
physics students are typically men
accounting students are typically


# $causality$

In [515]:
################### hparams ###################
rank = 100
###############################################

In [516]:
experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
test_targets = functional.random_edit_targets(test.samples)

## setup

In [517]:
source = test.samples[0]
target = test_targets[source]

f"Changing the mapping ({source}) to ({source.subject} -> {target.object})"

'Changing the mapping (accounting -> men) to (accounting -> women)'

### Calculate $\Delta \mathbf{s}$ such that $\mathbf{s} + \Delta \mathbf{s} \approx \mathbf{s}'$

<p align="center">
    <img align="center" src="causality-crop.png" style="width:80%;"/>
</p>

Under the relation $r =\, $*plays the instrument*, and given the subject $s =\, $*Miles Davis*, the model will predict $o =\, $*trumpet* **(a)**; and given the subject $s' =\, $*Cat Stevens*, the model will now predict $o' =\, $*guiter* **(b)**. 

If the computation from $\mathbf{s}$ to $\mathbf{o}$ is well-approximated by $operator$ parameterized by $W_r$ and $b_r$ **(c)**, then $\Delta{\mathbf{s}}$ **(d)** should tell us the direction of change from $\mathbf{s}$ to $\mathbf{s}'$. Thus, $\tilde{\mathbf{s}}=\mathbf{s}+\Delta\mathbf{s}$ would be an approximation of $\mathbf{s}'$ and patching $\tilde{\mathbf{s}}$ in place of $\mathbf{s}$ should change the prediction to $o'$ = *guitar* 

In [518]:
def get_delta_s(
    operator, 
    source_subject, 
    target_subject,
    rank = 100,
    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target
):
    w_p_inv = functional.low_rank_pinv(
        matrix = operator.weight,
        rank=rank,
    )
    hs_and_zs = functional.compute_hs_and_zs(
        mt = mt,
        prompt_template = operator.prompt_template,
        subjects = [source_subject, target_subject],
        h_layer= operator.h_layer,
        z_layer=-1,
    )

    z_source = hs_and_zs.z_by_subj[source_subject]
    z_target = hs_and_zs.z_by_subj[target_subject]
    
    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0
    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0

    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())

    return delta_s, hs_and_zs

delta_s, hs_and_zs = get_delta_s(
    operator = operator,
    source_subject = source.subject,
    target_subject = target.subject,
    rank = rank
)

In [519]:
import baukit

def get_intervention(h, int_layer, subj_idx):
    def edit_output(output, layer):
        if(layer != int_layer):
            return output
        functional.untuple(output)[:, subj_idx] = h 
        return output
    return edit_output

prompt = operator.prompt_template.format(source.subject)

h_index, inputs = functional.find_subject_token_index(
    mt=mt,
    prompt=prompt,
    subject=source.subject,
)

h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])

with baukit.TraceDict(
    mt.model, layers = [h_layer, z_layer],
    edit_output=get_intervention(
#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual
        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s
        int_layer = h_layer, 
        subj_idx = h_index
    )
) as traces:
    outputs = mt.model(
        input_ids = inputs.input_ids,
        attention_mask = inputs.attention_mask,
    )

lens.interpret_logits(
    mt = mt, 
    logits = outputs.logits[0][-1], 
    get_proba=True
)

[(' women', 0.573),
 (' men', 0.376),
 (' male', 0.012),
 (' female', 0.01),
 (' girls', 0.004),
 ('\n', 0.003),
 (' males', 0.002),
 (' females', 0.002),
 (' woman', 0.002),
 (' guys', 0.001)]

## Measuring causality

In [520]:
from src.editors import LowRankPInvEditor

svd = torch.svd(operator.weight.float())
editor = LowRankPInvEditor(
    lre=operator,
    rank=rank,
    svd=svd,
)

In [521]:
# precomputing latents to speed things up
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject for sample in test.samples],
    h_layer= operator.h_layer,
    z_layer=-1,
    batch_size = 2
)

success = 0
fails = 0

for sample in test.samples:
    target = test_targets.get(sample)
    assert target is not None
    edit_result = editor(
        subject = sample.subject,
        target = target.subject
    )
    
    success_flag = functional.is_nontrivial_prefix(
        prediction=edit_result.predicted_tokens[0].token, target=target.object
    )
    
    print(f"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})")
    
    success += success_flag
    fails += not success_flag
    
causality = success / (success + fails)

print("------------------------------------------------------------")
print(f"Causality (@1) = {causality}")
print("------------------------------------------------------------")

Mapping accounting -> women | edit result= women (p=0.576) | success=(✓)
Mapping anthropology -> men | edit result= men (p=0.536) | success=(✓)
Mapping architecture -> women | edit result= men (p=0.524) | success=(✗)
Mapping astronomy -> women | edit result= women (p=0.891) | success=(✓)
Mapping business -> women | edit result= women (p=0.687) | success=(✓)
Mapping communications -> men | edit result= men (p=0.512) | success=(✓)
Mapping astronomy -> women | edit result= women (p=0.891) | success=(✓)
Mapping business -> women | edit result= women (p=0.687) | success=(✓)
Mapping communications -> men | edit result= men (p=0.512) | success=(✓)
Mapping culinary arts -> men | edit result= women (p=0.532) | success=(✗)
Mapping economics -> women | edit result= women (p=0.539) | success=(✓)
Mapping education -> men | edit result= men (p=0.567) | success=(✓)
Mapping culinary arts -> men | edit result= women (p=0.532) | success=(✗)
Mapping economics -> women | edit result= women (p=0.539) | suc