In [None]:
"""MedAssist.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/17YokGk0b0BEg6UFQppdnfeJM02eRUUhR

MedAssist-GPT: Complete Medical LLM Pretraining Script
=====================================================
Modern architecture with RoPE, GQA, SwiGLU, RMSNorm
Optimized for A100 GPU with Flash Attention
Automatic checkpointing and HuggingFace uploads

## Test

## Imports
"""

In [None]:
# Core Python
import os
import json
import gc
import re
import html
import math
import pickle
import multiprocessing as mp
from pathlib import Path
from typing import Dict, List, Tuple, Any, Optional
from functools import partial
from concurrent.futures import ProcessPoolExecutor, as_completed

In [None]:
# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast
from torch.optim.lr_scheduler import OneCycleLR

In [None]:
# Data & ML
import numpy as np
import tiktoken
import wandb
from datasets import load_dataset, concatenate_datasets
from huggingface_hub import login, create_repo, upload_folder, HfApi
from tqdm import tqdm
from bs4 import BeautifulSoup, NavigableString

In [None]:
"""## CONFIGURATION"""

In [None]:
MODEL_CONFIG = {
    "vocab_size": 50281,
    "d_model": 1024,
    "n_heads": 16,
    "gqa_groups": 4,
    "max_len": 1024,
    "d_ff": 2560,
    "eps": 1e-5,
    "dropout_p": 0.0,
    "blocks": 24,
}  # ~400M params - keep as is

TRAINING_CONFIG = {
    "batch_size": 32,
    "max_length": 1024,
    "stride": 512,              # CHANGED: 50% overlap for better coverage
    "gradient_accumulation_steps": 4,
    "learning_rate": 6e-4,       # CHANGED: Higher LR, cosine will bring it down
    "weight_decay": 0.1,
    "beta1": 0.9,
    "beta2": 0.95,
    "eps": 1e-8,
    "warmup_steps": 2000,        # CHANGED: Longer warmup for stability
    "max_steps": 100_000,        # CHANGED: ~13B tokens, plenty for this model
    "eval_freq": 2000,
    "eval_iter": 200,
    "save_freq": 5000,
    "grad_clip": 1.0,            # CHANGED: Tighter clipping for stability
    "num_workers": 4,
    "seed": 42,
}

DATA_CONFIG = {
    "dataset_name": "bigbio/pubmed_qa",  # CHANGED: Better quality QA dataset
    "text_column": "context",
    "max_length": 1024,
    "stride": 512,               # Match above
    "train_split": 0.98,         # CHANGED: More training data
    "max_train_samples": 5_000_000,  # CHANGED: More docs if available
    "chunk_size": 10000,
    "use_clean": False,
}

WANDB_CONFIG = {
    "project": "MedAssist-GPT-Pretraining",
    "entity": "kunjcr2-dreamable",  # Your wandb username
    "name": "medassist-303M-test",
}

HF_CONFIG = {
    "repo_id": "kunjcr2/MedAssist-GPT-401M",  # Change this!
    "upload_checkpoints": True,
    "upload_frequency": 20000,  # Upload every N steps
}

INFERENCE_CONFIG = {
    "max_new_tokens": 100,
    "temperature": 0.6,
    "prompt_text": "A patient was admitted with severe headache. Initial assessment revealed"
}

In [None]:
"""## Architecture"""

