In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from model import *

In [2]:
class NumberSequenceDataset(Dataset):
    """
    A simple dataset that teaches the model to count.
    Sequence: [0, 1, 2, 3, ... vocab_size]
    """
    def __init__(self, vocab_size, seq_len, length=1000):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.length = length
        
    def __len__(self):
        return self.length
        
    def __getitem__(self, idx):
        # Create a random starting point
        start = torch.randint(0, self.vocab_size - self.seq_len - 1, (1,)).item()
        
        # Generate a sequence of consecutive numbers: [10, 11, 12, 13...]
        data = torch.arange(start, start + self.seq_len + 1, dtype=torch.long)
        
        # x is the input: [10, 11, 12]
        # y is the target: [11, 12, 13] (shifted by 1)
        x = data[:-1]
        y = data[1:]
        return x, y


In [3]:
config = Config(vocab_size=100, embed_size=64, seq_len=8, n_layer=2, h=2, d_ff=128, total_epochs=5, lr=1e-3, dropout=0.0)    
print(f"Running on: {config.device}")

model = GPT.build_gpt(config)
print("Model built successfully.")

Running on: cpu
Model built successfully.


In [4]:
dataset = NumberSequenceDataset(config.vocab_size, config.seq_len, length=500)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

batch = next(iter(train_loader))
print(f"x: {batch[0].shape}\ny: {batch[1].shape}")

x: torch.Size([32, 8])
y: torch.Size([32, 8])


In [5]:
print("\nStarting Training...")
try:
    model.train_gpt(train_loader)
except KeyboardInterrupt:
    print("Training interrupted.")


Starting Training...


0/5: 100%|███████████████████████████| 16/16 [00:00<00:00, 67.23it/s, loss=4.12]


Epoch: 0/5 Loss: 4.429825156927109


1/5: 100%|███████████████████████████| 16/16 [00:00<00:00, 75.80it/s, loss=3.18]
2/5: 100%|███████████████████████████| 16/16 [00:00<00:00, 79.77it/s, loss=2.37]
3/5: 100%|███████████████████████████| 16/16 [00:00<00:00, 66.00it/s, loss=1.73]
4/5: 100%|███████████████████████████| 16/16 [00:00<00:00, 69.97it/s, loss=1.25]


In [6]:
print("\nStarting Generation Test...")
model.eval()

# Start with number [10]
# It should generate [11, 12, 13, 14...]
start_token = torch.tensor([[10]], dtype=torch.long).to(config.device)

# Generate 15 new tokens
generated = model.generate(start_token, max_new_token=15, top_k=1)

# Convert to simple list for printing
result = generated[0].tolist()

print(f"Input: [10]")
print(f"Generated: {result}")

# Verification
if result == list(range(10, 10 + 1 + 15)):
    print("\nSUCCESS: The model learned to count!")
else:
    print("\nPARTIAL: The model produced valid output but maybe not perfect counting yet.")


Starting Generation Test...
Input: [10]
Generated: [10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]

SUCCESS: The model learned to count!
