# Background, motivation and set up

*Objective*: investigate the relationship between attribution scores and output gradients, and utilise this relationship to generate the "optimal" counterfactual inputs such that a specific model component will be assigned high attribution scores by IG/AP.

In [None]:
import torch
import numpy as np

from captum.attr import LayerIntegratedGradients
from captum.attr._utils.approximation_methods import approximation_parameters

from transformer_lens.utils import get_act_name, get_device
from transformer_lens import ActivationCache, HookedTransformer, HookedTransformerConfig
from transformer_lens.hook_points import HookPoint

import seaborn as sns
import matplotlib.pyplot as plt

device = get_device()
model = HookedTransformer.from_pretrained("gpt2-small", device=device)
n_steps = 50

In [None]:
#| output: true

clean_prompt = "After John and Mary went to the store, Mary gave a bottle of milk to"
corrupted_prompt = "After John and Mary went to the store, John gave a bottle of milk to"

clean_input = model.to_tokens(clean_prompt)
corrupted_input = model.to_tokens(corrupted_prompt)

def logits_to_logit_diff(logits, correct_answer=" John", incorrect_answer=" Mary"):
    # model.to_single_token maps a string value of a single token to the token index for that token
    correct_index = model.to_single_token(correct_answer)
    incorrect_index = model.to_single_token(incorrect_answer)
    return logits[0, -1, correct_index] - logits[0, -1, incorrect_index]

# Explicitly calculate and expose the result for each attention head
model.set_use_attn_result(True)
model.set_use_hook_mlp_in(True)

clean_logits, clean_cache = model.run_with_cache(clean_input)
clean_logit_diff = logits_to_logit_diff(clean_logits)
print(f"Clean logit difference: {clean_logit_diff.item():.3f}")

corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_input)
corrupted_logit_diff = logits_to_logit_diff(corrupted_logits)
print(f"Corrupted logit difference: {corrupted_logit_diff.item():.3f}")

In [None]:
def run_from_layer_fn(x, original_input, prev_layer, reset_hooks_end=True):
    # Force the layer before the target layer to output the given values, i.e. pass the given input into the target layer
    # original_input value does not matter; useful to keep shapes nice, but its activations will be overwritten
    
    def fwd_hook(act, hook):
        x.requires_grad_(True)
        return x
    
    logits = model.run_with_hooks(
        original_input,
        fwd_hooks=[(prev_layer.name, fwd_hook)],
        reset_hooks_end=reset_hooks_end
    )
    logit_diff = logits_to_logit_diff(logits).unsqueeze(0)
    return logit_diff

# Gradients with respect to interpolated inputs

We take the change in model output with respect to the interpolated input at the target component. We quantify the change in gradients as the maximum range of gradients.

In [None]:
def visualise_attn_interpolated_outputs(target_layer_num, target_pos):
    hook_name = get_act_name("result", target_layer_num)
    target_layer = model.hook_dict[hook_name]

    layer_clean_input = clean_cache[hook_name] # Baseline

    # Only corrupt at target head
    layer_corrupt_input = layer_clean_input.clone()
    layer_corrupt_input[:, :, target_pos] = corrupted_cache[hook_name][:, :, target_pos]

    # Take the model starting from the target layer
    forward_fn = lambda x: run_from_layer_fn(x, clean_input, target_layer)
    _, alphas_func = approximation_parameters("gausslegendre")
    alphas = alphas_func(n_steps)

    interpolated_inputs = [layer_clean_input + alpha * (layer_corrupt_input - layer_clean_input) for alpha in alphas]
    outputs = [forward_fn(i) for i in interpolated_inputs]

    print(outputs)

    plt.title(f"Model output at interpolated gradients: head {(target_layer_num, target_pos)}")
    plt.plot([o.item() for o in outputs])
    plt.xlabel("Interpolation coefficient")
    plt.ylabel("Output (logit difference)")
    plt.ylim(0, 6)
    plt.show()

In [None]:
def get_layer_baseline_inputs(target_layer_num, target_pos, clean_cache, corrupt_cache):
    hook_name = get_act_name("result", target_layer_num)
    target_layer = model.hook_dict[hook_name]

    layer_baseline = clean_cache[hook_name] # Baseline
    layer_input = layer_baseline.clone()
    layer_input[:, :, target_pos] = corrupt_cache[hook_name][:, :, target_pos]

    return layer_baseline, layer_input, target_layer


# Calculate attribution score based on mean over each embedding, for each token
def mean_attribution(attribution_scores, pos=None):
    per_token_score = attribution_scores.mean(dim=3)
    score = per_token_score.mean(dim=1)
    if pos is None:
        return score
    return score[:, pos]


