In [2]:
# Install required packages
!pip install -q transformers datasets tqdm huggingface-hub

In [3]:
# Add this at the start of the notebook for better formatting
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

In [7]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
import math
import logging
import time
from datetime import datetime
from datasets import load_dataset
from transformers import AutoTokenizer
from tqdm import tqdm
import random
import warnings

# Create project directories
PROJECT_DIR = "deepseek-training"  # Changed from Google Drive path
CHECKPOINTS_DIR = os.path.join(PROJECT_DIR, "checkpoints")
LOGS_DIR = os.path.join(PROJECT_DIR, "logs")

# Create directories if they don't exist
for dir_path in [PROJECT_DIR, CHECKPOINTS_DIR, LOGS_DIR]:
    os.makedirs(dir_path, exist_ok=True)

# Configure logging
log_file = os.path.join(LOGS_DIR, f"training_{datetime.now().strftime('%Y%m%d_%H%M%S')}.log")
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(message)s',
    handlers=[
        logging.FileHandler(log_file),
        logging.StreamHandler()
    ]
)

warnings.filterwarnings("ignore")  # Suppress warnings

class RMSNorm(nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
        return x * norm * self.weight

class RotaryEmbedding(nn.Module):
    def __init__(self, dim: int, max_seq_length: int = 2048, theta: float = 10000.0):
        super().__init__()
        self.dim = dim
        self.max_seq_length = max_seq_length

        position = torch.arange(0, max_seq_length, dtype=torch.float)
        freqs = torch.exp(
            -torch.arange(0, dim, 2, dtype=torch.float) * (math.log(theta) / dim)
        )
        angles = position.unsqueeze(1) * freqs.unsqueeze(0)
        self.register_buffer("cos", angles.cos())
        self.register_buffer("sin", angles.sin())

    def forward(self, x: torch.Tensor, seq_len: Optional[int] = None) -> torch.Tensor:
        if seq_len is None:
            seq_len = x.size(1)

        cos = self.cos[:seq_len]
        sin = self.sin[:seq_len]

        cos = cos.view(1, seq_len, 1, -1)
        sin = sin.view(1, seq_len, 1, -1)

        x_even = x[..., ::2]
        x_odd = x[..., 1::2]

        rotated = torch.cat([
            x_even * cos - x_odd * sin,
            x_even * sin + x_odd * cos
        ], dim=-1)

        return rotated

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, num_key_value_heads: int, dropout: float = 0.0):
        super().__init__()
        self.num_heads = num_heads
        self.num_key_value_heads = num_key_value_heads
        self.head_dim = d_model // num_heads
        self.d_model = d_model

        self.q_proj = nn.Linear(d_model, d_model, bias=False)
        self.k_proj = nn.Linear(d_model, (d_model // num_heads) * num_key_value_heads, bias=False)
        self.v_proj = nn.Linear(d_model, (d_model // num_heads) * num_key_value_heads, bias=False)
        self.o_proj = nn.Linear(d_model, d_model, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape

        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)
        v = self.v_proj(x).view(batch_size, seq_len, self.num_key_value_heads, self.head_dim)

        if self.num_key_value_heads != self.num_heads:
            k = k.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=2)
            v = v.repeat_interleave(self.num_heads // self.num_key_value_heads, dim=2)

        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        scale = 1.0 / math.sqrt(self.head_dim)
        scores = torch.matmul(q, k.transpose(-2, -1)) * scale

        if attention_mask is not None:
            scores = scores.masked_fill(attention_mask == 0, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous()
        out = out.view(batch_size, seq_len, self.d_model)

        return self.o_proj(out)

class MoELayer(nn.Module):
    def __init__(self, d_model: int, d_ff: int, num_experts: int = 4, dropout: float = 0.0):
        super().__init__()
        self.num_experts = num_experts

        # Initialize experts
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(d_model, d_ff),
                nn.GELU(),
                nn.Dropout(dropout),
                nn.Linear(d_ff, d_model),
                nn.Dropout(dropout)
            )
            for _ in range(num_experts)
        ])

        # Router
        self.router = nn.Linear(d_model, num_experts)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape

        # Get routing probabilities
        route_logits = self.router(x)  # [batch_size, seq_len, num_experts]
        route_probs = F.softmax(route_logits, dim=-1)

        # Process each token through experts
        combined_output = torch.zeros_like(x)
        for i in range(self.num_experts):
            # Get expert outputs
            expert_output = self.experts[i](x)
            # Weight by routing probability
            combined_output += expert_output * route_probs[..., i:i+1]

        return combined_output

class TransformerBlock(nn.Module):
    def __init__(self, d_model: int, num_heads: int, num_key_value_heads: int, d_ff: int, num_experts: int, dropout: float = 0.0):
        super().__init__()
        self.attention = MultiHeadAttention(d_model, num_heads, num_key_value_heads, dropout)
        self.moe = MoELayer(d_model, d_ff, num_experts, dropout)
        self.norm1 = RMSNorm(d_model)
        self.norm2 = RMSNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        # Self-attention
        residual = x
        x = self.norm1(x)
        x = self.attention(x, attention_mask)
        x = self.dropout(x)
        x = residual + x

        # MoE FFN
        residual = x
        x = self.norm2(x)
        x = self.moe(x)
        x = residual + x

        return x

class DeepSeekLM(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        d_model: int = 384,
        num_layers: int = 20,
        num_heads: int = 6,
        d_ff: int = 1536,
        num_experts: int = 4,
        rope_theta: float = 10000.0,
        max_seq_length: int = 2048,
        dropout: float = 0.1
    ):
        super().__init__()
        self.d_model = d_model
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.d_ff = d_ff
        self.num_experts = num_experts
        self.head_dim = d_model // num_heads
        self.num_key_value_heads = max(1, num_heads // 3)

        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_embedding = RotaryEmbedding(
            self.head_dim,
            max_seq_length=max_seq_length,
            theta=rope_theta
        )

        self.blocks = nn.ModuleList([
            TransformerBlock(
                d_model=d_model,
                num_heads=num_heads,
                num_key_value_heads=self.num_key_value_heads,
                d_ff=d_ff,
                num_experts=num_experts,
                dropout=dropout
            )
            for _ in range(num_layers)
        ])

        self.norm = RMSNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)

        # Initialize weights
        self.apply(self._init_weights)

        self.gradient_checkpointing = False

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        x = self.embedding(input_ids)

        batch_size, seq_len, _ = x.shape
        x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
        x = self.pos_embedding(x)
        x = x.view(batch_size, seq_len, self.d_model)

        if self.gradient_checkpointing and self.training:
            for block in self.blocks:
                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        return module(*inputs)
                    return custom_forward

                x = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(block),
                    x, attention_mask
                )
        else:
            for block in self.blocks:
                x = block(x, attention_mask)

        x = self.norm(x)
        logits = self.lm_head(x)

        return logits

    def gradient_checkpointing_enable(self):
        """Enable gradient checkpointing"""
        self.gradient_checkpointing = True

    def gradient_checkpointing_disable(self):
        """Disable gradient checkpointing"""
        self.gradient_checkpointing = False

def count_parameters(model):
    """Count and format model parameters"""
    total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

    # Count parameters by layer type
    layer_params = {}
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        layer_type = name.split('.')[0]
        layer_params[layer_type] = layer_params.get(layer_type, 0) + p.numel()

    return total_params, layer_params

def format_number(num):
    """Format large numbers with commas and calculate size in MB"""
    return f"{num:,} (~{num/1e6:.1f}M)"

def get_latest_checkpoint():
    """Find the latest valid checkpoint in the checkpoints directory"""
    if not os.path.exists(CHECKPOINTS_DIR):
        return None

    checkpoints = [f for f in os.listdir(CHECKPOINTS_DIR)
                  if (f.startswith('step_') or f.startswith('interrupt_step_'))
                  and f.endswith('.pt')]
    if not checkpoints:
        return None

    # Extract step numbers and find latest valid checkpoint
    valid_checkpoints = []
    for f in checkpoints:
        try:
            checkpoint_path = os.path.join(CHECKPOINTS_DIR, f)
            # Try to load checkpoint to verify integrity
            torch.load(checkpoint_path, map_location='cpu')  # Load on CPU first to verify

            if f.startswith('interrupt_step_'):
                step = int(f.split('_')[2].split('.')[0])
            else:
                step = int(f.split('_')[1].split('.')[0])
            valid_checkpoints.append((step, f))
        except Exception as e:
            print(f"Skipping corrupted checkpoint {f}: {str(e)}")
            continue

    if not valid_checkpoints:
        return None

    # Get the checkpoint with highest step count
    _, latest_file = max(valid_checkpoints, key=lambda x: x[0])
    checkpoint_path = os.path.join(CHECKPOINTS_DIR, latest_file)
    return checkpoint_path

def save_checkpoint(model, optimizer, step, loss, tokens_processed, is_interrupt=False, is_final=False):
    """Save checkpoint with proper error handling"""
    try:
        if is_final:
            filename = "final_model.pt"
        elif is_interrupt:
            filename = f"interrupt_step_{step}.pt"
        else:
            filename = f"step_{step}.pt"

        checkpoint_path = os.path.join(CHECKPOINTS_DIR, filename)

        # Save to temporary file first
        temp_path = checkpoint_path + '.tmp'
        torch.save({
            'step': step,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            'tokens_processed': tokens_processed,
        }, temp_path)

        # If save was successful, rename to final filename
        os.replace(temp_path, checkpoint_path)
        return checkpoint_path
    except Exception as e:
        print(f"Error saving checkpoint: {str(e)}")
        if os.path.exists(temp_path):
            os.remove(temp_path)
        return None

def setup_tokenizer_and_dataset(max_length=256):
    """Setup tokenizer and dataset with progress tracking"""
    print("\nInitializing training components:")
    print("="*50)

    # Initialize tokenizer with progress tracking
    print("Loading tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(
        "HuggingFaceTB/cosmo2-tokenizer",
        trust_remote_code=True,
        use_fast=True  # Use fast tokenizer
    )
    tokenizer.pad_token = tokenizer.eos_token
    print("✓ Tokenizer loaded successfully")

    # Load dataset with progress tracking
    print("\nLoading dataset...")
    try:
        dataset = load_dataset(
            "HuggingFaceTB/smollm-corpus",
            name="cosmopedia-v2",
            streaming=True,
            split="train"
        )

        # Calculate approximate dataset size
        dataset_size = dataset.dataset_size
        sample_size = int(dataset_size * 0.3)
        print(f"✓ Dataset loaded successfully")
        print(f"  - Total samples: {dataset_size:,}")
        print(f"  - Using {sample_size:,} samples (30%)")

        # Take 30% of data with progress tracking
        dataset = dataset.take(sample_size)
        dataset = dataset.shuffle(seed=42)

        return tokenizer, dataset

    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        raise

def setup_logging():
    """Setup logging with both file and console handlers"""
    # Create a unique session ID
    session_id = datetime.now().strftime('%Y%m%d_%H%M%S')

    # Ensure directories exist
    os.makedirs(LOGS_DIR, exist_ok=True)

    # Setup file handler to append to existing log file
    log_file = os.path.join(LOGS_DIR, "training_history.log")
    file_handler = logging.FileHandler(log_file, mode='a')
    file_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))

    # Setup console handler with minimal output
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(logging.Formatter('%(message)s'))
    console_handler.setLevel(logging.WARNING)

    # Configure root logger
    logging.getLogger().handlers = []
    logging.basicConfig(
        level=logging.INFO,
        handlers=[file_handler, console_handler]
    )

    # Create CSV progress file for this session
    progress_file = os.path.join(LOGS_DIR, f"progress_{session_id}.csv")
    os.makedirs(os.path.dirname(progress_file), exist_ok=True)
    with open(progress_file, 'w') as f:
        f.write("step,loss,tokens_processed,time,tokens_per_sec\n")

    return log_file, progress_file, session_id

