In [1]:
# ============================================================================
# CELL 1: Package Installation
# ============================================================================

!pip install bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl (59.1 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m59.1/59.1 MB[0m [31m20.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: bitsandbytes
Successfully installed bitsandbytes-0.49.0


In [2]:
# ============================================================================
# CELL 2: Import Libraries and Mount Drive
# ============================================================================

import os
import json
import time
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from pathlib import Path
import numpy as np
from tqdm.auto import tqdm
from datetime import datetime
import threading
import queue
from datasets import load_dataset
import tiktoken
from torch.cuda.amp import autocast, GradScaler
import bitsandbytes as bnb
import shutil
import gc

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

print(f"‚úÖ Libraries imported")
print(f"üîß PyTorch version: {torch.__version__}")
print(f"üéÆ CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"üéÆ GPU: {torch.cuda.get_device_name(0)}")
    print(f"üíæ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

Mounted at /content/drive
‚úÖ Libraries imported
üîß PyTorch version: 2.9.0+cu126
üéÆ CUDA available: True
üéÆ GPU: Tesla T4
üíæ GPU Memory: 15.83 GB


In [14]:
# ============================================================================
# CELL 3: Configuration
# ============================================================================

class Config:
    """Training configuration"""

    # Paths
    DRIVE_DATA_DIR = "/content/drive/MyDrive/llm_training/data"
    DRIVE_CHECKPOINT_DIR = "/content/drive/MyDrive/llm_training/checkpoints"
    LOCAL_CACHE_DIR = "/content/training_cache"

    # Model architecture
    VOCAB_SIZE = 50257  # GPT-2 tokenizer vocab size
    D_MODEL = 768       # Model dimension
    N_LAYERS = 12       # Number of transformer layers
    N_HEADS = 12         # Number of attention heads
    D_FF = 3072         # Feed-forward dimension
    MAX_SEQ_LEN = 1024  # Maximum sequence length
    DROPOUT = 0.1

    # Training hyperparameters
    BATCH_SIZE = 6
    GRADIENT_ACCUM_STEPS = 21
    LEARNING_RATE = 5e-5
    WEIGHT_DECAY = 0.1
    MAX_GRAD_NORM = 1.0
    WARMUP_STEPS = 1750

    # Data loading
    CHUNK_SIZE_GB = 30      # Load 30GB at a time to Colab temp
    PREFETCH_CHUNKS = 1     # Number of chunks to prefetch

    # Checkpointing and logging
    CHECKPOINT_EVERY = 100  # Steps
    LOG_EVERY = 10           # Steps
    SAMPLE_EVERY = 100       # Generate samples
    SAVE_TOTAL_LIMIT = 3     # Keep only last 3 checkpoints

    # Training
    MAX_STEPS = 25000       # Total training steps
    USE_AMP = True           # Mixed precision training

    @classmethod
    def estimate_params(cls):
        """Estimate total parameters"""
        # Embedding
        embed_params = cls.VOCAB_SIZE * cls.D_MODEL

        # Transformer blocks
        # Each block: 4*d_model*d_model (attention) + 2*d_model*d_ff (FFN) + layer norms
        per_layer = 4 * cls.D_MODEL**2 + 2 * cls.D_MODEL * cls.D_FF
        transformer_params = cls.N_LAYERS * per_layer

        # Output head
        output_params = cls.D_MODEL * cls.VOCAB_SIZE

        total = embed_params + transformer_params + output_params
        return total / 1e6  # Return in millions

    @classmethod
    def print_config(cls):
        print("="*60)
        print("TRAINING CONFIGURATION")
        print("="*60)
        print(f"Model Parameters: ~{cls.estimate_params():.0f}M")
        print(f"Sequence Length: {cls.MAX_SEQ_LEN}")
        print(f"Batch Size: {cls.BATCH_SIZE} √ó {cls.GRADIENT_ACCUM_STEPS} = {cls.BATCH_SIZE * cls.GRADIENT_ACCUM_STEPS}")
        print(f"Learning Rate: {cls.LEARNING_RATE}")
        print(f"Max Steps: {cls.MAX_STEPS}")
        print(f"Checkpoint Every: {cls.CHECKPOINT_EVERY} steps")
        print(f"Mixed Precision: {cls.USE_AMP}")
        print("="*60)

# Create directories
os.makedirs(Config.DRIVE_DATA_DIR, exist_ok=True)
os.makedirs(Config.DRIVE_CHECKPOINT_DIR, exist_ok=True)
os.makedirs(Config.LOCAL_CACHE_DIR, exist_ok=True)

Config.print_config()

TRAINING CONFIGURATION
Model Parameters: ~162M
Sequence Length: 1024
Batch Size: 6 √ó 21 = 126
Learning Rate: 5e-05
Max Steps: 25000
Checkpoint Every: 100 steps
Mixed Precision: True


In [4]:
# ============================================================================
# CELL 4: Complete Smart Dataset
# ============================================================================

import threading

class SmartResumeTokenDataset(Dataset):
    '''Dataset with disk cleaner, smart offset, thread safety'''

    def __init__(self, token_files, seq_length, local_cache_dir, start_step=0):
        self.token_files = sorted(token_files)
        self.seq_length = seq_length
        self.local_cache_dir = Path(local_cache_dir)
        self.local_cache_dir.mkdir(exist_ok=True)

        self._file_locks = {}

        # Scan dataset
        print("üìä Scanning dataset...")
        self.file_token_counts = []
        self.total_tokens = 0

        for f in tqdm(self.token_files, desc="Scanning"):
            try:
                tokens_in_file = os.path.getsize(f) // 2
                self.file_token_counts.append(tokens_in_file)
                self.total_tokens += tokens_in_file
            except: pass

        self.num_sequences = self.total_tokens // seq_length

        # Smart offset calculation
        tokens_per_step = Config.BATCH_SIZE * Config.MAX_SEQ_LEN * Config.GRADIENT_ACCUM_STEPS
        tokens_consumed = start_step * tokens_per_step
        self.start_offset = tokens_consumed // seq_length

        # Calculate starting file
        cumulative = 0
        start_file_idx = 0
        for i, count in enumerate(self.file_token_counts):
            if cumulative + count > tokens_consumed:
                start_file_idx = i
                break
            cumulative += count

        print(f"‚úÖ Dataset ready: {self.total_tokens:,} tokens")
        if start_step > 0:
            print(f"üéØ Smart Resume: Starting from file #{start_file_idx}")

        self.current_chunk = None
        self.current_file_idx = -1
        self.files_seen = set()

    def __len__(self):
        return max(0, self.num_sequences - self.start_offset)

    def _get_file_lock(self, file_idx):
        if file_idx not in self._file_locks:
            self._file_locks[file_idx] = threading.Lock()
        return self._file_locks[file_idx]

    def _clean_cache(self, keep_file_name=None):
        '''Delete old cached files to prevent disk full'''
        for f in self.local_cache_dir.glob('*.npy'):
            if keep_file_name and f.name == keep_file_name:
                continue
            try: f.unlink()
            except: pass

    def _load_file(self, file_idx):
        source_file = self.token_files[file_idx]
        cache_file = self.local_cache_dir / source_file.name
        lock = self._get_file_lock(file_idx)

        with lock:
            if cache_file.exists():
                try: return np.load(cache_file, mmap_mode='r')
                except: cache_file.unlink()

            # Clean cache before loading new file
            self._clean_cache(keep_file_name=source_file.name)

            print(f"üì• Loading {source_file.name}...")
            temp_file = self.local_cache_dir / f".tmp_{source_file.name}"
            try:
                tokens = np.load(source_file)
                np.save(temp_file, tokens)
                temp_file.rename(cache_file)
                return np.load(cache_file, mmap_mode='r')
            except Exception as e:
                if temp_file.exists(): temp_file.unlink()
                raise RuntimeError(f"Failed to load: {e}")

    def __getitem__(self, idx):
        actual_idx = self.start_offset + idx
        if actual_idx >= self.num_sequences:
            actual_idx = actual_idx % self.num_sequences

        token_pos = actual_idx * self.seq_length

        # Find file
        cumsum = 0
        file_idx = 0
        for i, count in enumerate(self.file_token_counts):
            if cumsum + count > token_pos:
                file_idx = i
                break
            cumsum += count

        if file_idx != self.current_file_idx:
            self.current_chunk = self._load_file(file_idx)
            self.current_file_idx = file_idx
            self.files_seen.add(file_idx)

        pos_in_file = token_pos - sum(self.file_token_counts[:file_idx])

        if pos_in_file + self.seq_length + 1 > len(self.current_chunk):
            seq = self.current_chunk[pos_in_file:]
            if len(seq) < self.seq_length + 1:
                seq = np.pad(seq, (0, self.seq_length + 1 - len(seq)), constant_values=0)
        else:
            seq = self.current_chunk[pos_in_file:pos_in_file + self.seq_length + 1]

        x = torch.from_numpy(seq[:-1].copy().astype(np.int64))
        y = torch.from_numpy(seq[1:].copy().astype(np.int64))
        return x, y

print("‚úÖ Smart Dataset loaded!")

‚úÖ Smart Dataset loaded!


In [18]:
# ============================================================================
# CELL 5B: MIXED DATALOADER WITH EXPLICIT STATE TRACKING (FIXED)
# ============================================================================
# This version replaces "mathematical estimation" with a physical "bookmark" file.
# It saves the current file index to 'dataset_state.json' every time a new file loads.
# ============================================================================

import threading
import json
import time
from typing import List, Dict, Tuple
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import tiktoken

class DatasetStateManager:
    """
    Manages a simple JSON file that tracks the last used file index for each dataset.
    Location: Config.DRIVE_CHECKPOINT_DIR / dataset_state.json
    """
    def __init__(self, checkpoint_dir: Path):
        self.state_file = checkpoint_dir / "dataset_state.json"
        self._lock = threading.Lock()
        self.state = self._load_state()

    def _load_state(self) -> Dict[str, int]:
        if self.state_file.exists():
            try:
                with open(self.state_file, 'r') as f:
                    return json.load(f)
            except:
                return {}
        return {}

    def update(self, dataset_name: str, file_idx: int):
        """Update the file index for a dataset and save to disk immediately"""
        with self._lock:
            self.state[dataset_name] = file_idx
            try:
                with open(self.state_file, 'w') as f:
                    json.dump(self.state, f)
            except Exception as e:
                print(f"‚ö†Ô∏è Failed to save dataset state: {e}")

    def get_start_file(self, dataset_name: str) -> int:
        """Get the last saved file index, or 0 if new"""
        return self.state.get(dataset_name, 0)

class TokenPool:
    """Manages token files for a dataset with explicit state tracking"""

    def __init__(self, token_files: List[Path], name: str, cache_dir: Path, tokenizer, state_manager: DatasetStateManager):
        self.token_files = sorted(token_files)
        self.name = name
        self.cache_dir = cache_dir / name
        self.cache_dir.mkdir(exist_ok=True, parents=True)
        self.tokenizer = tokenizer
        self.state_manager = state_manager

        # Scan files to build local index (needed for intra-file navigation if strictly required,
        # but here we rely principally on file_idx from state_manager)
        self.file_token_counts = []
        self.total_tokens = 0
        for f in self.token_files:
            try:
                # Fast size check (files are uint16, so bytes/2)
                count = os.path.getsize(f) // 2
                self.file_token_counts.append(count)
                self.total_tokens += count
            except:
                self.file_token_counts.append(0)

        # LOAD STATE: Resume explicitly from the saved file index
        self.current_file_idx = self.state_manager.get_start_file(name)

        # Validation: Ensure index is within bounds
        if self.current_file_idx >= len(self.token_files):
            self.current_file_idx = 0

        self.current_chunk = None
        self._file_lock = threading.Lock()

        # Pre-load the starting file immediately so we are ready
        if self.total_tokens > 0:
            print(f"   ‚Ü™ {name}: Resuming explicitly at File #{self.current_file_idx}")
            self._load_file(self.current_file_idx)

    def get_sequence(self, cursor: int, seq_length: int) -> np.ndarray:
        """
        Get sequence.
        Note: 'cursor' argument is kept for compatibility with the scheduler,
        but we primarily drive data flow by iterating through our current loaded chunk.
        """
        if self.total_tokens == 0:
            return np.full(seq_length + 1, self.tokenizer.eot_token, dtype=np.uint16)

        # If chunk is not loaded or we ran off the end (shouldn't happen with logic below), reload
        if self.current_chunk is None:
            self._load_file(self.current_file_idx)

        chunk_len = len(self.current_chunk)
        max_start = chunk_len - seq_length - 1

        if max_start <= 0:
            # File too small, just pad
            seq = self.current_chunk.copy()
            return np.pad(seq, (0, seq_length + 1 - len(seq)), constant_values=self.tokenizer.eot_token)

        # Use the cursor to pick a spot in the current file, wrapping around
        start_idx = (cursor * 1024) % max_start # Arbitrary stride to utilize the file

        seq = self.current_chunk[start_idx : start_idx + seq_length + 1].copy()

        if not hasattr(self, '_access_count'): self._access_count = 0
        self._access_count += 1

        # Approx 1 file worth of context (assuming 10M tokens per file)
        # 10M tokens / 1024 seq_len ~= 10,000 samples
        if self._access_count > 10000:
            self._access_count = 0
            next_idx = (self.current_file_idx + 1) % len(self.token_files)
            self._load_file(next_idx)

        return seq

    def _load_file(self, file_idx: int):
        """Load file and SAVE STATE to Drive"""
        if not self.token_files: return

        source_file = self.token_files[file_idx]
        cache_file = self.cache_dir / source_file.name

        with self._file_lock:
            # 1. Update the JSON state file immediately
            print(f"üì• [{self.name}] Loading File #{file_idx} ({source_file.name}) -> Updating state.json")
            self.state_manager.update(self.name, file_idx)

            # 2. Load data
            if cache_file.exists():
                try:
                    self.current_chunk = np.load(cache_file, mmap_mode='r')
                    self.current_file_idx = file_idx
                    return
                except:
                    cache_file.unlink()

            # Clean old cache
            for old_file in self.cache_dir.glob('*.npy'):
                try: old_file.unlink()
                except: pass

            try:
                tokens = np.load(source_file)
                np.save(cache_file, tokens)
                self.current_chunk = np.load(cache_file, mmap_mode='r')
                self.current_file_idx = file_idx
            except Exception as e:
                raise RuntimeError(f"Failed to load {source_file}: {e}")

class HybridMixedDatasetLoader(Dataset):
    def __init__(
        self, c4_files, cosmopedia_files, alpaca_files, python_files,
        seq_length, local_cache_dir, tokenizer, current_training_step, dataset_probs, schedule_length=100000
    ):
        self.seq_length = seq_length
        self.tokenizer = tokenizer

        # Initialize State Manager (The "Bookmark" System)
        self.state_manager = DatasetStateManager(Path(Config.DRIVE_CHECKPOINT_DIR))

        # Force Python/Alpaca to always restart at 0 (they are small & loopable)
        # We only want to track C4 and Cosmopedia persistence
        # (Optional: you can remove this if you want to track them too)
        self.state_manager.update("alpaca", 0)
        self.state_manager.update("python", 0)

        # Standard Setup
        cache_path = Path(local_cache_dir)
        cache_path.mkdir(exist_ok=True, parents=True)
        self.datasets = {}

        if c4_files: self.datasets['c4'] = TokenPool(c4_files, 'c4', cache_path, tokenizer, self.state_manager)
        if cosmopedia_files: self.datasets['cosmopedia'] = TokenPool(cosmopedia_files, 'cosmopedia', cache_path, tokenizer, self.state_manager)
        if alpaca_files: self.datasets['alpaca'] = TokenPool(alpaca_files, 'alpaca', cache_path, tokenizer, self.state_manager)
        if python_files: self.datasets['python'] = TokenPool(python_files, 'python', cache_path, tokenizer, self.state_manager)

        # Renormalize Probs
        self.datasets = {k: v for k, v in self.datasets.items() if v.total_tokens > 0}
        available = set(self.datasets.keys())
        dataset_probs = {k: v for k, v in dataset_probs.items() if k in available}
        total = sum(dataset_probs.values())
        self.dataset_probs = {k: v/total for k,v in dataset_probs.items()}

        # Schedule
        self.schedule_length = schedule_length
        self.mixing_schedule = []
        rng = np.random.RandomState(42)
        names = list(self.dataset_probs.keys())
        probs = [self.dataset_probs[n] for n in names]
        for _ in range(schedule_length):
            self.mixing_schedule.append(rng.choice(names, p=probs))

        self.total_sequences = 10_000_000_000
        print(f"‚úÖ State-Aware Hybrid Loader Initialized")

    def __len__(self):
        return self.total_sequences

    def __getitem__(self, idx):
        # We just cycle through the schedule
        dataset_name = self.mixing_schedule[idx % self.schedule_length]

        # Pass the idx as a cursor so we move through the file
        # The TokenPool handles the actual "File Switching" logic
        x_tokens = self.datasets[dataset_name].get_sequence(idx, self.seq_length)

        x = torch.from_numpy(x_tokens[:-1].copy().astype(np.int64))
        y = torch.from_numpy(x_tokens[1:].copy().astype(np.int64))
        return x, y

def setup_mixed_dataloader(current_step: int = 0):
    print("="*70)
    print("SETTING UP MIXED DATALOADER (STATE-AWARE VERSION)")
    print("="*70)

    tokenizer = tiktoken.get_encoding("gpt2")
    base_dir = Path(Config.DRIVE_DATA_DIR)

    # Files
    c4 = sorted(base_dir.glob("tokens_*.npy"))
    cosmo = sorted((base_dir / "cosmopedia").glob("cosmopedia_tokens_*.npy"))
    alpaca = sorted((base_dir / "alpaca").glob("alpaca_tokens_*.npy"))
    python = sorted((base_dir / "python").glob("python_tokens_*.npy"))

    # Probs
    probs = {'cosmopedia': 0.50, #'c4': 0.00,
             'alpaca': 0.30, 'python': 0.20}
    if not python:
        probs = {'cosmopedia': 0.50, #'c4': 0.34,
                 'alpaca': 0.50}

    # Initialize
    # Note: We don't pass 'current_step' anymore because the StateManager
    # reads the actual file index from disk!
    mixed_dataset = HybridMixedDatasetLoader(
        c4_files=c4, cosmopedia_files=cosmo, alpaca_files=alpaca, python_files=python,
        seq_length=Config.MAX_SEQ_LEN, local_cache_dir=Config.LOCAL_CACHE_DIR,
        tokenizer=tokenizer, current_training_step=current_step, dataset_probs=probs
    )

    dataloader = DataLoader(mixed_dataset, batch_size=Config.BATCH_SIZE, num_workers=1, pin_memory=True)
    return dataloader

In [19]:
# ============================================================================
# CELL 6: Model Architecture (GPT-2 style)
# ============================================================================

from torch.utils.checkpoint import checkpoint
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, n_heads, dropout=0.1):
        super().__init__()
        assert d_model % n_heads == 0

        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads

        self.qkv_proj = nn.Linear(d_model, 3 * d_model)
        self.out_proj = nn.Linear(d_model, d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask=None):
        batch_size, seq_len, d_model = x.shape

        # QKV projection
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.n_heads, self.d_k)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # (3, batch, heads, seq, d_k)
        q, k, v = qkv[0], qkv[1], qkv[2]

        # Attention scores
        scores = torch.matmul(q, k.transpose(-2, -1)) / (self.d_k ** 0.5)

        # Apply causal mask
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attn = torch.softmax(scores, dim=-1)
        attn = self.dropout(attn)

        # Apply attention to values
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).contiguous().reshape(batch_size, seq_len, d_model)

        return self.out_proj(out)

class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(d_model, d_ff)
        self.linear2 = nn.Linear(d_ff, d_model)
        self.dropout = nn.Dropout(dropout)
        self.activation = nn.GELU()

    def forward(self, x):
        return self.linear2(self.dropout(self.activation(self.linear1(x))))

class TransformerBlock(nn.Module):
    def __init__(self, d_model, n_heads, d_ff, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadAttention(d_model, n_heads, dropout)
        self.ff = FeedForward(d_model, d_ff, dropout)
        self.ln1 = nn.LayerNorm(d_model)
        self.ln2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        # Pre-norm architecture (like GPT-2)
        x = x + self.dropout(self.attn(self.ln1(x), mask))
        x = x + self.dropout(self.ff(self.ln2(x)))
        return x

class GPTModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Embeddings
        self.token_embed = nn.Embedding(config.VOCAB_SIZE, config.D_MODEL)
        self.pos_embed = nn.Embedding(config.MAX_SEQ_LEN, config.D_MODEL)
        self.dropout = nn.Dropout(config.DROPOUT)

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(config.D_MODEL, config.N_HEADS, config.D_FF, config.DROPOUT)
            for _ in range(config.N_LAYERS)
        ])

        # Output
        self.ln_final = nn.LayerNorm(config.D_MODEL)
        self.head = nn.Linear(config.D_MODEL, config.VOCAB_SIZE, bias=False)

        # Tie weights
        self.token_embed.weight = self.head.weight

        # Initialize weights
        self.apply(self._init_weights)

        print(f"‚úÖ Model initialized with {self.count_parameters()/1e6:.2f}M parameters")

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def forward(self, x):
        batch_size, seq_len = x.shape

        # Embeddings
        positions = torch.arange(0, seq_len, device=x.device).unsqueeze(0)
        x = self.token_embed(x) + self.pos_embed(positions)
        x = self.dropout(x)

        # Causal mask
        mask = torch.tril(torch.ones(seq_len, seq_len, device=x.device)).unsqueeze(0).unsqueeze(0)

        # Transformer blocks
        for block in self.blocks:
            x = block(x, mask)

        # Output
        x = self.ln_final(x)
        logits = self.head(x)

        return logits

print("‚úÖ Model architecture defined!")

‚úÖ Model architecture defined!


In [20]:
# ============================================================================
# CELL 7: Training Utilities
# ============================================================================

class TrainingLogger:
    """Comprehensive logging system"""

    def __init__(self, log_dir):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(exist_ok=True, parents=True)

        # Log file
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        self.log_file = self.log_dir / f"training_{timestamp}.log"

        # Metrics history
        self.metrics = {
            'step': [],
            'loss': [],
            'lr': [],
            'tokens_per_sec': [],
            'gpu_mem_gb': []
        }

    def log(self, message, print_console=True):
        """Log message to file and optionally console"""
        timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
        log_message = f"[{timestamp}] {message}"

        with open(self.log_file, 'a') as f:
            f.write(log_message + '\n')

        if print_console:
            print(log_message)

    def log_metrics(self, step, loss, lr, tokens_per_sec, gpu_mem_gb):
        """Log training metrics"""
        self.metrics['step'].append(step)
        self.metrics['loss'].append(loss)
        self.metrics['lr'].append(lr)
        self.metrics['tokens_per_sec'].append(tokens_per_sec)
        self.metrics['gpu_mem_gb'].append(gpu_mem_gb)

        # Save metrics
        metrics_file = self.log_dir / "metrics.json"
        with open(metrics_file, 'w') as f:
            json.dump(self.metrics, f, indent=2)

    def print_progress(self, step, loss, lr, tokens_per_sec, gpu_mem_gb, elapsed_time):
        """Print formatted progress"""
        message = (
            f"Step {step:6d} | "
            f"Loss: {loss:.4f} | "
            f"LR: {lr:.2e} | "
            f"Tokens/s: {tokens_per_sec:,.0f} | "
            f"GPU: {gpu_mem_gb:.1f}GB | "
            f"Time: {elapsed_time:.0f}s"
        )
        self.log(message)

def get_lr(step, warmup_steps, max_lr, max_steps):
    """Learning rate schedule with warmup and cosine decay"""
    if step < warmup_steps:
        # Linear warmup
        return max_lr * step / warmup_steps
    elif step < max_steps:
        # Cosine decay
        progress = (step - warmup_steps) / (max_steps - warmup_steps)
        return max_lr * 0.5 * (1 + np.cos(np.pi * progress))
    else:
        return max_lr * 0.1

def save_checkpoint(model, optimizer, scaler, step, loss, checkpoint_dir, logger):
    """Save training checkpoint"""
    checkpoint_dir = Path(checkpoint_dir)
    checkpoint_dir.mkdir(exist_ok=True, parents=True)

    checkpoint_path = checkpoint_dir / f"checkpoint_step_{step}.pt"

    checkpoint = {
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict() if scaler else None,
        'loss': loss,
        'config': {
            'vocab_size': Config.VOCAB_SIZE,
            'd_model': Config.D_MODEL,
            'n_layers': Config.N_LAYERS,
            'n_heads': Config.N_HEADS,
            'd_ff': Config.D_FF,
            'max_seq_len': Config.MAX_SEQ_LEN,
            'dropout': Config.DROPOUT,
        }
    }

    torch.save(checkpoint, checkpoint_path)
    logger.log(f"üíæ Checkpoint saved: {checkpoint_path}")

    # Keep only last N checkpoints
    checkpoints = sorted(checkpoint_dir.glob("checkpoint_step_*.pt"), key=lambda p: int(p.stem.split('_')[-1]))
    if len(checkpoints) > Config.SAVE_TOTAL_LIMIT:
        for old_checkpoint in checkpoints[:-Config.SAVE_TOTAL_LIMIT]:
            old_checkpoint.unlink()
            logger.log(f"üóëÔ∏è  Removed old checkpoint: {old_checkpoint.name}")

    return checkpoint_path

def load_checkpoint(checkpoint_path, model, optimizer, scaler):
    """Load checkpoint"""
    checkpoint = torch.load(checkpoint_path, map_location='cuda', weights_only=False)

    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    if scaler and checkpoint['scaler_state_dict']:
        scaler.load_state_dict(checkpoint['scaler_state_dict'])

    return checkpoint['step'], checkpoint['loss']

def generate_sample(model, tokenizer, prompt="The future of AI is", max_length=100, temperature=0.8):
    """Generate text sample"""
    model.eval()

    tokens = tokenizer.encode(prompt)
    tokens = torch.tensor(tokens, dtype=torch.long, device='cuda').unsqueeze(0)

    with torch.no_grad():
        for _ in range(max_length):
            if tokens.size(1) >= Config.MAX_SEQ_LEN:
                tokens = tokens[:, -Config.MAX_SEQ_LEN:]

            logits = model(tokens)
            logits = logits[:, -1, :] / temperature

            probs = torch.softmax(logits, dim=-1)
            next_token = torch.multinomial(probs, num_samples=1)

            tokens = torch.cat([tokens, next_token], dim=1)

            # Stop at end of text token
            if next_token.item() == tokenizer.eot_token:
                break

    generated = tokenizer.decode(tokens[0].cpu().numpy())
    model.train()

    return generated

print("‚úÖ Training utilities defined!")

‚úÖ Training utilities defined!


In [21]:
# ============================================================================
# Cell 8 - Complete Training Function (Handles torch.compile checkpoints)
# ============================================================================

def train_complete():
    '''Complete training with smart resume, disk cleaning, error handling'''

    logger = TrainingLogger(Config.DRIVE_CHECKPOINT_DIR)
    logger.log("="*60)
    logger.log("STARTING TRAINING")
    logger.log("="*60)

    # Find token files
    token_files = list(Path(Config.DRIVE_DATA_DIR).glob("tokens_*.npy"))
    if not token_files:
        logger.log("‚ùå No token files!")
        return

    # Check for checkpoint
    checkpoints = sorted(Path(Config.DRIVE_CHECKPOINT_DIR).glob("checkpoint_step_*.pt"))
    start_step = 0

    # Create model
    logger.log("üèóÔ∏è  Building model...")
    model = GPTModel(Config).cuda()

    optimizer = bnb.optim.AdamW8bit(
        model.parameters(),
        lr=Config.LEARNING_RATE,
        weight_decay=Config.WEIGHT_DECAY,
        betas=(0.9, 0.95)
    )

    scaler = GradScaler() if Config.USE_AMP else None

    # Load checkpoint if exists
    if checkpoints:
        latest = checkpoints[-1]
        logger.log(f"üì• Loading: {latest.name}")
        checkpoint = torch.load(latest, map_location='cuda', weights_only=False)

        # ============================================================
        # FIX: Handle torch.compile() checkpoints
        # ============================================================
        state_dict = checkpoint['model_state_dict']

        # Check if checkpoint was saved from compiled model
        if any(key.startswith('_orig_mod.') for key in state_dict.keys()):
            logger.log("‚ö†Ô∏è  Checkpoint from compiled model, removing _orig_mod. prefix...")
            # Remove _orig_mod. prefix from all keys
            new_state_dict = {}
            for key, value in state_dict.items():
                new_key = key.replace('_orig_mod.', '')
                new_state_dict[new_key] = value
            state_dict = new_state_dict
        # ============================================================

        model.load_state_dict(state_dict)
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        if scaler and checkpoint.get('scaler_state_dict'):
            scaler.load_state_dict(checkpoint['scaler_state_dict'])
        start_step = checkpoint['step']
        logger.log(f"‚úÖ Resumed from step {start_step}")

    # Create smart dataset
    dataset = SmartResumeTokenDataset(
        token_files, Config.MAX_SEQ_LEN,
        Config.LOCAL_CACHE_DIR, start_step=start_step
    )

    dataloader = mixed_dataloader

    tokenizer = tiktoken.get_encoding("gpt2")

    # Training loop
    logger.log("üöÄ Starting training...")
    model.train()

    step = start_step
    total_tokens = 0
    start_time = time.time()
    step_time = time.time()
    accum_counter = 0

    try:
        for batch_idx, (x, y) in enumerate(dataloader):
            x, y = x.cuda(), y.cuda()

            with autocast(enabled=Config.USE_AMP):
                logits = model(x)
                loss = nn.functional.cross_entropy(
                    logits.view(-1, Config.VOCAB_SIZE), y.view(-1)
                )
                loss = loss / Config.GRADIENT_ACCUM_STEPS

            if Config.USE_AMP: scaler.scale(loss).backward()
            else: loss.backward()

            accum_counter += 1

            if accum_counter % Config.GRADIENT_ACCUM_STEPS == 0:
                if Config.USE_AMP:
                    scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), Config.MAX_GRAD_NORM)

                if Config.USE_AMP:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                optimizer.zero_grad()

                lr = get_lr(step, Config.WARMUP_STEPS, Config.LEARNING_RATE, Config.MAX_STEPS)
                for pg in optimizer.param_groups:
                    pg['lr'] = lr

                step += 1
                total_tokens += Config.BATCH_SIZE * Config.MAX_SEQ_LEN * Config.GRADIENT_ACCUM_STEPS

                elapsed = time.time() - step_time
                tokens_per_sec = (Config.BATCH_SIZE * Config.MAX_SEQ_LEN * Config.GRADIENT_ACCUM_STEPS) / elapsed
                gpu_mem = torch.cuda.max_memory_allocated() / 1e9

                if step % Config.LOG_EVERY == 0:
                    logger.print_progress(step, loss.item()*Config.GRADIENT_ACCUM_STEPS,
                                        lr, tokens_per_sec, gpu_mem, time.time()-start_time)
                    # logger.log(f"   üìä File #{dataset.current_file_idx} | Seen: {len(dataset.files_seen)}/{len(token_files)}")

                if step % Config.SAMPLE_EVERY == 0:
                    logger.log("\n" + "="*60)
                    logger.log("üìù SAMPLES")
                    for prompt in ["The future of AI is", "Once upon a time,", "Physics is", "A list is different from a tuple because", "A for", "Define gravity in one sentence."
]:
                        sample = generate_sample(model, tokenizer, prompt, 50)
                        logger.log(f"{prompt} ‚Üí {sample}")
                    logger.log("="*60 + "\n")

                if step % Config.CHECKPOINT_EVERY == 0:
                    save_checkpoint(model, optimizer, scaler, step,
                                  loss.item()*Config.GRADIENT_ACCUM_STEPS,
                                  Config.DRIVE_CHECKPOINT_DIR, logger)

                step_time = time.time()

                if step >= Config.MAX_STEPS:
                    break

    except KeyboardInterrupt:
        logger.log("\n‚ö†Ô∏è  Interrupted! Saving...")
        save_checkpoint(model, optimizer, scaler, step,
                       loss.item()*Config.GRADIENT_ACCUM_STEPS,
                       Config.DRIVE_CHECKPOINT_DIR, logger)
        return model

    logger.log("\n‚úÖ TRAINING COMPLETE!")
    save_checkpoint(model, optimizer, scaler, step,
                   loss.item()*Config.GRADIENT_ACCUM_STEPS,
                   Config.DRIVE_CHECKPOINT_DIR, logger)
    return model

