MedAssist-GPT: Complete Medical LLM Pretraining Script
=====================================================
Modern architecture with RoPE, GQA, SwiGLU, RMSNorm
Optimized for A100 GPU with Flash Attention
Automatic checkpointing and HuggingFace uploads

PRETRAINED WEIGHTS AT: https://huggingface.co/kunjcr2/MedAssistGPT

## Test

## Imports

In [1]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from datasets import load_dataset
from bs4 import BeautifulSoup, NavigableString
import tiktoken

import os
import json
import shutil
from pathlib import Path
from typing import Dict, List, Tuple, Any
import math

import pickle
import gc
from pathlib import Path
import re
import html
from bs4 import BeautifulSoup, NavigableString

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast
from torch.optim.lr_scheduler import OneCycleLR

import numpy as np
import pandas as pd
import multiprocessing as mp

from concurrent.futures import ProcessPoolExecutor, as_completed
from functools import partial

import tiktoken
import wandb
from datasets import load_dataset, concatenate_datasets
from huggingface_hub import login, create_repo, upload_folder, HfApi
from tqdm import tqdm

## CONFIGURATION

In [43]:
MODEL_CONFIG = {
    "vocab_size": 50281,
    "d_model": 512,
    "n_heads": 16,
    "gqa_groups": 4,
    "max_len": 1024,
    "d_ff": 2048,           # 4x hidden dimension
    "eps": 1e-5,
    "dropout_p": 0.1,       # No dropout during pretraining
    "blocks": 16,           # ~500M parameters
}

TRAINING_CONFIG = {
    "batch_size": 64,
    "max_length": 1024,
    "stride": 1024,
    "gradient_accumulation_steps": 2,  # Effective batch size: 128
    "learning_rate": 3e-4,
    "weight_decay": 0.1,
    "beta1": 0.9,
    "beta2": 0.95,
    "eps": 1e-8,
    "warmup_steps": 500,
    "max_steps": 50000,
    "eval_freq": 500,
    "eval_iter": 100,
    "save_freq": 1000,
    "grad_clip": 1.0,
    "num_workers": 4,
    "seed": 1496,
}

DATA_CONFIG = {
    "dataset_name": "Hack90/europe_pmc_articles_part_2",  # Your dataset
    "text_column": "full_text",  # Column with text
    "max_length": 1024,  # Sequence length
    "stride": 1024,  # Window stride
    "train_split": 0.95,  # 95% train, 5% val
    "max_train_samples": 1_000_000,  # Total documents to use
    "chunk_size": 10000,  # Process 5K docs at a time
    "use_clean": True,  # Use your clean() function
}

WANDB_CONFIG = {
    "project": "MedAssist-GPT-Pretraining",
    "entity": "kunjcr2-dreamable",  # Your wandb username
    "name": "medassist-86M",
}

HF_CONFIG = {
    "repo_id": "kunjcr2/MedAssist-GPT-125M",  # Change this!
    "upload_checkpoints": True,
    "upload_frequency": 5000,  # Upload every N steps
}

INFERENCE_CONFIG = {
    "max_new_tokens": 100,
    "temperature": 0.5,
    "prompt_text": "To live a good life"
}

## Architecture

In [15]:
class RoPE(nn.Module):
    """Rotary Position Embeddings (RoPE)"""
    def __init__(self, d_model: int, max_len: int = 5000):
        super().__init__()
        assert d_model % 2 == 0, "d_model must be even for RoPE"

        self.d_model = d_model
        self.max_len = max_len

        # Position indices - tensor (0,1,2,...,max_len) of size (max_len, 1)
        self.register_buffer('position_ids', torch.arange(max_len).unsqueeze(1))

        # Frequency terms
        self.register_buffer(
            'div_term',
            torch.exp(torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model))
            # e^(2i*(-log(10000))/d_model)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len, d_model = x.shape

        # Get positions
        position_ids = self.position_ids[:seq_len] # (seq_len, 1)

        # Calculate angles
        angles = position_ids * self.div_term # (seq_len, d_model/2)
        cos_vals = torch.cos(angles)
        sin_vals = torch.sin(angles)

        # Reshape for rotation
        x_pairs = x.view(batch_size, seq_len, d_model // 2, 2) # (b, s, d//2, 2)
        x_even = x_pairs[..., 0] # (b, s, d//2)
        x_odd = x_pairs[..., 1] # (b, s, d//2)

        # Apply rotation
        rotated_even = x_even * cos_vals - x_odd * sin_vals
        rotated_odd = x_even * sin_vals + x_odd * cos_vals

        # Reconstruct
        rotated_pairs = torch.stack([rotated_even, rotated_odd], dim=-1) # (b, s, d//2, 2)
        rotated_x = rotated_pairs.view(batch_size, seq_len, d_model) # (b, s, d)

        return rotated_x


class GroupedQueryAttention(nn.Module):
    """Grouped Query Attention (GQA) with RoPE"""
    def __init__(
        self,
        d_model: int = 512,
        n_heads: int = 8,
        gqa_groups: int = 2,
        max_len: int = 1024,
    ):
        super().__init__()
        assert d_model % n_heads == 0, "d_model must be divisible by n_heads"
        assert n_heads % gqa_groups == 0, "n_heads must be divisible by gqa_groups"

        self.d_model = d_model
        self.n_heads = n_heads
        self.gqa_groups = gqa_groups
        self.head_dim = d_model // n_heads
        self.n_kv_heads = n_heads // gqa_groups
        self.max_len = max_len

        # Projections (bias-free)
        self.q_proj = nn.Linear(d_model, n_heads * self.head_dim, bias=False)
        self.k_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
        self.v_proj = nn.Linear(d_model, self.n_kv_heads * self.head_dim, bias=False)
        self.o_proj = nn.Linear(n_heads * self.head_dim, d_model, bias=False)

        # RoPE for Q and K
        self.rope_q = RoPE(d_model=n_heads * self.head_dim, max_len=max_len)
        self.rope_k = RoPE(d_model=self.n_kv_heads * self.head_dim, max_len=max_len)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C = x.shape

        # Project Q, K, V
        q = self.q_proj(x)  # (B, T, H*D)
        k = self.k_proj(x)  # (B, T, H_kv*D)
        v = self.v_proj(x)  # (B, T, H_kv*D)

        # Apply RoPE
        q = self.rope_q(q)
        k = self.rope_k(k)

        # Reshape to heads
        q = q.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)  # (B, H, T, D)
        k = k.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)  # (B, H_kv, T, D)
        v = v.view(B, T, self.n_kv_heads, self.head_dim).transpose(1, 2)  # (B, H_kv, T, D)

        # Expand K and V for GQA
        expand_factor = self.n_heads // self.n_kv_heads
        k = k.repeat_interleave(expand_factor, dim=1)  # (B, H, T, D)
        v = v.repeat_interleave(expand_factor, dim=1)  # (B, H, T, D)

        # Scaled dot-product attention with Flash Attention if available
        out = F.scaled_dot_product_attention(
            q, k, v,
            attn_mask=None,
            dropout_p=0.0,
            is_causal=True
        )

        # Merge heads
        out = out.transpose(1, 2).contiguous().view(B, T, self.n_heads * self.head_dim)

        # Output projection
        out = self.o_proj(out)

        return out


