|<h2>Substack post:</h2>|<h1><a href=" " target="_blank">Least squares part 4: modeling GPT activations</a></h1>|
|-|:-:|
|<h2>Teacher:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the post may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# pytorch libraries
import torch
import torch.nn.functional as F

# for running regression models
import statsmodels.api as sm

from transformers import AutoModelForCausalLM, GPT2Tokenizer

In [None]:
### Run this cell only if you're using "dark mode"

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    'figure.facecolor': '#383838',
    'figure.edgecolor': '#383838',
    'axes.facecolor':   '#383838',
    'axes.edgecolor':   '#DDE2F4',
    'axes.labelcolor':  '#DDE2F4',
    'xtick.color':      '#DDE2F4',
    'ytick.color':      '#DDE2F4',
    'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
})

# Import and inspect the GPT2 model

In [None]:
# GPT2 model and its tokenizer
from transformers import AutoModelForCausalLM, GPT2Tokenizer
gpt2 = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# switch to "evaluation" mode (disable training-related operations)
gpt2.eval()

In [None]:
# variable for the number of layers
n_layers = gpt2.config.n_layer

# layer names for plotting labels
layerNames = [ f'L{i}' for i in range(n_layers) ]

# Tokenize text and get attention adjustments

In [None]:
# tokenize some text
# https://en.wikipedia.org/wiki/Rock_music_in_Hungary

txt = 'Hungarian rock has been a part of the popular music of Hungary since the early 1960s. The first major bands were Illés, Metró and Omega. At the time, rock was not approved of by the Hungarian Communist authorities. In the 1970s, the Communists cracked down on rock, and Illés was banned from recording. Some members of the other bands formed a supergroup called Locomotiv GT, while the band Omega became very popular in Germany.'

tokens = tokenizer.encode(txt,return_tensors='pt')
n_tokens = len(tokens[0])

print(f'The text contains {len(txt)} characters and {len(tokens[0])} tokens.\n')
for tok in tokens[0]:
  print(f'Token {tok:6}: "{tokenizer.decode(tok)}"')

In [None]:
# hook functions to store attention adjustment vectors

# 1) initialize an empty dictionary
adjustments = {}

# 2) an "outer" function that creates a hook function
def implant_hook(layer_number):
  def hook(module, input, output):

    # 3) grab the attention adjustments and store in the dictionary
    adjustments[f'L{layer_number}_attn'] = output[0].detach().numpy()

  return hook

# 4) implant hooks into all layers
for layeri in range(n_layers):
  layername = gpt2.transformer.h[layeri].attn
  layername.register_forward_hook(implant_hook(layeri))

In [None]:
# push the tokens through the model
gpt2(tokens)

# check the key names and adjustments sizes
print(adjustments.keys(),'\n')
print(adjustments['L0_attn'].shape)

In [None]:
# visualize
_,axs = plt.subplots(1,2,figsize=(12,4))

# 1) loop over the layers
for layeri in range(n_layers):

  # 2) extract the norms and log-transform
  norms = np.linalg.norm(adjustments[f'L{layeri}_attn'][0,1:,:],axis=-1)
  norms = np.log(norms)

  # 3) show all norms
  axs[0].plot(np.random.normal(layeri,.08,len(norms)),norms,'wh',
              markerfacecolor=mpl.cm.plasma(layeri/n_layers),
              markersize=10,alpha=.4,linewidth=.3)

  # 4) get and show histograms
  y,x = np.histogram(norms,bins='fd')
  axs[1].plot(x[:-1],y,'s-',color=mpl.cm.plasma(layeri/n_layers))


# adjust the axes
axs[0].set(xticks=range(n_layers),xticklabels=layerNames,title='Attention norms',
           xlabel='Transformer block',ylabel='Attention norm (log)')
axs[1].set(xlabel='Attention norm (log)',ylabel='Count',title='Attention norm distributions')

plt.tight_layout()
plt.show()

# Build and fit a model for one transformer

In [None]:
# the model
# y[t] = b0 + b1*y[t-1] + b2*t[t-2]

# 1) initialize
designMat = np.zeros((n_tokens-3,3))
y = np.zeros(n_tokens-3)

# 2) get the norms
norms = np.linalg.norm(adjustments['L0_attn'][0,:,:],axis=-1)
norms = np.log(norms)

# 3) loop over tokens
for t in range(3,n_tokens):

  # 4) dependent variable
  y[t-3] = norms[t]

  # 5) design matrix
  designMat[t-3,0] = 1
  designMat[t-3,1] = norms[t-1]
  designMat[t-3,2] = norms[t-2]

# check sizes
print(f'The design matrix has shape {designMat.shape}')
print(f'The dependent variable has shape {y.shape}')

In [None]:
# statsmodel to fit the regression
regmodel = sm.OLS(y,designMat).fit()
print(regmodel.summary())

In [None]:
# extracting parameters from the fitted model
print('.params = ',regmodel.params)
print('.pvalues = ',regmodel.pvalues)
print('')

for beta,pvals in zip(regmodel.params,regmodel.pvalues):
  print(f'beta = {beta:7.4f}, p = {pvals:5.3f}')

In [None]:
# confirm that numpy gives the same beta values as sm.OLS
betas_np = np.linalg.lstsq(designMat,y,rcond=None)[0]

for beta in betas_np:
  print(f'beta = {beta:7.4f}')

# Run the regression for all layers

In [None]:
# initialize
betas = np.zeros((n_layers,2))
pvals = np.zeros((n_layers,2))

for layeri in range(n_layers):

  # 1) get the norms
  norms = np.linalg.norm(adjustments[f'L{layeri}_attn'][0,:,:],axis=-1)
  norms = np.log(norms)

  # 2) create the design matrix
  for t in range(3,n_tokens):
    designMat[t-3,:] = [ 1,norms[t-1],norms[t-2] ]

  # 3) fit the model
  regmodel = sm.OLS(norms[3:],designMat).fit()

  # 4) extract the parameters and p-values
  betas[layeri,:] = regmodel.params[1:]
  pvals[layeri,:] = regmodel.pvalues[1:]


In [None]:
# plot the results!
plt.figure(figsize=(10,4))

# line plot of the betas for t-1
plt.plot(betas[:,0],color=[.9,.7,.7],linewidth=.6,label='t-1')

# indicate the significant (p<.05) layers
plt.plot(np.where(pvals[:,0]<.05)[0],betas[pvals[:,0]<.05,0],'s',color=[.9,.7,.7],markersize=12)

# indicate the nonsignificant layers
plt.plot(np.where(pvals[:,0]>.05)[0],betas[pvals[:,0]>.05,0],'wx',markersize=6)


# repeat for t-2
plt.plot(betas[:,1],color=[.7,.7,.9],linewidth=.6,label='t-2')
plt.plot(np.where(pvals[:,1]<.05)[0],betas[pvals[:,1]<.05,1],'o',color=[.7,.7,.9],markersize=12)
plt.plot(np.where(pvals[:,1]>.05)[0],betas[pvals[:,1]>.05,1],'wx',markersize=6)

# beautify the plot
plt.legend(fontsize=14)
plt.gca().set(xticks=range(n_layers),xticklabels=layerNames,
              xlabel='Transformer block',ylabel='Beta',
              title='Predicting token update from previous token updates')
plt.grid(linewidth=.1,linestyle='--')
plt.axhline(0,color='gray')

plt.show()