In [None]:
%pip install pandas torch transformers

In [2]:
# Purpose: To learn how words / tokens are used in LLMs.
# Using a small language model like gpt2, pythia smaller models,
# Experiment with sentences and see how next words are predicted
#
# Given a text, calculate next word probabilities and display prob and word.
# Also, generate longer sentences from original text.

import torch
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

model_id = "gpt2"
# model_id = "EleutherAI/pythia-14m"

original_text = "The fox is jumping over the"
# original_text = "The prince is in love with a"
# original_text = "The man is marrying with a"
text = original_text
print('model_id:', model_id)

model = AutoModelForCausalLM.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
total_tokens = len(tokenizer)
print(f"Total tokens in tokenizer: {total_tokens}")

# partial sentence
inputs = tokenizer(text, return_tensors="pt")
# inputs: {'input_ids': tensor([[  464, 21831,   318, 14284,   625,   262]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1]])}

print('text:', text)
print('input_ids:', inputs["input_ids"])
print('id and word:', [(id.item(), tokenizer.decode(id)) for id in inputs["input_ids"][0]])

# Calculate probabilities for the next words (token)
def get_next_word_probs(model, inputs):
  with torch.no_grad():
    # logits = model(**inputs).logits[:, -1, :]
    output = model(**inputs)  # CausalLMOutputWithCrossAttentions(loss=None, logits=tensor([[[...]]])
    logits = output.logits # torch.Size([1, 6, 50257])
    # print('output.logits.shape', logits.shape)
    last_columns = logits[:, -1, :]  # torch.Size([1, 50257])
    # print('last_columns.shape', last_columns.shape)

    prob = torch.nn.functional.softmax(last_columns[0], dim=-1)

  return prob

def show_next_token_choices(prob, top_n=5):
  return pd.DataFrame(
      [
          (id, tokenizer.decode(id), p.item())
          for id, p in enumerate(prob)
          if p.item()
      ],
      columns=["id", "token", "prob"],
  ).sort_values("prob", ascending=False)[:top_n]

prob = get_next_word_probs(model, inputs)
print('prob.shape', prob.shape)  # torch.Size([50257])
print(show_next_token_choices(prob, 10))

next_token_id = torch.argmax(prob).item()
next_word = tokenizer.decode(next_token_id)
print(f"Next most probable token id: {next_token_id}")
print(f"Next probable token: {next_word}")
text += next_word
print(f"New sentence: {text}")

model_id: gpt2
Total tokens in tokenizer: 50257
text: The fox is jumping over the
input_ids: tensor([[  464, 21831,   318, 14284,   625,   262]])
id and word: [(464, 'The'), (21831, ' fox'), (318, ' is'), (14284, ' jumping'), (625, ' over'), (262, ' the')]
prob.shape torch.Size([50257])
          id    token      prob
13990  13990    fence  0.106365
5743    5743     edge  0.027167
5509    5509     tree  0.026133
3355    3355     wall  0.023594
1353    1353      top  0.014355
7150    7150    trees  0.013775
19516  19516    cliff  0.012385
2046    2046     fire  0.009482
37413  37413   bushes  0.009091
2318    2318      bar  0.009064
Next most probable token id: 13990
Next probable token:  fence
New sentence: The fox is jumping over the fence


In [3]:
# Generate next words with a for-loop to see in detail (slower, streaming)
text = original_text
inputs = tokenizer(original_text, return_tensors="pt")
print(text, end='')

for i in range(20):
  inputs = tokenizer(text, return_tensors="pt")
  prob = get_next_word_probs(model, inputs)
  next_token_id = torch.argmax(prob).item()
  next_word = tokenizer.decode(next_token_id)
  print(next_word, end='')
  text += next_word

The fox is jumping over the fence.


"I'm going to kill you!"


"I'm going to

In [6]:
# Generate entire text, with a given maximum token length (faster)
# The sentence is likely different each time, due to temperature (randomness)
# Or even repeating same sentence (repetition_penalty controls this)
inputs = tokenizer(original_text, return_tensors="pt")
output = model.generate(**inputs, max_length=50, pad_token_id=tokenizer.eos_token_id,
    do_sample=True,  # Enable sampling for temperature
    temperature=0.9,
    repetition_penalty=1.5,
    num_beams=5  # Number of beams for beam search (optional)
    )
decoded = tokenizer.decode(output[0])
print(decoded)

The fox is jumping over the fence.

"I don't know what's going on, but I think he's going to jump over the fence," he said. "He's going to jump over the fence. He's going to jump