print("‚úÖ Fixed training function loaded (handles torch.compile checkpoints)!")

‚úÖ Fixed training function loaded (handles torch.compile checkpoints)!


In [24]:
# ============================================================================
# CELL 9: START TRAINING
# ============================================================================

"""
üöÄ START TRAINING

This cell will:
1. Load tokenized data from Google Drive
2. Train the model with automatic checkpointing
3. Generate samples periodically
4. Save progress to Drive

‚ö†Ô∏è  This will take DAYS to complete!
‚ö†Ô∏è  Keep Colab Pro active or it will disconnect

Press Ctrl+C to stop training (checkpoint will be saved)
"""

try:
    mixed_dataloader = setup_mixed_dataloader(current_step=0)
    trained_model = train_complete()
    print("\nüéâ Training completed successfully!")
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è  Training interrupted by user")
    print("üíæ Latest checkpoint saved to Drive")
except Exception as e:
    print(f"\n‚ùå Training failed with error: {e}")
    import traceback
    traceback.print_exc()

SETTING UP MIXED DATALOADER (STATE-AWARE VERSION)
   ‚Ü™ c4: Resuming explicitly at File #0
üì• [c4] Loading File #0 (tokens_0000.npy) -> Updating state.json
   ‚Ü™ cosmopedia: Resuming explicitly at File #6