def train(resume_training=True):
    """Train the model with support for resuming from checkpoints"""
    # Setup logging first
    log_file, progress_file, session_id = setup_logging()

    # Log session start
    logging.info("\n" + "="*50)
    logging.info(f"Starting new training session: {session_id}")

    # Create directories if needed
    os.makedirs(CHECKPOINTS_DIR, exist_ok=True)
    os.makedirs(LOGS_DIR, exist_ok=True)

    # Initialize components
    tokenizer, dataset = setup_tokenizer_and_dataset()
    dataset_iter = iter(dataset)

    # Training parameters - optimized for speed
    batch_size = 4  # Increased from 2
    learning_rate = 3e-4
    max_steps = 10000
    save_every = 1000
    log_every = 100
    max_length = 256  # Keep this the same
    grad_clip = 1.0
    warmup_steps = 1000
    gradient_accumulation_steps = 4  # Reduced from 8 (maintains same effective batch size)

    # Model parameters - adjusted to target exactly 135M parameters
    model = DeepSeekLM(
        vocab_size=tokenizer.vocab_size,
        d_model=512,        # Reduced from 576
        num_layers=12,      # Adjusted from 10
        num_heads=8,        # Reduced from 9
        d_ff=2048,         # 4x d_model
        num_experts=4,      # Reduced from 6
        max_seq_length=max_length,
        dropout=0.1
    )

    # Print model size details
    total_params = sum(p.numel() for p in model.parameters())/1e6
    print("\nModel Configuration:")
    print("="*50)
    print(f"Total Parameters: {total_params:.2f}M")
    print(f"Embedding dim: {model.d_model}")
    print(f"Layers: {model.num_layers}")
    print(f"Attention heads: {model.num_heads}")
    print(f"FF dim: {model.d_ff}")
    print(f"MoE experts: {model.num_experts}")
    print(f"Sequence length: {max_length}")
    print(f"Batch size: {batch_size} (effective: {batch_size * gradient_accumulation_steps})")
    print("="*50)

    # Enable memory optimizations
    model.gradient_checkpointing_enable()
    print("✓ Gradient checkpointing enabled")

    # Set memory efficient attention
    torch.backends.cuda.max_memory_split_size = None
    torch.backends.cuda.max_memory_cached = None

    # Empty CUDA cache before moving model to device
    torch.cuda.empty_cache()

    # Move model to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Initialize mixed precision training with lower precision
    scaler = torch.cuda.amp.GradScaler()
    torch.backends.cudnn.benchmark = True

    # Enable faster training options
    torch.backends.cudnn.benchmark = True
    torch.backends.cuda.matmul.allow_tf32 = True  # Allow TF32 on Ampere GPUs
    torch.backends.cudnn.allow_tf32 = True

    # Data loading optimization
    def get_next_batch(dataset_iter):
        """Get next batch of data with error handling"""
        max_retries = 3
        for retry in range(max_retries):
            try:
                batch_texts = []
                while len(batch_texts) < batch_size:
                    try:
                        sample = next(dataset_iter)
                        if not isinstance(sample['text'], str):
                            continue
                        if len(sample['text'].strip()) < 10:  # Skip very short texts
                            continue
                        batch_texts.append(sample['text'])
                    except StopIteration:
                        dataset_iter = iter(dataset.shuffle())
                        sample = next(dataset_iter)
                        batch_texts.append(sample['text'])

                # Process current batch
                inputs = tokenizer(
                    batch_texts,
                    truncation=True,
                    max_length=max_length,
                    padding='max_length',
                    return_tensors='pt'
                )
                return inputs, dataset_iter
            except Exception as e:
                if retry == max_retries - 1:
                    raise Exception(f"Failed to get batch after {max_retries} attempts: {str(e)}")
                print(f"Retry {retry + 1}/{max_retries}: Error getting batch: {str(e)}")
                continue

    # Initialize session
    start_time = time.time()

    # Initialize optimizer and learning rate scheduler
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=learning_rate,
        betas=(0.9, 0.95),
        eps=1e-8,
        weight_decay=0.1
    )

    # Add learning rate scheduler
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer,
        T_max=max_steps,
        eta_min=learning_rate/10
    )

    # Initialize training state
    step = 0
    tokens_processed = 0
    best_loss = float('inf')

    # Try to load checkpoint if resume_training is True
    if resume_training:
        checkpoint_path = get_latest_checkpoint()
        if checkpoint_path:
            print(f"\nFound valid checkpoint at {checkpoint_path}")
            try:
                checkpoint = torch.load(checkpoint_path, map_location=device)

                # Verify checkpoint contents
                required_keys = ['model_state_dict', 'optimizer_state_dict', 'step', 'loss']
                if not all(k in checkpoint for k in required_keys):
                    raise ValueError("Checkpoint missing required keys")

                model.load_state_dict(checkpoint['model_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

                # Move optimizer state to correct device
                for state in optimizer.state.values():
                    for k, v in state.items():
                        if isinstance(v, torch.Tensor):
                            state[k] = v.to(device)

                step = checkpoint['step']
                tokens_processed = checkpoint.get('tokens_processed', 0)
                best_loss = checkpoint.get('loss', float('inf'))

                # Update scheduler to correct step
                for _ in range(step):
                    scheduler.step()

                print(f"✓ Successfully resumed training from step {step}")
                print(f"  - Previous loss: {best_loss:.4f}")
                print(f"  - Tokens processed: {tokens_processed:,}")

            except Exception as e:
                print(f"Error loading checkpoint: {str(e)}")
                print("Starting fresh training...")
                step = 0
                tokens_processed = 0
                best_loss = float('inf')
        else:
            print("No valid checkpoints found, starting fresh training")
            step = 0
            tokens_processed = 0
            best_loss = float('inf')
    else:
        print("Starting fresh training")
        step = 0
        tokens_processed = 0
        best_loss = float('inf')

    # Get detailed parameter counts and log architecture details
    total_params, layer_params = count_parameters(model)

    logging.info("\nModel Architecture Details:")
    logging.info("="*50)
    logging.info(f"Total Parameters: {format_number(total_params)}")
    logging.info("\nParameters by layer:")
    for layer, params in layer_params.items():
        logging.info(f"- {layer}: {format_number(params)}")

    logging.info("\nArchitecture Configuration:")
    logging.info(f"- Model dimension: {model.d_model}")
    logging.info(f"- Number of layers: {model.num_layers}")
    logging.info(f"- Attention heads: {model.num_heads}")
    logging.info(f"- KV heads: {model.num_key_value_heads} (MLHA)")
    logging.info(f"- FF dimension: {model.d_ff}")
    logging.info(f"- MoE experts: {model.num_experts}")
    logging.info("="*50 + "\n")

    # Training loop
    model.train()
    running_loss = 0.0

    try:
        print("\nStarting training loop...")
        progress_bar = tqdm(
            total=max_steps,
            desc="Training",
            position=0,
            leave=True,
            ncols=100,
            disable=True  # Disable tqdm bar, we'll use our own progress format
        )

        while step < max_steps:
            try:
                accumulated_loss = 0
                optimizer.zero_grad(set_to_none=True)

                # Gradient accumulation loop
                for accum_step in range(gradient_accumulation_steps):
                    # Get next batch with progress info
                    if step == 0 and accum_step == 0:
                        print("Getting first batch...")

                    inputs, dataset_iter = get_next_batch(dataset_iter)
                    if step == 0 and accum_step == 0:
                        print("✓ First batch processed successfully")

                    inputs = {k: v.to(device, non_blocking=True) for k, v in inputs.items()}

                    # Training step with progress info
                    if step == 0 and accum_step == 0:
                        print("Starting first forward pass...")

                    with torch.cuda.amp.autocast(dtype=torch.float16):
                        logits = model(inputs['input_ids'])
                        shift_logits = logits[..., :-1, :].contiguous()
                        shift_labels = inputs['input_ids'][..., 1:].contiguous()
                        loss = F.cross_entropy(
                            shift_logits.view(-1, tokenizer.vocab_size),
                            shift_labels.view(-1)
                        )
                        loss = loss / gradient_accumulation_steps

                    if step == 0 and accum_step == 0:
                        print(f"✓ First forward pass complete (loss: {loss.item():.4f})")
                        print("Starting first backward pass...")

                    # Scale loss and backward pass
                    scaler.scale(loss).backward()
                    accumulated_loss += loss.item()

                    if step == 0 and accum_step == 0:
                        print("✓ First backward pass complete")

                    # Memory cleanup
                    del logits, shift_logits, shift_labels
                    if step % 100 == 0:
                        torch.cuda.empty_cache()

                # Optimizer step with gradient clipping
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
                scaler.step(optimizer)
                scaler.update()

                if step < warmup_steps:
                    lr = learning_rate * (step / warmup_steps)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr
                else:
                    scheduler.step()

                # Update counters
                step += 1
                tokens_processed += inputs['input_ids'].numel() * gradient_accumulation_steps
                running_loss += accumulated_loss

                # Print progress for every step
                elapsed = time.time() - start_time
                tokens_per_sec = tokens_processed / elapsed
                gpu_mem_alloc = torch.cuda.memory_allocated() / 1024**2
                gpu_mem_reserved = torch.cuda.memory_reserved() / 1024**2

                progress_msg = (
                    f"\rStep {step:5d}/{max_steps} ({step/max_steps*100:.1f}%) | "
                    f"Loss: {accumulated_loss:.4f} | "
                    f"Time: {elapsed:.1f}s | "
                    f"Tokens/sec: {tokens_per_sec:.2f} | "
                    f"Total Tokens: {tokens_processed:,} | "
                    f"GPU Memory: {gpu_mem_alloc:.0f}MB/{gpu_mem_reserved:.0f}MB"
                )
                print(progress_msg, end="", flush=True)

                # Log detailed stats every 100 steps
                if step % log_every == 0:
                    avg_loss = running_loss / log_every

                    # Log to file
                    logging.info(progress_msg)

                    # Update CSV progress
                    try:
                        with open(progress_file, 'a') as f:
                            f.write(f"{step},{avg_loss},{tokens_processed},{elapsed},{tokens_per_sec}\n")
                    except Exception as e:
                        print(f"\nWarning: Could not write to progress file: {str(e)}")

                    running_loss = 0.0
                    print()  # New line after logging

                # Save checkpoint every 1000 steps
                if step % save_every == 0:
                    saved_path = save_checkpoint(
                        model, optimizer, step, accumulated_loss,
                        tokens_processed, is_interrupt=False
                    )
                    if saved_path:
                        logging.info(f"Saved checkpoint to {saved_path}")

            except Exception as e:
                print(f"\nError during training step {step}: {str(e)}")
                raise

        progress_bar.close()

        # Log completion
        logging.info("\nTraining completed successfully!")
        logging.info(f"Total steps: {step}")
        logging.info(f"Total tokens processed: {tokens_processed:,}")
        logging.info(f"Best loss achieved: {best_loss:.4f}")
        logging.info("="*50 + "\n")

    except KeyboardInterrupt:
        progress_bar.close()
        logging.warning("\nTraining interrupted! Saving checkpoint...")
        save_checkpoint(
            model, optimizer, step, accumulated_loss,
            tokens_processed, is_interrupt=True
        )

    return model, tokenizer, log_file

def generate_samples(model, tokenizer, device, num_samples=5, max_new_tokens=100, temperature=1.0):
    """Generate text samples from the model"""
    print("\nGenerating samples:")
    print("="*50)

    model.eval()
    prompts = [
        "The quantum mechanics of particles describes how",
        "In computer science, neural networks can learn to",
        "The theory of relativity fundamentally changed our understanding of",
        "Machine learning algorithms have revolutionized",
        "The structure of DNA contains information about"
    ]

    # Generation parameters - tuned for better output
    temperature = 1.0       # Higher temperature for more randomness
    top_k = 100            # Increased for more variety
    top_p = 0.95          # Slightly higher nucleus sampling threshold
    repetition_penalty = 1.5  # Increased repetition penalty
    min_length = 20       # Ensure minimum output length

    for i, prompt in enumerate(prompts[:num_samples], 1):
        print(f"\nSample {i}:")
        print(f"Prompt: {prompt}")

        with torch.no_grad():
            input_ids = tokenizer(prompt, return_tensors='pt')['input_ids'].to(device)
            output_sequence = input_ids.clone()

            # Track generated tokens and their counts
            token_counts = {}
            generated_tokens = []

            for _ in range(max_new_tokens):
                outputs = model(output_sequence)
                next_token_logits = outputs[:, -1, :] / temperature

                # Apply stronger repetition penalty based on token frequency
                for token, count in token_counts.items():
                    next_token_logits[0, token] /= (repetition_penalty ** count)

                # Apply top-k filtering
                top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)

                # Apply dynamic nucleus sampling
                probs = torch.softmax(top_k_logits, dim=-1)
                cumulative_probs = torch.cumsum(probs, dim=-1)
                nucleus_mask = cumulative_probs < top_p
                nucleus_mask[..., 1:] = nucleus_mask[..., :-1].clone()
                nucleus_mask[..., 0] = True

                # Apply stronger filtering for common tokens
                filtered_logits = top_k_logits.masked_fill(~nucleus_mask, float('-inf'))
                filtered_probs = torch.softmax(filtered_logits, dim=-1)

                next_token_idx = torch.multinomial(filtered_probs, num_samples=1)
                next_token = top_k_indices.gather(-1, next_token_idx)

                # Update token counts
                token = next_token.item()
                token_counts[token] = token_counts.get(token, 0) + 1

                # Stop if we generate an EOS token and passed minimum length
                if next_token.item() == tokenizer.eos_token_id and len(generated_tokens) >= min_length:
                    break

                # Add token and check for repetition
                generated_tokens.append(token)
                output_sequence = torch.cat([output_sequence, next_token], dim=1)

                # Check for repetition patterns
                if len(generated_tokens) >= 4:
                    last_4 = generated_tokens[-4:]
                    if len(set(last_4)) == 1 or (  # Same token repeated
                        len(generated_tokens) >= 8 and
                        generated_tokens[-4:] == generated_tokens[-8:-4]  # Repeating pattern
                    ):
                        break

            generated_text = tokenizer.decode(output_sequence[0], skip_special_tokens=True)
            print(f"Generated: {generated_text}")
            print("-"*50)

if __name__ == "__main__":

    from google.colab import drive
    drive.mount('/content/drive')

    # Update paths for Google Drive
    global PROJECT_DIR, CHECKPOINTS_DIR, LOGS_DIR
    PROJECT_DIR = "/content/drive/MyDrive/deepseek-training"
    CHECKPOINTS_DIR = os.path.join(PROJECT_DIR, "checkpoints")
    LOGS_DIR = os.path.join(PROJECT_DIR, "logs")

    # Create directories
    for dir_path in [PROJECT_DIR, CHECKPOINTS_DIR, LOGS_DIR]:
        os.makedirs(dir_path, exist_ok=True)
    # Set resume_training=False to force fresh training
    model, tokenizer, log_file = train(resume_training=True)

    print("\nTraining Summary:")
    print("="*50)

    # Display model architecture
    with open(log_file, 'r') as f:
        lines = f.readlines()
        arch_section = False
        for line in lines:
            if "Model Architecture Details:" in line:
                arch_section = True
                print("\n" + line.strip())
            elif arch_section and "Training" in line:
                arch_section = False
            elif arch_section:
                print(line.strip())

    # Display final training stats
    print("\nFinal Training Stats:")
    print("-"*50)
    for line in lines[-10:]:  # Show last 10 lines of training log
        if "step" in line.lower() or "loss" in line.lower():
            print(line.strip())

    # Generate samples
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    generate_samples(model, tokenizer, device, temperature=1.0)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).