class SwiGLU_MLP(nn.Module):
    """SwiGLU Feed-Forward Network"""
    def __init__(self, d_model: int = 512, d_ff: int = 2048):
        super().__init__()
        # Fused up + gate projection
        self.w1 = nn.Linear(d_model, 2 * d_ff, bias=False)
        # Down projection
        self.w2 = nn.Linear(d_ff, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        up, gate = self.w1(x).chunk(2, dim=-1) # breaks it into 2 parts - (b, s, d_ff)
        x = up * F.silu(gate)  # SwiGLU activation - (b,s,d_ff) * (b,s,d_ff) = (b,s,d_ff)
        x = self.w2(x)  # (b,s,d_model)
        return x


class TransformerBlock(nn.Module):
    """Transformer block with pre-norm and residual connections"""
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        self.rms1 = nn.RMSNorm(config["d_model"], eps=config["eps"])
        self.rms2 = nn.RMSNorm(config["d_model"], eps=config["eps"])

        self.attn = GroupedQueryAttention(
            d_model=config["d_model"],
            n_heads=config["n_heads"],
            gqa_groups=config["gqa_groups"],
            max_len=config["max_len"]
        )

        self.mlp = SwiGLU_MLP(
            d_model=config["d_model"],
            d_ff=config["d_ff"]
        )

        self.dropout = nn.Dropout(config["dropout_p"])

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Pre-norm attention
        x = x + self.dropout(self.attn(self.rms1(x)))
        # Pre-norm MLP
        x = x + self.dropout(self.mlp(self.rms2(x)))
        return x


class MedAssistGPT(nn.Module):
    """Main model class"""
    def __init__(self, config: Dict[str, Any]):
        super().__init__()
        self.config = config

        # Token embeddings
        self.embed = nn.Embedding(config["vocab_size"], config["d_model"])

        # Transformer blocks
        self.blocks = nn.ModuleList([
            TransformerBlock(config) for _ in range(config["blocks"])
        ])

        # Final RMSNorm
        self.final_rms = nn.RMSNorm(config["d_model"], eps=config["eps"])

        # Language model head (weight-tied with embeddings)
        self.lm_head = nn.Linear(config["d_model"], config["vocab_size"], bias=False)
        self.lm_head.weight = self.embed.weight

        # Initialize weights
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
        # input_ids: (batch, seq_len)
        h = self.embed(input_ids)  # (batch, seq_len, d_model)

        # Pass through transformer blocks
        for block in self.blocks:
            h = block(h)

        # Final normalization
        h = self.final_rms(h)

        # Language model head
        logits = self.lm_head(h)  # (batch, seq_len, vocab_size)

        return logits

    def count_parameters(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

## Data Loading

### clean()

In [16]:
def clean(xml_str: str) -> str:
    """
    YOUR ORIGINAL XML CLEANING FUNCTION
    Converts JATS/PMC XML to readable plain text
    Removes tables, figures, citations, keeps narrative text
    """
    soup = BeautifulSoup(xml_str, "lxml-xml")

    # 1) Remove whole non-narrative blocks
    drop_whole = [
        "ref-list", "fig", "fig-group", "table-wrap", "table", "thead", "tbody",
        "tr", "td", "th", "graphic", "media", "supplementary-material", "back",
        "sec-meta", "table-wrap-foot", "caption"
    ]
    for name in drop_whole:
        for tag in soup.find_all(name):
            tag.decompose()

    # 2) Remove cross-references entirely (citations, table/fig pointers)
    for tag in soup.find_all("xref"):
        tag.decompose()

    # 3) Preserve disp-quote as plain paragraphs
    for dq in soup.find_all("disp-quote"):
        txt = dq.get_text(" ", strip=True)
        dq.replace_with(NavigableString(("\n" + txt + "\n") if txt else ""))

    # 4) Turn <title> into clean section headers
    for t in soup.find_all("title"):
        title_txt = t.get_text(" ", strip=True)
        t.replace_with(NavigableString("\n\n" + title_txt + "\n"))

    # 5) Ensure paragraphs end cleanly; unwrap inline tags
    inline_unwrap = [
        "italic", "bold", "underline", "sc", "em", "strong", "sup", "sub",
        "styled-content", "inline-formula", "monospace"
    ]
    for p in soup.find_all("p"):
        for name in inline_unwrap:
            for it in p.find_all(name):
                it.unwrap()
        p.insert_after(NavigableString("\n\n"))
        p.unwrap()

    # 6) Unwrap remaining structural containers
    for name in ["sec", "body", "front", "article", "abstract", "boxed-text", "list", "list-item"]:
        for tag in soup.find_all(name):
            tag.unwrap()

    # 7) Extract text and clean up
    text = soup.get_text()
    text = html.unescape(text)

    # Post-processing cleanup
    text = re.sub(r"\(\s*(?:\d+\s*(?:[-‚Äì]\s*\d+)?\s*(?:[,;]\s*)?)+\)", "", text)
    text = re.sub(r"\(\s*(?:[Ff]ig(?:ure)?\.?\s*\d+|[Tt]able\s*\d+)\s*\)", "", text)
    text = re.sub(r"\(\s*\)", "", text)
    text = re.sub(r"\[\s*\]", "", text)
    text = re.sub(r"\s+([,.;:!?])", r"\1", text)
    text = re.sub(r"([,.;:!?])\s*\1+", r"\1 ", text)
    text = re.sub(r"[ \t]{2,}", " ", text)
    text = re.sub(r"\n{3,}", "\n\n", text)

    text = "\n".join(line.rstrip() for line in text.splitlines())
    return text.strip()

### MemortMappedDataset class

In [17]:
class MemoryMappedDataset(Dataset):
    """
    Dataset that uses memory-mapped files for ZERO RAM overhead!

    HOW IT WORKS:
    1. Data is stored on disk in binary format (.npy files)
    2. When you access data[i], OS loads ONLY that piece into RAM
    3. OS automatically evicts old data when RAM gets full
    4. You get 100GB+ dataset working with 5GB RAM!
    """
    def __init__(self, cache_dir: Path):
        self.cache_dir = Path(cache_dir)

        # Load metadata (tiny file, ~1KB)
        with open(self.cache_dir / "metadata.pkl", "rb") as f:
            self.metadata = pickle.load(f)

        # Memory-map the input/target arrays
        # mode='r' = read-only, mmap_mode='r' = memory-mapped read
        # CRITICAL: np.load with mmap_mode DOES NOT load data into RAM!
        # It just maps the file, OS loads pages on-demand
        self.inputs_mmap = np.load(
            self.cache_dir / "inputs.npy",
            mmap_mode='r'  # ‚Üê THIS IS THE MAGIC! OS manages memory
        )
        self.targets_mmap = np.load(
            self.cache_dir / "targets.npy",
            mmap_mode='r'
        )

        print(f"üìÇ Memory-mapped dataset: {len(self.inputs_mmap):,} samples")
        print(f"üíæ RAM overhead: ~0 MB (OS manages it)")

    def __len__(self) -> int:
        return len(self.inputs_mmap)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        # OS loads ONLY this row into RAM (4KB per sample)
        # When RAM is full, OS evicts least-recently-used data
        input_ids = torch.from_numpy(self.inputs_mmap[idx].copy()).long()
        target_ids = torch.from_numpy(self.targets_mmap[idx].copy()).long()

        return input_ids, target_ids

### process_single_chunk()

In [18]:
def process_single_chunk(
    doc_batch: List[Dict],
    tokenizer_name: str,  # Can't pickle tokenizer, so pass name
    max_length: int,
    stride: int,
    use_clean: bool,
    text_column: str,
    chunk_id: int,
) -> Tuple[List[np.ndarray], List[np.ndarray], int, int]:
    """
    Worker function that processes one chunk

    WHY SEPARATE FUNCTION:
    - ProcessPoolExecutor spawns new Python processes
    - Each process needs its own tokenizer instance
    - Can't share tokenizer objects across processes (pickling issues)

    WHAT IT DOES:
    1. Creates its own tokenizer
    2. Processes its batch of documents
    3. Returns results to main process
    """
    import tiktoken

    # Each worker creates its own tokenizer
    tokenizer = tiktoken.get_encoding(tokenizer_name)
    vocab_size = tokenizer.n_vocab

    chunk_tokens = []
    docs_processed = 0
    total_tokens = 0

    # Process each document in this chunk
    for doc in doc_batch:
        # Extract text
        if text_column in doc:
            text = doc[text_column]
        elif 'text' in doc:
            text = doc['text']
        else:
            continue

        if not text or len(text) < 100:
            continue

        # Apply clean function if enabled
        if use_clean:
            try:
                text = clean(text)  # Your clean function
            except:
                continue

        if not text or len(text) < 100:
            continue

        # Tokenize
        try:
            tokens = tokenizer.encode(text, allowed_special={"<|endoftext|>"})
            tokens = [t for t in tokens if 0 <= t < vocab_size]

            if len(tokens) < 10:
                continue

            chunk_tokens.extend(tokens)
            total_tokens += len(tokens)
            docs_processed += 1

        except:
            continue

    # Create sliding windows from all tokens in this chunk
    samples = create_sliding_windows(chunk_tokens, max_length, stride, vocab_size)

    if samples:
        inputs, targets = zip(*samples)
        return list(inputs), list(targets), docs_processed, total_tokens
    else:
        return [], [], docs_processed, total_tokens

### process_dataset_in_chunks()

In [19]:
def process_dataset_in_chunks(
    dataset_name: str,
    tokenizer,
    cache_dir: Path,
    max_length: int = 1024,
    stride: int = 1024,
    max_samples: int = None,
    chunk_size: int = 5000,
    use_clean: bool = True,
    text_column: str = "full_text"
):
    """
    PARALLEL VERSION - Drop-in replacement for your original function

    HOW PARALLELIZATION WORKS:
    1. Load documents in batches (streaming)
    2. Distribute batches to worker processes
    3. Each worker: clean ‚Üí tokenize ‚Üí create windows
    4. Main process: collect results and save

    NUM WORKERS:
    - Uses all CPU cores by default
    - On 8-core: 8 chunks processed simultaneously
    - On Colab: ~2 cores (still 2x speedup!)
    """

    cache_dir = Path(cache_dir)
    cache_dir.mkdir(parents=True, exist_ok=True)

    # Determine number of workers
    num_workers = min(mp.cpu_count(), 8)  # Max 8 workers (diminishing returns)

    print(f"üî• Processing dataset: {dataset_name}")
    print(f"üìä Chunk size: {chunk_size} documents")
    print(f"üßπ Cleaning enabled: {use_clean}")
    print(f"‚ö° Parallel workers: {num_workers}")

    # Load dataset in streaming mode
    dataset = load_dataset(dataset_name, split="train", streaming=True)

    if max_samples:
        dataset = dataset.take(max_samples)

    # Accumulate results from all workers
    all_inputs = []
    all_targets = []
    total_docs_processed = 0
    total_tokens = 0

    print("üîÑ Loading documents into batches...")

    # Collect documents into batches
    # WHY: Can't parallelize streaming iterator directly
    # So we batch first, then parallelize batch processing
    doc_batches = []
    current_batch = []

    for doc in tqdm(dataset, desc="Batching docs"):
        current_batch.append(doc)

        if len(current_batch) >= chunk_size:
            doc_batches.append(current_batch)
            current_batch = []

    # Don't forget the last batch
    if current_batch:
        doc_batches.append(current_batch)

    print(f"üì¶ Created {len(doc_batches)} batches of ~{chunk_size} documents each")
    print(f"‚ö° Processing with {num_workers} parallel workers...")

    # PARALLEL PROCESSING STARTS HERE!
    # ProcessPoolExecutor creates worker processes
    # Each worker gets its own batch to process
    with ProcessPoolExecutor(max_workers=num_workers) as executor:

        # Submit all batches to workers
        # partial() pre-fills the arguments that are same for all workers
        worker_fn = partial(
            process_single_chunk,
            tokenizer_name="p50k_base",  # Pass tokenizer name, not object
            max_length=max_length,
            stride=stride,
            use_clean=use_clean,
            text_column=text_column,
        )

        # Submit all jobs
        futures = []
        for chunk_id, doc_batch in enumerate(doc_batches):
            future = executor.submit(worker_fn, doc_batch, chunk_id=chunk_id)
            futures.append(future)

        # Collect results as they complete
        # as_completed() returns futures as soon as they finish (not in order)
        # This gives us progress updates in real-time!
        try:
          for future in tqdm(as_completed(futures), total=len(futures), desc="Processing chunks"):
            try:
                # Get results from this worker
                inputs, targets, docs_proc, tokens_proc = future.result()

                # Accumulate
                all_inputs.extend(inputs)
                all_targets.extend(targets)
                total_docs_processed += docs_proc
                total_tokens += tokens_proc

                # Show progress
                if len(inputs) > 0:
                    print(f"‚úÖ Chunk done: {len(inputs)} samples, {docs_proc} docs, {tokens_proc:,} tokens")

            except Exception as e:
                print(f"‚ö†Ô∏è  Worker failed: {e}")
                continue

        except KeyboardInterrupt as e:
            print("‚ö†Ô∏è  Interrupted by user")

    print(f"\nüìä Parallel processing complete!")
    print(f"   Documents processed: {total_docs_processed:,}")
    print(f"   Total tokens: {total_tokens:,}")
    print(f"   Training samples: {len(all_inputs):,}")

    # Save to disk (same as before)
    print("üíæ Saving to disk as memory-mapped arrays...")

    inputs_array = np.array(all_inputs, dtype=np.int32)
    targets_array = np.array(all_targets, dtype=np.int32)

    np.save(cache_dir / "inputs.npy", inputs_array)
    np.save(cache_dir / "targets.npy", targets_array)

    # Save metadata
    metadata = {
        "num_samples": len(all_inputs),
        "max_length": max_length,
        "stride": stride,
        "vocab_size": tokenizer.n_vocab,
        "docs_processed": total_docs_processed,
        "total_tokens": total_tokens,
    }

    with open(cache_dir / "metadata.pkl", "wb") as f:
        pickle.dump(metadata, f)

    print(f"‚úÖ Saved to {cache_dir}")

    # Clear RAM
    del all_inputs, all_targets, inputs_array, targets_array
    gc.collect()

    return cache_dir

### create_sliding_windows()

In [20]:
def create_sliding_windows(
    tokens: List[int],
    max_length: int,
    stride: int,
    vocab_size: int
) -> List[Tuple[np.ndarray, np.ndarray]]:
    """
    Create sliding window samples from token list

    EXAMPLE:
    tokens = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
    max_length = 4, stride = 2

    Samples:
    input:  [1, 2, 3, 4]  target: [2, 3, 4, 5]
    input:  [3, 4, 5, 6]  target: [4, 5, 6, 7]
    input:  [5, 6, 7, 8]  target: [6, 7, 8, 9]
    etc.
    """
    samples = []
    tokens = np.array(tokens, dtype=np.int32)

    # Verify tokens are in valid range
    if len(tokens) == 0:
        return samples

    # Clip to valid range (safety)
    tokens = np.clip(tokens, 0, vocab_size - 1)

    # Create sliding windows
    for i in range(0, len(tokens) - max_length, stride):
        input_ids = tokens[i:i+max_length]
        target_ids = tokens[i+1:i+max_length+1]

        # Ensure both are exactly max_length
        if len(input_ids) == max_length and len(target_ids) == max_length:
            samples.append((input_ids, target_ids))

    return samples

### prepare_mdeical_data()

In [21]:
def prepare_medical_data(
    config: Dict[str, Any],
    tokenizer
) -> Tuple[Path, Path]:
    """
    Main function: Prepares training and validation data efficiently

    WHAT THIS DOES:
    1. Checks if cache exists (skip if already processed)
    2. Splits dataset into train/val
    3. Processes each split in chunks
    4. Saves as memory-mapped files
    5. Returns paths to cached data

    FIRST RUN: ~30 minutes (processes and caches)
    SUBSEQUENT RUNS: ~1 second (loads from cache)
    """

    train_cache = Path("./data_cache/train")
    val_cache = Path("./data_cache/val")

    # Check if already cached
    train_exists = (train_cache / "metadata.pkl").exists()
    val_exists = (val_cache / "metadata.pkl").exists()

    if train_exists and val_exists:
        print("‚úÖ Found cached data! Skipping processing.")
        print(f"   Train cache: {train_cache}")
        print(f"   Val cache: {val_cache}")
        return train_cache, val_cache

    # Calculate split sizes
    total_samples = config.get("max_train_samples", 1_000_000)
    train_size = int(total_samples * config.get("train_split", 0.95))
    val_size = total_samples - train_size

    print(f"üìä Dataset split:")
    print(f"   Training: {train_size:,} documents")
    print(f"   Validation: {val_size:,} documents")

    # Process training data
    if not train_exists:
        print("\n" + "="*80)
        print("PROCESSING TRAINING DATA")
        print("="*80)

        # Create temporary streaming dataset for train split
        dataset = load_dataset(
            config["dataset_name"],
            split="train",
            streaming=True
        )
        train_dataset = dataset.take(train_size)

        # Process and cache
        # We create a temporary generator to process
        def train_generator():
            for item in train_dataset:
                yield item

        # Process using your dataset directly
        process_dataset_in_chunks(
            dataset_name=config["dataset_name"],
            tokenizer=tokenizer,
            cache_dir=train_cache,
            max_length=config.get("max_length", 1024),
            stride=config.get("stride", 1024),
            max_samples=train_size,
            chunk_size=config.get("chunk_size", 5000),
            use_clean=config.get("use_clean", True),
            text_column=config.get("text_column", "full_text")
        )

    # Process validation data
    if not val_exists:
        print("\n" + "="*80)
        print("PROCESSING VALIDATION DATA")
        print("="*80)

        # Create dataset that skips training samples
        dataset = load_dataset(
            config["dataset_name"],
            split="train",
            streaming=True
        )
        val_dataset = dataset.skip(train_size).take(val_size)

        # Process and cache
        # For validation, we need to handle the skip
        # Easier to just process with offset
        temp_config = config.copy()
        temp_config["dataset_offset"] = train_size

        process_dataset_in_chunks(
            dataset_name=config["dataset_name"],
            tokenizer=tokenizer,
            cache_dir=val_cache,
            max_length=config.get("max_length", 1024),
            stride=config.get("stride", 1024),
            max_samples=val_size,
            chunk_size=config.get("chunk_size", 5000),
            use_clean=config.get("use_clean", True),
            text_column=config.get("text_column", "full_text")
        )

    print("\n" + "="*80)
    print("‚úÖ DATA PREPARATION COMPLETE!")
    print("="*80)

    return train_cache, val_cache

### create_dataloader()

In [22]:
def create_dataloader(
    cache_dir: Path,
    batch_size: int,
    shuffle: bool = True,
    num_workers: int = 2,  # Lower for memory-mapped files
) -> DataLoader:
    """
    Create DataLoader from cached data

    NUM_WORKERS = 2:
    - Memory-mapped files don't benefit from many workers
    - OS handles the parallel I/O better than Python
    - 2 workers is sweet spot for prefetching
    """

    dataset = MemoryMappedDataset(cache_dir)

    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        num_workers=num_workers,
        pin_memory=True,  # Faster GPU transfer
        drop_last=True,
        persistent_workers=True,  # Keep workers alive between epochs
    )

## Training Utils

In [23]:
def calc_loss_batch(
    input_batch: torch.Tensor,
    target_batch: torch.Tensor,
    model: nn.Module,
    device: torch.device
) -> torch.Tensor:
    """Calculate loss for a single batch"""
    input_batch = input_batch.to(device, non_blocking=True)
    target_batch = target_batch.to(device, non_blocking=True)

    with autocast("cuda", torch.bfloat16):
        logits = model(input_batch)
        loss = F.cross_entropy(
            logits.flatten(0, 1),
            target_batch.flatten()
        )

    return loss


def evaluate_model(
    model: nn.Module,
    data_loader: DataLoader,
    device: torch.device,
    num_batches: int = 100
) -> float:
    """Evaluate model on validation set"""
    model.eval()
    total_loss = 0.0
    num_batches = min(num_batches, len(data_loader))

    with torch.no_grad():
        for i, (input_batch, target_batch) in enumerate(data_loader):
            if i >= num_batches:
                break
            loss = calc_loss_batch(input_batch, target_batch, model, device)
            total_loss += loss.item()

    model.train()
    return total_loss / num_batches


def save_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    scheduler: Any,
    step: int,
    loss: float,
    save_dir: Path,
    config: Dict[str, Any]
):
    """Save model checkpoint"""
    checkpoint = {
        'step': step,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict(),
        'loss': loss,
        'config': config,
    }

    save_path = save_dir / f"checkpoint_step_{step}.pt"
    torch.save(checkpoint, save_path)
    print(f"üíæ Checkpoint saved: {save_path}")

    return save_path