üì• [cosmopedia] Loading File #6 (cosmopedia_tokens_0006.npy) -> Updating state.json
   ‚Ü™ alpaca: Resuming explicitly at File #0
üì• [alpaca] Loading File #0 (alpaca_tokens_0000.npy) -> Updating state.json
   ‚Ü™ python: Resuming explicitly at File #0
üì• [python] Loading File #0 (python_tokens_0000.npy) -> Updating state.json
‚úÖ State-Aware Hybrid Loader Initialized
[2026-01-04 08:23:46] STARTING TRAINING
[2026-01-04 08:23:46] üèóÔ∏è  Building model...

‚ö†Ô∏è  Training interrupted by user
üíæ Latest checkpoint saved to Drive


In [23]:
import json
import os
import shutil

# --- CONFIGURATION ---
# Check your folder to confirm the exact name.
# It is usually 'loader_state.json' or 'model/loader_state.json'
STATE_FILE = f"{Config.DRIVE_CHECKPOINT_DIR}/dataset_state.json"

def remove_c4_from_state():
    if not os.path.exists(STATE_FILE):
        # Try looking in the root if not found in model/
        if os.path.exists("loader_state.json"):
            state_path = "loader_state.json"
        else:
            print(f"‚ùå Could not find state file at {STATE_FILE} or ./loader_state.json")
            return
    else:
        state_path = STATE_FILE

    print(f"üìÇ Found state file: {state_path}")

    # 1. Create a Backup (Safety First)
    backup_path = state_path + ".bak"
    shutil.copy(state_path, backup_path)
    print(f"‚úÖ Backup created at: {backup_path}")

    # 2. Load JSON
    with open(state_path, 'r') as f:
        data = json.load(f)

    # 3. Check and Remove C4
    # The structure usually looks like { "c4": {...}, "cosmopedia": {...} }
    # Or sometimes { "dataset_states": { "c4": ... } }

    modified = False

    # Direct Key Check
    if "c4" in data:
        print(f"üóëÔ∏è  Found 'c4' entry. Removing...")
        del data["c4"]
        modified = True

    # Nested Key Check (just in case your script nests it)
    elif "dataset_states" in data and "c4" in data["dataset_states"]:
        print(f"üóëÔ∏è  Found nested 'c4' entry. Removing...")
        del data["dataset_states"]["c4"]
        modified = True

    if modified:
        # 4. Save Changes
        with open(state_path, 'w') as f:
            json.dump(data, f, indent=4)
        print(f"üíæ Saved updated state file (C4 removed).")
        print("üöÄ You can now restart your training script.")
    else:
        print("‚ö†Ô∏è  'c4' key was not found in the file. It might already be clean.")
        print(f"Current Keys: {list(data.keys())}")

