# BookNLP Maximum Performance Quote Attribution (Unified)

**Goal**: Train the max-performance quote attribution model (8090% accuracy) in either Kaggle or Colab via one RUN_ENV-aware notebook.

**Features**
- DeBERTa-v3-large with quote/candidate masks + [QUOTE], [ALTQUOTE], [PAR]
- Candidate-level softmax with label smoothing; optional R-Drop; optional temperature scaling
- Optional multi-source loading + genre-balanced sampler; PDNC fallback; configurable hard negatives
- Curriculum sampler + light augmentation; gradient checkpointing + FP16
- Auto checkpoint/resume (model/optimizer/scheduler/best_acc) with cadence set by RUN_ENV; bucketed eval + placeholder postprocess hook

**Requirements**
- Kaggle: T4 x2 accelerator; storage under `/kaggle/working`
- Colab: T4 GPU; storage in Drive at `/content/drive`

**Quick start**
1) Set `RUN_ENV = "kaggle"` or `"colab"` in the next cell (default: kaggle).
2) Kaggle: no Drive mount; repo/output in `/kaggle/working`; multi-GPU via `accelerate`; checkpoints/evals every 500 steps; auto-resume from latest `checkpoint_*.pt`.
3) Colab: mounts Drive to `/content/drive`; repo in `/content`; outputs in Drive; single-GPU (no DDP) with gradient accumulation; checkpoints/evals every 300 steps; auto-resume from latest `checkpoint_*.pt`.
4) Run all cellsdata is cloned automatically from the repo.



## RUN_ENV toggle
Set `RUN_ENV = "kaggle"` or `"colab"` in the next cell (default: kaggle). Paths, checkpoint cadence, and mounts adjust automatically. Kaggle uses multi-GPU via `accelerate` (no Drive mount); Colab mounts Drive, runs single-GPU with gradient accumulation.



In [None]:
import os, sys, torch

# CURSOR: Toggle once; everything else keys off this value
RUN_ENV = os.environ.get("RUN_ENV", "kaggle").strip().lower()
ENV_CFG = {
    "kaggle": {
        "base_dir": "/kaggle/working",
        "repo_dir": "/kaggle/working/speaker-attribution-acl2023",
        "training_repo_dir": "/kaggle/working/quote-attribution-training",
        "output_root": "/kaggle/working",
        "checkpoint_every": 500,
        "eval_every": 500,
        "grad_accum": 8,
        "mount_drive": False,
    },
    "colab": {
        "base_dir": "/content/drive/MyDrive/quote_attribution",
        "repo_dir": "/content/speaker-attribution-acl2023",
        "training_repo_dir": "/content/quote-attribution-training",
        "output_root": "/content/drive/MyDrive/quote_attribution",
        "checkpoint_every": 300,
        "eval_every": 300,
        "grad_accum": 16,
        "mount_drive": True,
    },
}
assert RUN_ENV in ENV_CFG, f"Unsupported RUN_ENV: {RUN_ENV}"
ENV = ENV_CFG[RUN_ENV]

if ENV["mount_drive"]:
    from google.colab import drive
    drive.mount("/content/drive")
    os.makedirs(ENV["base_dir"], exist_ok=True)

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    raise RuntimeError("GPU not available; enable a GPU runtime.")

# Data root (datasets auto-downloaded later based on CONFIG['datasets'])
REPO_DIR = ENV["repo_dir"]
os.makedirs(REPO_DIR, exist_ok=True)
print(f"Data repo root: {REPO_DIR}")

# Clone training code repository
TRAINING_REPO_DIR = ENV["training_repo_dir"]
if not os.path.exists(TRAINING_REPO_DIR):
    print("\n[DOWN] Cloning training code repository...")
    !git clone https://github.com/bohdan-natsevych/quote-attribution-training.git {TRAINING_REPO_DIR}
else:
    print(f"[OK] Training repository present at {TRAINING_REPO_DIR}")

# Add training repo to Python path for imports
if TRAINING_REPO_DIR not in sys.path:
    sys.path.insert(0, TRAINING_REPO_DIR)
    print(f"[OK] Added {TRAINING_REPO_DIR} to Python path")

DATA_ROOT = f"{REPO_DIR}/data"
os.makedirs(DATA_ROOT, exist_ok=True)
print(f"Data root: {DATA_ROOT}")

BASE_DIR = ENV["base_dir"]
OUTPUT_ROOT = ENV["output_root"]
os.makedirs(OUTPUT_ROOT, exist_ok=True)
print(f"Output root: {OUTPUT_ROOT}")


In [None]:
%pip install -q transformers>=4.30.0 accelerate>=0.20.0 datasets scikit-learn tqdm nlpaug nltk matplotlib seaborn


In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================

# CURSOR: Import to get available datasets
import sys
from enum import Enum
from dataclasses import dataclass, field
from typing import List, Union
if TRAINING_REPO_DIR not in sys.path:
    sys.path.insert(0, TRAINING_REPO_DIR)

from data.multi_source_data import MultiSourceDataLoader

# Define Dataset enum from available datasets
class Dataset(str, Enum):
    """Available datasets for quote attribution training."""
    PDNC = "pdnc"
    LITBANK = "litbank"
    DIRECTQUOTE = "directquote"
    # CURSOR: QUOTEBANK = "quotebank"
    
    @classmethod
    def get_all(cls):
        """Get all dataset values."""
        return [d.value for d in cls]
    
    @classmethod
    def validate(cls, datasets: list):
        """Validate that all datasets are in the enum."""
        valid = cls.get_all()
        for ds in datasets:
            if ds not in valid:
                raise ValueError(f"Invalid dataset '{ds}'. Must be one of: {valid}")
        return True


@dataclass
class TrainingConfig:
    """Validated configuration for quote attribution training."""
    
    # Core settings
    target_level: int
    name: str
    target_accuracy: float
    
    # Training hyperparameters
    epochs: int = 15
    batch_size: int = 8
    lr: float = 5e-6
    
    # Dataset configuration
    datasets: List[str] = field(default_factory=lambda: ['pdnc'])
    fold_selection: Union[str, List[int]] = "all"
    
    # Advanced features
    use_augmentation: bool = True
    use_curriculum: bool = True
    focal_gamma: float = 2.0
    label_smoothing: float = 0.1
    r_drop_alpha: float = 0.0
    hard_negative_topk: int = 2
    max_candidates: int = 10
    shuffle_candidates: bool = True
    calibrate_temperature: bool = True
    balance_genres: bool = False
    min_genre_acc: float = 0.75
    
    # Model configuration (populated after init)
    base_model: str = 'microsoft/deberta-v3-large'
    max_length: int = 512
    gradient_accumulation_steps: int = 4
    checkpoint_every: int = 500
    eval_every: int = 500
    fp16: bool = True
    gradient_checkpointing: bool = True
    seed: int = 42
    output_dir: str = ''
    multi_source_base: str = ''
    
    def __post_init__(self):
        """Validate configuration after initialization."""
        # Validate target level
        if self.target_level == 3:
            raise NotImplementedError(
                "TARGET_LEVEL=3 ensemble training not implemented in notebook. "
                "Train 5 folds with TARGET_LEVEL=2, then use models/ensemble.py for ensemble inference."
            )
        if self.target_level not in [1, 2]:
            raise ValueError(f"target_level must be 1 or 2, got {self.target_level}")
        
        # Validate datasets
        Dataset.validate(self.datasets)
        
        # Validate hyperparameters
        if self.batch_size < 1:
            raise ValueError(f"batch_size must be >= 1, got {self.batch_size}")
        if self.epochs < 1:
            raise ValueError(f"epochs must be >= 1, got {self.epochs}")
        if not 0 <= self.label_smoothing <= 0.5:
            raise ValueError(f"label_smoothing must be in [0, 0.5], got {self.label_smoothing}")
        if self.focal_gamma < 0:
            raise ValueError(f"focal_gamma must be >= 0, got {self.focal_gamma}")
        
        if self.hard_negative_topk < 0:
            raise ValueError(f"hard_negative_topk must be >= 0, got {self.hard_negative_topk}")
        if self.max_candidates < 2:
            raise ValueError(f"max_candidates must be >= 2, got {self.max_candidates}")
        
        # Validate fold selection
        if isinstance(self.fold_selection, list):
            if not all(isinstance(f, int) and 0 <= f < 5 for f in self.fold_selection):
                raise ValueError(f"fold_selection list must contain integers in [0, 4], got {self.fold_selection}")
        elif self.fold_selection != "all":
            raise ValueError(f"fold_selection must be 'all' or list of ints, got {self.fold_selection}")


