|<h2>Substack post:</h2>|<h1><a href="https://mikexcohen.substack.com/p/king-man-woman-queen-is-fake-news" target="_blank">Gender bias in large language models, part 2 (correcting the bias)</a></h1>|
|-|:-:|
|<h2>Teacher:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the post may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn.functional as F

In [None]:
### Run this cell only if you're using "dark mode"

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    'figure.facecolor': '#171717',
    'figure.edgecolor': '#171717',
    'axes.facecolor':   '#171717',
    'axes.edgecolor':   '#DDE2F4',
    'axes.labelcolor':  '#DDE2F4',
    'xtick.color':      '#DDE2F4',
    'ytick.color':      '#DDE2F4',
    'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
})

# Demo 1: Setup (repeat of Post 1)

In [None]:
from transformers import BertTokenizer, BertForMaskedLM

# Load BERT model and tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased')
model = BertForMaskedLM.from_pretrained('bert-large-uncased')
model.eval()

In [None]:
# list of target words
target_words = [ 'he','she' ]

# tokenize sentences
tokens_he  = tokenizer('The engineer informed the client that he would need more time.',return_tensors='pt')
tokens_she = tokenizer('The engineer informed the client that she would need more time.',return_tensors='pt')

# tokenize the masked sentence
tokens_mask = tokenizer(f'The engineer informed the client that {tokenizer.mask_token} would need more time.',return_tensors='pt')

In [None]:
# the mask index
maskTarget_idx = torch.where(tokens_mask['input_ids'][0] == tokenizer.mask_token_id)[0].item()

# token indices of target words
targets_idx = [tokenizer.encode(t)[1] for t in target_words]

# Demo 2: Inspect hidden states

In [None]:
# forward-pass one version to get hidden states
with torch.no_grad():
  out = model(**tokens_he,output_hidden_states=True)

In [None]:
print(f'There are {len(out.hidden_states)} layers of hidden states')
print(f'Each layer has shape {out.hidden_states[0].shape}')

In [None]:
layer = 13
hs = out.hidden_states[layer][0,maskTarget_idx,:]

plt.figure(figsize=(10,3))
plt.plot(hs,'ko',markerfacecolor=[.7,.9,.7,.5],markersize=10,linewidth=.5)
plt.gca().set(xlabel='Embeddings dimension',ylabel='Embeddings value',
              title=f"Embeddings for \"{tokenizer.decode(tokens_he['input_ids'][0,maskTarget_idx])}\" in layer {layer}")
plt.show()

In [None]:
n_hidden = model.config.num_hidden_layers

# loop over layers and get hidden-state activations at mask position
hss = np.zeros((n_hidden,model.config.hidden_size))
for layeri in range(n_hidden):
  hss[layeri,:] = out.hidden_states[layeri+1][0,maskTarget_idx,:].cpu().numpy()


plt.figure(figsize=(10,3))
plt.imshow(hss,aspect='auto',vmin=-1,vmax=1,cmap='plasma',origin='lower')

plt.colorbar(pad=.01)
plt.gca().set(xlabel='Hidden state index',ylabel='Layer')
plt.show()

# Demo 3: Manipulate internal activations

In [None]:
# indices (redefined later)
layer2replace = 40000 # no replacement...
he_vector = torch.zeros(model.config.hidden_size)
she_vector = torch.zeros(model.config.hidden_size)

# proportion "he" vector
p_he = .5

# 1) hook functions
def implant_hook(layer_number):
  def hook(module, input, output):

    # 2) only change this layer if there's a matching variable value
    if layer_number == layer2replace:

      # 3) unpack tuple
      hidden, *rest = output

      # 4) mix the old and the new
      mixvect = p_he*he_vector + (1-p_he)*she_vector
      hidden[0,maskTarget_idx,:] = mixvect

      # 5) reconstruct output
      output = tuple([hidden]+rest)
      print(f'Replaced layer {layer_number:2}')

    return output
  return hook


