# Tracing Back Gendered Pronouns in Biased Prompts
In 2023, the authors of [We Found An Neuron in GPT-2](https://www.lesswrong.com/posts/cgqh99SHsCv3jJYDS/we-found-an-neuron-in-gpt-2#Finding_3__We_can_use_neurons__output_congruence_to_find_specific_neurons_that_predict_a_token) found that a single MLP neuron in GPT-2 Large is crucial for predicting the token "an" over "a".
After reading this blog post, I began to wonder if the same method could be used to analysing the prediction of a binary-gendered pronoun (he/she) over another (she/he). Even if this approach is not enough to have an understanding on where/why a model choses one over another (spoiler: it isn't!), it can gives us insight on the granularity of the grammatical gender information in a transformer.

## Neurons, Pronouns, and Mechanistic Interpretability
The role of individual neurons in predicting a binary-gendered pronoun has been [analysed in the past](https://arxiv.org/abs/2004.12265). Vig et al. frame specific neurons as mediators in the causal path between input and output, finding that the neurons in layer 0 and the first hidden layer have the strongest indirect effects, with a small "bump" in the middle layers where the effect increases after the initial decline. Additionally, they find that most neurons don't contribute much, indicating that they may be specialized.

## Imports

In [40]:
import torch
from fancy_einsum import einsum
from transformer_lens import HookedTransformer, utils, ActivationCache
from torchtyping import TensorType as TT
import plotly.express as px
import numpy as np
from functools import partial
import transformer_lens.patching as patching

# setup env variables
import os
# change gpu here
os.environ["CUDA_VISIBLE_DEVICES"] = "1, 3"

## Utils

In [41]:
# for plots

def imshow(tensor, **kwargs):
    px.imshow(
        utils.to_numpy(tensor),
        color_continuous_midpoint=0.0,
        color_continuous_scale="RdBu",
        **kwargs,
    ).show()


def line(tensor, **kwargs):
    px.line(
        y=utils.to_numpy(tensor),
        **kwargs,
    ).show()


def scatter(x, y, xaxis="", yaxis="", caxis="", **kwargs):
    x = utils.to_numpy(x)
    y = utils.to_numpy(y)
    px.scatter(
        y=y,
        x=x,
        labels={"x": xaxis, "y": yaxis, "color": caxis},
        **kwargs,
    ).show()

## Model Setup

In [42]:
model = HookedTransformer.from_pretrained(
    "gpt2-small",
    center_unembed=True,
    center_writing_weights=True,
    fold_ln=True,
    refactor_factored_attn_matrices=True,
)

device: torch.device = utils.get_device()


`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884



Loaded pretrained model gpt2-small into HookedTransformer


## Prompt Setup

Most work on gender bias and interpretability focus on *occupational bias*, where a profession is associated with a certain gender (nurse -> woman, doctor -> man, etc...). However, gender bias in text can manifest in different [contextual and structural ways](https://aclanthology.org/W19-3802.pdf). For this experiment to be all-encompassing, we should test our hypothesis in several prompts that encompass diverse contextual/structuras biases, as well as being inclusive of non-binary genders and pronouns (we leave that as future work).

In [43]:
he_tok, she_tok = model.to_single_token(" he"), model.to_single_token(" she")
answer_token_indices = torch.tensor([[she_tok, he_tok]], device=model.cfg.device)
prompts = ["The parent was working because"]
utils.test_prompt(prompts[0], " she", model, prepend_bos=True, top_k=5)

Tokenized prompt: ['<|endoftext|>', 'The', ' parent', ' was', ' working', ' because']
Tokenized answer: [' she']


Top 0th token. Logit: 16.21 Prob: 32.86% Token: | she|
Top 1th token. Logit: 15.74 Prob: 20.59% Token: | he|
Top 2th token. Logit: 14.83 Prob:  8.33% Token: | of|
Top 3th token. Logit: 14.44 Prob:  5.60% Token: | the|
Top 4th token. Logit: 14.35 Prob:  5.15% Token: | they|


As expected, `gpt2-small` associates *parent* with a female pronoun! We now take the stereotypical *she* pronoun as the "correct" answer, in order to trace back the neuron activations that caused this output. 

In [45]:
def ave_correct_incorrect_logit_diff(logits, correct_tok, incorrect_tok, per_prompt=False):
    """Returns the logit difference between the correct and incorrect answer tokens."""
    final_token_logits = logits[:, -1, :] # Only the final logits are relevant for the answer
    answer_logits = final_token_logits.gather(dim=-1, index=torch.tensor([[correct_tok, incorrect_tok]], device=device))
    answer_logit_diff = answer_logits[:, 0] - answer_logits[:, 1]
    return answer_logit_diff if per_prompt else answer_logit_diff.mean()

tokens = model.to_tokens(prompts, prepend_bos=True).to(device=device)
clean_logits, cache = model.run_with_cache(tokens)
clean_logit_diff = ave_correct_incorrect_logit_diff(clean_logits, correct_tok=she_tok, incorrect_tok=he_tok)
print("' she' / ' he' logit difference:", clean_logit_diff.item())

' she' / ' he' logit difference: 0.46780967712402344


We know run a **corrupted** prompt, in which the model predicts *he*. We replace *parent* with *man* to elicit a *he* output from the model.

In [48]:
corrupted_prompts = ["The man was working because"]
corrupted_tokens = model.to_tokens(corrupted_prompts, prepend_bos=True)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens, return_type="logits")
corrupted_logit_diff = ave_correct_incorrect_logit_diff(corrupted_logits, correct_tok=she_tok, incorrect_tok=he_tok)
print("Corrupted Average Logit Diff", corrupted_logit_diff)
print("Clean Logit Diff", clean_logit_diff)

Corrupted Average Logit Diff tensor(-4.5342, device='cuda:0', grad_fn=<MeanBackward0>)
Clean Logit Diff tensor(0.4678, device='cuda:0', grad_fn=<MeanBackward0>)


## Activation Patching

Let's start by patching the residual stream at the start of each block.

In [49]:
# Whether to do the runs by head and by position, which are much slower
DO_SLOW_RUNS = True

def get_logit_diff(logits, answer_token_indices):
    if len(logits.shape)==3:
        # Get final logits only
        logits = logits[:, -1, :]
    correct_logits = logits.gather(1, answer_token_indices[:, 0].unsqueeze(1))
    incorrect_logits = logits.gather(1, answer_token_indices[:, 1].unsqueeze(1))
    return (correct_logits - incorrect_logits).mean()

clean_logits, clean_cache = model.run_with_cache(tokens)
corrupted_logits, corrupted_cache = model.run_with_cache(corrupted_tokens)


CLEAN_BASELINE = clean_logit_diff
CORRUPTED_BASELINE = corrupted_logit_diff
def ioi_metric(logits, answer_token_indices=answer_token_indices):
    # Subtract corrupted logit diff to measure the improvement, divide by the total improvement from clean to corrupted to normalise
    # 0 means zero change, negative means actively made worse, 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (get_logit_diff(logits, answer_token_indices) - CORRUPTED_BASELINE) / (CLEAN_BASELINE  - CORRUPTED_BASELINE)

resid_pre_act_patch_results = patching.get_act_patch_resid_pre(model, corrupted_tokens, cache, ioi_metric)

imshow(resid_pre_act_patch_results, 
       labels={"x": "Position", "y": "Layer"},
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))],
       title="resid_pre Activation Patching")

100%|██████████| 72/72 [00:01<00:00, 46.60it/s]


The model processes information regarding the *parent* token up until layer 9, and around layer 10 that information is moved to the last token.
We now attempt to replicate the results of Vis et al. by patching the output of the MLP layers and the attention layers.

In [50]:
resid_pre_act_patch_results = patching.get_act_patch_mlp_out(model, corrupted_tokens, cache, ioi_metric)

imshow(resid_pre_act_patch_results, 
       labels={"x": "Position", "y": "Layer"},
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))],
       title="mlp_out Activation Patching")

100%|██████████| 72/72 [00:01<00:00, 46.98it/s]


MLP 0 has the strongest effect, but it's generally regarded in GPT-2 as a layer that extends the functionality of the embedding, and as such we ignore it. Disregarding MLP 0, layers 5 and 7 are the ones which have the largest effect with regards to *parent*. Layer 6 has a suprisingly negative effect.

In [51]:
resid_pre_act_patch_results = patching.get_act_patch_attn_out(model, corrupted_tokens, cache, ioi_metric)

imshow(resid_pre_act_patch_results, 
       labels={"x": "Position", "y": "Layer"},
       x=[f"{tok} {i}" for i, tok in enumerate(model.to_str_tokens(tokens[0]))],
       title="attn_out Activation Patching")

100%|██████████| 72/72 [00:01<00:00, 47.25it/s]


In the attention modules, the higher impact appears in the last layers. According to Vis et al, the hypothesis here is that the biased information emerges in the MLP modules, and is then captured by the top attention modules. 

In [52]:
resid_pre_act_patch_results = patching.get_act_patch_attn_head_out_all_pos(model, corrupted_tokens, cache, ioi_metric)

imshow(resid_pre_act_patch_results, 
       labels={"x": "Head", "y": "Layer"},
       title="attn_out Activation Patching")

100%|██████████| 144/144 [00:03<00:00, 39.15it/s]


It we compare the 2 attention plots, it seems that L9H7 and L10H9 have the a very high impact. In order to interpret this, let's use SVD decomposition to obtain interpretable [semantic clusters](https://www.lesswrong.com/posts/mkbGjzxD8d8XqKHzA/the-singular-value-decompositions-of-transformer-weight) in the OV circuits!

### Utils for SVD decomposition

In [54]:
from copy import deepcopy
from tabulate import tabulate
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("gpt2")

def get_max_token_length(tokens):
  maxlen = 0
  for t in tokens:
    l = len(t)
    if l > maxlen:
      maxlen = l
  return maxlen

def pad_with_space(t, maxlen):
  spaces_to_add = maxlen - len(t)
  for i in range(spaces_to_add):
    t += " "
  return t

def convert_to_tokens(indices, tokenizer, strip=True, pad_to_maxlen=False):
    res = tokenizer.convert_ids_to_tokens(indices)
    if strip:
        res = list(map(lambda x: x[1:] if x[0] == 'Ġ' else "#" + x, res))
    if pad_to_maxlen:
      maxlen = get_max_token_length(res)
      res = list(map(lambda t: pad_with_space(t, maxlen), res))
    return res


def top_tokens(v_tok, k=100, pad_to_maxlen=False):
    v_tok = deepcopy(v_tok)
    ignored_indices = []
    v_tok[ignored_indices] = -np.inf
    values, indices = torch.topk(v_tok, k=k)
    res = convert_to_tokens(indices, tokenizer, pad_to_maxlen = pad_to_maxlen)
    return res

# Utility function to plot values in the same style as the Conjecture post
# The Conjecture post plotting library doesnt work for some reason, so we have to manually calculate the top tokens and plot them with tabulate
def plot_matrix(matrix, k=20):
  tensor = matrix.squeeze(1)
  Vs = list(torch.unbind(tensor, dim=1))
  Vs = [top_tokens(Vs[i].detach().float().cpu(), k = k, pad_to_maxlen=True) for i in range(len(Vs))]
  print(tabulate([*zip(*Vs)]))


`clean_up_tokenization_spaces` was not set. It will be set to `True` by default. This behavior will be depracted in transformers v4.45, and will be then set to `False` by default. For more details check this issue: https://github.com/huggingface/transformers/issues/31884



In [55]:
from transformer_lens import SVDInterpreter

svd_interpreter = SVDInterpreter(model)

ov1 = svd_interpreter.get_singular_vectors('OV', layer_index=9, head_index=7, num_vectors=20)
ov2 = svd_interpreter.get_singular_vectors('OV', layer_index=10, head_index=9, num_vectors=20)

plot_matrix(ov1)
plot_matrix(ov2)

--------  --------  -------  --------  ----------  ------------  -----------------  ---------  ----------  ------------  -----------  -------------  ---------  ---------------  -----------  -------------------------  ------------------  ------------------  --------  ---------
#his      #His      #He      #xon      rigid       Both          #ocious            #mented    #ãĤ®        latter        #otions      #aughs         #ESA       #CHAT            #bledon      #çīĪ                       #catentry           #*/(                Syd       Sta
#He       #his      #His     #ãĥ¯ãĥ³   strict      #Both         Highly             #ql        #knife      #both         #redo        happily        #osta      #ILA             #estinal     #etheless                  #wcsstore           Flavoring           Saur      aviation
his       his       himself  #warts    #areth      Neither       #azeera            #cious     #ubi        #Both         #zed         #ignt          hers       #ARI            

We see that both attention heads focus on binary-gendered pronouns, L9 focusing harder on male pronouns and L10 focusing harder on female pronouns.

## Neuron Activation (finally)

In [57]:
def patch_resid(corrupted_resid: TT["batch", "pos", "d_model"], hook, pos, clean_cache):
    corrupted_resid[:, pos, :] = clean_cache[hook.name][:, pos, :]
    return corrupted_resid

def normalize_patched_logit_diff(patched_logit_diff):
    # Subtract corrupted logit diff to measure the improvement,
    # divide by the total improvement from clean to corrupted to normalise.
    # 0 means zero change, negative means actively made worse,
    # 1 means totally recovered clean performance, >1 means actively *improved* on clean performance
    return (patched_logit_diff - corrupted_logit_diff)/(clean_logit_diff - corrupted_logit_diff)

def patch_neuron_activation(corrupted_mlp_act: TT["batch", "pos", "d_mlp"], hook, neuron, clean_cache):
    corrupted_mlp_act[:, :, neuron] = clean_cache[hook.name][:, :, neuron]
    return corrupted_mlp_act

patched_neurons_normalized_improvement = torch.zeros(model.cfg.d_mlp, device=device, dtype=torch.float32)
from tqdm import tqdm
#for neuron in tqdm(range(1000, model.cfg.d_mlp)):

def plot_patched_neurons(layer):
    with torch.no_grad():
        for neuron in tqdm(range(model.cfg.d_mlp)):
            hook_fn = partial(patch_neuron_activation, neuron=neuron, clean_cache=cache)
            patched_neuron_logits = model.run_with_hooks(
                corrupted_tokens,
                fwd_hooks = [(f"blocks.{layer}.mlp.hook_post", hook_fn)],
                return_type="logits"
            )
            patched_neuron_logit_diff = ave_correct_incorrect_logit_diff(patched_neuron_logits, correct_tok=she_tok, incorrect_tok=he_tok)
            patched_neurons_normalized_improvement[neuron] = normalize_patched_logit_diff(patched_neuron_logit_diff)
    patched_neuron_fig = px.scatter(y=patched_neurons_normalized_improvement.cpu().detach(),
        x=list(range(len(patched_neurons_normalized_improvement))), 
        title=f"Logit Difference From Patched Neurons in MLP Layer {layer}", 
        labels={"x":"Neuron", "y":"Patch Improvement"},
        )
    #patched_neuron_fig.add_annotation(x=1000, y=0.485, text="Neuron 892 stands out", showarrow=True, arrowhead=1, ax=50, ay=40)
    patched_neuron_fig.show()

Let's explore the middle MLP layers of the model (4, 5, 6 and 7)!

In [58]:
plot_patched_neurons(4)

100%|██████████| 3072/3072 [00:50<00:00, 60.88it/s]


The neuron that has the most negative impact is mostly associated with family. Interestingly, the 20 text examples given by Neuroscope (https://neuroscope.io/gpt2-small/4/2814.html) refer to *family* associated with a male subject, which may explain why the neuron has a negative impact in the prediction of *she*!

![title](img/4.png)

I don't want to give the impression that all these neurons are relevant for the task in an interpretable sense. The top positive neuron's top example on Neuroscope is:

![title](img/irrelevant.png)

In [59]:
plot_patched_neurons(5)

100%|██████████| 3072/3072 [00:50<00:00, 60.58it/s]


Whoa! In layer 5, neuron 1777 is an outlier. If we look it up on Neuroscope (https://neuroscope.io/gpt2-small/5/1777.html), we see that it is activated on male/female pronouns.

![title](img/5.png)

In [60]:
plot_patched_neurons(7)

100%|██████████| 3072/3072 [00:50<00:00, 61.20it/s]


If we look at neuron 3050 (https://neuroscope.io/gpt2-small/7/3050.html), we see that it is activated on female names and pronouns.

![title](img/7.png)

In short, it feels like a small number of neurons are responsible for the selection of grammatical gender, and some are focused on a specific gender.