# Configuration selection
TARGET_LEVEL = 1  # 1=PDNC, 2=multi-source
FOLD_SELECTION = [0,1]

# Define configurations
if TARGET_LEVEL == 1:
    CONFIG = TrainingConfig(
        target_level=1,
        name='Target 1: DeBERTa-large + Augmentation',
        target_accuracy=0.85,
        epochs=6,  # CURSOR: Reduced from 15 to fit 5 folds in 30 hours on 2xT4
        batch_size=4,
        lr=5e-6,
        datasets=['pdnc'],
        fold_selection=FOLD_SELECTION,
        use_augmentation=True,
        use_curriculum=True,
        focal_gamma=2.0,
        label_smoothing=0.1,
        hard_negative_topk=2,
        calibrate_temperature=True,
        balance_genres=False,
    )
elif TARGET_LEVEL == 2:
    CONFIG = TrainingConfig(
        target_level=2,
        name='Target 2: Multi-Source + Genre Balancing',
        target_accuracy=0.88,
        epochs=8,  # CURSOR: Reduced from 15 for reasonable training time
        batch_size=4,
        lr=2e-6,
        datasets=['pdnc', 'litbank', 'directquote'],
        fold_selection=FOLD_SELECTION,
        use_augmentation=True,
        use_curriculum=True,
        focal_gamma=2.0,
        label_smoothing=0.1,
        hard_negative_topk=2,
        calibrate_temperature=True,
        balance_genres=True,
        min_genre_acc=0.75,
    )
else:
    # This will raise NotImplementedError in __post_init__
    CONFIG = TrainingConfig(
        target_level=TARGET_LEVEL,
        name='Target 3: Ensemble (Not Implemented)',
        target_accuracy=0.90,
    )

# Update environment-specific settings
multi_source_base = f"{REPO_DIR}/data"
CONFIG.multi_source_base = multi_source_base
CONFIG.gradient_accumulation_steps = ENV['grad_accum']
CONFIG.checkpoint_every = ENV['checkpoint_every']
CONFIG.eval_every = ENV['eval_every']
CONFIG.output_dir = f"{OUTPUT_ROOT}/target_{TARGET_LEVEL}"

os.makedirs(CONFIG.output_dir, exist_ok=True)
print(f"Selected: {CONFIG.name}")
print(f"Target accuracy: {CONFIG.target_accuracy:.0%}")
print(f"Datasets: {CONFIG.datasets}")
print(f"Output dir: {CONFIG.output_dir}")
print(f"RUN_ENV: {RUN_ENV} | checkpoint_every={CONFIG.checkpoint_every} | grad_accum={CONFIG.gradient_accumulation_steps}")


In [None]:
# =============================================================================
# AUTO-DOWNLOAD AND PREPARE ALL DATASETS
# =============================================================================

from data.multi_source_data import download_datasets

downloaded_datasets = download_datasets(
    base_path=CONFIG.multi_source_base,
    datasets=CONFIG.datasets
)


In [None]:
# =============================================================================
# GPU SETUP - DataParallel for multi-GPU on Kaggle (no process spawning)
# =============================================================================

import os
import torch

NUM_GPUS = torch.cuda.device_count()
print(f"[SEARCH] Detected {NUM_GPUS} GPU(s)")
for i in range(NUM_GPUS):
    props = torch.cuda.get_device_properties(i)
    print(f"   GPU {i}: {torch.cuda.get_device_name(i)} ({props.total_memory / 1024**3:.1f} GB)")

# CURSOR: Use DataParallel for multi-GPU on Kaggle - runs in single process
# CURSOR: Avoids all DDP/DeepSpeed process spawning issues
USE_DATA_PARALLEL = NUM_GPUS > 1

if USE_DATA_PARALLEL:
    print(f"\n[OK] Multi-GPU training via DataParallel (single process)")
    print(f"   Effective batch: {CONFIG.batch_size} x {NUM_GPUS} x {CONFIG.gradient_accumulation_steps} = {CONFIG.batch_size * NUM_GPUS * CONFIG.gradient_accumulation_steps}")
else:
    print(f"\n[OK] Single-GPU training mode") 
    print(f"   Effective batch: {CONFIG.batch_size} x 1 x {CONFIG.gradient_accumulation_steps} = {CONFIG.batch_size * CONFIG.gradient_accumulation_steps}")


In [None]:
import glob, random, numpy as np, pandas as pd
from pathlib import Path
from collections import defaultdict
from dataclasses import dataclass
from typing import Optional, Dict, Any, Tuple
import time
import csv
import logging
import sys
from datetime import timedelta

# CURSOR: Multi-GPU stability env vars - must be set BEFORE torch import
import os
os.environ["NCCL_P2P_DISABLE"] = "1"           # CURSOR: Disable P2P for T4 compatibility
os.environ["NCCL_IB_DISABLE"] = "1"            # CURSOR: Disable InfiniBand (not available on Kaggle)
os.environ["TOKENIZERS_PARALLELISM"] = "false" # CURSOR: Avoid tokenizer fork issues with DDP

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
os.environ.setdefault("USE_TF", "0")  # CURSOR: Avoid TF/Keras imports; notebook uses PyTorch-only Trainer.

# CURSOR: Configure logging to display in Jupyter/Kaggle notebooks
logging.basicConfig(
    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    datefmt='%H:%M:%S',
    level=logging.INFO,
    handlers=[logging.StreamHandler(sys.stdout)],
    force=True  # CURSOR: Force reconfiguration even if logging was already set up
)

# CURSOR: Configure tqdm for notebook display
from tqdm.auto import tqdm
os.environ["TQDM_NOTEBOOK"] = "1"

from transformers import Trainer, TrainingArguments, EvalPrediction
from transformers import logging as hf_logging
hf_logging.set_verbosity_info()  # CURSOR: Ensure transformers logs are visible
hf_logging.enable_progress_bar()  # CURSOR: Enable progress bar in transformers

# CURSOR: Force HuggingFace Hub to show download progress
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "0"

print("[OK] Logging configured - you should see output below during training")
from sklearn.metrics import accuracy_score

from data.multi_source_data import MultiSourceDataLoader
from data.data_augmentation import QuoteAugmenter
from data.curriculum_loader import DifficultyClassifier, CurriculumSampler, CurriculumConfig
from evaluation.confidence_calibration import TemperatureScaling
from evaluation.error_analysis import ErrorAnalyzer
from optimization.post_processing import PostProcessor
from models.max_performance_model import MaxPerformanceSpeakerModel
from losses.focal_loss import CombinedLoss

# CURSOR: Try to import wandb, but don't fail if unavailable
try:
    import wandb
    WANDB_AVAILABLE = True
except ImportError:
    WANDB_AVAILABLE = False
    print("[WARN]  wandb not available, logging to CSV only")

# CURSOR: Deterministic setup for reproducibility
def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(CONFIG.seed)
print(f"GPUs: {NUM_GPUS} | FP16: {CONFIG.fp16}")


In [None]:
# =============================================================================
# DATASET CLASS AND DATA LOADING
# =============================================================================

from data.training_samples import finalize_candidate_sets, derive_quote_span

