|<h2>Substack post:</h2>|<h1><a href="https://mikexcohen.substack.com/p/llm-breakdown-46-transformer-outputs" target="_blank">LLM breakdown 4/6: Transformer outputs (hidden states)</a></h1>|
|-|:-:|
|<h2>Teacher:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the post may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# pytorch libraries
import torch
import torch.nn.functional as F

# huggingface LLM
from transformers import GPT2Tokenizer

In [None]:
### Run this cell only if you're using "dark mode"

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    'figure.facecolor': '#383838',#'#020617',#
    'figure.edgecolor': '#383838',#'#020617',#
    'axes.facecolor':   '#383838',#'#020617',#
    'axes.edgecolor':   '#DDE2F4',
    'axes.labelcolor':  '#DDE2F4',
    'xtick.color':      '#DDE2F4',
    'ytick.color':      '#DDE2F4',
    'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
})

# Demo 1: Inspecting the hidden states

In [None]:
# huggingface LLM
from transformers import AutoModelForCausalLM, GPT2Tokenizer

# GPT2 model and its tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# toggle model into "evaluation" mode (disable training-related operations)
model.eval()

In [None]:
model.config

In [None]:
# some tokens
txt = 'A wise man once said: Penguins are cute.'
tokens = tokenizer(txt,return_tensors='pt')
num_tokens = len(tokens['input_ids'][0])

for key,item in tokens.items():
  print(f'"{key}" contains:\n  {item}\n')

In [None]:
# forward pass and inspect output sizes
with torch.no_grad():
  outputs = model(**tokens,output_hidden_states=True)

print('Keys in "outputs":\n  ',outputs.keys())
print('\nSize of outputs.logits:\n  ',outputs.logits.shape)
print('\nNumber of hidden states:\n  ',len(outputs.hidden_states))
print('\nSize of each hidden state:\n  ',outputs.hidden_states[0].shape)

In [None]:
# some convenience variables
hs = outputs.hidden_states
num_hidden = len(hs)
hidden_dim = model.config.n_embd

In [None]:
# all embeddings from one token
whichToken = 8

# setup the figure
_,axs = plt.subplots(1,2,figsize=(12,4))

# loop over layers
for layeri in range(num_hidden):

  # extract the activations from this layer and this token
  acts = hs[layeri][0,whichToken,:]

  # plot all the activations
  axs[0].plot(np.random.normal(layeri,.05,hidden_dim),acts,'wo',markersize=8,
           markerfacecolor=mpl.cm.plasma((layeri+1)/num_hidden),alpha=.4)

  # plot the variance of the activations
  axs[1].plot(layeri,acts.var(),'ws',markersize=12,
           markerfacecolor=mpl.cm.plasma((layeri+1)/num_hidden))

# names of the layers, for the x-axis tick labels
layer_labels = ['Emb'] + [f'L{i}' for i in range(num_hidden-1)]

# adjust the axes
axs[0].set(xticks=range(num_hidden),xticklabels=layer_labels,xlabel='Hidden layer (model depth)',ylabel='Activation value',
              title=f'Hidden state activations for token "{tokenizer.decode(tokens['input_ids'][0,whichToken])}"')

axs[1].set(xticks=range(num_hidden),xticklabels=layer_labels,xlabel='Hidden layer (model depth)',ylabel='Activation variance',
              title=f'Activation variances for token "{tokenizer.decode(tokens['input_ids'][0,whichToken])}"')

plt.tight_layout()
plt.show()

# Demo 2: Cosine similarities within and across layers

In [None]:
# pick 4 evenly spaced tokens including the first and final
tokens2analyze = np.linspace(0,len(tokens['input_ids'][0])-1,4,dtype=int)

fig,axs = plt.subplots(1,4,layout='constrained',figsize=(12,3))

# loop over selected tokens
for toki in range(len(tokens2analyze)):

  # extract the hidden-state activations from this token into a matrix
  all_hiddens = torch.zeros((num_hidden,hidden_dim))
  for layeri in range(num_hidden):
    all_hiddens[layeri,:] = hs[layeri][0,toki,:]

  # and calculate the cosine similarity matrix on all pairs of layers
  cos_sim = F.cosine_similarity(all_hiddens.unsqueeze(0),all_hiddens.unsqueeze(1),dim=-1)

  # show the matrix
  h = axs[toki].imshow(cos_sim,cmap='plasma',vmin=.8,vmax=1,origin='lower')
  axs[toki].set(xticks=range(0,num_hidden,3),yticks=range(1,num_hidden,3),
                title=f'CS matrix for "{tokenizer.decode(tokens["input_ids"][0,tokens2analyze[toki]])}"')

# adjustments
axs[0].set(xlabel='Hidden layer (model depth)',ylabel='Hidden layer (model depth)')
fig.colorbar(h,ax=axs[-1],label='Cosine similarity',pad=.02,shrink=.97)

