# Training a Tiny Language Model from Scratch

This is a minimalistic example of training a language model from scratch using PyTorch. This code demonstrates:

- Tokenizing a small dataset.
- Converting the text into a sequence of integer indices.
- Creating training samples for next-word prediction.
- Defining a simple feedforward neural network model.
- Training the model on the dataset.
- Use the model to predict next words.

## Libraries and settings

In [None]:
# Libraries
import os
import torch
import torch.nn as nn

# Ignore warnings
import warnings
warnings.filterwarnings("ignore")

# Show current working directory
print("Current working directory:", os.getcwd())



## Dataset

We start with a very small dataset — a short story about three little pigs. This will be used for next-word prediction.


In [None]:
text = (
    "once upon a time there were three little pigs who went out into the world "
    "to build their own houses the first little pig built his house out of straw "
    "the second little pig built his house out of sticks "
    "the third little pig built his house out of bricks "
    "one day a big bad wolf came along and saw the first little pig in his straw house "
    "he knocked on the door and said little pig little pig let me come in "
    "not by the hair of my chinny chin chin said the pig "
    "then I'll huff and I'll puff and I'll blow your house in said the wolf "
    "and he did the straw house blew down and the first little pig ran to his brother's house made of sticks "
    "but the wolf followed and said little pigs little pigs let me come in "
    "not by the hairs of our chinny chin chins they said "
    "so the wolf huffed and puffed and blew the stick house down too "
    "the two little pigs ran to their brother's house made of bricks "
    "the wolf followed once more and said little pigs little pigs let me come in "
    "not by the hairs of our chinny chin chins they said again "
    "so the wolf huffed and puffed but this time he could not blow the house in "
    "the brick house was strong and sturdy and the three little pigs were safe inside"
)


## Tokenization

Next, we tokenize the text by splitting it into words and assigning each unique word a numeric index. This builds our vocabulary.


In [None]:
# Tokenize the text and assign indices
words = list(set(text.split()))
word2idx = {w: i for i, w in enumerate(words)}
idx2word = {i: w for w, i in word2idx.items()}
vocab_size = len(words)
print("Vocabulary Size:", vocab_size)
print("Sample word2idx mapping:", dict(list(word2idx.items())[:5]))


## Encoding

We convert the text into a sequence of integer indices based on the vocabulary created.


In [None]:
# Text encoding
encoded = [word2idx[w] for w in text.split()]
print("Encoded sequence (first 10):", encoded[:10])



## Training Data Preparation

We create input-output pairs using a fixed context window. Each input is a sequence of `context_size` words, and the target is the next word.


In [None]:
# Function to create input-target pairs for training
def get_batches(encoded, context_size=2):
    """
    Creates input-target pairs for training.
    Each input is a sequence of context_size words, 
    and the target is the next word.
    """
    inputs, targets = [], []
    for i in range(len(encoded) - context_size):
        context = encoded[i:i+context_size]
        target = encoded[i+context_size]
        inputs.append(context)
        targets.append(target)
    return torch.tensor(inputs), torch.tensor(targets)

x, y = get_batches(encoded)
print("Input shape:", x.shape)
print("Target shape:", y.shape)
print("Example input-target pair:", x[0], '->', y[0])


## Model Definition

We define a simple feedforward neural network model that uses word embeddings and a linear layer to predict the next word.


In [None]:
# Class for the TinyLLM model
class TinyLLM(nn.Module):
    def __init__(self, vocab_size, embed_dim, context_size):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.fc = nn.Linear(embed_dim * context_size, vocab_size)

    def forward(self, x):
        x = self.embed(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)


## Model Initialization

We initialize the model, optimizer, and loss function.


In [None]:
# Define context size
context_size = 2

# Initialize model
model = TinyLLM(vocab_size, 
                embed_dim=10, 
                context_size=context_size)

# Initialize optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Define loss function
loss_fn = nn.CrossEntropyLoss()


## Training Loop

The model is trained for a number of epochs using cross-entropy loss. The training process involves predicting the next word and updating the model weights accordingly.


In [None]:
# Train the model
for epoch in range(500):
    logits = model(x)
    loss = loss_fn(logits, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

print(f"Final training loss: {loss.item():.4f}")


## Prediction

We define a function to generate the next `n` words given a context of `context_size` words.


In [None]:
# Function to predict the next n words
def predict_next_n(context_words, n=3):
    """
    Predict the next n words given a context of words.
    """
    result = context_words.copy()
    for _ in range(n):
        context_idx = torch.tensor([[word2idx[w] for w in result[-context_size:]]])
        logits = model(context_idx)
        predicted_idx = torch.argmax(logits, dim=1).item()
        predicted_word = idx2word[predicted_idx]
        result.append(predicted_word)
    return result


# Example prediction
print("Prediction example:", predict_next_n(["one", "day"], n=6))

### Jupyter notebook --footer info-- (please always provide this at the end of each notebook)

In [None]:
import os
import platform
import socket
from platform import python_version
from datetime import datetime

print('-----------------------------------')
print(os.name.upper())
print(platform.system(), '|', platform.release())
print('Datetime:', datetime.now().strftime("%Y-%m-%d %H:%M:%S"))
print('Python Version:', python_version())
print('-----------------------------------')