In [None]:
import os
import sys
import io
import time
import random
import json
import gc
import warnings
from pathlib import Path
from datetime import timedelta
import psutil
import numpy as np
from contextlib import contextmanager, nullcontext
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm  # or tqdm.auto if needed
from typing import List, Union, Optional, Dict, Any
from dataclasses import dataclass
import threading
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import tempfile

# Check if running in Google Colab
def is_colab():
    try:
        return 'google.colab' in str(get_ipython())
    except NameError:
        return False

# Mount Google Drive if in Colab
if is_colab():
    from google.colab import drive
    drive.mount('/content/drive')

# Create LLM directory in Drive
base_dir = Path('/content/drive/MyDrive/LLM') if is_colab() else Path('./LLM')
for dir_name in ['checkpoints', 'models', 'logs', 'configs', 'data']:
    (base_dir / dir_name).mkdir(parents=True, exist_ok=True)

def is_package_installed(package_name):
    try:
        __import__(package_name)
        return True
    except ImportError:
        return False

if is_colab():
    # PyTorch packages
    pytorch_packages = ['torch', 'torchvision', 'torchaudio']
    pytorch_install = [pkg for pkg in pytorch_packages if not is_package_installed(pkg)]
    if pytorch_install:
        !pip install {' '.join(pytorch_install)} --index-url https://download.pytorch.org/whl/cu118

    # Additional packages
    other_packages = ['pynvml', 'nvidia_ml_py3', 'gputil', 'fastapi', 'uvicorn', 'pydantic']
    other_install = [pkg for pkg in other_packages if not is_package_installed(pkg)]
    if other_install:
        !pip install {' '.join(other_install)}

# Suppress warnings
warnings.filterwarnings('ignore')

print("Section 1: Initial setup and core components complete")