if __name__ == "__main__":
    remove_c4_from_state()

üìÇ Found state file: /content/drive/MyDrive/llm_training/checkpoints/dataset_state.json
‚úÖ Backup created at: /content/drive/MyDrive/llm_training/checkpoints/dataset_state.json.bak
üóëÔ∏è  Found 'c4' entry. Removing...
üíæ Saved updated state file (C4 removed).
üöÄ You can now restart your training script.


In [None]:
# ============================================================================
# CELL 10: Interactive Text Generation
# ============================================================================

"""
üí¨ INTERACTIVE TEXT GENERATION

Chat with your model!
"""

def interactive_generation(checkpoint_path):
    """Interactive text generation loop"""

    print("üì• Loading model...")
    model = GPTModel(Config).cuda()
    checkpoint = torch.load(checkpoint_path, map_location='cuda')
    model.load_state_dict(checkpoint['model_state_dict'])
    tokenizer = tiktoken.get_encoding("gpt2")

    print(f"‚úÖ Model loaded from step {checkpoint['step']}")
    print("\n" + "="*60)
    print("INTERACTIVE GENERATION")
    print("="*60)
    print("Enter your prompts (type 'quit' to exit)")
    print("="*60 + "\n")

    while True:
        prompt = input("You: ")

        if prompt.lower() in ['quit', 'exit', 'q']:
            print("üëã Goodbye!")
            break

        if not prompt.strip():
            continue

        print("\nü§ñ Model: ", end="")

        # Generate with streaming
        tokens = tokenizer.encode(prompt)
        tokens = torch.tensor(tokens, dtype=torch.long, device='cuda').unsqueeze(0)

        model.eval()
        with torch.no_grad():
            for i in range(150):
                if tokens.size(1) >= Config.MAX_SEQ_LEN:
                    tokens = tokens[:, -Config.MAX_SEQ_LEN:]

                logits = model(tokens)
                logits = logits[:, -1, :] / 0.8

                probs = torch.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

                tokens = torch.cat([tokens, next_token], dim=1)

                # Decode and print
                decoded = tokenizer.decode([next_token.item()])
                print(decoded, end="", flush=True)

                if next_token.item() == tokenizer.eot_token:
                    break

        model.train()
        print("\n" + "-"*60 + "\n")