Initializing training components:
Loading tokenizer...
✓ Tokenizer loaded successfully

Loading dataset...


Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/104 [00:00<?, ?it/s]

✓ Dataset loaded successfully
  - Total samples: 212,503,640,747
  - Using 63,751,092,224 samples (30%)

Model Configuration:
Total Parameters: 159.02M
Embedding dim: 512
Layers: 12
Attention heads: 8
FF dim: 2048
MoE experts: 4
Sequence length: 256
Batch size: 4 (effective: 16)
✓ Gradient checkpointing enabled

Found valid checkpoint at /content/drive/MyDrive/deepseek-training/checkpoints/step_10000.pt
✓ Successfully resumed training from step 10000
  - Previous loss: 0.3837
  - Tokens processed: 40,960,000

Starting training loop...

Training Summary:

Model Architecture Details:
2025-02-07 12:48:43,292 - Total Parameters: 159,019,568 (~159.0M)
2025-02-07 12:48:43,292 -
Parameters by layer:
2025-02-07 12:48:43,292 - - embedding: 25,165,824 (~25.2M)
2025-02-07 12:48:43,292 - - blocks: 108,687,408 (~108.7M)
2025-02-07 12:48:43,292 - - norm: 512 (~0.0M)
2025-02-07 12:48:43,292 - - lm_head: 25,165,824 (~25.2M)
2025-02-07 12:48:43,292 -
Architecture Configuration:
2025-02-07 12:48:43,292 