class QuoteDataset(Dataset):
    """
    Dataset for quote attribution with candidate masking.
    Handles tokenization, augmentation, and candidate mask generation.
    """
    def __init__(self, samples, tokenizer, max_length=512, augment=False, augmenter: QuoteAugmenter = None):
        self.samples, self.tok, self.max_len = samples, tokenizer, max_length
        self.augment = augment
        self.augmenter = augmenter
        self.par_id = self.tok.convert_tokens_to_ids("[PAR]")
        self.altq_id = self.tok.convert_tokens_to_ids("[ALTQUOTE]")
        self.quote_id = self.tok.convert_tokens_to_ids("[QUOTE]")

    def __len__(self):
        return len(self.samples)

    def _maybe_augment_text(self, sample: dict) -> str:
        text = sample.get('text', '')
        if not self.augment or not self.augmenter:
            return text

        # Apply augmentation with 50% probability
        if random.random() > 0.5:
            return text

        quote = sample.get('quote', '')
        q_start = sample.get('quote_start')
        q_end = sample.get('quote_end')
        if quote and (q_start is None or q_end is None):
            q_start, q_end = derive_quote_span(text, quote)

        protected = []
        if q_start is not None and q_end is not None and q_start >= 0 and q_end > q_start:
            protected = [(int(q_start), int(q_end))]

        augmented = text
        try:
            if random.random() < 0.4:
                n_synonyms = random.randint(3, 5)
                augmented = self.augmenter.synonym_replace(augmented, protected_spans=protected, n=n_synonyms)
            if random.random() < 0.25:
                augmented = self.augmenter.random_insert(augmented, protected_spans=protected, n=1)
            if random.random() < 0.2:
                augmented = self.augmenter.random_swap(augmented, protected_spans=protected, n=1)
            if random.random() < 0.15:
                augmented = self.augmenter.random_delete(augmented, protected_spans=protected, p=0.03)
            return augmented
        except Exception:
            return text

    def _ensure_quote_markers(self, text: str, quote: str, q_start: Optional[int], q_end: Optional[int]) -> str:
        if not quote:
            return text

        if q_start is not None and q_end is not None and 0 <= q_start < q_end <= len(text):
            return text[:q_start] + " [QUOTE] " + text[q_start:q_end] + " [QUOTE] " + text[q_end:]

        # Best-effort substring match
        pos = text.find(quote)
        if pos < 0:
            pos = text.lower().find(quote.lower())
        if pos >= 0:
            end = pos + len(quote)
            return text[:pos] + " [QUOTE] " + text[pos:end] + " [QUOTE] " + text[end:]

        # Fallback: append quote with markers
        return text + " [QUOTE] " + quote + " [QUOTE]"

    def _truncate_around_quote(self, base_ids: list, room: int) -> list:
        if len(base_ids) <= room:
            return base_ids

        qpos = [i for i, tid in enumerate(base_ids) if tid == self.quote_id]
        if len(qpos) >= 2:
            q0, q1 = qpos[0], qpos[1]
            span_len = (q1 - q0) + 1

            if span_len >= room:
                start = max(0, q0)
                end = min(len(base_ids), start + room)
                return base_ids[start:end]

            left_ctx = (room - span_len) // 2
            start = max(0, q0 - left_ctx)
            end = start + room
            if end <= q1:
                end = q1 + 1
                start = max(0, end - room)
            end = min(len(base_ids), end)
            start = max(0, end - room)
            return base_ids[start:end]

        return base_ids[:room]

    def _build_quote_mask(self, base_ids: list, quote: str) -> list:
        mask = [0] * len(base_ids)
        qpos = [i for i, tid in enumerate(base_ids) if tid == self.quote_id]
        if len(qpos) >= 2:
            for i in range(qpos[0] + 1, qpos[1]):
                mask[i] = 1
            return mask

        # Fallback: subsequence match for quote token ids
        if quote:
            q_ids = self.tok.encode(quote, add_special_tokens=False)
            if q_ids and len(q_ids) <= len(base_ids):
                for i in range(0, len(base_ids) - len(q_ids) + 1):
                    if base_ids[i:i+len(q_ids)] == q_ids:
                        for j in range(i, i + len(q_ids)):
                            mask[j] = 1
                        break

        if sum(mask) == 0:
            mask = [1] * len(base_ids)
        return mask

    def _encode(self, sample):
        # CURSOR: Augment context only; keep quote span intact via protected spans.
        context_text = self._maybe_augment_text(sample)
        quote = sample.get('quote', '')

        q_start = sample.get('quote_start')
        q_end = sample.get('quote_end')
        if quote and (q_start is None or q_end is None):
            q_start, q_end = derive_quote_span(context_text, quote)

        base_text = self._ensure_quote_markers(context_text, quote, q_start, q_end)
        base_ids = self.tok.encode(base_text, add_special_tokens=False)

        candidates = sample['candidates']
        cand_ids = [self.tok.encode(c, add_special_tokens=False) for c in candidates]

        reserved = 1 + 1 + sum(1 + len(ci) for ci in cand_ids)
        room = max(self.max_len - reserved, 8)
        base_ids = self._truncate_around_quote(base_ids, room)

        base_quote_mask = self._build_quote_mask(base_ids, quote)

        tokens = [self.par_id] + base_ids + [self.altq_id]
        quote_mask = [0] + base_quote_mask + [0]

        # Track candidate start/end positions
        cand_spans = []
        for ci in cand_ids:
            tokens.append(self.par_id)
            quote_mask.append(0)
            start = len(tokens)
            tokens.extend(ci)
            quote_mask.extend([0] * len(ci))
            end = len(tokens)
            cand_spans.append((start, end))

        if not cand_spans:
            tokens.append(self.par_id)
            quote_mask.append(0)
            cand_spans.append((len(tokens), len(tokens)))

        # Create masks with uniform length
        final_len = len(tokens)
        cand_masks = []
        for start, end in cand_spans:
            mask = [0] * final_len
            for i in range(start, min(end, final_len)):
                mask[i] = 1
            cand_masks.append(mask)

        # Extend quote_mask to match final token length
        if len(quote_mask) < final_len:
            quote_mask += [0] * (final_len - len(quote_mask))

        tokens = tokens[: self.max_len]
        quote_mask = quote_mask[: self.max_len]
        attention = [1] * len(tokens)
        if len(tokens) < self.max_len:
            pad_len = self.max_len - len(tokens)
            tokens += [self.tok.pad_token_id] * pad_len
            attention += [0] * pad_len
            quote_mask += [0] * pad_len
            cand_masks = [cm + [0] * pad_len for cm in cand_masks]
        else:
            cand_masks = [cm[: self.max_len] for cm in cand_masks]

        return tokens, attention, quote_mask, cand_masks

    def __getitem__(self, idx):
        sample = self.samples[idx]
        tokens, attention, quote_mask, cand_masks = self._encode(sample)
        label_idx = sample['gold_index'] if sample['gold_index'] >= 0 else -100
        return {
            'input_ids': torch.tensor(tokens, dtype=torch.long),
            'attention_mask': torch.tensor(attention, dtype=torch.long),
            'quote_mask': torch.tensor(quote_mask, dtype=torch.long),
            'candidate_masks': [torch.tensor(cm, dtype=torch.long) for cm in cand_masks],
            'label_idx': torch.tensor(label_idx, dtype=torch.long),
            'quote_id': sample['quote_id']
        }


# =============================================================================
# DATA LOADING HELPERS WITH ERROR HANDLING
# =============================================================================

# CURSOR: Candidate construction/shuffling is handled by finalize_candidate_sets(...) for all datasets.


def _convert_multi_source_sample(s, idx, source='unknown'):
    """Convert MultiSourceDataLoader sample format to QuoteDataset format."""
    gold = s.get('speaker', '') or s.get('gold', '')
    text = s.get('text') or ''
    quote = s.get('quote') or ''
    quote_start = s.get('quote_start')
    quote_end = s.get('quote_end')

    if not text:
        text = quote or ''

    source = s.get('source', source)
    genre = s.get('genre', source)
    book_id = s.get('book_id', '')

    if not gold or not text:
        return None

    qid = f"{source}:{book_id}:{idx}" if book_id else f"{source}:{idx}"

    out = {
        'quote_id': qid,
        'text': text,
        'quote': quote,
        'quote_start': quote_start,
        'quote_end': quote_end,
        'gold': gold,
        'genre': genre,
        'source': source,
        'book_id': book_id,
    }

    # CURSOR: Preserve provided candidate sets (e.g., PDNC) when available.
    if isinstance(s.get('candidates'), list) and s.get('candidates'):
        out['candidates'] = s.get('candidates')
        out['gold_index'] = s.get('gold_index', -1)

    return out


