In [1]:
"""
LLM Pretraining Script
======================
Modern implementation for pretraining language models from scratch using the latest techniques:
- Flash Attention 2 for efficient attention computation
- Mixed precision training (bfloat16)
- Gradient checkpointing for memory efficiency
- AdamW optimizer with cosine learning rate schedule
- Gradient accumulation for large effective batch sizes
- Distributed Data Parallel (DDP) support

Dataset: BEE-spoke-data/TxT360-1M-sample (1M high-quality text samples)
"""

import os
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.distributed as dist
from datasets import load_dataset
from tokenizers import Tokenizer
from tokenizers.models import BPE
from tokenizers.trainers import BpeTrainer
from tokenizers.pre_tokenizers import Whitespace
from tokenizers.processors import TemplateProcessing
from transformers import PreTrainedTokenizerFast
from transformers.optimization import get_cosine_schedule_with_warmup
from dataclasses import dataclass
from typing import Optional, Dict, Any
import wandb
from tqdm.auto import tqdm
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [2]:


@dataclass
class ModelConfig:
    """Configuration for the LLM architecture."""

    vocab_size: int = 50304  # Padded to nearest multiple of 64 for efficiency
    max_seq_length: int = 2048
    hidden_size: int = 2048
    num_hidden_layers: int = 24
    num_attention_heads: int = 16
    num_key_value_heads: int = 8  # Grouped-Query Attention (GQA)
    intermediate_size: int = 5632  # ~2.7x hidden_size (SwiGLU)
    hidden_act: str = "silu"
    rms_norm_eps: float = 1e-6
    rope_theta: float = 10000.0
    attention_dropout: float = 0.0
    hidden_dropout: float = 0.0
    use_flash_attention: bool = True
    gradient_checkpointing: bool = True

In [3]:


@dataclass
class TrainingConfig:
    """Configuration for training hyperparameters."""

    # Data
    dataset_name: str = "BEE-spoke-data/TxT360-1M-sample"
    dataset_split: str = "train"
    max_seq_length: int = 2048

    # Training
    batch_size: int = 4
    gradient_accumulation_steps: int = 32  # Effective batch size = 128
    num_epochs: int = 1
    max_steps: Optional[int] = None

    # Optimization
    learning_rate: float = 3e-4
    weight_decay: float = 0.1
    adam_beta1: float = 0.9
    adam_beta2: float = 0.95
    adam_epsilon: float = 1e-8
    max_grad_norm: float = 1.0
    warmup_steps: int = 2000

    # Mixed precision
    use_fp16: bool = False
    use_bf16: bool = True

    # Checkpointing
    save_steps: int = 5000
    save_total_limit: int = 3
    output_dir: str = "./checkpoints/llm"

    # Logging
    logging_steps: int = 10
    use_wandb: bool = False
    wandb_project: str = "llm-pretraining"

    # Distributed
    local_rank: int = -1
    world_size: int = 1

In [4]:


class RMSNorm(nn.Module):
    """Root Mean Square Layer Normalization (more efficient than LayerNorm)."""

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        """
        Initialize RMSNorm layer.

        Args:
            hidden_size: Dimension of the hidden states
            eps: Epsilon for numerical stability
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        """
        Apply RMS normalization.

        Args:
            hidden_states: Input tensor of shape (batch, seq_len, hidden_size)

        Returns:
            Normalized tensor of same shape
        """
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states


In [5]:

class RotaryPositionalEmbedding(nn.Module):
    """
    Rotary Position Embedding (RoPE) for better length extrapolation.
    Implementation follows the original paper: https://arxiv.org/abs/2104.09864
    """

    def __init__(self, dim: int, max_seq_length: int = 2048, base: float = 10000.0):
        """
        Initialize RoPE.

        Args:
            dim: Dimension of each attention head
            max_seq_length: Maximum sequence length
            base: Base for the geometric progression
        """
        super().__init__()
        self.dim = dim
        self.max_seq_length = max_seq_length
        self.base = base

        # Precompute frequencies
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq, persistent=False)

        # Build position embeddings cache
        self._set_cos_sin_cache(max_seq_length)

    def _set_cos_sin_cache(self, seq_len: int):
        """Precompute cos and sin values for efficiency."""
        self.max_seq_len_cached = seq_len
        t = torch.arange(seq_len, device=self.inv_freq.device).type_as(self.inv_freq)
        freqs = torch.outer(t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer("cos_cached", emb.cos(), persistent=False)
        self.register_buffer("sin_cached", emb.sin(), persistent=False)

    def forward(self, x: torch.Tensor, seq_len: int):
        """
        Apply rotary embeddings.

        Args:
            x: Input tensor
            seq_len: Sequence length

        Returns:
            Tuple of (cos, sin) tensors for rotary embedding
        """
        if seq_len > self.max_seq_len_cached:
            self._set_cos_sin_cache(seq_len)

        return (
            self.cos_cached[:seq_len],
            self.sin_cached[:seq_len],
        )


In [6]:

def rotate_half(x: torch.Tensor) -> torch.Tensor:
    """Helper function to rotate half the hidden dims of the input."""
    x1, x2 = x.chunk(2, dim=-1)
    return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor):
    """
    Apply rotary position embeddings to queries and keys.

    Args:
        q: Query tensor
        k: Key tensor
        cos: Cosine tensor
        sin: Sine tensor

    Returns:
        Tuple of rotated (q, k) tensors
    """
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


In [None]:

class GroupedQueryAttention(nn.Module):
    """
    Grouped-Query Attention (GQA) - hybrid between Multi-Head and Multi-Query Attention.
    More efficient than MHA while maintaining better quality than MQA.
    Paper: https://arxiv.org/abs/2305.13245
    """

    def __init__(self, config: ModelConfig):
        """
        Initialize GQA layer.

        Args:
            config: Model configuration
        """
        super().__init__()
        self.hidden_size = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.num_key_value_heads = config.num_key_value_heads
        self.num_key_value_groups = self.num_heads // self.num_key_value_heads
        self.head_dim = self.hidden_size // self.num_heads
        self.attention_dropout = config.attention_dropout
        self.use_flash_attention = config.use_flash_attention

        # Projections
        self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)

        # RoPE
        self.rotary_emb = RotaryPositionalEmbedding(
            self.head_dim,
            max_seq_length=config.max_seq_length,
            base=config.rope_theta,
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass of GQA.

        Args:
            hidden_states: Input tensor of shape (batch, seq_len, hidden_size)
            attention_mask: Optional attention mask

        Returns:
            Output tensor of shape (batch, seq_len, hidden_size)
        """
        batch_size, seq_length, _ = hidden_states.shape

        # Project to Q, K, V
        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        # Reshape for multi-head attention
        query_states = query_states.view(batch_size, seq_length, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(batch_size, seq_length, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        # Apply RoPE
        cos, sin = self.rotary_emb(value_states, seq_length)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        # Repeat K, V for grouped-query attention
        if self.num_key_value_groups > 1:
            key_states = key_states.repeat_interleave(self.num_key_value_groups, dim=1)
            value_states = value_states.repeat_interleave(self.num_key_value_groups, dim=1)

        # Flash Attention or standard attention
        if self.use_flash_attention and hasattr(F, 'scaled_dot_product_attention'):
            # Use PyTorch's native flash attention (requires PyTorch 2.0+)
            attn_output = F.scaled_dot_product_attention(
                query_states,
                key_states,
                value_states,
                attn_mask=attention_mask,
                dropout_p=self.attention_dropout if self.training else 0.0,
                is_causal=True,
            )
        else:
            # Standard attention computation
            attn_weights = torch.matmul(query_states, key_states.transpose(-2, -1)) / math.sqrt(self.head_dim)

            if attention_mask is not None:
                attn_weights = attn_weights + attention_mask

            attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
            attn_weights = F.dropout(attn_weights, p=self.attention_dropout, training=self.training)
            attn_output = torch.matmul(attn_weights, value_states)

        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(batch_size, seq_length, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        return attn_output


In [None]:

class SwiGLUMLP(nn.Module):
    """
    SwiGLU MLP - Gated Linear Unit with Swish activation.
    More efficient and better performing than standard FFN.
    Paper: https://arxiv.org/abs/2002.05202
    """

    def __init__(self, config: ModelConfig):
        """
        Initialize SwiGLU MLP.

        Args:
            config: Model configuration
        """
        super().__init__()
        self.hidden_size = config.hidden_size
        self.intermediate_size = config.intermediate_size

        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
        self.act_fn = nn.SiLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass of SwiGLU MLP.

        Args:
            x: Input tensor of shape (batch, seq_len, hidden_size)

        Returns:
            Output tensor of shape (batch, seq_len, hidden_size)
        """
        return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))


In [None]:

class TransformerBlock(nn.Module):
    """Single transformer block with pre-normalization."""

    def __init__(self, config: ModelConfig):
        """
        Initialize transformer block.

        Args:
            config: Model configuration
        """
        super().__init__()
        self.attention = GroupedQueryAttention(config)
        self.mlp = SwiGLUMLP(config)
        self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        """
        Forward pass with residual connections.

        Args:
            hidden_states: Input tensor
            attention_mask: Optional attention mask

        Returns:
            Output tensor
        """
        # Self-attention with pre-norm
        residual = hidden_states
        hidden_states = self.input_layernorm(hidden_states)
        hidden_states = self.attention(hidden_states, attention_mask)
        hidden_states = residual + hidden_states

        # MLP with pre-norm
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        return hidden_states


In [None]:

class LLMModel(nn.Module):
    """Complete Language Model architecture."""

    def __init__(self, config: ModelConfig):
        """
        Initialize LLM.

        Args:
            config: Model configuration
        """
        super().__init__()
        self.config = config

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) # embed_tokens is the embedding layer for the input tokens
        self.layers = nn.ModuleList([TransformerBlock(config) for _ in range(config.num_hidden_layers)]) # layers is the list of transformer blocks
        self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # norm is the layer normalization layer
        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) # lm_head is the linear layer for the output logits

        # Tie embeddings
        self.lm_head.weight = self.embed_tokens.weight # tie the weights of the linear layer to the weights of the embedding layer

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        """Initialize weights using scaled initialization."""
        std = 0.02
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
    ) -> Dict[str, torch.Tensor]:
        """
        Forward pass of the model.

        Args:
            input_ids: Token IDs of shape (batch, seq_len)
            attention_mask: Optional attention mask
            labels: Optional labels for computing loss

        Returns:
            Dictionary containing loss and logits
        """
        # Embed tokens
        hidden_states = self.embed_tokens(input_ids)

        # Prepare causal attention mask
        batch_size, seq_length = input_ids.shape
        causal_mask = torch.triu(
            torch.full((seq_length, seq_length), float("-inf"), device=input_ids.device),
            diagonal=1,
        )

        # Pass through transformer blocks
        for layer in self.layers:
            hidden_states = layer(hidden_states, causal_mask)

        # Final normalization and projection
        hidden_states = self.norm(hidden_states)
        logits = self.lm_head(hidden_states)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            # Shift logits and labels for next-token prediction
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            # Compute cross-entropy loss
            loss = F.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1),
                ignore_index=-100,
            )

        return {"loss": loss, "logits": logits}


In [None]:

class TextDataset(Dataset):
    """Dataset for loading and tokenizing text data."""

    def __init__(
        self,
        dataset_name: str,
        tokenizer: PreTrainedTokenizerFast,
        max_length: int = 2048,
        split: str = "train",
    ):
        """
        Initialize dataset.

        Args:
            dataset_name: Name of the HuggingFace dataset
            tokenizer: Tokenizer instance
            max_length: Maximum sequence length
            split: Dataset split to use
        """
        self.tokenizer = tokenizer
        self.max_length = max_length

        logger.info(f"Loading dataset {dataset_name}...")
        self.dataset = load_dataset(dataset_name, split=split, streaming=False)
        logger.info(f"Dataset loaded with {len(self.dataset)} samples")

    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        """
        Get a single tokenized sample.

        Args:
            idx: Sample index

        Returns:
            Dictionary with input_ids and attention_mask
        """
        text = self.dataset[idx]["text"]

        # Tokenize
        encoding = self.tokenizer(
            text,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        # Prepare input_ids and labels (for causal LM, labels = input_ids)
        input_ids = encoding["input_ids"].squeeze(0)

        return {
            "input_ids": input_ids,
            "labels": input_ids.clone(),
        }

In [None]:


def train_tokenizer(dataset_name: str, vocab_size: int = 50304, output_path: str = "./tokenizer") -> PreTrainedTokenizerFast:
    """
    Train a BPE tokenizer on the dataset.

    Args:
        dataset_name: Name of the HuggingFace dataset
        vocab_size: Vocabulary size
        output_path: Path to save tokenizer

    Returns:
        Trained tokenizer
    """
    logger.info("Training tokenizer...")

    # Load dataset
    dataset = load_dataset(dataset_name, split="train", streaming=True)

    # Create iterator for training
    def batch_iterator(batch_size=1000):
        for i, batch in enumerate(dataset):
            if i >= 100000:  # Use first 100k samples for tokenizer training
                break
            yield batch["text"]

    # Initialize tokenizer
    tokenizer = Tokenizer(BPE(unk_token="<unk>"))
    tokenizer.pre_tokenizer = Whitespace()

    # Train tokenizer
    trainer = BpeTrainer(
        vocab_size=vocab_size,
        special_tokens=["<unk>", "<s>", "</s>", "<pad>"],
        show_progress=True,
    )

    tokenizer.train_from_iterator(batch_iterator(), trainer=trainer)

    # Add post-processor
    tokenizer.post_processor = TemplateProcessing(
        single="<s> $A </s>",
        special_tokens=[("<s>", 1), ("</s>", 2)],
    )

    # Convert to HuggingFace tokenizer
    hf_tokenizer = PreTrainedTokenizerFast(
        tokenizer_object=tokenizer,
        unk_token="<unk>",
        bos_token="<s>",
        eos_token="</s>",
        pad_token="<pad>",
    )

    # Save tokenizer
    os.makedirs(output_path, exist_ok=True)
    hf_tokenizer.save_pretrained(output_path)
    logger.info(f"Tokenizer saved to {output_path}")

    return hf_tokenizer


In [None]:

def setup_distributed():
    """Initialize distributed training."""
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        rank = int(os.environ["RANK"]) # rank is the process id
        world_size = int(os.environ["WORLD_SIZE"]) # world_size is the total number of processes
        local_rank = int(os.environ["LOCAL_RANK"]) # local_rank is the rank of the process on the current node

        dist.init_process_group("nccl")
        torch.cuda.set_device(local_rank) # set the device to the local rank

        return rank, world_size, local_rank
    return 0, 1, -1

In [None]:


def train(config: TrainingConfig):
    """
    Main training function.

    Args:
        config: Training configuration
    """
    # Setup distributed training
    rank, world_size, local_rank = setup_distributed()
    is_main_process = rank == 0

    if is_main_process:
        logger.info("Starting LLM pretraining...")
        if config.use_wandb:
            wandb.init(project=config.wandb_project, config=config.__dict__)

    # Set device
    device = torch.device(f"cuda:{local_rank}" if local_rank >= 0 else "cuda" if torch.cuda.is_available() else "cpu")

    # Train or load tokenizer
    tokenizer_path = "./tokenizer"
    if is_main_process and not os.path.exists(tokenizer_path):
        tokenizer = train_tokenizer(config.dataset_name, output_path=tokenizer_path)

    if world_size > 1:
        dist.barrier()  # Wait for main process to finish tokenizer training

    tokenizer = PreTrainedTokenizerFast.from_pretrained(tokenizer_path)

    # Create model
    model_config = ModelConfig(vocab_size=len(tokenizer), max_seq_length=config.max_seq_length)
    model = LLMModel(model_config).to(device)

    if is_main_process:
        total_params = sum(p.numel() for p in model.parameters())
        logger.info(f"Model parameters: {total_params:,} ({total_params / 1e9:.2f}B)")

    # Enable gradient checkpointing
    if model_config.gradient_checkpointing:
        model.gradient_checkpointing_enable()

    # Wrap model with DDP
    if world_size > 1:
        model = DDP(model, device_ids=[local_rank], output_device=local_rank)

    # Create dataset and dataloader
    dataset = TextDataset(config.dataset_name, tokenizer, config.max_seq_length, config.dataset_split)
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank) if world_size > 1 else None
    dataloader = DataLoader(
        dataset,
        batch_size=config.batch_size,
        sampler=sampler,
        shuffle=(sampler is None),
        num_workers=4,
        pin_memory=True,
    )

    # Setup optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=config.learning_rate,
        betas=(config.adam_beta1, config.adam_beta2),
        eps=config.adam_epsilon,
        weight_decay=config.weight_decay,
    )

    # Calculate total steps
    total_steps = config.max_steps if config.max_steps else len(dataloader) * config.num_epochs // config.gradient_accumulation_steps

    # Setup scheduler
    scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=config.warmup_steps,
        num_training_steps=total_steps,
    )

    # Setup mixed precision
    scaler = torch.cuda.amp.GradScaler() if config.use_fp16 else None
    dtype = torch.bfloat16 if config.use_bf16 else torch.float16 if config.use_fp16 else torch.float32

    # Training loop
    global_step = 0
    model.zero_grad()

    for epoch in range(config.num_epochs):
        if sampler:
            sampler.set_epoch(epoch)

        progress_bar = tqdm(dataloader, disable=not is_main_process)

        for step, batch in enumerate(progress_bar):
            # Move batch to device
            input_ids = batch["input_ids"].to(device)
            labels = batch["labels"].to(device)

            # Forward pass with mixed precision
            with torch.cuda.amp.autocast(dtype=dtype):
                outputs = model(input_ids=input_ids, labels=labels)
                loss = outputs["loss"] / config.gradient_accumulation_steps

            # Backward pass
            if scaler:
                scaler.scale(loss).backward()
            else:
                loss.backward()

            # Update weights
            if (step + 1) % config.gradient_accumulation_steps == 0:
                if scaler:
                    scaler.unscale_(optimizer)

                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_grad_norm)

                if scaler:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

                # Logging
                if global_step % config.logging_steps == 0 and is_main_process:
                    lr = scheduler.get_last_lr()[0]
                    loss_val = loss.item() * config.gradient_accumulation_steps
                    progress_bar.set_postfix({"loss": f"{loss_val:.4f}", "lr": f"{lr:.2e}"})

                    if config.use_wandb:
                        wandb.log({
                            "train/loss": loss_val,
                            "train/learning_rate": lr,
                            "train/step": global_step,
                        })

                # Save checkpoint
                if global_step % config.save_steps == 0 and is_main_process:
                    checkpoint_path = os.path.join(config.output_dir, f"checkpoint-{global_step}")
                    os.makedirs(checkpoint_path, exist_ok=True)

                    model_to_save = model.module if hasattr(model, "module") else model
                    torch.save({
                        "model": model_to_save.state_dict(),
                        "optimizer": optimizer.state_dict(),
                        "scheduler": scheduler.state_dict(),
                        "step": global_step,
                        "config": model_config,
                    }, os.path.join(checkpoint_path, "pytorch_model.bin"))

                    tokenizer.save_pretrained(checkpoint_path)
                    logger.info(f"Checkpoint saved at step {global_step}")

                # Check if max steps reached
                if config.max_steps and global_step >= config.max_steps:
                    break

        if config.max_steps and global_step >= config.max_steps:
            break

    # Save final model
    if is_main_process:
        final_path = os.path.join(config.output_dir, "final")
        os.makedirs(final_path, exist_ok=True)

        model_to_save = model.module if hasattr(model, "module") else model
        torch.save({
            "model": model_to_save.state_dict(),
            "config": model_config,
        }, os.path.join(final_path, "pytorch_model.bin"))

        tokenizer.save_pretrained(final_path)
        logger.info(f"Final model saved to {final_path}")

        if config.use_wandb:
            wandb.finish()

    if world_size > 1:
        dist.destroy_process_group()


In [None]:

config = TrainingConfig()
train(config)

