In [None]:
import os
import sys
import io
import time
import random
import json
import gc
import warnings
from pathlib import Path
from datetime import timedelta
import psutil
import numpy as np
from contextlib import contextmanager, nullcontext
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm  # or tqdm.auto if needed
from typing import List, Union, Optional, Dict, Any
from dataclasses import dataclass
import threading
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import tempfile


# Check if running in Google Colab
def is_colab():
    try:
        return 'google.colab' in str(get_ipython())
    except NameError:
        return False

# Mount Google Drive if in Colab
if is_colab():
    from google.colab import drive
    drive.mount('/content/drive')

# Create LLM directory in Drive
base_dir = Path('/content/drive/MyDrive/LLM') if is_colab() else Path('./LLM')
for dir_name in ['checkpoints', 'models', 'logs', 'configs', 'data']:
    (base_dir / dir_name).mkdir(parents=True, exist_ok=True)

def is_package_installed(package_name):
    try:
        __import__(package_name)
        return True
    except ImportError:
        return False

if is_colab():
    # PyTorch packages
    pytorch_packages = ['torch', 'torchvision', 'torchaudio']
    pytorch_install = [pkg for pkg in pytorch_packages if not is_package_installed(pkg)]
    if pytorch_install:
        !pip install {' '.join(pytorch_install)}

    # Additional packages
    other_packages = ['pynvml', 'nvidia_ml_py3', 'gputil', 'fastapi', 'uvicorn', 'pydantic', 'sentencepiece']
    other_install = [pkg for pkg in other_packages if not is_package_installed(pkg)]
    if other_install:
        !pip install {' '.join(other_install)}

# Suppress warnings
warnings.filterwarnings('ignore')

print("Section 1: Initial setup and core components complete")

In [None]:
@dataclass
class ModelConfig:
    """Configuration for model training and architecture"""
    # Model Architecture
    vocab_size: int = 0
    block_size: int = 64
    n_embed: int = 64
    n_head: int = 4
    n_layer: int = 4
    ff_dim: int = 256
    head_dim: int = 32

    # Regularization
    resid_pdrop: float = 0.2
    weight_decay: float = 0.1

    # Architecture Features
    bias: bool = True
    flash_attn: bool = True
    use_gradient_checkpointing: bool = False

    # Training Parameters
    batch_size: int = 32
    gradient_accumulation_steps: int = 1
    epochs: int = 10
    max_grad_norm: float = 1.0
    gradient_clip_val: float = 1.0
    log_interval: int = 100  # Added this parameter

    # Enhanced Learning Rate Parameters
    learning_rate: float = 1e-4
    min_learning_rate: float = 1e-5
    warmup_steps: int = 200
    lr_schedule: str = 'cosine_with_warmup'  # Options: 'cosine_with_warmup', 'linear_with_warmup', 'step'
    lr_decay_epochs: int = 8  # For step scheduler
    warmup_ratio: float = 0.001  # Alternative to warmup_steps (ratio of total training steps)

    # Early Stopping
    patience: int = 5
    min_delta: float = 1e-4
    loss_threshold: float = 1.0
    within_epoch_loss_threshold: float = 0.3

    # Evaluation and Checkpointing
    eval_steps: int = 10000
    eval_every: int = 10000
    save_every: int = 1
    keep_last_n_checkpoints: int = 5

    # Precision
    use_amp: bool = True
    amp_dtype: torch.dtype = torch.bfloat16
    dtype: torch.dtype = torch.bfloat16

    # System Parameters
    pin_memory: bool = True
    device: str = "cuda"
    num_workers: int = 8
    prefetch_factor: int = 5

    # Paths
    save_dir: Union[str, Path] = Path("/content/drive/MyDrive/LLM/checkpoints")
    tokenizer_model_path='/content/drive/MyDrive/LLM/configs/sentencepiece.model'
    training_data_path='/content/drive/MyDrive/LLM/data/training.txt'
    validation_data_path='/content/drive/MyDrive/LLM/data/validation.txt'

    # Generation Parameters
    gen_temperature: float = 0.8
    max_gen_tokens: int = 256
    top_k: int = 50
    top_p: float = 0.9

    def __post_init__(self):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
        if not isinstance(self.save_dir, Path):
            self.save_dir = Path(self.save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)

    def print_params(self) -> None:
        """Print model and training parameters"""
        print("Model Parameters:")
        for key, value in self.__dict__.items():
            print(f"  {key}: {value}")

    def save(self, path: Union[str, Path]) -> None:
        """Save configuration to JSON"""
        path = Path(path)
        config_dict = {
            k: str(v) if isinstance(v, (Path, torch.dtype)) else v
            for k, v in self.__dict__.items()
        }
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open('w') as f:
            json.dump(config_dict, f, indent=4)

    @classmethod
    def load(cls, path: Union[str, Path]) -> 'ModelConfig':
        """Load configuration from JSON"""
        path = Path(path)
        with path.open('r') as f:
            config_dict = json.load(f)

        if 'dtype' in config_dict:
            config_dict['dtype'] = getattr(torch, config_dict['dtype'].split('.')[-1])
        if 'save_dir' in config_dict:
            config_dict['save_dir'] = str(config_dict['save_dir'])

        return cls(**config_dict)