In [None]:
class RoPE(nn.Module):
    """Rotary Position Embeddings (RoPE)"""
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        assert d_model % 2 == 0, "d_model must be even for RoPE"

        self.d_model = d_model
        self.max_len = max_len

        # Position indices - tensor (0,1,2,...,max_len) of size (max_len, 1)
        self.register_buffer('position_ids', torch.arange(max_len).unsqueeze(1))

        # Frequency terms
        self.register_buffer(
            'div_term',
            torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
            # e^(2i*(-log(10000))/d_model)
        )

    def forward(self, x: torch.Tensor, position_offset: int = 0) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape

        # Get positions with offset for KV cache
        position_ids = self.position_ids[position_offset:position_offset + seq_len]  # (seq_len, 1)

        # Calculate angles
        angles = position_ids * self.div_term  # (seq_len, d_model/2)
        cos_vals = torch.cos(angles)
        sin_vals = torch.sin(angles)

        # Reshape for rotation
        x_pairs = x.view(batch_size, seq_len, d_model // 2, 2)  # (b, s, d//2, 2)
        x_even = x_pairs[..., 0]  # (b, s, d//2)
        x_odd = x_pairs[..., 1]  # (b, s, d//2)

        # Apply rotation
        rotated_even = x_even * cos_vals - x_odd * sin_vals
        rotated_odd = x_even * sin_vals + x_odd * cos_vals

        # Reconstruct
        rotated_pairs = torch.stack([rotated_even, rotated_odd], dim=-1)  # (b, s, d//2, 2)
        rotated_x = rotated_pairs.view(batch_size, seq_len, d_model)  # (b, s, d)

        return rotated_x

In [None]:
class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (GQA) with RoPE"""
    def __init__(
        self,
        d_model: int = 512,
        n_heads: int = 8,
        gqa_groups: int = 2,
        max_len: int = 1024,
    ):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        assert n_heads % gqa_groups == 0, "n_heads must be divisible by gqa_groups"

        self.d_model = d_model
        self.n_heads = n_heads
        self.gqa_groups = gqa_groups
        self.head_dim = d_model // n_heads
        self.n_kv_heads = n_heads // gqa_groups
        self.max_len = max_len

        # Projections (bias-free)
        self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False)

        # RoPE for Q and K
        self.rope_q = RoPE(d_model=n_heads * self.head_dim, max_len=max_len)
        self.rope_k = RoPE(d_model=self.n_kv_heads * self.head_dim, max_len=max_len)

    def forward(
        self,
        x: torch.Tensor,
        past_key_value: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False
    ):
        B, T, C = x.shape

        # Determine position offset from cache
        if past_key_value is not None:
            # past_key_value: (past_k, past_v) each of shape (B, H_kv, past_len, D)
            position_offset = past_key_value[0].shape[2]
        else:
            position_offset = 0

        # Project Q, K, V
        q = self.q_proj(x)  # (B, T, H*D)
        k = self.k_proj(x)  # (B, T, H_kv*D)
        v = self.v_proj(x)  # (B, T, H_kv*D)

        # Apply RoPE with position offset
        q = self.rope_q(q, position_offset=position_offset)
        k = self.rope_k(k, position_offset=position_offset)

        # Reshape to heads
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # (B, H, T, D)
        k = k.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)  # (B, H_kv, T, D)
        v = v.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)  # (B, H_kv, T, D)

        # Concatenate with past KV if present
        if past_key_value is not None:
            past_k, past_v = past_key_value
            k = torch.cat([past_k, k], dim=2)  # (B, H_kv, past_len + T, D)
            v = torch.cat([past_v, v], dim=2)  # (B, H_kv, past_len + T, D)

        # Store new cache if needed (before GQA expansion)
        new_cache = (k, v) if use_cache else None

        # Expand K and V for GQA
        expand_factor = self.n_heads // self.n_kv_heads
        k_expanded = k.repeat_interleave(expand_factor, dim=1)  # (B, H, total_len, D)
        v_expanded = v.repeat_interleave(expand_factor, dim=1)  # (B, H, total_len, D)

        # Scaled dot-product attention
        # When using cache, we only have new queries attending to all keys
        # is_causal=True only works when q and k have same length
        if past_key_value is not None:
            # During generation: q has length 1 (or few), k/v have full length
            # Cannot use is_causal=True, need explicit causal mask or no mask
            # For autoregressive generation of single token, no mask needed
            out = F.scaled_dot_product_attention(
                q, k_expanded, v_expanded,
                attn_mask=None,
                dropout_p=0.0,
                is_causal=False  # Single query attends to all past + current
            )
        else:
            # Training or first inference step: standard causal attention
            out = F.scaled_dot_product_attention(
                q, k_expanded, v_expanded,
                attn_mask=None,
                dropout_p=0.0,
                is_causal=True
            )

        # Merge heads
        out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)

        # Output projection
        out = self.o_proj(out)

        if use_cache:
            return out, new_cache
        return out

In [None]:
class SwiGLU_MLP(nn.Module):
    """SwiGLU Feed-Forward Network"""
    def __init__(self, d_model: int = 512, d_ff: int = 2048):
        super().__init__()
        # Fused up + gate projection
        self.w1 = nn.Linear(d_model, 2 * d_ff, bias=False)
        # Down projection
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        up, gate = self.w1(x).chunk(2, dim=-1) # breaks it into 2 parts - (b, s, d_ff)
        x = up * F.silu(gate)  # SwiGLU activation - (b,s,d_ff) * (b,s,d_ff) = (b,s,d_ff)
        x = self.w2(x)  # (b,s,d_model)
        return x

In [None]:
class TransformerBlock(nn.Module):
    """Transformer block with pre-norm and residual connections"""
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        self.rms1 = nn.RMSNorm(config["d_model"], eps=config["eps"])
        self.rms2 = nn.RMSNorm(config["d_model"], eps=config["eps"])

        self.attn = GroupedQueryAttention(
            d_model=config["d_model"],
            n_heads=config["n_heads"],
            gqa_groups=config["gqa_groups"],
            max_len=config["max_len"]
        )

        self.mlp = SwiGLU_MLP(
            d_model=config["d_model"],
            d_ff=config["d_ff"]
        )

        self.dropout = nn.Dropout(config["dropout_p"])

    def forward(
        self,
        x: torch.Tensor,
        layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
        use_cache: bool = False
    ):
        # Pre-norm attention with optional cache
        if use_cache:
            attn_out, present = self.attn(self.rms1(x), past_key_value=layer_past, use_cache=True)
            x = x + self.dropout(attn_out)
        else:
            x = x + self.dropout(self.attn(self.rms1(x)))
            present = None

        # Pre-norm MLP
        x = x + self.dropout(self.mlp(self.rms2(x)))

        if use_cache:
            return x, present
        return x

In [None]:
class MedAssistGPT(nn.Module):
    """Main model class"""
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        self.config = config

        # Token embeddings
        self.embed = nn.Embedding(config["vocab_size"], config["d_model"])

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config["blocks"])
        ])

        # Final RMSNorm
        self.final_rms = nn.RMSNorm(config["d_model"], eps=config["eps"])

        # Language model head (weight-tied with embeddings)
        self.lm_head = nn.Linear(config["d_model"], config["vocab_size"], bias=False)
        self.lm_head.weight = self.embed.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(
        self,
        input_ids: torch.Tensor,
        past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
        use_cache: bool = False
    ):
        """
        Args:
            input_ids: (batch, seq_len)
            past_key_values: List of (k, v) tuples, one per layer. None for training.
            use_cache: If True, return (logits, new_past_key_values). If False, return logits only.
        """
        # input_ids: (batch, seq_len)
        h = self.embed(input_ids)  # (batch, seq_len, d_model)

        # Initialize presents list for caching
        presents = [] if use_cache else None

        # Pass through transformer blocks
        for i, block in enumerate(self.blocks):
            layer_past = past_key_values[i] if past_key_values is not None else None

            if use_cache:
                h, present = block(h, layer_past=layer_past, use_cache=True)
                presents.append(present)
            else:
                h = block(h)

        # Final normalization
        h = self.final_rms(h)

        # Language model head
        logits = self.lm_head(h)  # (batch, seq_len, vocab_size)

        if use_cache:
            return logits, presents
        return logits

    def count_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

In [None]:
"""## Data Loading

### clean()
"""

In [None]:
def clean(xml_str: str) -> str:
    """
    YOUR ORIGINAL XML CLEANING FUNCTION
    Converts JATS/PMC XML to readable plain text
    Removes tables, figures, citations, keeps narrative text
    """
    soup = BeautifulSoup(xml_str, "lxml-xml")

    # 1) Remove whole non-narrative blocks
    drop_whole = [
        "ref-list", "fig", "fig-group", "table-wrap", "table", "thead", "tbody",
        "tr", "td", "th", "graphic", "media", "supplementary-material", "back",
        "sec-meta", "table-wrap-foot", "caption"
    ]
    for name in drop_whole:
        for tag in soup.find_all(name):
            tag.decompose()

    # 2) Remove cross-references entirely (citations, table/fig pointers)
    for tag in soup.find_all("xref"):
        tag.decompose()

    # 3) Preserve disp-quote as plain paragraphs
    for dq in soup.find_all("disp-quote"):
        txt = dq.get_text(" ", strip=True)
        dq.replace_with(NavigableString(("\n" + txt + "\n") if txt else ""))

    # 4) Turn <title> into clean section headers
    for t in soup.find_all("title"):
        title_txt = t.get_text(" ", strip=True)
        t.replace_with(NavigableString("\n\n" + title_txt + "\n"))

    # 5) Ensure paragraphs end cleanly; unwrap inline tags
    inline_unwrap = [
        "italic", "bold", "underline", "sc", "em", "strong", "sup", "sub",
        "styled-content", "inline-formula", "monospace"
    ]
    for p in soup.find_all("p"):
        for name in inline_unwrap:
            for it in p.find_all(name):
                it.unwrap()
        p.insert_after(NavigableString("\n\n"))
        p.unwrap()

    # 6) Unwrap remaining structural containers
    for name in ["sec", "body", "front", "article", "abstract", "boxed-text", "list", "list-item"]:
        for tag in soup.find_all(name):
            tag.unwrap()

    # 7) Extract text and clean up
    text = soup.get_text()
    text = html.unescape(text)

    # Post-processing cleanup
    text = re.sub(r"\(\s*(?:\d+\s*(?:[-–]\s*\d+)?\s*(?:[,;]\s*)?)+\)", "", text)
    text = re.sub(r"\(\s*(?:[Ff]ig(?:ure)?\.?\s*\d+|[Tt]able\s*\d+)\s*\)", "", text)
    text = re.sub(r"\(\s*\)", "", text)
    text = re.sub(r"\[\s*\]", "", text)
    text = re.sub(r"\s+([,.;:!?])", r"\1", text)
    text = re.sub(r"([,.;:!?])\s*\1+", r"\1 ", text)
    text = re.sub(r"[ \t]{2,}", " ", text)
    text = re.sub(r"\n{3,}", "\n\n", text)

    text = "\n".join(line.rstrip() for line in text.splitlines())
    return text.strip()

In [None]:
"""### MemortMappedDataset class"""

In [None]:
class MemoryMappedDataset(Dataset):
    """
    Dataset that uses memory-mapped files for ZERO RAM overhead!

    HOW IT WORKS:
    1. Data is stored on disk in binary format (.npy files)
    2. When you access data[i], OS loads ONLY that piece into RAM
    3. OS automatically evicts old data when RAM gets full
    4. You get 100GB+ dataset working with 5GB RAM!
    """
    def __init__(self, cache_dir: Path):
        self.cache_dir = Path(cache_dir)

        # Load metadata (tiny file, ~1KB)
        with open(self.cache_dir / "metadata.pkl", "rb") as f:
            self.metadata = pickle.load(f)

        # Memory-map the input/target arrays
        # mode='r' = read-only, mmap_mode='r' = memory-mapped read
        # CRITICAL: np.load with mmap_mode DOES NOT load data into RAM!
        # It just maps the file, OS loads pages on-demand
        self.inputs_mmap = np.load(
            self.cache_dir / "inputs.npy",
            mmap_mode='r'  # ← THIS IS THE MAGIC! OS manages memory
        )
        self.targets_mmap = np.load(
            self.cache_dir / "targets.npy",
            mmap_mode='r'
        )

        print(f"Memory-mapped dataset: {len(self.inputs_mmap):,} samples")
        print(f"RAM overhead: ~0 MB (OS manages it)")

    def __len__(self) -> int:
        return len(self.inputs_mmap)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # OS loads ONLY this row into RAM (4KB per sample)
        # When RAM is full, OS evicts least-recently-used data
        input_ids = torch.from_numpy(self.inputs_mmap[idx].copy()).long()
        target_ids = torch.from_numpy(self.targets_mmap[idx].copy()).long()

        return input_ids, target_ids

In [None]:
"""### process_single_chunk()"""

In [None]:
def process_single_chunk(
    doc_batch: List[Dict],
    tokenizer_name: str,  # Can't pickle tokenizer, so pass name
    max_length: int,
    stride: int,
    use_clean: bool,
    text_column: str,
    chunk_id: int,
) -> Tuple[List[np.ndarray], List[np.ndarray], int, int]:
    """
    Worker function that processes one chunk

    WHY SEPARATE FUNCTION:
    - ProcessPoolExecutor spawns new Python processes
    - Each process needs its own tokenizer instance
    - Can't share tokenizer objects across processes (pickling issues)

    WHAT IT DOES:
    1. Creates its own tokenizer
    2. Processes its batch of documents
    3. Returns results to main process
    """
    import tiktoken

    # Each worker creates its own tokenizer
    tokenizer = tiktoken.get_encoding(tokenizer_name)
    vocab_size = tokenizer.n_vocab

    chunk_tokens = []
    docs_processed = 0
    total_tokens = 0

    # Process each document in this chunk
    for doc in doc_batch:
        # Extract text
        if text_column in doc:
            text = doc[text_column]
        elif 'text' in doc:
            text = doc['text']
        else:
            continue

        if not text or len(text) < 100:
            continue

        # Apply clean function if enabled
        if use_clean:
            try:
                text = clean(text)  # Your clean function
            except:
                continue

        if not text or len(text) < 100:
            continue

        # Tokenize
        try:
            tokens = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
            tokens = [t for t in tokens if 0 <= t < vocab_size]

            if len(tokens) < 10:
                continue

            chunk_tokens.extend(tokens)
            total_tokens += len(tokens)
            docs_processed += 1

        except:
            continue

    # Create sliding windows from all tokens in this chunk
    samples = create_sliding_windows(chunk_tokens, max_length, stride, vocab_size)

    if samples:
        inputs, targets = zip(*samples)
        return list(inputs), list(targets), docs_processed, total_tokens
    else:
        return [], [], docs_processed, total_tokens

In [None]:
"""### process_dataset_in_chunks()"""

In [None]:
def process_dataset_in_chunks(
    dataset_name: str,
    tokenizer,
    cache_dir: Path,
    max_length: int = 1024,
    stride: int = 1024,
    max_samples: int = None,
    chunk_size: int = 5000,
    use_clean: bool = True,
    text_column: str = "full_text"
):
    """
    PARALLEL VERSION - Drop-in replacement for your original function

    HOW PARALLELIZATION WORKS:
    1. Load documents in batches (streaming)
    2. Distribute batches to worker processes
    3. Each worker: clean → tokenize → create windows
    4. Main process: collect results and save

    NUM WORKERS:
    - Uses all CPU cores by default
    - On 8-core: 8 chunks processed simultaneously
    - On Colab: ~2 cores (still 2x speedup!)
    """

    cache_dir = Path(cache_dir)
    cache_dir.mkdir(parents=True, exist_ok=True)

    # Determine number of workers
    num_workers = min(mp.cpu_count(), 8)  # Max 8 workers (diminishing returns)

    print(f"Processing dataset: {dataset_name}")
    print(f"Chunk size: {chunk_size} documents")
    print(f"Cleaning enabled: {use_clean}")
    print(f"Parallel workers: {num_workers}")

    # Load dataset in streaming mode
    dataset = load_dataset(dataset_name, split="train", streaming=True)

    if max_samples:
        dataset = dataset.take(max_samples)

    # Accumulate results from all workers
    all_inputs = []
    all_targets = []
    total_docs_processed = 0
    total_tokens = 0

    print("Loading documents into batches...")

    # Collect documents into batches
    # WHY: Can't parallelize streaming iterator directly
    # So we batch first, then parallelize batch processing
    doc_batches = []
    current_batch = []

    for doc in tqdm(dataset, desc="Batching docs"):
        current_batch.append(doc)

        if len(current_batch) >= chunk_size:
            doc_batches.append(current_batch)
            current_batch = []

    # Don't forget the last batch
    if current_batch:
        doc_batches.append(current_batch)

    print(f"Created {len(doc_batches)} batches of ~{chunk_size} documents each")
    print(f"Processing with {num_workers} parallel workers...")

    # PARALLEL PROCESSING STARTS HERE!
    # ProcessPoolExecutor creates worker processes
    # Each worker gets its own batch to process
    with ProcessPoolExecutor(max_workers=num_workers) as executor:

        # Submit all batches to workers
        # partial() pre-fills the arguments that are same for all workers
        worker_fn = partial(
            process_single_chunk,
            tokenizer_name="p50k_base",  # Pass tokenizer name, not object
            max_length=max_length,
            stride=stride,
            use_clean=use_clean,
            text_column=text_column,
        )

        # Submit all jobs
        futures = []
        for chunk_id, doc_batch in enumerate(doc_batches):
            future = executor.submit(worker_fn, doc_batch, chunk_id=chunk_id)
            futures.append(future)

        # Collect results as they complete
        # as_completed() returns futures as soon as they finish (not in order)
        # This gives us progress updates in real-time!
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing chunks"):
            try:
                # Get results from this worker
                inputs, targets, docs_proc, tokens_proc = future.result()

                # Accumulate
                all_inputs.extend(inputs)
                all_targets.extend(targets)
                total_docs_processed += docs_proc
                total_tokens += tokens_proc

                # Show progress
                if len(inputs) > 0:
                    print(f"Chunk done: {len(inputs)} samples, {docs_proc} docs, {tokens_proc:,} tokens")

            except Exception as e:
                print(f"Worker failed: {e}")
                continue

    print(f"\nParallel processing complete!")
    print(f"   Documents processed: {total_docs_processed:,}")
    print(f"   Total tokens: {total_tokens:,}")
    print(f"   Training samples: {len(all_inputs):,}")

    # Save to disk (same as before)
    print("Saving to disk as memory-mapped arrays...")

    inputs_array = np.array(all_inputs, dtype=np.int32)
    targets_array = np.array(all_targets, dtype=np.int32)

    np.save(cache_dir / "inputs.npy", inputs_array)
    np.save(cache_dir / "targets.npy", targets_array)

    # Save metadata
    metadata = {
        "num_samples": len(all_inputs),
        "max_length": max_length,
        "stride": stride,
        "vocab_size": tokenizer.n_vocab,
        "docs_processed": total_docs_processed,
        "total_tokens": total_tokens,
    }

    with open(cache_dir / "metadata.pkl", "wb") as f:
        pickle.dump(metadata, f)

    print(f"Saved to {cache_dir}")

    # Clear RAM
    del all_inputs, all_targets, inputs_array, targets_array
    gc.collect()

    return cache_dir

In [None]:
"""### create_sliding_windows()"""

In [None]:
def create_sliding_windows(
    tokens: List[int],
    max_length: int,
    stride: int,
    vocab_size: int
) -> List[Tuple[np.ndarray, np.ndarray]]:
    """
    Create sliding window samples from token list

    EXAMPLE:
    tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    max_length = 4, stride = 2

    Samples:
    input:  [1, 2, 3, 4]  target: [2, 3, 4, 5]
    input:  [3, 4, 5, 6]  target: [4, 5, 6, 7]
    input:  [5, 6, 7, 8]  target: [6, 7, 8, 9]
    etc.
    """
    samples = []
    tokens = np.array(tokens, dtype=np.int32)

    # Verify tokens are in valid range
    if len(tokens) == 0:
        return samples

    # Clip to valid range (safety)
    tokens = np.clip(tokens, 0, vocab_size - 1)

    # Create sliding windows
    for i in range(0, len(tokens) - max_length, stride):
        input_ids = tokens[i:i+max_length]
        target_ids = tokens[i+1:i+max_length+1]

        # Ensure both are exactly max_length
        if len(input_ids) == max_length and len(target_ids) == max_length:
            samples.append((input_ids, target_ids))

    return samples

In [None]:
"""### prepare_mdeical_data()"""

In [None]:
def prepare_medical_data(
    config: Dict[str, Any],
    tokenizer
) -> Tuple[Path, Path]:
    """
    Main function: Prepares training and validation data efficiently

    WHAT THIS DOES:
    1. Checks if cache exists (skip if already processed)
    2. Splits dataset into train/val
    3. Processes each split in chunks
    4. Saves as memory-mapped files
    5. Returns paths to cached data

    FIRST RUN: ~30 minutes (processes and caches)
    SUBSEQUENT RUNS: ~1 second (loads from cache)
    """

    train_cache = Path("./data_cache/train")
    val_cache = Path("./data_cache/val")

    # Check if already cached
    train_exists = (train_cache / "metadata.pkl").exists()
    val_exists = (val_cache / "metadata.pkl").exists()

    if train_exists and val_exists:
        print("Found cached data! Skipping processing.")
        print(f"   Train cache: {train_cache}")
        print(f"   Val cache: {val_cache}")
        return train_cache, val_cache

    # Calculate split sizes
    total_samples = config.get("max_train_samples", 1_000_000)
    train_size = int(total_samples * config.get("train_split", 0.95))
    val_size = total_samples - train_size

    print(f"Dataset split:")
    print(f"   Training: {train_size:,} documents")
    print(f"   Validation: {val_size:,} documents")

    # Process training data
    if not train_exists:
        print("\n" + "="*80)
        print("PROCESSING TRAINING DATA")
        print("="*80)

        # Create temporary streaming dataset for train split
        dataset = load_dataset(
            config["dataset_name"],
            split="train",
            streaming=True
        )
        train_dataset = dataset.take(train_size)

        # Process and cache
        # We create a temporary generator to process
        def train_generator():
            for item in train_dataset:
                yield item

        # Process using your dataset directly
        process_dataset_in_chunks(
            dataset_name=config["dataset_name"],
            tokenizer=tokenizer,
            cache_dir=train_cache,
            max_length=config.get("max_length", 1024),
            stride=config.get("stride", 1024),
            max_samples=train_size,
            chunk_size=config.get("chunk_size", 5000),
            use_clean=config.get("use_clean", True),
            text_column=config.get("text_column", "full_text")
        )

    # Process validation data
    if not val_exists:
        print("\n" + "="*80)
        print("PROCESSING VALIDATION DATA")
        print("="*80)

        # Create dataset that skips training samples
        dataset = load_dataset(
            config["dataset_name"],
            split="train",
            streaming=True
        )
        val_dataset = dataset.skip(train_size).take(val_size)

        # Process and cache
        # For validation, we need to handle the skip
        # Easier to just process with offset
        temp_config = config.copy()
        temp_config["dataset_offset"] = train_size

        process_dataset_in_chunks(
            dataset_name=config["dataset_name"],
            tokenizer=tokenizer,
            cache_dir=val_cache,
            max_length=config.get("max_length", 1024),
            stride=config.get("stride", 1024),
            max_samples=val_size,
            chunk_size=config.get("chunk_size", 5000),
            use_clean=config.get("use_clean", True),
            text_column=config.get("text_column", "full_text")
        )

    print("\n" + "="*80)
    print("DATA PREPARATION COMPLETE!")
    print("="*80)

    return train_cache, val_cache

In [None]:
"""### create_dataloader()"""

In [None]:
def create_dataloader(
    cache_dir: Path,
    batch_size: int,
    shuffle: bool = True,
    num_workers: int = 2,  # Lower for memory-mapped files
) -> DataLoader:
    """
    Create DataLoader from cached data

    NUM_WORKERS = 2:
    - Memory-mapped files don't benefit from many workers
    - OS handles the parallel I/O better than Python
    - 2 workers is sweet spot for prefetching
    """

    dataset = MemoryMappedDataset(cache_dir)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,  # Faster GPU transfer
        drop_last=True,
        persistent_workers=True,  # Keep workers alive between epochs
    )

