# Mediators

We aim to investigate exactly how identified latent components affect other components.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import Tensor
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device, get_act_name

from attribution_methods import integrated_gradients, activation_patching, highlight_components
from testing import Task, TaskDataset, logit_diff_metric, average_correlation, measure_overlap, test_multi_ablated_performance
from plotting import plot_attn, plot_attn_comparison, plot_correlation, plot_correlation_comparison, plot_bar_chart

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
torch.set_grad_enabled(False)
torch.cuda.empty_cache()

device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Specific Example

For a specific example, we will ablate each component and measure the change in activations across all other components in the layer immediately downstream. We then analyse how much identified latent components change other components' activations.

### Experiment - measure change in activations per ablation

In [12]:
clean_input = "When John and Mary went to the store, John gave a bottle of milk to "
corrupted_input = "When John and Mary went to the store, Mary gave a bottle of milk to "
labels = torch.tensor([
    [model.to_tokens("Mary", prepend_bos=False).item(), 
     model.to_tokens("John", prepend_bos=False).item()]
])

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

Clean logit difference: tensor([0.6779], device='cuda:0')
Corrupted logit difference: tensor([0.1698], device='cuda:0')


In [18]:
def activation_cache_to_tensor(model, activation_cache, get_attn: bool):
    """
    Convert the activation cache to a tensor of shape:
    - (num_layers, num_heads, d_model) for attention heads
    - (num_layers, d_model) for MLP heads
    """
    matrix = []
    for layer in range(model.cfg.n_layers):
        if get_attn:
            hook_name = get_act_name("result", layer)
        else:
            hook_name = get_act_name("post", layer)
        
        matrix.append(activation_cache[hook_name][0, -1])
    
    matrix = torch.stack(matrix, dim=0)
    return matrix

In [31]:
baseline_attn_cache = activation_cache_to_tensor(model, clean_cache, get_attn=True)
print("Baseline attention cache shape:", baseline_attn_cache.shape)

baseline_mlp_cache = activation_cache_to_tensor(model, clean_cache, get_attn=False)
print("Baseline MLP cache shape:", baseline_mlp_cache.shape)

Baseline attention cache shape: torch.Size([12, 12, 768])
Baseline MLP cache shape: torch.Size([12, 3072])


In [None]:
# How much does each head affect other heads in other layers?
ablation_act_diff_attn = torch.zeros((model.cfg.n_layers * model.cfg.n_heads, model.cfg.n_layers, model.cfg.n_heads))

for layer in range(model.cfg.n_layers):
    for head in range(model.cfg.n_heads):
        hook_name = get_act_name("result", layer)

        def ablate_hook(act, hook):
            act[:, head] = corrupted_cache[hook.name][:, head]
            return act
        
        with model.hooks(fwd_hooks=[(hook_name, ablate_hook)]):
            _, patched_attn_cache = model.run_with_cache(clean_tokens)
        
        diff = activation_cache_to_tensor(model, patched_attn_cache, get_attn=True) - baseline_attn_cache
        print("Number of components with changed activations:", torch.count_nonzero(diff.mean(-1)).item())
        ablation_act_diff_attn[layer * model.cfg.n_heads + head] = diff.mean(dim=-1)

torch.save(ablation_act_diff_attn, "results/mediators/ablation_act_diff_attn.pt")


Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 101376
Number of components with changed activations: 101370
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of components with changed activations: 0
Number of 

In [None]:
# How much does each neuron affect other neurons in other layers?
ablation_act_diff_mlp = torch.zeros((model.cfg.n_layers * model.cfg.d_mlp, model.cfg.n_layers, model.cfg.d_mlp))

for layer in range(model.cfg.n_layers):
    for neuron in range(model.cfg.d_mlp):
        hook_name = get_act_name("post", layer)

        def ablate_hook(act, hook):
            act[:, head] = corrupted_cache[hook.name][:, head]
            return act
        
        with model.hooks(fwd_hooks=[(hook_name, ablate_hook)]):
            _, patched_mlp_cache = model.run_with_cache(clean_tokens)
        
        diff = activation_cache_to_tensor(model, patched_mlp_cache, get_attn=False) - baseline_mlp_cache
        print("Number of components with changed activations:", torch.count_nonzero(diff).item())
        ablation_act_diff_mlp[layer * model.cfg.d_mlp + neuron] = diff

torch.save(ablation_act_diff_mlp, "results/mediators/ablation_act_diff_mlp.pt")


Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations: 33775
Number of components with changed activations:

### Experiment - identify outliers

In [None]:
# Run integrated gradients in opposite directions

mlp_corrupt_clean, attn_corrupt_clean = integrated_gradients(
    model,
    clean_tokens,
    clean_cache,
    corrupted_cache,
    logit_diff_metric,
    labels
)

torch.save(mlp_corrupt_clean, "results/mediators/mlp_corrupt_clean.pt")
torch.save(attn_corrupt_clean, "results/mediators/attn_corrupt_clean.pt")

In [None]:
mlp_clean_corrupt, attn_clean_corrupt = integrated_gradients(
    model,
    corrupted_tokens,
    corrupted_cache,
    clean_cache,
    logit_diff_metric,
    labels
)

torch.save(mlp_clean_corrupt, "results/mediators/mlp_clean_corrupt.pt")
torch.save(attn_clean_corrupt, "results/mediators/attn_clean_corrupt.pt")

In [None]:
# Identify latent components

mlp_latent = highlight_components(mlp_corrupt_clean)[0] ^ highlight_components(mlp_clean_corrupt)[0]
attn_latent = highlight_components(attn_corrupt_clean)[0] ^ highlight_components(attn_clean_corrupt)[0]

mlp_latent_indices = mlp_latent.nonzero()
attn_latent_indices = attn_latent.nonzero()

### Analysis

In [None]:
baseline_act_diff_attn = ablation_act_diff_attn.mean()
baseline_act_diff_mlp = ablation_act_diff_mlp.mean()

In [None]:
latent_act_diffs = {"None": baseline_act_diff_attn}

for layer, idx in attn_latent_indices:
    act_diff = ablation_act_diff_attn[layer * model.cfg.n_heads + idx]
    latent_act_diffs[(layer, idx)] = act_diff

    print(f"Layer {layer}, Head {idx}: {act_diff.mean()}")
    plot_attn(act_diff, model)