|<h2>Substack post:</h2>|<h1><a href="https://mikexcohen.substack.com/p/llm-breakdown-26-logits-and-next" target="_blank">LLM breakdown 2/6: Logits and next-token prediction</a></h1>|
|-|:-:|
|<h2>Teacher:<h2>|<h1>Mike X Cohen, <a href="https://sincxpress.com" target="_blank">sincxpress.com</a></h1>|

<br>

<i>Using the code without reading the post may lead to confusion or errors.</i>

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

# pytorch libraries
import torch
import torch.nn.functional as F

In [None]:
### Run this cell only if you're using "dark mode"

# svg plots (higher-res)
import matplotlib_inline.backend_inline
matplotlib_inline.backend_inline.set_matplotlib_formats('svg')

plt.rcParams.update({
    'figure.facecolor': '#020617',#'#383838',#
    'figure.edgecolor': '#020617',#'#383838',#
    'axes.facecolor':   '#020617',#'#383838',#
    'axes.edgecolor':   '#DDE2F4',
    'axes.labelcolor':  '#DDE2F4',
    'xtick.color':      '#DDE2F4',
    'ytick.color':      '#DDE2F4',
    'text.color':       '#DDE2F4',
    'axes.spines.right': False,
    'axes.spines.top':   False,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
})

# Import the model, and inspect output logits

In [None]:
# huggingface LLM
from transformers import AutoModelForCausalLM, GPT2Tokenizer

# GPT2 model and its tokenizer
model = AutoModelForCausalLM.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')

In [None]:
# Some tokenized text
txt = 'I think more people would eat tumeric if it were purple.'

tokens = tokenizer.encode(txt,return_tensors='pt')
print(f'There are {len(txt)} characters and {len(tokens[0])} tokens.')

In [None]:
model.eval()
with torch.no_grad():
  output = model(tokens)

dir(output)

In [None]:
output.logits.shape

In [None]:
print(f'5th token is "{tokenizer.decode(tokens[0,4])}"')
print(f'6th token is "{tokenizer.decode(tokens[0,5])}"')

In [None]:
plt.figure(figsize=(10,3))

# plot all of the tokens (.detach() breaks the variable from out of the computational graph)
plt.plot(output.logits[0,4,:].detach(),'h',color=[.3,.3,.3],markerfacecolor=[.7,.9,.7,.3])

# plot the prediction for the next token
plt.plot(tokens[0,5],output.logits[0,4,tokens[0,5]].detach(),'rs',label='Logit for next token')

plt.legend()
plt.gca().set(xlabel='Vocab index',ylabel='Logits (raw)',title='Logits from the 5th token',xlim=[-10,50280])
plt.show()

In [None]:
# What token does the model think should come next?
tokenizer.decode(torch.argmax(output.logits[0,4,:]))

In [None]:
# sum of raw logits
output.logits[0,4,:].sum()

# About softmaxing

In [None]:
# problems with direct translation of the softmax function:
logits = output.logits[0,4,:].detach()
softmax_direct = torch.exp(logits)/torch.exp(logits).sum()
softmax_logits = F.softmax(logits,dim=-1)

softmax_direct

In [None]:
# plot (note: there're lots of dots here! it takes 10-20 seconds to render the figure)
_,axs = plt.subplots(1,3,figsize=(12,3.3))
axs[0].plot(logits,'.',markeredgecolor='none',markersize=3,markerfacecolor=[.7,.9,.7,.3])
axs[0].set(xlabel='Vocab index',ylabel='Logits (raw)',title='Raw logits',xlim=[-10,50280])

axs[1].plot(softmax_logits,'.',markeredgecolor='none',markersize=5,markerfacecolor=[.9,.7,.7])
axs[1].set(xlabel='Vocab index',ylabel='Softmax probabilities',title='Softmax logits',xlim=[-10,50280])

axs[2].plot(logits,softmax_logits,'.',markeredgecolor='none',markersize=5,markerfacecolor=[.7,.7,.9])
axs[2].set(xlabel='Logits (raw)',ylabel='Softmax logits',title='Logits vs. softmax logits')

plt.tight_layout()
plt.show()

In [None]:
# log-softmax
log_softmax_logits = F.log_softmax(logits,dim=-1)

# plot
_,axs = plt.subplots(1,3,figsize=(12,3.3))
axs[0].plot(softmax_logits,'.',markeredgecolor='none',markersize=5,markerfacecolor=[.7,.9,.7])
axs[0].set(xlabel='Vocab index',ylabel='Softmax probabilities',title='Softmax logits',xlim=[-10,50280])