In [None]:
"""## Training Utils"""

In [None]:
def calc_loss_batch(
    input_batch: torch.Tensor,
    target_batch: torch.Tensor,
    model: nn.Module,
    device: torch.device
) -> torch.Tensor:
    """Calculate loss for a single batch"""
    input_batch = input_batch.to(device, non_blocking=True)
    target_batch = target_batch.to(device, non_blocking=True)

    with autocast("cuda", torch.bfloat16):
        logits = model(input_batch)
        loss = F.cross_entropy(
            logits.flatten(0, 1),
            target_batch.flatten()
        )

    return loss

In [None]:
def evaluate_model(
    model: nn.Module,
    data_loader: DataLoader,
    device: torch.device,
    num_batches: int = 100
) -> float:
    """Evaluate model on validation set"""
    model.eval()
    total_loss = 0.0
    num_batches = min(num_batches, len(data_loader))

    with torch.no_grad():
        for i, (input_batch, target_batch) in enumerate(data_loader):
            if i >= num_batches:
                break
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()

    model.train()
    return total_loss / num_batches

In [None]:
def save_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: Any,
    step: int,
    loss: float,
    save_dir: Path,
    config: Dict[str, Any]
):
    """Save model checkpoint"""
    checkpoint = {
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
        'config': config,
    }

    save_path = save_dir / f"checkpoint_step_{step}.pt"
    torch.save(checkpoint, save_path)
    print(f"Checkpoint saved: {save_path}")

    return save_path