In [None]:

'''
from pathlib import Path
import sentencepiece as spm
import math
import os
from typing import List, Tuple
import random

class VocabularyOptimizer(ModelConfig):
    def __init__(self, data_path: Path, max_candidates=5):
        self.data_path = data_path
        self.max_candidates = max_candidates
        self.dataset_stats = self._analyze_dataset()

    def _analyze_dataset(self) -> dict:
        """Analyze dataset characteristics to inform vocab size selection"""
        with open(self.data_path, 'r', encoding='utf-8') as f:
            text = f.read()

        return {
            'size_mb': os.path.getsize(self.data_path) / (1024 ** 2),
            'total_chars': len(text),
            'unique_chars': len(set(text)),
            'avg_word_length': self._calculate_avg_word_length(text)
        }

    def _calculate_avg_word_length(self, text: str) -> float:
        """Calculate average word length for whitespace tokenization"""
        words = text.split()
        return sum(len(word) for word in words) / len(words) if words else 0

    def _generate_vocab_candidates(self) -> List[int]:
        """Generate candidate vocab sizes based on dataset characteristics"""
        base_size = max(
            self.dataset_stats['unique_chars'] * 2,  # Minimum coverage
            int(self.dataset_stats['total_chars'] ** 0.5)  # Square root heuristic
        )
        return sorted({
            base_size,
            base_size * 2,
            base_size * 4,
            int(2 ** (math.log2(base_size) + 1)),
            min(32000, base_size * 8)  # Upper bound
        })

    def _evaluate_vocab_size(self, vocab_size: int) -> Tuple[float, float]:
        """Train and evaluate a single vocab size using perplexity"""
        model_prefix = f"temp_model_{vocab_size}"

        # Train SentencePiece model
        spm.SentencePieceTrainer.train(
            input=str(self.data_path),
            model_prefix=model_prefix,
            vocab_size=vocab_size,
            character_coverage=1.0,
            model_type='bpe',
            num_threads=os.cpu_count()
        )

        # Evaluate model quality
        sp_model = spm.SentencePieceProcessor()
        sp_model.load(f"{model_prefix}.model")

        # Calculate compression ratio and entropy
        encoded = sp_model.encode_as_ids(open(self.data_path).read())
        entropy = self._calculate_entropy(encoded, vocab_size)
        compression_ratio = len(encoded) / self.dataset_stats['total_chars']

        # Cleanup temporary model
        os.remove(f"{model_prefix}.model")
        os.remove(f"{model_prefix}.vocab")

        return entropy, compression_ratio

    def _calculate_entropy(self, token_ids: List[int], vocab_size: int) -> float:
        """Calculate empirical entropy of token distribution"""
        from collections import Counter
        counts = Counter(token_ids)
        total = len(token_ids)
        return -sum((count/total) * math.log2(count/total)
                  for count in counts.values() if count > 0)

    def optimize_vocab_size(self) -> int:
        """Main optimization pipeline with early stopping"""
        candidates = self._generate_vocab_candidates()[:self.max_candidates]
        best_size = None
        best_score = float('inf')
        results = []

        for vocab_size in sorted(candidates):
            try:
                entropy, compression = self._evaluate_vocab_size(vocab_size)
                # Combined score weights entropy 70% and compression 30%
                score = 0.7 * entropy + 0.3 * compression

                results.append((vocab_size, score, entropy, compression))

                if score < best_score:
                    best_score = score
                    best_size = vocab_size
                else:
                    # Early stopping if performance plateaus
                    break

            except Exception as e:
                print(f"Failed training with vocab_size {vocab_size}: {str(e)}")
                continue

        # Print formatted results table
        print("Vocab Size\tTotal Score\tEntropy\tCompression Ratio")
        for result in results:
            print(f"{result[0]}\t{result[1]:.4f}\t{result[2]:.4f}\t{result[3]:.4f}")

        return best_size

# Usage example
if __name__ == "__main__":
    optimizer = VocabularyOptimizer(Path(ModelConfig.training_data_path))
    optimal_vocab_size = optimizer.optimize_vocab_size()
    print(f"Recommended vocabulary size: {optimal_vocab_size}")

    '''

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
from typing import Union
import sentencepiece as spm


class TokenDataset(Dataset):
    """Dataset for tokenized text sequences."""
    def __init__(self, tokens: list, block_size: int):
        self.tokens = tokens
        self.block_size = block_size

    def __len__(self):
        return len(self.tokens) - self.block_size

    def __getitem__(self, idx: int):
        x = self.tokens[idx : idx + self.block_size]
        y = self.tokens[idx + 1 : idx + self.block_size + 1]
        return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)

