|<h2>Substack post:</h2>|<h1><a href="https://mikexcohen.substack.com/p/llm-breakdown-56-attention" target="_blank">LLM breakdown 5/6: Attention</a></h1>|
|-|:-:|
|<h2>Teacher:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the post may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# pytorch libraries
import torch
import torch.nn.functional as F

# huggingface LLM
from transformers import GPT2Tokenizer

In [None]:
### Run this cell only if you're using "dark mode"

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    'figure.facecolor': '#171717',
    'figure.edgecolor': '#171717',
    'axes.facecolor':   '#171717',
    'axes.edgecolor':   '#DDE2F4',
    'axes.labelcolor':  '#DDE2F4',
    'xtick.color':      '#DDE2F4',
    'ytick.color':      '#DDE2F4',
    'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
})

# Demo 1: Inspecting the attention adjustment vectors

In [None]:
# huggingface LLM
from transformers import AutoModelForCausalLM, GPT2Tokenizer

# GPT2 model and its tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2-large')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
# note: all GPT2 variants use the same tokenizer

# toggle model into "evaluation" mode (disable training-related operations)
model.eval()

In [None]:
# convenience variables
n_layers = model.config.n_layer
n_embd = model.config.n_embd

# some helpful variables
n_heads = model.config.n_head
head_dim = n_embd // n_heads # will be used in demo 3
sqrtD = torch.sqrt(torch.tensor(head_dim)) # used for attention equation

In [None]:
# hook functions to store QVK vectors

# 1) initialize an empty dictionary
activations = {}

# 2) an "outer" function that creates a hook
def implant_hook(layer_number):
  def hook(module, input, output):

    # 3) grab the activations and store in the dictionary
    activations[f'L{layer_number}_qvk'] = output.detach()
    # no 'return' in the hook function!

  return hook # this line is for the 'implant_hook' function

# 4) implant hooks into all layers
hookhandles = []
for layeri in range(n_layers):
  layername = model.transformer.h[layeri].attn.c_attn
  hookhandles.append(layername.register_forward_hook(implant_hook(layeri)))

In [None]:
txt = """Be who you are and say what you feel,
      because those who mind don't matter and those who matter don't mind"""
tokens = tokenizer.encode(txt,return_tensors='pt')

n_tokens = len(tokens[0])

print('The text contains:')
print(f'  {len(txt)} characters ({len(set(txt))} unique)')
print(f'  {n_tokens} tokens ({len(set(tokens[0]))} unique)')

In [None]:
# push through the model
with torch.no_grad(): model(tokens)
activations.keys(),activations['L5_qvk'].shape

In [None]:
# split into separte matrices
q,k,v = torch.split(activations['L5_qvk'],n_embd,dim=-1)
q.shape,k.shape,v.shape

In [None]:
# histograms
y_q,x_q = torch.histogram(q[0,1:,:].flatten(),bins=100,density=True)
y_k,x_k = torch.histogram(k[0,1:,:].flatten(),bins=100,density=True)
y_v,x_v = torch.histogram(v[0,1:,:].flatten(),bins=100,density=True)

plt.figure(figsize=(9,3))
plt.plot(x_k[:-1],y_k,linewidth=2,label='K')
plt.plot(x_q[:-1],y_q,linewidth=2,label='Q')
plt.plot(x_v[:-1],y_v,linewidth=2,label='V')

plt.legend()
plt.gca().set(xlabel='Activation value',ylabel='Density',xlim=x_k[[0,-2]],
              title='Distribution of Layer 5 attention activations')

plt.show()

In [None]:
# common bin boundaries for all vectors
binEdges = np.linspace(-5,5,101)

# initializations
Qhist = np.zeros((n_layers,len(binEdges)-1))
Khist = np.zeros((n_layers,len(binEdges)-1))
Vhist = np.zeros((n_layers,len(binEdges)-1))

variances = np.zeros((n_layers,3))

# loop over all the layers
for layeri in range(n_layers):

  # split the activations matrices
  q,k,v = torch.split(activations[f'L{layeri}_qvk'],n_embd,dim=-1)

  # histograms
  Qhist[layeri,:] = np.histogram(q[0,1:,:].flatten().numpy(),bins=binEdges,density=True)[0]
  Khist[layeri,:] = np.histogram(k[0,1:,:].flatten().numpy(),bins=binEdges,density=True)[0]
  Vhist[layeri,:] = np.histogram(v[0,1:,:].flatten().numpy(),bins=binEdges,density=True)[0]

  # variances
  variances[layeri,0] = torch.var(q[0,1:,:]).item()
  variances[layeri,1] = torch.var(k[0,1:,:]).item()
  variances[layeri,2] = torch.var(v[0,1:,:]).item()