In [None]:
def upload_to_huggingface(
    model: nn.Module,
    save_dir: Path,
    repo_id: str,
    config: Dict[str, Any],
    step: int
):
    """Upload model to HuggingFace Hub"""
    try:
        # Save model weights
        torch.save(model.state_dict(), save_dir / "pytorch_model.bin")

        # Save config
        with open(save_dir / "config.json", "w") as f:
            json.dump(config, f, indent=2)

        # Upload to HF
        api = HfApi()
        api.upload_folder(
            folder_path=str(save_dir),
            repo_id=repo_id,
            repo_type="model",
            commit_message=f"Training checkpoint at step {step}"
        )

        print(f"Uploaded to HuggingFace: {repo_id}")
    except Exception as e:
        print(f"Failed to upload to HuggingFace: {e}")

In [None]:
"""## Training loop"""

In [None]:
def train(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: Any,
    device: torch.device,
    config: Dict[str, Any],
    save_dir: Path,
    hf_repo_id: str = None
):
    """Main training loop with all optimizations"""

    print("=" * 80)
    print("STARTING MEDICAL LLM PRETRAINING")
    print("=" * 80)
    print(f"Model: {model.count_parameters():,} parameters")
    print(f"Training batches: {len(train_loader):,}")
    print(f"Max steps: {config['max_steps']:,}")
    print(f"Effective batch size: {config['batch_size'] * config['gradient_accumulation_steps']}")
    print(f"Device: {device}")
    print("=" * 80)

    model.train()
    global_step = 0
    tokens_seen = 0
    best_val_loss = float('inf')

    train_losses = []
    val_losses = []

    grad_accum = config["gradient_accumulation_steps"]

    try:
        for epoch in range(100):  # Virtually unlimited epochs
            epoch_loss = 0.0
            epoch_steps = 0

            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

            for batch_idx, (input_batch, target_batch) in enumerate(progress_bar):
                # Forward pass
                loss = calc_loss_batch(input_batch, target_batch, model, device)

                # Scale loss for gradient accumulation
                loss = loss / grad_accum
                loss.backward()

                # Accumulate
                if (batch_idx + 1) % grad_accum == 0 or (batch_idx + 1) == len(train_loader):
                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        max_norm=config["grad_clip"]
                    )

                    # Optimizer step
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

                    # Update counters
                    global_step += 1
                    tokens_seen += input_batch.numel() * grad_accum

                    # Log training loss
                    # train_losses.append(loss.item() * grad_accum)
                    epoch_loss += loss.item() * grad_accum
                    epoch_steps += 1

                    # Update progress bar
                    progress_bar.set_postfix({
                        'loss': f"{loss.item() * grad_accum:.4f}",
                        'lr': f"{scheduler.get_last_lr()[0]:.2e}",
                        'step': global_step
                    })

                    # Evaluation
                    if global_step % config["eval_freq"] == 0:
                        val_loss = evaluate_model(
                            model, val_loader, device, config["eval_iter"]
                        )
                        # val_losses.append(val_loss)

                        # Log to wandb
                        wandb.log({
                            "val_loss": val_loss,
                            "learning_rate": scheduler.get_last_lr()[0],
                            "step": global_step,
                        })

                        # Check for improvement
                        if val_loss < best_val_loss:
                            best_val_loss = val_loss
                            print(f"\nNew best validation loss: {val_loss:.4f}")

                    # Save checkpoint
                    if global_step % config["save_freq"] == 0:
                        save_checkpoint(
                            model, optimizer, scheduler,
                            global_step, train_losses[-1],
                            save_dir, MODEL_CONFIG
                        )

                        # Upload to HuggingFace
                        if hf_repo_id and config.get("upload_checkpoints", False):
                            if global_step % config.get("upload_frequency", 1000) == 0:
                                upload_to_huggingface(
                                    model, save_dir / "hf_upload",
                                    hf_repo_id, MODEL_CONFIG, global_step
                                )

                    # Check if max steps reached
                    if global_step >= config["max_steps"]:
                        print(f"\nReached max steps ({config['max_steps']})")
                        raise StopIteration

                    wandb.log({
                        "train_loss": loss.item() * grad_accum,
                        "tokens_seen": tokens_seen
                    })

            # End of epoch summary
            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else float('inf')
            print(f"\nEpoch {epoch+1} complete - Avg loss: {avg_epoch_loss:.4f}")

    except (KeyboardInterrupt, StopIteration):
        print("\nTraining stopped")

    # Final checkpoint
    print("\nSaving final checkpoint...")
    save_checkpoint(
        model, optimizer, scheduler,
        global_step, train_losses[-1] if train_losses else 0,
        save_dir, MODEL_CONFIG
    )

    print(f"\nTraining complete!")
    print(f" Total steps: {global_step:,}")
    print(f" Total tokens: {tokens_seen:,}")
    print(f" Best validation loss: {best_val_loss:.4f}")

    return train_losses, val_losses

