In [None]:
import sys
sys.path.append('..')

import torch
from src import models, data, lens, functional
from src.utils import experiment_utils
from baukit import Menu, show

In [44]:
import sys
sys.path.append('..')

import torch
from src import models, data, lens, functional
from src.utils import experiment_utils
from baukit import Menu, show

torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated(0) / 1024**3:.2f} GB")
print(f"Reserved: {torch.cuda.memory_reserved(0) / 1024**3:.2f} GB")

Allocated: 11.82 GB
Reserved: 12.74 GB


In [45]:
device = "cuda:0"
mt = models.load_model("gptj", device=device, fp16=True)
print(f"dtype: {mt.model.dtype}, device: {mt.model.device}, memory: {mt.model.get_memory_footprint()}")

Some weights of the model checkpoint at EleutherAI/gpt-j-6B were not used when initializing GPTJForCausalLM: ['transformer.h.0.attn.bias', 'transformer.h.0.attn.masked_bias', 'transformer.h.1.attn.bias', 'transformer.h.1.attn.masked_bias', 'transformer.h.10.attn.bias', 'transformer.h.10.attn.masked_bias', 'transformer.h.11.attn.bias', 'transformer.h.11.attn.masked_bias', 'transformer.h.12.attn.bias', 'transformer.h.12.attn.masked_bias', 'transformer.h.13.attn.bias', 'transformer.h.13.attn.masked_bias', 'transformer.h.14.attn.bias', 'transformer.h.14.attn.masked_bias', 'transformer.h.15.attn.bias', 'transformer.h.15.attn.masked_bias', 'transformer.h.16.attn.bias', 'transformer.h.16.attn.masked_bias', 'transformer.h.17.attn.bias', 'transformer.h.17.attn.masked_bias', 'transformer.h.18.attn.bias', 'transformer.h.18.attn.masked_bias', 'transformer.h.19.attn.bias', 'transformer.h.19.attn.masked_bias', 'transformer.h.2.attn.bias', 'transformer.h.2.attn.masked_bias', 'transformer.h.20.attn.bi

dtype: torch.float16, device: cuda:0, memory: 12101765568


In [176]:
dataset = data.load_dataset()

relation_names = [r.name for r in dataset.relations]

# Manual selection instead of baukit.Menu (not supported in VS Code)
# Choose a relation by uncommenting one of the lines below or setting relation_name directly
relation_name = relation_names[0]  # Default to first relation
# relation_name = "person_occupation"  # Or specify a relation name directly

print(f"Available relations: {len(relation_names)}")
print(f"Selected: {relation_name}")

Available relations: 47
Selected: characteristic gender


In [177]:
# relation_name is set in the previous cell
relation = dataset.filter(relation_names=[relation_name])[0]
print(f"{relation.name} -- {len(relation.samples)} samples")
print("------------------------------------------------------")

experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
train, test = relation.split(4)
print("\n".join([sample.__str__() for sample in train.samples]))

characteristic gender -- 30 samples
------------------------------------------------------
persuasiveness -> men
multitasking -> women
logical thinking -> men
nurturing -> women


In [178]:
################### hparams ###################
layer = 5
beta = 2.5
###############################################

In [None]:
# Enable gradient checkpointing to save memory can remove later not sure if relevant
#mt.model.gradient_checkpointing_enable()

In [202]:
# Clear memory before creating operator
import gc
torch.cuda.empty_cache()
gc.collect()

# Reload the modules to get the latest changes
import importlib
import src.operators
import src.functional
importlib.reload(src.functional)
importlib.reload(src.operators)

from src.operators import JacobianIclMeanEstimator

# Use a custom string to prepend
#prepend_text = "The following 4 sentences are completely wrong.\n"
prepend_text = "The following 4 sentences are accurate.\n"

jacobian_samples = train.samples

print(f"Prepended text (not 'trained' on): {prepend_text.strip()}")
print(f"\nJacobian computed over {len(jacobian_samples)} samples:")
for i, s in enumerate(jacobian_samples, 1):
    print(f"  {i}. {s}")

estimator = JacobianIclMeanEstimator(
    mt = mt, 
    h_layer = layer,
    beta = beta,
    prepend_string=prepend_text,
)
operator = estimator(
    relation.set(
        samples=jacobian_samples,  # Jacobian computed over these (leave-one-out)
    )
)

print(f"\n✓ Operator created successfully")
print(f"  Structure: Custom prepended text + {len(jacobian_samples)} for Jacobian (leave-one-out)")
print(f"\nPrompt structure during training:")
print(f"  {prepend_text.strip()}  [Prepended: in Jacobian]")
for i, s in enumerate(jacobian_samples, 1):
    print(f"  {s.subject} -> {s.object}  [Sample {i}: leave-one-out]")
print(f"  [Query subject]")

Prepended text (not 'trained' on): The following 4 sentences are accurate.

Jacobian computed over 4 samples:
  1. persuasiveness -> men
  2. multitasking -> women
  3. logical thinking -> men
  4. nurturing -> women

✓ Operator created successfully
  Structure: Custom prepended text + 4 for Jacobian (leave-one-out)

Prompt structure during training:
  The following 4 sentences are accurate.  [Prepended: in Jacobian]
  persuasiveness -> men  [Sample 1: leave-one-out]
  multitasking -> women  [Sample 2: leave-one-out]
  logical thinking -> men  [Sample 3: leave-one-out]
  nurturing -> women  [Sample 4: leave-one-out]
  [Query subject]

✓ Operator created successfully
  Structure: Custom prepended text + 4 for Jacobian (leave-one-out)

Prompt structure during training:
  The following 4 sentences are accurate.  [Prepended: in Jacobian]
  persuasiveness -> men  [Sample 1: leave-one-out]
  multitasking -> women  [Sample 2: leave-one-out]
  logical thinking -> men  [Sample 3: leave-one-out]

# Checking $faithfulness$

test runs the LLM on the test data to filter out the ones where it succeeds. 

In [203]:
test = functional.filter_relation_samples_based_on_provided_fewshots(
    mt=mt, test_relation=test, prompt_template=operator.prompt_template, batch_size=4
)

In [204]:
sample = test.samples[0]
print(sample)
operator(subject = sample.subject).predictions

adventurousness -> men


[PredictedToken(token=' women', prob=0.7289615869522095),
 PredictedToken(token=' men', prob=0.24801670014858246),
 PredictedToken(token=' girls', prob=0.004989052657037973),
 PredictedToken(token=' both', prob=0.0016324118478223681),
 PredictedToken(token=' females', prob=0.0014979852130636573)]

In [205]:
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject],
    h_layer= operator.h_layer,
)
h = hs_and_zs.h_by_subj[sample.subject]