class PDNCFoldIterator:
    """Lazy loading iterator for PDNC folds. Loads folds on-demand to save memory."""
    
    def __init__(self, base_path: str, n_folds: int = 5, seed: int = 42, validate_on_init: bool = True):
        self.base_path = base_path
        self.n_folds = n_folds
        self.seed = seed
        self._book_assignments = None
        self._all_samples = None
        
        # CURSOR: Pre-flight validation - catch data issues early before training starts
        if validate_on_init:
            self._preflight_validation()
    
    def _preflight_validation(self):
        """Validate data is accessible before training starts."""
        print(f"   [SEARCH] Running pre-flight data validation...")
        
        # Check base path exists
        from pathlib import Path
        if not Path(self.base_path).exists():
            raise RuntimeError(
                f"Data path does not exist: {self.base_path}. "
                "Run dataset download cell first."
            )
        
        # Check PDNC subdirectory exists
        pdnc_path = Path(self.base_path) / 'pdnc'
        if not pdnc_path.exists():
            raise RuntimeError(
                f"PDNC dataset not found at {pdnc_path}. "
                "Ensure dataset download completed successfully."
            )
        
        # Try loading a small sample to validate format
        try:
            loader = MultiSourceDataLoader(base_path=self.base_path, datasets=['pdnc'], seed=self.seed)
            loader.load_all()
            
            total_samples = sum(len(samples) for samples in loader.data_by_genre.values())
            if total_samples == 0:
                raise RuntimeError("PDNC loader returned 0 samples - check dataset format")
            
            # Check book distribution for fold viability
            by_book = defaultdict(int)
            for genre_samples in loader.data_by_genre.values():
                for s in genre_samples:
                    by_book[s.get('book_id', 'unknown')] += 1
            
            n_books = len(by_book)
            if n_books < self.n_folds:
                raise ValueError(
                    f"Only {n_books} books found, but {self.n_folds} folds requested. "
                    "Reduce n_folds or check dataset."
                )
            
            print(f"   [OK] Pre-flight passed: {total_samples:,} samples from {n_books} books")
            
        except Exception as e:
            raise RuntimeError(
                f"Pre-flight data validation failed: {e}. "
                "Check dataset download and format."
            ) from e
    
    def _load_pdnc_data(self):
        """Load PDNC data using MultiSourceDataLoader."""
        if self._all_samples is not None:
            return
        
        loader = MultiSourceDataLoader(base_path=self.base_path, datasets=['pdnc'], seed=self.seed)
        loader.load_all()
        
        all_samples = []
        for genre_samples in loader.data_by_genre.values():
            all_samples.extend(genre_samples)
        
        if not all_samples:
            raise RuntimeError(
                f"Failed to load PDNC data from {self.base_path}. "
                "Check that dataset download completed successfully."
            )
        
        self._all_samples = all_samples
        
        # Group samples by book_id for leave-book-out cross-validation
        by_book = defaultdict(list)
        for i, s in enumerate(all_samples):
            book_id = s.get('book_id', 'unknown')
            by_book[book_id].append((i, s))
        
        book_ids = sorted(by_book.keys())
        n_books = len(book_ids)
        
        if n_books < self.n_folds:
            raise ValueError(
                f"Only {n_books} books found, but {self.n_folds} folds requested. "
                f"Reduce n_folds or check dataset."
            )
        
        # Assign books to folds
        random.seed(self.seed)
        shuffled_books = book_ids.copy()
        random.shuffle(shuffled_books)
        
        fold_book_assignments = [[] for _ in range(self.n_folds)]
        for i, book_id in enumerate(shuffled_books):
            fold_book_assignments[i % self.n_folds].append(book_id)
        
        self._book_assignments = {
            'fold_books': fold_book_assignments,
            'by_book': by_book,
            'all_books': shuffled_books
        }
    
    def load_fold(self, fold_idx: int, other_datasets: list = None) -> Tuple[list, list, list]:
        """Load a specific fold on-demand."""
        # Validate fold index
        if fold_idx < 0 or fold_idx >= self.n_folds:
            raise ValueError(
                f"Fold index {fold_idx} out of range [0, {self.n_folds-1}]"
            )
        
        # Ensure data is loaded
        self._load_pdnc_data()
        
        # Get book assignments for this fold
        fold_books = self._book_assignments['fold_books']
        by_book = self._book_assignments['by_book']
        all_books = self._book_assignments['all_books']
        
        test_books = set(fold_books[fold_idx])
        val_books = set(fold_books[(fold_idx + 1) % self.n_folds])
        train_books = set(all_books) - test_books - val_books
        
        train_samples, val_samples, test_samples = [], [], []
        
        for book_id, samples_list in by_book.items():
            for i, s in samples_list:
                converted = _convert_multi_source_sample(s, i, 'pdnc')
                if converted is None:
                    continue
                if book_id in test_books:
                    test_samples.append(converted)
                elif book_id in val_books:
                    val_samples.append(converted)
                else:
                    train_samples.append(converted)
        
        # Validate splits
        if not train_samples or not val_samples or not test_samples:
            raise RuntimeError(
                f"Fold {fold_idx} has empty splits: "
                f"train={len(train_samples)}, val={len(val_samples)}, test={len(test_samples)}"
            )
        
        # Combine with other datasets if provided
        if other_datasets:
            print(f"   + Adding datasets: {other_datasets}")
            other_train, other_val, other_test = load_datasets(self.base_path, other_datasets)
            train_samples = train_samples + other_train
            val_samples = val_samples + other_val
            test_samples = test_samples + other_test
        
        return train_samples, val_samples, test_samples


def load_datasets(base_path: str, datasets: list):
    """
    Load datasets using MultiSourceDataLoader and convert to training format.
    """
    if not datasets:
        raise ValueError("datasets list is empty")
    
    loader = MultiSourceDataLoader(base_path=base_path, datasets=datasets, seed=CONFIG.seed)
    loader.load_all()
    
    # Use the proper split_by_genre method from the module
    train_samples, val_samples, test_samples = loader.split_by_genre(
        val_ratio=0.1,
        test_ratio=0.1
    )
    
    # Convert using unified helper function
    train_converted = [_convert_multi_source_sample(s, i) for i, s in enumerate(train_samples)]
    train_converted = [s for s in train_converted if s is not None]
    
    val_converted = [_convert_multi_source_sample(s, i) for i, s in enumerate(val_samples)]
    val_converted = [s for s in val_converted if s is not None]
    
    test_converted = [_convert_multi_source_sample(s, i) for i, s in enumerate(test_samples)]
    test_converted = [s for s in test_converted if s is not None]
    
    # Validate we got data
    if not train_converted:
        raise RuntimeError(f"No training samples loaded from datasets: {datasets}")
    
    return train_converted, val_converted, test_converted


In [None]:
# =============================================================================
# PREPARE FOLD LOADING
# =============================================================================

datasets_to_load = CONFIG.datasets
use_pdnc_folds = 'pdnc' in datasets_to_load
other_datasets = [d for d in datasets_to_load if d != 'pdnc']

if use_pdnc_folds:
    fold_selection = CONFIG.fold_selection

    if fold_selection == "all":
        FOLDS_TO_TRAIN = list(range(5))
    elif isinstance(fold_selection, list):
        FOLDS_TO_TRAIN = fold_selection
    else:
        raise ValueError(f"Invalid fold_selection: {fold_selection}. Must be 'all' or list of ints.")

    print(f"[NOTE] PDNC Folds to train: {FOLDS_TO_TRAIN}")
    
    # Create lazy fold iterator
    fold_iterator = PDNCFoldIterator(
        base_path=CONFIG.multi_source_base,
        n_folds=5,
        seed=CONFIG.seed
    )
    print(f"   [OK] Lazy fold iterator created (folds load on-demand)")
else:
    FOLDS_TO_TRAIN = [0]  # Single iteration for multi-dataset
    fold_iterator = None
    print(f"[NOTE] Training with datasets: {datasets_to_load}")