In [None]:
"""## main.py"""

In [None]:
def main():
    # Set random seeds
    torch.manual_seed(TRAINING_CONFIG["seed"])
    np.random.seed(TRAINING_CONFIG["seed"])

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Create save directory
    save_dir = Path("./checkpoints")
    save_dir.mkdir(exist_ok=True)

    # Initialize tokenizer
    print("Loading tokenizer...")
    tokenizer = tiktoken.get_encoding("p50k_base")

    # Load and prepare data
    train_tokens, val_tokens = prepare_medical_data(DATA_CONFIG, tokenizer)

    # Create dataloaders
    print("Creating dataloaders...")
    train_loader = create_dataloader(
        Path("/content/data_cache/train/"),
        batch_size=TRAINING_CONFIG["batch_size"],
        # max_length=TRAINING_CONFIG["max_length"],
        # stride=TRAINING_CONFIG["stride"],
        shuffle=True,
        num_workers=TRAINING_CONFIG["num_workers"]
    )

    val_loader = create_dataloader(
        Path("/content/data_cache/val/"),
        batch_size=TRAINING_CONFIG["batch_size"],
        # max_length=TRAINING_CONFIG["max_length"],
        # stride=TRAINING_CONFIG["stride"],
        shuffle=False,
        num_workers=TRAINING_CONFIG["num_workers"]
    )

    # Initialize model
    print("Initializing model...")
    model = MedAssistGPT(MODEL_CONFIG)
    model = model.to(device)

    # Compile model (PyTorch 2.0+)
    if hasattr(torch, 'compile'):
        print("Compiling model...")
        model = torch.compile(model, mode="default", fullgraph=False, dynamic=True)

    print(f"Model has {model.count_parameters():,} parameters")

    # Initialize optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=TRAINING_CONFIG["learning_rate"],
        weight_decay=TRAINING_CONFIG["weight_decay"],
        betas=(TRAINING_CONFIG["beta1"], TRAINING_CONFIG["beta2"]),
        eps=TRAINING_CONFIG["eps"]
    )

    # Initialize scheduler
    scheduler = OneCycleLR(
        optimizer,
        max_lr=TRAINING_CONFIG["learning_rate"],
        total_steps=TRAINING_CONFIG["max_steps"],
        pct_start=TRAINING_CONFIG["warmup_steps"] / TRAINING_CONFIG["max_steps"],
        anneal_strategy='cos',
        div_factor=10,
        final_div_factor=100
    )

    # Initialize wandb
    wandb.init(
        project=WANDB_CONFIG["project"],
        entity=WANDB_CONFIG["entity"],
        name=WANDB_CONFIG["name"],
        config={**MODEL_CONFIG, **TRAINING_CONFIG, **DATA_CONFIG}
    )

    # Login to HuggingFace (if uploading)
    if HF_CONFIG.get("upload_checkpoints", False):
        try:
            login()
            create_repo(HF_CONFIG["repo_id"], repo_type="model", exist_ok=True)
            print(f"HuggingFace repo ready: {HF_CONFIG['repo_id']}")
        except Exception as e:
            print(f"HuggingFace setup failed: {e}")
            HF_CONFIG["upload_checkpoints"] = False


    # Train!
    train(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        config=TRAINING_CONFIG,
        save_dir=save_dir,
        hf_repo_id=HF_CONFIG["repo_id"] if HF_CONFIG.get("upload_checkpoints") else None
    )

    wandb.finish()
    print("\nAll done!")

