# Saturated Gradients

From the ablation experiments, we can see that IG assigns higher attribution scores than AP, but some of these attribution scores are overestimated. AP also underestimates the attribution scores for some heads!

- IG has more true positives, but also more false positives: IG has higher recall, but AP has higher precision.
- Overall the results between the methods are very similar.

What causes false positives in IG?

In [None]:
# Set up

%load_ext autoreload
%autoreload 2

import torch
import random
from transformer_lens import HookedTransformer
from transformer_lens.utils import get_device, get_act_name
import numpy as np
import matplotlib.pyplot as plt

from attribution_methods import run_from_layer_fn, compute_layer_to_output_attributions
from testing import Task, TaskDataset, logit_diff_metric, identify_outliers
from plotting import plot_attn, plot_attn_comparison, plot_correlation, plot_correlation_comparison, plot_bar_chart

from split_ig import SplitIntegratedGradients, SplitLayerIntegratedGradients, compute_layer_to_output_attributions_split_ig, split_integrated_gradients

In [None]:
torch.set_grad_enabled(False)
torch.cuda.empty_cache()

device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

## Output shapes

Hypothesis: outliers overestimated by IG are due to the shape of output curve in between the baseline and inputs to IG.

- IG calculates change in loss based on integrating gradients between two input values.
- A high attribution score could be caused by strong gradients (sensitivity) up until an intermediate input value (in between the two input values). In this case, the highlighted component would be important for the task "in between" (represented by different counterfactual inputs) instead of the target task.

![Overestimation](reference/overestimation.png)

To test this, we can visualise the gradients for intervals which are summed up by IG. We focus on attention head (9, 6) because it is highlighted more strongly by IG than by AP.

In [None]:
ioi_dataset = TaskDataset(Task.IOI)
ioi_dataloader = ioi_dataset.to_dataloader(batch_size=100)

clean_input, corrupted_input, labels = next(iter(ioi_dataloader))

clean_tokens = model.to_tokens(clean_input)
corrupted_tokens = model.to_tokens(corrupted_input)

