In [None]:
# Section 1: Initial Setup and Core Components

import os
from pathlib import Path
import logging
import torch
import psutil
import time
import random
from datetime import timedelta
import sys
import io
import tempfile
import numpy as np
from contextlib import contextmanager
from torch.utils.data import Dataset, DataLoader
from tqdm.auto import tqdm
import warnings
from contextlib import nullcontext

# Check if running in Google Colab
def is_colab():
    try:
        return 'google.colab' in str(get_ipython())
    except NameError:
        return False

# Mount Google Drive if in Colab
if is_colab():
    from google.colab import drive
    drive.mount('/content/drive')

# Create LLM directory in Drive
base_dir = Path('/content/drive/MyDrive/LLM') if is_colab() else Path('./LLM')
for dir_name in ['checkpoints', 'models', 'logs', 'configs', 'data']:
    (base_dir / dir_name).mkdir(parents=True, exist_ok=True)

def is_package_installed(package_name):
    try:
        __import__(package_name)
        return True
    except ImportError:
        return False

if is_colab():
    # PyTorch packages
    pytorch_packages = ['torch', 'torchvision', 'torchaudio']
    pytorch_install = [pkg for pkg in pytorch_packages if not is_package_installed(pkg)]
    if pytorch_install:
        !pip install {' '.join(pytorch_install)} --index-url https://download.pytorch.org/whl/cu118

    # Additional packages
    other_packages = ['pynvml', 'nvidia_ml_py3', 'gputil', 'fastapi', 'uvicorn', 'pydantic']
    other_install = [pkg for pkg in other_packages if not is_package_installed(pkg)]
    if other_install:
        !pip install {' '.join(other_install)}

# Suppress warnings
warnings.filterwarnings('ignore')

print("Section 1: Initial setup and core components complete")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
Section 1: Initial setup and core components complete


In [None]:
import logging
import torch
from pathlib import Path
from typing import List, Union

class DataManager:
    def __init__(self, config):
        self.config = config
        self.char_to_idx = {}
        self.idx_to_char = {}
        self.train_data: torch.Tensor = None
        self.val_data: torch.Tensor = None
        self.test_data: torch.Tensor = None
        self.vocab_size = 0

        # Define the data paths as lists of files
        self.data_paths = {
            'train': [
                Path('/content/drive/MyDrive/LLM/data/huff_2252.txt'),
                Path('/content/drive/MyDrive/LLM/data/theo_2253.txt'),
                Path('/content/drive/MyDrive/LLM/data/hubbard_2251.txt'),
                Path('/content/drive/MyDrive/LLM/data/raekwon_2250.txt'),
                Path('/content/drive/MyDrive/LLM/data/pappas_2249.txt'),
                Path('/content/drive/MyDrive/LLM/data/trussell_2247.txt'),
                Path('/content/drive/MyDrive/LLM/data/fox_2246.txt'),
                Path('/content/drive/MyDrive/LLM/data/blagojevich_2245.txt'),
                Path('/content/drive/MyDrive/LLM/data/sorin_2242.txt'),
                Path('/content/drive/MyDrive/LLM/data/strassman_2241.txt'),
                Path('/content/drive/MyDrive/LLM/data/tarantino_2240.txt'),
                Path('/content/drive/MyDrive/LLM/data/derek_2239.txt'),
                Path('/content/drive/MyDrive/LLM/data/mcphee_2238.txt'),
                Path('/content/drive/MyDrive/LLM/data/parks_2236.txt'),
            ],
            'val': [
                Path('/content/drive/MyDrive/LLM/data/lennon_2243.txt'),
                Path('/content/drive/MyDrive/LLM/data/andreessen_2234.txt'),
                Path('/content/drive/MyDrive/LLM/data/lennon_2243.txt'),
                Path('/content/drive/MyDrive/LLM/data/storch_2233.txt'),
                # Add more validation files as needed
            ],
            'test': [
                Path('/content/drive/MyDrive/LLM/data/graves_2244.txt'),
                #Path('/content/drive/MyDrive/LLM/data/test_file2.txt'),
                # Add more test files as needed
            ]
        }

    def calculate_vocab_size(self):
      """Calculates vocabulary size from all training files."""
      try:
          # Collect all unique characters from all training files
          all_chars = set()
          for train_path in self.data_paths['train']:
              if not train_path.exists():
                  logging.error(f"Training data file does not exist: {train_path}")
                  raise FileNotFoundError(f"Training data file does not exist: {train_path}")

              with open(train_path, 'r', encoding='utf-8') as f:
                  text = f.read()
                  all_chars.update(set(text))

          # Sort the unique characters and add special tokens
          unique_chars = sorted(all_chars)
          special_chars = ['\n', '\t', ' ', '_', '[', ']', '(', ')', '{', '}', '*', '/', '\\', '|',
                          '@', '#', '$', '%', '^', '&', '+', '=', '`', '~', '<pad>', '<extra>']

          # Add special characters that aren't already in unique_chars
          for char in special_chars:
              if char not in unique_chars:
                  unique_chars.append(char)

          self.setup_tokenizer(unique_chars)

          print(f"Vocabulary created with {self.vocab_size} characters including special characters")
          print(f"Special characters included: {[c for c in special_chars if c in self.char_to_idx]}")

      except Exception as e:
          logging.error(f"Error during vocabulary calculation: {e}")
          raise

      return self.vocab_size

    def load_data(self):
        """Load and preprocess text data from multiple files into properly structured tensors."""
        try:
            # First calculate vocabulary from all training files
            self.calculate_vocab_size()

            # Process each split (train, val, test)
            for split in ['train', 'val', 'test']:
                all_indices = []

                # Process each file in the split
                for file_path in self.data_paths[split]:
                    logging.info(f"Loading {split} data from {file_path}")
                    try:
                        with open(file_path, 'r', encoding='utf-8') as f:
                            text = f.read()

                        # Convert text to indices
                        unknown_chars = set()
                        file_indices = []
                        for c in text:
                            if c in self.char_to_idx:
                                file_indices.append(self.char_to_idx[c])
                            else:
                                unknown_chars.add(c)
                                file_indices.append(self.char_to_idx[' '])

                        if unknown_chars:
                            logging.warning(f"Characters not in vocabulary in {file_path}: {unknown_chars}")

                        all_indices.extend(file_indices)

                    except Exception as e:
                        logging.error(f"Error processing file {file_path}: {str(e)}")
                        raise

                # Convert combined indices to tensor
                data = torch.tensor(all_indices, dtype=torch.long, device=self.config.device)

                # Create input-target pairs
                n = len(data) - self.config.block_size
                if n <= 0:
                    raise ValueError(f"Combined text in {split} split is shorter than block_size")

                input_ids = torch.stack([data[i:i+self.config.block_size] for i in range(0, n, self.config.block_size)])
                target_ids = torch.stack([data[i+1:i+self.config.block_size+1] for i in range(0, n, self.config.block_size)])

                # Store as tuple of (inputs, targets)
                if split == 'train':
                    self.train_data = (input_ids, target_ids)
                elif split == 'val':
                    self.val_data = (input_ids, target_ids)
                else:
                    self.test_data = (input_ids, target_ids)

                logging.info(f"Loaded {split} data with {len(input_ids)} sequences from {len(self.data_paths[split])} files")

        except Exception as e:
            logging.error(f"Error loading data: {str(e)}")
            raise

    def setup_tokenizer(self, unique_chars: List[str]):
        """Set up the tokenizer mapping."""
        # Create char_to_idx and idx_to_char mappings
        self.char_to_idx = {char: idx for idx, char in enumerate(unique_chars)}
        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
        self.vocab_size = len(self.char_to_idx)

        logging.info(f"Vocabulary size: {self.vocab_size}")

    def encode(self, text: str) -> torch.Tensor:
        """Encode text into tensor of indices."""
        try:
            # Ensure all characters are in the vocabulary
            indices = [self.char_to_idx[char] for char in text]
            return torch.tensor(indices, dtype=torch.long, device=self.config.device)
        except KeyError as e:
            raise ValueError(f"Character '{e.args[0]}' not found in vocabulary. Update the vocabulary to include all characters.")

    def get_batch(self, split: str, batch_size: int):
        """Retrieve a batch of data for training or evaluation."""
        # Get the appropriate dataset
        if split == 'train':
            data = self.train_data
        elif split == 'val':
            data = self.val_data
        elif split == 'test':
            data = self.test_data
        else:
            raise ValueError("Split must be 'train', 'val', or 'test'")

        # Validate data
        if data is None:
            raise ValueError(f"{split.capitalize()} data is not loaded.")

        try:
            # Assuming data is a tuple of (inputs, targets)
            if not isinstance(data, tuple) or len(data) != 2:
                raise ValueError(f"Data format error: expected tuple of (inputs, targets), got {type(data)}")

            inputs, targets = data

            if not isinstance(inputs, torch.Tensor):
                raise ValueError(f"Inputs must be a tensor, got {type(inputs)}")

            # Get random indices for batch
            max_idx = max(0, len(inputs) - self.config.block_size - 1)
            if max_idx <= 0:
                raise ValueError(f"Data length ({len(inputs)}) must be greater than block_size ({self.config.block_size})")

            idx = torch.randint(0, max_idx, (batch_size,), device=self.config.device)

            # Select batch data and ensure proper dimensions
            x = inputs[idx]
            y = targets[idx]

            return x, y

        except Exception as e:
            logging.error(f"Error in get_batch: {str(e)}")
            raise

