# Model Pretrain

The script uses the tiny Shakespeare dataset for pretraining and configures the Hugging Face Trainer for efficient and optimized training.

We use follow optimizations for the training:
- Weight initialization from the original SmolLM
- Mixed precision training (fp16)
- torch.compile for acceleration
- Optimizer with the requested configuration:
  - Adam with proper initialization
  - Gradient norm clipping (1.0)
  - Cosine decay learning rate with warmup
  - Weight decay (0.1) excluding LayerNorm and bias parameters
  - Fused Adam implementation when available
  - Large effective batch size through gradient accumulation


The training data setup in this script is based on standard causal language modeling where:
- Input Processing:
The Shakespeare text is split into chunks of 512 tokens (SEQUENCE_LENGTH)
Each chunk is independent, with no overlap between chunks
Only chunks with more than 10 characters are kept
- Label Generation:
The DataCollatorForLanguageModeling(mlm=False) automatically handles label creation
For causal language modeling, it uses the input sequence itself as input
The labels are the same sequence but shifted right by one position
This creates the standard "predict next token" objective
- Text Handling:
The current implementation uses hard cutoffs at SEQUENCE_LENGTH
It doesn't use a sliding window approach (which would have overlapping chunks)
This means context is limited to the chunk boundaries


This approach is simple but effective for pretraining. Each batch contains independent chunks of Shakespeare text, and the model learns to predict the next token given the preceding tokens within each chunk.


If you wanted to improve this approach, you could implement:
- Sliding window chunking instead of hard cutoffs
- Preserving document boundaries
- Better handling of truncation to avoid cutting in the middle of sentences

In [None]:
#!/usr/bin/env python
# Pretraining script for SmolLMForCausalLM model on tiny Shakespeare dataset

import os
import math
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import (
    Trainer,
    TrainingArguments,
    DataCollatorForLanguageModeling,
    get_cosine_schedule_with_warmup,
    set_seed,
)
from transformers.trainer_pt_utils import get_parameter_names
from accelerate import Accelerator
from typing import Dict, List, Optional, Union, Any, Tuple
import gc
import logging
import sys

# Import the custom model implementation
from smol_model import SmolLMConfig, SmolLMForCausalLM

# Configure logging
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
    datefmt="%m/%d/%Y %H:%M:%S",
    handlers=[logging.StreamHandler(sys.stdout)],
    level=logging.INFO,
)
logger = logging.getLogger(__name__)

# Set seed for reproducibility
set_seed(42)

# Constants
MODEL_NAME = "HuggingFaceTB/SmolLM-135M-Instruct"
SEQUENCE_LENGTH = 512
BATCH_SIZE = 16  # Per-device batch size
GRADIENT_ACCUMULATION_STEPS = 4  # Adjust to achieve effective batch size of 0.5M tokens
LEARNING_RATE = 6e-4
WEIGHT_DECAY = 0.1
WARMUP_STEPS = 100
NUM_TRAIN_EPOCHS = 1
LOGGING_STEPS = 10
SAVE_STEPS = 1000
USE_FLASH_ATTENTION = True  # Whether to use Flash Attention


# Dataset preparation
def get_shakespeare_dataset():
    """Load and prepare the tiny Shakespeare dataset for language modeling."""
    # Direct download from raw.githubusercontent if dataset not available
    dataset_url = "https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt"

    # Try loading from HF datasets first
    try:
        logger.info("Attempting to load Shakespeare dataset from Hugging Face...")
        dataset = load_dataset("tiny_shakespeare")
        return dataset
    except:
        # If not available, download and process manually
        logger.info(f"Downloading Shakespeare dataset from {dataset_url}...")
        import requests

        text = requests.get(dataset_url).text

        # Create a simple text dataset
        train_split = int(len(text) * 0.9)
        train_text = text[:train_split]
        val_text = text[train_split:]

        # Convert to HF dataset format
        from datasets import Dataset

        train_dataset = Dataset.from_dict({"text": [train_text]})
        val_dataset = Dataset.from_dict({"text": [val_text]})

        # Tokenize the dataset
        def tokenize_function(examples):
            tokenized = []
            for text in examples["text"]:
                # Create chunks of SEQUENCE_LENGTH
                for i in range(0, len(text), SEQUENCE_LENGTH):
                    chunk = text[i : i + SEQUENCE_LENGTH]
                    if len(chunk) > 10:  # Only keep chunks with meaningful content
                        tokenized.append(chunk)
            return {"text": tokenized}

        train_dataset = train_dataset.map(
            tokenize_function, batched=True, remove_columns=["text"]
        )
        val_dataset = val_dataset.map(
            tokenize_function, batched=True, remove_columns=["text"]
        )

        return {"train": train_dataset, "validation": val_dataset}