# 6) loop over layers and do surgery
handles = []
for layeri in range(n_hidden):
  h = model.bert.encoder.layer[layeri].register_forward_hook(implant_hook(layeri))
  handles.append(h)

In [None]:
# redefine as outside the range
layer2replace = 40000

# forward-pass the three versions
with torch.no_grad():
  out_he = model(**tokens_he,output_hidden_states=True)
  out_she = model(**tokens_she,output_hidden_states=True)
  out_mask = model(**tokens_mask,output_hidden_states=True)

In [None]:
# get s/he/they activation from one hidden state

layer2replace = 23
she_vector = out_she.hidden_states[layer2replace+1][0,maskTarget_idx,:]
he_vector  = out_he.hidden_states[layer2replace+1][0,maskTarget_idx,:]

with torch.no_grad():
  out_mask_replace = model(**tokens_mask,output_hidden_states=True)

In [None]:
# grab and visualize the log-softmax
logsm_orig = F.log_softmax(out_mask.logits[0,maskTarget_idx,:],dim=-1).detach()
logsm_repl = F.log_softmax(out_mask_replace.logits[0,maskTarget_idx,:],dim=-1).detach()

fig,axs = plt.subplots(1,2,figsize=(10,3.5))

axs[0].bar(np.arange(2)-.2,logsm_orig[targets_idx],width=.5,label='Original')
axs[0].bar(np.arange(2)+.2,logsm_repl[targets_idx],width=.5,label='Modified')
axs[0].legend()
axs[0].set(xticks=range(2),xticklabels=target_words,xlabel='Target words',ylabel='Log-softmax',title='Log-softmax for masked word')

axs[1].bar(np.arange(2)-.2,torch.exp(logsm_orig[targets_idx]),width=.5,label='Original')
axs[1].bar(np.arange(2)+.2,torch.exp(logsm_repl[targets_idx]),width=.5,label='Modified')
axs[1].legend()
axs[1].set(xticks=range(2),xticklabels=target_words,xlabel='Target words',ylabel='Softmax prob.',title='Softmax probability for masked word')

fig.suptitle(tokenizer.decode(tokens_mask['input_ids'][0,1:-1]),fontweight='bold')

plt.tight_layout()
plt.show()

In [None]:
bias_orig = logsm_orig[targets_idx[0]] - logsm_orig[targets_idx[1]]
bias_repl = logsm_repl[targets_idx[0]] - logsm_repl[targets_idx[1]]

print(f'Bias (he-she) in original model: {bias_orig:.3f}')
print(f'Bias (he-she) in modified model: {bias_repl:.3f}')

# Demo 4: Laminar profile of anti-bias impact

In [None]:
# redefine mixing
p_he = .5

# 1) initialize results vector and loop over layers
bias_scores = torch.zeros(n_hidden)

for layer2replace in range(n_hidden):

  # 2) vector to replace (from "she" sentence)
  she_vector = out_she.hidden_states[layer2replace+1][0,maskTarget_idx,:]
  he_vector  = out_he.hidden_states[layer2replace+1][0,maskTarget_idx,:]

  # 3) forward-pass with hook to replace
  with torch.no_grad():
    out_mask_replace = model(**tokens_mask,output_hidden_states=True)

  # 4) calculate the log-sm probabilities
  logsm_repl = F.log_softmax(out_mask_replace.logits[0,maskTarget_idx,:],dim=-1)

  # 5) calculate the bias towards "he"
  bias_scores[layer2replace] = logsm_repl[targets_idx[0]] - logsm_repl[targets_idx[1]]


In [None]:
plt.figure(figsize=(8,3))
plt.plot(bias_scores,'wh',markerfacecolor=[.7,.9,.7],markersize=12,linewidth=.5)
plt.axhline(0,linestyle='--',zorder=-3,color='gray')
plt.gca().set(xlabel='Layer of replacement',ylabel='Bias score')
plt.show()

In [None]:
# remove hooks
for h in handles:
  h.remove()