# Scratchpad 2 - Generate text, save and load model, pretrain

## 1. Generate text

### Simple generate text function with greedy sampling

In greedy decoding, the model always chooses the token with the highest probability (the biggest logit) to be the output.

In [1]:
import torch

def generate_text_simple(model, input_idx, max_new_tokens, context_size):
    # input_idx shape: (batch_size, n_tokens) in the current context

    # we will output the input tokens plus the generarted new tokens
    output_idx = input_idx
    # iterate over the number of new tokens to generate
    for _ in range(max_new_tokens):
        # in case the current tokens are longer than the model's supported context_size, 
        # crop the tokens in the front and preserve tokens that fit in the model's `context_size`
        idx = output_idx[:, -context_size:]

        # get the model's prediction for the current context
        with torch.no_grad():
            logits = model(idx)

        # predicted next token is at the last position of the logits, so we extract only the last token's logits.
        ## logits shape: (batch_size, context_size, vocab_size) -> next_token_logits shape: (batch_size, vocab_size)
        next_token_logits = logits[:, -1, :]
        # to find the index of the token with the highest probability, we only need to find the index of the largest logit in the last dimension (vocab_size)
        ## keepdim=True ensures that the output has the same shape as the input, except in the dimension where we take the argmax
        next_token_idx = torch.argmax(next_token_logits, dim=-1, keepdim=True) # shape: (batch_size, 1)
        # concatenate the new token to the output
        output_idx = torch.cat((output_idx, next_token_idx), dim=1)

    return output_idx

Test generate_text_simple function on our untrained GPT-2.

In [2]:
import torch
import tiktoken
from gpt.gpt_model import GPTModel

GPT_CONFIG_124M = {
    "vocab_size": 50257,  # Vocabulary size
    "context_length": 1024,  # Context length
    "embedding_dim": 768,  # Embedding dimension
    "n_heads": 12,  # Number of attention heads
    "n_layers": 12,  # Number of layers
    "dropout_rate": 0.1,  # Dropout rate
    "qkv_bias": False,  # Query-Key-Value bias
}

torch.manual_seed(123)
model = GPTModel(GPT_CONFIG_124M)
model.eval()    # disable dropout

GPTModel(
  (tok_emb): Embedding(50257, 768)
  (pos_emb): Embedding(1024, 768)
  (dropout_emb): Dropout(p=0.1, inplace=False)
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (norm1): LayerNorm()
      (att): MultiHeadAttention(
        (W_q): Linear(in_features=768, out_features=768, bias=False)
        (W_k): Linear(in_features=768, out_features=768, bias=False)
        (W_v): Linear(in_features=768, out_features=768, bias=False)
        (out_proj): Linear(in_features=768, out_features=768, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (norm2): LayerNorm()
      (ff): FeedForward(
        (layers): Sequential(
          (0): Linear(in_features=768, out_features=3072, bias=True)
          (1): GELU()
          (2): Linear(in_features=3072, out_features=768, bias=True)
        )
      )
      (drop_shortcut): Dropout(p=0.1, inplace=False)
    )
    (1): TransformerBlock(
      (norm1): LayerNorm()
      (att): MultiHeadAttention(
        

In [3]:
start_context = "Hi, I am a large language model"

In [4]:
tokenizer = tiktoken.get_encoding("gpt2")
encoded = tokenizer.encode(start_context)
encoded_tensor = torch.tensor(encoded).unsqueeze(0)  # add batch dimension

In [11]:
print(f"\n{50*'='}\n{22*' '}IN\n{50*'='}")
print(f"Input text: {start_context}")
print(f"Encoded input text: {encoded}")  # encoded token IDs
print(f"Encoded tensor shape: {encoded_tensor.shape}")  # shape: (batch_size, n_tokens)


                      IN
Input text: Hi, I am a large language model
Encoded input text: [17250, 11, 314, 716, 257, 1588, 3303, 2746]
Encoded tensor shape: torch.Size([1, 8])


In [10]:
out = generate_text_simple(
    model=model,
    input_idx=encoded_tensor,
    max_new_tokens=10,
    context_size=GPT_CONFIG_124M["context_length"],
)
decoded_text = tokenizer.decode(out.squeeze(0).tolist())

In [13]:
print(f"\n{50*'='}\n{22*' '}OUT\n{50*'='}")
print(f"Output tensor: {out}")
print(f"Output text: {decoded_text}")


                      OUT
Output tensor: tensor([[17250,    11,   314,   716,   257,  1588,  3303,  2746, 45199, 41518,
         45173, 31263, 23195,  8603,  7384, 10261, 18815, 30220]])
Output text: Hi, I am a large language model fixme Satanic cordsulkan275 equally attacked 93 colony corrobor