In [None]:
# RUN FOR TRAINING and INFERENCE

from dataclasses import dataclass
from typing import Optional, Dict, Any, Union
import logging
import json
from pathlib import Path
import torch
import time
import psutil
from datetime import timedelta
import threading

@dataclass
class ModelConfig:
    save_dir: Union[str, Path] = Path("/content/drive/MyDrive/LLM/checkpoints")
    vocab_size: int = 0
    block_size: int = 64              # Context length
    n_embed: int = 64                 # Embedding dimension
    n_head: int = 4                   # Number of attention heads
    n_layer: int = 4                  # Number of transformer layers
    ff_dim: int = 256                 # Feed-forward dimension
    dropout: float = 0.1              # Dropout rate
    bias: bool = True
    flash_attn: bool = True
    head_dim: int = 64
    attn_pdrop: float = 0.1
    resid_pdrop: float = 0.1
    embd_pdrop: float = 0.1

    # Update existing parameters
    ff_dim: int = None  # Will be set to 4 * n_embed

    def __post_init__(self):
        super().__post_init__()
        if self.ff_dim is None:
            self.ff_dim = 4 * self.n_embed


    # Training Parameters
    batch_size: int = 128
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-3
    min_learning_rate: float = 1e-3
    weight_decay: float = 0.1
    epochs: int = 10
    warmup_steps: int = 200            # Number of warmup steps
    max_grad_norm: float = 1.0
    within_epoch_loss_threshold: float = 0.7  # New parameter for early stopping within epoch
    eval_steps: int = 100  # Check loss every N steps

    # Learning Rate Schedule
    lr_schedule: str = 'cosine_with_warmup'
    lr_decay_epochs: int = 8

    # Memory Optimization
    use_amp: bool = True
    amp_dtype: torch.dtype = torch.bfloat16  # Changed from float16
    dtype: torch.dtype = torch.bfloat16      # Changed from float32
    use_gradient_checkpointing: bool = False

    # Data Loading
    pin_memory: bool = True
    num_workers: int = 8
    prefetch_factor: int = 5

    # Generation Parameters
    gen_temperature: float = 0.8
    max_gen_tokens: int = 256
    top_k: int = 50
    top_p: float = 0.9

    # Early Stopping
    patience: int = 5
    min_delta: float = 1e-4
    loss_threshold: float = 1.0

    # Evaluation and Checkpointing
    eval_every: int = 1000
    save_every: int = 1
    keep_last_n_checkpoints: int = 5

    # System Parameters
    device: str = "cuda"

    def __post_init__(self):
        """Initialize derived parameters and create directories"""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32  # Changed from float16

        # Create save directory
        self.save_dir = Path("/content/drive/MyDrive/LLM/checkpoints")
        self.save_dir.mkdir(parents=True, exist_ok=True)

    def print_params(self) -> None:
        """Print model and training parameters"""
        logging.info("Model Parameters:")
        for key, value in self.__dict__.items():
            logging.info(f"  {key}: {value}")

    def setup_logging(self, iteration: int) -> None:
        """Configure logging system"""
        log_dir = Path('logs')
        log_dir.mkdir(parents=True, exist_ok=True)

        log_file = log_dir / f'training_{iteration}.log'

        logging.basicConfig(
            level=logging.INFO,
            format='%(asctime)s - %(levelname)s - %(message)s',
            handlers=[
                logging.StreamHandler(),
                logging.FileHandler(log_file)
            ]
        )

        # Log parameters
        logging.info(f"PyTorch version: {torch.__version__}")
        logging.info(f"CUDA available: {torch.cuda.is_available()}")
        if torch.cuda.is_available():
            logging.info(f"CUDA device: {torch.cuda.get_device_name(0)}")
            logging.info(f"CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

        # Print model and training parameters
        self.print_params()

    def save(self, path: Union[str, Path]) -> None:
        """Save configuration to JSON"""
        path = Path(path)
        config_dict = {
            k: str(v) if isinstance(v, (Path, torch.dtype)) else v
            for k, v in self.__dict__.items()
        }
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open('w') as f:
            json.dump(config_dict, f, indent=4)

    @classmethod
    def load(cls, path: Union[str, Path]) -> 'ModelConfig':
        """Load configuration from JSON"""
        path = Path(path)
        with path.open('r') as f:
            config_dict = json.load(f)

        if 'dtype' in config_dict:
            config_dict['dtype'] = getattr(torch, config_dict['dtype'].split('.')[-1])
        if 'save_dir' in config_dict:
            config_dict['save_dir'] = str(config_dict['save_dir'])

        return cls(**config_dict)

'''
class SystemMonitor:
    """System resource monitoring with memory optimization"""

    def __init__(self):
        self.gpu_available = torch.cuda.is_available()
        self.start_time = time.time()
        self._lock = threading.Lock()
        self.history = {
            'train_loss': [],
            'val_loss': [],
            'learning_rate': [],
            'epoch_times': []
        }
        self.epoch_start_time = None

    def start_epoch(self):
        """Mark the start of an epoch"""
        self.epoch_start_time = time.time()

    def end_epoch(self):
        """Mark the end of an epoch and record duration"""
        if self.epoch_start_time is not None:
            duration = time.time() - self.epoch_start_time
            self.history['epoch_times'].append(duration)
            self.epoch_start_time = None

    def cleanup(self):
        """Cleanup any resources"""
        if self.gpu_available:
            torch.cuda.empty_cache()
'''

In [None]:
# Section 4: Model Architecture
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_seq_len_cached = max_position_embeddings

        # Initialize cache
        t = torch.arange(max_position_embeddings, device=inv_freq.device).type_as(inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

    def forward(self, x, seq_len=None):
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            self.cos_cached = emb.cos()[None, None, :, :]
            self.sin_cached = emb.sin()[None, None, :, :]

        return (
            self.cos_cached[:, :, :seq_len, ...].to(x.device),
            self.sin_cached[:, :, :seq_len, ...].to(x.device)
        )

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class CausalSelfAttention(nn.Module):
    """Multi-head causal self-attention layer with RoPE, gradient checkpointing and flash attention."""
    def __init__(self, config):
        super().__init__()
        assert config.n_embed % config.n_head == 0

        # Store necessary dimensions
        self.n_embed = config.n_embed
        self.n_head = config.n_head
        self.head_dim = config.n_embed // config.n_head
        self.dropout = config.dropout

        # Key, query, value projections for all heads
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed, bias=config.bias)

        # Regularization
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        # Initialize RoPE
        self.rope = RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=config.block_size
        )

        # Flash attention support
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            self.register_buffer(
                "mask",
                torch.tril(torch.ones(config.block_size, config.block_size))
                .view(1, 1, config.block_size, config.block_size)
            )

        # Scaling factor for attention
        self.scale = 1.0 / math.sqrt(self.head_dim)

    def forward(self, x):
        B, T, C = x.size()  # batch size, sequence length, embedding dimensionality

        # Calculate query, key, values for all heads
        q, k, v = self.c_attn(x).split(self.n_embed, dim=2)

        # Reshape for multi-head attention
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)  # (B, nh, T, hs)

        # Apply RoPE to queries and keys
        cos, sin = self.rope(q, seq_len=T)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # Causal self-attention with flash attention optimization if available
        if self.flash:
            # Use flash attention for efficiency
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=self.dropout if self.training else 0,
                is_causal=True,
                scale=self.scale
            )
        else:
            # Manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * self.scale
            att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v  # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)

        # Re-assemble all head outputs side by side
        y = y.transpose(1, 2).contiguous().view(B, T, C)

        # Output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

