# Latent Circuits

We construct a toy model which implements a known circuit (either AND, or OR).

Aim: confirm that patching in two directions recovers latent components.

In [1]:
%load_ext autoreload
%autoreload 2

import torch
import torch.nn as nn
import torch.optim as optim
from torch import Tensor
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device, get_act_name

from attribution_methods import integrated_gradients, activation_patching, highlight_components, asymmetry_score
from testing import Task, TaskDataset, logit_diff_metric, average_correlation, measure_overlap, test_multi_ablated_performance
from plotting import plot_attn, plot_attn_comparison, plot_correlation, plot_correlation_comparison, plot_bar_chart

from captum.attr import LayerIntegratedGradients

  from .autonotebook import tqdm as notebook_tqdm


### Set up

In [2]:
# Define a simple feedforward neural network
class ANDORNet(nn.Module):
    def __init__(self):
        super(ANDORNet, self).__init__()
        self.model = nn.Sequential(
            # nn.Linear(2, 2),
            # nn.ReLU(),
            nn.Linear(2, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.model(x)
    
    def run_with_cache(self, x):
        cache = {}

        # Hook function to save output
        def save_activation(name):
            def hook(module, input, output):
                cache[name] = output.detach()
            return hook

        # Register hooks
        handles = []
        for idx, layer in enumerate(self.model):
            handles.append(layer.register_forward_hook(save_activation(f"layer_{idx}")))

        # Run forward pass
        with torch.no_grad():
            output = self.forward(x)

        # Clean up hooks
        for handle in handles:
            handle.remove()

        return output, cache

In [3]:
def toy_activation_patching(model: nn.Module, baseline_inputs, corrupt_inputs):
    """Attribution scores for neuron A and B in layer 0."""
    n_samples = baseline_inputs.size(0)
    baseline_diff = model(baseline_inputs) - model(corrupt_inputs)
    print(f"Patch {corrupt_inputs} into {baseline_inputs}")

    # Run the model, but patch in the given value at the target layer
    attributions = torch.zeros((n_samples, 2))
    for neuron_idx in range(2):
        # Corrupt specific neuron activations
        corrupted_value = baseline_inputs.clone()
        corrupted_value[:, neuron_idx] = corrupt_inputs[:, neuron_idx]

        patch_diff = model(baseline_inputs) - model(corrupted_value)
        attributions[:, neuron_idx] = patch_diff / baseline_diff
    
    return attributions

## AND Circuit

### Construct toy model

In [15]:
# Training data for AND logic gate
X = torch.tensor([[0., 0.],
                  [0., 1.],
                  [1., 0.],
                  [1., 1.]])

y = torch.tensor([[0.],
                  [0.],
                  [0.],
                  [1.]])

In [16]:
# Initialize the network, loss function and optimizer
and_model = ANDORNet()
criterion = nn.BCELoss()  # Binary classification
optimizer = optim.SGD(and_model.parameters(), lr=0.1)

# Training loop
with torch.enable_grad():
    for epoch in range(1000):
        shuffle_order = torch.randperm(X.size(0))
        shuffled_X = X[shuffle_order]
        shuffled_y = y[shuffle_order]

        optimizer.zero_grad()
        outputs = and_model(shuffled_X)
        loss = criterion(outputs, shuffled_y)
        loss.backward()
        optimizer.step()
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# Evaluate model
with torch.no_grad():
    preds = and_model(X)
    print("Predictions:")
    print(preds.round())  # Round predictions to get binary output

Epoch 0, Loss: 0.6489
Epoch 100, Loss: 0.4553
Epoch 200, Loss: 0.3584
Epoch 300, Loss: 0.2983
Epoch 400, Loss: 0.2569
Epoch 500, Loss: 0.2262
Epoch 600, Loss: 0.2024
Epoch 700, Loss: 0.1832
Epoch 800, Loss: 0.1673
Epoch 900, Loss: 0.1540
Predictions:
tensor([[0.],
        [0.],
        [0.],
        [1.]])


In [35]:
for name, param in and_model.named_parameters():
    print(f"{name}: {param.data}")

model.0.weight: tensor([[3.0482, 3.0556]])
model.0.bias: tensor([-4.7946])


### Run attributions

In [17]:
clean_input = torch.tensor([[1., 1.]], requires_grad=True)
corrupt_input = torch.tensor([[0., 0.]], requires_grad=True)
positive_output, clean_cache = and_model.run_with_cache(clean_input)
negative_output, corrupt_cache = and_model.run_with_cache(corrupt_input)

print(positive_output, clean_cache)
print(negative_output, corrupt_cache)

tensor([[0.7872]]) {'layer_0': tensor([[1.3082]]), 'layer_1': tensor([[0.7872]])}
tensor([[0.0082]]) {'layer_0': tensor([[-4.7919]]), 'layer_1': tensor([[0.0082]])}


In [18]:
# Run standard integrated gradients for both directions

ig_and = LayerIntegratedGradients(and_model, and_model.model[0], multiply_by_inputs=True)

# Standard integrated gradients for clean input
ig_and_clean_zero = ig_and.attribute(inputs=clean_input, internal_batch_size=1, attribute_to_layer_input=True)
print(f"Standard integrated gradients for clean input: {ig_and_clean_zero}")

# Standard integrated gradients for corrupt input
ig_and_corrupt_zero = ig_and.attribute(inputs=corrupt_input, internal_batch_size=1, attribute_to_layer_input=True)
print(f"Standard integrated gradients for corrupt input: {ig_and_corrupt_zero}")

Standard integrated gradients for clean input: tensor([[0.3901, 0.3889]])
Standard integrated gradients for corrupt input: tensor([[0., 0.]])


In [41]:
# Run activation patching in both directions

with torch.no_grad():
    # Patch clean into corrupt
    ap_and_clean_corrupt = toy_activation_patching(and_model, corrupt_input, clean_input)
    print(f"Clean->Corrupt: {ap_and_clean_corrupt}\n")
    # Patch corrupt into clean
    ap_and_corrupt_clean = toy_activation_patching(and_model, clean_input, corrupt_input)
    print("Corrupt->Clean", ap_and_corrupt_clean)

Patch tensor([[1., 1.]], requires_grad=True) into tensor([[0., 0.]], requires_grad=True)
Clean->Corrupt: tensor([[0.1800, 0.1813]])

Patch tensor([[0., 0.]], requires_grad=True) into tensor([[1., 1.]], requires_grad=True)
Corrupt->Clean tensor([[0.8187, 0.8200]])


In [19]:
# Run integrated gradients for both directions
ig_and = LayerIntegratedGradients(and_model, and_model.model[0], multiply_by_inputs=True)

# Patch clean into corrupt
ig_and_clean_corrupt = ig_and.attribute(
    inputs=clean_input, baselines=corrupt_input, internal_batch_size=1, attribute_to_layer_input=True
)
print(f"Integrated gradients for clean input patched into corrupt input: {ig_and_clean_corrupt}")

# Patch corrupt into clean
ig_and_corrupt_clean = ig_and.attribute(
    inputs=corrupt_input, baselines=clean_input, internal_batch_size=1, attribute_to_layer_input=True
)
print(f"Integrated gradients for corrupt input patched into clean input: {ig_and_corrupt_clean}")

Integrated gradients for clean input patched into corrupt input: tensor([[0.3901, 0.3889]])
Integrated gradients for corrupt input patched into clean input: tensor([[-0.3901, -0.3889]])


In the example of the AND toy model, we can see that:

- IG with corrupt baseline attributes roughly even importance (0.5) to neurons A and B. This is correct and holds for gradients in both directions (order of clean/corrupt as input/baseline only affects sign).
- Activation patching from clean ([1, 1]) to corrupt ([0, 0]) suggests that both A and B have low attribution scores, because patching either A or B does not affect the output. Denoising fails to identify the circuit components.
- Activation patching from corrupt ([0, 0]) to clean ([1, 1]) suggests that both A and B have high attribution scores, because patching either A or B affects the output.

Therefore, IG with corrupt baseline correctly identifies AND circuit components, but AP only detects AND circuit components when patching from corrupt to clean.

## OR Circuit

### Construct toy model

In [4]:
# Training data for OR logic gate
X = torch.tensor([[0., 0.],
                  [0., 1.],
                  [1., 0.],
                  [1., 1.]])

y = torch.tensor([[0.],
                  [1.],
                  [1.],
                  [1.]])

In [5]:
# Initialize the model, loss function, and optimizer
or_model = ANDORNet()
criterion = nn.BCELoss()
optimizer = optim.SGD(or_model.parameters(), lr=0.1)

# Training loop
with torch.enable_grad():
    for epoch in range(1000):
        shuffle_order = torch.randperm(X.size(0))
        shuffled_X = X[shuffle_order]
        shuffled_y = y[shuffle_order]

        optimizer.zero_grad()
        outputs = or_model(shuffled_X)
        loss = criterion(outputs, shuffled_y)
        loss.backward()
        optimizer.step()
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

# Evaluate
with torch.no_grad():
    preds = or_model(X)
    print("Predictions:")
    print(preds.round())

Epoch 0, Loss: 0.7877
Epoch 100, Loss: 0.3939
Epoch 200, Loss: 0.2994
Epoch 300, Loss: 0.2392
Epoch 400, Loss: 0.1980
Epoch 500, Loss: 0.1683
Epoch 600, Loss: 0.1460
Epoch 700, Loss: 0.1286
Epoch 800, Loss: 0.1147
Epoch 900, Loss: 0.1034
Predictions:
tensor([[0.],
        [1.],
        [1.],
        [1.]])


In [45]:
for name, param in or_model.named_parameters():
    print(f"{name}: {param.data}")

model.0.weight: tensor([[4.0211, 4.0242]])
model.0.bias: tensor([-1.4725])


### Run attributions

In [12]:
clean_input = torch.tensor([[1., 1.]], requires_grad=True)
corrupt_input = torch.tensor([[0., 0.]], requires_grad=True)
positive_output, clean_cache = or_model.run_with_cache(clean_input)
negative_output, corrupt_cache = or_model.run_with_cache(corrupt_input)

print(positive_output, clean_cache)
print(negative_output, corrupt_cache)

tensor([[0.9984]]) {'layer_0': tensor([[6.4254]]), 'layer_1': tensor([[0.9984]])}
tensor([[0.1952]]) {'layer_0': tensor([[-1.4166]]), 'layer_1': tensor([[0.1952]])}


In [10]:
# Run standard integrated gradients for both directions

ig_or = LayerIntegratedGradients(or_model, or_model.model[0], multiply_by_inputs=True)

# Standard integrated gradients for clean input
ig_or_clean_zero = ig_or.attribute(inputs=clean_input, internal_batch_size=1, attribute_to_layer_input=True)
print(f"Standard integrated gradients for clean input (zero baseline): {ig_or_clean_zero}")

# Standard integrated gradients for corrupt input
ig_or_corrupt_zero = ig_or.attribute(inputs=corrupt_input, internal_batch_size=1, attribute_to_layer_input=True)
print(f"Standard integrated gradients for corrupt input (zero baseline): {ig_or_corrupt_zero}")

Standard integrated gradients for clean input (zero baseline): tensor([[0.3997, 0.4035]])
Standard integrated gradients for corrupt input (zero baseline): tensor([[0., 0.]])


In [13]:
# Patch clean into corrupt
ig_or_clean_corrupt = ig_or.attribute(
    inputs=clean_input, baselines=corrupt_input, internal_batch_size=1, attribute_to_layer_input=True
)
print(f"Integrated gradients for clean input patched into corrupt input: {ig_or_clean_corrupt}")

# Patch corrupt into clean
ig_or_corrupt_clean= ig_or.attribute(
    inputs=corrupt_input, baselines=clean_input, internal_batch_size=1, attribute_to_layer_input=True
)
print(f"Integrated gradients for corrupt input patched into clean input: {ig_or_corrupt_clean}")


Integrated gradients for clean input patched into corrupt input: tensor([[0.3997, 0.4035]])
Integrated gradients for corrupt input patched into clean input: tensor([[-0.3997, -0.4035]])


In [9]:
with torch.no_grad():
    # Patch clean into corrupt
    ap_or_clean_corrupt = toy_activation_patching(or_model, corrupt_input, clean_input)
    print(f"Clean->Corrupt: {ap_or_clean_corrupt}\n")
    # Patch corrupt into clean
    ap_or_corrupt_clean = toy_activation_patching(or_model, clean_input, corrupt_input)
    print("Corrupt->Clean", ap_or_corrupt_clean)

Patch tensor([[1., 1.]], requires_grad=True) into tensor([[0., 0.]], requires_grad=True)
Clean->Corrupt: tensor([[0.9063, 0.9096]])

Patch tensor([[0., 0.]], requires_grad=True) into tensor([[1., 1.]], requires_grad=True)
Corrupt->Clean tensor([[0.0904, 0.0937]])


For similar reasons as the AND model, for the OR model:

- IG with corrupt baseline correctly identifies that both A and B are important.
- Activation patching from clean to corrupt also identifies A and B as important.
- Activation patching from corrupt ([0, 0]) to clean ([1, 1]) does not identify either A or B as important, because patching either neuron does not affect the OR circuit.

## Conclusions

Integrated gradients with corrupt activations as the baseline can correctly assign equal, non-trivial attribution scores to both latent components in AND and OR circuits. Activation patching in one direction may fail to identify latent components.