## SmolLM2-135
A reverse-engineered implementation of Hugging Face's SmolLM2 135M model in PyTorch.

## Imports

In [None]:
# Import statements
import os
import math
import time
import torch
import torch.nn as nn
from torch.nn import functional as F
from typing import Optional, Tuple
from dataclasses import dataclass
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.amp import autocast, GradScaler
from transformers import AutoTokenizer
import time

In [None]:
# Mount drive
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# Set working directory
work_dir = "/content/drive/MyDrive/ERAV4/Session_13"
os.chdir(work_dir)
os.getcwd()

'/content/drive/MyDrive/ERAV4/Session_13'

In [None]:
import os
print("CWD:", os.getcwd())
print("Files here:", os.listdir())

CWD: /content/drive/MyDrive/ERAV4/Session_13
Files here: ['input.txt', 'model']


# Model

## Details

SmolLM2-135M Model Architecture Implementation
<br>Based on Llama2 architecture with Grouped Query Attention (GQA)

Architecture Details:
- Parameters: 135M
- Layers: 30
- Hidden Size: 576
- Intermediate Size: 1536
- Attention Heads: 9
- Key-Value Heads: 3 (GQA)
- Vocabulary Size: 49152
- Max Sequence Length: 2048
- RoPE Theta: 10000.0
- Training Tokens: 1M (originally 2T)

## Layer Normalization

