In [1]:
import torch
import torch.nn as nn
from torch.optim import Adam
import numpy as np

In [2]:
torch.__version__

'2.5.1+cu121'

In [3]:
# ============ SIMPLE TEXT DATA ============
text = """the cat sat on the mat. the dog ran in the park.
the cat jumped over the fence. the dog barked at the cat.
the mat was soft and warm. the park was green and quiet."""

# ============ TOKENIZATION ============
# Split into words and create vocab
words = text.lower().replace(".", "").split()
vocab = sorted(set(words))
word_to_id = {w: i for i, w in enumerate(vocab)}
id_to_word = {i: w for w, i in word_to_id.items()}

print(f"Vocab size: {len(vocab)}")
print(f"Vocab: {vocab}\n")

# Convert text to token IDs
tokens = [word_to_id[w] for w in words]
print(f"Tokens: {tokens[:10]}...\n")

# ============ CREATE TRAINING DATA ============
# For each token, predict the next token
seq_len = 3  # Use 3 tokens to predict the 4th
X, y = [], []

for i in range(len(tokens) - seq_len):
    X.append(tokens[i:i+seq_len])
    y.append(tokens[i+seq_len])

X = torch.tensor(X, dtype=torch.long)
y = torch.tensor(y, dtype=torch.long)

print(f"Training examples: {len(X)}")
print(f"Example input: {X[0].tolist()} -> {y[0].item()}")
print(f"Words: {[id_to_word[i.item()] for i in X[0]]} -> {id_to_word[y[0].item()]}\n")

# ============ SIMPLE LLM MODEL ============
class SimpleLLM(nn.Module):
    def __init__(self, vocab_size, embed_dim=16, hidden_dim=32):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.fc1 = nn.Linear(embed_dim * 3, hidden_dim)  # 3 tokens flattened
        self.fc2 = nn.Linear(hidden_dim, vocab_size)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        # x shape: (batch_size, seq_len)
        x = self.embedding(x)  # (batch_size, seq_len, embed_dim)
        x = x.reshape(x.size(0), -1)  # Flatten: (batch_size, seq_len * embed_dim)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# ============ TRAINING ============
model = SimpleLLM(vocab_size=len(vocab), embed_dim=16, hidden_dim=32)
optimizer = Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()

epochs = 100
batch_size = 4

print("Training...\n")
for epoch in range(epochs):
    # Shuffle data
    perm = torch.randperm(len(X))
    X_shuffled = X[perm]
    y_shuffled = y[perm]
    
    total_loss = 0
    
    # Mini-batch training
    for i in range(0, len(X), batch_size):
        X_batch = X_shuffled[i:i+batch_size]
        y_batch = y_shuffled[i:i+batch_size]
        
        # Forward pass
        logits = model(X_batch)
        loss = loss_fn(logits, y_batch)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
    
    avg_loss = total_loss / (len(X) // batch_size)
    
    if (epoch + 1) % 20 == 0:
        print(f"Epoch {epoch+1:3d}, Loss: {avg_loss:.4f}")

print("\nTraining complete!\n")

# ============ INFERENCE ============
def generate(prompt_words, num_generate=5):
    """Generate text given a prompt"""
    model.eval()
    
    # Convert prompt to tokens
    current_tokens = [word_to_id[w] for w in prompt_words]
    generated = prompt_words.copy()
    
    with torch.no_grad():
        for _ in range(num_generate):
            # Get last 3 tokens
            input_tokens = torch.tensor([current_tokens[-3:]], dtype=torch.long)
            logits = model(input_tokens)
            
            # Pick the token with highest probability
            next_token = torch.argmax(logits, dim=1).item()
            next_word = id_to_word[next_token]
            
            generated.append(next_word)
            current_tokens.append(next_token)
    
    return " ".join(generated)

# Test generation
print("Generated text:")
print(f"> {generate(['the', 'cat', 'sat'], num_generate=5)}")
print(f"> {generate(['the', 'dog', 'ran'], num_generate=5)}")
print(f"> {generate(['the', 'mat', 'was'], num_generate=5)}")

Vocab size: 20
Vocab: ['and', 'at', 'barked', 'cat', 'dog', 'fence', 'green', 'in', 'jumped', 'mat', 'on', 'over', 'park', 'quiet', 'ran', 'sat', 'soft', 'the', 'warm', 'was']

Tokens: [17, 3, 15, 10, 17, 9, 17, 4, 14, 7]...

Training examples: 33
Example input: [17, 3, 15] -> 10
Words: ['the', 'cat', 'sat'] -> on

Training...

Epoch  20, Loss: 0.0061
Epoch  40, Loss: 0.0016
Epoch  60, Loss: 0.0008
Epoch  80, Loss: 0.0004
Epoch 100, Loss: 0.0003

Training complete!

Generated text:
> the cat sat on the mat the dog
> the dog ran in the park the cat
> the mat was soft and warm the park