class DataManager:
    def __init__(self, config: ModelConfig):
        self.config = config
        self._tokenizer_initialized = False
        self.tokenizer = None
        self.train_loader = None
        self.val_loader = None

    def initialize_tokenizer(self):
        """Initialize SentencePiece tokenizer from training data."""
        try:
            model_dir = Path(self.config.tokenizer_model_path).parent
            model_dir.mkdir(parents=True, exist_ok=True)
            model_prefix = model_dir / 'sentencepiece'

            spm.SentencePieceTrainer.train(
                input=str(self.config.training_data_path),
                model_prefix=str(model_prefix),
                vocab_size=self.config.vocab_size,
                character_coverage=1.0,
                model_type='bpe'
            )

            self.tokenizer = spm.SentencePieceProcessor()
            self.tokenizer.load(str(model_prefix) + '.model')

            self.char_to_idx = {self.tokenizer.id_to_piece(i): i for i in range(self.tokenizer.get_piece_size())}
            self.idx_to_char = {i: self.tokenizer.id_to_piece(i) for i in range(self.tokenizer.get_piece_size())}

            self._tokenizer_initialized = True

        except Exception as e:
            print(f"Error initializing tokenizer: {str(e)}")
            raise

    def _load_and_tokenize(self, file_path: Union[str, Path]) -> list:
        """Tokenize a text file into a list of token IDs."""
        tokens = []
        with open(file_path, 'r', encoding='utf-8') as f:
            for line in f:
                if not line.strip():
                    continue
                line_tokens = self.tokenizer.encode(line.strip(), out_type=int)
                tokens.extend(line_tokens)
        return tokens

    def load_data(self):
        """Load and tokenize training/validation data into DataLoaders."""
        if not self._tokenizer_initialized:
            self.initialize_tokenizer()

        # Load training data
        train_tokens = self._load_and_tokenize(self.config.training_data_path)
        self.train_dataset = TokenDataset(train_tokens, self.config.block_size)
        self.train_loader = DataLoader(
            self.train_dataset,
            batch_size=self.config.batch_size,
            shuffle=True,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory,
            prefetch_factor=self.config.prefetch_factor
        )

        # Load validation data
        val_tokens = self._load_and_tokenize(self.config.validation_data_path)
        self.val_dataset = TokenDataset(val_tokens, self.config.block_size)
        self.val_loader = DataLoader(
            self.val_dataset,
            batch_size=self.config.batch_size,
            shuffle=False,
            num_workers=self.config.num_workers,
            pin_memory=self.config.pin_memory,
            prefetch_factor=self.config.prefetch_factor
        )

    def cleanup(self):
        """Cleanup resources and memory."""
        for attr in ['train_loader', 'val_loader', 'train_dataset', 'val_dataset', 'tokenizer']:
            if hasattr(self, attr):
                delattr(self, attr)
        torch.cuda.empty_cache()

In [None]:
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_seq_len_cached = max_position_embeddings

        # Initialize cache
        t = torch.arange(max_position_embeddings, device=inv_freq.device).type_as(inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

    def forward(self, x, seq_len=None):
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            self.cos_cached = emb.cos()[None, None, :, :]
            self.sin_cached = emb.sin()[None, None, :, :]

        return (
            self.cos_cached[:, :, :seq_len, ...].to(x.device),
            self.sin_cached[:, :, :seq_len, ...].to(x.device)
        )

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class CausalSelfAttention(nn.Module):
    """Multi-head causal self-attention layer with RoPE and residual dropout"""
    def __init__(self, config):
        super().__init__()
        assert config.n_embed % config.n_head == 0

        # Store necessary dimensions
        self.n_embed = config.n_embed
        self.n_head = config.n_head
        self.head_dim = config.n_embed // config.n_head

        # Key, query, value projections for all heads
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed, bias=config.bias)

        # Regularization
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

        # Initialize RoPE
        self.rope = RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=config.block_size
        )

        # Flash attention support
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            self.register_buffer(
                "mask",
                torch.tril(torch.ones(config.block_size, config.block_size))
                .view(1, 1, config.block_size, config.block_size)
            )

        # Scaling factor for attention
        self.scale = 1.0 / math.sqrt(self.head_dim)

    def forward(self, x):
        B, T, C = x.size()

        # Calculate query, key, values for all heads
        q, k, v = self.c_attn(x).split(self.n_embed, dim=2)

        # Reshape for multi-head attention
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Apply RoPE to queries and keys
        cos, sin = self.rope(q, seq_len=T)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # Causal self-attention with flash attention optimization
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=0.0,  # Disabled in favor of residual dropout
                is_causal=True,
                scale=self.scale
            )
        else:
            att = (q @ k.transpose(-2, -1)) * self.scale
            att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            y = att @ v

        # Re-assemble and project
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

class LayerNorm(nn.Module):
    """LayerNorm with optional bias"""
    def __init__(self, ndim, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class GPTBlock(nn.Module):
    """Transformer block with residual dropout"""
    def __init__(self, config):
        super().__init__()
        self.ln1 = LayerNorm(config.n_embed, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln2 = LayerNorm(config.n_embed, bias=config.bias)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias),
            nn.GELU(),
            nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias),
            nn.Dropout(config.resid_pdrop)
        )

        # Training optimizations
        self.use_checkpointing = config.use_gradient_checkpointing
        self.layer_scale_1 = nn.Parameter(torch.ones(config.n_embed) * 0.1)
        self.layer_scale_2 = nn.Parameter(torch.ones(config.n_embed) * 0.1)

    def forward(self, x):
        # Attention block
        if self.use_checkpointing and x.requires_grad:
            attn_output = torch.utils.checkpoint.checkpoint(self.attn, self.ln1(x))
        else:
            attn_output = self.attn(self.ln1(x))
        x = x + self.layer_scale_1.unsqueeze(0).unsqueeze(0) * attn_output

        # MLP block
        if self.use_checkpointing and x.requires_grad:
            mlp_output = torch.utils.checkpoint.checkpoint(self.mlp, self.ln2(x))
        else:
            mlp_output = self.mlp(self.ln2(x))
        x = x + self.layer_scale_2.unsqueeze(0).unsqueeze(0) * mlp_output

        return x