In [None]:
if __name__ == "__main__":
    main()

In [None]:
"""## Inference"""

In [None]:
def generate_text(model, tokenizer, inference_config: Dict[str, Any], device: str = 'cuda'):
    """
    Generate text using KV cache for efficient autoregressive generation.

    First step: Process the full prompt, cache K/V.
    Subsequent steps: Only feed the new token, reuse cached K/V.
    """
    model.eval()
    model.to(device)

    text = inference_config["prompt_text"]
    max_new_tokens = inference_config.get("max_new_tokens", 50)
    temperature = inference_config.get("temperature", 0.8)

    # Encode the input text
    encoded_input = tokenizer.encode(text)
    input_ids = torch.tensor(encoded_input, dtype=torch.long).unsqueeze(0).to(device)

    generated_tokens = []
    past_key_values = None

    with torch.no_grad():
        for step in range(max_new_tokens):
            # First step: process full prompt
            # Subsequent steps: only feed the last generated token
            if past_key_values is None:
                current_input = input_ids
            else:
                current_input = input_ids[:, -1:]  # Only the new token

            # Forward pass with caching
            logits, past_key_values = model(current_input, past_key_values=past_key_values, use_cache=True)

            # Get logits for the last position
            next_token_logits = logits[:, -1, :]  # (1, vocab_size)

            # Apply temperature sampling
            if temperature == 0.0:
                next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(1)
            else:
                probs = torch.softmax(next_token_logits / temperature, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)

            # Append the new token
            generated_tokens.append(next_token.item())
            input_ids = torch.cat([input_ids, next_token], dim=1)

            # Stop if end-of-text token is generated
            if next_token.item() == tokenizer.eot_token:
                break

    # Decode the generated tokens
    decoded_output = tokenizer.decode(generated_tokens)
    return text + decoded_output

