In [None]:
from utils import ModelConfig, TrainingConfig, StreamingTextDataset, Trainer, ToyTransformer
import json
import torch
from datasets import load_dataset
from torch.utils.data import DataLoader
vocab_size = 10000 # Reduced tinystories

# Configure model - using real tokenizer vocab size now
model_config = ModelConfig(
    model_type='attention_only_1L',  # Change this to experiment with different architectures
    # model_type='attention_only_2L',  # Change this to experiment with different architectures
    vocab_size=vocab_size,  # GPT-2 tokenizer size
    d_model=512,  # Moderate size for experiments
    n_head=8,
    n_ctx=512,  # Shorter context for faster training
    dropout=0.1
)

training_config = TrainingConfig(
    model_config=model_config,
    batch_size=16,  # Adjust based on GPU memory
    learning_rate=3e-3,
    eval_interval=500,
    log_interval=50,
    max_epochs=1, 
)

    

dataset_name = 'noanabeshima/TinyModelTokIds'
split = 'train'
batch_size = 32

# For training
train_dataset = load_dataset(dataset_name, split='train')
# train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# For validation
val_dataset = load_dataset(dataset_name, split='validation')
# val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)


In [2]:
import numpy as np
from tokenization.tokenization import tokenizer
# Create model
model = ToyTransformer(model_config)
print(f"Model type: {model_config.model_type}")
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

# Create trainer
trainer = Trainer(model, training_config)

# Training loop
print("Starting training...")
for epoch in range(training_config.max_epochs):
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    for batch_idx, batch in enumerate(train_dataloader):
        batch = torch.stack(batch.get("tok_ids")).T
        # Train step
        loss, lr = trainer.train_step(batch)
        
        # Logging
        if iter % training_config.log_interval == 0:
            print(f"Iter {iter}: loss={loss:.4f}, lr={lr:.6f}")
            
        # Evaluation
        with torch.no_grad():
            if iter % training_config.eval_interval == 0 and iter > 0:
                val_losses = []
                val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
                for val_idx, batch in enumerate(val_dataloader):
                    batch = torch.stack(batch.get("tok_ids")).T
                    # Train step
                    loss, lr = trainer.train_step(batch)
                    val_losses.append(loss.item())
                val_loss = np.mean(val_losses)
                print(f"Validation loss: {val_loss:.4f}")
                
                # Generate sample text
                model.eval()
                context = torch.zeros((1, 1), dtype=torch.long, device='cuda')
                generated = model.generate(context, max_new_tokens=100, temperature=0.8)
                print(f"Sample generation: {tokenizer.decode(generated[0].tolist())}")
                model.train()

print("Training complete!")

# Save model
torch.save(model.state_dict(), f'toy_transformer_{model_config.model_type}.pt')
print(f"Model saved to toy_transformer_{model_config.model_type}.pt")
# turn model config into a dictionary
model_config_dict = model_config.__dict__
# save model config to a json file
with open(f'{model_config.model_type}_config.json', 'w') as f:
    json.dump(model_config_dict, f)

Model type: attention_only_1L
Parameters: 6,170,176
Starting training...


AttributeError: 'TrainingConfig' object has no attribute 'max_epochs'

In [3]:
from utils import visualize_prediction_error

# grab the first sample from the training dataset
x, y = train_dataset.get_batch(1)

# visualize the prediction error
text = val_dataset.tokenizer.decode(x[0].tolist())
visualize_prediction_error(model, val_dataset.tokenizer, text)

In [5]:
model

ToyTransformer(
  (embed): Embedding(50257, 512)
  (dropout): Dropout(p=0.1, inplace=False)
  (layers): ModuleList(
    (0-1): 2 x QuadraticAttention(
      (rotary): Rotary()
      (norm): RMSNorm((64,), eps=None, elementwise_affine=True)
      (mask): Mask()
      (q): Linear(in_features=512, out_features=512, bias=True)
      (k): Linear(in_features=512, out_features=512, bias=True)
      (v): Linear(in_features=512, out_features=512, bias=True)
      (o): Linear(in_features=512, out_features=512, bias=False)
    )
  )
  (head): Linear(in_features=512, out_features=50257, bias=False)
)