# Run interactive mode
checkpoints = sorted(Path(Config.DRIVE_CHECKPOINT_DIR).glob("checkpoint_step_*.pt"))
if checkpoints:
    latest = checkpoints[-1]
    run = input(f"Start interactive generation with {latest.name}? (y/n): ")
    if run.lower() == 'y':
        interactive_generation(latest)
else:
    print("‚ùå No checkpoints found!")

In [None]:
# ============================================================================
# CELL 11: Training Statistics and Visualization
# ============================================================================

"""
üìä VISUALIZE TRAINING PROGRESS
"""

import matplotlib.pyplot as plt
import json

def plot_training_metrics():
    """Plot training metrics from logs"""

    metrics_file = Path(Config.DRIVE_CHECKPOINT_DIR) / "metrics.json"

    if not metrics_file.exists():
        print("‚ùå No metrics file found!")
        return

    with open(metrics_file, 'r') as f:
        metrics = json.load(f)

    if not metrics['step']:
        print("‚ùå No metrics recorded yet!")
        return

    fig, axes = plt.subplots(2, 2, figsize=(15, 10))

    # Loss curve
    axes[0, 0].plot(metrics['step'], metrics['loss'])
    axes[0, 0].set_xlabel('Step')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training Loss')
    axes[0, 0].grid(True)

    # Learning rate
    axes[0, 1].plot(metrics['step'], metrics['lr'])
    axes[0, 1].set_xlabel('Step')
    axes[0, 1].set_ylabel('Learning Rate')
    axes[0, 1].set_title('Learning Rate Schedule')
    axes[0, 1].grid(True)

    # Tokens per second
    axes[1, 0].plot(metrics['step'], metrics['tokens_per_sec'])
    axes[1, 0].set_xlabel('Step')
    axes[1, 0].set_ylabel('Tokens/sec')
    axes[1, 0].set_title('Training Throughput')
    axes[1, 0].grid(True)

    # GPU memory
    axes[1, 1].plot(metrics['step'], metrics['gpu_mem_gb'])
    axes[1, 1].set_xlabel('Step')
    axes[1, 1].set_ylabel('GPU Memory (GB)')
    axes[1, 1].set_title('GPU Memory Usage')
    axes[1, 1].grid(True)

    plt.tight_layout()
    plt.savefig(Path(Config.DRIVE_CHECKPOINT_DIR) / 'training_curves.png', dpi=150)
    plt.show()

    print("‚úÖ Training curves plotted!")
    print(f"üìä Total steps: {metrics['step'][-1]}")
    print(f"üìâ Final loss: {metrics['loss'][-1]:.4f}")
    print(f"‚ö° Avg throughput: {sum(metrics['tokens_per_sec'])/len(metrics['tokens_per_sec']):,.0f} tokens/sec")

# Plot metrics
plot_training_metrics()

print("\n‚úÖ ALL CELLS COMPLETE!")
print("="*60)
print("üìö USAGE GUIDE")
print("="*60)
print("1. Run Cell 4 to download and tokenize C4 (one-time, 2-6 hours)")
print("2. Run Cell 9 to start training (will take days)")
print("3. Run Cell 10 to generate text from checkpoints")
print("4. Run Cell 11 for interactive generation")
print("5. Run Cell 12 to visualize training progress")
print("="*60)

In [None]:
# ============================================================================
# DOWNLOAD CELL 1: Data Preprocessing - Tokenize C4 and Save (WITH CHECKPOINTING)
# ============================================================================