## Approximating LM computation $F$ as an affine transformation

### $$ F(\mathbf{s}, c_r) \approx \beta \, W_r \mathbf{s} + b_r $$

In [206]:
z = operator.beta * (operator.weight @ h) + operator.bias

lens.logit_lens(
    mt = mt,
    h = z,
    get_proba = True
)

([(' women', 0.729),
  (' men', 0.248),
  (' girls', 0.005),
  (' both', 0.002),
  (' females', 0.001),
  (' boys', 0.001),
  (' femin', 0.001),
  (' creativity', 0.001),
  (' the', 0.001),
  (' intro', 0.001)],
 {})

In [207]:
test.samples

[RelationSample(subject='adventurousness', object='men'),
 RelationSample(subject='aggressiveness', object='men'),
 RelationSample(subject='ambition', object='men'),
 RelationSample(subject='assertiveness', object='men'),
 RelationSample(subject='bravery', object='men'),
 RelationSample(subject='compassion', object='women'),
 RelationSample(subject='competitiveness', object='men'),
 RelationSample(subject='confidence', object='men'),
 RelationSample(subject='critical thinking', object='men'),
 RelationSample(subject='decisiveness', object='men'),
 RelationSample(subject='discipline', object='men'),
 RelationSample(subject='empathy', object='women'),
 RelationSample(subject='endurance', object='men'),
 RelationSample(subject='independence', object='men'),
 RelationSample(subject='intuition', object='women'),
 RelationSample(subject='leadership', object='men'),
 RelationSample(subject='resilience', object='men'),
 RelationSample(subject='risk-taking', object='men'),
 RelationSample(subje

In [208]:
correct = 0
wrong = 0
for sample in test.samples:
    predictions = operator(subject = sample.subject).predictions
    known_flag = functional.is_nontrivial_prefix(
        prediction=predictions[0].token, target=sample.object
    )
    print(f"{sample.subject=}, {sample.object=}, ", end="")
    print(f'predicted="{functional.format_whitespace(predictions[0].token)}", (p={predictions[0].prob}), known=({functional.get_tick_marker(known_flag)})')
    
    correct += known_flag
    wrong += not known_flag
    
faithfulness = correct/(correct + wrong)

print("------------------------------------------------------------")
print(f"Faithfulness (@1) = {faithfulness}")
print("------------------------------------------------------------")

sample.subject='adventurousness', sample.object='men', predicted=" women", (p=0.7289615869522095), known=(✗)
sample.subject='aggressiveness', sample.object='men', predicted=" women", (p=0.663330614566803), known=(✗)


sample.subject='ambition', sample.object='men', predicted=" men", (p=0.5623860359191895), known=(✓)
sample.subject='assertiveness', sample.object='men', predicted=" women", (p=0.8608171343803406), known=(✗)
sample.subject='bravery', sample.object='men', predicted=" women", (p=0.5630571842193604), known=(✗)
sample.subject='compassion', sample.object='women', predicted=" women", (p=0.8373768925666809), known=(✓)
sample.subject='competitiveness', sample.object='men', predicted=" women", (p=0.5813226103782654), known=(✗)
sample.subject='compassion', sample.object='women', predicted=" women", (p=0.8373768925666809), known=(✓)
sample.subject='competitiveness', sample.object='men', predicted=" women", (p=0.5813226103782654), known=(✗)
sample.subject='confidence', sample.object='men', predicted=" women", (p=0.539031982421875), known=(✗)
sample.subject='critical thinking', sample.object='men', predicted=" women", (p=0.623507022857666), known=(✗)
sample.subject='decisiveness', sample.object='men

In [209]:
# Let's see what the actual prompt looks like
print("=" * 80)
print("PROMPT TEMPLATE STORED IN OPERATOR:")
print("=" * 80)
print(repr(operator.prompt_template))
print("\n" + "=" * 80)
print("EXAMPLE PROMPT FOR FIRST TEST SAMPLE:")
print("=" * 80)
example_prompt = operator.prompt_template.format(test.samples[0].subject)
print(example_prompt)
print("=" * 80)

PROMPT TEMPLATE STORED IN OPERATOR:
'<|endoftext|>persuasiveness is commonly associated with men\nmultitasking is commonly associated with women\nlogical thinking is commonly associated with men\nnurturing is commonly associated with women\n{} is commonly associated with'

EXAMPLE PROMPT FOR FIRST TEST SAMPLE:
<|endoftext|>persuasiveness is commonly associated with men
multitasking is commonly associated with women
logical thinking is commonly associated with men
nurturing is commonly associated with women
adventurousness is commonly associated with


# $causality$

In [210]:
################### hparams ###################
rank = 100
###############################################

In [211]:
experiment_utils.set_seed(12345) # set seed to a constant value for sampling consistency
test_targets = functional.random_edit_targets(test.samples)

## setup

In [212]:
source = test.samples[0]
target = test_targets[source]

f"Changing the mapping ({source}) to ({source.subject} -> {target.object})"

'Changing the mapping (adventurousness -> men) to (adventurousness -> women)'

### Calculate $\Delta \mathbf{s}$ such that $\mathbf{s} + \Delta \mathbf{s} \approx \mathbf{s}'$

<p align="center">
    <img align="center" src="causality-crop.png" style="width:80%;"/>
</p>

Under the relation $r =\, $*plays the instrument*, and given the subject $s =\, $*Miles Davis*, the model will predict $o =\, $*trumpet* **(a)**; and given the subject $s' =\, $*Cat Stevens*, the model will now predict $o' =\, $*guiter* **(b)**. 

If the computation from $\mathbf{s}$ to $\mathbf{o}$ is well-approximated by $operator$ parameterized by $W_r$ and $b_r$ **(c)**, then $\Delta{\mathbf{s}}$ **(d)** should tell us the direction of change from $\mathbf{s}$ to $\mathbf{s}'$. Thus, $\tilde{\mathbf{s}}=\mathbf{s}+\Delta\mathbf{s}$ would be an approximation of $\mathbf{s}'$ and patching $\tilde{\mathbf{s}}$ in place of $\mathbf{s}$ should change the prediction to $o'$ = *guitar* 

In [213]:
def get_delta_s(
    operator, 
    source_subject, 
    target_subject,
    rank = 100,
    fix_latent_norm = None, # if set, will fix the norms of z_source and z_target
):
    w_p_inv = functional.low_rank_pinv(
        matrix = operator.weight,
        rank=rank,
    )
    hs_and_zs = functional.compute_hs_and_zs(
        mt = mt,
        prompt_template = operator.prompt_template,
        subjects = [source_subject, target_subject],
        h_layer= operator.h_layer,
        z_layer=-1,
    )

    z_source = hs_and_zs.z_by_subj[source_subject]
    z_target = hs_and_zs.z_by_subj[target_subject]
    
    z_source *= fix_latent_norm / z_source.norm() if fix_latent_norm is not None else 1.0
    z_target *= z_source.norm() / z_target.norm() if fix_latent_norm is not None else 1.0

    delta_s = w_p_inv @  (z_target.squeeze() - z_source.squeeze())

    return delta_s, hs_and_zs

delta_s, hs_and_zs = get_delta_s(
    operator = operator,
    source_subject = source.subject,
    target_subject = target.subject,
    rank = rank
)

In [214]:
import baukit

def get_intervention(h, int_layer, subj_idx):
    def edit_output(output, layer):
        if(layer != int_layer):
            return output
        functional.untuple(output)[:, subj_idx] = h 
        return output
    return edit_output

prompt = operator.prompt_template.format(source.subject)

h_index, inputs = functional.find_subject_token_index(
    mt=mt,
    prompt=prompt,
    subject=source.subject,
)

h_layer, z_layer = models.determine_layer_paths(model = mt, layers = [layer, -1])

with baukit.TraceDict(
    mt.model, layers = [h_layer, z_layer],
    edit_output=get_intervention(
#         h = hs_and_zs.h_by_subj[source.subject],         # let the computation proceed as usual
        h = hs_and_zs.h_by_subj[source.subject] + delta_s, # replace s with s + delta_s
        int_layer = h_layer, 
        subj_idx = h_index
    )
) as traces:
    outputs = mt.model(
        input_ids = inputs.input_ids,
        attention_mask = inputs.attention_mask,
    )

lens.interpret_logits(
    mt = mt, 
    logits = outputs.logits[0][-1], 
    get_proba=True
)

[(' men', 0.491),
 (' women', 0.476),
 (' girls', 0.004),
 (' both', 0.003),
 (' the', 0.002),
 (' boys', 0.002),
 ('\n', 0.002),
 (' males', 0.002),
 (' people', 0.001),
 (' females', 0.001)]

## Measuring causality

In [215]:
from src.editors import LowRankPInvEditor

svd = torch.svd(operator.weight.float())
editor = LowRankPInvEditor(
    lre=operator,
    rank=rank,
    svd=svd,
)

In [216]:
# precomputing latents to speed things up
hs_and_zs = functional.compute_hs_and_zs(
    mt = mt,
    prompt_template = operator.prompt_template,
    subjects = [sample.subject for sample in test.samples],
    h_layer= operator.h_layer,
    z_layer=-1,
    batch_size = 2
)

success = 0
fails = 0

for sample in test.samples:
    target = test_targets.get(sample)
    assert target is not None
    edit_result = editor(
        subject = sample.subject,
        target = target.subject
    )
    
    success_flag = functional.is_nontrivial_prefix(
        prediction=edit_result.predicted_tokens[0].token, target=target.object
    )
    
    print(f"Mapping {sample.subject} -> {target.object} | edit result={edit_result.predicted_tokens[0]} | success=({functional.get_tick_marker(success_flag)})")
    
    success += success_flag
    fails += not success_flag
    
causality = success / (success + fails)

print("------------------------------------------------------------")
print(f"Causality (@1) = {causality}")
print("------------------------------------------------------------")

Mapping adventurousness -> women | edit result= men (p=0.495) | success=(✗)
Mapping aggressiveness -> women | edit result= women (p=0.565) | success=(✓)
Mapping ambition -> women | edit result= women (p=0.696) | success=(✓)
Mapping assertiveness -> women | edit result= women (p=0.605) | success=(✓)
Mapping bravery -> women | edit result= women (p=0.671) | success=(✓)
Mapping compassion -> men | edit result= men (p=0.867) | success=(✓)
Mapping assertiveness -> women | edit result= women (p=0.605) | success=(✓)
Mapping bravery -> women | edit result= women (p=0.671) | success=(✓)
Mapping compassion -> men | edit result= men (p=0.867) | success=(✓)
Mapping competitiveness -> women | edit result= women (p=0.705) | success=(✓)
Mapping confidence -> women | edit result= women (p=0.794) | success=(✓)
Mapping critical thinking -> women | edit result= men (p=0.516) | success=(✗)
Mapping competitiveness -> women | edit result= women (p=0.705) | success=(✓)
Mapping confidence -> women | edit resu