def attn_interpolated_gradients(target_layer_num, target_pos):
    # Get the baseline inputs
    layer_baseline, layer_input, target_layer = get_layer_baseline_inputs(target_layer_num, target_pos, clean_cache, corrupted_cache)
    forward_fn = lambda x: run_from_layer_fn(x, clean_input, target_layer, reset_hooks_end=False)

    # Get interpolated inputs according to step sizes
    _, alphas_func = approximation_parameters("gausslegendre")
    alphas = alphas_func(n_steps)
    interpolated_inputs = [layer_baseline + alpha * (layer_input - layer_baseline) for alpha in alphas]

    # Calculate gradient of output with respect to interpolated inputs at target attention head
    _, seq_len, _, d_model = layer_input.shape
    grad_history = torch.zeros((n_steps, seq_len, 1, d_model))

    with torch.autograd.set_grad_enabled(True):
        for idx, i in enumerate(interpolated_inputs):
            output = forward_fn(i)
            grad = torch.autograd.grad(output, i)[0] # Shape (seq_len, n_heads, d_model)
            model.reset_hooks()

            # Take the gradient at target attention head
            grad_history[idx] = grad[:, target_pos, :]

    # ALTERNATIVE IMPLEMENTATION: CHECK RESULTS
    # with torch.autograd.set_grad_enabled(True):
    #     outputs = forward_fn(interpolated_inputs)
    #     grads = torch.autograd.grad(outputs, interpolated_inputs)[0] # Shape (n_steps, seq_len, n_heads, d_model)
    #     model.reset_hooks()

    #     # Take the gradient at target attention head
    #     grad_history = grads[:, :, target_pos, :]

    return alphas, grad_history


def quantify_gradients_range(mean_grad_history):
    # Expected input shape: (n_steps, 1)
    max_grad = torch.max(mean_grad_history)
    min_grad = torch.min(mean_grad_history)
    return max_grad - min_grad

We visualise the change in gradients for a specific attention head, and compare it to the shape of its output. We also check that the gradients range makes sense.

In [None]:
visualise_attn_interpolated_outputs(target_layer_num=11, target_pos=10)

alphas, grad_history_1110 = attn_interpolated_gradients(target_layer_num=11, target_pos=10)

mean_grad_history_1110 = mean_attribution(grad_history_1110) # Shape (n_steps, 1)

plt.title(f"Mean gradient of output wrt interpolated inputs (head 11.10)")
plt.plot(alphas, [grad.item() for grad in mean_grad_history_1110])
plt.xlabel("Interpolation coefficient")
plt.ylabel("Gradient of output wrt input at head")
plt.show()

print(quantify_gradients_range(mean_grad_history_1110))

In [None]:
visualise_attn_interpolated_outputs(target_layer_num=9, target_pos=6)

alphas, grad_history_96 = attn_interpolated_gradients(target_layer_num=9, target_pos=6)

mean_grad_history_96 = mean_attribution(grad_history_96) # Shape (n_steps, 1)

plt.title(f"Mean gradient of output wrt interpolated inputs (head 9.6)")
plt.plot(alphas, [grad.item() for grad in mean_grad_history_96])
plt.xlabel("Interpolation coefficient")
plt.ylabel("Gradient of output wrt input at head")
plt.show()

print(quantify_gradients_range(mean_grad_history_96))

### Correlation between gradients range and activation patching scores

We quantify the range in gradients across all attention heads, and compare it to the activation patching score.

In [None]:
attn_grad_ranges = torch.zeros((model.cfg.n_layers, model.cfg.n_heads))

for layer_num in model.cfg.n_layers:
    for head_pos in model.cfg.n_heads:
        alphas, grad_history = attn_interpolated_gradients(layer_num, head_pos)
        mean_grad_history = mean_attribution(grad_history)
        grad_range = quantify_gradients_range(mean_grad_history)
        attn_grad_ranges[layer_num, head_pos] = grad_range

In [None]:
attn_patch_results = torch.load("attn_patch_results.pt")

In [None]:
# Plot gradient ranges and activation patching scores side-by-side
plt.figure(figsize=(10,10))
plt.subplot(1, 2, 1)
plt.title("Gradient ranges for attention heads")
plt.imshow(attn_grad_ranges.detach(), cmap="RdBu")
plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))
plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))
plt.colorbar(orientation="horizontal")

plt.figure(figsize=(10,10))
plt.subplot(1, 2, 1)
plt.title("Activation patching scores for attention heads")
plt.imshow(attn_patch_results.detach(), cmap="RdBu")
plt.xlabel("Head Index")
plt.xticks(list(range(model.cfg.n_heads)))
plt.ylabel("Layer")
plt.yticks(list(range(model.cfg.n_layers)))
plt.colorbar(orientation="horizontal")

In [None]:
# Plot correlation between gradient range and activation patching score

attn_grad_ranges_1d = attn_grad_ranges.flatten()
attn_patch_results_1d = attn_patch_results.flatten()

sns.regplot(x=attn_grad_ranges_1d, y=attn_patch_results_1d)
plt.xlabel("Range in gradient of output wrt interpolated inputs")
plt.ylabel("Activation patching attribution scores")
plt.show()

print(f"Correlation coefficient between gradient range and activation patching score: {np.corrcoef(attn_grad_ranges_1d, attn_patch_results_1d)[0, 1]}")