In [None]:
def load_model_from_checkpoint(checkpoint_path: Path, model_config: Dict[str, Any], device: torch.device) -> nn.Module:
    """Loads a MedAssistGPT model from a checkpoint file."""
    print(f"Loading model from {checkpoint_path}...")
    model = MedAssistGPT(model_config)
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Fix: Remove '_orig_mod.' prefix from state_dict keys if present
    state_dict = checkpoint['model_state_dict']
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('_orig_mod.'):
            new_state_dict[k[len('_orig_mod.'):]] = v
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.to(device)
    model.eval() # Set model to evaluation mode
    print("Model loaded successfully.")
    return model

In [None]:
# Example usage:
# First, set up device and tokenizer as they are needed for both loading and inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = tiktoken.get_encoding("p50k_base")

In [None]:
# Specify the path to your saved checkpoint
# Replace 'checkpoints/checkpoint_step_15000.pt' with the actual path if different
checkpoint_file_path = Path("checkpoints/checkpoint_step_15000.pt")

In [None]:
# Load the model
loaded_model = load_model_from_checkpoint(checkpoint_file_path, MODEL_CONFIG, device)

In [None]:
# Now you can use the loaded_model with the generate_text function
print("\n--- Generating text with loaded model ---")
generated_output = generate_text(loaded_model, tokenizer, INFERENCE_CONFIG, device)

In [None]:
print(generated_output)