plt.show()

In [None]:
# convert tokens into a list for axis labeling
toks_list = [tokenizer.decode(tokens['input_ids'][0,i]) for i in range(num_tokens)]

# 4 evenly spaced layers
layers2analyze = np.linspace(0,num_hidden-1,4,dtype=int)

fig,axs = plt.subplots(1,4, layout='constrained',figsize=(12,3))

# loop over layers
for layeri in range(len(layers2analyze)):

  # cosine similarity matrix over all token pairs for this layer
  cos_sim = F.cosine_similarity(hs[layeri][0,:,:].unsqueeze(0),hs[layeri][0,:,:].unsqueeze(1),dim=-1)

  # show the matrix
  h = axs[layeri].imshow(cos_sim,cmap='plasma',vmin=.5,vmax=1,origin='lower')
  axs[layeri].set(xticks=range(num_tokens),yticks=range(num_tokens),yticklabels=toks_list,
                title=f'CS matrix for layer {layers2analyze[layeri]}')
  axs[layeri].set_xticklabels(toks_list,rotation=90)

fig.colorbar(h,ax=axs[-1],label='Cosine similarity',pad=.02,shrink=.91)

plt.show()

In [None]:
plt.figure(figsize=(10,4))

for layeri in range(num_hidden):

  # similarities across all tokens, excluding the first
  cos_sim = F.cosine_similarity(hs[layeri][0,1:,:].unsqueeze(0),hs[layeri][0,1:,:].unsqueeze(1),dim=-1)
  unique_sim = torch.unique(torch.triu(cos_sim,1))[1:]

  # and plot all the dots
  plt.plot(np.random.normal(layeri,.05,len(unique_sim)),unique_sim,'wo',markersize=8,
           markerfacecolor=mpl.cm.plasma((layeri+1)/num_hidden),alpha=.4)

# adjust the axis properties
plt.gca().set(xticks=range(num_hidden),xticklabels=layer_labels,
              xlabel='Hidden layer (model depth)',ylabel='Cosine similarity',
              title=f'Laminar profile of inter-token cosine similarities')

plt.tight_layout()
plt.savefig('cosine_similarities.png',dpi=300,transparent=True)
plt.show()

# Demo 3: Manipulating hidden states

In [None]:
txt = 'As Gregor Samsa awoke one morning from uneasy dreams, he found himself transformed in his bed into a gigantic' # next word is "insect"

tokens = tokenizer.encode(txt,return_tensors='pt')
print('The text contains:')
print(f'  {len(txt)} characters ({len(set(txt))} unique)')
print(f'  {len(tokens[0])} tokens ({len(set(tokens[0]))} unique)')

In [None]:
# "clean" forward pass
with torch.no_grad():
  outputs = model(tokens)

# find the most likely next tokens
_,indices = torch.topk(outputs.logits[0,-1,:],21)

print('Top 21 possible next words:')
c = 0
for t in indices:
  print(f'"{tokenizer.decode(t)}"',end=',   ')
  if c%7==6: print()
  c+=1

In [None]:
# find the log softmax for the target token
target_token_idx = tokenizer.encode(' insect')[0]

log_sm_logits = F.log_softmax(outputs.logits[0,-1,:],dim=-1)
target_logsm_clean = log_sm_logits[target_token_idx]

target_logsm_clean

In [None]:
# 1) initialize
target_logsm = np.zeros(num_hidden-1)

# 2) loop over layers
for layeri in range(num_hidden-1):

  # 3) create a hook function
  def hookfun(module,input,output):
    hidden, *rest = output      # 3a
    hidden.mul_(.8)             # 3b
    return tuple([hidden]+rest) # 3c

  # 4) implant the hook
  hookHandle = model.transformer.h[layeri].register_forward_hook(hookfun)

  # 5) forward pass
  with torch.no_grad():
    outputs = model(tokens,output_hidden_states=True)

  # 6) remove the hook
  hookHandle.remove()

  # 7) measure log-softmax logit for " insect"
  log_sm_logits = F.log_softmax(outputs.logits[0,-1,:],dim=-1)
  target_logsm[layeri] = log_sm_logits[target_token_idx]

In [None]:
plt.figure(figsize=(9,3))

plt.plot(target_logsm,'kh',markersize=14,markerfacecolor=[.9,.7,.7])
plt.axhline(target_logsm_clean,linestyle='--',color=[.7,.7,.7],zorder=-3)
plt.text(0,target_logsm_clean+.01,f'Clean: {target_logsm_clean:.2f}',va='bottom')

plt.gca().set(xlabel='Transformer block',ylabel='Log softmax',xticks=range(num_hidden-1),
              xticklabels=[f'L{i}' for i in range(num_hidden-1)],
              title='Impact of global hidden-state scaling on log-softmax')

plt.show()