def upload_to_huggingface(
    model: nn.Module,
    save_dir: Path,
    repo_id: str,
    config: Dict[str, Any],
    step: int
):
    """Upload model to HuggingFace Hub"""
    try:
        # Save model weights
        torch.save(model.state_dict(), save_dir / "pytorch_model.bin")

        # Save config
        with open(save_dir / "config.json", "w") as f:
            json.dump(config, f, indent=2)

        # Upload to HF
        api = HfApi()
        api.upload_folder(
            folder_path=str(save_dir),
            repo_id=repo_id,
            repo_type="model",
            commit_message=f"Training checkpoint at step {step}"
        )

        print(f"‚òÅÔ∏è  Uploaded to HuggingFace: {repo_id}")
    except Exception as e:
        print(f"‚ö†Ô∏è  Failed to upload to HuggingFace: {e}")

## Training loop

In [27]:
def train(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    optimizer: torch.optim.Optimizer,
    scheduler: Any,
    device: torch.device,
    config: Dict[str, Any],
    save_dir: Path,
    hf_repo_id: str = None
):
    """Main training loop with all optimizations"""

    print("=" * 80)
    print("üöÄ STARTING MEDICAL LLM PRETRAINING")
    print("=" * 80)
    print(f"üìä Model: {model.count_parameters():,} parameters")
    print(f"üìä Training batches: {len(train_loader):,}")
    print(f"üìä Max steps: {config['max_steps']:,}")
    print(f"üìä Effective batch size: {config['batch_size'] * config['gradient_accumulation_steps']}")
    print(f"üìä Device: {device}")
    print("=" * 80)

    model.train()
    global_step = 0
    tokens_seen = 0
    best_val_loss = float('inf')

    train_losses = []
    val_losses = []

    grad_accum = config["gradient_accumulation_steps"]

    try:
        for epoch in range(100):  # Virtually unlimited epochs
            epoch_loss = 0.0
            epoch_steps = 0

            progress_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}")

            for batch_idx, (input_batch, target_batch) in enumerate(progress_bar):
                # Forward pass
                loss = calc_loss_batch(input_batch, target_batch, model, device)

                # Scale loss for gradient accumulation
                loss = loss / grad_accum
                loss.backward()

                # Accumulate
                if (batch_idx + 1) % grad_accum == 0 or (batch_idx + 1) == len(train_loader):
                    # Clip gradients
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(),
                        max_norm=config["grad_clip"]
                    )

                    # Optimizer step
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()

                    # Update counters
                    global_step += 1
                    tokens_seen += input_batch.numel() * grad_accum

                    # Log training loss
                    train_losses.append(loss.item() * grad_accum)
                    epoch_loss += loss.item() * grad_accum
                    epoch_steps += 1

                    # Update progress bar
                    progress_bar.set_postfix({
                        'loss': f"{loss.item() * grad_accum:.4f}",
                        'lr': f"{scheduler.get_last_lr()[0]:.2e}",
                        'step': global_step
                    })

                    # Evaluation
                    if global_step % config["eval_freq"] == 0:
                        val_loss = evaluate_model(
                            model, val_loader, device, config["eval_iter"]
                        )
                        val_losses.append(val_loss)

                        # Log to wandb
                        wandb.log({
                            "val_loss": val_loss,
                            "learning_rate": scheduler.get_last_lr()[0],
                            "step": global_step,
                        })

                        # Check for improvement
                        if val_loss < best_val_loss:
                            best_val_loss = val_loss
                            print(f"\n‚ú® New best validation loss: {val_loss:.4f}")

                    # Save checkpoint
                    if global_step % config["save_freq"] == 0:
                        save_checkpoint(
                            model, optimizer, scheduler,
                            global_step, train_losses[-1],
                            save_dir, MODEL_CONFIG
                        )

                        # Upload to HuggingFace
                        if hf_repo_id and config.get("upload_checkpoints", False):
                            if global_step % config.get("upload_frequency", 1000) == 0:
                                upload_to_huggingface(
                                    model, save_dir / "hf_upload",
                                    hf_repo_id, MODEL_CONFIG, global_step
                                )

                    # Check if max steps reached
                    if global_step >= config["max_steps"]:
                        print(f"\nüéâ Reached max steps ({config['max_steps']})")
                        raise StopIteration

                    wandb.log({
                        "train_loss": loss.item() * grad_accum,
                        "tokens_seen": tokens_seen
                    })

            # End of epoch summary
            avg_epoch_loss = epoch_loss / epoch_steps if epoch_steps > 0 else float('inf')
            print(f"\nEpoch {epoch+1} complete - Avg loss: {avg_epoch_loss:.4f}")

    except (KeyboardInterrupt, StopIteration):
        print("\n‚ö†Ô∏è  Training stopped")

    # Final checkpoint
    print("\nüíæ Saving final checkpoint...")
    save_checkpoint(
        model, optimizer, scheduler,
        global_step, train_losses[-1] if train_losses else 0,
        save_dir, MODEL_CONFIG
    )

    print(f"\nüéâ Training complete!")
    print(f"üìä Total steps: {global_step:,}")
    print(f"üìä Total tokens: {tokens_seen:,}")
    print(f"üìä Best validation loss: {best_val_loss:.4f}")

    return train_losses, val_losses