axs[1].plot(log_softmax_logits,'.',markeredgecolor='none',markersize=3,markerfacecolor=[.9,.7,.7,.3])
axs[1].set(xlabel='Vocab index',ylabel='Log-softmax',title='Log-softmax logits',xlim=[-10,50280])

axs[2].plot(softmax_logits,log_softmax_logits,'.',markeredgecolor='none',markersize=5,markerfacecolor=[.7,.7,.9])
axs[2].set(ylabel='Log-softmax',xlabel='Softmax',title='Softmax vs. log-softmax')

plt.tight_layout()
plt.show()

# Generating new tokens

In [None]:
tokens = tokenizer.encode('I like oat milk in my',return_tensors='pt')
final_logits = model(tokens).logits[0,-1,:].detach()

max_logit = torch.argmax(final_logits)
print(f'The most likely next token is "{tokenizer.decode(max_logit)}"')

In [None]:
plt.figure(figsize=(10,3))

# plot all of the tokens
plt.plot(final_logits,'h',color=[.3,.3,.3],markerfacecolor=[.7,.9,.7,.3])

# plot the prediction for the next token
plt.plot(max_logit,final_logits[max_logit],'rs',label=f'Max logit ("{tokenizer.decode(max_logit)}")')

plt.legend()
plt.gca().set(xlabel='Vocab index',ylabel='Logits (raw)',title='Logits from the final token',xlim=[-150,50290])
plt.show()

In [None]:
# top 10 choices
print('   Logit   |    Token')
print('-----------+---------------')
for t in torch.topk(final_logits,10)[1]:
  print(f' {final_logits[t]:.3f}  |  "{tokenizer.decode(t)}"')

# Where's the coffee??

In [None]:
coffee_idx = tokenizer.encode(' coffee')[0]
print(f'" coffee" has index {coffee_idx}')

# sort the final logits
sidx = torch.argsort(final_logits,descending=True)

# and find the position of " coffee"
torch.where(sidx==coffee_idx)[0]

# Probabilistic token sampling

In [None]:
softmax_logits = F.softmax(final_logits,dim=-1)
multin_tokens = torch.multinomial(softmax_logits,5)
for t in multin_tokens:
  print(f'"{tokenizer.decode(t)}"')

# Generating a token sequence

In [None]:
tokens = tokenizer.encode('I like oat milk in my',return_tensors='pt')

for i in range(10):

  # extract logits from the final token
  with torch.no_grad():
    logits = model(tokens).logits[0,-1,:]

  # transform to softmax-probability
  softmax_logits = F.softmax(logits,dim=-1)

  # pick the next token, either through sampling or via greedy
  next_token = torch.multinomial(softmax_logits,1) # comment either this line
  # next_token = torch.argmax(softmax_logits) # or this line

  # concatenate the list of tokens
  tokens = torch.cat([tokens,torch.tensor([[next_token]])],dim=-1)

  # and print the results so far
  print(f'Iteration {i}:  {tokenizer.decode(tokens[0])}')

In [None]:
# The better way
tokens = tokenizer.encode('I like oat milk in my',return_tensors='pt')

token_seq = model.generate(tokens,max_new_tokens=10,do_sample=True)
tokenizer.decode(token_seq[0])

# Manipulating model interals with hooks

In [None]:
# inspect the architecture
model

In [None]:
# define and implant the hook function
def hook(module, input, output):

  # replace token index coffee_idx with max+10
  actual_max = torch.argmax(output[0,-1,:])
  output[0,-1,coffee_idx] = output[0,-1,actual_max] + 10

  # and return the modified version
  return output

hookHandle = model.lm_head.register_forward_hook(hook)

In [None]:
# get outputs and find the next token
final_logits = model(tokens).logits[0,-1,:].detach()
max_logit = torch.argmax(final_logits)
print(f'The most likely next token is "{tokenizer.decode(max_logit)}"')

# remove the hook function
hookHandle.remove()

In [None]:
plt.figure(figsize=(10,3))

# plot the prediction for the next token
plt.plot(max_logit,final_logits[max_logit],'rs',label=f'Max logit ("{tokenizer.decode(max_logit)}")')

# and plot all of the tokens
plt.plot(final_logits,'h',color=[.3,.3,.3],markerfacecolor=[.7,.9,.7,.3])

plt.legend()
plt.gca().set(xlabel='Vocab index',ylabel='Logits (raw)',title='Logits from the final token',xlim=[-150,50290])
plt.show()