|<h2>Substack post:</h2>|<h1><a href="https://mikexcohen.substack.com/p/llm-breakdown-66-mlp" target="_blank">LLM breakdown 6/6: MLP</a></h1>|
|-|:-:|
|<h2>Teacher:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the post may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# pytorch libraries
import torch
import torch.nn.functional as F

# huggingface LLM
from transformers import AutoModelForCausalLM, GPT2Tokenizer

In [None]:
### Run this cell only if you're using "dark mode"

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    'figure.facecolor': '#171717',
    'figure.edgecolor': '#171717',
    'axes.facecolor':   '#171717',
    'axes.edgecolor':   '#DDE2F4',
    'axes.labelcolor':  '#DDE2F4',
    'xtick.color':      '#DDE2F4',
    'ytick.color':      '#DDE2F4',
    'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
})

# Demo 0: Linear separation after dimensionality expansion

In [None]:
# angles
n = 100
theta = np.linspace(0,2*np.pi-1/n,n)

# coordinates in 2D
x_inner = 1*np.cos(theta) + np.random.randn(n)/10
y_inner = 1*np.sin(theta) + np.random.randn(n)/10
x_outer = 2*np.cos(theta) + np.random.randn(n)/10
y_outer = 2*np.sin(theta) + np.random.randn(n)/10

# dimensionality-expansion via nonlinear transform
z_inner = np.sqrt(x_inner**2 + y_inner**2)
z_outer = np.sqrt(x_outer**2 + y_outer**2)



### 2D scatter plot
fig = plt.figure(figsize=(12,5))
ax0 = fig.add_subplot(121)

ax0.plot(x_inner,y_inner,'ko',markerfacecolor=[.7,.9,.7],markersize=9)
ax0.plot(x_outer,y_outer,'ks',markerfacecolor=[.9,.7,.7],markersize=9)
ax0.axis('square')
ax0.set(title='Non-linearly separable in 2D',xlabel='x',ylabel='y',
        xticklabels=[],yticklabels=[])

### 3D scatter plot
ax1 = fig.add_subplot(122, projection='3d')
ax1.plot(x_inner,y_inner,z_inner,'ko',markerfacecolor=[.7,.9,.7],markersize=9)
ax1.plot(x_outer,y_outer,z_outer,'ks',markerfacecolor=[.9,.7,.7],markersize=9)
ax1.set(title='Linearly separable in 3D',xlabel='x',ylabel='y',zlabel='Radius',
        xticklabels=[],yticklabels=[])
ax1.view_init(20,20)


plt.tight_layout()
plt.show()

# Demo 0.5: GELU activation

In [None]:
# the data (simulating activations) and its nonlinear activation
x = torch.randn(4000)
x_gelu = F.gelu(x)


_,axs = plt.subplots(1,2,figsize=(9,3.3))

# the data
axs[0].plot(x,x,'ks',markerfacecolor=[.7,.9,.7,.5],label='Data')
axs[0].plot(x,x_gelu,'ko',markerfacecolor=[.7,.7,.9,.5],label='GELU')
axs[0].legend()
axs[0].set(title='Impact of GELU activation',xlabel='Input',ylabel='Output')

# the histograms
binbounds = np.linspace(-3,3,71)
y,x = np.histogram(x,bins=binbounds)
axs[1].plot(x[:-1],y,label='Data')

y,x = np.histogram(x_gelu,bins=binbounds)
axs[1].plot(x[:-1],y,label='GELU')

axs[1].legend()
axs[1].set(title='Distributions',xlabel='Activation value',ylabel='Count',xlim=binbounds[[0,-1]])

plt.tight_layout()
plt.show()

# Demo 1: Inspecting the hidden states

In [None]:
# GPT2 model and its tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2-medium')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# toggle model into "evaluation" mode (disable training-related operations)
model.eval()

In [None]:
# size variables used later
n_layers = model.config.n_layer
n_embd = model.config.n_embd

In [None]:
acts = {}

def make_mlp_hook(layer_idx):
  def hook(module,inputs,output):

    # 1) input into MLP
    acts[f'1_L{layer_idx}'] = inputs[0].detach()

    # 2) expansion (pre-gelu)
    acts[f'2_L{layer_idx}'] = module.c_fc(inputs[0]).detach()

    # 3) expansion (post-gelu)
    acts[f'3_L{layer_idx}'] = F.gelu(acts[f'2_L{layer_idx}'])

    # 4) contraction (output of MLP)
    acts[f'4_L{layer_idx}'] = output.detach()
  return hook

# 5) register one hook per layer
hookHandles = []
for layeri in range(n_layers):
  module_name = model.transformer.h[layeri].mlp
  h = module_name.register_forward_hook(make_mlp_hook(layeri))
  hookHandles.append(h)

In [None]:
# two sentences to process in parallel
sentences = [ "If the sun is round, then the moon is round.",
              "If a square is square, then why isn't a triangle square?" ]

# create a pad token
tokenizer.pad_token = tokenizer.eos_token

# tokenize the sentences
tokens = tokenizer(sentences,return_tensors='pt',padding=True)

# inspect the tokenization
print('*First sentence:')
print('  input_ids:',tokens['input_ids'][0].tolist())
print('  attention_mask:',tokens['attention_mask'][0].tolist())

