# BookNLP Maximum Performance Quote Attribution (Unified)

**Goal**: Train the max-performance quote attribution model (80â€“90% accuracy) in either Kaggle or Colab via one RUN_ENV-aware notebook.

**Features**
- DeBERTa-v3-large with quote/candidate masks + [QUOTE], [ALTQUOTE], [PAR]
- Candidate-level softmax with label smoothing; optional R-Drop; optional temperature scaling
- Optional multi-source loading + genre-balanced sampler; PDNC fallback; configurable hard negatives
- Curriculum sampler + light augmentation; gradient checkpointing + FP16
- Auto checkpoint/resume (model/optimizer/scheduler/best_acc) with cadence set by RUN_ENV; bucketed eval + placeholder postprocess hook

**Requirements**
- Kaggle: T4 x2 accelerator; storage under `/kaggle/working`
- Colab: T4 GPU; storage in Drive at `/content/drive`

**Quick start**
1) Set `RUN_ENV = "kaggle"` or `"colab"` in the next cell (default: kaggle).
2) Kaggle: no Drive mount; repo/output in `/kaggle/working`; multi-GPU via `accelerate`; checkpoints/evals every 500 steps; auto-resume from latest `checkpoint_*.pt`.
3) Colab: mounts Drive to `/content/drive`; repo in `/content`; outputs in Drive; single-GPU (no DDP) with gradient accumulation; checkpoints/evals every 300 steps; auto-resume from latest `checkpoint_*.pt`.
4) Run all cellsâ€”data is cloned automatically from the repo.



## RUN_ENV toggle
Set `RUN_ENV = "kaggle"` or `"colab"` in the next cell (default: kaggle). Paths, checkpoint cadence, and mounts adjust automatically. Kaggle uses multi-GPU via `accelerate` (no Drive mount); Colab mounts Drive, runs single-GPU with gradient accumulation.



In [None]:
import os, sys, torch

# CURSOR: Toggle once; everything else keys off this value
RUN_ENV = os.environ.get("RUN_ENV", "kaggle").strip().lower()
ENV_CFG = {
    "kaggle": {
        "base_dir": "/kaggle/working",
        "repo_dir": "/kaggle/working/speaker-attribution-acl2023",
        "output_root": "/kaggle/working",
        "checkpoint_every": 500,
        "eval_every": 500,
        "grad_accum": 4,
        "use_accelerate": True,
        "mount_drive": False,
    },
    "colab": {
        "base_dir": "/content/drive/MyDrive/quote_attribution",
        "repo_dir": "/content/speaker-attribution-acl2023",
        "output_root": "/content/drive/MyDrive/quote_attribution",
        "checkpoint_every": 300,
        "eval_every": 300,
        "grad_accum": 16,
        "use_accelerate": False,  # single-GPU path
        "mount_drive": True,
    },
}
assert RUN_ENV in ENV_CFG, f"Unsupported RUN_ENV: {RUN_ENV}"
ENV = ENV_CFG[RUN_ENV]

if ENV["mount_drive"]:
    from google.colab import drive
    drive.mount("/content/drive")
    os.makedirs(ENV["base_dir"], exist_ok=True)