# Model initialization with random weights
def init_model_for_training():
    """Initialize the model with random weights for pretraining, with optional Flash Attention."""
    from transformers import AutoConfig

    logger.info("Initializing SmolLM model with random weights...")

    # Load the original configuration for architecture details
    official_config = AutoConfig.from_pretrained(MODEL_NAME, trust_remote_code=True)

    # Create our custom config with Flash Attention enabled if requested
    custom_config = SmolLMConfig(
        official_config=official_config, use_flashattention=USE_FLASH_ATTENTION
    )

    # Initialize our custom model
    logger.info(
        f"Flash Attention is {'enabled' if USE_FLASH_ATTENTION else 'disabled'}"
    )
    model = SmolLMForCausalLM(custom_config)

    # Apply proper weight initialization for a transformer model
    logger.info("Applying weight initialization...")

    def _init_weights(module):
        if isinstance(module, nn.Linear):
            # Use standard initialization for linear layers
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            # Initialize embeddings with normal distribution
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if hasattr(module, "padding_idx") and module.padding_idx is not None:
                with torch.no_grad():
                    module.weight[module.padding_idx].fill_(0)

    # Apply the initialization to all modules
    model.apply(_init_weights)

    # Apply special scaling to final layer norm
    with torch.no_grad():
        if hasattr(model.model, "norm"):
            model.model.norm.weight.fill_(1.0)

    # Tie weights if specified in config
    if model.config.tie_word_embeddings:
        model.tie_weights()
        logger.info("Tied input and output embedding weights")

    # Use torch.compile to speed up training if available
    if torch.cuda.is_available() and hasattr(torch, "compile"):
        logger.info("Applying torch.compile for faster training...")
        model = torch.compile(model)

    return model


# Custom Optimizer with AdamW (fused), weight decay exclusion, and gradient norm clipping
def get_optimizer(model, lr, weight_decay):
    """Create optimizer with proper weight decay exclusion and use fused Adam if available."""
    # Filter parameters that should not have weight decay
    decay_parameters = get_parameter_names(model, [nn.LayerNorm])
    decay_parameters = [name for name in decay_parameters if "bias" not in name]

    # Organize parameters
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if n in decay_parameters],
            "weight_decay": weight_decay,
        },
        {
            "params": [
                p for n, p in model.named_parameters() if n not in decay_parameters
            ],
            "weight_decay": 0.0,
        },
    ]

    # Use fused Adam if available (faster on CUDA)
    if torch.cuda.is_available():
        try:
            from torch.optim.adam import Adam as FusedAdam

            logger.info("Using fused Adam optimizer")
            optimizer = FusedAdam(
                optimizer_grouped_parameters,
                lr=lr,
                betas=(0.9, 0.95),
                eps=1e-8,
                fused=True,
            )
        except ImportError:
            logger.info("Fused Adam not available, using standard AdamW")
            optimizer = torch.optim.AdamW(
                optimizer_grouped_parameters,
                lr=lr,
                betas=(0.9, 0.95),
                eps=1e-8,
            )
    else:
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=lr,
            betas=(0.9, 0.95),
            eps=1e-8,
        )

    return optimizer


def main():
    # 1. Initialize accelerator
    accelerator = Accelerator(
        mixed_precision="fp16", log_with="tensorboard", project_dir="./logs"
    )

    accelerator.print(f"Running on {accelerator.device}")

    # 2. Get dataset
    dataset = get_shakespeare_dataset()

    # 3. Initialize model
    model = init_model_for_training()

    # 4. Create tokenizer from model config
    tokenizer = model.get_input_embeddings()

    # 5. Create data collator
    data_collator = DataCollatorForLanguageModeling(
        tokenizer=tokenizer,
        mlm=False,  # We're doing causal language modeling
    )

    # 6. Prepare training arguments
    training_args = TrainingArguments(
        output_dir="./results",
        overwrite_output_dir=True,
        num_train_epochs=NUM_TRAIN_EPOCHS,
        per_device_train_batch_size=BATCH_SIZE,
        per_device_eval_batch_size=BATCH_SIZE,
        gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
        evaluation_strategy="steps",
        eval_steps=SAVE_STEPS,
        save_strategy="steps",
        save_steps=SAVE_STEPS,
        save_total_limit=2,
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        warmup_steps=WARMUP_STEPS,
        lr_scheduler_type="cosine",
        logging_dir="./logs",
        logging_steps=LOGGING_STEPS,
        report_to="tensorboard",
        fp16=True,  # Mixed precision training
        remove_unused_columns=False,
        dataloader_num_workers=4,
        gradient_checkpointing=True,  # Memory optimization
        max_grad_norm=1.0,  # Gradient norm clipping
    )

    # 7. Create optimizer and scheduler
    optimizer = get_optimizer(model, LEARNING_RATE, WEIGHT_DECAY)
    total_steps = (
        len(dataset["train"])
        // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS)
        * NUM_TRAIN_EPOCHS
    )
    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer=optimizer,
        num_warmup_steps=WARMUP_STEPS,
        num_training_steps=total_steps,
    )

    # 8. Create and start Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=dataset["train"],
        eval_dataset=dataset["validation"],
        data_collator=data_collator,
        optimizers=(optimizer, lr_scheduler),
    )

    # 9. Start training
    logger.info("Starting training...")
    trainer.train()

    # 10. Save final model
    trainer.save_model("./final_model")
    logger.info("Training complete! Model saved to ./final_model")


if __name__ == "__main__":
    main()
