In [2]:
!pip install "transformers==4.40.1" huggingface_hub sentencepiece tokenizers



In [8]:
#!/usr/bin/env python3
"""
Resume Training SmolLM2-135M from a Checkpoint
Fixes:
- Dataset chunk size issues
- PyTorch 2.6+ safe loading of checkpoints
- Handles token sequences safely for max_position_embeddings
- Mixed precision support
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import AutoTokenizer
from dataclasses import dataclass, fields
import json, os
from tqdm import tqdm
from torch.serialization import add_safe_globals

# =============================
# Device Configuration
# =============================
def get_device_config():
    if torch.cuda.is_available():
        return torch.device("cuda"), True, torch.bfloat16, f"CUDA ({torch.cuda.get_device_name(0)})"
    elif torch.backends.mps.is_available():
        return torch.device("mps"), True, torch.float16, "Apple Silicon (MPS)"
    else:
        return torch.device("cpu"), False, torch.float32, "CPU"

# =============================
# Load config from JSON
# =============================
@dataclass
class SmolLM2Config:
    vocab_size: int = 49152
    hidden_size: int = 576
    intermediate_size: int = 1536
    num_hidden_layers: int = 30
    num_attention_heads: int = 9
    num_key_value_heads: int = 3
    max_position_embeddings: int = 8192
    rms_norm_eps: float = 1e-5
    rope_theta: float = 100000.0
    attention_dropout: float = 0.0
    hidden_dropout: float = 0.0
    initializer_range: float = 0.041666666666666664
    tie_word_embeddings: bool = True

    def __post_init__(self):
        self.head_dim = self.hidden_size // self.num_attention_heads

    @classmethod
    def from_json(cls, path):
        with open(path, "r") as f:
            cfg = json.load(f)
        allowed_keys = {f.name for f in fields(cls)}
        filtered_cfg = {k: v for k, v in cfg.items() if k in allowed_keys}
        return cls(**filtered_cfg)

# =============================
# Dataset
# =============================
class TextDataset(Dataset):
    def __init__(self, file_path, tokenizer, block_size=32, batch_size=4):
        text = open(file_path, "r", encoding="utf-8").read()
        ids = tokenizer.encode(text, add_special_tokens=False)
        self.block_size = block_size
        self.batch_size = batch_size
        tp = block_size * batch_size

        # Trim to multiple of (tp+1)
        total_tokens = len(ids) - ((len(ids) - 1) % (tp + 1))
        ids = ids[:total_tokens]

        self.data = torch.tensor(ids, dtype=torch.long)
        self.num_batches = len(self.data) // (tp + 1)

    def __len__(self):
        return self.num_batches

    def __getitem__(self, idx):
        tp = self.block_size * self.batch_size
        start = idx * (tp + 1)
        chunk = self.data[start:start + tp + 1]
        if len(chunk) != tp + 1:
            # pad last chunk (safety)
            pad_len = tp + 1 - len(chunk)
            chunk = torch.cat([chunk, torch.zeros(pad_len, dtype=torch.long)])
        x = chunk[:-1].view(self.batch_size, self.block_size)
        y = chunk[1:].view(self.batch_size, self.block_size)
        return x, y

# =============================
# SmolLM2 Model Classes
# =============================
class RMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.eps = eps
    def forward(self, x):
        var = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(var + self.eps)
        return self.weight * x

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=8192, base=100000.0):
        super().__init__()
        self.dim = dim
        self.base = base
        inv_freq = 1.0 / (self.base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)
    def forward(self, seq_len, device):
        t = torch.arange(seq_len, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        return emb.cos()[None, :, :], emb.sin()[None, :, :]

def rotate_half(x):
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2:]
    return torch.cat([-x2, x1], dim=-1)

def apply_rotary(q, k, cos, sin):
    return (q * cos + rotate_half(q) * sin, k * cos + rotate_half(k) * sin)

class GroupedQueryAttention(nn.Module):
    def __init__(self, cfg: SmolLM2Config):
        super().__init__()
        self.cfg = cfg
        self.q_proj = nn.Linear(cfg.hidden_size, cfg.num_attention_heads * cfg.head_dim, bias=False)
        self.k_proj = nn.Linear(cfg.hidden_size, cfg.num_key_value_heads * cfg.head_dim, bias=False)
        self.v_proj = nn.Linear(cfg.hidden_size, cfg.num_key_value_heads * cfg.head_dim, bias=False)
        self.o_proj = nn.Linear(cfg.num_attention_heads * cfg.head_dim, cfg.hidden_size, bias=False)
        self.rotary = RotaryEmbedding(cfg.head_dim, max_position_embeddings=cfg.max_position_embeddings, base=cfg.rope_theta)
        self.num_key_value_groups = cfg.num_attention_heads // cfg.num_key_value_heads

    def forward(self, x):
        B, T, _ = x.shape
        q = self.q_proj(x).view(B, T, self.cfg.num_attention_heads, self.cfg.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.cfg.num_key_value_heads, self.cfg.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.cfg.num_key_value_heads, self.cfg.head_dim).transpose(1, 2)
        cos, sin = self.rotary(T, x.device)
        q, k = apply_rotary(q, k, cos, sin)
        k = k.repeat_interleave(self.num_key_value_groups, dim=1)
        v = v.repeat_interleave(self.num_key_value_groups, dim=1)
        attn = F.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.cfg.attention_dropout if self.training else 0.0, is_causal=True)
        attn = attn.transpose(1, 2).contiguous().view(B, T, self.cfg.hidden_size)
        return self.o_proj(attn)

class MLP(nn.Module):
    def __init__(self, cfg: SmolLM2Config):
        super().__init__()
        self.gate = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
        self.up = nn.Linear(cfg.hidden_size, cfg.intermediate_size, bias=False)
        self.down = nn.Linear(cfg.intermediate_size, cfg.hidden_size, bias=False)
    def forward(self, x):
        return self.down(F.silu(self.gate(x)) * self.up(x))

class DecoderLayer(nn.Module):
    def __init__(self, cfg: SmolLM2Config):
        super().__init__()
        self.ln1 = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
        self.attn = GroupedQueryAttention(cfg)
        self.ln2 = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
        self.mlp = MLP(cfg)
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

class SmolLM2Model(nn.Module):
    def __init__(self, cfg: SmolLM2Config):
        super().__init__()
        self.cfg = cfg
        self.embed_tokens = nn.Embedding(cfg.vocab_size, cfg.hidden_size)
        self.layers = nn.ModuleList([DecoderLayer(cfg) for _ in range(cfg.num_hidden_layers)])
        self.ln_f = RMSNorm(cfg.hidden_size, eps=cfg.rms_norm_eps)
    def forward(self, input_ids):
        x = self.embed_tokens(input_ids)
        for layer in self.layers:
            x = layer(x)
        x = self.ln_f(x)
        return x

class SmolLM2ForCausalLM(nn.Module):
    def __init__(self, config: SmolLM2Config):
        super().__init__()
        self.config = config
        self.model = SmolLM2Model(config)
        if config.tie_word_embeddings:
            self.lm_head = None
        else:
            self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
        self.apply(self._init_weights)
    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
    def forward(self, input_ids, labels=None):
        hidden_states = self.model(input_ids)
        if self.config.tie_word_embeddings:
            logits = F.linear(hidden_states, self.model.embed_tokens.weight)
        else:
            logits = self.lm_head(hidden_states)
        loss = None
        if labels is not None:
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            loss = F.cross_entropy(
                shift_logits.view(-1, self.config.vocab_size),
                shift_labels.view(-1),
                ignore_index=-100
            )
        return logits, loss

# =============================
# Resume Training
# =============================
@dataclass
class TrainingConfig:
    input_file: str = "input.txt"
    block_size: int = 32
    batch_size: int = 4
    gradient_accumulation_steps: int = 16
    learning_rate: float = 6e-4
    min_lr: float = 6e-5
    warmup_steps: int = 100
    weight_decay: float = 0.1
    beta1: float = 0.9
    beta2: float = 0.95
    grad_clip: float = 1.0
    checkpoint_dir: str = "checkpoints"
    save_every: int = 500
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    use_amp: bool = True
    resume_checkpoint: str = "checkpoint_step_5000.pt"
    max_steps: int = 50  # Continue for 50 steps

# =============================
# Training function
# =============================
def train_resume():
    device, use_amp, amp_dtype, device_name = get_device_config()
    cfg = SmolLM2Config.from_json("config.json")
    tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")
    dataset = TextDataset(TrainingConfig().input_file, tokenizer, block_size=32, batch_size=4)
    dataloader = DataLoader(dataset, batch_size=1)
    all_batches = [b for b in dataloader]

    model = SmolLM2ForCausalLM(cfg).to(device)
    try:
        optimizer = torch.optim.AdamW(model.parameters(), lr=TrainingConfig().learning_rate,
                                      betas=(TrainingConfig().beta1, TrainingConfig().beta2),
                                      weight_decay=TrainingConfig().weight_decay, fused=True)
    except:
        optimizer = torch.optim.AdamW(model.parameters(), lr=TrainingConfig().learning_rate,
                                      betas=(TrainingConfig().beta1, TrainingConfig().beta2),
                                      weight_decay=TrainingConfig().weight_decay)

    # Allow loading config object from checkpoint
    add_safe_globals([SmolLM2Config])
    ckpt = torch.load(TrainingConfig().resume_checkpoint, map_location=device, weights_only=False)
    model.load_state_dict(ckpt["model_state_dict"])
    optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    start_step = ckpt["global_step"]

    print(f"Resuming from step {start_step}")

    global_step = start_step
    micro_batch_count = 0
    optimizer.zero_grad()
    pbar = tqdm(total=start_step + TrainingConfig().max_steps, desc="Training", dynamic_ncols=True, initial=start_step)

    while global_step < start_step + TrainingConfig().max_steps:
        x, y = all_batches[micro_batch_count % len(all_batches)]
        x, y = x.squeeze(0).to(device), y.squeeze(0).to(device)

        if use_amp:
            with torch.autocast(device.type, dtype=amp_dtype):
                logits, loss = model(x, labels=y)
                loss = loss / TrainingConfig().gradient_accumulation_steps
        else:
            logits, loss = model(x, labels=y)
            loss = loss / TrainingConfig().gradient_accumulation_steps

        loss.backward()
        micro_batch_count += 1

        if micro_batch_count % TrainingConfig().gradient_accumulation_steps == 0:
            torch.nn.utils.clip_grad_norm_(model.parameters(), TrainingConfig().grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1
            pbar.update(1)
            pbar.set_postfix({'loss': f'{loss.item() * TrainingConfig().gradient_accumulation_steps:.4f}'})

            # Save checkpoint if needed
            if global_step % TrainingConfig().save_every == 0:
                os.makedirs(TrainingConfig().checkpoint_dir, exist_ok=True)
                checkpoint_path = os.path.join(TrainingConfig().checkpoint_dir, f"checkpoint_step_{global_step}.pt")
                torch.save({
                    'global_step': global_step,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'config': cfg
                }, checkpoint_path)
                print(f"\n✓ Checkpoint saved: {checkpoint_path}")

    pbar.close()
    print("✅ Resume training complete!")

# =============================
# Run training
# =============================
if __name__ == "__main__":
    train_resume()


Token indices sequence length is longer than the specified maximum sequence length for this model (341094 > 8192). Running this sequence through the model will result in indexing errors


Resuming from step 5000


Training: 100%|██████████| 5050/5050 [01:58<00:00,  2.38s/it, loss=1.0300]

✅ Resume training complete!