print(f"Python: {sys.version}")
print(f"PyTorch: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    print(f"GPUs: {torch.cuda.device_count()}")
    for i in range(torch.cuda.device_count()):
        print(f"  GPU {i}: {torch.cuda.get_device_name(i)}")
else:
    raise RuntimeError("GPU not available; enable a GPU runtime.")

REPO_DIR = ENV["repo_dir"]
if not os.path.exists(REPO_DIR):
    print("\nðŸ“¥ Cloning repository with data...")
    !git clone https://github.com/Priya22/speaker-attribution-acl2023.git {REPO_DIR}
else:
    print(f"âœ… Repository present at {REPO_DIR}")

DATA_DIR = f"{REPO_DIR}/training/data/pdnc"
if not os.path.exists(DATA_DIR):
    raise FileNotFoundError(f"PDNC data missing at {DATA_DIR}")
else:
    print(f"âœ… PDNC data found at {DATA_DIR}")

BASE_DIR = ENV["base_dir"]
OUTPUT_ROOT = ENV["output_root"]
os.makedirs(OUTPUT_ROOT, exist_ok=True)
print(f"Output root: {OUTPUT_ROOT}")


In [None]:
%pip install -q transformers>=4.30.0 accelerate>=0.20.0 datasets scikit-learn tqdm nlpaug nltk


In [None]:
# =============================================================================
# CONFIGURATION
# =============================================================================
TARGET_LEVEL = 1  # 1=PDNC, 2=multi-source, 3=ensemble placeholder

# CURSOR: For best generalization on unknown/new books, train all 5 folds
# Set to "all" for all folds, or list of fold indices [0, 1, 2, 3, 4] or [0, 2] etc.
FOLD_SELECTION = "all"  # "all" or list like [0, 1, 2] or [3]
SPLIT_TYPE = "leave-x-out"  # "leave-x-out" or "random"

CONFIGS = {
    1: {
        'name': 'Target 1: DeBERTa-large + Augmentation',
        'epochs': 50, 'batch_size': 8, 'lr': 5e-6,
        'use_augmentation': True, 'use_curriculum': True,
        'focal_gamma': 2.0, 'label_smoothing': 0.1, 'r_drop_alpha': 0.7,
        'target_accuracy': 0.85,
        'hard_negative_topk': 2,
        'calibrate_temperature': True,
        'use_multi_source': False,
        'use_postprocess': False,
        'balance_genres': False,
        'fold_selection': FOLD_SELECTION,
        'split_type': SPLIT_TYPE,
    },
    2: {
        'name': 'Target 2: Multi-Source + Genre Balancing',
        'epochs': 30, 'batch_size': 8, 'lr': 2e-6,
        'use_augmentation': True, 'use_curriculum': True,
        'balance_genres': True, 'min_genre_acc': 0.75,
        'target_accuracy': 0.88,
        'hard_negative_topk': 2,
        'calibrate_temperature': True,
        'use_multi_source': True,
        'use_postprocess': False,
        'fold_selection': FOLD_SELECTION,
        'split_type': SPLIT_TYPE,
    },
    3: {
        'name': 'Target 3: Ensemble + Distillation',
        'ensemble_models': ['microsoft/deberta-v3-large', 'roberta-large'],
        'student_model': 'microsoft/deberta-v3-base',
        'epochs': 10, 'batch_size': 4, 'lr': 5e-6,
        'distill_epochs': 10, 'temperature': 3.0, 'alpha': 0.7,
        'target_accuracy': 0.90,
        'use_multi_source': True,
        'use_augmentation': True,
        'use_curriculum': True,
        'balance_genres': True,
        'hard_negative_topk': 2,
        'calibrate_temperature': True,
        'use_postprocess': False,
        'fold_selection': FOLD_SELECTION,
        'split_type': SPLIT_TYPE,
    }
}

CONFIG = CONFIGS[TARGET_LEVEL].copy()
CONFIG.update({
    'base_model': 'microsoft/deberta-v3-large',
    'max_length': 512,
    'gradient_accumulation_steps': ENV['grad_accum'],
    'checkpoint_every': ENV['checkpoint_every'],  # CURSOR: env-specific cadence
    'eval_every': ENV['eval_every'],
    'fp16': True,
    'gradient_checkpointing': True,
    'seed': 42,
    'data_path': f'{REPO_DIR}/data/pdnc_source',
    'train_data_path': f'{REPO_DIR}/training/data/pdnc',
    'output_dir': f'{OUTPUT_ROOT}/target_{TARGET_LEVEL}',
    'multi_source_base': f'{REPO_DIR}/data',
    'use_accelerate': ENV['use_accelerate'],
    # CURSOR: Feature toggles (disabled by default)
    'use_combined_loss': True,
    'use_postprocess': True,
    'postprocess_confidence': 0.6,
    'use_ensemble_eval': True,
    'ensemble_model_names': ['microsoft/deberta-v3-large'],
    'ensemble_voting_strategy': 'weighted_average',
    'run_cross_domain_validation': True,
    'run_genre_adaptation': True,
    'run_error_analysis': True,
    'run_model_optimization': True,
    'optimize_quantize': True,
    'optimize_export_onnx': True,
})

os.makedirs(CONFIG['output_dir'], exist_ok=True)
print(f"Selected: {CONFIG['name']}")
print(f"Target accuracy: {CONFIG['target_accuracy']:.0%}")
print(f"Output dir: {CONFIG['output_dir']}")
print(f"RUN_ENV: {RUN_ENV} | checkpoint_every={CONFIG['checkpoint_every']} | grad_accum={CONFIG['gradient_accumulation_steps']}")


In [None]:
# CURSOR: Auto-download multi-source datasets when enabled
import shutil
import subprocess
from pathlib import Path

if CONFIG.get("use_multi_source", False):
    multi_base = Path(CONFIG["multi_source_base"])
    multi_base.mkdir(parents=True, exist_ok=True)

    def clone_if_missing(url: str, dest: Path) -> bool:
        if dest.exists() and any(dest.iterdir()):
            print(f"âœ… Found {dest}")
            return False
        print(f"ðŸ“¥ Cloning {url} -> {dest}")
        subprocess.run(["git", "clone", "--depth", "1", url, str(dest)], check=True)
        return True

    lit_dir = multi_base / "litbank"
    clone_if_missing("https://github.com/dbamman/litbank.git", lit_dir)

    dq_dir = multi_base / "directquote"
    if not dq_dir.exists() or not any(dq_dir.iterdir()):
        tmp_repo = multi_base / "directquote_repo_tmp"
        if tmp_repo.exists():
            shutil.rmtree(tmp_repo, ignore_errors=True)
        clone_if_missing("https://github.com/THUNLP-MT/DirectQuote.git", tmp_repo)

        dq_dir.mkdir(parents=True, exist_ok=True)
        data_src = tmp_repo / "data"
        copied = 0

        if data_src.exists():
            for src in data_src.glob("**/*.json*"):
                if src.is_file():
                    target = dq_dir / src.name
                    shutil.copy2(src, target)
                    copied += 1

        if copied == 0:
            for src in tmp_repo.glob("*.json*"):
                if src.is_file():
                    target = dq_dir / src.name
                    shutil.copy2(src, target)
                    copied += 1

        shutil.rmtree(tmp_repo, ignore_errors=True)
        print(f"âœ… DirectQuote files copied: {copied}")
    else:
        print(f"âœ… Found {dq_dir}")
else:
    print("Multi-source disabled; skipping extra dataset downloads.")



In [None]:
import glob, random, numpy as np, pandas as pd
from pathlib import Path
from collections import defaultdict

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.cuda.amp import GradScaler, autocast

from transformers import DebertaV2Model, DebertaV2Tokenizer
from accelerate import Accelerator
from sklearn.metrics import accuracy_score
from tqdm.auto import tqdm

from training.data.multi_source_data import MultiSourceDataLoader
from training.data.data_augmentation import QuoteAugmenter
from training.data.curriculum_loader import DifficultyClassifier, CurriculumSampler, CurriculumConfig
from training.evaluation.confidence_calibration import TemperatureScaling
from training.models.max_performance_model import MaxPerformanceSpeakerModel
from training.losses.focal_loss import CombinedLoss
from training.optimization.post_processing import PostProcessor
from training.evaluation.cross_domain_validation import CrossDomainValidator
from training.evaluation.genre_specific_adaptations import GenreSpecificAdaptation
from training.evaluation.error_analysis import ErrorAnalyzer
from training.models.ensemble import create_ensemble
from training.optimization.model_optimization import optimize_for_inference

# CURSOR: Deterministic setup for reproducibility

def set_seed(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(CONFIG['seed'])

accelerator = None
if CONFIG['use_accelerate']:
    accelerator = Accelerator(
        mixed_precision='fp16' if CONFIG['fp16'] else 'no',
        gradient_accumulation_steps=CONFIG['gradient_accumulation_steps']
    )
    device = accelerator.device
else:
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"Device: {device} | RUN_ENV={RUN_ENV} | accelerate={bool(accelerator)}")


In [None]:
# Use the richer model with span/mask support

class QuoteDataset(Dataset):
    def __init__(self, samples, tokenizer, max_length=512, augment=False, augmenter: QuoteAugmenter = None):
        self.samples, self.tok, self.max_len = samples, tokenizer, max_length
        self.augment = augment
        self.augmenter = augmenter
        self.par_id = self.tok.convert_tokens_to_ids("[PAR]")
        self.altq_id = self.tok.convert_tokens_to_ids("[ALTQUOTE]")

    def __len__(self):
        return len(self.samples)

    def _maybe_augment_text(self, text: str) -> str:
        if not self.augment or not self.augmenter:
            return text
        # CURSOR: Light, safe synonym swap
        if random.random() < 0.2:
            try:
                return self.augmenter.synonym_replace(text, protected_spans=[], n=2)
            except Exception:
                return text
        return text

    def _encode(self, sample):
        base_text = self._maybe_augment_text(sample['text'])
        base_ids = self.tok.encode(base_text, add_special_tokens=False)
        candidates = sample['candidates']
        cand_ids = [self.tok.encode(c, add_special_tokens=False) for c in candidates]

        reserved = 1 + 1 + sum(1 + len(ci) for ci in cand_ids)
        room = max(self.max_len - reserved, 8)
        if len(base_ids) > room:
            base_ids = base_ids[:room]

        tokens = [self.par_id] + base_ids + [self.altq_id]
        quote_mask = [1] * len(tokens)

        cand_masks = []
        for ci in cand_ids:
            tokens.append(self.par_id)
            start = len(tokens)
            tokens.extend(ci)
            end = len(tokens)
            mask = [0] * len(tokens)
            for i in range(start, end):
                mask[i] = 1
            cand_masks.append(mask)

        if not cand_masks:
            tokens.append(self.par_id)
            mask = [0] * len(tokens)
            cand_masks.append(mask)

        tokens = tokens[: self.max_len]
        attention = [1] * len(tokens)
        if len(tokens) < self.max_len:
            pad_len = self.max_len - len(tokens)
            tokens += [self.tok.pad_token_id] * pad_len
            attention += [0] * pad_len
            quote_mask += [0] * pad_len
            cand_masks = [cm + [0] * pad_len for cm in cand_masks]
        else:
            quote_mask = quote_mask[: self.max_len]
            cand_masks = [cm[: self.max_len] for cm in cand_masks]

        return tokens, attention, quote_mask, cand_masks

    def __getitem__(self, idx):
        sample = self.samples[idx]
        tokens, attention, quote_mask, cand_masks = self._encode(sample)
        label_idx = sample['gold_index'] if sample['gold_index'] >= 0 else -100
        return {
            'input_ids': torch.tensor(tokens, dtype=torch.long),
            'attention_mask': torch.tensor(attention, dtype=torch.long),
            'quote_mask': torch.tensor(quote_mask, dtype=torch.long),
            'candidate_masks': [torch.tensor(cm, dtype=torch.long) for cm in cand_masks],
            'label_idx': torch.tensor(label_idx, dtype=torch.long),
            'quote_id': sample['quote_id']
        }


class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, label_smoothing=0.1):
        super().__init__()
        self.gamma, self.ls = gamma, label_smoothing
    def forward(self, inputs, targets):
        smoothed = targets.float() * (1 - self.ls) + 0.5 * self.ls
        probs = torch.sigmoid(inputs)
        ce = F.binary_cross_entropy(probs, smoothed, reduction='none')
        pt = torch.where(targets > 0, probs, 1 - probs).clamp(min=1e-6, max=1-1e-6)
        return ((1 - pt) ** self.gamma * ce).mean()