In [None]:
class DataManager:
    def __init__(self, config):
        self.config = config
        self.char_to_idx = {}
        self.idx_to_char = {}
        self.train_data: torch.Tensor = None
        self.val_data: torch.Tensor = None
        self.test_data: torch.Tensor = None
        self.vocab_size = 0

        # Define the data paths as lists of files
        self.data_paths = {
            'train': [
                Path('/content/drive/MyDrive/LLM/data/huff_2252.txt'),
                Path('/content/drive/MyDrive/LLM/data/theo_2253.txt'),
                Path('/content/drive/MyDrive/LLM/data/hubbard_2251.txt'),
                Path('/content/drive/MyDrive/LLM/data/raekwon_2250.txt'),
                Path('/content/drive/MyDrive/LLM/data/pappas_2249.txt'),
                Path('/content/drive/MyDrive/LLM/data/trussell_2247.txt'),
                Path('/content/drive/MyDrive/LLM/data/fox_2246.txt'),
                Path('/content/drive/MyDrive/LLM/data/blagojevich_2245.txt'),
                Path('/content/drive/MyDrive/LLM/data/sorin_2242.txt'),
                Path('/content/drive/MyDrive/LLM/data/strassman_2241.txt'),
                Path('/content/drive/MyDrive/LLM/data/tarantino_2240.txt'),
                Path('/content/drive/MyDrive/LLM/data/derek_2239.txt'),
                Path('/content/drive/MyDrive/LLM/data/mcphee_2238.txt'),
                Path('/content/drive/MyDrive/LLM/data/parks_2236.txt'),
            ],
            'val': [
                Path('/content/drive/MyDrive/LLM/data/lennon_2243.txt'),
                Path('/content/drive/MyDrive/LLM/data/andreessen_2234.txt'),
                Path('/content/drive/MyDrive/LLM/data/lennon_2243.txt'),
                Path('/content/drive/MyDrive/LLM/data/storch_2233.txt'),
                # Add more validation files as needed
            ],
            'test': [
                Path('/content/drive/MyDrive/LLM/data/graves_2244.txt'),
                #Path('/content/drive/MyDrive/LLM/data/test_file2.txt'),
                # Add more test files as needed
            ]
        }

    def update_params():
        """Prompts the user to update generation parameters."""
        while True:
            try:
                temperature = float(input("Enter new temperature (0.0-1.0, default 0.7): ") or 0.7)
                if not 0.0 <= temperature <= 1.0:
                    raise ValueError("Temperature must be between 0.0 and 1.0")

                max_tokens = int(input("Enter new max tokens (positive integer, default 100): ") or 100)
                if max_tokens <= 0:
                    raise ValueError("Max tokens must be a positive integer")

                top_k = int(input("Enter new top k (positive integer, default 50): ") or 50)
                if top_k <= 0:
                    raise ValueError("Top k must be a positive integer")
                break  # Exit the loop if all inputs are valid
            except ValueError as e:
                print(f"Invalid input: {e}. Please try again.")
        return temperature, max_tokens, top_k

    def calculate_vocab_size(self):
        """Calculates vocabulary size from all training files."""
        try:
            # Collect all unique characters from all training files
            all_chars = set()
            for train_path in self.data_paths['train']:
                if not train_path.exists():
                    print(f"Training data file does not exist: {train_path}")
                    raise FileNotFoundError(f"Training data file does not exist: {train_path}")

                with open(train_path, 'r', encoding='utf-8') as f:
                    text = f.read()
                    all_chars.update(set(text))

            # Sort the unique characters and add special tokens
            unique_chars = sorted(all_chars)
            special_chars = ['\n', '\t', ' ', '_', '[', ']', '(', ')', '{', '}', '*', '/', '\\', '|',
                           '@', '#', '$', '%', '^', '&', '+', '=', '`', '~', '<pad>', '<extra>']

            # Add special characters that aren't already in unique_chars
            for char in special_chars:
                if char not in unique_chars:
                    unique_chars.append(char)

            self.setup_tokenizer(unique_chars)

            print(f"Vocabulary created with {self.vocab_size} characters including special characters")
            print(f"Special characters included: {[c for c in special_chars if c in self.char_to_idx]}")

        except Exception as e:
            print(f"Error during vocabulary calculation: {e}")
            raise

        return self.vocab_size

    def load_and_encode_data(self, data_path: Union[str, Path]) -> torch.Tensor:
        """Load and encode data from a file."""
        if not data_path.exists():
            print(f"Data file does not exist: {data_path}")
            raise FileNotFoundError(f"Data file does not exist: {data_path}")

        with open(data_path, 'r', encoding='utf-8') as f:
            text = f.read()

        print(f"Loaded {len(text)} characters from {data_path}.")
        unique_chars = sorted(set(text))

        # Add <pad> and <extra> tokens to the vocabulary
        unique_chars.extend(['<pad>', '<extra>'])

        self.setup_tokenizer(unique_chars)

        return self.encode(text)

    def setup_tokenizer(self, unique_chars: List[str]):
        """Set up the tokenizer mapping."""
        # Create char_to_idx and idx_to_char mappings
        self.char_to_idx = {char: idx for idx, char in enumerate(unique_chars)}
        self.idx_to_char = {idx: char for char, idx in self.char_to_idx.items()}
        self.vocab_size = len(self.char_to_idx)

        print(f"Vocabulary size: {self.vocab_size}")

    def encode(self, text: str) -> torch.Tensor:
        """Encode text into tensor of indices."""
        try:
            # Ensure all characters are in the vocabulary
            indices = [self.char_to_idx[char] for char in text]
            return torch.tensor(indices, dtype=torch.long, device=self.config.device)
        except KeyError as e:
            raise ValueError(f"Character '{e.args[0]}' not found in vocabulary. Update the vocabulary to include all characters.")

    def decode(self, indices: torch.Tensor) -> str:
        """Decode tensor of indices back into text with robust dimension handling."""
        try:
            # Move tensor to CPU and detach from computation graph
            indices = indices.detach().cpu()

            # Get the shape of the tensor
            shape = indices.size()
            print(f"Decode input tensor shape: {shape}")

            # If we have a batch dimension, take the first sequence
            if len(shape) > 1:
                indices = indices[0]

            # Convert to 1D list of indices
            index_list = indices.tolist()

            # Ensure index_list is flat
            if isinstance(index_list, (list, tuple)) and isinstance(index_list[0], (list, tuple)):
                index_list = index_list[0]

            # Convert indices to characters
            result = []
            for idx in index_list:
                if idx in self.idx_to_char:
                    result.append(self.idx_to_char[idx])
                else:
                    print(f"Unknown index: {idx}")
                    result.append('<UNK>')

            return ''.join(result)

        except Exception as e:
            print(f"Decoding error: {str(e)}")
            raise

