# Saturated Gradients

From the ablation experiments, we can see that IG assigns higher attribution scores than AP, but some of these attribution scores are overestimated. AP also underestimates the attribution scores for some heads!

- IG has more true positives, but also more false positives: IG has higher recall, but AP has higher precision.
- Overall the results between the methods are very similar.

What causes false positives in IG?

In [1]:
# Set up

%load_ext autoreload
%autoreload 2

import torch
import random
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device, get_act_name
import numpy as np
import matplotlib.pyplot as plt

from attribution_methods import run_from_layer_fn, compute_layer_to_output_attributions
from testing import Task, TaskDataset, logit_diff_metric, identify_outliers, test_single_ablated_performance
from plotting import plot_attn, plot_attn_comparison, plot_correlation, plot_correlation_comparison, plot_bar_chart

from split_ig import SplitIntegratedGradients, SplitLayerIntegratedGradients, compute_layer_to_output_attributions_split_ig, split_integrated_gradients

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
torch.set_grad_enabled(False)
torch.cuda.empty_cache()

device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

Loaded pretrained model gpt2-small into HookedTransformer


## Output shapes

Hypothesis: outliers overestimated by IG are due to the shape of output curve in between the baseline and inputs to IG.

- IG calculates change in loss based on integrating gradients between two input values.
- A high attribution score could be caused by strong gradients (sensitivity) up until an intermediate input value (in between the two input values). In this case, the highlighted component would be important for the task "in between" (represented by different counterfactual inputs) instead of the target task.

![Overestimation](reference/overestimation.png)

To test this, we can visualise the gradients for intervals which are summed up by IG. We focus on attention head (9, 6) because it is highlighted more strongly by IG than by AP.

In [3]:
ioi_dataset = TaskDataset(Task.IOI)
ioi_dataloader = ioi_dataset.to_dataloader(batch_size=10)