# Data helpers
import json

def _split_candidates(raw: str):
    if not raw:
        return []
    for sep in ["||", "|", ";", ","]:
        if sep in raw:
            return [c.strip() for c in raw.split(sep) if c.strip()]
    return [raw.strip()] if raw.strip() else []


def _add_hard_negatives(samples, topk):
    if not topk or not samples:
        return samples
    freq = {}
    for s in samples:
        for c in s['candidates']:
            freq[c] = freq.get(c, 0) + 1
    sorted_cands = [c for c, _ in sorted(freq.items(), key=lambda x: -x[1])]
    for s in samples:
        existing = set(s['candidates'])
        extras = [c for c in sorted_cands if c not in existing and c != s['gold']][:topk]
        s['candidates'] = s['candidates'] + extras
    return samples


def load_quote_data(filename):
    samples = []
    with open(filename, 'r', encoding='utf-8', errors='ignore') as f:
        for line in f:
            cols = line.rstrip().split("\t")
            if len(cols) < 6:
                continue
            doc_id = cols[0].strip()
            quote_id = cols[1].strip()
            gold_id = cols[2].strip()
            candidate_field = cols[4].strip()
            base_text = cols[5]
            candidates = _split_candidates(candidate_field)
            if not candidates or not gold_id:
                continue
            qid = f"{doc_id}:{quote_id}" if doc_id or quote_id else quote_id
            samples.append({
                'quote_id': qid,
                'text': base_text,
                'candidates': candidates,
                'gold': gold_id,
                'genre': 'pdnc',
                'source': 'pdnc'
            })
    samples = _add_hard_negatives(samples, CONFIG.get('hard_negative_topk', 0))
    for s in samples:
        s['gold_index'] = s['candidates'].index(s['gold']) if s['gold'] in s['candidates'] else -1
    return samples