In [None]:
@dataclass
class ModelConfig:
    save_dir: Union[str, Path] = Path("/content/drive/MyDrive/LLM/checkpoints")
    vocab_size: int = 0
    block_size: int = 64              # Context length
    n_embed: int = 64                 # Embedding dimension
    n_head: int = 4                   # Number of attention heads
    n_layer: int = 4                  # Number of transformer layers
    ff_dim: int = 256                 # Feed-forward dimension
    #dropout: float = 0.1              # Dropout rate
    bias: bool = True
    flash_attn: bool = True
    head_dim: int = 64
    #attn_pdrop: float = 0.1
    resid_pdrop: float = 0.2
    #embd_pdrop: float = 0.1

    # Update existing parameters
    ff_dim: int = None  # Will be set to 4 * n_embed

    def __post_init__(self):
        super().__post_init__()
        if self.ff_dim is None:
            self.ff_dim = 4 * self.n_embed


    # Training Parameters
    batch_size: int = 128
    gradient_accumulation_steps: int = 1
    learning_rate: float = 1e-3
    min_learning_rate: float = 1e-3
    weight_decay: float = 0.1
    epochs: int = 10
    warmup_steps: int = 200            # Number of warmup steps
    max_grad_norm: float = 1.0
    within_epoch_loss_threshold: float = 0.7  # New parameter for early stopping within epoch
    eval_steps: int = 100  # Check loss every N steps

    # Learning Rate Schedule
    lr_schedule: str = 'cosine_with_warmup'
    lr_decay_epochs: int = 8

    # Memory Optimization
    use_amp: bool = True
    amp_dtype: torch.dtype = torch.bfloat16  # Changed from float16
    dtype: torch.dtype = torch.bfloat16      # Changed from float32
    use_gradient_checkpointing: bool = False

    # Data Loading
    pin_memory: bool = True
    num_workers: int = 8
    prefetch_factor: int = 5

    # Generation Parameters
    gen_temperature: float = 0.8
    max_gen_tokens: int = 256
    top_k: int = 50
    top_p: float = 0.9

    # Early Stopping
    patience: int = 5
    min_delta: float = 1e-4
    loss_threshold: float = 1.0

    # Evaluation and Checkpointing
    eval_every: int = 1000
    save_every: int = 1
    keep_last_n_checkpoints: int = 5

    # System Parameters
    device: str = "cuda"

    def __post_init__(self):
        """Initialize derived parameters and create directories"""
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32  # Changed from float16

        # Create save directory
        self.save_dir = Path("/content/drive/MyDrive/LLM/checkpoints")
        self.save_dir.mkdir(parents=True, exist_ok=True)

    def print_params(self) -> None:
        """Print model and training parameters"""
        print("Model Parameters:")
        for key, value in self.__dict__.items():
            print(f"  {key}: {value}")

    def save(self, path: Union[str, Path]) -> None:
        """Save configuration to JSON"""
        path = Path(path)
        config_dict = {
            k: str(v) if isinstance(v, (Path, torch.dtype)) else v
            for k, v in self.__dict__.items()
        }
        path.parent.mkdir(parents=True, exist_ok=True)
        with path.open('w') as f:
            json.dump(config_dict, f, indent=4)

    @classmethod
    def load(cls, path: Union[str, Path]) -> 'ModelConfig':
        """Load configuration from JSON"""
        path = Path(path)
        with path.open('r') as f:
            config_dict = json.load(f)

        if 'dtype' in config_dict:
            config_dict['dtype'] = getattr(torch, config_dict['dtype'].split('.')[-1])
        if 'save_dir' in config_dict:
            config_dict['save_dir'] = str(config_dict['save_dir'])

        return cls(**config_dict)