In [None]:
fig,axs = plt.subplots(1,3,figsize=(10,4))

h = axs[0].imshow(Qhist,aspect='auto',origin='lower',cmap='magma',vmin=0,vmax=.4,extent=[binEdges[0],binEdges[-2],0,n_layers])
axs[0].set(xlabel='Activation value',ylabel='Transformer block',title='Q activations\n\n\n')
ch = fig.colorbar(h,ax=axs[0],location='top',pad=.02)

h = axs[1].imshow(Khist,aspect='auto',origin='lower',cmap='magma',vmin=0,vmax=.3,extent=[binEdges[0],binEdges[-2],0,n_layers])
axs[1].set(xlabel='Activation value',title='K activations\n\n\n')
ch = fig.colorbar(h,ax=axs[1],location='top',pad=.02)

h = axs[2].imshow(Vhist,aspect='auto',origin='lower',cmap='magma',vmin=0,vmax=.7,extent=[binEdges[0],binEdges[-2],0,n_layers])
axs[2].set(xlabel='Activation value',title='V activations\n\n\n')
ch = fig.colorbar(h,ax=axs[2],location='top',pad=.02)

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(8,3))
plt.plot(variances,'.-')

plt.legend(['Q','K','V'])
plt.gca().set(xlabel='Transformer block',ylabel='Variance',xlim=[0,n_layers],
              title='Variance of attention activations')

plt.tight_layout()
plt.show()

# Demo 2: Distribution of QK^T

In [None]:
print(k.shape)
print(k[0,1:,:].transpose(-2,-1).shape)

In [None]:
qkt = q[0,1:,:] @ k[0,1:,:].transpose(-2,-1)  / sqrtD
qkt.shape

In [None]:
fig,axs = plt.subplots(1,2,figsize=(10,3))

# the matrix
h = axs[0].imshow(qkt,cmap='magma',vmin=-90,vmax=0)
fig.colorbar(h,ax=axs[0],pad=.02)
axs[0].set(title=r'$QK^T \;/\; \sqrt{d_k}$',xlabel='Tokens',ylabel='Tokens')

# the matrix vectorized
qkt_f = qkt.flatten()
scaled = (qkt_f-qkt_f.min()) / (qkt_f.max()-qkt_f.min())
axs[1].scatter(range(len(qkt_f)),qkt_f,20,edgecolor='w',linewidth=.3,
               marker='s',c=mpl.cm.magma(scaled),alpha=.8)
axs[1].set(xlabel='Dot product index',ylabel='Dot product value',
           title='Distribution of $QK^T$',xlim=[0,len(qkt_f)])

plt.tight_layout()
plt.show()

In [None]:
_,axs = plt.subplots(1,2,figsize=(10,3))

# normalization for mapping line colors to colorbar
cmap = mpl.cm.plasma
norm = mpl.colors.Normalize(vmin=0,vmax=n_layers)

# keep track of the means and standard deviations across the layers
meenz = np.zeros(n_layers)
stdz = np.zeros(n_layers)

# loop over all the layers
for layeri in range(n_layers):

  # split the matrices
  q,k,v = torch.split(activations[f'L{layeri}_qvk'],n_embd,dim=-1)

  # calculate the attention activations
  qkt = q[0,1:,:] @ k[0,1:,:].transpose(-2,-1)
  attn_acts = qkt / sqrtD

  # distribution of "raw" values
  y,x = np.histogram(attn_acts.flatten(),20,density=True)
  axs[0].plot(x[:-1],y,color=cmap(norm(layeri)),label=f'Layer {layeri}')

  # distribution of softmax-prob values
  y,x = np.histogram(F.softmax(attn_acts,dim=-1).flatten(),20,density=True)
  axs[1].plot(x[:-1],y,color=mpl.cm.plasma(layeri/n_layers),label=f'Layer {layeri}')

  # store the descriptive characteristics to be plotted later
  meenz[layeri] = attn_acts.mean()
  stdz[layeri] = attn_acts.std()


# plot adjustments
axs[0].set(xlabel='Activation value',ylabel='Density',title='Distribution of $QK^T$',ylim=[0,None])
axs[1].set(xlabel='Softmax probability',ylabel='log(density)',yscale='log',
           title='Distribution of $\\sigma(QK^T)$')