def load_multi_source_samples(base_path: str):
    loader = MultiSourceDataLoader(base_path=base_path, seed=CONFIG['seed'])
    loader.load_all()
    all_samples = []
    for source, samples in loader.data_by_source.items():
        for i, s in enumerate(samples):
            gold = s.get('speaker', '')
            text = s.get('text') or s.get('quote') or ''
            genre = s.get('genre', source)
            if not gold or not text:
                continue  # CURSOR: skip entries without speakers/text
            qid = f"{source}:{i}"
            all_samples.append({
                'quote_id': qid,
                'text': text,
                'candidates': [gold],
                'gold': gold,
                'genre': genre,
                'source': source
            })
    all_samples = _add_hard_negatives(all_samples, CONFIG.get('hard_negative_topk', 0))
    for s in all_samples:
        s['gold_index'] = s['candidates'].index(s['gold']) if s['gold'] in s['candidates'] else -1
    random.shuffle(all_samples)
    split_val = int(0.1 * len(all_samples))
    split_test = int(0.1 * len(all_samples))
    val_samples = all_samples[:split_val]
    test_samples = all_samples[split_val:split_val + split_test]
    train_samples = all_samples[split_val + split_test:]
    return train_samples, val_samples, test_samples



In [None]:
# Load Data (PDNC or multi-source)
use_multi = CONFIG.get('use_multi_source', False)

# CURSOR: Determine which folds to train
fold_selection = CONFIG.get('fold_selection', [0])
split_type = CONFIG.get('split_type', 'leave-x-out')

if fold_selection == "all":
    FOLDS_TO_TRAIN = list(range(5))  # CURSOR: Train all 5 folds for best generalization
elif isinstance(fold_selection, list):
    FOLDS_TO_TRAIN = fold_selection
else:
    FOLDS_TO_TRAIN = [int(fold_selection)]

print(f"ðŸ“‹ Folds to train: {FOLDS_TO_TRAIN}")

# CURSOR: Helper to reload data for a specific fold
def load_fold_data(fold_idx: int):
    """Load train/val/test data for a specific fold."""
    train_f = f"{CONFIG['train_data_path']}/{split_type}/split_{fold_idx}/quotes.train.txt"
    dev_f = f"{CONFIG['train_data_path']}/{split_type}/split_{fold_idx}/quotes.dev.txt"
    test_f = f"{CONFIG['train_data_path']}/{split_type}/split_{fold_idx}/quotes.test.txt"
    return load_quote_data(train_f), load_quote_data(dev_f), load_quote_data(test_f)

if use_multi:
    print("ðŸ“‚ Loading multi-source data...")
    train_samples, val_samples, test_samples = load_multi_source_samples(CONFIG['multi_source_base'])
    print(f"   Train quotes: {len(train_samples)}")
    print(f"   Val quotes: {len(val_samples)}")
    print(f"   Test quotes: {len(test_samples)}")
    # CURSOR: For multi-source, we train once (no fold iteration)
    FOLDS_TO_TRAIN = [0]
else:
    # CURSOR: Load data for first fold (will reload in loop for each fold)
    current_fold = FOLDS_TO_TRAIN[0]
    
    train_file = f"{CONFIG['train_data_path']}/{split_type}/split_{current_fold}/quotes.train.txt"
    dev_file = f"{CONFIG['train_data_path']}/{split_type}/split_{current_fold}/quotes.dev.txt"
    test_file = f"{CONFIG['train_data_path']}/{split_type}/split_{current_fold}/quotes.test.txt"

    print(f"ðŸ“‚ Loading pre-processed PDNC data (Fold {current_fold}, {split_type})...")
    print(f"   Train: {train_file}")
    print(f"   Dev: {dev_file}")
    print(f"   Test: {test_file}")

    train_samples = load_quote_data(train_file)
    val_samples = load_quote_data(dev_file)
    test_samples = load_quote_data(test_file)

    print(f"\nâœ… Data loaded!")
    print(f"   Train quotes: {len(train_samples)}")
    print(f"   Val quotes: {len(val_samples)}")
    print(f"   Test quotes: {len(test_samples)}")

val_meta_map = {s['quote_id']: s for s in val_samples}


In [None]:
print("Loading model...")
model = MaxPerformanceSpeakerModel(CONFIG.get('base_model', 'microsoft/deberta-v3-large'))
if CONFIG['gradient_checkpointing']:
    model.encoder.gradient_checkpointing_enable()
tokenizer = model.get_tokenizer()
print(f"Parameters: {sum(p.numel() for p in model.parameters()):,}")

augmenter = QuoteAugmenter(seed=CONFIG['seed']) if CONFIG.get('use_augmentation', False) else None
post_processor = PostProcessor(confidence_threshold=CONFIG.get('postprocess_confidence', 0.6)) if CONFIG.get('use_postprocess', False) else None
ensemble_handle = None