In [None]:
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer('inv_freq', inv_freq)
        self.max_seq_len_cached = max_position_embeddings

        # Initialize cache
        t = torch.arange(max_position_embeddings, device=inv_freq.device).type_as(inv_freq)
        freqs = torch.einsum('i,j->ij', t, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

    def forward(self, x, seq_len=None):
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(seq_len, device=x.device).type_as(self.inv_freq)
            freqs = torch.einsum('i,j->ij', t, self.inv_freq)
            emb = torch.cat((freqs, freqs), dim=-1)
            self.cos_cached = emb.cos()[None, None, :, :]
            self.sin_cached = emb.sin()[None, None, :, :]

        return (
            self.cos_cached[:, :, :seq_len, ...].to(x.device),
            self.sin_cached[:, :, :seq_len, ...].to(x.device)
        )

def rotate_half(x):
    x1, x2 = x[..., :x.shape[-1]//2], x[..., x.shape[-1]//2:]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin):
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class CausalSelfAttention(nn.Module):
    """Multi-head causal self-attention layer with RoPE and residual dropout"""
    def __init__(self, config):
        super().__init__()
        assert config.n_embed % config.n_head == 0

        # Store necessary dimensions
        self.n_embed = config.n_embed
        self.n_head = config.n_head
        self.head_dim = config.n_embed // config.n_head

        # Key, query, value projections for all heads
        self.c_attn = nn.Linear(config.n_embed, 3 * config.n_embed, bias=config.bias)
        self.c_proj = nn.Linear(config.n_embed, config.n_embed, bias=config.bias)

        # Regularization
        self.resid_dropout = nn.Dropout(config.resid_pdrop)

        # Initialize RoPE
        self.rope = RotaryEmbedding(
            self.head_dim,
            max_position_embeddings=config.block_size
        )

        # Flash attention support
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
        if not self.flash:
            self.register_buffer(
                "mask",
                torch.tril(torch.ones(config.block_size, config.block_size))
                .view(1, 1, config.block_size, config.block_size)
            )

        # Scaling factor for attention
        self.scale = 1.0 / math.sqrt(self.head_dim)

    def forward(self, x):
        B, T, C = x.size()

        # Calculate query, key, values for all heads
        q, k, v = self.c_attn(x).split(self.n_embed, dim=2)

        # Reshape for multi-head attention
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)

        # Apply RoPE to queries and keys
        cos, sin = self.rope(q, seq_len=T)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # Causal self-attention with flash attention optimization
        if self.flash:
            y = torch.nn.functional.scaled_dot_product_attention(
                q, k, v,
                attn_mask=None,
                dropout_p=0.0,  # Disabled in favor of residual dropout
                is_causal=True,
                scale=self.scale
            )
        else:
            att = (q @ k.transpose(-2, -1)) * self.scale
            att = att.masked_fill(self.mask[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            y = att @ v

        # Re-assemble and project
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))
        return y

class LayerNorm(nn.Module):
    """LayerNorm with optional bias"""
    def __init__(self, ndim, bias=True):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None

    def forward(self, input):
        return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

class GPTBlock(nn.Module):
    """Transformer block with residual dropout"""
    def __init__(self, config):
        super().__init__()
        self.ln1 = LayerNorm(config.n_embed, bias=config.bias)
        self.attn = CausalSelfAttention(config)
        self.ln2 = LayerNorm(config.n_embed, bias=config.bias)
        self.mlp = nn.Sequential(
            nn.Linear(config.n_embed, 4 * config.n_embed, bias=config.bias),
            nn.GELU(),
            nn.Linear(4 * config.n_embed, config.n_embed, bias=config.bias),
            nn.Dropout(config.resid_pdrop)
        )

        # Training optimizations
        self.use_checkpointing = config.use_gradient_checkpointing
        self.layer_scale_1 = nn.Parameter(torch.ones(config.n_embed) * 0.1)
        self.layer_scale_2 = nn.Parameter(torch.ones(config.n_embed) * 0.1)

    def forward(self, x):
        # Attention block
        if self.use_checkpointing and x.requires_grad:
            attn_output = torch.utils.checkpoint.checkpoint(self.attn, self.ln1(x))
        else:
            attn_output = self.attn(self.ln1(x))
        x = x + self.layer_scale_1.unsqueeze(0).unsqueeze(0) * attn_output

        # MLP block
        if self.use_checkpointing and x.requires_grad:
            mlp_output = torch.utils.checkpoint.checkpoint(self.mlp, self.ln2(x))
        else:
            mlp_output = self.mlp(self.ln2(x))
        x = x + self.layer_scale_2.unsqueeze(0).unsqueeze(0) * mlp_output

        return x

class GPT(nn.Module):
    """GPT-like transformer with residual dropout"""
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.transformer = nn.ModuleDict({
            'wte': nn.Embedding(config.vocab_size, config.n_embed),
            'h': nn.ModuleList([GPTBlock(config) for _ in range(config.n_layer)]),
            'ln_f': nn.LayerNorm(config.n_embed)
        })
        self.lm_head = nn.Linear(config.n_embed, config.vocab_size, bias=False)

        # Init
        self.apply(self._init_weights)
        for pn, p in self.named_parameters():
            if pn.endswith('c_proj.weight'):
                torch.nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))

        print(f"Parameters: {sum(p.numel() for p in self.parameters()) / 1e6:.2f}M")

    def _init_weights(self, module):
        if isinstance(module, (nn.Linear, nn.Embedding)):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if isinstance(module, nn.Linear) and module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, LayerNorm):
            torch.nn.init.ones_(module.weight)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)

    def forward(self, idx, targets=None):
        device = idx.device
        b, t = idx.size()

        with torch.cuda.amp.autocast(enabled=self.config.use_amp, dtype=torch.bfloat16):
            x = self.transformer.wte(idx)
            for block in self.transformer.h:
                x = block(x)
            x = self.transformer.ln_f(x)

            logits = self.lm_head(x)

            # Fixed loss calculation
            if targets is not None:
                loss = F.cross_entropy(
                    logits.view(-1, logits.size(-1)),
                    targets.view(-1)
                )
            else:
                # For generation mode
                logits = logits[:, -1, :]
                loss = None

        return logits, loss

    @torch.no_grad()
    def generate(self, idx: torch.Tensor, max_new_tokens: int,
                temperature: float = 1.0, top_k: Optional[int] = None) -> torch.Tensor:
        """Generate tokens with shape validation"""
        if idx.dim() != 2:
            idx = idx.unsqueeze(0)  # Add batch dimension if not present

        for _ in range(max_new_tokens):
            # Crop sequence if needed
            idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]

            # Forward pass
            logits, _ = self(idx_cond)

            # If logits has 3 dimensions (batch, sequence, vocab), take last token
            if logits.dim() == 3:
                logits = logits[:, -1, :]

            # Apply temperature
            logits = logits / temperature

            # Apply top-k sampling
            if top_k is not None:
                v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
                logits[logits < v[:, [-1]]] = float('-inf')

            # Sample next token
            probs = F.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            # Append next token
            idx = torch.cat((idx, next_token), dim=1)

        return idx