## main.py

In [28]:
def main():
    # Set random seeds
    torch.manual_seed(TRAINING_CONFIG["seed"])
    np.random.seed(TRAINING_CONFIG["seed"])

    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"üîß Using device: {device}")

    # Create save directory
    save_dir = Path("./checkpoints")
    save_dir.mkdir(exist_ok=True)

    # Initialize tokenizer
    print("üîß Loading tokenizer...")
    tokenizer = tiktoken.get_encoding("p50k_base")

    # Load and prepare data
    train_tokens, val_tokens = prepare_medical_data(DATA_CONFIG, tokenizer)

    # Create dataloaders
    print("üîß Creating dataloaders...")
    train_loader = create_dataloader(
        Path("/content/data_cache/train/"),
        batch_size=TRAINING_CONFIG["batch_size"],
        # max_length=TRAINING_CONFIG["max_length"],
        # stride=TRAINING_CONFIG["stride"],
        shuffle=True,
        num_workers=TRAINING_CONFIG["num_workers"]
    )

    val_loader = create_dataloader(
        Path("/content/data_cache/val/"),
        batch_size=TRAINING_CONFIG["batch_size"],
        # max_length=TRAINING_CONFIG["max_length"],
        # stride=TRAINING_CONFIG["stride"],
        shuffle=False,
        num_workers=TRAINING_CONFIG["num_workers"]
    )

    # Initialize model
    print("üîß Initializing model...")
    model = MedAssistGPT(MODEL_CONFIG)
    model = model.to(device)

    # Compile model (PyTorch 2.0+)
    if hasattr(torch, 'compile'):
        print("üîß Compiling model...")
        model = torch.compile(model, mode="default", fullgraph=False, dynamic=True)

    print(f"‚úÖ Model has {model.count_parameters():,} parameters")

    # Initialize optimizer
    optimizer = torch.optim.AdamW(
        model.parameters(),
        lr=TRAINING_CONFIG["learning_rate"],
        weight_decay=TRAINING_CONFIG["weight_decay"],
        betas=(TRAINING_CONFIG["beta1"], TRAINING_CONFIG["beta2"]),
        eps=TRAINING_CONFIG["eps"]
    )

    # Initialize scheduler
    scheduler = OneCycleLR(
        optimizer,
        max_lr=TRAINING_CONFIG["learning_rate"],
        total_steps=TRAINING_CONFIG["max_steps"],
        pct_start=TRAINING_CONFIG["warmup_steps"] / TRAINING_CONFIG["max_steps"],
        anneal_strategy='cos',
        div_factor=10,
        final_div_factor=100
    )

    # Initialize wandb
    wandb.init(
        project=WANDB_CONFIG["project"],
        entity=WANDB_CONFIG["entity"],
        name=WANDB_CONFIG["name"],
        config={**MODEL_CONFIG, **TRAINING_CONFIG, **DATA_CONFIG}
    )

    # Login to HuggingFace (if uploading)
    if HF_CONFIG.get("upload_checkpoints", False):
        try:
            login()
            create_repo(HF_CONFIG["repo_id"], repo_type="model", exist_ok=True)
            print(f"‚úÖ HuggingFace repo ready: {HF_CONFIG['repo_id']}")
        except Exception as e:
            print(f"‚ö†Ô∏è  HuggingFace setup failed: {e}")
            HF_CONFIG["upload_checkpoints"] = False


    # Train!
    train(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        scheduler=scheduler,
        device=device,
        config=TRAINING_CONFIG,
        save_dir=save_dir,
        hf_repo_id=HF_CONFIG["repo_id"] if HF_CONFIG.get("upload_checkpoints") else None
    )

    wandb.finish()
    print("\n‚úÖ All done!")