class LayerNorm(nn.Module):
    """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
    def __init__(self, ndim, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class GPTBlock(nn.Module):
    """Transformer block with optimizations, layer normalization, and feed-forward network."""
    def __init__(self, config):
        super().__init__()
        # Layer normalization with optional bias
        self.ln1 = LayerNorm(config.n_embed, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln2 = LayerNorm(config.n_embed, bias=config.bias)

        # MLP with 4x expansion ratio
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias),
            nn.GELU(),
            nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias),
            nn.Dropout(config.dropout),
        )

        # Optional gradient checkpointing
        self.use_checkpointing = config.use_gradient_checkpointing

        # LayerScale for better training dynamics (optional)
        self.layer_scale_1 = nn.Parameter(torch.ones(config.n_embed) * 0.1)
        self.layer_scale_2 = nn.Parameter(torch.ones(config.n_embed) * 0.1)

    def forward(self, x):
        # Attention block with residual connection and layer scale
        if self.use_checkpointing and x.requires_grad:
            attn_output = torch.utils.checkpoint.checkpoint(self.attn, self.ln1(x))
        else:
            attn_output = self.attn(self.ln1(x))
        x = x + self.layer_scale_1.unsqueeze(0).unsqueeze(0) * attn_output

        # MLP block with residual connection and layer scale
        if self.use_checkpointing and x.requires_grad:
            mlp_output = torch.utils.checkpoint.checkpoint(self.mlp, self.ln2(x))
        else:
            mlp_output = self.mlp(self.ln2(x))
        x = x + self.layer_scale_2.unsqueeze(0).unsqueeze(0) * mlp_output

        return x

class GPT(nn.Module):
    """GPT Language Model with RoPE, memory optimizations, and advanced techniques."""
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Initialize transformer components
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.n_embed),  # Token embeddings
            'drop': nn.Dropout(config.dropout),  # Dropout layer
            'h': nn.ModuleList([GPTBlock(config) for _ in range(config.n_layer)]),  # Transformer blocks
            'ln_f': nn.LayerNorm(config.n_embed)  # Final layer normalization
        })

        # Output layer to predict vocab size
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)

        # Initialize weights
        self.apply(self._init_weights)

        # Apply special scaled initialization to the residual projections
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        # Report number of parameters
        print(f"Number of parameters: {sum(p.numel() for p in self.parameters()) / 1e6:.2f}M")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
        elif isinstance(module, LayerNorm):
            torch.nn.init.ones_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, RotaryEmbedding):
            # Initialize RoPE parameters if any
            pass

    def forward(self, idx, targets=None):
        """Forward pass through the model."""
        device = idx.device
        b, t = idx.size()

        # Memory efficient forward pass with AMP
        with torch.cuda.amp.autocast(enabled=self.config.use_amp, dtype=torch.bfloat16):
            # Only token embeddings now, no positional embeddings needed (handled by RoPE)
            x = self.transformer.wte(idx)
            x = self.transformer.drop(x)

            # Pass through transformer blocks
            for block in self.transformer.h:
                x = block(x)
            x = self.transformer.ln_f(x)

            # Compute logits and loss if targets are provided
            if targets is not None:
                logits = self.lm_head(x)
                loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1))
            else:
                logits = self.lm_head(x[:, -1, :])
                loss = None

        return logits, loss

In [None]:
class ModelInterface:
    """Interface for model operations with comprehensive error handling and state management"""

    def __init__(self, config: ModelConfig):
        """Initialize the interface with configuration and necessary components"""
        self.config = config
        self.original_batch_size = config.batch_size  # Store original batch size
        self.model = None
        self.data_manager = None
        self.scaler = torch.cuda.amp.GradScaler(enabled=getattr(config, 'use_amp', True))
        self._is_initialized = False
        self.checkpoint_info = None
        self._initialize()

    def _initialize(self):
        """Safe initialization with resource management and state verification"""
        try:
            # Verify and adjust batch size if needed
            if self.config.batch_size != self.original_batch_size:
                logging.warning(f"Batch size was modified from {self.original_batch_size} to {self.config.batch_size}")
                self.config.batch_size = self.original_batch_size  # Restore original batch size

            # Initialize data manager first
            self.data_manager = DataManager(self.config)

            # Initialize model with proper device placement
            self.model = GPT(self.config)
            self.model = self.model.to(self.config.device)

            self._is_initialized = True

            # Log initialization details
            print(f"Model initialized with {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M parameters")
            print(f"Using device: {self.config.device}")
            print(f"Batch size: {self.config.batch_size}")

            # Adjust gradient accumulation based on GPU memory if needed
            if self.config.device == "cuda":
                gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
                if gpu_mem < 8:  # Only adjust if very limited memory
                    self.config.gradient_accumulation_steps = 2
                    logging.info(f"Adjusted gradient accumulation steps to {self.config.gradient_accumulation_steps} due to limited GPU memory")

        except Exception as e:
            self.cleanup()
            raise RuntimeError(f"Failed to initialize model interface: {str(e)}")

    def verify_state(self):
        """Verify that the model and vocabulary are in a valid state"""
        if not self._is_initialized:
            raise RuntimeError("Model interface not properly initialized")

        if not self.model:
            raise RuntimeError("Model not loaded")

        if not self.data_manager:
            raise RuntimeError("Data manager not initialized")

        if not self.data_manager.char_to_idx:
            raise RuntimeError("Vocabulary mappings not loaded")

        # Verify vocabulary consistency
        if self.data_manager.vocab_size != self.config.vocab_size:
            raise ValueError(f"Vocabulary size mismatch: {self.data_manager.vocab_size} != {self.config.vocab_size}")

        # Verify model vocabulary size matches config
        if self.model.config.vocab_size != self.config.vocab_size:
            raise ValueError(f"Model vocabulary size mismatch: {self.model.config.vocab_size} != {self.config.vocab_size}")

        # Verify batch size hasn't been modified
        if self.config.batch_size != self.original_batch_size:
            logging.warning(f"Batch size mismatch detected. Restoring original batch size: {self.original_batch_size}")
            self.config.batch_size = self.original_batch_size

    def load_model(self, model_path: Union[str, Path]):
        """Load a trained model from a checkpoint with comprehensive state restoration"""
        try:
            checkpoint = torch.load(model_path, map_location=self.config.device)

            # Preserve original batch size before loading checkpoint config
            original_batch_size = self.config.batch_size

            # First, set up vocabulary mappings before loading model state
            if 'config' in checkpoint:
                # Get vocab size from checkpoint config
                vocab_size = checkpoint['config'].get('vocab_size')
                if vocab_size:
                    self.config.vocab_size = vocab_size

                # Restore original batch size
                checkpoint['config']['batch_size'] = original_batch_size

                # Set up vocabulary if present in checkpoint
                if 'char_to_idx' in checkpoint['config'] and 'idx_to_char' in checkpoint['config']:
                    self.data_manager.char_to_idx = checkpoint['config']['char_to_idx']
                    self.data_manager.idx_to_char = checkpoint['config']['idx_to_char']
                    self.data_manager.vocab_size = len(self.data_manager.char_to_idx)
                else:
                    # If no vocabulary in checkpoint, calculate it from training data
                    self.data_manager.calculate_vocab_size()
            else:
                raise ValueError("Checkpoint missing required config")

            # Now load model state
            if 'model_state_dict' in checkpoint:
                self.model.load_state_dict(checkpoint['model_state_dict'])
            else:
                self.model.load_state_dict(checkpoint)

            self.model.eval()
            self.verify_state()

        except Exception as e:
            logging.error(f"Error loading model: {str(e)}")
            raise

    def cleanup(self):
        """Comprehensive cleanup of resources and memory"""
        try:
            if hasattr(self, 'model') and self.model is not None:
                self.model.cpu()
                del self.model
                self.model = None

            if hasattr(self, 'data_manager') and self.data_manager is not None:
                del self.data_manager
                self.data_manager = None

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            gc.collect()

            self._is_initialized = False
            logging.info("Model interface cleaned up successfully")

        except Exception as e:
            logging.error(f"Error during cleanup: {str(e)}")

In [None]:
# RUN FOR TRAINING and INFERENCE

import datetime
import logging
import torch
from tqdm import tqdm
from pathlib import Path
from contextlib import nullcontext
import gc

class Trainer:
    """Handles model training and optimization"""

    def __init__(self, model: GPT, data_manager: DataManager, config: ModelConfig):
        self.model = model
        self.data_manager = data_manager
        self.config = config
        self.device = config.device
        #self.monitor = SystemMonitor()

        # Training state
        self.epoch = 0
        self.global_step = 0
        self.best_val_loss = float('inf')
        self.no_improvement_count = 0
        self.setup_training()

    def setup_training(self):
        """Setup training parameters."""
        # Ensure data is loaded
        if self.data_manager.train_data is None:
            logging.error("Training data is not loaded. Please load the data before training.")
            raise ValueError("Training data is not loaded.")

        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay,
            betas=(0.9, 0.95)
        )

        if self.config.use_amp:
            self.scaler = torch.cuda.amp.GradScaler()

        total_steps = (
            self.config.epochs *
            len(self.data_manager.train_data) //
            (self.config.batch_size * self.config.gradient_accumulation_steps)
        )

        self.scheduler = self.get_cosine_schedule_with_warmup(
            self.optimizer,
            self.config.warmup_steps,
            total_steps
        )

        # Now you can safely access len(self.data_manager.train_data)
        num_batches = self.config.epochs * len(self.data_manager.train_data) // self.config.batch_size

        # Learning rate scheduler
        if self.config.lr_schedule == 'cosine_with_warmup':
            self.scheduler = self.get_cosine_schedule_with_warmup(
                self.optimizer,
                self.config.warmup_steps,
                self.config.epochs * len(self.data_manager.train_data) // self.config.batch_size
            )
        else:
            self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
                self.optimizer, mode='min', factor=0.5, patience=5
            )

        # Gradient scaler for mixed precision training
        self.scaler = torch.cuda.amp.GradScaler() if self.config.use_amp else None

        # Setup loss tracking
        self.train_losses = []
        self.val_losses = []

    @staticmethod
    def get_cosine_schedule_with_warmup(optimizer, warmup_steps, total_steps):
        """Create a schedule with linear warmup and cosine decay"""
        def lr_lambda(current_step):
            if current_step < warmup_steps:
                return float(current_step) / float(max(1, warmup_steps))
            progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
            return max(0.0, 0.5 * (1.0 + math.cos(math.pi * progress)))

        return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

    def save_checkpoint(self, loss: float, is_best: bool = False, custom_path: Path = None):
      """Save training checkpoint with epoch number and loss rate in the filename"""
      timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

      # Ensure save_dir is a Path object
      if not isinstance(self.config.save_dir, Path):
          self.config.save_dir = Path(self.config.save_dir)

      # Create directories if they don't exist
      self.config.save_dir.mkdir(parents=True, exist_ok=True)

      # Always save periodic checkpoints and best model
      should_save = True  # Remove restrictive conditions

      # Format loss for filename
      loss_str = f"{loss:.4f}".replace('.', '_')

      # Use custom path if provided, otherwise use default naming
      if custom_path is None:
          if is_best:
              checkpoint_name = f"best_model_epoch_{self.epoch + 1}_loss_{loss_str}_{timestamp}.pt"
          else:
              checkpoint_name = f"checkpoint_epoch_{self.epoch + 1}_loss_{loss_str}_{timestamp}.pt"
          checkpoint_path = self.config.save_dir / checkpoint_name
      else:
          checkpoint_path = custom_path

      checkpoint = {
          'epoch': self.epoch,
          'model_state_dict': self.model.state_dict(),
          'optimizer_state_dict': self.optimizer.state_dict(),
          'scheduler_state_dict': self.scheduler.state_dict() if self.scheduler else None,
          'loss': loss,
          'config': self.config.__dict__,
          'global_step': self.global_step,
          'best_val_loss': self.best_val_loss,
          'train_losses': self.train_losses,
          'val_losses': self.val_losses
      }

      try:
          # Save checkpoint
          torch.save(checkpoint, str(checkpoint_path))
          logging.info(f"Checkpoint saved: {checkpoint_path}")

          # Save metadata file
          metadata_path = checkpoint_path.with_suffix('.txt')
          with open(metadata_path, 'w') as f:
              f.write(f"Checkpoint Information:\n")
              f.write(f"Epoch: {self.epoch + 1}\n")
              f.write(f"Loss: {loss:.6f}\n")
              f.write(f"Timestamp: {timestamp}\n")
              f.write(f"Best model: {is_best}\n")
              f.write(f"Learning rate: {self.optimizer.param_groups[0]['lr']:.6f}\n")
              # Add other metadata as needed

      except Exception as e:
          logging.error(f"Error saving checkpoint: {str(e)}")
          raise

    def train(self):
        """Main training loop"""
        logging.info("Starting training")
        self.log_training_params()

        # Create checkpoint directory if it doesn't exist
        checkpoint_dir = Path('/content/drive/MyDrive/LLM/checkpoints') if is_colab() else Path('./LLM/checkpoints')
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        current_val_loss = float('inf')  # Initialize val_loss

        try:
            for epoch in range(self.epoch, self.config.epochs):
                self.epoch = epoch
                logging.info(f"Epoch {self.epoch + 1}/{self.config.epochs} starting...")

                try:
                    # Train for one epoch
                    train_loss = self.train_epoch()
                    self.train_losses.append(train_loss)
                    current_val_loss = self.evaluate('val')  # Store val_loss
                    self.val_losses.append(current_val_loss)

                    # Save periodic checkpoint
                    self.save_checkpoint(current_val_loss, False)

                    # Check if this is the best model
                    if current_val_loss < self.best_val_loss:
                        self.best_val_loss = current_val_loss
                        self.save_checkpoint(current_val_loss, True)
                        logging.info(f"New best validation loss: {current_val_loss:.4f}")

                    # Early stopping check
                    if self.no_improvement_count >= self.config.patience:
                        logging.info(f"Early stopping triggered after {epoch + 1} epochs")
                        break

                except Exception as e:
                    logging.error(f"Error in epoch {epoch + 1}: {str(e)}")
                    # Save checkpoint on error
                    error_path = checkpoint_dir / f"error_checkpoint_epoch_{epoch + 1}.pt"
                    self.save_checkpoint(current_val_loss, False, custom_path=error_path)
                    raise

        except KeyboardInterrupt:
            logging.info("Training interrupted by user")
            interrupt_path = checkpoint_dir / f"interrupted_checkpoint_epoch_{self.epoch + 1}.pt"
            self.save_checkpoint(current_val_loss, False, custom_path=interrupt_path)

        finally:
            # Save final checkpoint
            final_path = checkpoint_dir / f"final_checkpoint_epoch_{self.epoch + 1}.pt"
            self.save_checkpoint(current_val_loss, False, custom_path=final_path)

            # Final evaluation
            test_loss = self.evaluate('test')
            logging.info(f"Final test loss: {test_loss:.4f}")

    def _cleanup_old_checkpoints(self):
        """Remove old checkpoints keeping only the latest ones based on epoch intervals"""
        try:
            checkpoints = []
            for checkpoint in self.config.save_dir.glob("checkpoint_epoch_*.pt"):
                # Extract epoch and loss from filename
                parts = checkpoint.stem.split('_')
                try:
                    epoch_num = int(parts[parts.index('epoch') + 1])
                    loss_val = float(parts[parts.index('loss') + 1].replace('_', '.'))
                    checkpoints.append((checkpoint, epoch_num, loss_val))
                except (ValueError, IndexError):
                    logging.warning(f"Couldn't parse epoch/loss from checkpoint name: {checkpoint}")
                    continue

            # Sort checkpoints by epoch number
            checkpoints.sort(key=lambda x: x[1])  # Sort by epoch number

            # Always keep first and last checkpoint
            checkpoints_to_keep = set()
            if checkpoints:
                checkpoints_to_keep.add(checkpoints[0][0])  # First checkpoint
                checkpoints_to_keep.add(checkpoints[-1][0])  # Last checkpoint

            # Keep checkpoints at intervals specified by keep_last_n_checkpoints
            for checkpoint, epoch_num, _ in checkpoints:
                if epoch_num % self.config.keep_last_n_checkpoints == 0:
                    checkpoints_to_keep.add(checkpoint)

            # Keep best model checkpoints (they start with 'best_model')
            best_models = list(self.config.save_dir.glob("best_model_*.pt"))
            checkpoints_to_keep.update(best_models)

            # Remove checkpoints that are not in the keep list
            all_checkpoints = set(self.config.save_dir.glob("checkpoint_epoch_*.pt"))
            for checkpoint in all_checkpoints - checkpoints_to_keep:
                # Remove both the checkpoint and its metadata file
                checkpoint.unlink()
                metadata_file = checkpoint.with_suffix('.txt')
                if metadata_file.exists():
                    metadata_file.unlink()
                logging.info(f"Removed checkpoint: {checkpoint}")

        except Exception as e:
            logging.warning(f"Error during checkpoint cleanup: {str(e)}")

    def save_model(self, model_path: Path, eval_loss: float):
        """Save the trained model with validation loss and epoch in filename"""
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        loss_str = f"{eval_loss:.4f}".replace('.', '_')

        # Ensure we're using the correct models directory path
        models_dir = Path('/content/drive/MyDrive/LLM/models') if is_colab() else Path('./LLM/models')
        models_dir.mkdir(parents=True, exist_ok=True)

        # Include epoch number and loss in the filename
        unique_model_path = models_dir / f'trained_model_epoch_{self.epoch + 1}_loss_{loss_str}_{timestamp}.pt'

        # Save complete model state
        save_dict = {
            'model_state_dict': self.model.state_dict(),
            'config': self.config.__dict__,
            'epoch': self.epoch + 1,
            'final_loss': eval_loss,
            'vocab_size': self.data_manager.vocab_size,
            'char_to_idx': self.data_manager.char_to_idx,
            'idx_to_char': self.data_manager.idx_to_char,
            'timestamp': timestamp
        }

        torch.save(save_dict, unique_model_path)
        logging.info(f"Model saved to {unique_model_path}")

        # Save metadata
        metadata_path = unique_model_path.with_suffix('.txt')
        with open(metadata_path, 'w') as f:
            f.write(f"Model Information:\n")
            f.write(f"Epoch: {self.epoch + 1}\n")
            f.write(f"Final Loss: {eval_loss:.6f}\n")
            f.write(f"Training completed at: {timestamp}\n")
            f.write(f"Vocabulary size: {self.data_manager.vocab_size}\n")
            f.write(f"Model parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M\n")
            f.write("\nModel Configuration:\n")
            for key, value in self.config.__dict__.items():
                f.write(f"{key}: {value}\n")

        return unique_model_path

    def load_checkpoint(self, checkpoint_path: Union[str, Path]):
        """Load training checkpoint"""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        self.model.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        self.epoch = checkpoint['epoch']
        self.global_step = checkpoint['global_step']

        logging.info(f"Loaded checkpoint from epoch {self.epoch}")

    def train_epoch(self) -> float:
      """Train for one epoch"""
      self.model.train()
      total_loss = 0.0
      running_loss = 0.0

      try:
          # Get number of sequences in training data
          n_sequences = len(self.data_manager.train_data[0])  # Access first element of tuple
          num_batches = n_sequences // self.config.batch_size

          progress_bar = tqdm(range(num_batches),
                            desc=f'Epoch {self.epoch + 1}/{self.config.epochs}',
                            dynamic_ncols=True)

          batch_count = 0
          for _ in progress_bar:
              if batch_count >= num_batches:
                  break

              # Get batch
              x, y = self.data_manager.get_batch('train', self.config.batch_size)

              # Perform training step
              batch_loss = self.train_step(x, y)

              # Update metrics
              total_loss += batch_loss
              running_loss += batch_loss
              batch_count += 1

              # Update progress bar only once per iteration
              avg_loss = total_loss / batch_count
              progress_bar.set_postfix({
                  'loss': f'{batch_loss:.4f}',
                  'avg_loss': f'{avg_loss:.4f}',
                  'lr': f'{self.optimizer.param_groups[0]["lr"]:.6f}'
              })

              if batch_count % self.config.eval_steps == 0:
                  current_avg_loss = running_loss / self.config.eval_steps
                  running_loss = 0.0

                  if current_avg_loss < self.config.within_epoch_loss_threshold:
                      logging.info(f"Loss threshold reached: {current_avg_loss:.4f}")
                      break

          progress_bar.close()
          return total_loss / batch_count if batch_count > 0 else float('inf')

      except Exception as e:
          logging.error(f"Error in train_epoch: {str(e)}")
          raise

    def train_step(self, x: torch.Tensor, y: torch.Tensor) -> float:
        """Perform one training step with gradient accumulation"""
        accumulated_loss = 0.0
        micro_batch_size = self.config.batch_size // self.config.gradient_accumulation_steps

        for i in range(self.config.gradient_accumulation_steps):
            start_idx = i * micro_batch_size
            end_idx = start_idx + micro_batch_size
            micro_x = x[start_idx:end_idx]
            micro_y = y[start_idx:end_idx]

            with torch.cuda.amp.autocast() if self.config.use_amp else nullcontext():
                logits, loss = self.model(micro_x, targets=micro_y)
                loss = loss / self.config.gradient_accumulation_steps

            accumulated_loss += loss.item() * self.config.gradient_accumulation_steps

            if self.config.use_amp:
                self.scaler.scale(loss).backward()
            else:
                loss.backward()

            if (i + 1) % self.config.gradient_accumulation_steps == 0:
                if self.config.max_grad_norm > 0:
                    if self.config.use_amp:
                        self.scaler.unscale_(self.optimizer)
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)

                if self.config.use_amp:
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
                else:
                    self.optimizer.step()

                self.optimizer.zero_grad(set_to_none=True)

                if isinstance(self.scheduler, torch.optim.lr_scheduler.LambdaLR):
                    self.scheduler.step()

                self.global_step += 1

        return accumulated_loss

    @torch.no_grad()
    def evaluate(self, split: str = 'val') -> float:
        """Evaluate model on validation or test set"""
        was_training = self.model.training
        self.model.eval()

        try:
            total_loss = 0.0
            num_batches = 0

            data = getattr(self.data_manager, f'{split}_data')
            total_samples = data[0].size(0)

            for i in range(0, total_samples, self.config.batch_size):
                batch_x = data[0][i:i+self.config.batch_size]
                batch_y = data[1][i:i+self.config.batch_size]

                if len(batch_x) == 0:
                    continue

                _, loss = self.model(batch_x.to(self.device),
                                  targets=batch_y.to(self.device))
                total_loss += loss.item()
                num_batches += 1

            return total_loss / num_batches if num_batches > 0 else float('inf')

        finally:
            # Restore original training state
            self.model.train(was_training)

print("Section 5: Training system setup complete")

Section 5: Training system setup complete


In [None]:
# RUN FOR TRAINING -> Pipeline!

import logging
import torch
import datetime
from pathlib import Path
from typing import Union
import gc

class TrainingPipeline:
    """Handles the overall training pipeline."""

    def __init__(self, model: GPT, config: ModelConfig):
        self.model = model
        self.config = config
        self.data_manager = DataManager(config)  # Instantiate the DataManager

        # Load the datasets using DataManager
        self.data_manager.load_data()

        # Verify data is loaded
        if self.data_manager.train_data is None:
            raise ValueError("Training data failed to load")

        # Initialize the Trainer with loaded data
        self.trainer = Trainer(self.model, self.data_manager, self.config)

    def run(self):
        """Run the training pipeline."""
        try:
            # Log training parameters before starting training
            self.log_training_params()

            # Step 3: Train the model
            logging.info("Starting training...")
            self.trainer.train()  # This will handle the entire training process

        except ValueError as ve:
            logging.error(f"ValueError: {str(ve)}")
            raise
        except FileNotFoundError as fnf_error:
            logging.error(f"File not found: {str(fnf_error)}")
            raise
        except torch.cuda.OutOfMemoryError:
            logging.error("CUDA out of memory. Reduce batch size or use a smaller model.")
            raise
        except Exception as e:
            logging.error(f"Error running training pipeline: {str(e)}")
            raise

    def train_and_evaluate(self):
        """Main function to train and evaluate the model."""
        model_path = Path('/content/drive/MyDrive/LLM/models')
        checkpoint_path = Path('/content/drive/MyDrive/LLM/checkpoints')

        for path in [model_path, checkpoint_path]:
            path.mkdir(parents=True, exist_ok=True)

        try:
            # Print final derived configuration
            print("\nFinal Configuration:")
            for key, value in vars(self.config).items():
                print(f"{key}: {value}")
            print("-" * 50)

            print("\nStarting training with:")
            print(f"- Model size: {sum(p.numel() for p in self.model.parameters()) / 1e6:.2f}M parameters")
            print(f"- Device: {self.config.device}")
            print(f"- Batch size: {self.config.batch_size}")
            print(f"- Learning rate: {self.config.learning_rate}")

            # Train model
            best_val_loss = float('inf')
            for epoch in range(self.config.epochs):
                train_loss = self.trainer.train_epoch()
                val_loss = self.trainer.evaluate('val')

                # Save checkpoint after each epoch
                is_best = val_loss < best_val_loss
                if is_best:
                    best_val_loss = val_loss
                self.trainer.save_checkpoint(val_loss, is_best)

                print(f"Epoch {epoch+1}: Train Loss = {train_loss:.4f}, Val Loss = {val_loss:.4f}")

            # Final evaluation and save
            final_eval_loss = self.trainer.evaluate('val')
            print(f"\nFinal validation loss: {final_eval_loss:.4f}")

            # Save final model
            final_model_path = self.save_model(model_path, final_eval_loss)
            print(f"\nFinal model saved to: {final_model_path}")

        except KeyboardInterrupt:
            print("\nInterrupted by user. Saving current state...")
            current_val_loss = self.trainer.evaluate('val')
            self.trainer.save_checkpoint(current_val_loss, False)
            self.save_model(model_path, current_val_loss)

        except Exception as e:
            print(f"\nError occurred: {str(e)}")
            raise

        finally:
            print("\nCleaning up...")
            torch.cuda.empty_cache()
            gc.collect()

    def save_model(self, model_path: Path, eval_loss: float):
        """Save the trained model with validation loss in filename."""
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        loss_str = f"{eval_loss:.4f}".replace('.', '_')  # Convert 1.3369 to 1_3369

        # Ensure we're using the correct models directory path
        models_dir = Path('/content/drive/MyDrive/LLM/models') if is_colab() else Path('./LLM/models')
        models_dir.mkdir(parents=True, exist_ok=True)

        unique_model_path = models_dir / f'trained_model_loss_{loss_str}_{timestamp}.pt'

        # Save complete model state
        save_dict = {
            'model_state_dict': self.model.state_dict(),
            'config': self.config.__dict__,
            'final_loss': eval_loss,
            'vocab_size': self.data_manager.vocab_size,
            'char_to_idx': self.data_manager.char_to_idx,
            'idx_to_char': self.data_manager.idx_to_char
        }

        torch.save(save_dict, unique_model_path)
        logging.info(f"Model saved to {unique_model_path}")

        # Save metadata
        metadata_path = unique_model_path.with_suffix('.txt')
        with open(metadata_path, 'w') as f:
            f.write(f"Training completed at: {timestamp}\n")
            f.write(f"Final validation loss: {eval_loss:.4f}\n")
            f.write(f"Vocabulary size: {self.data_manager.vocab_size}\n")
            f.write(f"Model parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M\n")
            f.write("\nModel Configuration:\n")
            for key, value in vars(self.config).items():
                f.write(f"{key}: {value}\n")

        return unique_model_path

In [None]:
if __name__ == "__main__":
    try:
        # Initialize the configuration with optimized values for small LLM
        config = ModelConfig(
            # Model Architecture
            vocab_size=0,  # Will be updated after data loading
            block_size=256,  # Context window size
            n_layer=6,     # Transformer layers
            n_head=8,      # Attention heads
            n_embed=384,   # Embedding dimension
            ff_dim=1536,   # Feed-forward dimension (4x n_embed)
            head_dim=64,   # New: dimension per attention head

            # Regularization
            dropout=0.2,
            attn_pdrop=0.1,   # New: specific attention dropout
            resid_pdrop=0.1,  # New: residual dropout
            embd_pdrop=0.1,   # New: embedding dropout
            weight_decay=0.1,

            # Architecture Features
            bias=True,
            flash_attn=True,  # New: use flash attention if available
            use_gradient_checkpointing=True,  # Memory efficiency

            # Training Parameters
            batch_size=16,
            gradient_accumulation_steps=2,
            learning_rate=0.0001,
            min_learning_rate=1e-5,  # New: minimum learning rate
            warmup_steps=500,
            epochs=100,
            lr_decay_epochs=30,

            # Early Stopping
            loss_threshold=0.3,
            within_epoch_loss_threshold=0.3,
            patience=5,        # New: early stopping patience
            min_delta=1e-4,   # New: minimum change for improvement

            # Evaluation and Checkpointing
            eval_steps=1000,
            save_every=5,
            keep_last_n_checkpoints=5,

            # Generation Parameters
            gen_temperature=0.8,
            max_gen_tokens=64,
            top_k=50,
            top_p=0.9,        # New: nucleus sampling parameter

            # System and Memory
            device='cuda' if torch.cuda.is_available() else 'cpu',
            pin_memory=True,
            num_workers=8,
            prefetch_factor=2,

            # Precision
            use_amp=True,
            amp_dtype=torch.bfloat16,
            dtype=torch.bfloat16
        )

        # Initialize data pipeline with error handling
        logging.info("Initializing data manager...")
        data_manager = DataManager(config)

        # Load and process data
        logging.info("Loading and processing data...")
        data_manager.load_data()

        # Calculate vocabulary size
        vocab_size = data_manager.calculate_vocab_size()
        logging.info(f"Calculated vocabulary size: {vocab_size}")

        # Analyze vocabulary
        unique_characters = list(data_manager.char_to_idx.keys())
        logging.info(f"Found {len(unique_characters)} unique characters")

        if logging.getLogger().isEnabledFor(logging.DEBUG):
            logging.debug("Unique characters:")
            logging.debug(unique_characters)

        # Update config with calculated vocab size
        config.vocab_size = vocab_size
        logging.info(f"Updated config vocabulary size: {config.vocab_size}")

        # Initialize model with memory optimization
        logging.info("Initializing model...")
        with torch.cuda.amp.autocast(enabled=config.use_amp):
            model = GPT(config).to(config.device)

        # Log model parameters
        total_params = sum(p.numel() for p in model.parameters())
        logging.info(f"Model initialized with {total_params/1e6:.2f}M parameters")

        # Optional: Print model architecture in debug mode
        if logging.getLogger().isEnabledFor(logging.DEBUG):
            logging.debug("Model Architecture:")
            logging.debug(str(model))

        # Initialize and run training pipeline with error handling
        logging.info("Initializing training pipeline...")
        pipeline = TrainingPipeline(model, config)

        # Run training with resource monitoring
        logging.info("Starting training...")
        try:
            pipeline.train_and_evaluate()
        except KeyboardInterrupt:
            logging.info("Training interrupted by user. Saving checkpoint...")
            pipeline.trainer.save_checkpoint(
                pipeline.trainer.evaluate('val'),
                is_best=False,
                custom_path=Path(config.save_dir) / "interrupted_checkpoint.pt"
            )

        logging.info("Training completed successfully")

    except Exception as e:
        logging.error(f"Error during execution: {str(e)}", exc_info=True)
        raise
    finally:
        # Cleanup
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()

Number of parameters: 10.73M

Final Configuration:
save_dir: /content/drive/MyDrive/LLM/checkpoints
vocab_size: 100
block_size: 256
n_embed: 384
n_head: 8
n_layer: 6
ff_dim: 1536
dropout: 0.2
bias: True
flash_attn: True
head_dim: 64
attn_pdrop: 0.1
resid_pdrop: 0.1
embd_pdrop: 0.1
batch_size: 16
gradient_accumulation_steps: 2
learning_rate: 0.0001
min_learning_rate: 1e-05
weight_decay: 0.1
epochs: 100
warmup_steps: 500
max_grad_norm: 1.0
within_epoch_loss_threshold: 0.3
eval_steps: 1000
lr_schedule: cosine_with_warmup
lr_decay_epochs: 30
use_amp: True
amp_dtype: torch.bfloat16
dtype: torch.bfloat16
use_gradient_checkpointing: True
pin_memory: True
num_workers: 8
prefetch_factor: 2
gen_temperature: 0.8
max_gen_tokens: 64
top_k: 50
top_p: 0.9
patience: 5
min_delta: 0.0001
loss_threshold: 0.3
eval_every: 1000
save_every: 5
keep_last_n_checkpoints: 5
device: cuda
--------------------------------------------------

Starting training with:
- Model size: 10.73M parameters
- Device: cuda
- Bat

Epoch 1/100: 100%|██████████| 546/546 [00:41<00:00, 13.31it/s, loss=3.9880, avg_loss=5.1647, lr=0.000100]


Epoch 1: Train Loss = 5.1647, Val Loss = 1.9076


Epoch 1/100: 100%|██████████| 546/546 [00:40<00:00, 13.46it/s, loss=3.3794, avg_loss=3.7313, lr=0.000100]


Epoch 2: Train Loss = 3.7313, Val Loss = 1.6483


Epoch 1/100: 100%|██████████| 546/546 [00:40<00:00, 13.42it/s, loss=3.1667, avg_loss=3.4077, lr=0.000100]


Epoch 3: Train Loss = 3.4077, Val Loss = 1.5358


Epoch 1/100: 100%|██████████| 546/546 [00:40<00:00, 13.44it/s, loss=3.1511, avg_loss=3.1972, lr=0.000100]


Epoch 4: Train Loss = 3.1972, Val Loss = 1.4599


Epoch 1/100: 100%|██████████| 546/546 [00:40<00:00, 13.36it/s, loss=2.9809, avg_loss=3.0695, lr=0.000100]


Epoch 5: Train Loss = 3.0695, Val Loss = 1.4150


Epoch 1/100:  55%|█████▌    | 301/546 [00:22<00:18, 13.39it/s, loss=2.9770, avg_loss=2.9820, lr=0.000000]



Interrupted by user. Saving current state...

Cleaning up...