In [None]:
class ModelInterface:
    """Interface for model operations with comprehensive error handling and state management"""

    def __init__(self, config: ModelConfig):
        """Initialize the interface with configuration and necessary components"""
        self.config = config
        self.model = None
        self.data_manager = None
        self._is_initialized = False
        self.checkpoint_info = None
        self._initialize()

    def _initialize(self):
        """Safe initialization with resource management and state verification"""
        try:
            # Initialize data manager first
            self.data_manager = DataManager(self.config)

            # Initialize model with proper device placement
            self.model = GPT(self.config)
            self.model = self.model.to(self.config.device)

            self._is_initialized = True

            # Log initialization details
            print(f"Model initialized with {sum(p.numel() for p in self.model.parameters())/1e6:.2f}M parameters")
            print(f"Using device: {self.config.device}")
            print(f"Batch size: {self.config.batch_size}")

            # Adjust gradient accumulation based on GPU memory if needed
            if self.config.device == "cuda":
                gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
                if gpu_mem < 8:  # Only adjust if very limited memory
                    self.config.gradient_accumulation_steps = 2
                    print(f"Adjusted gradient accumulation steps to {self.config.gradient_accumulation_steps} due to limited GPU memory")

        except Exception as e:
            self.cleanup()
            raise RuntimeError(f"Failed to initialize model interface: {str(e)}")

    def verify_state(self):
        """Verify that the model and vocabulary are in a valid state"""
        if not self._is_initialized:
            raise RuntimeError("Model interface not properly initialized")

        if not self.model:
            raise RuntimeError("Model not loaded")

        if not self.data_manager:
            raise RuntimeError("Data manager not initialized")

        if not self.data_manager.char_to_idx:
            raise RuntimeError("Vocabulary mappings not loaded")

        # Verify vocabulary consistency
        if self.data_manager.vocab_size != self.config.vocab_size:
            raise ValueError(f"Vocabulary size mismatch: {self.data_manager.vocab_size} != {self.config.vocab_size}")

        # Verify model vocabulary size matches config
        if self.model.config.vocab_size != self.config.vocab_size:
            raise ValueError(f"Model vocabulary size mismatch: {self.model.config.vocab_size} != {self.config.vocab_size}")

    def load_model(self, model_path: Union[str, Path]):
        """Load a trained model from a checkpoint with comprehensive state restoration"""
        try:
            checkpoint = torch.load(model_path, map_location=self.config.device)

            # First, set up vocabulary mappings before loading model state
            if 'config' in checkpoint:
                # Get vocab size from checkpoint config
                vocab_size = checkpoint['config'].get('vocab_size')
                if vocab_size:
                    self.config.vocab_size = vocab_size

                # Set up vocabulary if present in checkpoint
                if 'char_to_idx' in checkpoint['config'] and 'idx_to_char' in checkpoint['config']:
                    self.data_manager.char_to_idx = checkpoint['config']['char_to_idx']
                    self.data_manager.idx_to_char = checkpoint['config']['idx_to_char']
                    self.data_manager.vocab_size = len(self.data_manager.char_to_idx)
                else:
                    # If no vocabulary in checkpoint, calculate it from training data
                    self.data_manager.calculate_vocab_size()
            else:
                raise ValueError("Checkpoint missing required config")

            # Now load model state
            if 'model_state_dict' in checkpoint:
                self.model.load_state_dict(checkpoint['model_state_dict'])
            else:
                self.model.load_state_dict(checkpoint)

            self.model.eval()
            self.verify_state()

        except Exception as e:
            print(f"Error loading model: {str(e)}")
            raise

    def generate_text(self, prompt: str, max_tokens: int = 100,
                     temperature: float = 0.0, top_k: int = 50) -> str:
        """Generate text with comprehensive error handling and memory management"""
        try:
            # Verify state before generation
            self.verify_state()

            if not prompt:
                raise ValueError("Empty prompt provided")

            if temperature <= 0:
                raise ValueError("Temperature must be positive")

            if max_tokens <= 0:
                raise ValueError("max_tokens must be positive")

            if top_k <= 0:
                raise ValueError("top_k must be positive")

            # Set model to evaluation mode
            self.model.eval()

            with torch.no_grad():
                try:
                    # Encode the prompt
                    tokens = self.data_manager.encode(prompt)
                except KeyError as e:
                    raise ValueError(f"Character in prompt not in vocabulary: {e}")

                # Ensure proper tensor shape
                if tokens.dim() == 1:
                    tokens = tokens.unsqueeze(0)

                # Move to correct device
                tokens = tokens.to(self.config.device)

                # Generate text
                generated = self.model.generate(
                    tokens,
                    max_new_tokens=max_tokens,
                    temperature=temperature,
                    top_k=top_k
                )

                # Decode the generated tokens
                try:
                    if generated.dim() == 2:
                        generated_text = self.data_manager.decode(generated[0])
                    else:
                        generated_text = self.data_manager.decode(generated)
                except Exception as e:
                    raise RuntimeError(f"Error decoding generated tokens: {str(e)}")

                return generated_text

        except Exception as e:
            print(f"Error in text generation: {str(e)}")
            raise
        finally:
            # Cleanup
            if self.config.device == "cuda":
                torch.cuda.empty_cache()

    def cleanup(self):
        """Comprehensive cleanup of resources and memory"""
        try:
            if hasattr(self, 'model') and self.model is not None:
                self.model.cpu()
                del self.model
                self.model = None

            if hasattr(self, 'data_manager') and self.data_manager is not None:
                del self.data_manager
                self.data_manager = None

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            gc.collect()

            self._is_initialized = False
            print("Model interface cleaned up successfully")

        except Exception as e:
            print(f"Error during cleanup: {str(e)}")

    def get_model_info(self) -> Dict[str, Any]:
        """Return current model and configuration information"""
        return {
            'initialized': self._is_initialized,
            'device': self.config.device,
            'vocab_size': getattr(self.config, 'vocab_size', None),
            'model_parameters': sum(p.numel() for p in self.model.parameters())/1e6 if self.model else None,
            'has_checkpoint': self.checkpoint_info is not None,
            'checkpoint_epoch': self.checkpoint_info.get('epoch') if self.checkpoint_info else None,
            'checkpoint_loss': self.checkpoint_info.get('loss') if self.checkpoint_info else None
        }