if __name__ == "__main__":
    main()

üîß Using device: cuda
üîß Loading tokenizer...
‚úÖ Found cached data! Skipping processing.
   Train cache: data_cache/train
   Val cache: data_cache/val
üîß Creating dataloaders...
üìÇ Memory-mapped dataset: 524,046 samples
üíæ RAM overhead: ~0 MB (OS manages it)
üìÇ Memory-mapped dataset: 33,328 samples
üíæ RAM overhead: ~0 MB (OS manages it)
üîß Initializing model...
üîß Compiling model...
‚úÖ Model has 86,578,176 parameters


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv‚Ä¶

‚ö†Ô∏è  HuggingFace setup failed: 401 Client Error: Unauthorized for url: https://huggingface.co/api/repos/create (Request ID: Root=1-69114ead-705d2d345910c14b0c0335da;12dc0f30-7093-442c-a008-7dc113ca8e79)

Invalid username or password.
üöÄ STARTING MEDICAL LLM PRETRAINING
üìä Model: 86,578,176 parameters
üìä Training batches: 8,188
üìä Max steps: 50,000
üìä Effective batch size: 128
üìä Device: cuda


Epoch 1:  12%|‚ñà‚ñè        | 1000/8188 [08:29<13:07:57,  6.58s/it, loss=4.8076, lr=3.00e-04, step=500]