In [None]:
class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization"""

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.hidden_size = hidden_size

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        variance = x.pow(2).mean(-1, keepdim=True)
        x = x * torch.rsqrt(variance + self.eps)
        return self.weight * x

    def extra_repr(self) -> str:
        return f'({self.hidden_size},), eps={self.eps}'

## Positional Encoding

In [None]:
class RotaryEmbedding(nn.Module):
    """Rotary Position Embedding (RoPE)"""

    def __init__(self, dim: int, max_position_embeddings: int = 2048, base: float = 10000.0):
        super().__init__()
        self.dim = dim
        self.max_position_embeddings = max_position_embeddings
        self.base = base

        # Compute inverse frequencies
        inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float() / self.dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build position indices
        self._set_cos_sin_cache(
            seq_len=max_position_embeddings,
            device=self.inv_freq.device,
            dtype=torch.get_default_dtype()
        )

    def _set_cos_sin_cache(self, seq_len: int, device: torch.device, dtype: torch.dtype):
        self.max_seq_len_cached = seq_len
        t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
        freqs = torch.outer(t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos().to(dtype), persistent=False)
        self.register_buffer("sin_cached", emb.sin().to(dtype), persistent=False)

    def forward(self, x: torch.Tensor, seq_len: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        # x: [batch_size, num_heads, seq_len, head_dim]
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)

        return (
            self.cos_cached[:seq_len].to(dtype=x.dtype),
            self.sin_cached[:seq_len].to(dtype=x.dtype),
        )


def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
    """Apply rotary position embedding to query and key tensors."""
    # q, k: [batch_size, num_heads, seq_len, head_dim]
    # cos, sin: [seq_len, head_dim]
    cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq_len, head_dim]
    sin = sin.unsqueeze(0).unsqueeze(0)

    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

## Attention (Grouped Query Attention)

In [None]:
class SiLUActivation(nn.Module):
    """SiLU (Swish) activation function"""

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return F.silu(x)

In [None]:
class GroupedQueryAttention(nn.Module):
    """Multi-Head Attention with Grouped Query Attention (GQA)"""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.max_position_embeddings = config.max_position_embeddings
        self.rope_theta = config.rope_theta

        if (self.head_dim * self.num_heads) != self.hidden_size:
            raise ValueError(
                f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {self.num_heads})."
            )

        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        cos: Optional[torch.Tensor] = None,
        sin: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size, seq_length, _ = hidden_states.size()

        # Project queries, keys, values
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Reshape to [batch_size, seq_len, num_heads, head_dim]
        query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # Apply rotary embeddings (passed from model level)
        if cos is not None and sin is not None:
            query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # Repeat k/v heads if num_key_value_heads < num_heads (GQA)
        key_states = self._repeat_kv(key_states, self.num_key_value_groups)
        value_states = self._repeat_kv(value_states, self.num_key_value_groups)

        # Flash attention (Speedup 4)
        # Memory efficient and faster than manual attention computation
        # is_causal=True handles causal masking automatically, no explicit mask needed
        attn_output = F.scaled_dot_product_attention(
            query_states,
            key_states,
            value_states,
            attn_mask=None,
            is_causal=True
        )

        # Reshape back to [batch_size, seq_len, hidden_size]
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)

        # Output projection
        attn_output = self.o_proj(attn_output)

        return attn_output

    def _repeat_kv(self, hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
        """
        Repeat key/value tensors n_rep times for grouped query attention.
        """
        batch, num_key_value_heads, slen, head_dim = hidden_states.shape
        if n_rep == 1:
            return hidden_states
        hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
        return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

## MLP Block

In [None]:
class MLP(nn.Module):
    """Feed-forward network with SwiGLU activation"""

    def __init__(self, config):
        super().__init__()
        self.config = config
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = SiLUActivation()  # SwiGLU uses SiLU activation

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # SwiGLU: (silu(gate(x)) * up(x)) @ down
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))

## Transformer Block (Decoder)

In [None]:
class DecoderLayer(nn.Module):
    """Transformer decoder layer"""

    def __init__(self, config):
        super().__init__()
        self.hidden_size = config.hidden_size

        self.self_attn = GroupedQueryAttention(config)
        self.mlp = MLP(config)

        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        cos: Optional[torch.Tensor] = None,
        sin: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        residual = hidden_states

        # Self-attention with pre-norm
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_ids=position_ids,
            cos=cos,
            sin=sin,
        )
        hidden_states = residual + hidden_states

        # Feed-forward with pre-norm
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states

## Model Config

In [None]:
class SmolLM2Config:
    """Configuration class for SmolLM2-135M"""

    def __init__(
        self,
        vocab_size: int = 49152,  # Speedup 5 (Power of 2)
        hidden_size: int = 576,
        intermediate_size: int = 1536,
        num_hidden_layers: int = 30,
        num_attention_heads: int = 9,
        num_key_value_heads: int = 3,
        max_position_embeddings: int = 2048,
        rms_norm_eps: float = 1e-5,
        rope_theta: float = 10000.0,
        pad_token_id: int = 0,
        bos_token_id: int = 1,
        eos_token_id: int = 2,
        tie_word_embeddings: bool = True,
        **kwargs
    ):
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size
        self.num_hidden_layers = num_hidden_layers
        self.num_attention_heads = num_attention_heads
        self.num_key_value_heads = num_key_value_heads
        self.max_position_embeddings = max_position_embeddings
        self.rms_norm_eps = rms_norm_eps
        self.rope_theta = rope_theta
        self.pad_token_id = pad_token_id
        self.bos_token_id = bos_token_id
        self.eos_token_id = eos_token_id
        self.tie_word_embeddings = tie_word_embeddings

## Model Architecture

In [None]:
class SmolLM2Model(nn.Module):
    """SmolLM2-135M Model"""

    def __init__(self, config: SmolLM2Config):
        super().__init__()
        self.config = config
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
        self.layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.num_hidden_layers)])
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.rotary_emb = RotaryEmbedding(
            config.hidden_size // config.num_attention_heads,
            max_position_embeddings=config.max_position_embeddings,
            base=config.rope_theta,
        )

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        std = 0.02
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
    ) -> torch.Tensor:
        batch_size, seq_length = input_ids.shape

        # Embed tokens
        hidden_states = self.embed_tokens(input_ids)

        # Get rotary embeddings
        cos, sin = self.rotary_emb(hidden_states, seq_len=seq_length)

        # Apply transformer layers
        # Note: Flash attention handles causal masking automatically via is_causal=True
        for decoder_layer in self.layers:
            hidden_states = decoder_layer(
                hidden_states,
                attention_mask=None,
                position_ids=position_ids,
                cos=cos,
                sin=sin,
            )

        # Final layer norm
        hidden_states = self.norm(hidden_states)

        return hidden_states

    def _prepare_decoder_attention_mask(self, attention_mask, input_shape, dtype, device):
        """Prepare causal attention mask"""
        batch_size, seq_length = input_shape

        # Create causal mask
        causal_mask = torch.full((seq_length, seq_length), torch.finfo(dtype).min, device=device)
        mask_cond = torch.arange(causal_mask.size(-1), device=device)
        causal_mask.masked_fill_(mask_cond < (mask_cond + 1).view(causal_mask.size(-1), 1), 0)
        causal_mask = causal_mask.to(dtype)

        # Expand to batch size
        causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, seq_length, seq_length)

        # Apply attention mask if provided
        if attention_mask is not None:
            expanded_mask = attention_mask[:, None, None, :].to(dtype)
            expanded_mask = (1.0 - expanded_mask) * torch.finfo(dtype).min
            causal_mask = causal_mask + expanded_mask

        return causal_mask

In [None]:
class SmolLM2ForCausalLM(nn.Module):
    """SmolLM2 model with language modeling head"""

    def __init__(self, config: SmolLM2Config):
        super().__init__()
        self.config = config
        self.model = SmolLM2Model(config)
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        # Tie weights between embedding and lm_head if specified
        if config.tie_word_embeddings:
            self.lm_head.weight = self.model.embed_tokens.weight

    def forward(
        self,
        input_ids: torch.LongTensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.LongTensor] = None,
    ) -> Tuple[Optional[torch.Tensor], torch.Tensor]:
        # Forward pass through model
        hidden_states = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
        )

        # Compute logits
        logits = self.lm_head(hidden_states)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Flatten the tokens
            loss_fct = nn.CrossEntropyLoss()
            shift_logits = shift_logits.view(-1, self.config.vocab_size)
            shift_labels = shift_labels.view(-1)

            # Enable model parallelism
            shift_labels = shift_labels.to(shift_logits.device)
            loss = loss_fct(shift_logits, shift_labels)

        return loss, logits

    def generate(
        self,
        input_ids: torch.LongTensor,
        max_new_tokens: int = 50,
        temperature: float = 1.0,
        top_k: Optional[int] = None,
        top_p: Optional[float] = None,
    ) -> torch.LongTensor:
        """Simple greedy generation"""
        self.eval()

        for _ in range(max_new_tokens):
            # Forward pass
            with torch.no_grad():
                _, logits = self.forward(input_ids)

            # Get logits for last token
            next_token_logits = logits[:, -1, :] / temperature

            # Apply top-k filtering
            if top_k is not None:
                indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None]
                next_token_logits[indices_to_remove] = float('-inf')

            # Apply top-p (nucleus) filtering
            if top_p is not None:
                sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True)
                cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

                # Remove tokens with cumulative probability above the threshold
                sorted_indices_to_remove = cumulative_probs > top_p
                sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
                sorted_indices_to_remove[..., 0] = 0

                indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
                next_token_logits[indices_to_remove] = float('-inf')

            # Sample next token
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Append to sequence
            input_ids = torch.cat([input_ids, next_token], dim=-1)

            # Stop if EOS token generated
            if next_token.item() == self.config.eos_token_id:
                break

        return input_ids

In [None]:
def count_parameters(model):
    """Count total and trainable parameters"""
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return total_params, trainable_params

## Test Model

In [None]:
# Test model initialization
config = SmolLM2Config()
model = SmolLM2ForCausalLM(config)

total_params, trainable_params = count_parameters(model)
print(f"SmolLM2-135M Model initialized")
print(f"Total parameters: {total_params:,} ({total_params/1e6:.1f}M)")
print(f"Trainable parameters: {trainable_params:,}")

# Test forward pass
batch_size = 2
seq_length = 32
input_ids = torch.randint(0, config.vocab_size, (batch_size, seq_length))

print(f"\nTesting forward pass with input shape: {input_ids.shape}")
loss, logits = model(input_ids, labels=input_ids)
print(f"Output logits shape: {logits.shape}")
print(f"Loss: {loss.item():.4f}")
print("\nModel architecture verified successfully!")

SmolLM2-135M Model initialized
Total parameters: 134,515,008 (134.5M)
Trainable parameters: 134,515,008

Testing forward pass with input shape: torch.Size([2, 32])
Output logits shape: torch.Size([2, 32, 49152])
Loss: 10.8403

Model architecture verified successfully!


## Verify Model Architecture

### Model details comparison

In [None]:
# # Model details
# print(model)

In [None]:
# # Model from HF
# from transformers import AutoModelForCausalLM
# model_name = "HuggingFaceTB/SmolLM2-135M"
# hf_model = AutoModelForCausalLM.from_pretrained(model_name)
# print(hf_model)  # Model details

### Model weights loading check

In [None]:
# # Load HF model weights
# hf_weights = hf_model.state_dict()
# model.load_state_dict(hf_weights)

All keys matched successfully

# Train

## Device

In [None]:
# DEVICE
device = 'cpu'
if torch.cuda.is_available():
    device = 'cuda'
elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    device = 'mps'
print(f"Using device: {device}")

# SEED
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed(42)

Using device: cuda


## Dataloader

In [None]:
# DATALOADER
from transformers import AutoTokenizer
class DataLoaderLite:
    """
    A data loader that loads tokens from a text file and provides batches for training.
    """

    def __init__(self, B, T):
        self.B = B
        self.T = T

        # at init load tokens from disk and store them in memory
        with open('input.txt', 'r') as f:
            text = f.read()
        enc = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")  # SmolLM2-135M model
        tokens = enc.encode(text)
        self.tokens = torch.tensor(tokens)
        print(f'Loaded {len(self.tokens)} tokens')
        print(f'Batch size = {B * T} tokens')
        # print(f'1 epoch = {len(self.tokens) // (B * T)} batches')
        print(f'1 Step = {len(self.tokens) // (B * T)} batches\n')

        # state
        self.current_position = 0

    def next_batch(self):
        B, T = self.B, self.T
        buf = self.tokens[self.current_position: self.current_position + B * T + 1]
        x = (buf[:-1]).view(B, T) # inputs
        y = (buf[1:]).view(B, T) # targets
        # advance the position in the tensor
        self.current_position += B*T
        # if loading the next batch would be out of bounds, reset
        if self.current_position + (B * T + 1) > len(self.tokens):
            self.current_position = 0
        return x, y

## Training Loop

In [None]:
# FOR SAMPLE RESPONSE

# Encoder
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M")  # SmolLM2-135M model

# Sample prompt
prompt = "To be or not to be"
input = tokenizer.encode(prompt, return_tensors="pt").to(device)

In [None]:

# MODEL / DATALOADER
torch.set_float32_matmul_precision('high')  # Speedup 1 (Matrix multiplication)

model = SmolLM2ForCausalLM(SmolLM2Config())
model.to(device)
model = torch.compile(model)  # Speedup 3 (Linux-based compile)

train_loader = DataLoaderLite(B = 2, T = 1024)
# optimizer = torch.optim.AdamW(model.parameters(), lr = 3e-4)

Token indices sequence length is longer than the specified maximum sequence length for this model (341094 > 8192). Running this sequence through the model will result in indexing errors


Loaded 341094 tokens
Batch size = 2048 tokens
1 Step = 166 batches



In [None]:
# OPTIMIZER / SCHEDULER

# Optimizer configuration
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=5e-4,  # Peak learning rate (will be scaled by scheduler)
    betas=(0.9, 0.95),  # Beta values optimized for transformers
    eps=1e-8,
    weight_decay=0.1  # L2 regularization
)

# Training configuration
total_steps = 5000
warmup_steps = 500  # 10% warmup
save_interval = 500  # Model checkpoint

# Learning rate scheduler: Cosine annealing with warmup
max_steps = total_steps
min_lr = 5e-5  # 10% of peak lr

def get_lr(step):
    # Warmup phase
    if step < warmup_steps:
        return 5e-4 * step / warmup_steps

    # Cosine annealing phase
    if step > max_steps:
        return min_lr

    decay_ratio = (step - warmup_steps) / (max_steps - warmup_steps)
    coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
    return min_lr + coeff * (5e-4 - min_lr)

In [None]:
# TRAINING LOOP
for step in range(1, total_steps+1):
    t0 = time.time()

    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    with torch.autocast(device_type=device, dtype=torch.bfloat16):  # Speedup 2 (Data type)
        loss, logits = model(x, labels=y)
    loss.backward()

    # Update learning rate according to schedule
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Gradient clipping for stability
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    torch.cuda.synchronize()
    t1 = time.time()
    dt = (t1 - t0) * 1000
    tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
    print(f'[TRAIN] Step {step} | Loss: {loss.item():.3f} | LR: {lr:.6f} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec: .2f}')

    if step % save_interval == 0:
        # Save model checkpoint periodically
        torch.save({
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'step': step,
            'loss': loss.item(),
        }, f'model/checkpoint.pth')
        print(f'[SAVE] Checkpoint saved at step {step}')

        # Generate sample response
        output = model.generate(input, max_new_tokens=30)
        print('Sample response:', tokenizer.decode(output[0]))



[1;30;43mStreaming output truncated to the last 5000 lines.[0m
[TRAIN] Step 59 | Loss: 8.426 | LR: 0.000059 | dt: 1054.16ms | tok/sec:  1942.77
[TRAIN] Step 60 | Loss: 8.485 | LR: 0.000060 | dt: 1054.52ms | tok/sec:  1942.12
[TRAIN] Step 61 | Loss: 8.420 | LR: 0.000061 | dt: 1066.30ms | tok/sec:  1920.67
[TRAIN] Step 62 | Loss: 8.330 | LR: 0.000062 | dt: 1053.65ms | tok/sec:  1943.72
[TRAIN] Step 63 | Loss: 8.292 | LR: 0.000063 | dt: 1053.47ms | tok/sec:  1944.04
[TRAIN] Step 64 | Loss: 8.354 | LR: 0.000064 | dt: 1053.82ms | tok/sec:  1943.41
[TRAIN] Step 65 | Loss: 8.240 | LR: 0.000065 | dt: 1057.86ms | tok/sec:  1935.99
[TRAIN] Step 66 | Loss: 8.067 | LR: 0.000066 | dt: 1045.13ms | tok/sec:  1959.56
[TRAIN] Step 67 | Loss: 8.090 | LR: 0.000067 | dt: 1050.61ms | tok/sec:  1949.34
[TRAIN] Step 68 | Loss: 8.108 | LR: 0.000068 | dt: 1047.54ms | tok/sec:  1955.06
[TRAIN] Step 69 | Loss: 8.141 | LR: 0.000069 | dt: 1046.57ms | tok/sec:  1956.86
[TRAIN] Step 70 | Loss: 7.912 | LR: 0.000070

## Resume Training

<u>Confirm last checkpoint save and proceed</u>

In [None]:
# Load training checkpoint
checkpoint = torch.load('model/checkpoint.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
step = checkpoint['step']
loss = checkpoint['loss']
print(f'[LOAD] Checkpoint loaded at step {step}')

# Resume training
print(f'Training resumed from step {step+1}')
total_steps = 5000
for step in range(step+1, total_steps+51):
    t0 = time.time()

    x, y = train_loader.next_batch()
    x, y = x.to(device), y.to(device)
    optimizer.zero_grad()
    with torch.autocast(device_type=device, dtype=torch.bfloat16):  # Speedup 2 (Data type)
        loss, logits = model(x, labels=y)
    loss.backward()

    # Update learning rate according to schedule
    lr = get_lr(step)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    # Gradient clipping for stability
    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optimizer.step()

    torch.cuda.synchronize()
    t1 = time.time()
    dt = (t1 - t0) * 1000
    tokens_per_sec = (train_loader.B * train_loader.T) / (t1 - t0)
    print(f'[TRAIN] Step {step} | Loss: {loss.item():.3f} | LR: {lr:.6f} | dt: {dt:.2f}ms | tok/sec: {tokens_per_sec: .2f}')

print('Training completed!')

[LOAD] Checkpoint loaded at step 5000
Training resumed from step 5001
[TRAIN] Step 5001 | Loss: 0.034 | LR: 0.000050 | dt: 1122.43ms | tok/sec:  1824.61
[TRAIN] Step 5002 | Loss: 0.038 | LR: 0.000050 | dt: 1058.27ms | tok/sec:  1935.24
[TRAIN] Step 5003 | Loss: 0.044 | LR: 0.000050 | dt: 1061.47ms | tok/sec:  1929.39
[TRAIN] Step 5004 | Loss: 0.048 | LR: 0.000050 | dt: 1068.50ms | tok/sec:  1916.70
[TRAIN] Step 5005 | Loss: 0.044 | LR: 0.000050 | dt: 1070.05ms | tok/sec:  1913.93
[TRAIN] Step 5006 | Loss: 0.032 | LR: 0.000050 | dt: 1071.22ms | tok/sec:  1911.84
[TRAIN] Step 5007 | Loss: 0.042 | LR: 0.000050 | dt: 1074.76ms | tok/sec:  1905.54
[TRAIN] Step 5008 | Loss: 0.037 | LR: 0.000050 | dt: 1078.00ms | tok/sec:  1899.81
[TRAIN] Step 5009 | Loss: 0.042 | LR: 0.000050 | dt: 1076.89ms | tok/sec:  1901.77
[TRAIN] Step 5010 | Loss: 0.044 | LR: 0.000050 | dt: 1082.14ms | tok/sec:  1892.55
[TRAIN] Step 5011 | Loss: 0.038 | LR: 0.000050 | dt: 1085.98ms | tok/sec:  1885.85
[TRAIN] Step 5012

## Model Weights

In [None]:
# Save final model weights
torch.save(model.state_dict(), 'model/model.pth')

<u>Confirm last weights save and proceed</u>

In [None]:
# Strip prefix from keys (torch.compile)
state_dict = torch.load('model/model.pth', map_location=device)
state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

# Convert to half precision to shrink file size
for k, v in state_dict.items():
    state_dict[k] = v.half()
torch.save(state_dict, 'model/model_final.pth')

<u>Confirm last weights save and proceed</u>

In [None]:
# Verify model weights
model = SmolLM2ForCausalLM(SmolLM2Config())
model.to(device)
weights = torch.load('model/model_final.pth', map_location=device)
model.load_state_dict(weights)

<All keys matched successfully>