print(f"[LOOP] Will train {len(FOLDS_TO_TRAIN)} fold(s): {FOLDS_TO_TRAIN}")


In [None]:
# CURSOR: No-op placeholder (do not delete the training repo mid-session).

In [None]:
# =============================================================================
# TRAINING LOGGER AND CHECKPOINT MANAGEMENT
# =============================================================================

class TrainingLogger:
    """Logs training metrics to CSV with optional wandb integration."""
    
    def __init__(self, log_dir: str, run_name: str, config: dict = None):
        self.log_dir = Path(log_dir)
        self.log_dir.mkdir(parents=True, exist_ok=True)
        self.csv_path = self.log_dir / "training_log.csv"
        self.run_name = run_name
        
        # Initialize CSV
        if not self.csv_path.exists():
            with open(self.csv_path, 'w', newline='') as f:
                writer = csv.writer(f)
                writer.writerow(['timestamp', 'fold', 'epoch', 'step', 'metric', 'value'])
        
        # Initialize wandb if available (skip to avoid blocking)
        self.wandb_run = None
        # CURSOR: Disabled wandb by default to prevent hanging on login prompts
        # Set WANDB_MODE=online and login first if you want wandb
        if WANDB_AVAILABLE and config and os.environ.get("WANDB_MODE") == "online":
            try:
                import signal
                # CURSOR: Timeout wandb init to prevent indefinite hanging
                self.wandb_run = wandb.init(
                    project="quote-attribution-training",
                    name=run_name,
                    config=config,
                    reinit=True,
                    settings=wandb.Settings(init_timeout=30)  # 30 second timeout
                )
                print("[OK] wandb logging enabled")
            except Exception as e:
                print(f"[WARN]  wandb init failed (using CSV only): {e}")
        else:
            print("[INFO]  wandb disabled (set WANDB_MODE=online to enable)")
    
    def log(self, metrics: dict, fold: int = 0, epoch: int = 0, step: int = 0):
        """Log metrics to CSV and wandb."""
        timestamp = time.time()
        
        # Log to CSV
        with open(self.csv_path, 'a', newline='') as f:
            writer = csv.writer(f)
            for metric_name, value in metrics.items():
                writer.writerow([timestamp, fold, epoch, step, metric_name, value])
        
        # Log to wandb
        if self.wandb_run:
            try:
                wandb.log({**metrics, 'fold': fold, 'epoch': epoch, 'step': step})
            except Exception:
                pass
    
    def finish(self):
        """Close wandb run."""
        if self.wandb_run:
            try:
                wandb.finish()
            except Exception:
                pass


def find_latest_checkpoint(fold_output_dir: str) -> Optional[str]:
    """Find most recent checkpoint for fold."""
    fold_path = Path(fold_output_dir)
    if not fold_path.exists():
        return None
    
    checkpoints = list(fold_path.glob("checkpoint-*/"))
    if not checkpoints:
        return None
    
    # Sort by step number
    checkpoints.sort(key=lambda p: int(p.name.split('-')[1]))
    return str(checkpoints[-1])


def cleanup_old_checkpoints(fold_output_dir: str, keep_last: int = 2):
    """
    Smart checkpoint cleanup: keep last N checkpoints + best model.
    Runs at fold boundaries to avoid training slowdown.
    """
    fold_path = Path(fold_output_dir)
    if not fold_path.exists():
        return
    
    checkpoints = list(fold_path.glob("checkpoint-*/"))
    if len(checkpoints) <= keep_last:
        return  # Nothing to clean
    
    # Sort by step number
    checkpoints.sort(key=lambda p: int(p.name.split('-')[1]))
    
    # Keep only last N
    to_delete = checkpoints[:-keep_last]
    
    for checkpoint_path in to_delete:
        try:
            import shutil
            shutil.rmtree(checkpoint_path)
            print(f"   [CLEAN]  Cleaned old checkpoint: {checkpoint_path.name}")
        except Exception as e:
            print(f"   [WARN]  Failed to delete {checkpoint_path}: {e}")


# =============================================================================
# TRAINING SETUP AND HELPERS
# =============================================================================

# CURSOR: Enable gradient checkpointing for all GPU configs - modern PyTorch/HF Trainer
# CURSOR: handles this correctly with DDP. Required for DeBERTa-large on T4 GPUs.
USE_GRADIENT_CHECKPOINTING = CONFIG.gradient_checkpointing
print(f"[CONFIG] Gradient checkpointing: {USE_GRADIENT_CHECKPOINTING}")

# Collate function for variable-length candidate masks
def collate_fn(batch):
    max_cands = max(len(item['candidate_masks']) for item in batch)
    input_ids = torch.stack([b['input_ids'] for b in batch])
    attention_mask = torch.stack([b['attention_mask'] for b in batch])
    quote_mask = torch.stack([b['quote_mask'] for b in batch])
    cand_masks, cand_attn = [], []
    for b in batch:
        masks = b['candidate_masks']
        orig_len = len(masks) or 1
        if not masks:
            masks = [torch.zeros_like(b['input_ids'])]
        pad_count = max_cands - orig_len
        if pad_count > 0:
            masks = masks + [torch.zeros_like(masks[0])] * pad_count
        cand_masks.append(torch.stack(masks))
        cand_attn.append(torch.tensor([1] * orig_len + [0] * pad_count, dtype=torch.long))
    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'quote_mask': quote_mask,
        'candidate_masks': torch.stack(cand_masks),
        'candidate_attention_mask': torch.stack(cand_attn),
        'labels': torch.stack([b['label_idx'] for b in batch]),
    }

# CURSOR: Custom callback for visible progress output in notebooks
from transformers import TrainerCallback

class ProgressPrintCallback(TrainerCallback):
    """Prints training progress to stdout for visibility in notebooks."""
    
    def __init__(self, print_every=50):
        self.print_every = print_every
        self.last_print_step = 0
    
    def on_step_end(self, args, state, control, **kwargs):
        if state.global_step - self.last_print_step >= self.print_every:
            loss = state.log_history[-1].get('loss', 0) if state.log_history else 0
            lr = state.log_history[-1].get('learning_rate', 0) if state.log_history else 0
            epoch = state.epoch or 0
            total_steps = state.max_steps or 1
            pct = (state.global_step / total_steps) * 100
            print(f"   [STATS] Step {state.global_step}/{total_steps} ({pct:.1f}%) | Epoch {epoch:.2f} | Loss: {loss:.4f} | LR: {lr:.2e}", flush=True)
            self.last_print_step = state.global_step
    
    def on_epoch_end(self, args, state, control, **kwargs):
        epoch = int(state.epoch) if state.epoch else 0
        print(f"   [LOOP] Epoch {epoch} complete", flush=True)
    
    def on_evaluate(self, args, state, control, metrics=None, **kwargs):
        if metrics:
            acc = metrics.get('eval_accuracy', 0)
            loss = metrics.get('eval_loss', 0)
            print(f"   [GRAPH] EVAL @ step {state.global_step}: accuracy={acc:.4f}, loss={loss:.4f}", flush=True)