clean_input, corrupted_input, labels = next(iter(ioi_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

Clean logit difference: tensor([-0.0307, -0.9269, -0.4937,  2.2320,  0.6754,  4.0447, -0.1785,  1.1947,
         1.1514,  1.7507], device='cuda:0')
Corrupted logit difference: tensor([-0.0387, -0.9451, -0.5103,  2.2153,  0.6299, -3.2074, -0.1823,  1.1766,
        -3.0072,  1.7392], device='cuda:0')


In [14]:
from captum.attr._utils.approximation_methods import approximation_parameters

n_steps = 50

def visualise_attn_interpolated_outputs(target_layer_num, target_pos):
    hook_name = get_act_name("result", target_layer_num)
    visualise_interpolated_integrated_gradients(hook_name, target_layer_num, target_pos)


def visualise_mlp_interpolated_outputs(target_layer_num, target_pos):
    hook_name = get_act_name("pos", target_layer_num)
    visualise_interpolated_integrated_gradients(hook_name, target_layer_num, target_pos)    


def visualise_interpolated_integrated_gradients(hook_name, target_layer_num, target_pos):
    target_layer = model.hook_dict[hook_name]
    layer_clean_input = clean_cache[hook_name] # Baseline

    # Only corrupt at target head
    layer_corrupt_input = layer_clean_input.clone()
    layer_corrupt_input[:, :, target_pos] = corrupted_cache[hook_name][:, :, target_pos]

    # Take the model starting from the target layer
    forward_fn = lambda x: run_from_layer_fn(model, clean_input, target_layer, x, logit_diff_metric, labels)
    _, alphas_func = approximation_parameters("gausslegendre")
    alphas = alphas_func(n_steps)

    with torch.autograd.set_grad_enabled(True):
        interpolated_inputs = [layer_clean_input + alpha * (layer_corrupt_input - layer_clean_input) for alpha in alphas]
        outputs = [forward_fn(i) for i in interpolated_inputs]

    plt.title(f"Model output at interpolated gradients: {(target_layer_num, target_pos)}")
    plt.plot(alphas, [o.item() for o in outputs])
    plt.xlabel("Interpolation coefficient")
    plt.ylabel("Output (logit difference)")
    plt.ylim(0, 6)
    plt.show()

In [9]:
ig_mlp = torch.load("results/aligned/ioi/ig_mlp.pt")
ig_attn = torch.load("results/aligned/ioi/ig_attn.pt")

ap_mlp = torch.load("results/aligned/ioi/ap_mlp.pt")
ap_attn = torch.load("results/aligned/ioi/ap_attn.pt")

# Identify disagreements between the two attribution methods

scaled_ig_attn = ig_attn * 1e5
scaled_ig_attn = ig_attn * 1e5
attn_outliers = []
for i in range(ig_attn.size(0)):
    outliers = identify_outliers(scaled_ig_attn[i], ap_attn[i])
    attn_outliers.append(outliers)

scaled_ig_mlp = ig_mlp * 1e5
mlp_outliers = []
for i in range(ig_mlp.size(0)):
    outliers = identify_outliers(scaled_ig_mlp[i], ap_mlp[i])
    mlp_outliers.append(outliers)

In [15]:
for layer, idx in attn_outliers[0]:
    visualise_attn_interpolated_outputs(layer, idx)

for layer_idx, in mlp_outliers[0]:
    visualise_mlp_interpolated_outputs(layer, idx)

OutOfMemoryError: CUDA out of memory. Tried to allocate 474.00 MiB. GPU 0 has a total capacity of 10.75 GiB of which 209.44 MiB is free. Process 253904 has 1.74 GiB memory in use. Process 915710 has 6.47 GiB memory in use. Including non-PyTorch memory, this process has 2.23 GiB memory in use. Of the allocated memory 1.92 GiB is allocated by PyTorch, and 133.00 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

## Split Integrated Gradients

The shape of outputs in integrated gradients suggests that IG may overestimate some attribution values due to saturated gradients at interpolated inputs. To confirm this, we run SplitIG (which cuts off the interpolated inputs if the gradients are saturated) and examine the level of agreement.

### Sanity check: no splitting

We check that Split IG with a split ratio of 1 (i.e. no splitting) produces the same result as regular IG, for a random attention head (5, 5).

In [None]:
hook_name = get_act_name("result", 5)
target_layer = model.hook_dict[hook_name]
prev_layer_hook = get_act_name("z", 5)
prev_layer = model.hook_dict[prev_layer_hook]

layer_clean_input = clean_cache[prev_layer_hook]
layer_corrupt_input = corrupted_cache[prev_layer_hook]

# Shape [batch, seq_len, d_head, d_model]
left_ig, _, _ = compute_layer_to_output_attributions_split_ig(
    clean_tokens, layer_corrupt_input, layer_clean_input, target_layer, prev_layer, logit_diff_metric, labels, ratio=1)

original_attributions = compute_layer_to_output_attributions(
    model, clean_tokens, layer_corrupt_input, layer_clean_input, target_layer, prev_layer, logit_diff_metric, labels
)

assert torch.allclose(left_ig, original_attributions.detach().cpu()), f"Split IG does not produce expected IG result"

Verify that Split IG at ratio 1 produces the same outputs as standard IG.

In [None]:
from captum.attr import LayerIntegratedGradients

n_samples = clean_tokens.size(0)
forward_fn = lambda x: run_from_layer_fn(model, clean_tokens, prev_layer, x, logit_diff_metric, labels)

split_ig = SplitLayerIntegratedGradients(forward_fn, target_layer, multiply_by_inputs=True)
split_ig_attributions, _, _, interpolated_inputs = split_ig.attribute(inputs=layer_corrupt_input,
                                baselines=layer_clean_input,
                                internal_batch_size=n_samples, # Needs to match patching shape
                                attribute_to_layer_input=False,
                                return_convergence_delta=False)
# split_ig_attributions = split_ig_attributions.reshape((n_samples, 50,) + split_ig_attributions.shape[1:])
split_ig_attributions = np.reshape(split_ig_attributions.detach().cpu().numpy(), (n_samples, 50,) + split_ig_attributions.shape[1:], order='F')
split_ig_attributions = torch.tensor(split_ig_attributions).to(device).sum(dim=1)

ig_embed = LayerIntegratedGradients(forward_fn, target_layer, multiply_by_inputs=True)
ig_attributions = ig_embed.attribute(inputs=layer_corrupt_input,
                                baselines=layer_clean_input, 
                                internal_batch_size=n_samples,
                                attribute_to_layer_input=False,
                                return_convergence_delta=False)

assert torch.allclose(split_ig_attributions, ig_attributions), f"Split IG does not produces same output as IG"

### Run Split IG

We run Split IG on the same IOI dataset, using a split ratio of 0.9.

In [None]:
ioi_split_ig_mlp, ioi_split_ig_attn = split_integrated_gradients(
    model, clean_tokens, clean_cache, corrupted_cache, logit_diff_metric, labels, ratio=0.9
)

torch.save(ioi_split_ig_mlp, "results/saturated/ioi_split_ig_mlp.pt")
torch.save(ioi_split_ig_attn, "results/saturated/ioi_split_ig_attn.pt")

### Analysis

In [None]:
ioi_ig_mlp = torch.load("saved_results/ioi_ig_mlp.pt")
ioi_ig_attn = torch.load("saved_results/ioi_ig_attn.pt")

ioi_split_ig_mlp = torch.load("results/saturated/ioi_split_ig_mlp.pt")
ioi_split_ig_attn = torch.load("results/saturated/ioi_split_ig_attn.pt")

In [None]:
plot_attn_comparison(ioi_ig_attn[:3].unsqueeze(0), ioi_split_ig_attn[:3].unsqueeze(0), model, "Integrated Gradients", "Split Integrated Gradients (0.9)")

In [None]:
plot_correlation(ioi_ig_mlp, ioi_split_ig_mlp, "Integrated Gradients Attribution Scores", "Split IG Attribution Scores", "Attribution Scores for Neurons in IOI")


In [None]:
plot_correlation(ioi_ig_attn, ioi_split_ig_attn, "Integrated Gradients Attribution Scores", "Split IG Attribution Scores", "Attribution Scores for Attention Heads in IOI")


## Elimination of noise

We want to check if noisy (method-exclusive) components are eliminated under split IG.

To do this, we compare the attribution scores of method-exclusive components under IG and Split IG.

### Method-exclusive components in Split IG

In [None]:
attn_outliers_split_ig = {(layer, idx): ioi_split_ig_attn[layer][idx] for layer, idx in attn_outliers}
mlp_outliers_split_ig = {(layer, idx): ioi_split_ig_mlp[layer][idx] for layer, idx in mlp_outliers}

plot_bar_chart(attn_outliers_split_ig, "Attention Heads", "Split IG Attribution Scores", "Split IG Attribution Scores for Attention Head Outliers in IOI")
plot_bar_chart(mlp_outliers_split_ig, "MLP Neurons", "Split IG Attribution Scores", "Split IG Attribution Scores for MLP Neuron Outliers in IOI")

In [None]:
diff_attn = ioi_split_ig_attn - ioi_ig_attn
diff_mlp = ioi_split_ig_mlp - ioi_ig_mlp

In [None]:
attn_outliers_discrepancies = {(layer, idx): diff_attn[layer][idx] for layer, idx in attn_outliers}
mlp_outliers_discrepancies = {(layer, idx): diff_mlp[layer][idx] for layer, idx in mlp_outliers}

plot_bar_chart(attn_outliers_discrepancies, "Attention Heads", "Discrepancy in Attribution Scores", "Discrepancy between IG and Split IG for Attention Heads")
plot_bar_chart(mlp_outliers_discrepancies, "MLP Neurons", "Discrepancy in Attribution Scores", "Discrepancy between IG and Split IG for MLP Neurons")

### Ablation of IG exclusive components

To evaluate the noisiness of IG versus Split IG, we ablate IG-exclusive and SIG-exclusive components.

In [None]:
# Get the mean activations over a corrupt dataset

attn_outlier_hooks = [get_act_name("result", layer_idx) for layer_idx, _ in attn_outliers]
mlp_outlier_hooks = [get_act_name("post", layer_idx) for layer_idx, _ in mlp_outliers]

test_dataset = TaskDataset(Task.IOI)
random_dataloader = test_dataset.to_dataloader(batch_size=100, shuffle=True)
random_prompts, _, _ = next(iter(random_dataloader))

prompts_tokens = model.to_tokens(random_prompts)
_, prompt_cache = model.run_with_cache(
    prompts_tokens, 
    names_filter=lambda x: x in attn_outlier_hooks or x in mlp_outlier_hooks
)

mean_corrupt_activations = {}
for key in prompt_cache.keys():
    mean_values_over_prompts = torch.mean(prompt_cache[key], dim=0, keepdim=True)
    mean_corrupt_activations[key] = torch.mean(mean_values_over_prompts, dim=1, keepdim=True)

In [None]:
# Identify components highlighted by IG but not by Split IG
attn_ig_outliers = identify_outliers(ioi_ig_attn, ioi_split_ig_attn, only_collect_x_outliers=True)
mlp_ig_outliers = identify_outliers(ioi_ig_mlp, ioi_split_ig_mlp, only_collect_x_outliers=True)

# Ablate the components highlighted by IG but not by Split IG
ig_only_attn_ablated_scores = dict()
for layer, idx in attn_ig_outliers:
    performance = test_single_ablated_performance(model, layer, idx, mean_corrupt_activations, Task.IOI, is_attn=True)
    ig_only_attn_ablated_scores[(layer, idx)] = performance

ig_only_mlp_ablated_scores = dict()
for layer, idx in mlp_ig_outliers:
    performance = test_single_ablated_performance(model, layer, idx, mean_corrupt_activations, Task.IOI, is_attn=False)
    ig_only_mlp_ablated_scores[(layer, idx)] = performance

In [None]:
# Identify components highlighted by Split IG but not by IG
split_ig_only_attn_outliers = identify_outliers(ioi_split_ig_attn, ioi_ig_attn, only_collect_x_outliers=True)
split_ig_only_mlp_outliers = identify_outliers(ioi_split_ig_mlp, ioi_ig_mlp, only_collect_x_outliers=True)

# Ablate the components highlighted by Split IG but not by IG
split_ig_only_attn_ablated_scores = dict()
for layer, idx in split_ig_only_attn_outliers:
    performance = test_single_ablated_performance(model, layer, idx, mean_corrupt_activations, Task.IOI, is_attn=True)
    split_ig_only_attn_ablated_scores[(layer, idx)] = performance
    
split_ig_only_mlp_ablated_scores = dict()
for layer, idx in split_ig_only_mlp_outliers:
    performance = test_single_ablated_performance(model, layer, idx, mean_corrupt_activations, Task.IOI, is_attn=False)
    split_ig_only_mlp_ablated_scores[(layer, idx)] = performance

#### Results

In [None]:
plot_bar_chart(ig_only_attn_ablated_scores, "Abalted Attention Heads", "Model Performance", "Ablated Performance for Attention Heads Highlighted only by IG")
plot_bar_chart(ig_only_mlp_ablated_scores, "Abalted MLP Neurons", "Model Performance", "Ablated Performance for MLP Neurons Highlighted only by IG")

In [None]:
plot_bar_chart(split_ig_only_attn_ablated_scores, "Abalted Attention Heads", "Model Performance", "Ablated Performance for Attention Heads Highlighted only by Split IG")
plot_bar_chart(split_ig_only_mlp_ablated_scores, "Abalted MLP Neurons", "Model Performance", "Ablated Performance for MLP Neurons Highlighted only by Split IG")