‚ú® New best validation loss: 4.9242


Epoch 1:  24%|‚ñà‚ñà‚ñç       | 1998/8188 [14:46<38:41,  2.67it/s, loss=3.9421, lr=3.00e-04, step=1000]


‚ú® New best validation loss: 3.9290


Epoch 1:  24%|‚ñà‚ñà‚ñç       | 2000/8188 [15:01<4:33:07,  2.65s/it, loss=3.9421, lr=3.00e-04, step=1000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_1000.pt


Epoch 1:  37%|‚ñà‚ñà‚ñà‚ñã      | 3000/8188 [21:30<3:31:43,  2.45s/it, loss=3.6426, lr=3.00e-04, step=1500]


‚ú® New best validation loss: 3.5518


Epoch 1:  49%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 3998/8188 [27:45<26:12,  2.67it/s, loss=3.2987, lr=2.99e-04, step=2000]


‚ú® New best validation loss: 3.3488


Epoch 1:  49%|‚ñà‚ñà‚ñà‚ñà‚ñâ     | 4000/8188 [28:00<3:04:48,  2.65s/it, loss=3.2987, lr=2.99e-04, step=2000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_2000.pt


Epoch 1:  61%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà    | 5000/8188 [34:29<2:10:18,  2.45s/it, loss=3.3058, lr=2.99e-04, step=2500]


‚ú® New best validation loss: 3.2273


Epoch 1:  73%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 5998/8188 [40:44<13:41,  2.66it/s, loss=3.1478, lr=2.98e-04, step=3000]


‚ú® New best validation loss: 3.1449


Epoch 1:  73%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé  | 6000/8188 [41:00<1:36:34,  2.65s/it, loss=3.1478, lr=2.98e-04, step=3000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_3000.pt


Epoch 1:  85%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå | 7000/8188 [47:29<48:30,  2.45s/it, loss=3.1070, lr=2.97e-04, step=3500]


‚ú® New best validation loss: 3.0813


Epoch 1:  98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 7998/8188 [53:45<01:11,  2.66it/s, loss=3.0844, lr=2.96e-04, step=4000]


‚ú® New best validation loss: 3.0293


Epoch 1:  98%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä| 8000/8188 [54:00<08:18,  2.65s/it, loss=3.0844, lr=2.96e-04, step=4000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_4000.pt


Epoch 1: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8188/8188 [55:11<00:00,  2.47it/s, loss=2.8545, lr=2.96e-04, step=4094]



Epoch 1 complete - Avg loss: 3.8584


Epoch 2:  10%|‚ñâ         | 812/8188 [05:18<5:01:20,  2.45s/it, loss=3.0777, lr=2.95e-04, step=4500]


‚ú® New best validation loss: 2.9929


Epoch 2:  22%|‚ñà‚ñà‚ñè       | 1810/8188 [11:34<39:56,  2.66it/s, loss=3.1004, lr=2.94e-04, step=5000]


‚ú® New best validation loss: 2.9633


Epoch 2:  22%|‚ñà‚ñà‚ñè       | 1812/8188 [11:49<4:41:08,  2.65s/it, loss=3.1004, lr=2.94e-04, step=5000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_5000.pt


Epoch 2:  34%|‚ñà‚ñà‚ñà‚ñç      | 2812/8188 [18:19<3:39:27,  2.45s/it, loss=2.9451, lr=2.93e-04, step=5500]


‚ú® New best validation loss: 2.9359


Epoch 2:  47%|‚ñà‚ñà‚ñà‚ñà‚ñã     | 3810/8188 [24:34<27:21,  2.67it/s, loss=3.0218, lr=2.91e-04, step=6000]


‚ú® New best validation loss: 2.9105


Epoch 2:  47%|‚ñà‚ñà‚ñà‚ñà‚ñã     | 3812/8188 [24:49<3:13:16,  2.65s/it, loss=3.0218, lr=2.91e-04, step=6000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_6000.pt


Epoch 2:  59%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ    | 4812/8188 [31:18<2:17:52,  2.45s/it, loss=2.9147, lr=2.89e-04, step=6500]


‚ú® New best validation loss: 2.8890


Epoch 2:  71%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 5810/8188 [37:34<15:02,  2.63it/s, loss=2.9481, lr=2.87e-04, step=7000]


‚ú® New best validation loss: 2.8693


Epoch 2:  71%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà   | 5812/8188 [37:49<1:47:44,  2.72s/it, loss=2.9481, lr=2.87e-04, step=7000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_7000.pt


Epoch 2:  83%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé | 6812/8188 [44:18<56:15,  2.45s/it, loss=2.9766, lr=2.85e-04, step=7500]


‚ú® New best validation loss: 2.8529


Epoch 2:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 7810/8188 [50:34<02:21,  2.66it/s, loss=2.9394, lr=2.83e-04, step=8000]


‚ú® New best validation loss: 2.8375


Epoch 2:  95%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñå| 7812/8188 [50:49<16:35,  2.65s/it, loss=2.9394, lr=2.83e-04, step=8000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_8000.pt


Epoch 2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8188/8188 [53:10<00:00,  2.57it/s, loss=2.8984, lr=2.83e-04, step=8188]



Epoch 2 complete - Avg loss: 2.9498


Epoch 3:   8%|‚ñä         | 624/8188 [04:08<5:09:16,  2.45s/it, loss=2.8120, lr=2.81e-04, step=8500]


‚ú® New best validation loss: 2.8221


Epoch 3:  20%|‚ñà‚ñâ        | 1622/8188 [10:23<41:04,  2.66it/s, loss=2.7628, lr=2.79e-04, step=9000]


‚ú® New best validation loss: 2.8109


Epoch 3:  20%|‚ñà‚ñâ        | 1624/8188 [10:38<4:49:15,  2.64s/it, loss=2.7628, lr=2.79e-04, step=9000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_9000.pt


Epoch 3:  32%|‚ñà‚ñà‚ñà‚ñè      | 2624/8188 [17:08<3:47:15,  2.45s/it, loss=2.8661, lr=2.76e-04, step=9500]


‚ú® New best validation loss: 2.8004


Epoch 3:  44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 3622/8188 [23:24<28:32,  2.67it/s, loss=2.8194, lr=2.74e-04, step=1e+4]


‚ú® New best validation loss: 2.7870


Epoch 3:  44%|‚ñà‚ñà‚ñà‚ñà‚ñç     | 3624/8188 [23:39<3:21:19,  2.65s/it, loss=2.8194, lr=2.74e-04, step=1e+4]

üíæ Checkpoint saved: checkpoints/checkpoint_step_10000.pt


Epoch 3:  56%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã    | 4624/8188 [30:08<2:25:33,  2.45s/it, loss=2.8632, lr=2.71e-04, step=10500]


‚ú® New best validation loss: 2.7777


Epoch 3:  69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 5622/8188 [36:24<16:02,  2.67it/s, loss=2.8964, lr=2.68e-04, step=11000]


‚ú® New best validation loss: 2.7674


Epoch 3:  69%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä   | 5624/8188 [36:39<1:53:14,  2.65s/it, loss=2.8964, lr=2.68e-04, step=11000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_11000.pt


Epoch 3:  81%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà  | 6624/8188 [43:08<1:04:02,  2.46s/it, loss=2.7124, lr=2.65e-04, step=11500]


‚ú® New best validation loss: 2.7577


Epoch 3:  93%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 7622/8188 [49:24<03:32,  2.67it/s, loss=2.7280, lr=2.62e-04, step=12000]


‚ú® New best validation loss: 2.7499


Epoch 3:  93%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé| 7624/8188 [49:39<24:52,  2.65s/it, loss=2.7280, lr=2.62e-04, step=12000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_12000.pt


Epoch 3: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8188/8188 [53:11<00:00,  2.57it/s, loss=2.8848, lr=2.60e-04, step=12282]



Epoch 3 complete - Avg loss: 2.8274


Epoch 4:   5%|‚ñå         | 436/8188 [02:57<5:16:50,  2.45s/it, loss=2.7571, lr=2.59e-04, step=12500]


‚ú® New best validation loss: 2.7406


Epoch 4:  18%|‚ñà‚ñä        | 1434/8188 [09:13<42:15,  2.66it/s, loss=2.7287, lr=2.55e-04, step=13000]


‚ú® New best validation loss: 2.7349


Epoch 4:  18%|‚ñà‚ñä        | 1436/8188 [09:28<4:57:05,  2.64s/it, loss=2.7287, lr=2.55e-04, step=13000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_13000.pt


Epoch 4:  30%|‚ñà‚ñà‚ñâ       | 2436/8188 [15:57<3:54:54,  2.45s/it, loss=2.7125, lr=2.52e-04, step=13500]


‚ú® New best validation loss: 2.7279


Epoch 4:  42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 3434/8188 [22:12<29:44,  2.66it/s, loss=2.7483, lr=2.48e-04, step=14000]


‚ú® New best validation loss: 2.7223


Epoch 4:  42%|‚ñà‚ñà‚ñà‚ñà‚ñè     | 3436/8188 [22:27<3:29:37,  2.65s/it, loss=2.7483, lr=2.48e-04, step=14000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_14000.pt


Epoch 4:  54%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç    | 4436/8188 [28:57<2:33:18,  2.45s/it, loss=2.7263, lr=2.45e-04, step=14500]


‚ú® New best validation loss: 2.7135


Epoch 4:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 5434/8188 [35:13<17:13,  2.66it/s, loss=2.7985, lr=2.41e-04, step=15000]


‚ú® New best validation loss: 2.7097


Epoch 4:  66%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã   | 5436/8188 [35:28<2:01:25,  2.65s/it, loss=2.7985, lr=2.41e-04, step=15000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_15000.pt


Epoch 4:  79%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä  | 6436/8188 [41:57<1:11:37,  2.45s/it, loss=2.8543, lr=2.37e-04, step=15500]


‚ú® New best validation loss: 2.7041


Epoch 4:  91%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 7435/8188 [48:13<04:11,  3.00it/s, loss=2.6264, lr=2.33e-04, step=16000]


‚ú® New best validation loss: 2.6974


Epoch 4:  91%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà | 7436/8188 [48:28<43:20,  3.46s/it, loss=2.6264, lr=2.33e-04, step=16000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_16000.pt


Epoch 4: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 8188/8188 [53:11<00:00,  2.57it/s, loss=2.7740, lr=2.30e-04, step=16376]



Epoch 4 complete - Avg loss: 2.7630


Epoch 5:   3%|‚ñé         | 248/8188 [01:46<5:24:28,  2.45s/it, loss=2.7209, lr=2.29e-04, step=16500]


‚ú® New best validation loss: 2.6899


Epoch 5:  15%|‚ñà‚ñå        | 1246/8188 [08:02<43:24,  2.67it/s, loss=2.6805, lr=2.25e-04, step=17000]


‚ú® New best validation loss: 2.6849


Epoch 5:  15%|‚ñà‚ñå        | 1248/8188 [08:17<5:05:35,  2.64s/it, loss=2.6805, lr=2.25e-04, step=17000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_17000.pt


Epoch 5:  27%|‚ñà‚ñà‚ñã       | 2248/8188 [14:46<4:02:36,  2.45s/it, loss=2.7280, lr=2.21e-04, step=17500]


‚ú® New best validation loss: 2.6794


Epoch 5:  40%|‚ñà‚ñà‚ñà‚ñâ      | 3246/8188 [21:02<30:54,  2.66it/s, loss=2.7333, lr=2.17e-04, step=18000]


‚ú® New best validation loss: 2.6754


Epoch 5:  40%|‚ñà‚ñà‚ñà‚ñâ      | 3248/8188 [21:17<3:37:52,  2.65s/it, loss=2.7333, lr=2.17e-04, step=18000]

üíæ Checkpoint saved: checkpoints/checkpoint_step_18000.pt


Epoch 5:  52%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè    | 4247/8188 [27:34<25:35,  2.57it/s, loss=2.7553, lr=2.12e-04, step=18500]



‚ö†Ô∏è  Training stopped

üíæ Saving final checkpoint...
üíæ Checkpoint saved: checkpoints/checkpoint_step_18500.pt

üéâ Training complete!
üìä Total steps: 18,500
üìä Total tokens: 2,424,832,000
üìä Best validation loss: 2.6754


0,1
learning_rate,‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ
step,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
tokens_seen,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà
train_loss,‚ñà‚ñà‚ñà‚ñá‚ñÜ‚ñÜ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñÅ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
val_loss,‚ñà‚ñÖ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ

0,1
learning_rate,0.00022
step,18000.0
tokens_seen,2424700928.0
train_loss,2.62974
val_loss,2.67538



‚úÖ All done!


## Inference

In [None]:
def generate_text(model, tokenizer, inference_config: Dict[str, Any], device: str = 'cuda'):
    model.eval()
    model.to(device)

    text = inference_config["prompt_text"]
    max_new_tokens = inference_config.get("max_new_tokens", 50)
    temperature = inference_config.get("temperature", 0.8)

    # Encode the input text
    encoded_input = tokenizer.encode(text)
    input_ids = torch.tensor(encoded_input, dtype=torch.long).unsqueeze(0).to(device)

    generated_tokens = []

    with torch.no_grad():
        for _ in range(max_new_tokens):
            # Take the last 'max_len' tokens if input_ids is longer
            current_input_ids = input_ids if input_ids.size(1) <= model.config['max_len'] else input_ids[:, -model.config['max_len']:]

            logits = model(current_input_ids) # (1, seq_len, vocab_size)
            logits = logits[:, -1, :] # Take the logits for the last token (1, vocab_size)

            # Apply temperature
            if temperature == 0.0:
                next_token = torch.argmax(logits, dim=-1).unsqueeze(1) # Ensure (1, 1) shape
            else:
                probs = torch.softmax(logits / temperature, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1) # Already (1, 1) shape

            # Append the new token to the generated sequence
            generated_tokens.append(next_token.item())
            input_ids = torch.cat((input_ids, next_token), dim=1) # Removed .unsqueeze(0)

            # Stop if endoftext token is generated
            if next_token.item() == tokenizer.eot_token:
                break

    # Decode the generated tokens
    decoded_output = tokenizer.decode(generated_tokens)
    return text + decoded_output

def load_model_from_checkpoint(checkpoint_path: Path, model_config: Dict[str, Any], device: torch.device) -> nn.Module:
    """Loads a MedAssistGPT model from a checkpoint file."""
    print(f"Loading model from {checkpoint_path}...")
    model = MedAssistGPT(model_config)
    checkpoint = torch.load(checkpoint_path, map_location=device)

    # Fix: Remove '_orig_mod.' prefix from state_dict keys if present
    state_dict = checkpoint['model_state_dict']
    new_state_dict = {}
    for k, v in state_dict.items():
        if k.startswith('_orig_mod.'):
            new_state_dict[k[len('_orig_mod.'):]] = v
        else:
            new_state_dict[k] = v

    model.load_state_dict(new_state_dict)
    model.to(device)
    model.eval() # Set model to evaluation mode
    print("Model loaded successfully.")
    return model

# Example usage:
# First, set up device and tokenizer as they are needed for both loading and inference
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = tiktoken.get_encoding("p50k_base")

generated_output = []
ckpt=1000
while ckpt<=18500:
    # Specify the path to your saved checkpoint
    # Replace 'checkpoints/checkpoint_step_15000.pt' with the actual path if different
    checkpoint_file_path = Path(f"checkpoints/checkpoint_step_{ckpt}.pt")
    ckpt += 1000

    # Load the model
    loaded_model = load_model_from_checkpoint(checkpoint_file_path, MODEL_CONFIG, device)

    # Now you can use the loaded_model with the generate_text function
    generated_output.append(generate_text(loaded_model, tokenizer, INFERENCE_CONFIG, device))

In [50]:
for i,res in enumerate(generated_output):
    print(f"Checkpoint{(i+1)*1000}:\n{res.replace("\n", " ")}\n\n")

Checkpoint1000:
To live a good life, the capacity for a person is a major challenge to the health and health system (; ). In addition, the general population has a high prevalence of disability and disability, with a high prevalence of disability, but it is estimated that about 10% of the population is at high risk. The total number of people living with disability in the country is increasing in the United States (; ).  Patients with dementia are at risk of developing dementia, and may also have a high risk of developing dementia


Checkpoint2000:
To live a good life expectancy, the weight gain in the middle-aged population is also an important factor in the development of the elderly. However, the benefits of weight loss in elderly people are not well understood.  In this study, we investigated the effects of weight loss on the elderly population. We hypothesized that the effects of weight loss on the elderly population, as well as the age-related health outcomes, were assessed.Intro

## Saving to HF

In [None]:
import os, json, torch
from safetensors.torch import save_file
from huggingface_hub import create_repo, upload_folder, login

login(token="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")

REPO_ID = "kunjcr2/MedAssistGPT"
CKPT = "/content/checkpoints/checkpoint_step_17000.pt"
OUT  = "/content/medassistgpt_repo"
os.makedirs(OUT, exist_ok=True)

# --- load checkpoint -> state_dict ---
sd = torch.load(CKPT, map_location="cpu")
if isinstance(sd, dict):
    for k in ["state_dict","model_state_dict","module","model"]:
        if k in sd and isinstance(sd[k], dict):
            sd = sd[k]; break
if hasattr(sd, "state_dict"): sd = sd.state_dict()

# --- break shared storages (tied weights) ---
def break_shared_storage(state_dict):
    def storage_key(t):
        # uniquely identify underlying storage
        return (t.untyped_storage().data_ptr(), t.dtype, tuple(t.size()))
    seen = {}
    for name, t in list(state_dict.items()):
        if not torch.is_tensor(t):
            continue
        key = storage_key(t)
        if key in seen:
            # clone to new storage to satisfy safetensors
            state_dict[name] = t.clone()
        else:
            seen[key] = name
    return state_dict

sd = break_shared_storage(sd)

# save safetensors
save_file(sd, f"{OUT}/model.safetensors")

# minimal configs (your values)
import json
config = {
    "model_type": "medassist_gpt",
    "architectures": ["MedAssistGPTForCausalLM"],
    "vocab_size": 50281,
    "hidden_size": 512,
    "num_attention_heads": 16,
    "num_key_value_heads": 4,
    "intermediate_size": 2048,
    "num_hidden_layers": 16,
    "max_position_embeddings": 1024,
    "layer_norm_eps": 1e-5,
    "hidden_dropout_prob": 0.1,
    "attention_dropout_prob": 0.1,
    "torch_dtype": "bfloat16"
}
open(f"{OUT}/config.json","w").write(json.dumps(config, indent=2))

tok = {"tokenizer_type": "tiktoken", "tiktoken_model": "p50k_base"}
open(f"{OUT}/tokenizer_config.json","w").write(json.dumps(tok, indent=2))

create_repo(REPO_ID, exist_ok=True, private=False)
upload_folder(repo_id=REPO_ID, folder_path=OUT, commit_message="weights+config")
print("done")