# Custom Trainer for our model
class QuoteAttributionTrainer(Trainer):
    """Custom trainer that handles our model's unique input format."""
    
    def __init__(self, loss_fn=None, training_logger=None, **kwargs):
        super().__init__(**kwargs)
        self.custom_loss_fn = loss_fn
        self.ce_loss = nn.CrossEntropyLoss(ignore_index=-100)
        self.training_logger = training_logger

    def _get_train_sampler(self, train_dataset: Optional[Dataset] = None):
        train_dataset = train_dataset or self.train_dataset
        if train_dataset is None or not hasattr(train_dataset, 'samples'):
            return super()._get_train_sampler(train_dataset)

        samples = train_dataset.samples

        if CONFIG.use_curriculum:
            classifier = DifficultyClassifier()
            difficulty_indices = classifier.classify_dataset(samples)

            sample_weights = None
            if CONFIG.balance_genres:
                from collections import Counter
                genres = [s.get('genre', 'unknown') for s in samples]
                counts = Counter(genres)
                total = len(genres)
                n_genres = max(1, len(counts))
                sample_weights = [total / (n_genres * counts[g]) for g in genres]

            # CURSOR: Curriculum sampler now gets epoch from trainer state
            current_epoch = int(self.state.epoch) if self.state.epoch is not None else 0
            return CurriculumSampler(
                difficulty_indices=difficulty_indices,
                config=CurriculumConfig(),
                total_epochs=CONFIG.epochs,
                current_epoch=current_epoch,
                batch_size=self.args.per_device_train_batch_size,
                seed=CONFIG.seed,
                sample_weights=sample_weights,
            )

        if CONFIG.balance_genres:
            from collections import Counter
            from torch.utils.data import WeightedRandomSampler
            genres = [s.get('genre', 'unknown') for s in samples]
            counts = Counter(genres)
            total = len(genres)
            n_genres = max(1, len(counts))
            weights = [total / (n_genres * counts[g]) for g in genres]
            return WeightedRandomSampler(weights=weights, num_samples=len(weights), replacement=True)

        return super()._get_train_sampler(train_dataset)

    def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
        labels = inputs.pop('labels')

        def forward_once():
            logits, _ = model(
                inputs['input_ids'],
                inputs['attention_mask'],
                inputs['quote_mask'],
                inputs['candidate_masks'],
                inputs['candidate_attention_mask']
            )
            if inputs['candidate_attention_mask'] is not None:
                logits = logits.masked_fill(inputs['candidate_attention_mask'] == 0, -1e9)
            return logits

        logits1 = forward_once()
        logits_out = logits1

        # CURSOR: Enable R-Drop when configured by performing a second forward pass with dropout.
        if self.custom_loss_fn is not None and getattr(self.custom_loss_fn, 'r_drop_alpha', 0) > 0:
            logits2 = forward_once()
            loss = self.custom_loss_fn(logits1, labels, logits2=logits2)
            logits_out = (logits1 + logits2) / 2
        elif self.custom_loss_fn is not None:
            loss = self.custom_loss_fn(logits1, labels)
        else:
            loss = self.ce_loss(logits1, labels)

        return (loss, {'logits': logits_out}) if return_outputs else loss

# Metrics function
def compute_metrics(eval_pred: EvalPrediction) -> Dict[str, float]:
    logits, labels = eval_pred.predictions, eval_pred.label_ids
    preds = np.argmax(logits, axis=-1)
    mask = labels >= 0
    acc = (preds[mask] == labels[mask]).mean()
    return {'accuracy': float(acc)}

# Loss function - use CombinedLoss from module
def get_loss_fn():
    from losses.focal_loss import CombinedLoss
    return CombinedLoss(
        focal_gamma=CONFIG.focal_gamma,
        label_smoothing=CONFIG.label_smoothing,
        r_drop_alpha=CONFIG.r_drop_alpha,
        use_r_drop=CONFIG.r_drop_alpha > 0,
        ignore_index=-100,
    )

os.makedirs(CONFIG.output_dir, exist_ok=True)

print("[OK] Training helpers ready!")

In [None]:
# =============================================================================
# MULTI-FOLD TRAINING LOOP
# =============================================================================

import gc
import matplotlib.pyplot as plt
import seaborn as sns

# CURSOR: Print immediately so user knows cell is running
print("=" * 70, flush=True)
print(f"[START] MULTI-FOLD TRAINING: {CONFIG.name}", flush=True)
print("   [WAIT] Initializing training logger...", flush=True)

# Initialize training logger (wandb init can be slow)
training_logger = TrainingLogger(
    log_dir=CONFIG.output_dir,
    run_name=f"target_{CONFIG.target_level}_{int(time.time())}",
    config={
        'target_level': CONFIG.target_level,
        'epochs': CONFIG.epochs,
        'batch_size': CONFIG.batch_size,
        'lr': CONFIG.lr,
        'datasets': CONFIG.datasets,
        'num_gpus': NUM_GPUS,
    }
)
print("   [OK] Training logger ready", flush=True)

# Track results across all folds
fold_results = {}
all_fold_accuracies = []
fold_timings = []
print(f"   Folds to train: {FOLDS_TO_TRAIN}", flush=True)
print(f"   GPUs: {NUM_GPUS} | Batch/GPU: {CONFIG.batch_size}", flush=True)
print(f"   Effective batch: {CONFIG.batch_size * NUM_GPUS * CONFIG.gradient_accumulation_steps}", flush=True)
print("=" * 70, flush=True)