clean_logits, clean_cache = model.run_with_cache(clean_tokens)
clean_logit_diff = logit_diff_metric(clean_logits, labels)
print(f"Clean logit difference: {clean_logit_diff}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)
corrupted_logit_diff = logit_diff_metric(corrupted_logits, labels)
print(f"Corrupted logit difference: {corrupted_logit_diff}")

In [None]:
from captum.attr._utils.approximation_methods import approximation_parameters

n_steps = 50

def visualise_attn_interpolated_outputs(target_layer_num, target_pos):
    hook_name = get_act_name("result", target_layer_num)
    visualise_interpolated_integrated_gradients(hook_name, target_layer_num, target_pos)


def visualise_mlp_interpolated_outputs(target_layer_num, target_pos):
    hook_name = get_act_name("pos", target_layer_num)
    visualise_interpolated_integrated_gradients(hook_name, target_layer_num, target_pos)    


def visualise_interpolated_integrated_gradients(hook_name, target_layer_num, target_pos):
    target_layer = model.hook_dict[hook_name]
    layer_clean_input = clean_cache[hook_name] # Baseline

    # Only corrupt at target head
    layer_corrupt_input = layer_clean_input.clone()
    layer_corrupt_input[:, :, target_pos] = corrupted_cache[hook_name][:, :, target_pos]

    # Take the model starting from the target layer
    forward_fn = lambda x: run_from_layer_fn(x, clean_input, target_layer)
    _, alphas_func = approximation_parameters("gausslegendre")
    alphas = alphas_func(n_steps)

    with torch.autograd.set_grad_enabled(True):
        interpolated_inputs = [layer_clean_input + alpha * (layer_corrupt_input - layer_clean_input) for alpha in alphas]
        outputs = [forward_fn(i) for i in interpolated_inputs]

    plt.title(f"Model output at interpolated gradients: {(target_layer_num, target_pos)}")
    plt.plot(alphas, [o.item() for o in outputs])
    plt.xlabel("Interpolation coefficient")
    plt.ylabel("Output (logit difference)")
    plt.ylim(0, 6)
    plt.show()

In [None]:
ig_mlp = torch.load("results/aligned/ioi/ig_mlp.pt")
ig_attn = torch.load("results/aligned/ioi/ig_attn.pt")

ap_mlp = torch.load("results/aligned/ioi/ap_mlp.pt")
ap_attn = torch.load("results/aligned/ioi/ap_attn.pt")

scaled_ig_attn = ig_attn * 1e5
attn_outliers = identify_outliers(scaled_ig_attn, ap_attn)
mlp_outliers = identify_outliers(ig_mlp, ap_mlp)

In [None]:
for layer, idx in attn_outliers[:5]:
    visualise_attn_interpolated_outputs(layer, idx)

for layer_idx, in mlp_outliers[:5]:
    visualise_mlp_interpolated_outputs(layer, idx)

## Split Integrated Gradients

The shape of outputs in integrated gradients suggests that IG may overestimate some attribution values due to saturated gradients at interpolated inputs. To confirm this, we run SplitIG (which cuts off the interpolated inputs if the gradients are saturated) and examine the level of agreement.

### Sanity check: no splitting

We check that Split IG with a split ratio of 1 (i.e. no splitting) produces the same result as regular IG, for a random attention head (5, 5).

In [None]:
hook_name = get_act_name("result", 5)
target_layer = model.hook_dict[hook_name]
prev_layer_hook = get_act_name("z", 5)
prev_layer = model.hook_dict[prev_layer_hook]

layer_clean_input = clean_cache[prev_layer_hook]
layer_corrupt_input = corrupted_cache[prev_layer_hook]

# Shape [batch, seq_len, d_head, d_model]
left_ig, _, _ = compute_layer_to_output_attributions_split_ig(
    clean_tokens, layer_corrupt_input, layer_clean_input, target_layer, prev_layer, logit_diff_metric, labels, ratio=1)

original_attributions = compute_layer_to_output_attributions(
    model, clean_tokens, layer_corrupt_input, layer_clean_input, target_layer, prev_layer, logit_diff_metric, labels
)

assert torch.allclose(left_ig, original_attributions.detach().cpu()), f"Split IG does not produce expected IG result"

Verify that Split IG at ratio 1 produces the same outputs as standard IG.

In [None]:
from captum.attr import LayerIntegratedGradients

n_samples = clean_tokens.size(0)
forward_fn = lambda x: run_from_layer_fn(model, clean_tokens, prev_layer, x, logit_diff_metric, labels)

split_ig = SplitLayerIntegratedGradients(forward_fn, target_layer, multiply_by_inputs=True)
split_ig_attributions, _, _, interpolated_inputs = split_ig.attribute(inputs=layer_corrupt_input,
                                baselines=layer_clean_input,
                                internal_batch_size=n_samples, # Needs to match patching shape
                                attribute_to_layer_input=False,
                                return_convergence_delta=False)
# split_ig_attributions = split_ig_attributions.reshape((n_samples, 50,) + split_ig_attributions.shape[1:])
split_ig_attributions = np.reshape(split_ig_attributions.detach().cpu().numpy(), (n_samples, 50,) + split_ig_attributions.shape[1:], order='F')
split_ig_attributions = torch.tensor(split_ig_attributions).to(device).sum(dim=1)

ig_embed = LayerIntegratedGradients(forward_fn, target_layer, multiply_by_inputs=True)
ig_attributions = ig_embed.attribute(inputs=layer_corrupt_input,
                                baselines=layer_clean_input, 
                                internal_batch_size=n_samples,
                                attribute_to_layer_input=False,
                                return_convergence_delta=False)

assert torch.allclose(split_ig_attributions, ig_attributions), f"Split IG does not produces same output as IG"

### Run Split IG

We run Split IG on the same IOI dataset, using a split ratio of 0.9.

In [None]:
ioi_split_ig_mlp, ioi_split_ig_attn = split_integrated_gradients(
    model, clean_tokens, clean_cache, corrupted_cache, logit_diff_metric, labels, ratio=0.9
)

torch.save(ioi_split_ig_mlp, "saved_results/ioi_split_ig_mlp.pt")
torch.save(ioi_split_ig_attn, "saved_results/ioi_split_ig_attn.pt")

### Analysis

In [None]:
ioi_ig_mlp = torch.load("saved_results/ioi_ig_mlp.pt")
ioi_ig_attn = torch.load("saved_results/ioi_ig_attn.pt")

ioi_split_ig_mlp = torch.load("saved_results/ioi_split_ig_mlp.pt")
ioi_split_ig_attn = torch.load("saved_results/ioi_split_ig_attn.pt")

In [None]:
plot_attn_comparison(ioi_ig_attn[:3].unsqueeze(0), ioi_split_ig_attn[:3].unsqueeze(0), model, "Integrated Gradients", "Split Integrated Gradients (0.9)")

In [None]:
plot_correlation(ioi_ig_mlp, ioi_split_ig_mlp, "Integrated Gradients Attribution Scores", "Split IG Attribution Scores", "Attribution Scores for Neurons in IOI")


In [None]:
plot_correlation(ioi_ig_attn, ioi_split_ig_attn, "Integrated Gradients Attribution Scores", "Split IG Attribution Scores", "Attribution Scores for Attention Heads in IOI")