class C4Preprocessor:
    """Download C4, tokenize, and save to Drive with TRUE resume support"""

    def __init__(self, output_dir, target_size_gb=100):
        self.output_dir = Path(output_dir)
        self.output_dir.mkdir(exist_ok=True, parents=True)
        self.tokenizer = tiktoken.get_encoding("gpt2")
        self.target_size_gb = target_size_gb

        # Local temp directory for fast saving
        self.local_temp_dir = Path("/content/c4_temp_tokens")
        self.local_temp_dir.mkdir(exist_ok=True)

        # Progress tracking file
        self.progress_file = self.output_dir / "preprocessing_progress.json"

        # Partial chunk file for true resume
        self.partial_chunk_file = self.local_temp_dir / "partial_chunk.npy"

    def save_progress(self, current_bytes, file_count, current_tokens):
        """Save preprocessing progress including partial tokens"""
        progress = {
            'current_bytes': current_bytes,
            'file_count': file_count,
            'timestamp': datetime.now().isoformat()
        }

        # Save progress metadata
        with open(self.progress_file, 'w') as f:
            json.dump(progress, f)

        # Save partial tokens if any
        if current_tokens:
            token_array = np.array(current_tokens, dtype=np.uint16)
            np.save(self.partial_chunk_file, token_array)

    def load_progress(self):
        """Load preprocessing progress"""
        if self.progress_file.exists():
            with open(self.progress_file, 'r') as f:
                progress = json.load(f)

            # Load partial tokens if they exist
            if self.partial_chunk_file.exists():
                partial_tokens = np.load(self.partial_chunk_file).tolist()
                progress['partial_tokens'] = partial_tokens
            else:
                progress['partial_tokens'] = []

            return progress
        return None

    def preprocess_and_save(self):
        """Download C4, tokenize, and save as binary files with TRUE resume support"""

        print(f"üöÄ Starting C4 preprocessing (target: {self.target_size_gb}GB raw text)")

        # Check for existing progress
        existing_progress = self.load_progress()
        existing_files = sorted(self.output_dir.glob("tokens_*.npy"))

        if existing_progress and existing_files:
            current_bytes = existing_progress['current_bytes']
            file_count = existing_progress['file_count']
            current_tokens = existing_progress.get('partial_tokens', [])

            print(f"‚úÖ Found {len(existing_files)} existing token files")
            print(f"üìä Previous progress: {current_bytes/(1024**3):.2f} GB / {self.target_size_gb} GB")
            print(f"üì¶ {file_count} complete files, {len(current_tokens):,} tokens in partial chunk")

            if current_bytes >= self.target_size_gb * (1024**3):
                print("üéâ Target already reached!")
                return existing_files

            response = input("\nContinue downloading? (y/n): ")
            if response.lower() != 'y':
                print("‚è≠Ô∏è  Skipping preprocessing")
                return existing_files

            print(f"‚ñ∂Ô∏è  Continuing from {current_bytes/(1024**3):.2f} GB")
        else:
            current_bytes = 0
            file_count = 0
            current_tokens = []
            print("üÜï Starting new preprocessing...")

        # Load C4 streaming - NO SKIPPING, just continue from where we left off
        print("üì• Loading C4 dataset (streaming)...")
        dataset = load_dataset(
            "allenai/c4",
            "en",
            split="train",
            streaming=True,
            trust_remote_code=True
        )

        target_bytes = self.target_size_gb * (1024**3)
        tokens_per_file = 10_000_000  # ~10M tokens per file (~20MB)

        # For periodic saves
        save_progress_every_mb = 100  # Save every 100MB
        bytes_since_last_save = 0

        with tqdm(total=self.target_size_gb, unit='GB', desc="Processing C4",
                  initial=current_bytes/(1024**3)) as pbar:
            try:
                for item in dataset:
                    # Skip if we've reached target
                    if current_bytes >= target_bytes:
                        print("\nüéØ Target data size reached!")
                        break

                    text = item['text']
                    text_bytes = len(text.encode('utf-8'))

                    # Check if adding this document would exceed target
                    if current_bytes + text_bytes > target_bytes:
                        print(f"\n‚ö†Ô∏è  Next document ({text_bytes/1024:.1f} KB) would exceed target")
                        print("Stopping here to stay within limit")
                        break

                    # Tokenize
                    tokens = self.tokenizer.encode(text)
                    current_tokens.extend(tokens)
                    current_tokens.append(self.tokenizer.eot_token)

                    # Update progress
                    current_bytes += text_bytes
                    bytes_since_last_save += text_bytes

                    pbar.update(text_bytes / (1024**3))
                    pbar.set_postfix({
                        'files': file_count,
                        'tokens': f'{len(current_tokens):,}',
                        'GB': f'{current_bytes/(1024**3):.2f}'
                    })

                    # Save progress periodically
                    if bytes_since_last_save >= save_progress_every_mb * (1024**2):
                        self.save_progress(current_bytes, file_count, current_tokens)
                        bytes_since_last_save = 0

                    # Save file when chunk is full
                    if len(current_tokens) >= tokens_per_file:
                        filename = f"tokens_{file_count:04d}.npy"

                        # Save to local SSD first (FAST)
                        token_array = np.array(current_tokens, dtype=np.uint16)
                        local_file_path = self.local_temp_dir / filename
                        np.save(local_file_path, token_array)

                        # Copy to Drive with retry
                        max_retries = 3
                        for attempt in range(max_retries):
                            try:
                                shutil.copy(local_file_path, self.output_dir / filename)
                                break
                            except Exception as e:
                                if attempt < max_retries - 1:
                                    print(f"\n‚ö†Ô∏è  Copy failed (attempt {attempt+1}/{max_retries}): {e}")
                                    time.sleep(5)
                                else:
                                    print(f"\n‚ùå Failed to copy after {max_retries} attempts")
                                    raise

                        # Clean up
                        local_file_path.unlink()

                        file_count += 1
                        current_tokens = []

                        # Save progress after each file
                        self.save_progress(current_bytes, file_count, current_tokens)
                        bytes_since_last_save = 0

            except KeyboardInterrupt:
                print("\n\n‚ö†Ô∏è  Interrupted! Saving progress...")
                self.save_progress(current_bytes, file_count, current_tokens)
                print(f"‚úÖ Progress saved!")
                print(f"üìä Processed: {current_bytes/(1024**3):.2f} GB")
                print(f"üì¶ Complete files: {file_count}")
                print(f"üî§ Partial tokens: {len(current_tokens):,}")
                print("\nüí° Run this cell again to continue from here!")
                return list(self.output_dir.glob("tokens_*.npy"))

            except Exception as e:
                print(f"\n‚ùå Error occurred: {e}")
                print("üíæ Saving progress before exit...")
                self.save_progress(current_bytes, file_count, current_tokens)
                raise

        # Save any remaining tokens
        if current_tokens:
            filename = f"tokens_{file_count:04d}.npy"
            print(f"\nüíæ Saving final chunk ({len(current_tokens):,} tokens)...")
            token_array = np.array(current_tokens, dtype=np.uint16)
            local_file_path = self.local_temp_dir / filename
            np.save(local_file_path, token_array)
            shutil.copy(local_file_path, self.output_dir / filename)
            local_file_path.unlink()
            file_count += 1

        # Clean up partial chunk file
        if self.partial_chunk_file.exists():
            self.partial_chunk_file.unlink()

        # Final save
        self.save_progress(current_bytes, file_count, [])

        print(f"\n‚úÖ Preprocessing complete!")
        print(f"üìä Total processed: {current_bytes/(1024**3):.2f} GB")
        print(f"üì¶ Created {file_count} token files")
        print(f"üíæ Location: {self.output_dir}")

        # Calculate total tokens
        all_files = list(self.output_dir.glob("tokens_*.npy"))
        total_tokens = sum(len(np.load(f)) for f in all_files)
        print(f"üî¢ Total tokens: {total_tokens:,} ({total_tokens/1e9:.2f}B)")
        print(f"üéØ Optimal model size: {total_tokens / 20 / 1e6:.0f}M parameters")

        return all_files

# Run preprocessing
print("‚ö†Ô∏è  This cell will download and tokenize C4 data")
print("‚ö†Ô∏è  It will take 2-6 hours and use Google Drive space")
run_preprocessing = input("\nRun preprocessing now? (y/n): ")

if run_preprocessing.lower() == 'y':
    target_gb = int(input("How many GB of C4? (recommended: 50-100): ") or "50")
    preprocessor = C4Preprocessor(Config.DRIVE_DATA_DIR, target_size_gb=target_gb)
    token_files = preprocessor.preprocess_and_save()
    print(f"\n‚úÖ Ready for training with {len(token_files)} token files!")
else:
    print("‚è≠Ô∏è  Skipping preprocessing (assuming data already exists)")

In [None]:
# ============================================================================
# DOWNLOAD CELL 2: ONE-TIME DATASET PREPARATION
# ============================================================================
# ‚ö†Ô∏è  RUN THIS ONLY ONCE - Takes 3-5 hours
# Downloads Cosmopedia, Alpaca, and optionally Python to Drive
# After running, datasets are saved permanently - never run again!
# ============================================================================

import os
import json
import numpy as np
import tiktoken
from pathlib import Path
from datasets import load_dataset
from tqdm.auto import tqdm

