<a href="https://colab.research.google.com/github/mikexcohen/Substack/blob/main/textHeatmaps_GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

|<h2>Substack post:</h2>|<h1><a href="https://mikexcohen.substack.com/p/drawing-text-heatmaps-to-visualize" target="_blank">Drawing text heatmaps to visualize LLM calculations</a></h1>|
|-|:-:|
|<h2>Teacher:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the post may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# pytorch libraries
import torch
import torch.nn.functional as F

# huggingface LLM
from transformers import AutoModelForCausalLM, GPT2Tokenizer

In [None]:
### Run this cell only if you're using "dark mode"

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    'figure.facecolor': '#383838',
    'figure.edgecolor': '#383838',
    'axes.facecolor':   '#383838',
    'axes.edgecolor':   '#DDE2F4',
    'axes.labelcolor':  '#DDE2F4',
    'xtick.color':      '#DDE2F4',
    'ytick.color':      '#DDE2F4',
    'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold'
})

# Calculate the width of one letter

In [None]:
fig,ax = plt.subplots(figsize=(10,2))

# draw a text object
temp_text = ax.text(0,0,'n',fontsize=12,fontfamily='monospace')

# get its bounding box in display coordinates
bbox = temp_text.get_window_extent(renderer=fig.canvas.get_renderer())

# convert from display to axis coordinates
inv = ax.transAxes.inverted()
bbox_axes = inv.transform([[bbox.x0,bbox.y0], [bbox.x1,bbox.y1]])
en_width = bbox_axes[1,0] - bbox_axes[0,0] # bbox is [(x0,y0),(x1,y1)]

plt.close(fig)
en_width

# Demo: Color words according to character count

In [None]:
# some text
text = ("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut "
        "labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris "
        "nisi ut aliquip ex ea commodo consequat. Duis aute irure dolor in reprehenderit in voluptate velit "
        "esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt "
        "in culpa qui officia deserunt mollit anim id est laborum." )

# "tokenize" the text
words = text.split()
words

In [None]:
# get the lengths of the words and convert to numpy
lens = [len(i) for i in words]
lens = np.array(lens)

plt.figure(figsize=(10,3))
plt.plot(lens,'ks',markerfacecolor=[.7,.9,.7])
plt.gca().set(xlabel='Word index',ylabel='Word length',title='Raw word counts')
plt.show()

In [None]:
# min/max scale
charcountsScale = (lens-lens.min()) / (lens.max()-lens.min())

# visualize
plt.figure(figsize=(10,3))
plt.plot(charcountsScale,'kh',markerfacecolor=[.7,.9,.9])
plt.gca().set(xlabel='Word index',ylabel='Word length',title='Min-max scaled word counts')
plt.show()

# Visualize the heatmap

In [None]:
# 1) initializations
x_pos = 0  # starting x position (in axis coordinates)
y_pos = 1  # vertical center

fig, ax = plt.subplots(figsize=(10,2),facecolor='w')
ax.axis('off')

# 2) for-loop
for i,word in enumerate(words):

  # 3) width of this word
  word_width = en_width*len(word)

  # 4) colorval is the scaled length of the word
  colorval = charcountsScale[i]

  # 5) text object with background color matching the scalar value
  ax.text(x_pos+word_width/2, y_pos, word, fontsize=12, color='k',
          ha='center', va='center',fontfamily='monospace',
          bbox = dict(boxstyle='round,pad=.3',
          facecolor=mpl.cm.Reds(colorval), edgecolor='none', alpha=.8))

  # 6) update x_pos
  x_pos += word_width + .015 # plus a small gap

  # 7) end of the line; reset coordinates and counter
  if x_pos>1.2:
    y_pos -= .2
    x_pos = 0

plt.show()

# Importing GPT2

In [None]:
# GPT2 model and its tokenizer
gpt2 = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

# inspect the architecture
gpt2

In [None]:
text = ( "There are many statistical analyses that quantify bivariate (two variables) "
         "relationships; in this post I will describe two — correlation and cosine "
         "similarity — discuss how they relate to each other, and advise you on when "
         "to use which one. The upshot is that the measures can be identical but are "
         "often different. Which one to use depends entirely on whether the scale of "
         "the data is important."
         )

tokens = tokenizer.encode(text,return_tensors='pt')
tokens

In [None]:
# get the outputs of the model
outputs = gpt2(tokens)
outputs.logits.shape

In [None]:
print(f'The 49th token is "{tokenizer.decode(tokens[0,48])}"')
print(f'The 50th token is "{tokenizer.decode(tokens[0,49])}"')

In [None]:
plt.figure(figsize=(10,3))

# plot the prediction for the next token
plt.plot(tokens[0,49],outputs.logits[0,48,tokens[0,49]].item(),'rs',label='Logit for next token')

# and all of the tokens
plt.plot(outputs.logits[0,48,:].detach(),'h',color=[.3,.3,.3],markerfacecolor=[.7,.9,.7,.3])

plt.legend()
plt.gca().set(xlabel='Vocab index',ylabel='Logits (raw)',title='Logits from the 49th token',xlim=[-10,50280])
plt.show()

In [None]:
# 1) initialize vector of log-probabilities for each token
predicted_logSM = np.zeros(len(tokens[0]))

# 2) loop over tokens (skip first)
for toki in range(1,len(tokens[0])):

  # 3) get the logit outputs and convert to log-softmax
  # use the PREVIOUS token position, bc that predicts the current token choice
  tokenlogit = outputs.logits[0,toki-1,:]
  lsm = F.log_softmax(tokenlogit,dim=-1)

  # 4) extract the softmax for the actual token
  predicted_logSM[toki] = lsm[tokens[0,toki]].item()

In [None]:
# min-max scale the predictions
y = predicted_logSM[1:] # ignore the first value
predicted_logSM[1:] = (y-y.min()) / (y.max()-y.min())

In [None]:
# 1) initializations
x_pos = 0  # starting x position (in axis coordinates)
y_pos = 1  # vertical center

fig, ax = plt.subplots(figsize=(10,2),facecolor='w')
ax.axis('off')

# 2) for-loop
for i in range(len(tokens[0])):

  # get this token
  word = tokenizer.decode(tokens[0,i])

  # 3) width of this word
  word_width = en_width*len(word)

  # 4) colorval is the prediction for this token
  colorval = predicted_logSM[i]

  # 5) text object with background color matching the scalar value
  ax.text(x_pos+word_width/2, y_pos, word, fontsize=12, color='k',
          ha='center', va='center',fontfamily='monospace',
          bbox = dict(boxstyle='round,pad=.3',
          facecolor=mpl.cm.Reds(colorval), edgecolor='none', alpha=.8))

  # 6) update x_pos
  x_pos += word_width + .015 # plus a small gap

  # 7) end of the line; reset coordinates and counter
  if x_pos>1.:
    y_pos -= .2
    x_pos = 0

plt.show()