for fold_idx in FOLDS_TO_TRAIN:
    fold_start_time = time.time()
    
    print(f"\n{'='*70}")
    print(f"[DIR] FOLD {fold_idx + 1}/{len(FOLDS_TO_TRAIN)} (index={fold_idx})")
    print(f"{'='*70}")
    
    # Fold-specific output directory
    fold_output_dir = f"{CONFIG.output_dir}/fold_{fold_idx}"
    os.makedirs(fold_output_dir, exist_ok=True)
    
    # Check if fold already completed (for resume support)
    best_model_path = f"{fold_output_dir}/best_model"
    model_file = Path(best_model_path) / "pytorch_model.bin"
    model_safetensors = Path(best_model_path) / "model.safetensors"
    
    if Path(best_model_path).exists() and (model_file.exists() or model_safetensors.exists()):
        # CURSOR: Validate model is actually loadable before skipping
        try:
            from transformers import AutoConfig
            config_path = Path(best_model_path) / "config.json"
            if config_path.exists():
                AutoConfig.from_pretrained(best_model_path)
                print(f"[SKIP]  Fold {fold_idx} already trained and validated, skipping")
                print(f"   Model exists at: {best_model_path}")
                continue
            else:
                print(f"[WARN]  Fold {fold_idx} model incomplete (no config.json), re-training...")
        except Exception as e:
            print(f"[WARN]  Fold {fold_idx} model validation failed: {e}")
            print(f"   Re-training fold {fold_idx}...")
    
    # Load data for this fold (lazy loading)
    if use_pdnc_folds:
        print(f"   [WAIT] Loading PDNC fold {fold_idx}...", flush=True)
        load_start = time.time()
        train_samples, val_samples, test_samples = fold_iterator.load_fold(fold_idx, other_datasets)
        print(f"   [OK] Data loaded in {time.time() - load_start:.1f}s", flush=True)
    else:
        print(f"   [WAIT] Loading datasets: {datasets_to_load}...", flush=True)
        load_start = time.time()
        train_samples, val_samples, test_samples = load_datasets(CONFIG.multi_source_base, datasets_to_load)
        print(f"   [OK] Data loaded in {time.time() - load_start:.1f}s", flush=True)
    
    # CURSOR: Build realistic candidate sets and eliminate candidate-order label leakage.
    print(f"   [WAIT] Finalizing candidate sets...", flush=True)
    cand_start = time.time()
    train_samples = finalize_candidate_sets(
        train_samples,
        seed=CONFIG.seed + fold_idx,
        hard_negative_topk=CONFIG.hard_negative_topk,
        max_candidates=CONFIG.max_candidates,
        shuffle_candidates=CONFIG.shuffle_candidates,
    )
    print(f"      Train candidates done ({len(train_samples)} samples)", flush=True)
    val_samples = finalize_candidate_sets(
        val_samples,
        seed=CONFIG.seed + fold_idx,
        hard_negative_topk=CONFIG.hard_negative_topk,
        max_candidates=CONFIG.max_candidates,
        shuffle_candidates=CONFIG.shuffle_candidates,
    )
    print(f"      Val candidates done ({len(val_samples)} samples)", flush=True)
    test_samples = finalize_candidate_sets(
        test_samples,
        seed=CONFIG.seed + fold_idx,
        hard_negative_topk=CONFIG.hard_negative_topk,
        max_candidates=CONFIG.max_candidates,
        shuffle_candidates=CONFIG.shuffle_candidates,
    )
    print(f"   [OK] Candidate sets finalized in {time.time() - cand_start:.1f}s", flush=True)

    # CURSOR: Quick leakage check (gold_index should not collapse to a constant).
    gi_preview = [s.get('gold_index', -1) for s in train_samples[:2000]]
    if gi_preview and len(set(gi_preview)) == 1:
        print(f"[WARN]  gold_index collapsed to {gi_preview[0]} for preview set; check candidate construction.")

    if not train_samples or not val_samples or not test_samples:
        raise RuntimeError(
            f"Empty split after candidate finalization: train={len(train_samples)}, val={len(val_samples)}, test={len(test_samples)}"
        )

    print(f"   Train: {len(train_samples)} | Val: {len(val_samples)} | Test: {len(test_samples)}")
    
    # Create fresh model for each fold (important for proper cross-validation)
    print(f"   [WAIT] Initializing fresh model ({CONFIG.base_model})...", flush=True)
    
    model_start = time.time()
    set_seed(CONFIG.seed + fold_idx)  # Different seed per fold
    model = MaxPerformanceSpeakerModel(CONFIG.base_model)
    if USE_GRADIENT_CHECKPOINTING:
        model.encoder.gradient_checkpointing_enable()
    tokenizer = model.get_tokenizer()
    print(f"   [OK] Model initialized in {time.time() - model_start:.1f}s", flush=True)
    
    # Create augmenter if enabled
    augmenter = QuoteAugmenter(seed=CONFIG.seed + fold_idx) if CONFIG.use_augmentation else None
    
    # CURSOR: Curriculum scheduling is handled by a sampler in the Trainer (no pre-sorting).
    
    # Create datasets
    print(f"   [WAIT] Creating datasets (tokenizing {len(train_samples)} train + {len(val_samples)} val samples)...", flush=True)
    ds_start = time.time()
    train_dataset = QuoteDataset(
        train_samples, tokenizer, CONFIG.max_length,
        augment=CONFIG.use_augmentation, augmenter=augmenter
    )
    val_dataset = QuoteDataset(val_samples, tokenizer, CONFIG.max_length)
    print(f"   [OK] Datasets created in {time.time() - ds_start:.1f}s", flush=True)
    
    # Check for resume checkpoint
    resume_checkpoint = find_latest_checkpoint(fold_output_dir)
    if resume_checkpoint:
        print(f"[LOOP] Resuming fold {fold_idx} from {resume_checkpoint}")
    
    # CURSOR: Wrap model with DataParallel for multi-GPU (single process)
    if USE_DATA_PARALLEL:
        model = nn.DataParallel(model)
        print(f"   [OK] Model wrapped with DataParallel for {NUM_GPUS} GPUs")
    
    # CURSOR: Set batch size - with DataParallel, batch is split across GPUs automatically
    import math
    if USE_DATA_PARALLEL:
        # CURSOR: Total batch across all GPUs
        effective_batch = CONFIG.batch_size 
        effective_grad_accum = CONFIG.gradient_accumulation_steps
    else:
        # CURSOR: Single GPU
        if USE_GRADIENT_CHECKPOINTING:
            effective_batch = min(CONFIG.batch_size, 4)
        else:
            effective_batch = 2
        effective_batch = max(1, int(effective_batch))
        effective_grad_accum = max(1, int(CONFIG.gradient_accumulation_steps * math.ceil(CONFIG.batch_size / effective_batch)))
    
    print(f"   Batch total: {effective_batch} | Grad accum: {effective_grad_accum} | Effective: {effective_batch * effective_grad_accum}")
    
    # Training arguments for this fold
    training_args = TrainingArguments(
        output_dir=fold_output_dir,
        num_train_epochs=CONFIG.epochs,
        per_device_train_batch_size=effective_batch // NUM_GPUS if USE_DATA_PARALLEL else effective_batch,
        per_device_eval_batch_size=effective_batch // NUM_GPUS if USE_DATA_PARALLEL else effective_batch,
        gradient_accumulation_steps=effective_grad_accum,
        gradient_checkpointing=USE_GRADIENT_CHECKPOINTING,  # CURSOR: Memory saving for DeBERTa-large
        learning_rate=CONFIG.lr,
        weight_decay=0.01,
        fp16=CONFIG.fp16,
        fp16_full_eval=CONFIG.fp16,  # CURSOR: Must match fp16 for DeepSpeed
        warmup_ratio=0.1,
        lr_scheduler_type="cosine",
        
        # Evaluation and saving
        eval_strategy="steps",
        eval_steps=CONFIG.eval_every,
        save_strategy="steps",
        save_steps=CONFIG.checkpoint_every,
        save_total_limit=3,  # Keep 3 checkpoints for safe resume
        load_best_model_at_end=True,
        metric_for_best_model="accuracy",
        greater_is_better=True,
        
        # Logging - frequent updates for real-time monitoring
        logging_steps=10,  # Log every 10 steps for real-time feedback
        logging_first_step=True,
        logging_strategy="steps",
        report_to="none",  # Custom logging to CSV + optional wandb
        disable_tqdm=False,  # Enable progress bars
        
        # Performance
        dataloader_num_workers=0,
        dataloader_pin_memory=True,
        remove_unused_columns=False,
        
        # CURSOR: No DDP/DeepSpeed - using DataParallel instead
        
        # Seed
        seed=CONFIG.seed + fold_idx,
    )
    
    # Create trainer for this fold
    trainer = QuoteAttributionTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=collate_fn,
        compute_metrics=compute_metrics,
        loss_fn=get_loss_fn(),
        training_logger=training_logger,
        callbacks=[ProgressPrintCallback(print_every=50)],  # CURSOR: Print progress every 50 steps
    )
    
    # Train this fold
    print(f"\n   [TRAIN] Training fold {fold_idx}...", flush=True)
    print(f"   [GO]  GPU training starting NOW - you should see progress bars below!", flush=True)
    print(f"   " + "="*50, flush=True)
    train_result = trainer.train(resume_from_checkpoint=resume_checkpoint)
    print(f"   " + "="*50, flush=True)
    
    # Evaluate on validation set
    print(f"\n   [STATS] Evaluating fold {fold_idx}...")
    eval_results = trainer.evaluate()
    fold_accuracy = eval_results.get('eval_accuracy', 0.0)
    all_fold_accuracies.append(fold_accuracy)
    
    # Save best model for this fold
    trainer.save_model(best_model_path)
    print(f"   [SAVE] Model saved: {best_model_path}")
    
    # Detailed evaluation with modules
    print(f"\n   [SEARCH] Running detailed evaluation...")
    
    # 1. Confidence calibration
    if CONFIG.calibrate_temperature:
        try:
            temp_scaler = TemperatureScaling()
            # Get validation predictions for calibration
            val_preds = trainer.predict(val_dataset)
            val_logits = torch.from_numpy(val_preds.predictions)
            val_labels = torch.from_numpy(val_preds.label_ids)
            temp_scaler.calibrate(val_logits, val_labels)
            optimal_temp = temp_scaler.get_temperature()
            print(f"   [MEASURE] Optimal temperature: {optimal_temp:.3f}")
            training_logger.log({'temperature': optimal_temp}, fold=fold_idx)
        except Exception as e:
            print(f"   [WARN]  Temperature calibration failed: {e}")
    
    # 2. Error analysis (sample N=1000 for speed)
    try:
        error_analyzer = ErrorAnalyzer()
        analysis_samples = test_samples[:1000]
        test_preds = trainer.predict(QuoteDataset(analysis_samples, tokenizer, CONFIG.max_length))
        test_predictions = np.argmax(test_preds.predictions, axis=-1)
        test_confidences = np.max(torch.softmax(torch.from_numpy(test_preds.predictions), dim=-1).numpy(), axis=-1)
        
        # CURSOR: Add errors one by one using the correct ErrorAnalyzer API
        for i, (pred, label, sample, conf) in enumerate(zip(
            test_predictions, test_preds.label_ids, analysis_samples, test_confidences
        )):
            if pred != label and label >= 0:
                candidates = sample.get('candidates', [])
                pred_speaker = candidates[pred] if 0 <= pred < len(candidates) else 'unknown'
                actual_speaker = candidates[label] if 0 <= label < len(candidates) else 'unknown'
                error_analyzer.add_error(
                    sample_id=sample.get('quote_id', str(i)),
                    text=sample.get('text', ''),
                    quote=sample.get('text', '')[:200],
                    predicted=pred_speaker,
                    actual=actual_speaker,
                    candidates=candidates,
                    confidence=float(conf),
                    genre=sample.get('genre', 'unknown')
                )
        
        error_summary = error_analyzer.get_summary()
        top_patterns = error_analyzer.get_top_error_patterns(n=3)
        
        print(f"   [GRAPH] Error analysis (N={len(analysis_samples)}):")
        print(f"      Total errors: {error_summary['total_errors']}")
        print(f"      High confidence errors: {error_summary['high_confidence_errors']:.1%}")
        for error_type, count, pct in top_patterns:
            print(f"      {error_type.value}: {count} ({pct:.1%})")
        
        training_logger.log({
            'error_count': error_summary['total_errors'],
            'high_confidence_error_rate': error_summary['high_confidence_errors'],
        }, fold=fold_idx)
    except Exception as e:
        print(f"   [WARN]  Error analysis failed: {e}")
    
    # 3. Post-processing impact (if implemented)
    try:
        post_processor = PostProcessor()
        post_processor.reset_context()
        
        # CURSOR: Build proper input format for PostProcessor.batch_process()
        pp_samples = analysis_samples[:100]
        pp_preds = test_predictions[:100]
        pp_confs = test_confidences[:100]
        pp_labels = test_preds.label_ids[:100]
        
        # Convert to format expected by PostProcessor
        pp_input = []
        for i, (sample, pred, conf) in enumerate(zip(pp_samples, pp_preds, pp_confs)):
            candidates = sample.get('candidates', [])
            pred_speaker = candidates[pred] if 0 <= pred < len(candidates) else 'unknown'
            pp_input.append({
                'speaker': pred_speaker,
                'confidence': float(conf),
                'quote': sample.get('text', '')[:200],
                'context': sample.get('text', ''),
                'candidates': candidates,
                'position': i * 100,
            })
        
        pp_results = post_processor.batch_process(pp_input)
        
        # Count improvements
        n_improved = 0
        for i, (result, label) in enumerate(zip(pp_results, pp_labels)):
            if result.was_modified and label >= 0:
                candidates = pp_samples[i].get('candidates', [])
                actual_speaker = candidates[label] if 0 <= label < len(candidates) else ''
                if result.speaker == actual_speaker:
                    n_improved += 1
        
        print(f"   [FIX] Post-processing improved: {n_improved}/100 samples")
        training_logger.log({'post_process_improvement': n_improved / 100}, fold=fold_idx)
    except Exception as e:
        print(f"   [WARN]  Post-processing test failed: {e}")
    
    # Log fold results
    fold_elapsed = time.time() - fold_start_time
    fold_timings.append(fold_elapsed)
    
    fold_results[fold_idx] = {
        'accuracy': fold_accuracy,
        'train_loss': train_result.training_loss,
        'model_path': best_model_path,
        'training_time': fold_elapsed,
    }
    
    training_logger.log({
        'fold_accuracy': fold_accuracy,
        'fold_train_loss': train_result.training_loss,
        'fold_time_seconds': fold_elapsed,
    }, fold=fold_idx)
    
    print(f"\n   [OK] Fold {fold_idx} complete! ({fold_elapsed/60:.1f} minutes)")
    print(f"      Accuracy: {fold_accuracy:.4f}")
    
    # Clean up old checkpoints at fold boundary
    cleanup_old_checkpoints(fold_output_dir, keep_last=2)
    
    # Clean up GPU memory before next fold
    del model, trainer, train_dataset, val_dataset
    gc.collect()
    torch.cuda.empty_cache()