def download_cosmopedia(output_dir: Path, tokenizer, target_gb: float = 20.0):
    """
    Download REAL Cosmopedia from HuggingFace.
    Educational/textbook content for explainer-style answers.
    """
    output_dir.mkdir(exist_ok=True, parents=True)

    print(f"üì• Downloading Cosmopedia from HuggingFace...")
    print(f"   Target: {target_gb} GB of educational content")
    print(f"   This will take 1-3 hours...")

    try:
        dataset = load_dataset(
            "HuggingFaceTB/cosmopedia",
            "web_samples_v1",
            split="train",
            streaming=True,
            trust_remote_code=True
        )
        print("‚úÖ Loaded web_samples_v1 subset")
    except Exception as e:
        print(f"‚ö†Ô∏è  web_samples_v1 failed: {e}")
        print("   Trying 'stories' subset...")
        dataset = load_dataset(
            "HuggingFaceTB/cosmopedia",
            "stories",
            split="train",
            streaming=True,
            trust_remote_code=True
        )
        print("‚úÖ Loaded stories subset")

    all_tokens = []
    target_bytes = int(target_gb * (1024**3))
    current_bytes = 0
    tokens_per_file = 5_000_000
    file_count = 0

    print("üîÑ Tokenizing Cosmopedia...")

    with tqdm(total=target_gb, unit='GB', desc="Progress") as pbar:
        try:
            for item in dataset:
                text = item.get('text', '').strip()

                if not text:
                    continue

                text_bytes = len(text.encode('utf-8'))

                if current_bytes + text_bytes > target_bytes:
                    print(f"\nüéØ Reached target {target_gb} GB!")
                    break

                tokens = tokenizer.encode(text)
                all_tokens.extend(tokens)
                all_tokens.append(tokenizer.eot_token)

                current_bytes += text_bytes
                pbar.update(text_bytes / (1024**3))

                # Save chunk
                if len(all_tokens) >= tokens_per_file:
                    filename = output_dir / f"cosmopedia_tokens_{file_count:04d}.npy"
                    np.save(filename, np.array(all_tokens[:tokens_per_file], dtype=np.uint16))
                    all_tokens = all_tokens[tokens_per_file:]
                    file_count += 1
                    pbar.set_postfix({'files': file_count})

        except KeyboardInterrupt:
            print("\n‚ö†Ô∏è  Interrupted! Saving progress...")
        except Exception as e:
            print(f"\n‚ö†Ô∏è  Error: {e}")
            print("Saving what we have...")

    # Save remaining
    if all_tokens:
        filename = output_dir / f"cosmopedia_tokens_{file_count:04d}.npy"
        np.save(filename, np.array(all_tokens, dtype=np.uint16))
        file_count += 1

    print(f"\n‚úÖ Cosmopedia Complete!")
    print(f"   üìä {current_bytes/(1024**3):.2f} GB raw text")
    print(f"   üì¶ {file_count} token files")
    print(f"   üíæ Saved to: {output_dir}")

    return list(output_dir.glob("cosmopedia_tokens_*.npy"))


def download_alpaca(output_dir: Path, tokenizer):
    """
    Download and tokenize Alpaca instruction dataset.
    Format: <|user|> instruction <|assistant|> response
    """
    output_dir.mkdir(exist_ok=True, parents=True)

    print(f"üì• Downloading Alpaca...")

    # Download if not exists
    if not os.path.exists("alpaca_data.json"):
        os.system("wget -q https://raw.githubusercontent.com/tatsu-lab/stanford_alpaca/main/alpaca_data.json")

    with open("alpaca_data.json", 'r') as f:
        data = json.load(f)

    print(f"üîÑ Tokenizing {len(data)} Alpaca samples...")

    all_tokens = []
    skipped = 0

    for item in tqdm(data, desc="Processing"):
        instruction = item.get('instruction', '').strip()
        input_text = item.get('input', '').strip()
        output_text = item.get('output', '').strip()

        if not instruction or not output_text:
            skipped += 1
            continue

        # Strong formatting with anchors
        if input_text:
            text = f"<|user|> {instruction}\n{input_text}\n<|assistant|> {output_text}\n"
        else:
            text = f"<|user|> {instruction}\n<|assistant|> {output_text}\n"

        tokens = tokenizer.encode(text)
        all_tokens.extend(tokens)
        all_tokens.append(tokenizer.eot_token)

    # Save in chunks
    tokens_per_file = 5_000_000
    file_count = 0

    for i in range(0, len(all_tokens), tokens_per_file):
        chunk = all_tokens[i:i + tokens_per_file]
        filename = output_dir / f"alpaca_tokens_{file_count:04d}.npy"
        np.save(filename, np.array(chunk, dtype=np.uint16))
        file_count += 1

    if skipped > 0:
        print(f"‚ö†Ô∏è  Skipped {skipped} empty samples")

    print(f"‚úÖ Alpaca Complete!")
    print(f"   üì¶ {file_count} token files")
    print(f"   üìù ~{len(data) - skipped:,} instruction pairs")
    print(f"   üíæ Saved to: {output_dir}")

    return list(output_dir.glob("alpaca_tokens_*.npy"))


def download_python(output_dir: Path, tokenizer, target_gb: float = 5.0):
    """
    Download Python code dataset from HuggingFace.
    Uses 'bigcode/the-stack-dedup' dataset filtered for Python.
    """
    output_dir.mkdir(exist_ok=True, parents=True)

    print(f"üì• Downloading Python code dataset...")
    print(f"   Target: {target_gb} GB of Python code")
    print(f"   This will take 30-60 minutes...")

    try:
        # Option 1: Try bigcode/the-stack-dedup (most reliable)
        print("   Trying bigcode/the-stack-dedup...")
        dataset = load_dataset(
            "bigcode/the-stack-dedup",
            data_dir="data/python",
            split="train",
            streaming=True,
            trust_remote_code=True
        )
        print("‚úÖ Loaded The Stack (Python)")
    except Exception as e1:
        print(f"   Failed: {e1}")
        try:
            # Option 2: Try codeparrot/github-code-clean
            print("   Trying codeparrot/github-code-clean...")
            dataset = load_dataset(
                "codeparrot/github-code-clean",
                languages=["Python"],
                split="train",
                streaming=True,
                trust_remote_code=True
            )
            print("‚úÖ Loaded GitHub Code Clean (Python)")
        except Exception as e2:
            print(f"   Failed: {e2}")
            try:
                # Option 3: Try smaller dataset
                print("   Trying sahil2801/CodeAlpaca-20k...")
                dataset = load_dataset(
                    "sahil2801/CodeAlpaca-20k",
                    split="train",
                    streaming=True,
                    trust_remote_code=True
                )
                print("‚úÖ Loaded CodeAlpaca (smaller, but reliable)")
            except Exception as e3:
                print(f"‚ö†Ô∏è  All Python datasets failed:")
                print(f"      Option 1: {e1}")
                print(f"      Option 2: {e2}")
                print(f"      Option 3: {e3}")
                print("   Skipping Python dataset...")
                print("   üí° You can train without it - just set python: 0.0 in probs")
                return []

    all_tokens = []
    target_bytes = int(target_gb * (1024**3))
    current_bytes = 0
    tokens_per_file = 10_000_000
    file_count = 0

    print("üîÑ Tokenizing Python code...")

    with tqdm(total=target_gb, unit='GB', desc="Progress") as pbar:
        try:
            for item in dataset:
                # Try different field names (datasets vary)
                code = None
                for field in ['content', 'code', 'text', 'output']:
                    if field in item and item[field]:
                        code = item[field].strip()
                        break

                if not code or len(code) < 100:  # Skip tiny snippets
                    continue

                code_bytes = len(code.encode('utf-8'))

                if current_bytes + code_bytes > target_bytes:
                    print(f"\nüéØ Reached target {target_gb} GB!")
                    break

                tokens = tokenizer.encode(code)
                all_tokens.extend(tokens)
                all_tokens.append(tokenizer.eot_token)

                current_bytes += code_bytes
                pbar.update(code_bytes / (1024**3))

                # Save chunk
                if len(all_tokens) >= tokens_per_file:
                    filename = output_dir / f"python_tokens_{file_count:04d}.npy"
                    np.save(filename, np.array(all_tokens[:tokens_per_file], dtype=np.uint16))
                    all_tokens = all_tokens[tokens_per_file:]
                    file_count += 1
                    pbar.set_postfix({'files': file_count})

        except KeyboardInterrupt:
            print("\n‚ö†Ô∏è  Interrupted! Saving progress...")
        except Exception as e:
            print(f"\n‚ö†Ô∏è  Error during processing: {e}")
            print("Saving what we have...")

    # Save remaining
    if all_tokens:
        filename = output_dir / f"python_tokens_{file_count:04d}.npy"
        np.save(filename, np.array(all_tokens, dtype=np.uint16))
        file_count += 1

    if file_count == 0:
        print(f"\n‚ö†Ô∏è  No Python data collected!")
        print("   Training will continue without Python dataset")
        return []

    print(f"\n‚úÖ Python Complete!")
    print(f"   üìä {current_bytes/(1024**3):.2f} GB code")
    print(f"   üì¶ {file_count} token files")
    print(f"   üíæ Saved to: {output_dir}")

    return list(output_dir.glob("python_tokens_*.npy"))