In [None]:
import gc

def run_interactive_chat(model_path: Union[str, Path], temperature=0.7, max_tokens=100, top_k=50):
    """Load trained model and run interactive chat session"""
    interface = None
    try:
        # Validate model path
        model_path = Path(model_path)
        if not model_path.exists():
            raise FileNotFoundError(f"Model checkpoint not found: {model_path}")

        # Determine device
        device = "cuda" if torch.cuda.is_available() else "cpu"
        print(f"Using device: {device}")

        # Load checkpoint
        try:
            checkpoint = torch.load(model_path, map_location=device)
        except Exception as e:
            raise RuntimeError(f"Failed to load checkpoint: {e}")

        if 'config' not in checkpoint:
            raise ValueError("Invalid checkpoint format: missing config")

        # Initialize config with correct device, filtering unwanted config parameters
        config_dict = checkpoint['config']
        config_dict['device'] = device  # Override device in loaded config

        # Remove parameters not in ModelConfig class
        unwanted_params = ['save_dir']
        for param in unwanted_params:
            config_dict.pop(param, None)

        config = ModelConfig(**config_dict)

        # Initialize interface
        interface = ModelInterface(config)
        interface.load_model(model_path)

        # Use interface's data_manager
        data_manager = interface.data_manager

        print(f"\nModel loaded successfully:")
        model_info = interface.get_model_info()
        print(f"- Model size: {model_info['model_parameters']:.2f}M parameters")
        print(f"- Device: {config.device}")
        print(f"- Vocabulary size: {config.vocab_size}")
        print(f"- Generation parameters:")
        print(f"  - Temperature: {temperature}")
        print(f"  - Max tokens: {max_tokens}")
        print(f"  - Top k: {top_k}")

        # Interactive loop
        print("\nStarting interactive session... (type 'quit' to exit, 'params' to update parameters)")
        while True:
            try:
                prompt = input("\nYou: ").strip()
                if not prompt:
                    continue

                if prompt.lower() == 'quit':
                    break

                if prompt.lower() == 'params':
                    temperature, max_tokens, top_k = data_manager.update_params()
                    print(f"\nUpdated parameters:")
                    print(f"- Temperature: {temperature}")
                    print(f"- Max tokens: {max_tokens}")
                    print(f"- Top k: {top_k}")
                    continue

                response = interface.generate_text(
                    prompt,
                    max_tokens=max_tokens,
                    temperature=temperature,
                    top_k=top_k
                )
                print("\nModel:", response)

            except KeyboardInterrupt:
                print("\nInterrupted by user")
                break
            except Exception as e:
                print(f"\nGeneration error: {str(e)}")
                continue

    except Exception as e:
        print(f"\nError initializing chat: {str(e)}")
        raise

    finally:
        if interface is not None:
            interface.cleanup()
        print("\nSession ended. Resources cleaned up.")

In [None]:
if __name__ == "__main__":
     # Specify the model path directly
    model_path = '/content/drive/MyDrive/LLM/models/trained_model_loss_1_1983_20250124_191919.pt'

    print(f"Using model checkpoint: {model_path}")

    try:
        # Run interactive chat
        run_interactive_chat(model_path)
    except KeyboardInterrupt:
        print("\nInterrupted by user")
    except Exception as e:
        print(f"\nError: {str(e)}")
    finally:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()