In [None]:
# imports
import torch
from importlib import reload
import transformer
reload(transformer)
from transformer import GPT, GPTConfig, train

In [None]:
# Create a small GPT model
config = GPTConfig(
    context_size=256,
    vocab_size=50257,
    n_layer=6,
    n_head=6,
    n_embd=384
)
model = GPT(config)

# Print model info
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params / 1e6:.2f}M")
print(f"Config: {config}")

In [None]:
# Forward pass with random tokens
tokens = torch.randint(0, 50257, (4, 128))  # Batch of 4, seq len 128
logits, loss = model(tokens)  # logits: (4, 128, 50257)
print(f"Input shape: {tokens.shape}")
print(f"Output logits shape: {logits.shape}")
print(f"Loss (no targets): {loss}")

In [None]:
# Forward pass with targets (for training)
targets = torch.randint(0, 50257, (4, 128))
logits, loss = model(tokens, targets)
print(f"Loss (with targets): {loss.item():.4f}")

In [None]:
# Text generation (with random tokens as prompt)
prompt = torch.randint(0, 50257, (1, 10))  # Starting tokens
print(f"Prompt shape: {prompt.shape}")

generated = model.generate(prompt, max_new_tokens=50, temperature=0.8, top_k=40)
print(f"Generated shape: {generated.shape}")
print(f"Generated tokens: {generated[0].tolist()}")

In [None]:
# Simple training example with a dummy dataloader
class DummyDataLoader:
    """Simple dataloader that generates random token sequences."""
    def __init__(self, vocab_size, seq_len, batch_size):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.batch_size = batch_size
    
    def next_batch(self):
        # Generate random input and target sequences
        # Target is input shifted by 1 (next token prediction)
        data = torch.randint(0, self.vocab_size, (self.batch_size, self.seq_len + 1))
        x = data[:, :-1]
        y = data[:, 1:]
        return x, y

# Create dataloader
dataloader = DummyDataLoader(vocab_size=50257, seq_len=128, batch_size=4)

# Create fresh model and optimizer
model = GPT(config)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

# Train for a few steps
model = train(
    model=model,
    dataloader=dataloader,
    n_steps=20,
    optimizer=optimizer,
    device='cpu',
    log_interval=5
)