# ============================================================================
# MAIN EXECUTION
# ============================================================================

print("="*70)
print("ONE-TIME DATASET PREPARATION")
print("="*70)
print("This cell downloads and tokenizes:")
print("  1. Cosmopedia (educational content) - ~20GB, 1-3 hours")
print("  2. Alpaca (instruction following) - ~100MB, 5 minutes")
print("  3. Python (code, optional) - ~5GB, 30-60 minutes")
print("")
print("‚ö†Ô∏è  WARNING: This will use ~50GB of Google Drive space")
print("‚ö†Ô∏è  WARNING: Only run this ONCE - data persists in Drive")
print("="*70)

# Check if already exists
base_dir = Path(Config.DRIVE_DATA_DIR)
cosmopedia_exists = len(list((base_dir / "cosmopedia").glob("*.npy"))) > 0
alpaca_exists = len(list((base_dir / "alpaca").glob("*.npy"))) > 0
python_exists = len(list((base_dir / "python").glob("*.npy"))) > 0

print(f"\nüìä Current Status:")
print(f"   Cosmopedia: {'‚úÖ Already exists' if cosmopedia_exists else '‚ùå Not found'}")
print(f"   Alpaca: {'‚úÖ Already exists' if alpaca_exists else '‚ùå Not found'}")
print(f"   Python: {'‚úÖ Already exists' if python_exists else '‚ùå Not found (optional)'}")

if cosmopedia_exists and alpaca_exists:
    print("\n‚úÖ All required datasets already exist!")
    print("   You can skip this cell and go directly to Cell 5B")
    should_run = input("\nRe-download anyway? (yes/no): ").strip().lower()
    if should_run != 'yes':
        print("Skipping download.")
        import sys
        sys.exit(0)

print("\n" + "="*70)
run_confirmation = input("Start downloading? Type 'yes' to continue: ").strip().lower()

if run_confirmation != 'yes':
    print("‚ùå Download cancelled")
else:
    print("\nüöÄ Starting download and tokenization...\n")

    # Initialize tokenizer
    tokenizer = tiktoken.get_encoding("gpt2")

    # 1. Cosmopedia
    if not cosmopedia_exists:
        print("\n" + "="*70)
        print("DOWNLOADING COSMOPEDIA")
        print("="*70)
        cosmopedia_files = download_cosmopedia(
            base_dir / "cosmopedia",
            tokenizer,
            target_gb=20.0
        )
    else:
        print("\n‚è≠Ô∏è  Skipping Cosmopedia (already exists)")

    # 2. Alpaca
    if not alpaca_exists:
        print("\n" + "="*70)
        print("DOWNLOADING ALPACA")
        print("="*70)
        alpaca_files = download_alpaca(
            base_dir / "alpaca",
            tokenizer
        )
    else:
        print("\n‚è≠Ô∏è  Skipping Alpaca (already exists)")

    # 3. Python (optional)
    if not python_exists:
        print("\n" + "="*70)
        print("DOWNLOADING PYTHON (OPTIONAL)")
        print("="*70)
        want_python = input("Download Python dataset? (yes/no): ").strip().lower()
        if want_python == 'yes':
            python_files = download_python(
                base_dir / "python",
                tokenizer,
                target_gb=5.0
            )
        else:
            print("‚è≠Ô∏è  Skipping Python dataset")
    else:
        print("\n‚è≠Ô∏è  Skipping Python (already exists)")

    print("\n" + "="*70)
    print("‚úÖ ALL DOWNLOADS COMPLETE!")
    print("="*70)
    print("Next steps:")
    print("  1. Run Cell 5B to create the mixed dataloader")
    print("  2. Run Cell 9 to continue training")
    print("="*70)

In [None]:
# ============================================================================
# DOWNLOAD CELL 3: PYTHON DATASET DOWNLOAD (Instruction-Aligned, Logic-Focused)
# ============================================================================

import os
import numpy as np
import tiktoken
from pathlib import Path
from datasets import load_dataset
from tqdm.auto import tqdm

def download_python_dataset(output_dir: Path, tokenizer, target_gb: float = 5.0):
    """
    Download Python instruction-style dataset.
    Python is used to teach LOGICAL ANSWERING, not raw code completion.
    """
    output_dir.mkdir(exist_ok=True, parents=True)

    print("="*70)
    print("DOWNLOADING PYTHON DATASET (LOGIC-ALIGNED)")
    print("="*70)
    print(f"Target: {target_gb} GB of instruction-style Python")
    print("="*70 + "\n")

    dataset = None

    # ------------------------------------------------------------
    # OPTION 1: CodeAlpaca (BEST for your goal)
    # ------------------------------------------------------------
    print("üì• Option 1: Trying 'sahil2801/CodeAlpaca-20k'...")
    try:
        dataset = load_dataset(
            "sahil2801/CodeAlpaca-20k",
            split="train",
            streaming=True,
            trust_remote_code=True
        )
        print("‚úÖ Loaded CodeAlpaca\n")
    except Exception as e:
        print(f"‚ùå Failed: {str(e)[:100]}...\n")

    # ------------------------------------------------------------
    # OPTION 2: Python instruction fallback
    # ------------------------------------------------------------
    if dataset is None:
        print("üì• Option 2: Trying 'iamtarun/python_code_instructions_18k_alpaca'...")
        try:
            dataset = load_dataset(
                "iamtarun/python_code_instructions_18k_alpaca",
                split="train",
                streaming=True,
                trust_remote_code=True
            )
            print("‚úÖ Loaded Python Instructions\n")
        except Exception as e:
            print(f"‚ùå Failed: {str(e)[:100]}...\n")

    if dataset is None:
        print("‚ùå No suitable Python instruction dataset found.")
        print("Training will continue WITHOUT Python.")
        return []

    # ------------------------------------------------------------
    # TOKENIZATION (CRITICAL FIX: instruction alignment)
    # ------------------------------------------------------------
    all_tokens = []
    target_bytes = int(target_gb * (1024**3))
    current_bytes = 0
    tokens_per_file = 10_000_000
    file_count = 0

    print("üîÑ Tokenizing Python instructions...")
    print("Python will reinforce reasoning + response discipline\n")

    with tqdm(total=target_gb, unit='GB', desc="Progress") as pbar:
        try:
            for item in dataset:
                instruction = str(item.get("instruction", "")).strip()
                output = str(item.get("output", "")).strip()

                if not instruction or not output:
                    continue

                # STRONG alignment with Alpaca / Cosmopedia
                text = f"<|user|> {instruction}\n<|assistant|> {output}\n"
                text_bytes = len(text.encode("utf-8"))

                if current_bytes + text_bytes > target_bytes:
                    print("\nüéØ Target reached")
                    break

                tokens = tokenizer.encode(text)
                all_tokens.extend(tokens)
                all_tokens.append(tokenizer.eot_token)

                current_bytes += text_bytes
                pbar.update(text_bytes / (1024**3))

                if len(all_tokens) >= tokens_per_file:
                    filename = output_dir / f"python_tokens_{file_count:04d}.npy"
                    np.save(filename, np.array(all_tokens[:tokens_per_file], dtype=np.uint16))
                    all_tokens = all_tokens[tokens_per_file:]
                    file_count += 1
                    pbar.set_postfix({
                        "files": file_count,
                        "GB": f"{current_bytes/(1024**3):.2f}"
                    })

        except KeyboardInterrupt:
            print("\n‚ö†Ô∏è Interrupted, saving progress...")
        except Exception as e:
            print(f"\n‚ö†Ô∏è Error: {e}")
            print("Saving what we have...")

    # Save remainder
    if all_tokens:
        filename = output_dir / f"python_tokens_{file_count:04d}.npy"
        np.save(filename, np.array(all_tokens, dtype=np.uint16))
        file_count += 1

    if file_count == 0:
        print("‚ùå No Python data collected")
        return []

    print("\n" + "="*70)
    print("‚úÖ PYTHON DATASET READY")
    print("="*70)
    print(f"üì¶ Files: {file_count}")
    print(f"üìä Raw text: {current_bytes/(1024**3):.2f} GB")
    print(f"üíæ Location: {output_dir}")
    print("üéØ Purpose: Logical answering + structured responses")
    print("="*70)

    return list(output_dir.glob("python_tokens_*.npy"))


# ============================================================================
# RUN
# ============================================================================

print("\nüêç PREPARING PYTHON DATASET (LOGIC MODE)")
print("="*70)

tokenizer = tiktoken.get_encoding("gpt2")
python_dir = Path(Config.DRIVE_DATA_DIR) / "python"

python_files = download_python_dataset(
    python_dir,
    tokenizer,
    target_gb=5.0
)

print(f"\n‚úÖ Python files ready: {len(python_files)}")
print("Next step: plug into MixedDatasetLoader")