# CURSOR: Helper function to create dataloaders for a given dataset
def create_dataloaders(train_samples, val_samples, tokenizer, augmenter):
    if CONFIG.get('use_curriculum', False):
        train_samples = sorted(train_samples, key=lambda s: len(s['text']))

    train_dataset = QuoteDataset(train_samples, tokenizer, CONFIG['max_length'], augment=CONFIG.get('use_augmentation', False), augmenter=augmenter)
    val_dataset = QuoteDataset(val_samples, tokenizer, CONFIG['max_length'], augment=False, augmenter=None)

    def make_genre_sampler(samples):
        counts = defaultdict(int)
        for s in samples:
            counts[s.get('genre', 'unknown')] += 1
        weights = []
        for s in samples:
            g = s.get('genre', 'unknown')
            weights.append(1.0 / max(counts[g], 1))
        return WeightedRandomSampler(weights=weights, num_samples=len(samples), replacement=True)

    def collate_fn(batch):
        max_cands = max(len(item['candidate_masks']) for item in batch)
        input_ids = torch.stack([b['input_ids'] for b in batch])
        attention_mask = torch.stack([b['attention_mask'] for b in batch])
        quote_mask = torch.stack([b['quote_mask'] for b in batch])
        cand_masks = []
        cand_attn = []
        for b in batch:
            masks = b['candidate_masks']
            orig_len = len(masks)
            if orig_len == 0:
                masks = [torch.zeros_like(b['input_ids'])]
                orig_len = 1
            pad_count = max_cands - orig_len
            if pad_count > 0:
                pad_mask = torch.zeros_like(masks[0])
                masks = masks + [pad_mask] * pad_count
            cand_masks.append(torch.stack(masks))
            attn_row = [1] * orig_len + [0] * pad_count
            cand_attn.append(torch.tensor(attn_row, dtype=torch.long))
        cand_masks = torch.stack(cand_masks)
        cand_attn = torch.stack(cand_attn)
        labels = torch.stack([b['label_idx'] for b in batch])
        quote_ids = [b['quote_id'] for b in batch]
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'quote_mask': quote_mask,
            'candidate_masks': cand_masks,
            'candidate_attention_mask': cand_attn,
            'label_idx': labels,
            'quote_ids': quote_ids
        }

    sampler = None
    if CONFIG.get('use_curriculum', False):
        difficulty = DifficultyClassifier()
        difficulty_indices = difficulty.classify_dataset(train_samples)
        sampler = CurriculumSampler(
            difficulty_indices=difficulty_indices,
            config=CurriculumConfig(),
            total_epochs=CONFIG['epochs'],
            current_epoch=0,
            batch_size=CONFIG['batch_size'],
            seed=CONFIG['seed']
        )
    elif CONFIG.get('use_multi_source', False) and CONFIG.get('balance_genres', False):
        sampler = make_genre_sampler(train_samples)

    train_loader = DataLoader(
        train_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False if sampler else not CONFIG.get('use_curriculum', False),
        sampler=sampler,
        num_workers=2,
        pin_memory=True,
        collate_fn=collate_fn
    )
    val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], shuffle=False, num_workers=2, collate_fn=collate_fn)
    
    return train_loader, val_loader

# Create initial dataloaders
train_loader, val_loader = create_dataloaders(train_samples, val_samples, tokenizer, augmenter)

optimizer = AdamW(model.parameters(), lr=CONFIG['lr'])
scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
use_combined_loss = CONFIG.get('use_combined_loss', False)
loss_fn = CombinedLoss(
    focal_gamma=CONFIG.get('focal_gamma', 2.0),
    label_smoothing=CONFIG.get('label_smoothing', 0.1),
    r_drop_alpha=CONFIG.get('r_drop_alpha', 0.0),
    use_focal=CONFIG.get('use_focal_loss', True),
    use_label_smoothing=CONFIG.get('label_smoothing', 0.0) > 0,
    use_r_drop=CONFIG.get('r_drop_alpha', 0.0) > 0
) if use_combined_loss else None
base_ce_loss = nn.CrossEntropyLoss(ignore_index=-100)

if CONFIG.get('use_ensemble_eval', False):
    try:
        ensemble_handle = create_ensemble(
            model_names=CONFIG.get('ensemble_model_names'),
            device=str(device) if isinstance(device, torch.device) else device
        )
        print(f"Ensemble initialized with {len(CONFIG.get('ensemble_model_names', []))} model(s)")
    except Exception as exc:
        ensemble_handle = None
        print(f"Ensemble init failed: {exc}")

if accelerator:
    model, optimizer, train_loader, val_loader = accelerator.prepare(model, optimizer, train_loader, val_loader)
    device = accelerator.device
    scaler = None
else:
    model = model.to(device)
    scaler = GradScaler(enabled=CONFIG['fp16'])

if loss_fn:
    loss_fn = loss_fn.to(device)
base_ce_loss = base_ce_loss.to(device)

os.makedirs(CONFIG['output_dir'], exist_ok=True)
print(f"Dataloaders ready | train batches: {len(train_loader)} | val batches: {len(val_loader)}")
print("Ready to train!")


In [None]:
# Training + evaluation utilities

def save_checkpoint(step, metrics, best_acc, epoch, fold_idx=None):
    suffix = f"_fold_{fold_idx}" if fold_idx is not None else ""
    if accelerator:
        accelerator.wait_for_everyone()
        unwrapped = accelerator.unwrap_model(model)
        accelerator.save({
            'step': step,
            'epoch': epoch,
            'model': unwrapped.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'metrics': metrics,
            'best_acc': best_acc,
            'fold': fold_idx
        }, f"{CONFIG['output_dir']}/checkpoint_{step}{suffix}.pt")
    else:
        torch.save({
            'step': step,
            'epoch': epoch,
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'metrics': metrics,
            'best_acc': best_acc,
            'fold': fold_idx
        }, f"{CONFIG['output_dir']}/checkpoint_{step}{suffix}.pt")
    if (not accelerator) or accelerator.is_main_process:
        print(f"Checkpoint saved at step {step}{suffix}")


def load_latest_checkpoint(fold_idx=None):
    import re
    suffix = f"_fold_{fold_idx}" if fold_idx is not None else ""
    pattern = f'checkpoint_*{suffix}.pt' if fold_idx is not None else 'checkpoint_*.pt'
    ckpts = glob.glob(os.path.join(CONFIG['output_dir'], pattern))
    if not ckpts:
        return 0, 0.0, 0
    # CURSOR: Use numeric parsing to find highest step (not lexicographic)
    def extract_step(path):
        match = re.search(r'checkpoint_(\d+)(?:_fold_\d+)?\.pt$', path)
        return int(match.group(1)) if match else -1
    latest = max(ckpts, key=extract_step)
    state = torch.load(latest, map_location='cpu')
    if accelerator:
        unwrapped = accelerator.unwrap_model(model)
        unwrapped.load_state_dict(state.get('model', {}))
    else:
        model.load_state_dict(state.get('model', {}))
    optimizer.load_state_dict(state.get('optimizer', {}))
    if state.get('scheduler'):
        scheduler.load_state_dict(state['scheduler'])
    start_step = state.get('step', 0)
    start_epoch = state.get('epoch', 0)
    best_acc = state.get('best_acc', 0.0)
    if (not accelerator) or accelerator.is_main_process:
        print(f"Resumed from {latest} (step {start_step}, epoch {start_epoch}, best_acc {best_acc:.4f})")
    return start_step, best_acc, start_epoch