# create a colorbar
sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
cbar = plt.colorbar(sm,ax=axs[-1],pad=.02)
cbar.set_label('Transformer block')

plt.tight_layout()
plt.show()

In [None]:
_,ax = plt.subplots(1,1,figsize=(6,3))

# plot the means in green
ax.plot(meenz,'s-',color=[.7,.9,.7],label='Mean')
ax.set(xlabel='Transformer block')
ax.set_ylabel('Activation mean',color=[.7,.9,.7])
ax.tick_params(axis='y',colors=[.7,.9,.7])


# and the standard deviations in lavender
axx = ax.twinx()
axx.plot(stdz,'o-',color=[.7,.7,.9],label='Stdev.')
axx.spines['right'].set_visible(True)
axx.set_ylabel('Activation stdev.',color=[.7,.7,.9])
axx.tick_params(axis='y',colors=[.7,.7,.9])

# get both legends
lines, labels = ax.get_legend_handles_labels()
lines2, labels2 = axx.get_legend_handles_labels()
ax.legend(lines + lines2, labels + labels2)

plt.tight_layout()
plt.show()

# Demo 3: Impact of attention head lesion on token prediction

In [None]:
# separate the Q,K,V matrices
q,k,v = torch.split(activations['L9_qvk'][0,:,:],n_embd,dim=1)

# now split into heads
q_h = torch.split(q,head_dim,dim=1)

print(f'There are {len(q_h)} heads')
print(f'Each head has size {q_h[2].shape}')

In [None]:
# remove all the previous hooks
for h in hookhandles:
  h.remove()

In [None]:
with torch.no_grad():
  outputs_clean = model(tokens)

outputs_clean.logits.shape

In [None]:
print(txt)

In [None]:
# convert to log-softmax
log_smax_clean = F.log_softmax(outputs_clean.logits[0,-2,:],dim=-1)

plt.figure(figsize=(10,3.5))
plt.plot(log_smax_clean,'ko',markerfacecolor=[.7,.7,.9,.4],markersize=6)
plt.plot(tokens[0,-1],log_smax_clean[tokens[0,-1]],'rs',zorder=-2,alpha=.7,
         label=f'Final token ("{tokenizer.decode(tokens[0,-1])}")')

plt.legend()
plt.gca().set(xlabel='Vocab index',ylabel='Log probability',xlim=[-20,len(log_smax_clean)+19],
              title='Log softmax for penultimate token')

plt.tight_layout()
plt.show()

In [None]:
# Now for the experiment :)

In [None]:
# 1) initialize results matrix
ablation_logits = np.zeros(n_layers)

# 2) loop over layers
for layeri in range(n_layers):

  # 3) define the hook function
  def hook2ablate(module,input):

    # 3a) reshape so we can index heads (1 -> batch size)
    head_tensor = input[0].view(1,n_tokens,n_heads,head_dim)

    # 3b) replace 5th head with zeros
    head_tensor[:,-2,4,:] = 0

    # 3c) reshape back to tensor
    head_tensor = head_tensor.view(1,n_tokens,n_embd)

    # 3d) return a tuple matching the original
    return tuple(head_tensor,*input[1:])


  # 4) implant the hook into this layer
  layer2implant = model.transformer.h[layeri].attn.c_proj
  h = layer2implant.register_forward_pre_hook(hook2ablate)

  # 5) forward pass and get output logits
  with torch.no_grad():
    outputs_ablated = model(tokens)

  # 6) remove the hook
  h.remove()

  # 7) convert to log-softmax
  log_smax = F.log_softmax(outputs_ablated.logits[0,-2,:],dim=-1)
  ablation_logits[layeri] = log_smax[tokens[0,-1]]

In [None]:
plt.figure(figsize=(9,3))

# draw the dots with colors for layer
for i in range(n_layers):
  plt.plot(i,ablation_logits[i],'wh',markerfacecolor=mpl.cm.plasma(i/n_layers),markersize=12)

# and the clean logit
plt.axhline(log_smax_clean[tokens[0,-1]],linestyle='--',color=[.7,.7,.7],zorder=-2,label='Clean log-prob')

plt.legend()
plt.gca().set(xlabel='Transformer block',ylabel='Log probability',
              title='Log softmax for penultimate token')

plt.tight_layout()
plt.show()

In [None]:
# FYI
np.exp(-.161)