class GPT(nn.Module):
    """GPT-like transformer with residual dropout"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.n_embed),
            'h': nn.ModuleList([GPTBlock(config) for _ in range(config.n_layer)]),
            'ln_f': nn.LayerNorm(config.n_embed)
        })
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)

        # Init
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        print(f"Parameters: {sum(p.numel() for p in self.parameters()) / 1e6:.2f}M")

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, LayerNorm):
            torch.nn.init.ones_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()

        with torch.cuda.amp.autocast(enabled=self.config.use_amp, dtype=torch.bfloat16):
            x = self.transformer.wte(idx)
            for block in self.transformer.h:
                x = block(x)
            x = self.transformer.ln_f(x)

            logits = self.lm_head(x)

            # Fixed loss calculation
            if targets is not None:
                loss = F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    targets.view(-1)
                )
            else:
                # For generation mode
                logits = logits[:, -1, :]
                loss = None

        return logits, loss

In [None]:
class ModelInterface:
    """Interface for model operations with comprehensive error handling and state management"""

    def __init__(self, config: ModelConfig):
        """Initialize the interface with configuration and necessary components"""
        self.config = config
        self.original_batch_size = config.batch_size  # Store original batch size
        self.model = None
        self.data_manager = None
        self.scaler = torch.cuda.amp.GradScaler(enabled=getattr(config, 'use_amp', True))
        self._is_initialized = False
        self.checkpoint_info = None
        self._initialize()

    def _initialize(self):
        """Safe initialization with resource management and state verification"""
        try:
            # Verify and adjust batch size if needed
            if self.config.batch_size != self.original_batch_size:
                print(f"Batch size was modified from {self.original_batch_size} to {self.config.batch_size}")
                self.config.batch_size = self.original_batch_size  # Restore original batch size

            # Initialize data manager first
            self.data_manager = DataManager(self.config)

            # Initialize model with proper device placement
            self.model = GPT(self.config)
            self.model = self.model.to(self.config.device)

            self._is_initialized = True

            # Log initialization details
            print(f"Model initialized with {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M parameters")
            print(f"Using device: {self.config.device}")
            print(f"Batch size: {self.config.batch_size}")

            # Adjust gradient accumulation based on GPU memory if needed
            if self.config.device == "cuda":
                gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
                if gpu_mem < 8:  # Only adjust if very limited memory
                    self.config.gradient_accumulation_steps = 2
                    print(f"Adjusted gradient accumulation steps to {self.config.gradient_accumulation_steps} due to limited GPU memory")

        except Exception as e:
            self.cleanup()
            raise RuntimeError(f"Failed to initialize model interface: {str(e)}")

    def verify_state(self):
        """Verify that the model and vocabulary are in a valid state"""
        if not self._is_initialized:
            raise RuntimeError("Model interface not properly initialized")

        if not self.model:
            raise RuntimeError("Model not loaded")

        if not self.data_manager:
            raise RuntimeError("Data manager not initialized")

        if not self.data_manager.char_to_idx:
            raise RuntimeError("Vocabulary mappings not loaded")

        # Verify vocabulary consistency
        if self.data_manager.vocab_size != self.config.vocab_size:
            raise ValueError(f"Vocabulary size mismatch: {self.data_manager.vocab_size} != {self.config.vocab_size}")

        # Verify model vocabulary size matches config
        if self.model.config.vocab_size != self.config.vocab_size:
            raise ValueError(f"Model vocabulary size mismatch: {self.model.config.vocab_size} != {self.config.vocab_size}")

        # Verify batch size hasn't been modified
        if self.config.batch_size != self.original_batch_size:
            print(f"Batch size mismatch detected. Restoring original batch size: {self.original_batch_size}")
            self.config.batch_size = self.original_batch_size


    def load_model(self, model_path: Union[str, Path]):
        """Load model with strict config validation and Tiktoken tokenizer checks"""
        try:
            # Load checkpoint with device mapping
            checkpoint = torch.load(model_path, map_location=self.config.device)

            # 1. Update config from checkpoint
            self.config.__dict__.update(checkpoint['config'])
            self.data_manager.config.__dict__.update(checkpoint['config'])

            # 2. Initialize Tiktoken tokenizer
            self.data_manager.initialize_tokenizer()

            # 4. Initialize model with updated config
            self.model = GPT(self.config).to(self.config.device)

            # 5. Load weights with architecture validation
            self.model.load_state_dict(checkpoint['model_state_dict'])
            self.model.eval()

            # 6. Force device alignment
            self.data_manager.config.device = self.config.device

            # 7. Warmup GPU and verify state
            with torch.cuda.amp.autocast():
                _ = self.model.generate(
                    torch.zeros((1,1), dtype=torch.long, device=self.config.device),
                    max_new_tokens=1
                )

            self._is_initialized = True
            print(f"Model loaded successfully on {self.config.device}")

        except Exception as e:
            self.cleanup()
            raise RuntimeError(f"Model loading failed: {str(e)}")

    def cleanup(self):
        """Comprehensive cleanup of resources and memory"""
        try:
            if hasattr(self, 'model') and self.model is not None:
                self.model.cpu()
                del self.model
                self.model = None

            if hasattr(self, 'data_manager') and self.data_manager is not None:
                del self.data_manager
                self.data_manager = None

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            gc.collect()

            self._is_initialized = False
            print("Model interface cleaned up successfully")

        except Exception as e:
            print(f"Error during cleanup: {str(e)}")

In [None]:
class LRScheduler:
    """Learning rate scheduler with multiple scheduling strategies"""

    def __init__(self, optimizer, config):
        self.optimizer = optimizer
        self.config = config
        self.total_steps = config.epochs * config.batch_size
        self.warmup_steps = config.warmup_steps
        self.current_step = 0
        self.base_lr = config.learning_rate
        self.min_lr = config.min_learning_rate

        # Select scheduling strategy
        if config.lr_schedule == 'cosine_with_warmup':
            self.get_lr = self._cosine_with_warmup
        elif config.lr_schedule == 'linear_with_warmup':
            self.get_lr = self._linear_with_warmup
        elif config.lr_schedule == 'step':
            self.get_lr = self._step_decay
        else:
            raise ValueError(f"Unknown lr schedule: {config.lr_schedule}")

    def _cosine_with_warmup(self):
        """Cosine annealing with warmup"""
        if self.current_step < self.warmup_steps:
            # Linear warmup
            return self.base_lr * (self.current_step / max(1, self.warmup_steps))
        else:
            # Cosine decay
            progress = (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
            return self.min_lr + 0.5 * (self.base_lr - self.min_lr) * (1 + math.cos(math.pi * progress))

    def _linear_with_warmup(self):
        """Linear decay with warmup"""
        if self.current_step < self.warmup_steps:
            # Linear warmup
            return self.base_lr * (self.current_step / max(1, self.warmup_steps))
        else:
            # Linear decay
            progress = (self.current_step - self.warmup_steps) / max(1, self.total_steps - self.warmup_steps)
            return self.min_lr + (self.base_lr - self.min_lr) * (1 - progress)

    def _step_decay(self):
        """Step decay with warmup"""
        if self.current_step < self.warmup_steps:
            return self.base_lr * (self.current_step / max(1, self.warmup_steps))
        else:
            # Decay learning rate by factor of 0.1 every lr_decay_epochs
            decay_factor = 0.1 ** (self.current_step // (self.config.lr_decay_epochs * self.config.batch_size))
            return max(self.min_lr, self.base_lr * decay_factor)

    def step(self):
        """Update learning rate"""
        self.current_step += 1
        lr = self.get_lr()
        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr
        return lr

    def get_last_lr(self):
        """Get current learning rate"""
        return self.optimizer.param_groups[0]['lr']

    def state_dict(self):
        """Returns the state of the scheduler as a :class:`dict`."""
        return {
            'current_step': self.current_step,
            'base_lr': self.base_lr,
            'config': self.config.__dict__
        }

    def load_state_dict(self, state_dict):
        """Loads the schedulers state."""
        self.current_step = state_dict['current_step']
        self.base_lr = state_dict['base_lr']
        self.config.__dict__.update(state_dict['config'])

In [None]:
import datetime
import torch
from typing import Optional, Union
from pathlib import Path
from tqdm import tqdm

class Trainer:
    def __init__(self, model: torch.nn.Module, data_manager: DataManager, config: ModelConfig):
        self.model = model
        self.data_manager = data_manager
        self.config = config
        self.device = config.device
        self.epoch = 0
        self.global_step = 0
        self.best_val_loss = float('inf')
        self.no_improvement_count = 0
        self.train_losses = []
        self.val_losses = []
        self.setup_training()

    def setup_training(self):
        """Initialize training components with enhanced LR scheduling."""
        # Add criterion initialization
        self.criterion = torch.nn.CrossEntropyLoss()

        self.optimizer = torch.optim.AdamW(
            self.model.parameters(),
            lr=self.config.learning_rate,
            weight_decay=self.config.weight_decay,
            betas=(0.9, 0.95)
        )

        if self.config.use_amp:
            self.scaler = torch.cuda.amp.GradScaler()

        # Initialize the custom LR scheduler
        self.scheduler = LRScheduler(self.optimizer, self.config)

    def save_checkpoint(self, loss: float, is_best: bool = False, custom_path: Path = None):
        timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
        loss_str = f"{loss:.4f}".replace('.', '_')
        checkpoint = {
            'epoch': self.epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'criterion_state_dict': self.criterion.state_dict(),
            'scheduler_state_dict': {
                'current_step': self.scheduler.current_step,
                'base_lr': self.scheduler.base_lr,
                'config': self.scheduler.config.__dict__
            },
            'loss': loss,
            'global_step': self.global_step,
            'best_val_loss': self.best_val_loss,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'config': {
                k: str(v) if isinstance(v, (Path, torch.dtype)) else v
                for k, v in self.config.__dict__.items()
            }  # Save the full config as a dictionary
        }

        checkpoint_dir = Path('/content/drive/MyDrive/LLM/checkpoints') if is_colab() else Path('./LLM/checkpoints')
        checkpoint_dir.mkdir(parents=True, exist_ok=True)

        if custom_path is None:
            checkpoint_name = f"{'best_model' if is_best else 'checkpoint'}_epoch_{self.epoch + 1}_loss_{loss_str}_{timestamp}.pt"
            checkpoint_path = checkpoint_dir / checkpoint_name
        else:
            checkpoint_path = custom_path

        torch.save(checkpoint, str(checkpoint_path))
        print(f"Checkpoint saved: {checkpoint_path}")

    def _validate_training_setup(self):
        if not self.data_manager.tokenizer:
            raise RuntimeError("Tokenizer not initialized")

    def train(self):
        """Train the model for the specified number of epochs."""
        self._validate_training_setup()

        for epoch in range(self.config.epochs):
            self.epoch = epoch
            train_loss = self.train_epoch()
            val_loss = self.evaluate()

            # Track losses
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)

            # Early stopping check
            if val_loss < self.best_val_loss:
                self.best_val_loss = val_loss
                self.no_improvement_count = 0
                self.save_checkpoint(val_loss, is_best=True)
            else:
                self.no_improvement_count += 1

            if self.no_improvement_count >= self.config.patience:
                print(f"Early stopping triggered after {epoch + 1} epochs")
                break

            print(f"Epoch {epoch + 1}/{self.config.epochs} | "
                  f"Train Loss: {train_loss:.4f} | "
                  f"Val Loss: {val_loss:.4f} | "
                  f"LR: {self.scheduler.get_last_lr():.2e}")

    def train_epoch(self):
        self.model.train()
        total_loss = 0
        start_time = time.time()
        total_batches = len(self.data_manager.train_loader)

        def _format_time(seconds):
            """Convert seconds to mm:ss format"""
            minutes = int(seconds) // 60
            seconds = int(seconds) % 60
            return f"{minutes:02d}:{seconds:02d}"

        for batch_idx, (input_ids, targets) in enumerate(self.data_manager.train_loader):
            # Move data to device
            input_ids = input_ids.to(self.config.device)
            targets = targets.to(self.config.device)

            # Forward pass
            with torch.cuda.amp.autocast(enabled=self.config.use_amp):
                logits, loss = self.model(input_ids, targets=targets)

            # Scale loss if using gradient accumulation
            if self.config.gradient_accumulation_steps > 1:
                loss = loss / self.config.gradient_accumulation_steps

            # Backward pass
            self.scaler.scale(loss).backward()

            # Step if gradient accumulation criteria met
            if (batch_idx + 1) % self.config.gradient_accumulation_steps == 0:
                self.scaler.unscale_(self.optimizer)
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.config.gradient_clip_val
                )
                self.scaler.step(self.optimizer)
                self.scaler.update()
                self.optimizer.zero_grad(set_to_none=True)

                if self.scheduler is not None:
                    self.scheduler.step()

            # Update metrics
            total_loss += loss.item()

            # Get learning rate (handling both list and float return types)
            if hasattr(self.scheduler, 'get_last_lr'):
                last_lr = self.scheduler.get_last_lr()
                current_lr = last_lr[0] if isinstance(last_lr, list) else last_lr
            else:
                current_lr = self.optimizer.param_groups[0]['lr']

            avg_loss = total_loss / (batch_idx + 1)

            # Calculate progress and timing
            progress = (batch_idx + 1) / total_batches
            elapsed_time = time.time() - start_time
            remaining_time = elapsed_time / progress - elapsed_time if progress > 0 else 0

            # Progress bar
            bar_length = 20
            filled_length = int(bar_length * progress)
            bar = '█' * filled_length + '░' * (bar_length - filled_length)

            # Status line
            status = (
                f'\r│ E{self.epoch+1:02d} │ '
                f'{bar} │ '
                f'{progress:>3.0%} │ '
                f'Loss: {avg_loss:.4f} │ '
                f'LR: {current_lr:.2e} │ '
                f'{_format_time(elapsed_time)}/{_format_time(remaining_time)} │'
            )

            print(status, end='', flush=True)

        print()  # New line after epoch completion
        return total_loss / total_batches

    def _training_step(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
        """Single training step with gradient scaling and LR scheduling."""
        with torch.cuda.amp.autocast(enabled=self.config.use_amp):
            _, loss = self.model(x, targets=y)

        if self.config.use_amp:
            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
            self.scaler.step(self.optimizer)
            self.scaler.update()
        else:
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.config.max_grad_norm)
            self.optimizer.step()

        self.optimizer.zero_grad(set_to_none=True)
        return loss

    @torch.no_grad()
    def evaluate(self):
        self.model.eval()
        total_loss = 0

        with torch.no_grad():
            for input_ids, targets in self.data_manager.val_loader:
                input_ids = input_ids.to(self.config.device)
                targets = targets.to(self.config.device)

                with torch.cuda.amp.autocast(enabled=self.config.use_amp):
                    logits, loss = self.model(input_ids, targets=targets)  # Model returns both logits and loss

                total_loss += loss.item()

        return total_loss / len(self.data_manager.val_loader)

    def load_checkpoint(self, checkpoint_path: Union[str, Path]):
        """Load checkpoint with complete model configuration."""
        checkpoint = torch.load(checkpoint_path, map_location=self.device)

        # Load model state
        self.model.load_state_dict(checkpoint['model_state_dict'])

        # Load optimizer state
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        # Load criterion state if it exists
        if 'criterion_state_dict' in checkpoint:
            self.criterion.load_state_dict(checkpoint['criterion_state_dict'])

        # Load scheduler state
        if 'scheduler_state_dict' in checkpoint:
            self.scheduler.current_step = checkpoint['scheduler_state_dict']['current_step']
            self.scheduler.base_lr = checkpoint['scheduler_state_dict']['base_lr']

        # Load training state
        self.epoch = checkpoint['epoch']
        self.global_step = checkpoint['global_step']
        self.best_val_loss = checkpoint['best_val_loss']
        self.train_losses = checkpoint['train_losses']
        self.val_losses = checkpoint['val_losses']

        # Load config if it exists
        if 'config' in checkpoint:
            # Convert string representations back to proper types if needed
            config_dict = checkpoint['config']
            if 'dtype' in config_dict:
                config_dict['dtype'] = getattr(torch, config_dict['dtype'].split('.')[-1])
            if 'amp_dtype' in config_dict:
                config_dict['amp_dtype'] = getattr(torch, config_dict['amp_dtype'].split('.')[-1])
            if 'save_dir' in config_dict:
                config_dict['save_dir'] = Path(config_dict['save_dir'])

            # Update the current config with saved values
            for key, value in config_dict.items():
                if hasattr(self.config, key):
                    setattr(self.config, key, value)

        print(f"Checkpoint loaded from: {checkpoint_path}")

In [None]:
class TrainingPipeline:
    def __init__(self, model: GPT, config: ModelConfig):
        self.model = model
        self.config = config
        self.best_val_loss = float('inf')

        # Setup logging
        self.setup_logging()

        # Initialize components
        self.initialize_pipeline()

    def setup_logging(self):
        """Setup logging formats and headers"""
        self.log_formats = {
            'epoch': "{: >4}/{}", # epoch/total_epochs
            'train_loss': "{:.4f}",
            'val_loss': "{:.4f}",
            'lr': "{:.2e}",
            'time': "{:.1f}s",
            'eta': "{}",
            'best': "✓" if True else " "
        }

        self.header = (
            "\n\033[1m"  # Bold
            f"{'Epoch':>6} | {'Train Loss':>10} | {'Val Loss':>10} | "
            f"{'LR':>10} | {'Time':>8} | {'ETA':>12} | {'Best':>4}"
            "\033[0m"  # Reset bold
        )

        self.divider = "-" * 70

    def initialize_pipeline(self):
        """Initialize data manager and trainer with progress tracking"""
        print("\n🔧 Initializing Pipeline:")
        print(self.divider)

        # Initialize data manager
        print("📂 Initializing data manager...")
        self.data_manager = DataManager(self.config)

        print("📥 Loading data...")
        self.data_manager.load_data()

        if self.data_manager.train_loader is None:
            raise ValueError("❌ Training data failed to load")

        # Calculate warmup steps
        if hasattr(self.config, 'warmup_ratio'):
            total_steps = len(self.data_manager.train_loader) * self.config.epochs
            self.config.warmup_steps = int(total_steps * self.config.warmup_ratio)
            print(f"🔥 Warmup steps: {self.config.warmup_steps:,} ({self.config.warmup_ratio*100:.1f}% of total)")

        print("🚀 Initializing trainer...")
        self.trainer = Trainer(self.model, self.data_manager, self.config)

        print(self.divider)

    def print_training_summary(self):
        """Print training configuration summary"""
        print("\n📊 Training Configuration:")
        print(self.divider)
        print(f"• Model Parameters: {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M")
        print(f"• Device: {self.config.device}")
        print(f"• Batch Size: {self.config.batch_size}")
        print(f"• Learning Rate: {self.config.learning_rate:.2e}")
        print(f"• Training Epochs: {self.config.epochs}")
        print(f"• Gradient Accumulation: {self.config.gradient_accumulation_steps}")
        print(self.divider)

    def train_and_evaluate(self):
        """Main training loop with improved progress tracking"""
        try:
            self.print_training_summary()
            print(self.header)

            start_time = time.time()
            epoch_times = []

            while self.trainer.epoch < self.config.epochs:
                epoch_start = time.time()
                current_epoch = self.trainer.epoch

                # Training and evaluation
                train_loss = self.trainer.train_epoch()
                val_loss = self.trainer.evaluate()

                # Update best model tracking
                is_best = val_loss < self.best_val_loss
                if is_best:
                    self.best_val_loss = val_loss
                    self.trainer.save_checkpoint(val_loss, is_best)

                # Calculate timing and ETA
                epoch_time = time.time() - epoch_start
                epoch_times.append(epoch_time)
                avg_epoch_time = np.mean(epoch_times[-5:]) if epoch_times else epoch_time
                remaining_epochs = self.config.epochs - (current_epoch + 1)
                eta = str(datetime.timedelta(seconds=int(avg_epoch_time * remaining_epochs)))

                # Print progress
                progress = (
                    f"{current_epoch + 1:>6}/{self.config.epochs} | "
                    f"{train_loss:>10.4f} | "
                    f"{val_loss:>10.4f} | "
                    f"{self.trainer.scheduler.get_last_lr():>10.2e} | "
                    f"{epoch_time:>8.1f}s | "
                    f"{eta:>12} | "
                    f"{'✓' if is_best else ' ':>4}"
                )
                print(progress)

                self.trainer.epoch += 1

                # Early stopping check
                if self.trainer.no_improvement_count >= self.config.patience:
                    print(f"\n⚠️  Early stopping triggered after {current_epoch + 1} epochs")
                    break

            # Training summary
            total_time = time.time() - start_time
            print(self.divider)
            print(f"\n✨ Training completed in {str(datetime.timedelta(seconds=int(total_time)))}")
            print(f"🏆 Best validation loss: {self.best_val_loss:.4f}")

        except KeyboardInterrupt:
            print("\n\n⚠️  Training interrupted by user. Saving checkpoint...")
            self.trainer.save_checkpoint(val_loss, is_best=False)

        except Exception as e:
            print(f"\n❌ Error during training: {str(e)}")
            raise

        finally:
            # Cleanup
            torch.cuda.empty_cache()
            gc.collect()

In [None]:
if __name__ == "__main__":
    try:
        # Initialize the configuration with optimized values for small LLM
        config = ModelConfig(
            # Model Architecture
            vocab_size=100,  # Will be updated after data loading
            block_size=64,  # Context window size
            n_layer=3,        # Reduced from 6
            n_head=4,
            n_embed=128,      # Reduced from 256
            ff_dim=512,       # Reduced from 1024
            head_dim=64,      # Reduced from 128

            # Increase Regularization
            resid_pdrop=0.2,      # Increased from 0.1
            weight_decay=0.05,    # Increased from 0.01

            # Architecture Features
            bias=True,
            flash_attn=True,  # New: use flash attention if available
            use_gradient_checkpointing=False,  # Memory efficiency

            # Training Parameters
            batch_size=128,
            gradient_accumulation_steps=1,
            epochs=300,
            gradient_clip_val=1.0,

            # Learning Rate - More conservative
            learning_rate=1e-4,   # Reduced from 1e-4
            min_learning_rate=5e-6,  # Adjusted proportionally
            warmup_ratio=0.0002,    # Increased from 0.01
            lr_schedule='cosine_with_warmup',
            lr_decay_epochs=9,    # Reduced from 8

            # Early Stopping - More aggressive
            loss_threshold=0.3,
            within_epoch_loss_threshold=0.3,
            patience=5,           # Reduced from 100 - much more aggressive
            min_delta=1e-3,      # Increased from 1e-4

            # Evaluation and Checkpointing
            eval_steps=500,       # Reduced from 250
            eval_every=1000,       # Reduced from 500
            save_every=5,
            keep_last_n_checkpoints=2,

            # Generation Parameters
            gen_temperature=0.8,
            max_gen_tokens=64,
            top_k=50,
            top_p=0.9,        # New: nucleus sampling parameter

            # System and Memory
            device='cuda' if torch.cuda.is_available() else 'cpu',
            pin_memory=True,
            num_workers=8,
            prefetch_factor=2,

            # Precision
            use_amp=True,
            amp_dtype=torch.bfloat16,
            dtype=torch.bfloat16
        )

        # Initialize data pipeline with error handling
        print("Initializing data manager...")
        data_manager = DataManager(config)

        # Load and process data
        print("Loading and processing data...")
        data_manager.load_data()

        # Analyze vocabulary
        unique_characters = list(data_manager.char_to_idx.keys())
        print(f"Found {len(unique_characters)} unique characters")

        # Initialize model with memory optimization
        print("Initializing model...")
        with torch.cuda.amp.autocast(enabled=config.use_amp):
            model = GPT(config).to(config.device)

        # Log model parameters
        total_params = sum(p.numel() for p in model.parameters())
        print(f"Model initialized with {total_params/1e6:.2f}M parameters")

        # Initialize and run training pipeline with error handling
        print("Initializing training pipeline...")
        pipeline = TrainingPipeline(model, config)

        # Run training with resource monitoring
        print("Starting training...")
        try:
            pipeline.train_and_evaluate()
        except KeyboardInterrupt:
            print("Training interrupted by user. Saving checkpoint...")
            pipeline.trainer.save_checkpoint(
                pipeline.trainer.evaluate(),
                is_best=False,
                custom_path=Path(config.save_dir) / "interrupted_checkpoint.pt"
            )

        print("Training completed successfully")

    except Exception as e:
        print(f"Error during execution: {str(e)}")
        raise
    finally:
        # Cleanup
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        gc.collect()