def categorize_sample(meta):
    text = meta.get('text', '').lower()
    pronouns = set(['he','she','they','him','her','them','his','hers','their','theirs','i','me','my','mine','we','us','our','ours'])
    words = text.split()
    pronoun_density = sum(1 for w in words if w in pronouns) / max(len(words), 1)
    if pronoun_density > 0.12:
        return 'pronoun_heavy'
    if 'said' in text or 'asked' in text or 'replied' in text:
        return 'verb_attribution'
    return 'other'


def apply_postprocess(pred_idx: int, confidence: float, qid: str) -> int:
    if not CONFIG.get('use_postprocess', False) or not post_processor:
        return pred_idx
    meta = val_meta_map.get(qid, {})
    candidates = meta.get('candidates', [])
    if not candidates or pred_idx is None or pred_idx < 0 or pred_idx >= len(candidates):
        return pred_idx
    result = post_processor.process(
        speaker=candidates[pred_idx],
        confidence=float(confidence),
        quote=meta.get('text', ''),
        context=meta.get('text', ''),
        candidates=candidates,
        position=0
    )
    if result.speaker in candidates:
        return candidates.index(result.speaker)
    return pred_idx


def r_drop_penalty(logits_a, logits_b, alpha):
    if alpha <= 0:
        return torch.tensor(0.0, device=logits_a.device)
    pa = torch.log_softmax(logits_a, dim=-1)
    pb = torch.log_softmax(logits_b, dim=-1)
    return alpha * 0.5 * ((pa.exp() * (pa - pb)).sum(dim=-1).mean() + (pb.exp() * (pb - pa)).sum(dim=-1).mean())


def move_to_device(batch, device):
    moved = {}
    for k, v in batch.items():
        if torch.is_tensor(v):
            moved[k] = v.to(device)
        elif isinstance(v, list) and v and torch.is_tensor(v[0]):
            moved[k] = [t.to(device) for t in v]
        else:
            moved[k] = v
    return moved


latest_eval_details = None

def evaluate(temp: float = 1.0):
    global latest_eval_details
    model.eval()
    total, correct = 0, 0
    buckets = defaultdict(lambda: {'total':0, 'correct':0})
    eval_details = {'preds': [], 'labels': [], 'genres': [], 'confidences': [], 'qids': [], 'samples': []}
    is_main = (not accelerator) or accelerator.is_main_process
    if post_processor:
        post_processor.dialogue_tracker.reset()
    with torch.no_grad():
        for batch in val_loader:
            if accelerator:
                logits, _ = model(
                    batch['input_ids'],
                    batch['attention_mask'],
                    batch['quote_mask'],
                    batch['candidate_masks'],
                    batch['candidate_attention_mask']
                )
                if 'candidate_attention_mask' in batch and batch['candidate_attention_mask'] is not None:
                    logits = logits.masked_fill(batch['candidate_attention_mask'] == 0, -1e9)
                probs = torch.softmax(logits / temp, dim=-1)
                conf_vals, preds = torch.max(probs, dim=-1)
                labels = batch['label_idx']
                preds_all = accelerator.gather_for_metrics(preds)
                labels_all = accelerator.gather_for_metrics(labels)
                conf_all = accelerator.gather_for_metrics(conf_vals)
                qids_all = accelerator.gather_object(batch['quote_ids'])
                flat_qids = [qid for sub in qids_all for qid in sub]
                preds_list = preds_all.cpu().tolist()
                labels_list = labels_all.cpu().tolist()
                conf_list = conf_all.cpu().tolist()
            else:
                batch = move_to_device(batch, device)
                logits, _ = model(
                    batch['input_ids'],
                    batch['attention_mask'],
                    batch['quote_mask'],
                    batch['candidate_masks'],
                    batch['candidate_attention_mask']
                )
                if 'candidate_attention_mask' in batch and batch['candidate_attention_mask'] is not None:
                    logits = logits.masked_fill(batch['candidate_attention_mask'] == 0, -1e9)
                probs = torch.softmax(logits / temp, dim=-1)
                conf_vals, preds = torch.max(probs, dim=-1)
                labels = batch['label_idx']
                flat_qids = batch['quote_ids']
                preds_list = preds.cpu().tolist()
                labels_list = labels.cpu().tolist()
                conf_list = conf_vals.cpu().tolist()
            for pred, label, qid, conf in zip(preds_list, labels_list, flat_qids, conf_list):
                adjusted_pred = apply_postprocess(pred, conf, qid)
                meta = val_meta_map.get(qid, {})
                eval_details['preds'].append(adjusted_pred)
                eval_details['labels'].append(label)
                eval_details['genres'].append(meta.get('genre', 'unknown'))
                eval_details['confidences'].append(float(conf))
                eval_details['qids'].append(qid)
                eval_details['samples'].append({
                    'qid': qid,
                    'pred': adjusted_pred,
                    'label': label,
                    'genre': meta.get('genre', 'unknown'),
                    'confidence': float(conf)
                })
                if label < 0:
                    continue
                bucket = categorize_sample(meta)
                buckets[bucket]['total'] += 1
                if adjusted_pred == label:
                    buckets[bucket]['correct'] += 1
                correct += 1 if adjusted_pred == label else 0
                total += 1
    overall = correct / total if total else 0.0
    bucket_metrics = {k: (v['correct']/v['total'] if v['total'] else 0.0) for k,v in buckets.items()}
    if is_main:
        latest_eval_details = eval_details
    return (overall if is_main else 0.0), (bucket_metrics if is_main else {}), eval_details


print("=" * 60)
print(f"TRAINING: {CONFIG['name']}")
print("=" * 60)