# =============================================================================
# TRAINING SUMMARY AND VISUALIZATION
# =============================================================================
print(f"\n{'='*70}")
print("[WIN] MULTI-FOLD TRAINING COMPLETE!")
print(f"{'='*70}")

print(f"\n[STATS] Results per fold:")
for fold_idx, results in fold_results.items():
    print(f"   Fold {fold_idx}: Accuracy = {results['accuracy']:.4f} ({results['training_time']/60:.1f} min)")

if len(all_fold_accuracies) > 1:
    mean_acc = np.mean(all_fold_accuracies)
    std_acc = np.std(all_fold_accuracies)
    print(f"\n[GRAPH] Cross-validation summary:")
    print(f"   Mean accuracy: {mean_acc:.4f}  {std_acc:.4f}")
    print(f"   Min: {min(all_fold_accuracies):.4f} | Max: {max(all_fold_accuracies):.4f}")
    
    training_logger.log({
        'cv_mean_accuracy': mean_acc,
        'cv_std_accuracy': std_acc,
        'cv_min_accuracy': min(all_fold_accuracies),
        'cv_max_accuracy': max(all_fold_accuracies),
    })
else:
    print(f"\n[GRAPH] Single fold accuracy: {all_fold_accuracies[0]:.4f}")

print(f"\n[TIME]  Total training time: {sum(fold_timings)/3600:.2f} hours")
print(f"[FILE] Models saved to: {CONFIG.output_dir}/fold_*/best_model")

# Generate training summary visualization
print(f"\n[STATS] Generating training visualization...")
try:
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Plot 1: Accuracy per fold
    folds = sorted(fold_results.keys())
    accs = [fold_results[f]['accuracy'] for f in folds]
    axes[0, 0].bar(folds, accs, color='steelblue')
    if len(accs) > 1:
        axes[0, 0].axhline(np.mean(accs), color='red', linestyle='--', label=f'Mean: {np.mean(accs):.4f}')
        axes[0, 0].legend()
    axes[0, 0].set_xlabel('Fold')
    axes[0, 0].set_ylabel('Accuracy')
    axes[0, 0].set_title('Accuracy by Fold')
    axes[0, 0].set_ylim([0.7, 1.0])
    
    # Plot 2: Training time per fold
    times = [fold_results[f]['training_time']/60 for f in folds]
    axes[0, 1].bar(folds, times, color='coral')
    axes[0, 1].set_xlabel('Fold')
    axes[0, 1].set_ylabel('Time (minutes)')
    axes[0, 1].set_title('Training Time by Fold')
    
    # Plot 3: Read training log for loss curves
    if training_logger.csv_path.exists():
        log_df = pd.read_csv(training_logger.csv_path)
        loss_df = log_df[log_df['metric'] == 'loss']
        if not loss_df.empty:
            for fold in folds:
                fold_loss = loss_df[loss_df['fold'] == fold]
                if not fold_loss.empty:
                    axes[1, 0].plot(fold_loss['step'], fold_loss['value'], label=f'Fold {fold}', alpha=0.7)
            axes[1, 0].set_xlabel('Step')
            axes[1, 0].set_ylabel('Loss')
            axes[1, 0].set_title('Training Loss Curves')
            axes[1, 0].legend()
    
    # Plot 4: Summary statistics
    stats_text = f"""
Training Summary
{'='*30}
Target Level: {CONFIG.target_level}
Datasets: {', '.join(CONFIG.datasets)}
Folds Trained: {len(folds)}

Cross-Validation:
  Mean Accuracy: {np.mean(accs):.4f}
  Std Accuracy: {np.std(accs):.4f}
  Min/Max: {min(accs):.4f} / {max(accs):.4f}

Training Time:
  Total: {sum(fold_timings)/3600:.2f} hours
  Avg/Fold: {np.mean(fold_timings)/60:.1f} min
"""
    axes[1, 1].text(0.1, 0.5, stats_text, fontsize=10, family='monospace', 
                    verticalalignment='center')
    axes[1, 1].axis('off')
    
    plt.tight_layout()
    plot_path = f"{CONFIG.output_dir}/training_summary.png"
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"   [OK] Saved visualization: {plot_path}")
    plt.show()
except Exception as e:
    print(f"   [WARN]  Visualization failed: {e}")

# Close training logger
training_logger.finish()

print(f"\n{'='*70}")

print("[OK] Training complete! Models ready for inference.")
print(f"{'='*70}")