# Latent components

Latent components are model components which only activate in conjunction with other components. They can only be detected by identifying differences in attribution scores between activation patching from clean to corrupt, and activation patching from corrupt to clean.

In [None]:
import torch
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device

from attribution_methods import integrated_gradients, activation_patching
from testing import Task, TaskDataset, logit_diff_metric, average_correlation, test_ablated_performance
from plotting import plot_attn, plot_attn_comparison, plot_correlation, plot_correlation_comparison

In [None]:
torch.set_grad_enabled(False)

device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

## Approximating activation patching

Hypothesis: integrated gradients can approximate activation patching in either direction.

To evaluate this hypothesis, we use 100 samples from the IOI dataset and run:

- Activation patching from clean to corrupt
- Activation patching from corrupt to clean
- Integrated gradients with corrupt input and clean baseline
- Integrated gradients with clean input and corrupt baseline

We then compare the resulting attribution scores from the four runs.

### Experiment

In [None]:
ioi_dataset = TaskDataset(Task.IOI)
ioi_dataloader = ioi_dataset.to_dataloader(batch_size=100)

clean_input, corrupted_input, labels = next(iter(ioi_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

In [None]:
ig_clean_corrupt_mlp, ig_clean_corrupt_attn = integrated_gradients(model, clean_tokens, clean_cache, corrupted_cache, logit_diff_metric, labels)

torch.save(ig_clean_corrupt_mlp, "results/latent_components/ig_clean_corrupt_mlp.pt")
torch.save(ig_clean_corrupt_attn, "results/latent_components/ig_clean_corrupt_attn.pt")

In [None]:
ap_clean_corrupt_mlp, ap_clean_corrupt_attn = activation_patching(model, clean_tokens, clean_cache, clean_logit_diff, corrupted_cache, corrupted_logit_diff, logit_diff_metric, labels)

torch.save(ap_clean_corrupt_mlp, "results/latent_components/ap_clean_corrupt_mlp.pt")
torch.save(ap_clean_corrupt_attn, "results/latent_components/ap_clean_corrupt_attn.pt")

In [None]:
ig_corrupt_clean_mlp, ig_corrupt_clean_attn = integrated_gradients(model, corrupted_tokens, corrupted_cache, clean_cache, logit_diff_metric, labels)

torch.save(ig_corrupt_clean_mlp, "results/latent_components/ig_corrupt_clean_mlp.pt")
torch.save(ig_corrupt_clean_attn, "results/latent_components/ig_corrupt_clean_attn.pt")

In [None]:
ap_corrupt_clean_mlp, ap_corrupt_clean_attn = activation_patching(model, corrupted_tokens, corrupted_cache, corrupted_logit_diff, clean_cache, clean_logit_diff, logit_diff_metric, labels)

torch.save(ap_corrupt_clean_mlp, "results/latent_components/ap_corrupt_clean_mlp.pt")
torch.save(ap_corrupt_clean_attn, "results/latent_components/ap_corrupt_clean_attn.pt")

### Analysis

We first visualise ten results from each of the four runs. We then calculate the average correlation coefficient across all 100 samples.

In [None]:
ig_corrupt_clean_mlp = torch.load("results/latent_components/ig_corrupt_clean_mlp.pt")
ap_corrupt_clean_mlp = torch.load("results/latent_components/ap_corrupt_clean_mlp.pt")

ig_corrupt_clean_attn = torch.load("results/latent_components/ig_corrupt_clean_attn.pt")
ap_corrupt_clean_attn = torch.load("results/latent_components/ap_corrupt_clean_attn.pt")

ig_clean_corrupt_mlp = torch.load("results/latent_components/ig_clean_corrupt_mlp.pt")
ap_clean_corrupt_mlp = torch.load("results/latent_components/ap_clean_corrupt_mlp.pt")

ig_clean_corrupt_attn = torch.load("results/latent_components/ig_clean_corrupt_attn.pt")
ap_clean_corrupt_attn = torch.load("results/latent_components/ap_clean_corrupt_attn.pt")

In [None]:
plot_attn_comparison(ig_clean_corrupt_attn[:10], ap_clean_corrupt_attn[:10], model, "IG clean baseline -> corrupt input", "AP clean -> corrupt")

plot_attn_comparison(ig_corrupt_clean_attn[:10], ap_corrupt_clean_attn, model[:10], "IG corrupt baseline -> clean input", "AP corrupt -> clean")

In [None]:
print("Correlation between clean -> corrupt IG and AP")
plot_correlation_comparison(ig_clean_corrupt_mlp[:10], ap_clean_corrupt_mlp[:10], ig_clean_corrupt_attn[:10], ap_clean_corrupt_attn[:10], Task.IOI)

In [None]:
print("Correlation between corrupt -> clean IG and AP")
plot_correlation_comparison(ig_corrupt_clean_mlp[:10], ap_corrupt_clean_mlp[:10], ig_corrupt_clean_attn[:10[:10], Task.IOI)

In [None]:
print(f"Average correlation between clean->corrupt IG and AP (MLP): {average_correlation(ig_clean_corrupt_mlp, ap_clean_corrupt_mlp)}")
print(f"Average correlation between clean->corrupt IG and AP (attention): {average_correlation(ig_clean_corrupt_attn, ap_clean_corrupt_attn)}")

print(f"Average correlation between corrupt->clean IG and AP (MLP): {average_correlation(ig_clean_corrupt_mlp, ap_clean_corrupt_mlp)}")
print(f"Average correlation between corrupt->clean IG and AP (attention): {average_correlation(ig_clean_corrupt_attn, ap_clean_corrupt_attn)}")

## Efficiently identifying latent components

Hypothesis: we can efficiently identify latent components using two passes of integrated gradients.