# CURSOR: Multi-fold training loop
for fold_idx in FOLDS_TO_TRAIN:
    if len(FOLDS_TO_TRAIN) > 1:
        print(f"\n{'='*60}")
        print(f"FOLD {fold_idx + 1}/{len(FOLDS_TO_TRAIN)} (split_{fold_idx})")
        print(f"{'='*60}")
        
        # Reload data for this fold
        if not use_multi:
            train_samples, val_samples, test_samples = load_fold_data(fold_idx)
            val_meta_map = {s['quote_id']: s for s in val_samples}
            print(f"   Train quotes: {len(train_samples)}")
            print(f"   Val quotes: {len(val_samples)}")
            print(f"   Test quotes: {len(test_samples)}")
            
            # Recreate dataloaders for this fold
            train_loader, val_loader = create_dataloaders(train_samples, val_samples, tokenizer, augmenter)
            
            # Re-prepare with accelerator if needed
            if accelerator:
                train_loader, val_loader = accelerator.prepare(train_loader, val_loader)
        
        # Reinitialize model for each fold
        model = MaxPerformanceSpeakerModel(CONFIG.get('base_model', 'microsoft/deberta-v3-large'))
        if CONFIG['gradient_checkpointing']:
            model.encoder.gradient_checkpointing_enable()
        
        # Reinitialize optimizer and scheduler
        optimizer = AdamW(model.parameters(), lr=CONFIG['lr'])
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=5, T_mult=2)
        
        if accelerator:
            model, optimizer = accelerator.prepare(model, optimizer)
            device = accelerator.device
        else:
            model = model.to(device)

    global_step, best_acc, start_epoch = load_latest_checkpoint(fold_idx if len(FOLDS_TO_TRAIN) > 1 else None)
    ce_loss = base_ce_loss
    # CURSOR: Only run second forward pass when R-Drop is enabled (saves ~2x compute otherwise)
    use_r_drop = CONFIG.get('r_drop_alpha', 0.0) > 0

    for epoch in range(start_epoch, CONFIG['epochs']):
        if isinstance(train_loader.sampler, CurriculumSampler):
            train_loader.sampler.set_epoch(epoch)
        if (not accelerator) or accelerator.is_main_process:
            fold_str = f" [Fold {fold_idx}]" if len(FOLDS_TO_TRAIN) > 1 else ""
            print(f"\n--- Epoch {epoch + 1}/{CONFIG['epochs']}{fold_str} ---")
        model.train()
        accum_step = 0

        for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
            if accelerator:
                with accelerator.accumulate(model):
                    logits1, _ = model(
                        batch['input_ids'],
                        batch['attention_mask'],
                        batch['quote_mask'],
                        batch['candidate_masks'],
                        batch['candidate_attention_mask']
                    )
                    if use_r_drop:
                        logits2, _ = model(
                            batch['input_ids'],
                            batch['attention_mask'],
                            batch['quote_mask'],
                            batch['candidate_masks'],
                            batch['candidate_attention_mask']
                        )
                    else:
                        logits2 = None
                    if use_combined_loss and loss_fn is not None:
                        loss = loss_fn(logits1, batch['label_idx'], logits2)
                        loss_main = loss
                    elif use_r_drop and logits2 is not None:
                        loss_main = ce_loss(logits1, batch['label_idx'])
                        loss_aux = ce_loss(logits2, batch['label_idx'])
                        kl_pen = r_drop_penalty(logits1, logits2, CONFIG.get('r_drop_alpha', 0.0))
                        loss = 0.5 * (loss_main + loss_aux) + kl_pen
                    else:
                        loss_main = ce_loss(logits1, batch['label_idx'])
                        loss = loss_main
                    accelerator.backward(loss)
                    accelerator.clip_grad_norm_(model.parameters(), 1.0)
                    optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
            else:
                batch = move_to_device(batch, device)
                with autocast(enabled=CONFIG['fp16']):
                    logits1, _ = model(
                        batch['input_ids'],
                        batch['attention_mask'],
                        batch['quote_mask'],
                        batch['candidate_masks'],
                        batch['candidate_attention_mask']
                    )
                    if use_r_drop:
                        logits2, _ = model(
                            batch['input_ids'],
                            batch['attention_mask'],
                            batch['quote_mask'],
                            batch['candidate_masks'],
                            batch['candidate_attention_mask']
                        )
                    else:
                        logits2 = None
                    if use_combined_loss and loss_fn is not None:
                        loss = loss_fn(logits1, batch['label_idx'], logits2)
                        loss_main = loss
                    elif use_r_drop and logits2 is not None:
                        loss_main = ce_loss(logits1, batch['label_idx'])
                        loss_aux = ce_loss(logits2, batch['label_idx'])
                        kl_pen = r_drop_penalty(logits1, logits2, CONFIG.get('r_drop_alpha', 0.0))
                        loss = 0.5 * (loss_main + loss_aux) + kl_pen
                    else:
                        loss_main = ce_loss(logits1, batch['label_idx'])
                        loss = loss_main
                    loss = loss / CONFIG['gradient_accumulation_steps']
                if scaler:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()
                accum_step += 1
                if accum_step % CONFIG['gradient_accumulation_steps'] == 0:
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                    if scaler:
                        scaler.step(optimizer)
                        scaler.update()
                    else:
                        optimizer.step()
                    scheduler.step()
                    optimizer.zero_grad()
            global_step += 1

            if global_step % CONFIG['checkpoint_every'] == 0:
                save_checkpoint(global_step, {'loss': float(loss_main.detach().cpu())}, best_acc, epoch, fold_idx if len(FOLDS_TO_TRAIN) > 1 else None)

            if global_step % CONFIG['eval_every'] == 0:
                acc, buckets, latest_eval_details = evaluate()
                if (not accelerator) or accelerator.is_main_process:
                    print(f"Step {global_step} | Accuracy: {acc:.4f} | buckets: {buckets}")
                    if acc > best_acc:
                        best_acc = acc
                        model_suffix = f"_split_{fold_idx}" if len(FOLDS_TO_TRAIN) > 1 else ""
                        save_path = f"{CONFIG['output_dir']}/best_model{model_suffix}.pt"
                        torch.save((accelerator.unwrap_model(model) if accelerator else model).state_dict(), save_path)
                        print(f"New best model! ({best_acc:.4f}) -> {save_path}")
                model.train()

        if (not accelerator) or accelerator.is_main_process:
            if best_acc >= CONFIG['target_accuracy']:
                print(f"\nðŸŽ‰ Target accuracy {CONFIG['target_accuracy']:.0%} reached!")
                break

    # Temperature calibration on val set after training this fold
    best_temp = 1.0
    if CONFIG.get('calibrate_temperature', False):
        scaler_ts = TemperatureScaling()
        val_logits_all = []
        val_labels_all = []
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                if accelerator:
                    logits, _ = model(
                        batch['input_ids'],
                        batch['attention_mask'],
                        batch['quote_mask'],
                        batch['candidate_masks'],
                        batch['candidate_attention_mask']
                    )
                    labels = batch['label_idx']
                    val_logits_all.append(accelerator.gather_for_metrics(logits).detach().cpu())
                    val_labels_all.append(accelerator.gather_for_metrics(labels).detach().cpu())
                else:
                    batch = move_to_device(batch, device)
                    logits, _ = model(
                        batch['input_ids'],
                        batch['attention_mask'],
                        batch['quote_mask'],
                        batch['candidate_masks'],
                        batch['candidate_attention_mask']
                    )
                    labels = batch['label_idx']
                    val_logits_all.append(logits.detach().cpu())
                    val_labels_all.append(labels.detach().cpu())
        if val_logits_all:
            val_logits = torch.cat(val_logits_all, dim=0)
            val_labels = torch.cat(val_labels_all, dim=0)
            scaler_ts.calibrate(val_logits, val_labels)
            best_temp = scaler_ts.get_temperature()
            acc_cal, buckets_cal, latest_eval_details = evaluate(temp=best_temp)
            if (not accelerator) or accelerator.is_main_process:
                fold_str = f" [Fold {fold_idx}]" if len(FOLDS_TO_TRAIN) > 1 else ""
                print(f"Calibrated temperature{fold_str}: {best_temp:.3f} (val acc={acc_cal:.4f}, buckets={buckets_cal})")

    if (not accelerator) or accelerator.is_main_process:
        fold_str = f" [Fold {fold_idx}]" if len(FOLDS_TO_TRAIN) > 1 else ""
        print(f"\n{'='*60}")
        print(f"FOLD {fold_idx} COMPLETE{fold_str}")
        print(f"Best Accuracy: {best_acc:.4f}")
        print(f"Temperature: {best_temp}")
        model_suffix = f"_split_{fold_idx}" if len(FOLDS_TO_TRAIN) > 1 else ""
        print(f"Model saved to: {CONFIG['output_dir']}/best_model{model_suffix}.pt")
        print(f"{'='*60}")