print('\n*Second sentence:')
print('  input_ids:',tokens['input_ids'][1].tolist())
print('  attention_mask:',tokens['attention_mask'][1].tolist())

In [None]:
tox1 = torch.where(tokens['attention_mask'][0])[0].tolist()
tox1 = tox1[1:]
tox2 = torch.where(tokens['attention_mask'][1])[0].tolist()
tox2 = tox2[1:]

tokens2use = [tox1,tox2]
tokens2use

In [None]:
# forward pass and remove hooks
with torch.no_grad(): outputs = model(**tokens)
for h in hookHandles: h.remove()

# check the activations
acts.keys()

In [None]:
# check the sizes
print(f'1_L0 shape:',acts['1_L0'].shape)
print(f'2_L0 shape:',acts['2_L0'].shape)
print(f'3_L0 shape:',acts['3_L0'].shape)
print(f'4_L0 shape:',acts['4_L0'].shape)

In [None]:
# labels for the four data parts
labels = [ 'Input to MLP','Expanded pre-gelu','Expanded post-gelu','Contraction' ]

# create a figure and common histogram bins
plt.figure(figsize=(9,3))
binedges = np.linspace(-3,2,201)

# loop over the four MLP parts
for i in range(1,5):

  # gather vectorized activations for the useable tokens
  allacts = np.concatenate(
      (acts[f'{i}_L10'][0,tokens2use[0],:].flatten(),
       acts[f'{i}_L10'][1,tokens2use[1],:].flatten()), axis=0 )

  # extract the histogram and plot
  y,x = np.histogram(allacts,bins=binedges)
  plt.plot(x[:-1],y,label=labels[i-1],linewidth=2)

plt.legend()
plt.gca().set(xlabel='Activation values',ylabel='Count',xlim=binedges[[0,-2]],
              title='Distribution of MLP activations from one layer',ylim=[0,2000])

plt.tight_layout()
plt.show()

In [None]:
# intialize output tensor
allhistograms = np.zeros((n_layers,4,len(binedges)-1))

# loop over layers
for layeri in range(n_layers):

  # loop over the MLP parts
  for i in range(1,5):

    # gather vectorized activations for the useable tokens
    allacts = np.concatenate(
        (acts[f'{i}_L{layeri}'][0,tokens2use[0],:].flatten(),
         acts[f'{i}_L{layeri}'][1,tokens2use[1],:].flatten()),axis=0 )

    # extract the histogram and store
    allhistograms[layeri,i-1,:] = np.histogram(allacts,bins=binedges)[0]

# check the size
allhistograms.shape

In [None]:
fig,axs = plt.subplots(1,4,figsize=(12,3.5))

# plot each part in turn
for i in range(4):

  # create the image
  h = axs[i].imshow(allhistograms[:,i,:],aspect='auto',cmap='magma',origin='lower',
                extent=[binedges[0],binedges[-1],0,n_layers],vmin=0,vmax=1500)
  axs[i].set(title=f'{labels[i]}\n\n',xlabel='Activation values',ylabel='Layer',
             ylim=[0,n_layers])

  # colorbar
  ch = fig.colorbar(h,ax=axs[i],location='top',pad=.02)
  ch.ax.tick_params(labelsize=7)

plt.tight_layout()
plt.show()

# Demo 2: Imputing high-activation MLP projections

In [None]:
text = 'It was a dark and stormy'
target_idx = tokenizer.encode(' night')[0]

tokens = tokenizer.encode(text,return_tensors='pt')
tokens.shape, tokens, target_idx

In [None]:
# 1) initialize results vector
target_logprobs = np.zeros(n_layers+1)

# 2) proportion of units to manipulate
pct_ablation = .08

# 3) loop over layers
for layeri in range(-1,n_layers):

  # 4) define a hook function
  def replace_hook(module, input, output):

    # 4a) get the indices of the top p%
    idx = torch.topk(output[0,-1,:],int(pct_ablation*n_embd)).indices

    # 4b) replace with the mean
    output[0,-1,idx] = torch.mean(output[0,-1,:])

    return output

  # 5) implant the hook
  if layeri>-1:
    handle = model.transformer.h[layeri].mlp.c_proj.register_forward_hook(replace_hook)

  # 6) forward pass to get output logits, and remove hook
  with torch.no_grad(): out=model(tokens)
  if layeri>-1: handle.remove()

  # 7) get log-prob of target
  target_logprobs[layeri+1] = F.log_softmax(out.logits[0,-1,:].detach(),dim=-1)[target_idx]

In [None]:
plt.figure(figsize=(8,3))

# plot the clean log-prob
plt.axhline(target_logprobs[0],linestyle='--',color=[.7,.7,.7])
plt.text(0,target_logprobs[0]+.003,'Clean',va='bottom',fontsize=12,color=[.7,.7,.7])

# the rest of the probs
plt.plot(target_logprobs[1:],'kh',markerfacecolor=[.9,.7,.7],markersize=12)

# adjustments
plt.gca().set(xlabel='Transformer block',ylabel='Log probability',
              title=f'Impact of mean-imputing the top {100*pct_ablation:.1f}% of MLP neurons')

plt.tight_layout()
plt.show()