# Run analysis and optimization only once at the end
if (not accelerator) or accelerator.is_main_process:
    if CONFIG.get('run_cross_domain_validation', False) and latest_eval_details:
        validator = CrossDomainValidator(
            thresholds={
                'accuracy': CONFIG.get('min_genre_acc', 0.75),
                'f1': 0.70,
                'min_samples': 10,
            }
        )
        validator.add_predictions(
            latest_eval_details.get('preds', []),
            latest_eval_details.get('labels', []),
            latest_eval_details.get('genres', [])
        )
        passed, domain_report = validator.validate()
        validator.print_report()
        if CONFIG.get('run_genre_adaptation', False):
            adapter = GenreSpecificAdaptation(min_accuracy_threshold=CONFIG.get('min_genre_acc', 0.75))
            adapter.underperforming_genres = domain_report.get('failing_genres', [])
            for genre in domain_report.get('failing_genres', []):
                print(f"Genre {genre}: suggested fallback -> {adapter.get_adaptation_strategy(genre)}")
    if CONFIG.get('run_error_analysis', False) and latest_eval_details:
        analyzer = ErrorAnalyzer()
        for sample in latest_eval_details.get('samples', []):
            if sample['label'] < 0 or sample['pred'] == sample['label']:
                continue
            meta = val_meta_map.get(sample['qid'], {})
            candidates = meta.get('candidates', [])
            predicted_name = candidates[sample['pred']] if 0 <= sample['pred'] < len(candidates) else str(sample['pred'])
            actual_name = candidates[sample['label']] if 0 <= sample['label'] < len(candidates) else str(sample['label'])
            analyzer.add_error(
                sample_id=sample['qid'],
                text=meta.get('text', ''),
                quote=meta.get('text', ''),
                predicted=predicted_name,
                actual=actual_name,
                candidates=candidates,
                confidence=sample.get('confidence', 0.0),
                genre=meta.get('genre', 'unknown')
            )
        summary = analyzer.get_summary()
        recs = analyzer.get_recommendations()
        print("Error analysis summary:", summary)
        if recs:
            print("Recommendations:")
            for rec in recs:
                print(f"- {rec}")

if (not accelerator) or accelerator.is_main_process:
    if CONFIG.get('run_model_optimization', False):
        try:
            print("Running model optimization for inference...")
            opt_model = accelerator.unwrap_model(model) if accelerator else model
            opt_model = opt_model.to('cpu')
            optimize_for_inference(
                model=opt_model,
                output_dir=f"{CONFIG['output_dir']}/optimized",
                quantize=CONFIG.get('optimize_quantize', True),
                export_onnx=CONFIG.get('optimize_export_onnx', False),
                calibration_data=None
            )
        except Exception as exc:
            print(f"Model optimization failed: {exc}")
    
    print(f"\n{'='*60}")
    print(f"ALL TRAINING COMPLETE!")
    if len(FOLDS_TO_TRAIN) > 1:
        print(f"Trained {len(FOLDS_TO_TRAIN)} folds: {FOLDS_TO_TRAIN}")
        print(f"Models saved: best_model_split_{{0-{len(FOLDS_TO_TRAIN)-1}}}.pt")
    else:
        print(f"Trained single fold: {FOLDS_TO_TRAIN[0]}")
        print(f"Model saved: best_model.pt")
    print(f"Output dir: {CONFIG['output_dir']}")
    print(f"{'='*60}")
