In [None]:
!pip uninstall -y transformers tokenizers sentence-transformers
!pip install transformers==4.30.2 --no-deps
!pip install "tokenizers<0.14" sacremoses
!pip install sentence-transformers==2.2.2
!pip install sacrebleu

In [None]:
import transformers
print(transformers.__version__)


In [None]:
# ==============================================================================
# CELL 0: TATN CONFIGURATION (BENGALI → ENGLISH) - FIXED
# ==============================================================================

import os
import sys
import math
import random
import re
import unicodedata
import time
import threading
from pathlib import Path
from collections import deque, defaultdict
from typing import List, Dict, Tuple, Optional, Union, Set, Any
from types import SimpleNamespace

import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import warnings
import gc

try:
    import pandas as pd
    _HAS_PANDAS = True
except Exception:
    pd = None
    _HAS_PANDAS = False

try:
    import transformers
    _HAS_TRANSFORMERS = True
except Exception:
    transformers = None
    _HAS_TRANSFORMERS = False

_HAS_M2M_TOKENIZER = False
if _HAS_TRANSFORMERS:
    try:
        from transformers import M2M100Tokenizer
        _HAS_M2M_TOKENIZER = True
    except Exception:
        _HAS_M2M_TOKENIZER = False

try:
    from datasets import load_dataset
    _HAS_DATASETS = True
except Exception:
    load_dataset = None
    _HAS_DATASETS = False

warnings.filterwarnings("ignore")

os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")

def _get_int_env(name: str, default: int) -> int:
    try:
        v = globals().get(name, None)
        if v is None:
            v_env = os.environ.get(name, None)
            if v_env is not None:
                return int(v_env)
            return int(default)
        return int(v)
    except Exception:
        return int(default)


def _get_float_env(name: str, default: float) -> float:
    try:
        v = globals().get(name, None)
        if v is None:
            v_env = os.environ.get(name, None)
            if v_env is not None:
                return float(v_env)
            return float(default)
        return float(v)
    except Exception:
        return float(default)


def _get_bool_env(name: str, default: bool) -> bool:
    try:
        v = globals().get(name, None)
        if v is None:
            s = os.environ.get(name, None)
            if s is None:
                return bool(default)
            if str(s).lower() in ("1", "true", "yes", "y"):
                return True
            return False
        return bool(v)
    except Exception:
        return bool(default)


NUM_GPUS = max(0, _get_int_env("NUM_GPUS", torch.cuda.device_count() if torch.cuda.is_available() else 0))
USE_MULTI_GPU = _get_bool_env("USE_MULTI_GPU", NUM_GPUS > 1)

if torch.cuda.is_available():
    if USE_MULTI_GPU and NUM_GPUS > 1:
        DEVICE = torch.device("cuda")
    else:
        DEVICE = torch.device("cuda:0" if torch.cuda.device_count() > 0 else "cpu")
else:
    DEVICE = torch.device("cpu")

DATASET_CSV_PATH = os.environ.get(
    "DATASET_PATH",
    globals().get("DATASET_CSV_PATH", "/kaggle/input/bn-homo/bn_homograph_complete_dataset.csv"),
)

BATCH_SIZE = max(1, _get_int_env("BATCH_SIZE", globals().get("BATCH_SIZE", 32)))
NUM_SAMPLES = max(1, _get_int_env("NUM_SAMPLES", globals().get("NUM_SAMPLES", 30000)))
MAX_LENGTH = max(1, _get_int_env("MAX_LENGTH", globals().get("MAX_LENGTH", 48)))

LR_NMT = float(_get_float_env("LR_NMT", globals().get("LR_NMT", 2e-5)))
LR_TRG = float(_get_float_env("LR_TRG", globals().get("LR_TRG", 1e-5)))
LR_PHI = float(_get_float_env("LR_PHI", globals().get("LR_PHI", 1e-5)))

EPOCHS = max(1, _get_int_env("EPOCHS", globals().get("EPOCHS", 5)))
GRAD_CLIP_NORM = float(_get_float_env("GRAD_CLIP_NORM", globals().get("GRAD_CLIP_NORM", 1.0)))
USE_AMP = _get_bool_env("USE_AMP", globals().get("USE_AMP", True))
PRINT_INTERVAL = max(1, _get_int_env("PRINT_INTERVAL", globals().get("PRINT_INTERVAL", 100)))
SEED = int(_get_int_env("SEED", globals().get("SEED", 42)))

ACCUMULATION_STEPS = max(1, _get_int_env("ACCUMULATION_STEPS", globals().get("ACCUMULATION_STEPS", 1)))

NUM_WORKERS = max(0, _get_int_env("NUM_WORKERS", globals().get("NUM_WORKERS", 0)))
PIN_MEMORY = _get_bool_env("PIN_MEMORY", globals().get("PIN_MEMORY", False))
PREFETCH_FACTOR = max(1, _get_int_env("PREFETCH_FACTOR", globals().get("PREFETCH_FACTOR", 2)))
GRADIENT_CHECKPOINTING = _get_bool_env("GRADIENT_CHECKPOINTING", globals().get("GRADIENT_CHECKPOINTING", False))

DEBUG_DISCOVERY = _get_bool_env("DEBUG_DISCOVERY", globals().get("DEBUG_DISCOVERY", True))
DEBUG_TIMING = _get_bool_env("DEBUG_TIMING", globals().get("DEBUG_TIMING", True))
DEBUG_VERBOSE = _get_bool_env("DEBUG_VERBOSE", globals().get("DEBUG_VERBOSE", True))
VERBOSE_LOGGING = _get_bool_env("VERBOSE_LOGGING", globals().get("VERBOSE_LOGGING", True))

DSCD_BUFFER_SIZE = max(1, _get_int_env("DSCD_BUFFER_SIZE", globals().get("DSCD_BUFFER_SIZE", 80)))
DSCD_MAX_PROTOS = max(1, _get_int_env("DSCD_MAX_PROTOS", globals().get("DSCD_MAX_PROTOS", 8)))
DSCD_N_MIN = max(1, _get_int_env("DSCD_N_MIN", globals().get("DSCD_N_MIN", 2)))
DSCD_DISPERSION_THRESHOLD = _get_float_env("DSCD_DISPERSION_THRESHOLD", globals().get("DSCD_DISPERSION_THRESHOLD", 0.70))
DSCD_EMBED_DIM = max(1, _get_int_env("DSCD_EMBED_DIM", globals().get("DSCD_EMBED_DIM", 1024)))
DSCD_TEMPERATURE = float(_get_float_env("DSCD_TEMPERATURE", globals().get("DSCD_TEMPERATURE", 0.7)))
DSCD_DROPOUT = float(_get_float_env("DSCD_DROPOUT", globals().get("DSCD_DROPOUT", 0.1)))
DSCD_AUGMENT_SCALE = float(_get_float_env("DSCD_AUGMENT_SCALE", globals().get("DSCD_AUGMENT_SCALE", 0.1)))
DSCD_ENABLE_TRAINING_CLUSTERING = _get_bool_env("DSCD_ENABLE_TRAINING_CLUSTERING", globals().get("DSCD_ENABLE_TRAINING_CLUSTERING", True))
DSCD_ENABLE_ONLINE_CLUSTERING = _get_bool_env("DSCD_ENABLE_ONLINE_CLUSTERING", globals().get("DSCD_ENABLE_ONLINE_CLUSTERING", True))
DSCD_ONLINE_CLUSTERING_FREQUENCY = max(1, _get_int_env("DSCD_ONLINE_CLUSTERING_FREQUENCY", globals().get("DSCD_ONLINE_CLUSTERING_FREQUENCY", 10)))
DSCD_WARMUP_SAMPLES = max(0, _get_int_env("DSCD_WARMUP_SAMPLES", globals().get("DSCD_WARMUP_SAMPLES", 0)))
DSCD_NEWSENSE_LAMBDA = float(_get_float_env("DSCD_NEWSENSE_LAMBDA", globals().get("DSCD_NEWSENSE_LAMBDA", 1.5)))
DSCD_USE_COSINE_DISTANCE = _get_bool_env("DSCD_USE_COSINE_DISTANCE", globals().get("DSCD_USE_COSINE_DISTANCE", True))

PERIODIC_DISCOVERY_FREQUENCY = max(1, _get_int_env("PERIODIC_DISCOVERY_FREQUENCY", globals().get("PERIODIC_DISCOVERY_FREQUENCY", 150)))
_MAX_TOKENS_PER_DISCOVERY = max(1, _get_int_env("_MAX_TOKENS_PER_DISCOVERY", globals().get("_MAX_TOKENS_PER_DISCOVERY", 150)))
DSCD_MIN_LETTERS = max(1, _get_int_env("DSCD_MIN_LETTERS", globals().get("DSCD_MIN_LETTERS", 2)))
DSCD_MIN_LETTER_FRACTION = float(_get_float_env("DSCD_MIN_LETTER_FRACTION", globals().get("DSCD_MIN_LETTER_FRACTION", 0.5)))
DSCD_MAX_CLUSTERING_POINTS = max(1, _get_int_env("DSCD_MAX_CLUSTERING_POINTS", globals().get("DSCD_MAX_CLUSTERING_POINTS", 500)))

ENABLE_ASBN_TRAINING = _get_bool_env("ENABLE_ASBN_TRAINING", globals().get("ENABLE_ASBN_TRAINING", True))
ENABLE_ASBN_INFERENCE = _get_bool_env("ENABLE_ASBN_INFERENCE", globals().get("ENABLE_ASBN_INFERENCE", False))
ENABLE_TRG_TRAINING = _get_bool_env("ENABLE_TRG_TRAINING", globals().get("ENABLE_TRG_TRAINING", True))
ENABLE_TRG_INFERENCE = _get_bool_env("ENABLE_TRG_INFERENCE", globals().get("ENABLE_TRG_INFERENCE", True))

MC_DROPOUT_PASSES = max(1, _get_int_env("MC_DROPOUT_PASSES", globals().get("MC_DROPOUT_PASSES", 5)))
TRG_EVIDENCE_K = max(1, _get_int_env("TRG_EVIDENCE_K", globals().get("TRG_EVIDENCE_K", 3)))
MAX_SILVER_BUFFER = max(1, _get_int_env("MAX_SILVER_BUFFER", globals().get("MAX_SILVER_BUFFER", 100)))

CLUSTERING_TIMEOUT = max(1, _get_int_env("CLUSTERING_TIMEOUT", globals().get("CLUSTERING_TIMEOUT", 60)))
MEMORY_CLEANUP_FREQUENCY = max(0, _get_int_env("MEMORY_CLEANUP_FREQUENCY", globals().get("MEMORY_CLEANUP_FREQUENCY", 200)))
VALIDATION_CHECK_INTERVAL = max(1, _get_int_env("VALIDATION_CHECK_INTERVAL", globals().get("VALIDATION_CHECK_INTERVAL", 500)))

CHECKPOINT_DIR = str(globals().get("CHECKPOINT_DIR", os.environ.get("CHECKPOINT_DIR", "/kaggle/working/")))
CHECKPOINT_SAVE_AFTER_TRAINING = _get_bool_env("CHECKPOINT_SAVE_AFTER_TRAINING", globals().get("CHECKPOINT_SAVE_AFTER_TRAINING", True))
CHECKPOINT_FILENAME = str(globals().get("CHECKPOINT_FILENAME", os.environ.get("CHECKPOINT_FILENAME", "tatn_final.pt")))
CHECKPOINT_INTERVAL = int(_get_int_env("CHECKPOINT_INTERVAL", globals().get("CHECKPOINT_INTERVAL", 99999999)))
SAVE_REPLAY_BUFFER = _get_bool_env("SAVE_REPLAY_BUFFER", globals().get("SAVE_REPLAY_BUFFER", False))
LOAD_REPLAY_BUFFER = _get_bool_env("LOAD_REPLAY_BUFFER", globals().get("LOAD_REPLAY_BUFFER", False))
REPLAY_BUFFER_SIZE = max(0, _get_int_env("REPLAY_BUFFER_SIZE", globals().get("REPLAY_BUFFER_SIZE", 25000)))
RESUME_FROM_CHECKPOINT = _get_bool_env("RESUME_FROM_CHECKPOINT", globals().get("RESUME_FROM_CHECKPOINT", False))
SAVE_DSCD_STATE = _get_bool_env("SAVE_DSCD_STATE", globals().get("SAVE_DSCD_STATE", True))

if not os.path.exists(CHECKPOINT_DIR):
    try:
        os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    except Exception:
        CHECKPOINT_DIR = "./"

TAU_LOW = float(_get_float_env("TAU_LOW", globals().get("TAU_LOW", 0.15)))
TAU_HIGH = float(_get_float_env("TAU_HIGH", globals().get("TAU_HIGH", 0.85)))
TAU_ACCEPT = float(_get_float_env("TAU_ACCEPT", globals().get("TAU_ACCEPT", 0.8)))

TRG_MAX_GEN_LEN = max(1, _get_int_env("TRG_MAX_GEN_LEN", globals().get("TRG_MAX_GEN_LEN", 16)))
TRG_GEN_EMBED = max(1, _get_int_env("TRG_GEN_EMBED", globals().get("TRG_GEN_EMBED", 64)))
TRG_GEN_HID = max(1, _get_int_env("TRG_GEN_HID", globals().get("TRG_GEN_HID", 64)))

TRG_SPAN_THRESHOLD = float(_get_float_env("TRG_SPAN_THRESHOLD", globals().get("TRG_SPAN_THRESHOLD", 0.15)))
TRG_UNCERTAINTY_THRESHOLD = float(_get_float_env("TRG_UNCERTAINTY_THRESHOLD", globals().get("TRG_UNCERTAINTY_THRESHOLD", 0.70)))
TRG_TEMPERATURE = float(_get_float_env("TRG_TEMPERATURE", globals().get("TRG_TEMPERATURE", 1.0)))

ASBN_HIDDEN_DIM = max(1, _get_int_env("ASBN_HIDDEN_DIM", globals().get("ASBN_HIDDEN_DIM", 64)))
ASBN_LAMBDA = float(_get_float_env("ASBN_LAMBDA", globals().get("ASBN_LAMBDA", 0.1)))
ASBN_DROPOUT = float(_get_float_env("ASBN_DROPOUT", globals().get("ASBN_DROPOUT", 0.1)))

LAMBDA_ASBN = float(_get_float_env("LAMBDA_ASBN", globals().get("LAMBDA_ASBN", 0.05)))
LAMBDA_DSCD = float(_get_float_env("LAMBDA_DSCD", globals().get("LAMBDA_DSCD", 0.15)))

TRAIN_DOMAIN = int(_get_int_env("TRAIN_DOMAIN", globals().get("TRAIN_DOMAIN", 0)))
TEST_DOMAIN = int(_get_int_env("TEST_DOMAIN", globals().get("TEST_DOMAIN", 1)))
USE_DOMAIN_LABELS = _get_bool_env("USE_DOMAIN_LABELS", globals().get("USE_DOMAIN_LABELS", True))

GRL_ALPHA_START = float(_get_float_env("GRL_ALPHA_START", globals().get("GRL_ALPHA_START", 0.0)))
GRL_ALPHA_END = float(_get_float_env("GRL_ALPHA_END", globals().get("GRL_ALPHA_END", 1.0)))
GRL_ALPHA_SCHEDULE = str(globals().get("GRL_ALPHA_SCHEDULE", "linear"))

_total_steps_estimate = max(1, NUM_SAMPLES // max(1, (BATCH_SIZE * ACCUMULATION_STEPS)))
GRL_ALPHA_STEPS = max(1, _total_steps_estimate * EPOCHS)

SOURCE_LANGUAGE = str(globals().get("SOURCE_LANGUAGE", os.environ.get("SOURCE_LANGUAGE", "bn")))
TARGET_LANGUAGE = str(globals().get("TARGET_LANGUAGE", os.environ.get("TARGET_LANGUAGE", "en")))

M2M100_BN_TOKEN_ID = int(globals().get("M2M100_BN_TOKEN_ID", os.environ.get("M2M100_BN_TOKEN_ID", 128025)))
M2M100_EN_TOKEN_ID = int(globals().get("M2M100_EN_TOKEN_ID", os.environ.get("M2M100_EN_TOKEN_ID", 128022)))

HOMOGRAPH_REFERENCE_LIST_BN: Set[str] = set(globals().get("HOMOGRAPH_REFERENCE_LIST_BN", [
    "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা"
]))
HOMOGRAPH_WATCHLIST_BN: Set[str] = set(globals().get("HOMOGRAPH_WATCHLIST_BN", list(HOMOGRAPH_REFERENCE_LIST_BN)))
HOMOGRAPH_WATCHLIST: Set[str] = set(HOMOGRAPH_WATCHLIST_BN)
USE_WATCHLIST_PRIORITIZATION = _get_bool_env("USE_WATCHLIST_PRIORITIZATION", globals().get("USE_WATCHLIST_PRIORITIZATION", False))
WATCHLIST_ONLY_FOR_TRG = _get_bool_env("WATCHLIST_ONLY_FOR_TRG", globals().get("WATCHLIST_ONLY_FOR_TRG", False))

def normalize_bengali(t: str) -> str:
    if not t:
        return ""
    t = unicodedata.normalize("NFKC", str(t))
    t = t.replace("▁", "").replace("##", "").strip()
    return t

def normalize_english(t: str) -> str:
    if not t:
        return ""
    t = unicodedata.normalize("NFKC", str(t)).lower().strip()
    return t

def normalize_token_key(token: str) -> str:
    if not token:
        return ""
    token = str(token)
    token = token.replace("▁", "").replace("##", "").replace("Ġ", "").strip()
    for punct in ".,!?;:\"'()-":
        token = token.replace(punct, "")
    return token.strip()

def empty_cuda_cache() -> None:
    gc.collect()
    if torch.cuda.is_available():
        try:
            torch.cuda.empty_cache()
        except Exception:
            pass

def safe_cuda_synchronize() -> None:
    if torch.cuda.is_available():
        try:
            torch.cuda.synchronize()
        except Exception:
            pass

def monitor_gpu_usage() -> None:
    if torch.cuda.is_available():
        visible_gpus = torch.cuda.device_count()
        print(f"\n[GPU MONITOR] Checking {visible_gpus} GPU(s):")
        for i in range(visible_gpus):
            try:
                mem_alloc = torch.cuda.memory_allocated(i) / (1024 ** 3)
                mem_reserved = torch.cuda.memory_reserved(i) / (1024 ** 3)
                print(f"  GPU {i}: {mem_alloc:.2f}GB allocated / {mem_reserved:.2f}GB reserved")
            except Exception:
                print(f"  GPU {i}: memory stats unavailable")
    else:
        print("[GPU MONITOR] No CUDA devices available")

def get_checkpoint_path() -> str:
    if not os.path.exists(CHECKPOINT_DIR):
        try:
            os.makedirs(CHECKPOINT_DIR, exist_ok=True)
        except Exception:
            pass
    return os.path.join(CHECKPOINT_DIR, CHECKPOINT_FILENAME)

def should_save_checkpoint(global_step: int, epoch: int, is_final: bool = False) -> bool:
    if is_final and CHECKPOINT_SAVE_AFTER_TRAINING:
        return True
    if CHECKPOINT_INTERVAL < 99999999 and global_step >= CHECKPOINT_INTERVAL and global_step % CHECKPOINT_INTERVAL == 0:
        return True
    return False

class FunctionTimeoutError(Exception):
    pass

def with_timeout(seconds: int):
    def decorator(func):
        def wrapper(*args, **kwargs):
            result = [FunctionTimeoutError("Function timed out")]
            def target():
                try:
                    result[0] = func(*args, **kwargs)
                except Exception as e:
                    result[0] = e
            thread = threading.Thread(target=target, daemon=True)
            thread.start()
            thread.join(timeout=seconds)
            if thread.is_alive():
                return None
            if isinstance(result[0], Exception):
                if isinstance(result[0], FunctionTimeoutError):
                    return None
                raise result[0]
            return result[0]
        return wrapper
    return decorator

def get_tokenizer_special_tokens(tokenizer) -> Set[str]:
    try:
        s = set(getattr(tokenizer, "all_special_tokens", []))
    except Exception:
        s = {"<pad>", "</s>", "<s>", "<unk>"}
    s.update({SOURCE_LANGUAGE, TARGET_LANGUAGE})
    return s

def get_special_tokens(tokenizer) -> Set[str]:
    return get_tokenizer_special_tokens(tokenizer)

_token_validation_cache: Dict[Tuple[str, str], bool] = {}
_cache_lock = threading.Lock()
_cache_max_size = 10000

def is_valid_token(token, special_tokens: Optional[Set[str]] = None, tokenizer=None, language: str = "bn") -> bool:
    token = "" if token is None else str(token)
    cache_key = (token, language)
    with _cache_lock:
        if cache_key in _token_validation_cache:
            return _token_validation_cache[cache_key]
    clean = token.replace("▁", "").replace("##", "").strip()
    if special_tokens and token in special_tokens:
        result = False
    else:
        min_len = 2
        if len(clean) < min_len:
            result = False
        else:
            has_bengali_chars = any('\u0980' <= c <= '\u09FF' for c in clean)
            if not has_bengali_chars:
                result = False
            else:
                bengali_count = sum(1 for c in clean if '\u0980' <= c <= '\u09FF')
                alphanum_count = sum(1 for c in clean if c.isalnum())
                if alphanum_count == 0:
                    result = False
                else:
                    bengali_ratio = bengali_count / alphanum_count
                    result = bengali_ratio >= 0.5
    with _cache_lock:
        if len(_token_validation_cache) < _cache_max_size:
            _token_validation_cache[cache_key] = result
    return result

def fallback_is_valid_token(token, special_tokens: Optional[Set[str]] = None, language: str = "bn") -> bool:
    return is_valid_token(token, special_tokens, None, language)

def safe_tokenize_with_offsets(tokenizer, text: str, max_length: int = 512):
    try:
        encoded = tokenizer(text, return_offsets_mapping=True, max_length=max_length, truncation=True, add_special_tokens=False)
        toks = tokenizer.convert_ids_to_tokens(encoded.get("input_ids", []))
        offsets = encoded.get("offset_mapping", [(0, 0)] * len(toks))
        return toks, offsets
    except Exception:
        return None, None

class DiscoveryTimer:
    def __init__(self):
        self.discovery_times: List[float] = []
        self.discovery_steps: List[int] = []

    def record(self, step: int, duration: float) -> None:
        self.discovery_times.append(duration)
        self.discovery_steps.append(step)

    def get_stats(self) -> Dict[str, float]:
        if not self.discovery_times:
            return {"count": 0, "total": 0.0, "avg": 0.0, "max": 0.0}
        total = sum(self.discovery_times)
        count = len(self.discovery_times)
        return {"count": count, "total": total, "avg": total / count, "max": max(self.discovery_times)}

_discovery_timer = DiscoveryTimer()
discoverytimer = _discovery_timer

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    try:
        torch.cuda.manual_seed_all(SEED)
    except Exception:
        pass

try:
    if hasattr(torch, "set_float32_matmul_precision"):
        torch.set_float32_matmul_precision("high")
except Exception:
    pass
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

effective_batch = BATCH_SIZE * ACCUMULATION_STEPS
if USE_MULTI_GPU and NUM_GPUS > 1:
    effective_batch *= NUM_GPUS

print("\n" + "=" * 80)
print("TATN CONFIGURATION (Bengali to English)")
print("=" * 80)
print(f"User: {os.getenv('KAGGLE_USERNAME', os.getenv('USER', 'manas0003'))}")
print(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime())} UTC")
print(f"Multi-GPU: {'ENABLED' if USE_MULTI_GPU and NUM_GPUS>0 else 'DISABLED'} ({NUM_GPUS} GPUs detected)")
print(f"Device: {DEVICE}")
print(f"Dataset: {DATASET_CSV_PATH}")
print(f"Samples: {NUM_SAMPLES:,} | Batch: {BATCH_SIZE} | Accum: {ACCUMULATION_STEPS}")
print(f"Effective batch: {effective_batch}")
print(f"Max length: {MAX_LENGTH} | Epochs: {EPOCHS} | AMP: {USE_AMP}")
print()
print("DSCD Config:")
print(f"  Buffer: {DSCD_BUFFER_SIZE} | n_min: {DSCD_N_MIN} | Max protos: {DSCD_MAX_PROTOS}")
print(f"  Dispersion threshold: {DSCD_DISPERSION_THRESHOLD}")
print(f"  Use cosine distance: {DSCD_USE_COSINE_DISTANCE}")
print(f"  Online clustering: {DSCD_ENABLE_ONLINE_CLUSTERING} (freq={DSCD_ONLINE_CLUSTERING_FREQUENCY})")
print(f"  New-sense lambda: {DSCD_NEWSENSE_LAMBDA}")
print(f"  Warmup samples: {DSCD_WARMUP_SAMPLES}")
print(f"  Periodic discovery: Every {PERIODIC_DISCOVERY_FREQUENCY} optimizer updates")
print(f"  Max tokens per discovery: {_MAX_TOKENS_PER_DISCOVERY}")
print(f"  Clustering timeout: {CLUSTERING_TIMEOUT}s")
print()
print("TRG & Uncertainty:")
print(f"  MC Dropout passes: {MC_DROPOUT_PASSES} | TAU_LOW: {TAU_LOW}")
print(f"  TRG_SPAN_THRESHOLD: {TRG_SPAN_THRESHOLD} | TRG_UNCERTAINTY_THRESHOLD: {TRG_UNCERTAINTY_THRESHOLD}")
print()
print("ASBN / Loss:")
print(f"  LAMBDA_ASBN: {LAMBDA_ASBN} | LAMBDA_DSCD: {LAMBDA_DSCD}")
print(f"  Domain labels: {USE_DOMAIN_LABELS} | GRL: {GRL_ALPHA_SCHEDULE}")
print(f"  GRL steps: {GRL_ALPHA_STEPS}")
print(f"  ASBN training: {ENABLE_ASBN_TRAINING}")
print(f"  ASBN inference: {ENABLE_ASBN_INFERENCE}")
print()
print("Augmentation:")
print(f"  Apply DSCD augmentation: {globals().get('APPLY_DSCD_AUGMENTATION', False)}")
print()
print("Debug Flags:")
print(f"  Discovery logging: {DEBUG_DISCOVERY}")
print(f"  Timing monitoring: {DEBUG_TIMING}")
print(f"  Verbose mode: {DEBUG_VERBOSE}")
print(f"  Verbose logging: {VERBOSE_LOGGING}")
print()
print("Validation:")
print(f"  Check interval: {VALIDATION_CHECK_INTERVAL} steps")
print()
print("Language Tokens (defaults):")
print(f"  Bengali (bn): {M2M100_BN_TOKEN_ID}")
print(f"  English (en): {M2M100_EN_TOKEN_ID}")
print()
print("Checkpoint:")
print(f"  Path: {get_checkpoint_path()}")
print(f"  Save strategy: {'Final only' if CHECKPOINT_SAVE_AFTER_TRAINING else 'Interval'}")
print(f"  Save DSCD state: {SAVE_DSCD_STATE}")
print("=" * 80)

if not _HAS_PANDAS:
    print("[WARN] pandas not available - CSV loading will fall back to builtin dataset")
if not _HAS_M2M_TOKENIZER:
    print("[WARN] M2M100 tokenizer class not detected (will be resolved at runtime if transformers available)")

try:
    test_file = os.path.join(CHECKPOINT_DIR, ".test_write")
    with open(test_file, "w") as f:
        f.write("test")
    os.remove(test_file)
    print(f"[INFO] Checkpoint directory writable: {CHECKPOINT_DIR}")
except Exception as e:
    print(f"[WARN] Checkpoint directory not writable: {e}")

monitor_gpu_usage()

print("\n" + "=" * 80)
print("Cell 0: Configuration loaded")
print("=" * 80)


In [None]:
# ===========================================================================================
# CELL 1: TOKENIZER UTILITIES (BENGALI-FOCUSED) - FIXED
# ===========================================================================================

import threading
import unicodedata
from typing import Tuple, List, Dict, Optional, Set, Any
import numpy as np
import torch

try:
    SAFE_OFFSET_MAX_LEN = int(MAX_LENGTH) if isinstance(MAX_LENGTH, (int, float)) and MAX_LENGTH > 0 else 48
except Exception:
    SAFE_OFFSET_MAX_LEN = 48
if SAFE_OFFSET_MAX_LEN <= 0:
    SAFE_OFFSET_MAX_LEN = 48

try:
    _SOURCE_LANG = str(SOURCE_LANGUAGE)
except Exception:
    _SOURCE_LANG = "bn"

try:
    _TARGET_LANG = str(TARGET_LANGUAGE)
except Exception:
    _TARGET_LANG = "en"

try:
    _DEBUG_VERBOSE = bool(DEBUG_VERBOSE)
except Exception:
    _DEBUG_VERBOSE = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except Exception:
    _DEBUG_DISCOVERY = False

_SPECIAL_TOKENS_CACHE: Dict[str, Set[str]] = {}
_SPECIAL_TOKENS_LOCK = threading.Lock()
_LANGUAGE_WARNING_COUNT = 0
_MAX_LANGUAGE_WARNINGS = 3

_SPIECE_UNDERLINE = "\u2581"

def _special_token_cache_key(tokenizer) -> str:
    if tokenizer is None:
        return "none_tokenizer__vocab=None"
    name = getattr(tokenizer, "name_or_path", None) or getattr(tokenizer, "name", None) or "unknown_tokenizer"
    vocab = None
    try:
        if hasattr(tokenizer, "vocab_size"):
            vocab = int(getattr(tokenizer, "vocab_size"))
        elif hasattr(tokenizer, "get_vocab") and callable(getattr(tokenizer, "get_vocab")):
            vocab = len(tokenizer.get_vocab())
    except Exception:
        vocab = None
    return f"{name}__vocab={vocab}"

def get_tokenizer_special_tokens(tokenizer) -> Set[str]:
    if tokenizer is None:
        return {"</s>", "<pad>", "<s>", "<unk>", "__bn__", "__en__"}

    cache_key = _special_token_cache_key(tokenizer)
    with _SPECIAL_TOKENS_LOCK:
        if cache_key in _SPECIAL_TOKENS_CACHE:
            return _SPECIAL_TOKENS_CACHE[cache_key]

    special_tokens: Set[str] = set()
    try:
        if hasattr(tokenizer, "all_special_tokens"):
            try:
                result = getattr(tokenizer, "all_special_tokens")
                if isinstance(result, (list, tuple, set)):
                    special_tokens.update(str(x) for x in result if x)
            except Exception:
                pass
        if hasattr(tokenizer, "additional_special_tokens"):
            try:
                result = getattr(tokenizer, "additional_special_tokens")
                if isinstance(result, (list, tuple, set)):
                    special_tokens.update(str(x) for x in result if x)
            except Exception:
                pass
        for attr in ("pad_token", "unk_token", "bos_token", "eos_token", "cls_token", "sep_token", "mask_token"):
            try:
                tok = getattr(tokenizer, attr, None)
                if tok:
                    special_tokens.add(str(tok))
            except Exception:
                pass
        try:
            stm = getattr(tokenizer, "special_tokens_map", None) or getattr(tokenizer, "special_tokens_map_extended", None)
            if isinstance(stm, dict):
                for v in stm.values():
                    if isinstance(v, str) and v:
                        special_tokens.add(v)
                    elif isinstance(v, (list, tuple, set)):
                        special_tokens.update(str(x) for x in v if x)
        except Exception:
            pass
    except Exception:
        special_tokens = set()

    special_tokens.update({
        "__bn__", "__en__", "</s>", "<pad>", "<s>", "<unk>",
        "[PAD]", "[EOS]", "[UNK]", "[CLS]", "[SEP]", "[MASK]",
    })

    try:
        vocab = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else {}
        preserved = {"</s>", "<pad>", "<s>", "<unk__", "__bn__", "__en__"}
        try:
            preserved.add(_SOURCE_LANG)
            preserved.add(_TARGET_LANG)
        except Exception:
            preserved.update({"bn", "en"})
        if isinstance(vocab, dict):
            special_tokens = {tok for tok in special_tokens if tok in vocab or tok in preserved}
        else:
            special_tokens.update(preserved)
    except Exception:
        pass

    with _SPECIAL_TOKENS_LOCK:
        _SPECIAL_TOKENS_CACHE[cache_key] = special_tokens

    return special_tokens

def _normalize_offset_mapping_for_batchencoding(enc: dict) -> dict:
    if not isinstance(enc, dict):
        return enc

    try:
        if "offset_mapping" in enc and enc["offset_mapping"] is not None:
            off = enc["offset_mapping"]
            try:
                if hasattr(off, "tolist"):
                    arr = off.tolist()
                    if isinstance(arr, list) and len(arr) > 0:
                        if isinstance(arr[0], list):
                            enc["offset_mapping"] = [(x[0], x[1]) if isinstance(x, (list, tuple)) and len(x) >= 2 else (None, None) for x in arr[0]]
                            return enc
                        else:
                            # single list of offsets (no change)
                            enc["offset_mapping"] = [(x[0], x[1]) if isinstance(x, (list, tuple)) and len(x) >= 2 else (None, None) for x in arr]
                            return enc
                if isinstance(off, (list, tuple)) and len(off) > 0:
                    if isinstance(off[0], (list, tuple)):
                        enc["offset_mapping"] = [(x[0], x[1]) if isinstance(x, (list, tuple)) and len(x) >= 2 else (None, None) for x in off[0]]
                        return enc
            except Exception:
                pass
    except Exception:
        pass

    try:
        data = getattr(enc, "data", None)
        if data and isinstance(data, dict) and "offset_mapping" in data and data["offset_mapping"] is not None:
            om = data["offset_mapping"]
            if isinstance(om, (list, tuple)) and len(om) > 0 and isinstance(om[0], (list, tuple)):
                enc["offset_mapping"] = [(x[0], x[1]) if isinstance(x, (list, tuple)) and len(x) >= 2 else (None, None) for x in om[0]]
                return enc
    except Exception:
        pass

    try:
        seq_len = 0
        if "input_ids" in enc:
            input_ids = enc["input_ids"]
            if hasattr(input_ids, "shape") and len(input_ids.shape) > 0:
                seq_len = int(input_ids.shape[-1])
            elif isinstance(input_ids, (list, tuple)) and len(input_ids) > 0 and isinstance(input_ids[0], (list, tuple)):
                seq_len = len(input_ids[0])
        enc["offset_mapping"] = [(None, None)] * seq_len
    except Exception:
        enc["offset_mapping"] = []

    return enc

def safe_offsets_tokenize(tokenizer, text: str, max_length: Optional[int] = None, include_special_tokens: bool = False) -> dict:
    if tokenizer is None:
        return {"input_ids": torch.tensor([[0]], dtype=torch.long), "attention_mask": torch.tensor([[1]], dtype=torch.long), "offset_mapping": []}

    if max_length is None:
        max_length = SAFE_OFFSET_MAX_LEN
    eff_max = max(1, int(max_length))

    try:
        if not isinstance(text, str):
            text = "" if text is None else str(text)
    except Exception:
        text = ""

    char_limit = min(eff_max * 30, 8000)
    sample_text = text[:char_limit] if len(text) > char_limit else text

    is_fast = getattr(tokenizer, "is_fast", False)

    if is_fast:
        try:
            enc = tokenizer(sample_text, return_offsets_mapping=True, return_tensors="pt", truncation=True, padding=False, max_length=eff_max, add_special_tokens=include_special_tokens)
            enc = _normalize_offset_mapping_for_batchencoding(enc)
            return enc
        except Exception:
            pass

    try:
        enc = tokenizer(sample_text, return_tensors="pt", truncation=True, padding=False, max_length=eff_max, add_special_tokens=include_special_tokens)
    except Exception:
        pad_id = getattr(tokenizer, "pad_token_id", 0)
        enc = {"input_ids": torch.tensor([[pad_id]], dtype=torch.long), "attention_mask": torch.tensor([[1]], dtype=torch.long)}
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc

    try:
        input_ids = None
        try:
            input_ids = enc["input_ids"][0].tolist()
        except Exception:
            if hasattr(enc, "data") and "input_ids" in enc.data:
                try:
                    input_ids = enc.data["input_ids"][0]
                except Exception:
                    input_ids = None

        tokens: List[str] = []
        if input_ids is not None:
            try:
                tokens = tokenizer.convert_ids_to_tokens(input_ids)
            except Exception:
                tokens = []

        offsets_list: List[Tuple[Optional[int], Optional[int]]] = []
        src = sample_text
        cur_pos = 0

        for tok in tokens:
            token_text = (tok or "").replace("▁", "").replace(_SPIECE_UNDERLINE, "").replace("##", "").replace("Ġ", "").strip()
            if not token_text:
                offsets_list.append((None, None))
                continue
            idx = src.find(token_text, cur_pos)
            if idx == -1:
                idx = src.lower().find(token_text.lower(), cur_pos)
            if idx == -1:
                offsets_list.append((None, None))
            else:
                start = int(idx)
                end = int(idx + len(token_text))
                offsets_list.append((start, end))
                cur_pos = end

        enc["offset_mapping"] = offsets_list
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc
    except Exception:
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc

def reconstruct_word_spans(tokenizer, text: str, max_length: Optional[int] = None) -> Tuple[Dict[int, Optional[str]], List[str]]:
    global _LANGUAGE_WARNING_COUNT

    if tokenizer is None:
        return {}, []

    if max_length is None:
        max_length = SAFE_OFFSET_MAX_LEN
    eff_max = max(1, int(max_length))

    if not isinstance(text, str) or len(text.strip()) == 0:
        return {}, []

    has_bengali = any('\u0980' <= c <= '\u09FF' for c in text)
    has_english = any('a' <= c.lower() <= 'z' for c in text)

    if _DEBUG_VERBOSE and _DEBUG_DISCOVERY:
        bengali_pct = (sum(1 for c in text if '\u0980' <= c <= '\u09FF') / max(1, len(text))) * 100.0
        print(f"[TOKENIZER] Text sample: {text[:50]}")
        print(f"[TOKENIZER] Bengali: {has_bengali} ({bengali_pct:.1f}%), English: {has_english}")

    if not has_bengali and has_english and _LANGUAGE_WARNING_COUNT < _MAX_LANGUAGE_WARNINGS:
        if _DEBUG_DISCOVERY:
            print("[TOKENIZER WARNING] Text appears to be ENGLISH, not BENGALI")
            print(f"  Sample: {text[:80]}")
        _LANGUAGE_WARNING_COUNT += 1
        if _LANGUAGE_WARNING_COUNT == _MAX_LANGUAGE_WARNINGS:
            print("[TOKENIZER] Suppressing further language warnings")

    char_limit = min(eff_max * 30, 8000)
    text = text[:char_limit]
    text_len = len(text)

    special_tokens = get_tokenizer_special_tokens(tokenizer)

    try:
        encoded = safe_offsets_tokenize(tokenizer, text, max_length=eff_max, include_special_tokens=False)
    except Exception:
        return {}, []

    offsets = encoded.get("offset_mapping", [])
    try:
        input_ids = encoded["input_ids"][0].tolist()
    except Exception:
        input_ids = []
    try:
        tokens = tokenizer.convert_ids_to_tokens(input_ids) if input_ids else []
    except Exception:
        tokens = []

    if not tokens:
        return {}, []

    if isinstance(offsets, list) and len(offsets) > 0:
        if all(isinstance(x, tuple) for x in offsets):
            offsets_list = offsets
        elif isinstance(offsets[0], (list, tuple)):
            offsets_list = [(x[0], x[1]) if isinstance(x, (list, tuple)) and len(x) >= 2 else (None, None) for x in offsets[0]]
        else:
            offsets_list = [(None, None)] * len(tokens)
    else:
        offsets_list = [(None, None)] * len(tokens)

    token_word_map: Dict[int, Optional[str]] = {}
    words: List[str] = []

    used_any_offset = any(isinstance(o, tuple) and o[0] is not None and o[1] is not None for o in offsets_list)

    if used_any_offset:
        word_start: Optional[int] = None
        word_end: Optional[int] = None
        word_token_indices: List[int] = []
        for idx, (off, tok) in enumerate(zip(offsets_list, tokens)):
            try:
                off_start = int(off[0]) if off[0] is not None else None
                off_end = int(off[1]) if off[1] is not None else None
            except Exception:
                off_start, off_end = None, None

            if off_start is None or off_end is None or tok in special_tokens:
                if word_start is not None and word_end is not None:
                    try:
                        wtext = text[word_start:word_end].strip()
                        if wtext:
                            words.append(wtext)
                            for tidx in word_token_indices:
                                token_word_map[tidx] = wtext
                    except Exception:
                        pass
                word_start = None
                word_end = None
                word_token_indices = []
                token_word_map[idx] = None
                continue

            if word_start is None:
                word_start = off_start
                word_end = off_end
                word_token_indices = [idx]
            else:
                if word_end is not None and off_start > word_end:
                    try:
                        wtext = text[word_start:word_end].strip()
                        if wtext:
                            words.append(wtext)
                            for tidx in word_token_indices:
                                token_word_map[tidx] = wtext
                    except Exception:
                        pass
                    word_start = off_start
                    word_end = off_end
                    word_token_indices = [idx]
                else:
                    if word_end is not None:
                        word_end = max(word_end, off_end)
                    else:
                        word_end = off_end
                    word_token_indices.append(idx)

        if word_start is not None and word_end is not None:
            try:
                wtext = text[word_start:word_end].strip()
                if wtext:
                    words.append(wtext)
                    for tidx in word_token_indices:
                        token_word_map[tidx] = wtext
            except Exception:
                pass

        if token_word_map:
            words = [w for w in words if isinstance(w, str) and w.strip()]
            return token_word_map, words

    token_word_map = {}
    assembled: List[str] = []
    current_parts: List[str] = []
    current_indices: List[int] = []
    max_word_len = 100

    for i, tok in enumerate(tokens):
        if tok in special_tokens:
            if current_parts:
                word = "".join(current_parts)
                if len(word) <= max_word_len:
                    assembled.append(word)
                    for tidx in current_indices:
                        token_word_map[tidx] = word
                current_parts = []
                current_indices = []
            token_word_map[i] = None
            continue

        clean = (tok or "").replace("▁", "").replace(_SPIECE_UNDERLINE, "").replace("Ġ", "").replace("##", "").strip()
        if not clean:
            token_word_map[i] = None
            continue

        is_start = tok.startswith("▁") or tok.startswith("Ġ") or tok.startswith(_SPIECE_UNDERLINE)

        if is_start:
            if current_parts:
                word = "".join(current_parts)
                if len(word) <= max_word_len:
                    assembled.append(word)
                    for tidx in current_indices:
                        token_word_map[tidx] = word
            current_parts = [clean]
            current_indices = [i]
        else:
            current_parts.append(clean)
            current_indices.append(i)
            if len("".join(current_parts)) > max_word_len:
                if current_parts[:-1]:
                    word = "".join(current_parts[:-1])
                    assembled.append(word)
                    for tidx in current_indices[:-1]:
                        token_word_map[tidx] = word
                current_parts = [clean]
                current_indices = [i]

    if current_parts:
        word = "".join(current_parts)
        if len(word) <= max_word_len:
            assembled.append(word)
            for tidx in current_indices:
                token_word_map[tidx] = word

    if token_word_map:
        words = [w for w in assembled if w and w.strip()]
        return token_word_map, words

    try:
        word_list = [w for w in text.split() if w.strip()]
        token_word_map = {}
        if tokens and word_list:
            word_idx = 0
            current_word = word_list[0] if word_list else None
            for i, tok in enumerate(tokens):
                clean = (tok or "").replace("▁", "").replace(_SPIECE_UNDERLINE, "").replace("Ġ", "").replace("##", "").strip()
                if not clean or tok in special_tokens:
                    token_word_map[i] = None
                    continue
                if tok.startswith("▁") or tok.startswith("Ġ") or tok.startswith(_SPIECE_UNDERLINE):
                    if word_idx < len(word_list) - 1:
                        word_idx += 1
                    current_word = word_list[word_idx] if word_idx < len(word_list) else None
                token_word_map[i] = current_word
        return token_word_map, word_list
    except Exception:
        return {}, []

def is_word_token(clean_token: str, min_letters: int = 2, min_letter_fraction: float = 0.5) -> bool:
    if not clean_token or not isinstance(clean_token, str):
        return False
    if len(clean_token) < min_letters:
        return False
    letter_count = sum(1 for c in clean_token if c.isalpha())
    if letter_count == 0:
        return False
    alphanum_count = sum(1 for c in clean_token if c.isalnum())
    if alphanum_count == 0:
        return False
    letter_ratio = letter_count / alphanum_count
    return letter_ratio >= min_letter_fraction

def test_tokenizer_utilities_quick(tokenizer=None) -> bool:
    sample_bn = "কাল আমি বাজারে যাব।"
    sample_en = "Tomorrow I will go to the market."

    print("\n" + "=" * 60)
    print("TOKENIZER UTILITIES TEST")
    print("=" * 60)

    try:
        if tokenizer is None:
            print("No tokenizer provided: skipping test")
            return True

        print("\n[TEST 1] Bengali text processing:")
        print(f"  Input: {sample_bn}")
        enc_bn = safe_offsets_tokenize(tokenizer, sample_bn, max_length=32, include_special_tokens=False)
        enc_len = int(enc_bn["input_ids"].shape[-1]) if isinstance(enc_bn, dict) and "input_ids" in enc_bn else "N/A"
        print(f"  Encoded length: {enc_len}")
        offsets_bn = enc_bn.get("offset_mapping") or []
        print(f"  Offsets (first 5): {offsets_bn[:5]}")

        token_map_bn, words_bn = reconstruct_word_spans(tokenizer, sample_bn, max_length=32)
        print(f"  Reconstructed words: {words_bn}")
        print(f"  Token map sample: {dict(list(token_map_bn.items())[:3])}")

        has_bengali_words = any(any('\u0980' <= c <= '\u09FF' for c in w) for w in words_bn)
        print(f"  Contains Bengali words: {has_bengali_words}")

        print("\n[TEST 2] English text processing (should show warning):")
        print(f"  Input: {sample_en}")
        token_map_en, words_en = reconstruct_word_spans(tokenizer, sample_en, max_length=32)
        print(f"  Reconstructed words: {words_en}")

        has_english_words = any(any('a' <= c.lower() <= 'z' for c in w) for w in words_en)
        print(f"  Contains English words: {has_english_words}")

        if has_bengali_words and not any('a' <= c.lower() <= 'z' for c in "".join(words_bn)):
            print("\nTest PASSED: Bengali processing works correctly")
            return True
        else:
            print("\nTest WARNING: Check language detection logic")
            return False

    except Exception as e:
        print(f"\nTest FAILED: {repr(e)}")
        import traceback
        traceback.print_exc()
        return False
    finally:
        print("=" * 60 + "\n")

safeoffsetstokenize = safe_offsets_tokenize
reconstructwordspans = reconstruct_word_spans
gettokenizerspecialtokens = get_tokenizer_special_tokens
iswordtoken = is_word_token

print("Cell 1: Tokenizer utilities loaded")

In [None]:
# ==============================================================================
# CELL 2: MEMORY-EFFICIENT DATA LOADING (BENGALI → ENGLISH TASK) - FIXED
# ==============================================================================

from typing import Optional, List, Tuple, Dict, Any
from collections import defaultdict
import os
import time
import random
import traceback
import re

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, get_worker_info
from tqdm import tqdm

try:
    import pandas as pd
    _HAS_PANDAS = True
except Exception:
    pd = None
    _HAS_PANDAS = False

try:
    from datasets import load_dataset
    _HAS_DATASETS = True
except Exception:
    load_dataset = None
    _HAS_DATASETS = False

def _get_global(name, default):
    try:
        return globals().get(name, default)
    except Exception:
        return default

try:
    _VERBOSE_LOGGING = bool(_get_global("VERBOSE_LOGGING", False))
except Exception:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_VERBOSE = bool(_get_global("DEBUG_VERBOSE", False))
except Exception:
    _DEBUG_VERBOSE = False

DEBUG_CELL2 = bool(_VERBOSE_LOGGING) or bool(_DEBUG_VERBOSE)
DEBUG_LIMIT = 10
_cell2_dbg_counts: Dict[str, int] = defaultdict(int)


def cell2_dbg(key: str, msg: str, limit: int = DEBUG_LIMIT) -> None:
    if not DEBUG_CELL2:
        return
    _cell2_dbg_counts[key] += 1
    if _cell2_dbg_counts[key] <= limit:
        print(f"[CELL2-DBG] {msg}")


try:
    _NUM_SAMPLES = int(_get_global("NUM_SAMPLES", 50000))
    if _NUM_SAMPLES <= 0:
        _NUM_SAMPLES = 50000
except Exception:
    _NUM_SAMPLES = 50000

try:
    _MAX_LENGTH = int(_get_global("MAX_LENGTH", 48))
    if _MAX_LENGTH <= 0:
        _MAX_LENGTH = 48
except Exception:
    _MAX_LENGTH = 48

try:
    _SOURCE_LANG = str(_get_global("SOURCE_LANGUAGE", "bn"))
    _TARGET_LANG = str(_get_global("TARGET_LANGUAGE", "en"))
except Exception:
    _SOURCE_LANG = "bn"
    _TARGET_LANG = "en"

try:
    _M2M_BN_TOKEN_ID = int(_get_global("M2M100_BN_TOKEN_ID", 128025))
    _M2M_EN_TOKEN_ID = int(_get_global("M2M100_EN_TOKEN_ID", 128022))
except Exception:
    _M2M_BN_TOKEN_ID = 128025
    _M2M_EN_TOKEN_ID = 128022

try:
    _NUM_GPUS = int(_get_global("NUM_GPUS", torch.cuda.device_count() if torch.cuda.is_available() else 0))
    _USE_MULTI_GPU = bool(_get_global("USE_MULTI_GPU", _NUM_GPUS > 1))
except Exception:
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1

try:
    _NUM_WORKERS = int(_get_global("NUM_WORKERS", 0))
    if _NUM_WORKERS < 0:
        _NUM_WORKERS = 0
except Exception:
    _NUM_WORKERS = 0

try:
    _PIN_MEMORY = bool(_get_global("PIN_MEMORY", False))
except Exception:
    _PIN_MEMORY = False

try:
    _PREFETCH_FACTOR = int(_get_global("PREFETCH_FACTOR", 2))
    if _PREFETCH_FACTOR <= 0:
        _PREFETCH_FACTOR = 2
except Exception:
    _PREFETCH_FACTOR = 2

try:
    _DATASET_CSV_PATH = str(_get_global("DATASET_CSV_PATH",
                                        "/kaggle/input/bengali-english-homograph/bengali_homograph_sentences.csv"))
except Exception:
    _DATASET_CSV_PATH = "/kaggle/input/bengali-english-homograph/bengali_homograph_sentences.csv"

try:
    _TRAIN_DOMAIN = int(_get_global("TRAIN_DOMAIN", 0))
    _TEST_DOMAIN = int(_get_global("TEST_DOMAIN", 1))
    _USE_DOMAIN_LABELS = bool(_get_global("USE_DOMAIN_LABELS", False))
except Exception:
    _TRAIN_DOMAIN = 0
    _TEST_DOMAIN = 1
    _USE_DOMAIN_LABELS = False

_has_normalize = ("normalize_bengali" in globals()) and ("normalize_english" in globals())
_has_reconstruct_word_spans = "reconstruct_word_spans" in globals()
_has_safe_offsets_tokenize = "safe_offsets_tokenize" in globals()

if not _has_normalize and DEBUG_CELL2:
    print("[CELL2] normalize_bengali/normalize_english not found; falling back to basic normalization")

_BENGALI_CHAR_RE = re.compile(r"[\u0980-\u09FF]")


def is_bengali_text(s: Optional[str]) -> bool:
    if s is None:
        return False
    if not isinstance(s, str) or not s:
        return False
    return bool(_BENGALI_CHAR_RE.search(s))


def _get_safe_vocab_size(tokenizer) -> int:
    try:
        if tokenizer is None:
            return 128112
        vocab_size = getattr(tokenizer, "vocab_size", None)
        if vocab_size is None:
            try:
                vocab_size = len(tokenizer)
            except Exception:
                vocab_size = 128112
        return int(vocab_size)
    except Exception:
        return 128112


def _dataloader_worker_init_fn(worker_id: int) -> None:
    worker_info = get_worker_info()
    dataset = worker_info.dataset if worker_info is not None else None
    try:
        tok_name = getattr(dataset, "_tokenizer_name_or_path", None) if dataset is not None else None
        if tok_name:
            try:
                from transformers import AutoTokenizer
                dataset.tokenizer = AutoTokenizer.from_pretrained(tok_name, use_fast=True)
                dataset.is_fast = getattr(dataset.tokenizer, "is_fast", False)
                if DEBUG_CELL2:
                    print(f"[CELL2-WORKER-{worker_id}] Tokenizer loaded: {tok_name}")
            except Exception as e:
                cell2_dbg("worker_tokenizer_reload", f"Worker {worker_id} tokenizer reload failed: {e}")
                dataset.tokenizer = None
                dataset.is_fast = False
    except Exception:
        if DEBUG_CELL2:
            print(f"[CELL2-WORKER-INIT] Tokenizer rebind failed worker {worker_id}")

    try:
        base = int(os.environ.get("PYTHONHASHSEED", "0"))
        seed = (base ^ (worker_id + 1) ^ int(time.time())) & 0xFFFFFFFF
        random.seed(seed)
        np.random.seed(seed % (2 ** 31 - 1))
        torch.manual_seed(seed % (2 ** 31 - 1))
    except Exception:
        pass


def load_and_preprocess_optimized(num_samples: Optional[int] = None, split: str = "train") -> List[Tuple[str, str]]:
    if num_samples is None:
        num_samples = _NUM_SAMPLES
    if num_samples <= 0:
        raise ValueError("num_samples must be positive")

    print(f"[CELL2] Loading up to {num_samples} samples from {_DATASET_CSV_PATH}")

    if not _HAS_PANDAS:
        print("[CELL2] pandas not available; using fallback dataset")
        return _get_fallback_dataset()

    if not os.path.exists(_DATASET_CSV_PATH):
        print(f"[CELL2] CSV not found at {_DATASET_CSV_PATH}; using fallback dataset")
        return _get_fallback_dataset()

    try:
        df = pd.read_csv(_DATASET_CSV_PATH, dtype=str, keep_default_na=False)
    except Exception as e:
        print(f"[CELL2] Failed to read CSV: {e}; using fallback dataset")
        return _get_fallback_dataset()

    if df is None or df.shape[0] == 0:
        print("[CELL2] CSV empty; using fallback dataset")
        return _get_fallback_dataset()

    cols = [c.lower().strip() for c in df.columns.tolist()]
    src_col = None
    tgt_col = None

    if "src" in cols and "tgt" in cols:
        src_col = df.columns[cols.index("src")]
        tgt_col = df.columns[cols.index("tgt")]
    else:
        for name in ["source", "source_text", "sentence", "text", "bn", "bengali"]:
            if name in cols:
                src_col = df.columns[cols.index(name)]
                break
        for name in ["target", "target_text", "translation", "en", "english"]:
            if name in cols:
                tgt_col = df.columns[cols.index(name)]
                break
        if src_col is None and len(df.columns) >= 1:
            src_col = df.columns[0]
        if tgt_col is None and len(df.columns) >= 2:
            tgt_col = df.columns[1] if df.columns[1] != src_col else df.columns[0]

    df[src_col] = df[src_col].fillna("").astype(str)
    df[tgt_col] = df[tgt_col].fillna("").astype(str)

    sample_src = df[src_col].iloc[0] if len(df) > 0 else ""
    sample_tgt = df[tgt_col].iloc[0] if len(df) > 0 else ""

    src_is_bengali = bool(_BENGALI_CHAR_RE.search(str(sample_src)))
    tgt_is_bengali = bool(_BENGALI_CHAR_RE.search(str(sample_tgt)))
    src_is_english = bool(re.search(r"[a-zA-Z]", str(sample_src))) and not src_is_bengali
    tgt_is_english = bool(re.search(r"[a-zA-Z]", str(sample_tgt))) and not tgt_is_bengali

    if src_is_english and tgt_is_bengali:
        print("[CELL2] Detected src=English and tgt=Bengali. Swapping columns for bn->en task.")
        df = df.rename(columns={src_col: "__temp_src__", tgt_col: "__temp_tgt__"})
        df["src"] = df["__temp_tgt__"]
        df["tgt"] = df["__temp_src__"]
        src_col = "src"
        tgt_col = "tgt"
        if len(df) > 0:
            sample_src = df[src_col].iloc[0]
            sample_tgt = df[tgt_col].iloc[0]
            src_is_bengali = bool(_BENGALI_CHAR_RE.search(str(sample_src)))
            tgt_is_english = bool(re.search(r"[a-zA-Z]", str(sample_tgt))) and not bool(_BENGALI_CHAR_RE.search(str(sample_tgt)))
            if not src_is_bengali or not tgt_is_english:
                print("[CELL2] WARNING: After swap columns don't clearly match bn->en. Proceeding but results may be noisy.")
    else:
        if not src_is_bengali or not tgt_is_english:
            if DEBUG_CELL2:
                print("[CELL2] Warning: detected languages may not match bn->en (proceeding)")

    df = df.head(int(num_samples))

    pairs: List[Tuple[str, str]] = []
    skipped = 0

    for _, row in tqdm(df.iterrows(), total=len(df), desc="Loading dataset"):
        try:
            src_val = str(row[src_col]).strip()
            tgt_val = str(row[tgt_col]).strip()
            if not src_val or not tgt_val:
                skipped += 1
                continue
            if not is_bengali_text(src_val):
                skipped += 1
                continue
            if not re.search(r"[a-zA-Z]", tgt_val):
                skipped += 1
                continue
            max_words = max(20, _MAX_LENGTH // 2)
            if len(src_val.split()) > max_words or len(tgt_val.split()) > max_words:
                skipped += 1
                continue
            if _has_normalize:
                try:
                    bn_norm = normalize_bengali(src_val)
                    en_norm = normalize_english(tgt_val)
                except Exception:
                    bn_norm = src_val.strip()
                    en_norm = tgt_val.strip().lower()
            else:
                bn_norm = src_val.strip()
                en_norm = tgt_val.strip().lower()
            if not bn_norm or not en_norm:
                skipped += 1
                continue
            pairs.append((bn_norm, en_norm))
        except Exception:
            skipped += 1
            continue

    if DEBUG_CELL2:
        print(f"[CELL2] Loaded pairs: {len(pairs)}, skipped: {skipped}")

    if len(pairs) == 0:
        print("[CELL2] No valid pairs loaded from CSV; using fallback dataset")
        return _get_fallback_dataset()

    return pairs


def _get_fallback_dataset() -> List[Tuple[str, str]]:
    fallback_pairs = [
        ("আমি কল বন্ধ করেছি।", "i turned off the tap."),
        ("সে আমাকে পরে কল করবে।", "he will call me later."),
        ("আমরা প্রতিদিন তাজা ফল খাই।", "we eat fresh fruits every day."),
        ("তার কঠোর পরিশ্রমের ভালো ফল হয়েছে।", "his hard work has brought good results."),
        ("গাছে নতুন পাতাগুলো গজিয়েছে।", "new leaves have sprouted on the tree."),
        ("আমি বইয়ের পাতা উল্টাচ্ছি।", "i am turning the pages of the book."),
        ("কাল আমি বাজারে গিয়েছিলাম।", "yesterday i went to the market."),
        ("কাল আমি তোমার সাথে দেখা করব।", "tomorrow i will meet you."),
        ("তারা আকাশে উজ্জ্বল।", "the stars are bright in the sky."),
        ("তারা বাড়িতে নেই।", "they are not at home."),
        ("ব্যাংক নদীর ধারে ভেঙে গেছে।", "the bank by the river has collapsed."),
        ("আমি ব্যাংকে টাকা জমা দিয়েছি।", "i deposited money in the bank."),
        ("বার বার চেষ্টা করতে হবে।", "you have to try again and again."),
        ("আমি বার খুলে ভিতরে ঢুকলাম।", "i opened the bar and entered."),
        ("তার মাথা ব্যথা করছে।", "his head is hurting."),
        ("আমি মাথা নেড়ে সম্মতি দিলাম।", "i nodded my head in agreement."),
        ("সে হার মেনে নিয়েছে।", "he accepted defeat."),
        ("আমি গলায় সোনার হার পরেছি।", "i am wearing a gold necklace."),
        ("পানি খুব ঠান্ডা।", "the water is very cold."),
        ("আমি পানি খাচ্ছি।", "i am drinking water."),
        ("দল খেলায় জিতেছে।", "the team won the game."),
        ("বাজার থেকে সবজি কিনলাম।", "i bought vegetables from the market."),
        ("তার নাম আহমেদ।", "his name is ahmed."),
        ("নাম না করে কাজ করো।", "work without making a name."),
        ("কথা বলা বন্ধ করো।", "stop talking."),
        ("বই পড়তে ভালো লাগে।", "i like reading books."),
        ("আমি একটি নতুন বই কিনেছি।", "i bought a new book."),
        ("ঘর পরিষ্কার করা হয়েছে।", "the house has been cleaned."),
        ("আমি ঘরে বসে আছি।", "i am sitting at home."),
        ("মন ভালো নেই।", "my mind is not good."),
        ("হাত ধুয়ে নাও।", "wash your hands."),
        ("দেখতে চাই বন সুন্দর।", "the forest is beautiful."),
    ]
    if _has_normalize:
        try:
            return [(normalize_bengali(bn), normalize_english(en)) for bn, en in fallback_pairs]
        except Exception:
            return [(bn.strip(), en.strip().lower()) for bn, en in fallback_pairs]
    return [(bn.strip(), en.strip().lower()) for bn, en in fallback_pairs]


class MemoryEfficientDataset(Dataset):
    def __init__(self, pairs: List[Tuple[str, str]], tokenizer: Any = None, max_length: Optional[int] = None, split: str = "train"):
        if max_length is None:
            max_length = _MAX_LENGTH
        self.max_length = max(1, int(max_length))
        self.tokenizer = tokenizer
        self.split = split

        try:
            self._tokenizer_name_or_path = getattr(tokenizer, "name_or_path", None)
        except Exception:
            self._tokenizer_name_or_path = None

        try:
            self.is_fast = getattr(self.tokenizer, "is_fast", False) if self.tokenizer is not None else False
        except Exception:
            self.is_fast = False

        self.pairs: List[Tuple[str, str]] = []
        invalid = 0
        for i, p in enumerate(pairs):
            try:
                if not isinstance(p, (list, tuple)) or len(p) != 2:
                    invalid += 1
                    cell2_dbg("init_badpair", f"Bad pair structure idx={i}")
                    continue
                src, tgt = p
                if not isinstance(src, str) or not isinstance(tgt, str):
                    invalid += 1
                    cell2_dbg("init_badtype", f"Non-string at idx={i}")
                    continue
                if not src or not tgt:
                    invalid += 1
                    cell2_dbg("init_empty", f"Empty at idx={i}")
                    continue
                if len(src) > self.max_length * 20 or len(tgt) > self.max_length * 20:
                    invalid += 1
                    cell2_dbg("init_long", f"Too long at idx={i}")
                    continue
                self.pairs.append((src, tgt))
            except Exception:
                invalid += 1
                cell2_dbg("init_exc", f"Init exception idx={i}")
        if DEBUG_CELL2:
            print(f"[CELL2] Dataset init: {len(self.pairs)} valid, {invalid} invalid")

        try:
            if self.tokenizer is not None and "get_tokenizer_special_tokens" in globals():
                self.special_tokens = get_tokenizer_special_tokens(self.tokenizer)
            elif self.tokenizer is not None:
                self.special_tokens = set(getattr(self.tokenizer, "all_special_tokens", []))
            else:
                self.special_tokens = set()
        except Exception:
            self.special_tokens = set()

        if not self.special_tokens:
            self.special_tokens = {"</s>", "<pad>", "<s>", "<unk>"}

    def __getstate__(self):
        state = self.__dict__.copy()
        state["tokenizer"] = None
        state["_tokenizer_name_or_path"] = getattr(self, "_tokenizer_name_or_path", None)
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.tokenizer = None
        self.is_fast = False

    def __len__(self) -> int:
        return len(self.pairs)

    def _encode_src(self, src_text: str):
        src_text = src_text if isinstance(src_text, str) else str(src_text)
        try:
            if self.tokenizer is None:
                self.tokenizer = globals().get("tokenizer", None)
                if self.tokenizer is not None:
                    self.is_fast = getattr(self.tokenizer, "is_fast", False)
            if self.tokenizer is None:
                raise RuntimeError("Tokenizer not available")

            vocab_size = _get_safe_vocab_size(self.tokenizer)

            if _has_safe_offsets_tokenize:
                enc = safe_offsets_tokenize(self.tokenizer, src_text, max_length=self.max_length)
                input_ids = enc.get("input_ids")
                attention_mask = enc.get("attention_mask", None)
                if isinstance(input_ids, torch.Tensor):
                    input_ids = input_ids.squeeze(0)
                elif isinstance(input_ids, list):
                    input_ids = torch.tensor(input_ids[0] if isinstance(input_ids[0], list) else input_ids, dtype=torch.long)
                else:
                    input_ids = torch.tensor([int(x) for x in input_ids], dtype=torch.long)
                if attention_mask is None:
                    attention_mask = (input_ids != getattr(self.tokenizer, "pad_token_id", 1)).long()
                elif isinstance(attention_mask, torch.Tensor) and attention_mask.dim() > 1:
                    attention_mask = attention_mask.squeeze(0)
                elif isinstance(attention_mask, list):
                    attention_mask = torch.tensor(attention_mask[0] if isinstance(attention_mask[0], list) else attention_mask, dtype=torch.long)
                try:
                    ids_list = input_ids.tolist() if isinstance(input_ids, torch.Tensor) else list(input_ids)
                    tokens = self.tokenizer.convert_ids_to_tokens(ids_list)
                except Exception:
                    tokens = []
            else:
                enc = self.tokenizer(src_text, max_length=self.max_length, padding="max_length", truncation=True, return_tensors="pt", add_special_tokens=False)
                input_ids = enc["input_ids"].squeeze(0)
                attention_mask = enc.get("attention_mask", torch.ones_like(input_ids)).squeeze(0)
                try:
                    tokens = self.tokenizer.convert_ids_to_tokens(input_ids.tolist())
                except Exception:
                    tokens = []

            input_ids = torch.clamp(input_ids, min=0, max=vocab_size - 1)

            token_word_map: Dict[int, str] = {}
            if _has_reconstruct_word_spans:
                try:
                    wm, words = reconstruct_word_spans(self.tokenizer, src_text, max_length=self.max_length)
                    if isinstance(wm, dict) and wm:
                        token_word_map = wm
                except Exception:
                    token_word_map = {}

            if not token_word_map and tokens:
                try:
                    current_word_parts: List[str] = []
                    for idx, tok in enumerate(tokens):
                        if not isinstance(tok, str) or tok in self.special_tokens:
                            continue
                        clean = tok.replace("▁", "").replace("Ġ", "").replace("##", "").strip()
                        if not clean:
                            continue
                        if tok.startswith("▁") or tok.startswith("Ġ"):
                            current_word_parts = [clean]
                            token_word_map[idx] = clean
                        else:
                            current_word_parts.append(clean)
                            word = "".join(current_word_parts)
                            token_word_map[idx] = word
                            for prev_idx in range(max(0, idx - len(current_word_parts) + 1), idx):
                                token_word_map[prev_idx] = word
                except Exception:
                    token_word_map = {}

            return input_ids, attention_mask, tokens, token_word_map

        except Exception as e:
            if DEBUG_CELL2:
                cell2_dbg("encode_src_fail", f"Source encoding failed: {type(e).__name__}")
            pad_id = getattr(self.tokenizer, "pad_token_id", 1) if self.tokenizer is not None else 1
            input_ids = torch.full((self.max_length,), int(pad_id), dtype=torch.long)
            attention_mask = torch.zeros(self.max_length, dtype=torch.long)
            return input_ids, attention_mask, [], {}

    def _encode_tgt(self, tgt_text: str):
        tgt_text = tgt_text if isinstance(tgt_text, str) else str(tgt_text)
        try:
            if self.tokenizer is None:
                self.tokenizer = globals().get("tokenizer", None)
            if self.tokenizer is None:
                raise RuntimeError("Tokenizer not available")

            vocab_size = _get_safe_vocab_size(self.tokenizer)

            dec = self.tokenizer(tgt_text, max_length=self.max_length, truncation=True, padding=False, return_tensors="pt", add_special_tokens=False)
            labels = dec["input_ids"].squeeze(0)

            pad_id = getattr(self.tokenizer, "pad_token_id", 1)

            non_pad_mask = labels != pad_id
            labels_clamped = torch.clamp(labels, min=0, max=vocab_size - 1)
            labels = torch.where(non_pad_mask, labels_clamped, torch.tensor(pad_id, dtype=labels.dtype))

            if labels.size(0) < self.max_length:
                pad_length = self.max_length - labels.size(0)
                pad_tensor = torch.full((pad_length,), -100, dtype=torch.long)
                labels = torch.cat([labels, pad_tensor], dim=0)
            elif labels.size(0) > self.max_length:
                labels = labels[:self.max_length]

            non_pad_mask_final = labels != -100
            labels_clamped_final = torch.clamp(labels, min=0, max=vocab_size - 1)
            labels = torch.where(non_pad_mask_final, labels_clamped_final, torch.tensor(-100, dtype=labels.dtype))

            return labels

        except Exception as e:
            if DEBUG_CELL2:
                cell2_dbg("encode_tgt_fail", f"Target encoding failed: {type(e).__name__}")
            return torch.full((self.max_length,), -100, dtype=torch.long)

    def _make_safe_sample(self, reason: str = "fallback") -> Dict[str, Any]:
        try:
            src = "আমি কল বন্ধ করেছি।"
            tgt = "i turned off the tap."
            input_ids, attention_mask, tokens, token_word_map = self._encode_src(src)
            labels = self._encode_tgt(tgt)

            domain_label = int(_TRAIN_DOMAIN if self.split == "train" else _TEST_DOMAIN)
            domain_label = max(0, min(domain_label, 255))

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "token_word_map": token_word_map,
                "src_text": src,
                "tokens": tokens,
                "domain_label": domain_label
            }
        except Exception:
            pad_id = 1
            domain_label = int(_TRAIN_DOMAIN if self.split == "train" else _TEST_DOMAIN)
            domain_label = max(0, min(domain_label, 255))
            return {
                "input_ids": torch.full((self.max_length,), int(pad_id), dtype=torch.long),
                "attention_mask": torch.zeros(self.max_length, dtype=torch.long),
                "labels": torch.full((self.max_length,), -100, dtype=torch.long),
                "token_word_map": {},
                "src_text": "",
                "tokens": [],
                "domain_label": domain_label
            }

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        try:
            if idx < 0 or idx >= len(self.pairs):
                return self._make_safe_sample("oob")

            src, tgt = self.pairs[idx]

            if not isinstance(src, str) or not isinstance(tgt, str):
                return self._make_safe_sample("bad_types")

            input_ids, attention_mask, tokens, token_word_map = self._encode_src(src)
            labels = self._encode_tgt(tgt)

            domain_label = int(_TRAIN_DOMAIN if self.split == "train" else _TEST_DOMAIN)
            domain_label = max(0, min(domain_label, 255))

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "token_word_map": token_word_map,
                "src_text": src,
                "tokens": tokens,
                "domain_label": domain_label
            }

        except Exception:
            return self._make_safe_sample("unhandled")


def _infer_pad_id_from_sample(sample: Dict[str, Any], default_pad_id: int = 1) -> int:
    try:
        tk = globals().get("tokenizer", None)
        if tk is not None:
            pad = getattr(tk, "pad_token_id", None)
            if pad is not None:
                return int(pad)
    except Exception:
        pass
    return int(default_pad_id)


def _pad_or_truncate_array(tensor: torch.Tensor, length: int, pad_value: int) -> torch.Tensor:
    if tensor is None:
        return torch.full((length,), int(pad_value), dtype=torch.long)
    try:
        t = tensor.view(-1).long()
    except Exception:
        t = tensor.flatten().long()
    L = t.size(0)
    if L == length:
        return t
    if L < length:
        pad = torch.full((length - L,), int(pad_value), dtype=t.dtype)
        return torch.cat([t, pad], dim=0)
    return t[:length]


def safe_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    if not batch:
        pad = _infer_pad_id_from_sample({}, default_pad_id=1)
        return {
            "input_ids": torch.full((1, _MAX_LENGTH), pad, dtype=torch.long),
            "attention_mask": torch.zeros(1, _MAX_LENGTH, dtype=torch.long),
            "labels": torch.full((1, _MAX_LENGTH), -100, dtype=torch.long),
            "token_word_map": [{}],
            "src_text": [""],
            "tokens": [[]],
            "domain_labels": torch.tensor([_TRAIN_DOMAIN], dtype=torch.long)
        }

    valid = [b for b in batch if isinstance(b, dict) and "input_ids" in b and isinstance(b["input_ids"], torch.Tensor)]

    if not valid:
        pad = _infer_pad_id_from_sample({}, default_pad_id=1)
        return {
            "input_ids": torch.full((1, _MAX_LENGTH), pad, dtype=torch.long),
            "attention_mask": torch.zeros(1, _MAX_LENGTH, dtype=torch.long),
            "labels": torch.full((1, _MAX_LENGTH), -100, dtype=torch.long),
            "token_word_map": [{}],
            "src_text": [""],
            "tokens": [[]],
            "domain_labels": torch.tensor([_TRAIN_DOMAIN], dtype=torch.long)
        }

    pad_id = _infer_pad_id_from_sample(valid[0], default_pad_id=1)

    vocab_size = 128112
    try:
        tk = globals().get("tokenizer", None)
        if tk is not None:
            vocab_size = _get_safe_vocab_size(tk)
    except Exception:
        pass

    inputs, masks, labs, twmaps, srcs, toks, domains = [], [], [], [], [], [], []

    for i, s in enumerate(valid):
        try:
            in_ids = s["input_ids"]
            att = s.get("attention_mask", None)
            lab = s["labels"]
            domain = s.get("domain_label", _TRAIN_DOMAIN)

            if att is None:
                att = (in_ids != pad_id).long()
            else:
                try:
                    att = att.view(-1).long()
                except Exception:
                    att = att.flatten().long()

            try:
                in_ids = in_ids.view(-1)
            except Exception:
                in_ids = in_ids.flatten()

            try:
                lab = lab.view(-1)
            except Exception:
                lab = lab.flatten()

            in_ids = _pad_or_truncate_array(in_ids, _MAX_LENGTH, pad_id)
            att = _pad_or_truncate_array(att, _MAX_LENGTH, 0)
            lab = _pad_or_truncate_array(lab, _MAX_LENGTH, -100)

            in_ids = torch.clamp(in_ids, min=0, max=vocab_size - 1)

            non_pad_mask = lab != -100
            lab_clamped = torch.clamp(lab, min=0, max=vocab_size - 1)
            lab = torch.where(non_pad_mask, lab_clamped, torch.tensor(-100, dtype=lab.dtype))

            inputs.append(in_ids)
            masks.append(att)
            labs.append(lab)
            twmaps.append(s.get("token_word_map", {}))
            srcs.append(s.get("src_text", ""))
            toks.append(s.get("tokens", []))
            domains.append(max(0, min(int(domain), 255)))

        except Exception:
            continue

    if not inputs:
        pad = _infer_pad_id_from_sample({}, default_pad_id=1)
        return {
            "input_ids": torch.full((1, _MAX_LENGTH), pad, dtype=torch.long),
            "attention_mask": torch.zeros(1, _MAX_LENGTH, dtype=torch.long),
            "labels": torch.full((1, _MAX_LENGTH), -100, dtype=torch.long),
            "token_word_map": [{}],
            "src_text": [""],
            "tokens": [[]],
            "domain_labels": torch.tensor([_TRAIN_DOMAIN], dtype=torch.long)
        }

    input_ids = torch.stack(inputs, dim=0)
    attention_mask = torch.stack(masks, dim=0)
    labels = torch.stack(labs, dim=0)

    try:
        domain_labels = torch.tensor(domains, dtype=torch.long)
    except Exception:
        domain_labels = torch.full((len(inputs),), _TRAIN_DOMAIN, dtype=torch.long)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "token_word_map": twmaps,
        "src_text": srcs,
        "tokens": toks,
        "domain_labels": domain_labels
    }


def create_optimized_dataloader(dataset: Dataset, batch_size: Optional[int] = None, shuffle: bool = True, split: str = "train") -> DataLoader:
    if batch_size is None:
        try:
            batch_size = int(_get_global("BATCH_SIZE", 8))
        except Exception:
            batch_size = 8

    batch_size = max(1, int(batch_size))
    original_batch_size = batch_size
    adjusted = False

    if _USE_MULTI_GPU and _NUM_GPUS > 1:
        remainder = batch_size % _NUM_GPUS
        if remainder != 0:
            new_batch_size = batch_size - remainder
            if new_batch_size == 0:
                new_batch_size = _NUM_GPUS
            batch_size = new_batch_size
            adjusted = batch_size != original_batch_size

    if adjusted:
        print(f"[CELL2] Adjusted batch size from {original_batch_size} to {batch_size} for DP across {_NUM_GPUS} GPUs")

    num_workers = _NUM_WORKERS if isinstance(_NUM_WORKERS, int) and _NUM_WORKERS >= 0 else 0

    try:
        max_possible = max(0, (os.cpu_count() or 1) - 1)
        if num_workers > max_possible:
            num_workers = max_possible
    except Exception:
        pass

    loader_kwargs: Dict[str, Any] = {
        "dataset": dataset,
        "batch_size": batch_size,
        "shuffle": shuffle,
        "num_workers": num_workers,
        "pin_memory": bool(_PIN_MEMORY and torch.cuda.is_available()),
        "collate_fn": safe_collate,
        "drop_last": False
    }

    if num_workers > 0:
        loader_kwargs["worker_init_fn"] = _dataloader_worker_init_fn
        loader_kwargs["prefetch_factor"] = max(2, _PREFETCH_FACTOR)
        loader_kwargs["persistent_workers"] = False

    try:
        dataloader = DataLoader(**loader_kwargs)
    except Exception:
        loader_kwargs["num_workers"] = 0
        loader_kwargs.pop("prefetch_factor", None)
        loader_kwargs.pop("persistent_workers", None)
        loader_kwargs.pop("worker_init_fn", None)
        dataloader = DataLoader(**loader_kwargs)

    if _USE_MULTI_GPU and _NUM_GPUS > 1:
        per_gpu = batch_size // _NUM_GPUS
        print(f"[CELL2] DataLoader created: total_batch={batch_size}, per_gpu={per_gpu}, workers={loader_kwargs.get('num_workers', 0)}")
    else:
        print(f"[CELL2] DataLoader created: batch_size={batch_size}, workers={loader_kwargs.get('num_workers', 0)}")

    return dataloader


print("Cell 2: Memory-efficient data loading ready")


In [None]:
# ==============================================================================
# CELL 3: DSCD MODULE - NaN/Inf FULLY HARDENED
# ==============================================================================

import threading
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gc
from collections import deque
import unicodedata
from typing import Optional, Dict, List, Any, Set, Tuple
from concurrent.futures import ThreadPoolExecutor, Future

PRINT_INTERVAL = 200

try:
    from scipy.cluster.hierarchy import linkage, fcluster
    from scipy.spatial.distance import pdist
    HAS_CLUSTERING = True
except Exception:
    HAS_CLUSTERING = False

try:
    from sklearn.cluster import KMeans
    HAS_KMEANS = True
except Exception:
    HAS_KMEANS = False

try:
    _DSCD_MAX_PROTOS = int(DSCD_MAX_PROTOS)
    _DSCD_BUFFER_SIZE = int(DSCD_BUFFER_SIZE)
    _DSCD_N_MIN = int(DSCD_N_MIN)
    _DSCD_DISPERSION_THRESHOLD = float(DSCD_DISPERSION_THRESHOLD)
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
    _DSCD_ENABLE_TRAINING_CLUSTERING = bool(DSCD_ENABLE_TRAINING_CLUSTERING)
    _DSCD_USE_COSINE_DISTANCE = bool(DSCD_USE_COSINE_DISTANCE)
    _DSCD_ENABLE_ONLINE_CLUSTERING = bool(DSCD_ENABLE_ONLINE_CLUSTERING)
    _DSCD_ONLINE_CLUSTERING_FREQUENCY = int(DSCD_ONLINE_CLUSTERING_FREQUENCY)
    _APPLY_DSCD_AUGMENTATION = bool(APPLY_DSCD_AUGMENTATION)
except Exception:
    _DSCD_MAX_PROTOS = 8
    _DSCD_BUFFER_SIZE = 50
    _DSCD_N_MIN = 2
    _DSCD_DISPERSION_THRESHOLD = 0.70
    _VERBOSE_LOGGING = False
    _DSCD_ENABLE_TRAINING_CLUSTERING = True
    _DSCD_USE_COSINE_DISTANCE = True
    _DSCD_ENABLE_ONLINE_CLUSTERING = True
    _DSCD_ONLINE_CLUSTERING_FREQUENCY = 10
    _APPLY_DSCD_AUGMENTATION = False

_DSCD_MAX_PROTOS = max(1, int(_DSCD_MAX_PROTOS))
_DSCD_BUFFER_SIZE = max(1, int(_DSCD_BUFFER_SIZE))
_DSCD_N_MIN = max(1, int(_DSCD_N_MIN))

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except Exception:
    _DEBUG_DISCOVERY = False

try:
    _MAX_TOKENS_PER_DISCOVERY = int(globals().get("_MAX_TOKENS_PER_DISCOVERY", 150))
except Exception:
    _MAX_TOKENS_PER_DISCOVERY = 150

_MAX_TOKENS_PER_DISCOVERY = max(1, _MAX_TOKENS_PER_DISCOVERY)

try:
    _DSCD_NEW_SENSE_LAMBDA = float(DSCD_NEW_SENSE_LAMBDA)
    if _DSCD_NEW_SENSE_LAMBDA <= 0:
        _DSCD_NEW_SENSE_LAMBDA = 1.5
except Exception:
    _DSCD_NEW_SENSE_LAMBDA = 1.5

try:
    _HOMOGRAPH_REFERENCE_LIST_BN = set(HOMOGRAPH_REFERENCE_LIST_BN)
except Exception:
    _HOMOGRAPH_REFERENCE_LIST_BN = {"কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা", "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত"}

try:
    _DSCD_MIN_LETTERS = int(DSCD_MIN_LETTERS)
    _DSCD_MIN_LETTER_FRACTION = float(DSCD_MIN_LETTER_FRACTION)
    _DSCD_MAX_CLUSTERING_POINTS = int(DSCD_MAX_CLUSTERING_POINTS)
except Exception:
    _DSCD_MIN_LETTERS = 2
    _DSCD_MIN_LETTER_FRACTION = 0.5
    _DSCD_MAX_CLUSTERING_POINTS = 500

_DSCD_MIN_LETTERS = max(1, _DSCD_MIN_LETTERS)
_DSCD_MIN_LETTER_FRACTION = min(max(0.0, _DSCD_MIN_LETTER_FRACTION), 1.0)
_DSCD_MAX_CLUSTERING_POINTS = max(1, _DSCD_MAX_CLUSTERING_POINTS)

_TRG_PUNCT_SET = set(".,!?;:-")
_PUNCT_SET = _TRG_PUNCT_SET


def normalize_token_key(token: str) -> Optional[str]:
    if token is None:
        return None
    token = unicodedata.normalize("NFKC", str(token))
    token = token.replace("▁", "").replace("Ġ", "").replace("##", "").replace("</w>", "")
    token = token.strip().lower()
    if not token or len(token) < 2:
        return None
    letter_count = sum(1 for ch in token if unicodedata.category(ch).startswith("L"))
    total_chars = sum(1 for ch in token if not ch.isspace())
    if total_chars == 0 or letter_count == 0:
        return None
    if letter_count < max(1, _DSCD_MIN_LETTERS):
        return None
    if letter_count / total_chars < _DSCD_MIN_LETTER_FRACTION:
        return None
    if all(c in _PUNCT_SET for c in token):
        return None
    return token


def is_word_token(token: str, min_letters: int = 2, min_letter_fraction: float = 0.6) -> bool:
    if not token or not isinstance(token, str):
        return False
    token = token.strip()
    letters = sum(1 for ch in token if unicodedata.category(ch).startswith("L"))
    total = sum(1 for ch in token if not ch.isspace())
    if total == 0 or letters < min_letters:
        return False
    return letters / total >= min_letter_fraction


def reconstruct_word_embeddings(
    token_embeddings: torch.Tensor,
    input_ids: Optional[torch.Tensor],
    tokenizer,
    device: torch.device
) -> Tuple[torch.Tensor, List[Dict[int, str]]]:
    batch_size = int(token_embeddings.size(0))
    token_seq_len = int(token_embeddings.size(1))
    embed_dim = int(token_embeddings.size(-1))
    
    word_embeddings_list: List[torch.Tensor] = []
    word_maps_batch: List[Dict[int, str]] = []
    
    for b in range(batch_size):
        if input_ids is not None:
            try:
                ids = input_ids[b].tolist()
                tokens = tokenizer.convert_ids_to_tokens(ids)
            except Exception:
                tokens = [f"tok{i}" for i in range(token_seq_len)]
        else:
            tokens = [f"tok{i}" for i in range(token_seq_len)]
        
        words: List[str] = []
        word_spans: List[Tuple[int, int]] = []
        current_word = ""
        word_start: Optional[int] = None
        
        for j, tok in enumerate(tokens):
            if not isinstance(tok, str):
                tok = str(tok)
            
            if tok in {"<s>", "</s>", "<pad>", "<unk>", "[CLS]", "[SEP]", "[BOS]", "[EOS]"}:
                if current_word and word_start is not None:
                    words.append(current_word)
                    word_spans.append((word_start, j))
                current_word = ""
                word_start = None
                continue
            
            is_start = tok.startswith("▁") or tok.startswith("\u2581") or tok.startswith("Ġ")
            clean_tok = tok.lstrip("▁\u2581Ġ")
            
            if is_start:
                if current_word and word_start is not None:
                    words.append(current_word)
                    word_spans.append((word_start, j))
                current_word = clean_tok
                word_start = j
            else:
                current_word += clean_tok
        
        if current_word and word_start is not None:
            words.append(current_word)
            word_spans.append((word_start, len(tokens)))
        
        word_embs: List[torch.Tensor] = []
        word_map: Dict[int, str] = {}
        
        for w_idx, (s, e) in enumerate(word_spans):
            if e > s and 0 <= s < token_seq_len and e <= token_seq_len:
                sub_embs = token_embeddings[b, s:e, :]
                if sub_embs.numel() == 0:
                    continue
                word_emb = sub_embs.mean(dim=0)
                word_embs.append(word_emb)
                
                try:
                    word_name = words[w_idx].replace("▁", "").strip()
                except Exception:
                    word_name = "".join(words[w_idx].split()).strip()
                
                word_map[w_idx] = word_name
        
        if not word_embs:
            word_emb_tensor = torch.zeros(1, embed_dim, device=device, dtype=torch.float32)
        else:
            word_emb_tensor = torch.stack(word_embs, dim=0).to(device).to(dtype=torch.float32)
        
        word_embeddings_list.append(word_emb_tensor)
        word_maps_batch.append(word_map)
    
    max_words = max((w.size(0) for w in word_embeddings_list), default=1)
    padded = []
    for w in word_embeddings_list:
        if w.size(0) < max_words:
            pad = torch.zeros(max_words - w.size(0), w.size(1), device=w.device, dtype=w.dtype)
            w = torch.cat([w, pad], dim=0)
        padded.append(w)
    
    word_embeddings = torch.stack(padded, dim=0)
    return word_embeddings, word_maps_batch


class MemoryEfficientPrototypeStore:
    def __init__(self, embed_dim, max_protos: Optional[int] = None):
        if max_protos is None:
            max_protos = _DSCD_MAX_PROTOS
        self.embed_dim = max(1, int(embed_dim))
        self.max_protos = max(1, int(max_protos))
        
        self.centroids: List[torch.Tensor] = []
        self.counts: List[int] = []
        self.creation_time: List[float] = []
        self.distances: List[float] = []
        
        self.mu = 0.0
        self.tau = 1e-3
        self.alpha = 0.1
        self.labels: Optional[torch.Tensor] = None
        self.corruption_warnings = 0
    
    def normalize_vec(self, v: torch.Tensor) -> torch.Tensor:
        try:
            v = v.detach().cpu().to(dtype=torch.float32)
            if not torch.isfinite(v).all():
                if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.corruption_warnings < 10:
                    print(f"⚠️  DSCD-STORE: Non-finite vector detected, returning zeros")
                self.corruption_warnings += 1
                return torch.zeros_like(v)
            v = torch.clamp(v, min=-100.0, max=100.0)
            norm = v.norm()
            if norm.item() < 1e-6:
                return v
            return v / (norm + 1e-9)
        except Exception:
            return torch.zeros(self.embed_dim, dtype=torch.float32)
    
    def validate_centroid(self, v: torch.Tensor) -> bool:
        try:
            if not isinstance(v, torch.Tensor):
                return False
            if v.numel() != self.embed_dim:
                return False
            if not torch.isfinite(v).all():
                return False
            if v.dtype not in [torch.float32, torch.float64, torch.float16]:
                return False
            norm = v.norm().item()
            if not np.isfinite(norm) or norm > 1000.0:
                return False
            return True
        except Exception:
            return False
    
    def add_prototype(self, vector: torch.Tensor, current_time: Optional[float] = None, count: int = 1) -> None:
        if current_time is None:
            current_time = time.time()
        
        try:
            v = vector.detach().cpu().clone().to(dtype=torch.float32)
        except Exception:
            try:
                v = torch.tensor(np.asarray(vector), dtype=torch.float32)
            except Exception:
                v = torch.zeros(self.embed_dim, dtype=torch.float32)
        
        v = torch.clamp(v, min=-100.0, max=100.0)
        v = self.normalize_vec(v)
        
        if not self.validate_centroid(v):
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.corruption_warnings < 10:
                print(f"⚠️  DSCD-STORE: Invalid centroid rejected in add_prototype")
            self.corruption_warnings += 1
            return
        
        if len(self.centroids) < self.max_protos:
            self.centroids.append(v)
            self.counts.append(max(1, int(count)))
            self.creation_time.append(float(current_time))
        else:
            try:
                min_idx = int(np.argmin(self.counts) if self.counts else 0)
            except Exception:
                min_idx = 0
            min_idx = max(0, min(min_idx, len(self.centroids) - 1))
            self.centroids[min_idx] = v
            while len(self.counts) <= min_idx:
                self.counts.append(1)
            self.counts[min_idx] = max(1, int(count))
            while len(self.creation_time) <= min_idx:
                self.creation_time.append(float(current_time))
            self.creation_time[min_idx] = float(current_time)
    
    def update_prototype(self, idx: int, vector: torch.Tensor, eta: float = 0.05, assignment_distance: Optional[float] = None) -> None:
        if idx < 0 or idx >= len(self.centroids):
            self.add_prototype(vector, time.time(), count=1)
            return
        
        eta = max(0.0, min(0.5, float(eta)))
        
        try:
            old_centroid = self.centroids[idx]
            new_vector = vector.detach().cpu().to(dtype=torch.float32)
            new_vector = torch.clamp(new_vector, min=-100.0, max=100.0)
            new_vector = self.normalize_vec(new_vector)
            
            if not isinstance(old_centroid, torch.Tensor):
                old_centroid = torch.tensor(np.asarray(old_centroid), dtype=torch.float32)
            old_centroid = self.normalize_vec(old_centroid)
            
            updated = (1.0 - eta) * old_centroid + eta * new_vector
            updated = torch.clamp(updated, min=-100.0, max=100.0)
            updated = self.normalize_vec(updated)
            
            if not self.validate_centroid(updated):
                if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.corruption_warnings < 10:
                    print(f"⚠️  DSCD-STORE: Invalid centroid rejected in update_prototype")
                self.corruption_warnings += 1
                return
            
            self.centroids[idx] = updated
            while len(self.counts) <= idx:
                self.counts.append(1)
            self.counts[idx] = int(self.counts[idx]) + 1
        except Exception as e:
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.corruption_warnings < 10:
                print(f"⚠️  DSCD-STORE: update_prototype failed: {type(e).__name__}")
            self.corruption_warnings += 1
            return
        
        if assignment_distance is not None:
            self.update_rolling_stats(float(assignment_distance))
    
    def update_rolling_stats(self, d: float) -> None:
        d = max(0.0, min(10.0, float(d)))
        
        if not self.distances:
            self.mu = float(d)
            self.tau = max(1e-3, min(10.0, abs(float(d) * 0.1)))
            self.distances = [float(d)]
            return
        
        prev_mu = self.mu
        self.mu = (1 - self.alpha) * self.mu + self.alpha * float(d)
        self.mu = max(0.0, min(10.0, self.mu))
        
        self.tau = (1 - self.alpha) * self.tau + self.alpha * abs(float(d) - prev_mu)
        self.tau = max(1e-3, min(10.0, self.tau))
        
        self.distances.append(float(d))
        if len(self.distances) > 50:
            self.distances.pop(0)
    
    def get_adaptive_threshold(self, lam: float = 1.0) -> float:
        lam = max(0.0, min(5.0, float(lam)))
        return float(max(0.0, min(10.0, self.mu + lam * self.tau)))
    
    def size(self) -> int:
        return len(self.centroids)
    
    def ensure_consistency(self) -> None:
        n = len(self.centroids)
        if len(self.counts) != n:
            if len(self.counts) > n:
                self.counts = self.counts[:n]
            else:
                self.counts.extend([1] * (n - len(self.counts)))
        
        if len(self.creation_time) != n:
            if len(self.creation_time) > n:
                self.creation_time = self.creation_time[:n]
            else:
                self.creation_time.extend([time.time()] * (n - len(self.creation_time)))
        
        valid_centroids = []
        valid_counts = []
        valid_times = []
        for i in range(n):
            if self.validate_centroid(self.centroids[i]):
                valid_centroids.append(self.centroids[i])
                valid_counts.append(self.counts[i] if i < len(self.counts) else 1)
                valid_times.append(self.creation_time[i] if i < len(self.creation_time) else time.time())
        
        if len(valid_centroids) != n:
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.corruption_warnings < 10:
                print(f"⚠️  DSCD-STORE: Removed {n - len(valid_centroids)} corrupted centroids in ensure_consistency")
            self.corruption_warnings += 1
        
        self.centroids = valid_centroids
        self.counts = valid_counts
        self.creation_time = valid_times


class SigmaNet(nn.Module):
    def __init__(self, embed_dim: int = 1024):
        super().__init__()
        embed_dim = max(1, int(embed_dim))
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 256),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(256, 1)
        )
        
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.mlp(x)


class MemoryEfficientDSCDOnline(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        tokenizer=None,
        buffer_size: Optional[int] = None,
        max_protos: Optional[int] = None,
        n_min: Optional[int] = None,
        dispersion_threshold: Optional[float] = None,
        language: str = "bn",
        enable_training_clustering: Optional[bool] = None,
        max_clustering_points: Optional[int] = None,
        max_candidates_per_step: int = 2,
        dscd_min_letters: int = 2,
        dscd_min_letter_fraction: float = 0.6,
    ):
        super().__init__()
        
        if buffer_size is None:
            buffer_size = _DSCD_BUFFER_SIZE
        if max_protos is None:
            max_protos = _DSCD_MAX_PROTOS
        if n_min is None:
            n_min = _DSCD_N_MIN
        if dispersion_threshold is None:
            dispersion_threshold = _DSCD_DISPERSION_THRESHOLD
        if max_clustering_points is None:
            max_clustering_points = _DSCD_MAX_CLUSTERING_POINTS
        if enable_training_clustering is None:
            enable_training_clustering = _DSCD_ENABLE_TRAINING_CLUSTERING
        
        self.embed_dim = max(1, int(embed_dim))
        self.buffer_size = max(1, int(buffer_size))
        self.max_protos = max(1, int(max_protos))
        self.n_min = max(1, int(n_min))
        self.dispersion_threshold = max(0.0, float(dispersion_threshold))
        self.language = language
        self.tokenizer = tokenizer
        
        self.dscd_min_letters = max(1, int(dscd_min_letters))
        self.dscd_min_letter_fraction = min(max(0.0, float(dscd_min_letter_fraction)), 1.0)
        
        self.use_cosine_distance = bool(_DSCD_USE_COSINE_DISTANCE)
        self.enable_online_clustering = bool(_DSCD_ENABLE_ONLINE_CLUSTERING)
        self.online_clustering_frequency = max(1, int(_DSCD_ONLINE_CLUSTERING_FREQUENCY))
        self.apply_augmentation = bool(_APPLY_DSCD_AUGMENTATION)
        
        self.sigma_net = SigmaNet(embed_dim=self.embed_dim)
        
        try:
            if tokenizer is not None and "get_tokenizer_special_tokens" in globals():
                self.special_tokens = get_tokenizer_special_tokens(tokenizer)
            else:
                self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []) if tokenizer is not None else [])
        except Exception:
            self.special_tokens = set()
        
        self.dscd_allowed_tokens: Set[str] = set()
        self.dscd_ignored_tokens: Set[str] = set()
        self.dscd_cache_max_size = 10000
        
        self.prototype_stores: Dict[str, MemoryEfficientPrototypeStore] = {}
        self.buffers: Dict[str, deque] = {}
        self.buffers_raw: Dict[str, deque] = {}
        self.discovered_log: List[Dict[str, Any]] = []
        self.discovered_homographs: Set[str] = set()
        
        self.last_periodic_check = 0
        self.cleanup_counter = 0
        
        self.dispersion_cache: Dict[str, float] = {}
        self.dispersion_last_updated: Dict[str, float] = {}
        self.dispersion_lock = threading.Lock()
        
        self.clustering_lock = threading.Lock()
        self.buffer_lock = threading.Lock()
        
        self.cluster_executor = ThreadPoolExecutor(max_workers=1)
        from collections import deque as thread_deque
        self.active_threads = thread_deque(maxlen=100)
        self.thread_lock = threading.Lock()
        
        self.last_cluster_time: Dict[str, float] = {}
        self.cluster_cooldown_seconds = 2.0
        
        self.enable_training_clustering = bool(enable_training_clustering)
        
        self.discovery_count = 0
        self.discovery_times: List[float] = []
        self.clustered_tokens: Set[str] = set()
        self.cluster_stats: Dict[str, Dict[str, Any]] = {}
        
        self.max_clustering_points = max(1, int(max_clustering_points))
        self.max_candidates_per_step = max(1, int(max_candidates_per_step))
        
        self.token_addition_counts: Dict[str, int] = {}
        
        self.state_dict_errors = 0
        self.load_state_dict_errors = 0
        
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print("=" * 80)
            print("DSCD-INIT: MemoryEfficientDSCDOnline initialized - NaN/Inf PROTECTED")
            print("=" * 80)
            print(f"  - embed_dim={self.embed_dim}, max_protos={self.max_protos}")
            print(f"  - n_min={self.n_min}, dispersion_threshold={self.dispersion_threshold}")
            print(f"  - use_cosine_distance={self.use_cosine_distance}")
            print(f"  - Centroid norm clamp: [-100, 100]")
            print(f"  - Distance clamp: [0, 10]")
            print(f"  - Rolling stats bounds: [0,10], [1e-3,10]")
            print("=" * 80)
    
    def state_dict(self, destination=None, prefix='', keep_vars=False):
        state = super().state_dict(destination, prefix, keep_vars)
        
        plain_stores = {}
        corrupted_count = 0
        for token, store in self.prototype_stores.items():
            try:
                cent_list = []
                for c in getattr(store, "centroids", []):
                    try:
                        if isinstance(c, torch.Tensor):
                            c_cpu = c.detach().cpu().to(dtype=torch.float32)
                            if not torch.isfinite(c_cpu).all():
                                corrupted_count += 1
                                continue
                            if c_cpu.numel() != self.embed_dim:
                                corrupted_count += 1
                                continue
                            cent_list.append(c_cpu)
                        else:
                            arr = np.asarray(c, dtype=np.float32)
                            if not np.isfinite(arr).all():
                                corrupted_count += 1
                                continue
                            if arr.size != self.embed_dim:
                                corrupted_count += 1
                                continue
                            cent_list.append(torch.from_numpy(arr).to(dtype=torch.float32))
                    except Exception:
                        corrupted_count += 1
                        continue
                
                if cent_list:
                    try:
                        centroids_tensor = torch.stack(cent_list, dim=0)
                    except Exception:
                        corrupted_count += 1
                        continue
                else:
                    continue
                
                counts = list(getattr(store, "counts", []))
                if len(counts) != len(cent_list):
                    counts = [1] * len(cent_list)
                
                creation_time = list(getattr(store, "creation_time", []))
                if len(creation_time) != len(cent_list):
                    creation_time = [time.time()] * len(cent_list)
                
                plain_stores[token] = {
                    "centroids": centroids_tensor,
                    "counts": [int(c) for c in counts],
                    "creation_time": [float(t) for t in creation_time],
                    "mu": max(0.0, min(10.0, float(getattr(store, "mu", 0.0)))),
                    "tau": max(1e-3, min(10.0, float(getattr(store, "tau", 1e-3)))),
                    "size": len(cent_list),
                }
            except Exception as e:
                corrupted_count += 1
                if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.state_dict_errors < 10:
                    print(f"⚠️  DSCD state_dict failed for token {token}: {type(e).__name__}")
                self.state_dict_errors += 1
                continue
        
        state[prefix + "prototype_stores_data"] = plain_stores
        state[prefix + "discovered_homographs"] = list(self.discovered_homographs)
        
        if corrupted_count > 0 and (_VERBOSE_LOGGING or _DEBUG_DISCOVERY):
            print(f"⚠️  DSCD state_dict: Skipped {corrupted_count} corrupted centroids")
        
        return state
    
    def load_state_dict(self, state_dict, strict=True):
        prefix = ""
        plain_stores = state_dict.pop(prefix + "prototype_stores_data", {})
        discovered = state_dict.pop(prefix + "discovered_homographs", [])
        
        super().load_state_dict(state_dict, strict=strict)
        
        if not plain_stores:
            return
        
        self.prototype_stores = {}
        self.discovered_homographs = set(discovered) if discovered else set()
        
        corrupted_count = 0
        loaded_count = 0
        for token, store_dict in plain_stores.items():
            try:
                store = MemoryEfficientPrototypeStore(embed_dim=self.embed_dim, max_protos=self.max_protos)
                
                centroids_data = store_dict.get("centroids", torch.empty(0, self.embed_dim, dtype=torch.float32))
                store.centroids = []
                
                try:
                    if isinstance(centroids_data, torch.Tensor) and centroids_data.numel() > 0:
                        if centroids_data.dim() == 2 and centroids_data.size(1) == self.embed_dim:
                            for i in range(centroids_data.size(0)):
                                t = centroids_data[i].detach().cpu().to(dtype=torch.float32)
                                if torch.isfinite(t).all():
                                    store.centroids.append(t)
                                else:
                                    corrupted_count += 1
                        else:
                            corrupted_count += 1
                            continue
                    elif isinstance(centroids_data, list):
                        for c in centroids_data:
                            if isinstance(c, torch.Tensor):
                                c_cpu = c.detach().cpu().to(dtype=torch.float32)
                                if c_cpu.numel() == self.embed_dim and torch.isfinite(c_cpu).all():
                                    store.centroids.append(c_cpu)
                                else:
                                    corrupted_count += 1
                            else:
                                try:
                                    arr = np.asarray(c, dtype=np.float32)
                                    if arr.size == self.embed_dim and np.isfinite(arr).all():
                                        store.centroids.append(torch.from_numpy(arr).to(dtype=torch.float32))
                                    else:
                                        corrupted_count += 1
                                except Exception:
                                    corrupted_count += 1
                    else:
                        corrupted_count += 1
                        continue
                except Exception as e:
                    if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.load_state_dict_errors < 10:
                        print(f"⚠️  DSCD load_state_dict centroid parsing failed for {token}: {type(e).__name__}")
                    self.load_state_dict_errors += 1
                    corrupted_count += 1
                    continue
                
                if not store.centroids:
                    corrupted_count += 1
                    continue
                
                store.counts = store_dict.get("counts", [])
                store.creation_time = store_dict.get("creation_time", [])
                store.mu = max(0.0, min(10.0, float(store_dict.get("mu", 0.0))))
                store.tau = max(1e-3, min(10.0, float(store_dict.get("tau", 1e-3))))
                
                try:
                    store.ensure_consistency()
                except Exception as e:
                    if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.load_state_dict_errors < 10:
                        print(f"⚠️  DSCD ensure_consistency failed for {token}: {type(e).__name__}")
                    self.load_state_dict_errors += 1
                    corrupted_count += 1
                    continue
                
                if store.size() > 0:
                    self.prototype_stores[token] = store
                    loaded_count += 1
                else:
                    corrupted_count += 1
                
            except Exception as e:
                if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.load_state_dict_errors < 10:
                    print(f"⚠️  DSCD load_state_dict failed for token {token}: {type(e).__name__}")
                self.load_state_dict_errors += 1
                corrupted_count += 1
                continue
        
        if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
            print(f"✅ DSCD load_state_dict: Loaded {loaded_count} stores, skipped {corrupted_count} corrupted")
    
    @staticmethod
    def clean_token(token):
        if token is None:
            return token
        token = unicodedata.normalize("NFKC", str(token))
        token = token.replace("▁", "").replace("Ġ", "").replace("##", "")
        for punct in [",", ".", ",", "!", "?", ";"]:
            token = token.replace(punct, "")
        return token.strip()
    
    def is_valid_multisense(self, token):
        if token not in self.prototype_stores:
            return False
        store = self.prototype_stores[token]
        total_occurrences = sum(store.counts) if getattr(store, "counts", None) else 0
        min_per_proto = min(store.counts) if getattr(store, "counts", None) else 0
        return store.size() >= 2 and total_occurrences >= 10 and min_per_proto >= 2
    
    def get_prototype_summary(self) -> Dict[str, Any]:
        total_tokens = len(self.prototype_stores)
        total_prototypes = sum(store.size() for store in self.prototype_stores.values())
        num_homographs = sum(1 for store in self.prototype_stores.values() if self.is_multisense_store(store))
        return {
            "total_tokens": total_tokens,
            "total_prototypes": total_prototypes,
            "num_homographs": num_homographs,
        }
    
    def is_multisense_store(self, store: MemoryEfficientPrototypeStore) -> bool:
        k = store.size()
        if k < 2:
            return False
        
        counts = store.counts if store.counts else [1] * k
        strong = sum(1 for c in counts if c >= max(2, self.n_min // 2))
        if strong < 2:
            return False
        
        try:
            cents = []
            for c in store.centroids:
                if isinstance(c, torch.Tensor):
                    cents.append(c.detach().cpu().numpy())
                else:
                    cents.append(np.asarray(c, dtype=np.float32))
            
            if len(cents) < 2:
                return False
            
            cents = np.stack(cents, axis=0)
            if not np.isfinite(cents).all():
                return False
            
            dists = np.linalg.norm(cents[:, None, :] - cents[None, :, :], axis=-1)
            dists = np.clip(dists, 0.0, 100.0)
            tri_idx = np.triu_indices(len(cents), 1)
            if tri_idx[0].size == 0:
                return False
            
            tri = dists[tri_idx]
            if tri.size == 0:
                return False
            
            min_dist = float(np.clip(tri.min(), 0.0, 100.0))
            base = max(store.tau, 1e-3)
            return min_dist >= base * _DSCD_NEW_SENSE_LAMBDA
        except Exception:
            return False
    
    def discover_homographs_for_tokens(self, token_names: List[str], min_cluster_samples: int, dispersion_threshold: float, global_step: int) -> int:
        discovered_in_run: List[str] = []
        for token in token_names:
            try:
                success = self.cluster_buffer_to_prototypes_kmeans(token)
                if success:
                    store = self.prototype_stores.get(token)
                    if store and store.size() >= 2:
                        clean_token = normalize_token_key(token)
                        if clean_token:
                            self.discovered_homographs.add(clean_token)
                            discovered_in_run.append(clean_token)
            except Exception:
                continue
        
        try:
            self.discovered_log.append({
                "timestamp": time.time(),
                "global_step": global_step,
                "candidates_processed": len(token_names),
                "discovered_count": len(discovered_in_run),
                "homographs": discovered_in_run,
                "total_discovered": len(self.discovered_homographs),
            })
        except Exception:
            pass
        
        return len(discovered_in_run)
    
    def discover_homographs(self, n_min: Optional[int] = None, dispersion_threshold: Optional[float] = None, min_cluster_size: int = 5, progress: bool = False, max_candidates: int = 500) -> int:
        if n_min is None:
            n_min = self.n_min
        if dispersion_threshold is None:
            dispersion_threshold = self.dispersion_threshold
        
        n_min = max(1, int(n_min))
        min_cluster_size = max(1, int(min_cluster_size))
        max_candidates = max(1, int(max_candidates))
        
        buffer_snapshot = {}
        with self.buffer_lock:
            for token, buffer in list(self.buffers_raw.items()):
                buffer_snapshot[token] = len(buffer)
        
        candidates = []
        for token, buffer_size in buffer_snapshot.items():
            if buffer_size >= n_min:
                dispersion = self.get_dispersion(token)
                if dispersion >= dispersion_threshold:
                    rank_score = dispersion * buffer_size
                    candidates.append((token, rank_score, buffer_size, dispersion))
        
        if not candidates:
            return 0
        
        candidates.sort(key=lambda x: x[1], reverse=True)
        candidates = candidates[:max_candidates]
        
        discovered = 0
        for token, score, buf_size, disp in candidates:
            try:
                success = self.cluster_buffer_to_prototypes_kmeans(token)
                if success:
                    store = self.prototype_stores.get(token)
                    if store and store.size() >= 2:
                        clean_token = normalize_token_key(token)
                        if clean_token:
                            self.discovered_homographs.add(clean_token)
                            discovered += 1
            except Exception:
                continue
        
        return discovered
    
    def periodic_discovery_check(self, global_step: int, discovery_frequency: int = 200, max_tokens_per_discovery: int = 150) -> int:
        self.last_periodic_check = global_step
        max_tokens_per_discovery = max(1, int(max_tokens_per_discovery))
        
        buffer_snapshot = {}
        with self.buffer_lock:
            for token, buffer in list(self.buffers_raw.items()):
                buffer_snapshot[token] = len(buffer)
        
        candidates = []
        for token, buffer_size in buffer_snapshot.items():
            if buffer_size >= self.n_min:
                dispersion = self.get_dispersion(token)
                if dispersion >= self.dispersion_threshold:
                    rank_score = dispersion * buffer_size
                    candidates.append((token, rank_score, buffer_size, dispersion))
        
        if not candidates:
            return 0
        
        candidates.sort(key=lambda x: x[1], reverse=True)
        token_names = [c[0] for c in candidates[:max_tokens_per_discovery]]
        
        try:
            with self.thread_lock:
                running = any(isinstance(f, Future) and not f.done() for f in list(self.active_threads))
            
            if hasattr(self, "cluster_executor") and not running:
                try:
                    fut = self.cluster_executor.submit(
                        self.discover_homographs_for_tokens,
                        token_names,
                        self.n_min,
                        self.dispersion_threshold,
                        global_step
                    )
                    with self.thread_lock:
                        self.active_threads.append(fut)
                    return 0
                except Exception:
                    pass
        except Exception:
            pass
        
        try:
            discovered = self.discover_homographs_for_tokens(
                token_names,
                self.n_min,
                self.dispersion_threshold,
                global_step
            )
            return discovered
        except Exception:
            return 0
    
    def get_dispersion(self, token_type: str) -> float:
        try:
            with self.dispersion_lock:
                last_update = self.dispersion_last_updated.get(token_type, 0.0)
                if token_type in self.dispersion_cache and (time.time() - last_update) < 3600:
                    return self.dispersion_cache[token_type]
        except Exception:
            pass
        
        with self.buffer_lock:
            if token_type not in self.buffers_raw or len(self.buffers_raw[token_type]) < 2:
                return 0.0
            
            try:
                embeddings = []
                for emb in self.buffers_raw[token_type]:
                    try:
                        if isinstance(emb, torch.Tensor):
                            embeddings.append(emb.detach().cpu().numpy())
                        else:
                            embeddings.append(np.asarray(emb, dtype=np.float32))
                    except Exception:
                        continue
                
                if len(embeddings) < 2:
                    return 0.0
                
                arr = np.stack(embeddings, axis=0)
            except Exception:
                return 0.0
        
        try:
            if not np.isfinite(arr).all():
                return 0.0
            arr = np.clip(arr, -100.0, 100.0)
            
            centroid = arr.mean(axis=0)
            distances = np.linalg.norm(arr - centroid[None, :], axis=1)
            distances = np.clip(distances, 0.0, 100.0)
            
            if distances.size < 2:
                return 0.0
            
            dispersion = float(np.std(distances))
            dispersion = max(0.0, min(10.0, dispersion))
            
            with self.dispersion_lock:
                self.dispersion_cache[token_type] = dispersion
                self.dispersion_last_updated[token_type] = time.time()
            
            return dispersion
        except Exception:
            return 0.0
    
    def validate_prototypes(self, homograph_list: Optional[List[str]] = None, cluster_missing: bool = False) -> Dict[str, Any]:
        if homograph_list is None:
            homograph_list = list(_HOMOGRAPH_REFERENCE_LIST_BN)
        
        validation_results: Dict[str, Any] = {
            "total_tokens": len(self.prototype_stores),
            "total_prototypes": 0,
            "multisense_tokens": 0,
            "homographs_found": 0,
            "homographs_missing": [],
            "avg_prototypes_per_token": 0.0,
            "avg_samples_per_prototype": 0.0,
            "quality_score": 0.0,
        }
        
        corrupted_stores = 0
        total_samples = 0
        corrupted = 0
        for token, store in self.prototype_stores.items():
            try:
                num_protos = len(getattr(store, "centroids", []))
                is_valid = True
                for c in getattr(store, "centroids", []):
                    if not store.validate_centroid(c):
                        is_valid = False
                        break
                if not is_valid:
                    corrupted += 1
                    continue
                
                validation_results["total_prototypes"] += num_protos
                if self.is_multisense_store(store):
                    validation_results["multisense_tokens"] += 1
                
                try:
                    total_samples += sum(getattr(store, "counts", []) or [])
                except Exception:
                    pass
            except Exception:
                corrupted += 1
                continue
        
        validation_results["corrupted_stores"] = corrupted
        
        if validation_results["total_tokens"] > 0:
            validation_results["avg_prototypes_per_token"] = validation_results["total_prototypes"] / validation_results["total_tokens"]
        
        if validation_results["total_prototypes"] > 0:
            validation_results["avg_samples_per_prototype"] = total_samples / validation_results["total_prototypes"]
        
        found = 0
        missing = []
        for homograph in homograph_list:
            clean_h = normalize_token_key(homograph)
            if not clean_h:
                continue
            
            found_key = None
            if homograph in self.prototype_stores:
                found_key = homograph
            elif clean_h in self.prototype_stores:
                found_key = clean_h
            else:
                for key in self.prototype_stores.keys():
                    c_k = normalize_token_key(key)
                    if c_k and (c_k == clean_h or clean_h in c_k or c_k in clean_h):
                        found_key = key
                        break
            
            if found_key:
                store = self.prototype_stores.get(found_key)
                found_protos = store.size() if store else 0
                if store and self.is_multisense_store(store):
                    found += 1
                elif found_protos >= 1:
                    missing.append(homograph)
                else:
                    missing.append(homograph)
            else:
                missing.append(homograph)
        
        validation_results["homographs_found"] = found
        validation_results["homographs_missing"] = missing
        
        homograph_coverage = found / len(homograph_list) if homograph_list else 0.0
        multisense_ratio = validation_results["multisense_tokens"] / validation_results["total_tokens"] if validation_results["total_tokens"] > 0 else 0.0
        validation_results["quality_score"] = homograph_coverage * 0.6 + multisense_ratio * 0.4
        
        return validation_results
    
    def should_track_token(self, token_text: str) -> bool:
        if not token_text or not isinstance(token_text, str):
            return False
        
        if len(self.dscd_allowed_tokens) > self.dscd_cache_max_size:
            self.dscd_allowed_tokens.clear()
        if len(self.dscd_ignored_tokens) > self.dscd_cache_max_size:
            self.dscd_ignored_tokens.clear()
        
        if token_text in self.dscd_allowed_tokens:
            return True
        if token_text in self.dscd_ignored_tokens:
            return False
        
        if token_text in self.special_tokens:
            self.dscd_ignored_tokens.add(token_text)
            return False
        
        clean = normalize_token_key(token_text)
        if not clean:
            self.dscd_ignored_tokens.add(token_text)
            return False
        
        self.dscd_allowed_tokens.add(token_text)
        return True
    
    def canonical_token_key(self, raw_token: str, token_word_map: Optional[Dict[int, Optional[str]]], idx: int) -> Optional[str]:
        try:
            if token_word_map and isinstance(token_word_map, dict) and idx in token_word_map and token_word_map[idx]:
                return str(token_word_map[idx]).strip()
        except Exception:
            pass
        return str(raw_token).strip() if raw_token is not None else None
    
    def cleanup_threads(self) -> None:
        try:
            with self.thread_lock:
                kept = []
                for f in list(self.active_threads):
                    try:
                        if isinstance(f, Future):
                            if not f.done():
                                kept.append(f)
                        else:
                            kept.append(f)
                    except Exception:
                        kept.append(f)
                self.active_threads.clear()
                self.active_threads.extend(kept)
        except Exception:
            pass
    
    def cleanup_memory(self) -> None:
        try:
            for token_type, buffer in list(self.buffers.items()):
                if len(buffer) > int(self.buffer_size * 1.5):
                    while len(buffer) > self.buffer_size:
                        buffer.popleft()
            
            for token_type, buffer in list(self.buffers_raw.items()):
                if len(buffer) > int(self.buffer_size * 1.5):
                    while len(buffer) > self.buffer_size:
                        buffer.popleft()
            
            now = time.time()
            expired = [k for k, v in self.dispersion_last_updated.items() if now - v > 3600]
            for k in expired:
                self.dispersion_cache.pop(k, None)
                self.dispersion_last_updated.pop(k, None)
            
            if gc.isenabled():
                gc.collect()
        except Exception:
            pass
    
    def forward(
        self,
        token_embeddings: Optional[torch.Tensor],
        token_types: Optional[List[List[str]]] = None,
        train_mode: bool = True,
        token_word_map: Optional[List[Dict[int, Optional[str]]]] = None,
        h_all: Optional[torch.Tensor] = None,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        **kwargs
    ) -> Dict[str, Any]:
        if token_embeddings is None and h_all is not None:
            token_embeddings = h_all
        
        if token_embeddings is None:
            raise ValueError("MemoryEfficientDSCDOnline.forward requires token_embeddings or h_all")
        
        device = token_embeddings.device
        batch_size = int(token_embeddings.size(0))
        
        try:
            word_embeddings, word_maps_from_reconstruction = reconstruct_word_embeddings(
                token_embeddings, input_ids, self.tokenizer, device
            )
        except Exception:
            word_embeddings = token_embeddings.to(device)
            word_maps_from_reconstruction = [{} for _ in range(batch_size)]
        
        if input_ids is not None and token_types is None:
            token_types = []
            for b in range(input_ids.size(0)):
                try:
                    token_types.append(self.tokenizer.convert_ids_to_tokens(input_ids[b].tolist()))
                except Exception:
                    token_types.append([f"tok{i}" for i in range(input_ids.size(1))])
        
        if token_types is None:
            word_seq_len = int(word_embeddings.size(1))
            token_types = [[f"tok{i}" for i in range(word_seq_len)] for _ in range(batch_size)]
        
        self.cleanup_counter += 1
        if self.cleanup_counter % 50 == 0:
            self.cleanup_counter = 0
            self.cleanup_memory()
            self.cleanup_threads()
        
        all_outputs: Dict[str, List[Any]] = {
            "proto_assignments": [],
            "proto_probs": [],
            "uncertainties": [],
            "span_preds": [],
            "gates": [],
            "h_augmented": [],
        }
        
        word_seq_len = int(word_embeddings.size(1))
        for b in range(batch_size):
            word_map_b = word_maps_from_reconstruction[b] if b < len(word_maps_from_reconstruction) else {}
            token_types_b = token_types[b] if token_types and len(token_types) > b else [f"tok{i}" for i in range(word_seq_len)]
            
            batch_outputs = self.process_sequence(
                word_embeddings[b],
                token_types_b,
                device,
                word_map=word_map_b,
                train_mode=train_mode
            )
            
            for k in all_outputs:
                all_outputs[k].append(batch_outputs[k])
        
        try:
            h_aug_list = []
            max_seq_len = word_seq_len
            for b in range(batch_size):
                h_batch_list = all_outputs["h_augmented"][b]
                if isinstance(h_batch_list, list) and len(h_batch_list) > 0 and isinstance(h_batch_list[0], torch.Tensor):
                    try:
                        h_batch = torch.stack(h_batch_list, dim=0)
                        if h_batch.size(0) < max_seq_len:
                            pad = max_seq_len - h_batch.size(0)
                            h_batch = F.pad(h_batch, (0, 0, 0, pad), value=0)
                        elif h_batch.size(0) > max_seq_len:
                            h_batch = h_batch[:max_seq_len]
                    except Exception:
                        h_batch = torch.zeros(max_seq_len, self.embed_dim, device=device)
                else:
                    h_batch = torch.zeros(max_seq_len, self.embed_dim, device=device)
                h_aug_list.append(h_batch)
            
            all_outputs["h_augmented"] = torch.stack(h_aug_list, dim=0)
        except Exception:
            all_outputs["h_augmented"] = word_embeddings
        
        try:
            proto_assign_tensor = []
            for row in all_outputs["proto_assignments"]:
                try:
                    arr = [int(x.item()) if isinstance(x, torch.Tensor) else int(x) for x in row]
                    proto_assign_tensor.append(torch.tensor(arr, dtype=torch.long))
                except Exception:
                    proto_assign_tensor.append(torch.full((word_seq_len,), -1, dtype=torch.long))
            all_outputs["proto_assignments"] = proto_assign_tensor
        except Exception:
            pass
        
        return all_outputs
    
    def process_sequence(
        self,
        token_embeddings: torch.Tensor,
        token_types: List[Any],
        device: torch.device,
        word_map: Optional[Dict[int, Optional[str]]] = None,
        train_mode: bool = True
    ) -> Dict[str, List[Any]]:
        seq_len = int(token_embeddings.size(0))
        outputs: Dict[str, List[Any]] = {
            "proto_assignments": [],
            "proto_probs": [],
            "uncertainties": [],
            "span_preds": [],
            "gates": [],
            "h_augmented": [],
        }
        
        w_entropy = 0.6
        w_margin = 0.25
        w_d1 = 0.15
        softmax_temp = 0.7
        eps = 1e-9
        
        for j in range(seq_len):
            raw_tok = token_types[j] if j < len(token_types) else f"tok{j}"
            if not isinstance(raw_tok, str):
                raw_tok = str(raw_tok)
            token_key = raw_tok.replace("▁", "").strip() or raw_tok
            h_j = token_embeddings[j]
            
            if not self.should_track_token(token_key):
                outputs["proto_assignments"].append(torch.tensor(-1, dtype=torch.long))
                outputs["proto_probs"].append(torch.tensor([1.0], dtype=torch.float32))
                outputs["uncertainties"].append(0.5)
                outputs["span_preds"].append(0.0)
                outputs["gates"].append(0.0)
                outputs["h_augmented"].append(h_j)
                continue
            
            try:
                h_j_clamped = torch.clamp(h_j, min=-100.0, max=100.0)
                q = F.normalize(h_j_clamped, p=2, dim=-1, eps=1e-9)
                if not torch.isfinite(q).all():
                    raise RuntimeError("non-finite query")
            except Exception:
                outputs["proto_assignments"].append(torch.tensor(-1, dtype=torch.long))
                outputs["proto_probs"].append(torch.tensor([1.0], dtype=torch.float32))
                outputs["uncertainties"].append(0.5)
                outputs["span_preds"].append(0.0)
                outputs["gates"].append(0.0)
                outputs["h_augmented"].append(h_j)
                continue
            
            h_raw = h_j.detach().cpu().clone()
            
            with self.buffer_lock:
                if token_key not in self.buffers:
                    self.buffers[token_key] = deque(maxlen=self.buffer_size)
                    self.buffers_raw[token_key] = deque(maxlen=self.buffer_size)
                    self.prototype_stores[token_key] = MemoryEfficientPrototypeStore(self.embed_dim, max_protos=self.max_protos)
                    self.token_addition_counts[token_key] = 0
                
                try:
                    self.buffers[token_key].append(q.detach().cpu().clone())
                    self.buffers_raw[token_key].append(h_raw)
                    self.token_addition_counts[token_key] += 1
                except Exception:
                    pass
            
            if self.enable_online_clustering and train_mode and self.token_addition_counts[token_key] % self.online_clustering_frequency == 0 and len(self.buffers_raw[token_key]) >= self.n_min:
                try:
                    fut = self.cluster_executor.submit(self.cluster_buffer_to_prototypes_kmeans, token_key)
                    with self.thread_lock:
                        self.active_threads.append(fut)
                except Exception:
                    pass
            
            store = self.prototype_stores[token_key]
            centroids_snapshot: List[torch.Tensor] = []
            
            with self.buffer_lock:
                for c in getattr(store, "centroids", []):
                    try:
                        if isinstance(c, torch.Tensor):
                            c_n = c.detach().to(device=device, dtype=torch.float32)
                        else:
                            c_n = torch.tensor(np.asarray(c, dtype=np.float32), device=device, dtype=torch.float32)
                        
                        if c_n.numel() != self.embed_dim:
                            continue
                        if not torch.isfinite(c_n).all():
                            continue
                        
                        c_n = torch.clamp(c_n, min=-100.0, max=100.0)
                        c_n = F.normalize(c_n, p=2, dim=-1, eps=1e-9) if c_n.norm().item() > 1e-6 else c_n
                        centroids_snapshot.append(c_n)
                    except Exception:
                        continue
            
            if not centroids_snapshot:
                outputs["proto_assignments"].append(torch.tensor(-1, dtype=torch.long))
                outputs["proto_probs"].append(torch.tensor([1.0], dtype=torch.float32))
                outputs["uncertainties"].append(0.5)
                outputs["span_preds"].append(0.0)
                outputs["gates"].append(0.0)
                outputs["h_augmented"].append(h_j)
                continue
            
            try:
                centroids_stacked = torch.stack(centroids_snapshot, dim=0)
                
                if not torch.isfinite(centroids_stacked).all():
                    raise ValueError("centroids contain NaN/Inf")
                
                if self.use_cosine_distance:
                    q_norm = F.normalize(h_j_clamped.to(device), p=2, dim=-1, eps=1e-9).unsqueeze(0)
                    cent_norm = F.normalize(centroids_stacked, p=2, dim=-1, eps=1e-9)
                    sims = (q_norm @ cent_norm.t()).squeeze(0).detach().cpu().numpy().astype(np.float32)
                    sims = np.clip(sims, -1.0, 1.0)
                    
                    logits = sims / max(eps, softmax_temp)
                    logits = np.clip(logits, -50.0, 50.0)
                    logits = logits - np.max(logits)
                    exp_logits = np.exp(logits)
                    probs = exp_logits / (exp_logits.sum() + eps)
                    
                    dists = np.clip(1.0 - sims, 0.0, 2.0)
                else:
                    q_vec = h_j_clamped.to(device).unsqueeze(0)
                    dists_t = torch.norm(centroids_stacked - q_vec, dim=1).detach().cpu().numpy().astype(np.float32)
                    dists_t = np.clip(dists_t, 0.0, 100.0)
                    
                    sims = -dists_t
                    logits = sims / max(eps, softmax_temp)
                    logits = np.clip(logits, -50.0, 50.0)
                    logits = logits - np.max(logits)
                    exp_logits = np.exp(logits)
                    probs = exp_logits / (exp_logits.sum() + eps)
                    
                    dists = dists_t
                
                if probs.size == 0:
                    raise ValueError("empty similarity/probabilities")
                
                probs = np.clip(probs, eps, 1.0)
                probs = probs / (probs.sum() + eps)
                p_max = float(np.max(probs))
                
                if probs.size > 1:
                    entropy_raw = -float(np.sum(probs * np.log(probs + eps)))
                    H_norm = entropy_raw / float(max(eps, np.log(probs.size)))
                    H_norm = float(np.clip(H_norm, 0.0, 1.0))
                else:
                    H_norm = 0.0
                
                sorted_idx = np.argsort(probs)[::-1]
                top0 = probs[sorted_idx[0]] if probs.size > 0 else 1.0
                top1 = probs[sorted_idx[1]] if probs.size > 1 else 0.0
                span_raw = float(np.clip(top0 - top1, 0.0, 1.0))
                
                d1 = float(np.min(dists)) if dists.size > 0 else 1.0
                d1 = max(0.0, min(10.0, d1))
                d1_norm = d1 / (1.0 + abs(d1))
                
                margin = span_raw
                uncertainty_raw = w_entropy * H_norm + w_margin * (1.0 - margin) + w_d1 * min(1.0, d1_norm)
                uncertainty = float(np.clip(uncertainty_raw, 0.0, 1.0))
                
                gate_raw = p_max * (1.0 - H_norm)
                gate = float(np.clip(gate_raw, 0.0, 1.0))
                
                assignment = int(sorted_idx[0]) if probs.size > 0 else -1
                prob_tensor = torch.from_numpy(probs.astype(np.float32))
                
                outputs["proto_assignments"].append(torch.tensor(assignment, dtype=torch.long))
                outputs["proto_probs"].append(prob_tensor)
                outputs["uncertainties"].append(float(uncertainty))
                outputs["span_preds"].append(float(span_raw))
                outputs["gates"].append(float(gate))
                
                h_aug = h_j
                if self.apply_augmentation and assignment >= 0 and p_max > 0.4:
                    try:
                        centroid_t = centroids_snapshot[assignment]
                        if centroid_t is not None and isinstance(centroid_t, torch.Tensor):
                            h_aug = h_j + centroid_t.to(h_j.device).to(dtype=h_j.dtype)
                            h_aug = torch.clamp(h_aug, min=-100.0, max=100.0)
                    except Exception:
                        h_aug = h_j
                
                outputs["h_augmented"].append(h_aug)
                
                try:
                    store.update_rolling_stats(d1)
                except Exception:
                    pass
                
            except Exception:
                assignment = -1
                outputs["proto_assignments"].append(torch.tensor(assignment, dtype=torch.long))
                outputs["proto_probs"].append(torch.tensor([1.0], dtype=torch.float32))
                outputs["uncertainties"].append(0.5)
                outputs["span_preds"].append(0.0)
                outputs["gates"].append(0.0)
                outputs["h_augmented"].append(h_j)
                continue
        
        return outputs
    
    def print_clusters_summary(self) -> None:
        try:
            items = []
            for token, store in self.prototype_stores.items():
                try:
                    proto_sample_count = sum(getattr(store, "counts", []) or [])
                except Exception:
                    proto_sample_count = 0
                
                buffer_len = len(self.buffers_raw.get(token, [])) if token in self.buffers_raw else 0
                protos = store.size()
                mu = getattr(store, "mu", 0.0)
                tau = getattr(store, "tau", 0.0)
                
                items.append((token, proto_sample_count, protos, mu, tau, buffer_len))
            
            items.sort(key=lambda x: x[1], reverse=True)
            
            total_samples = sum(i[1] for i in items)
            total_protos = sum(i[2] for i in items)
            total_buffers = sum(i[5] for i in items)
            
            print(f"Total: {len(items)} clusters, {total_samples} samples, {total_protos} protos, {total_buffers} buffers")
        except Exception:
            pass
    
    def cluster_buffer_to_prototypes_kmeans(self, token_type: str) -> bool:
        try:
            if not self.should_track_token(token_type):
                return False
            
            with self.buffer_lock:
                if token_type not in self.buffers_raw:
                    return False
                buf_snapshot = [e.clone() if isinstance(e, torch.Tensor) else torch.tensor(np.asarray(e), dtype=torch.float32)
                               for e in self.buffers_raw[token_type]]
            
            if len(buf_snapshot) < self.n_min:
                return False
            
            emb_list = []
            for e in buf_snapshot:
                try:
                    if isinstance(e, torch.Tensor):
                        arr = e.detach().cpu().numpy()
                    else:
                        arr = np.asarray(e, dtype=np.float32)
                    
                    if arr.size == self.embed_dim and np.isfinite(arr).all():
                        arr = np.clip(arr, -100.0, 100.0)
                        emb_list.append(arr)
                except Exception:
                    continue
            
            if len(emb_list) == 0:
                return False
            
            if len(emb_list) > self.max_clustering_points:
                idxs = np.random.choice(len(emb_list), size=self.max_clustering_points, replace=False)
                new_embeddings = np.stack([emb_list[i] for i in idxs], axis=0)
            else:
                new_embeddings = np.stack(emb_list, axis=0)
            
            if new_embeddings.shape[0] < 2:
                return False
            
            norms = np.linalg.norm(new_embeddings, axis=1)
            norms = np.clip(norms, 0.0, 100.0)
            valid_mask = norms > 1e-6
            if not np.any(valid_mask):
                return False
            
            new_embeddings = new_embeddings[valid_mask]
            norms = norms[valid_mask]
            
            if self.use_cosine_distance:
                new_embeddings_normalized = new_embeddings / (norms[:, None] + 1e-9)
            else:
                new_embeddings_normalized = new_embeddings
            
            store = self.prototype_stores[token_type]
            protos_added = 0
            
            if HAS_KMEANS and new_embeddings_normalized.shape[0] >= 2:
                try:
                    min_k = 1
                    max_k = min(self.max_protos, new_embeddings_normalized.shape[0], max(1, self.n_min))
                    k_guess = max(
                        min_k,
                        min(
                            max_k,
                            int(max(1, round(np.sqrt(max(1, new_embeddings_normalized.shape[0] / 2)))))
                        )
                    )
                    
                    if k_guess >= 2 and new_embeddings_normalized.shape[0] >= k_guess:
                        unique_rows = np.unique(new_embeddings_normalized, axis=0)
                        if unique_rows.shape[0] < 2:
                            return False
                        
                        km = KMeans(n_clusters=k_guess, random_state=0, n_init=10).fit(new_embeddings_normalized)
                        labels = km.labels_
                        
                        centroids = []
                        counts = []
                        times = []
                        for c_idx in range(k_guess):
                            mask = labels == c_idx
                            cluster_size = int(mask.sum())
                            if cluster_size >= self.n_min:
                                centroid = new_embeddings_normalized[mask].mean(axis=0).astype(np.float32)
                                if not np.isfinite(centroid).all():
                                    continue
                                centroid = np.clip(centroid, -100.0, 100.0)
                                
                                c_vec = torch.from_numpy(centroid)
                                if c_vec.numel() != self.embed_dim:
                                    continue
                                
                                if self.use_cosine_distance:
                                    c_vec = F.normalize(c_vec, p=2, dim=0, eps=1e-9)
                                
                                centroids.append(c_vec)
                                counts.append(cluster_size)
                                times.append(time.time())
                                protos_added += 1
                        
                        if centroids:
                            if len(centroids) > self.max_protos:
                                order = np.argsort(counts)[::-1][:self.max_protos]
                                centroids = [centroids[i] for i in order]
                                counts = [counts[i] for i in order]
                                times = [times[i] for i in order]
                            
                            store.centroids = centroids
                            store.counts = counts
                            store.creation_time = times
                            
                            try:
                                store.labels = torch.tensor(labels)
                            except Exception:
                                store.labels = None
                except Exception:
                    protos_added = 0
            
            if protos_added == 0 and HAS_CLUSTERING and new_embeddings_normalized.shape[0] >= 2:
                try:
                    condensed = pdist(new_embeddings_normalized, metric='euclidean')
                    condensed = np.clip(condensed, 0.0, 100.0)
                    
                    if condensed.size > 0:
                        Z = linkage(condensed, method='average')
                        max_dist = float(np.clip(np.max(condensed), 0.0, 100.0)) if condensed.size > 0 else 1.0
                        absolute_threshold = float(self.dispersion_threshold * max_dist)
                        
                        clusters = fcluster(Z, t=absolute_threshold, criterion='distance') - 1
                        
                        if clusters.size > 0:
                            new_centroids = []
                            new_counts = []
                            new_times = []
                            for c_id in range(int(clusters.max()) + 1):
                                mask = clusters == c_id
                                cluster_size = int(mask.sum())
                                if cluster_size >= self.n_min:
                                    centroid = new_embeddings_normalized[mask].mean(axis=0).astype(np.float32)
                                    if not np.isfinite(centroid).all():
                                        continue
                                    centroid = np.clip(centroid, -100.0, 100.0)
                                    
                                    c_vec = torch.from_numpy(centroid)
                                    if c_vec.numel() != self.embed_dim:
                                        continue
                                    
                                    if self.use_cosine_distance:
                                        c_vec = F.normalize(c_vec, p=2, dim=0, eps=1e-9)
                                    
                                    new_centroids.append(c_vec)
                                    new_counts.append(cluster_size)
                                    new_times.append(time.time())
                            
                            if new_centroids:
                                if len(new_centroids) > self.max_protos:
                                    order = np.argsort(new_counts)[::-1][:self.max_protos]
                                    new_centroids = [new_centroids[i] for i in order]
                                    new_counts = [new_counts[i] for i in order]
                                    new_times = [new_times[i] for i in order]
                                
                                store.centroids = new_centroids
                                store.counts = new_counts
                                store.creation_time = new_times
                                
                                try:
                                    store.labels = torch.tensor(clusters)
                                except Exception:
                                    store.labels = None
                                
                                protos_added = len(new_centroids)
                except Exception:
                    protos_added = 0
            
            if protos_added > 0:
                try:
                    counts = store.counts if store.counts else [1] * len(store.centroids)
                    total_count = int(sum(counts))
                    mean_count = float(total_count / max(1, len(counts)))
                    
                    self.cluster_stats[str(token_type)] = {
                        "num_prototypes": len(store.centroids),
                        "counts": [int(c) for c in counts],
                        "total_samples": total_count,
                        "mean_count": mean_count,
                        "mu": float(store.mu),
                        "tau": float(store.tau),
                    }
                except Exception:
                    pass
            
            return store.size() > 0
        except Exception:
            return False
    
    def get_explanations(self, threshold_span: float = 0.3) -> List[Dict[str, Any]]:
        expl: List[Dict[str, Any]] = []
        for token_type, store in self.prototype_stores.items():
            if store.size() >= 2:
                expl.append({"token": str(token_type), "protos": store.size()})
        return expl


print("=" * 80)
print("CELL3: DSCD MODULE - NaN/Inf HARDENED - COMPLETE")
print("=" * 80)


In [None]:
# ==============================================================================
# CELL 4: ASBN MODULE - NaN/Inf GRADIENT HARDENED
# ==============================================================================
import traceback
from typing import Any, List, Tuple, Optional, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threading

try:
    _MAX_LENGTH = int(MAX_LENGTH) if int(MAX_LENGTH) > 0 else 48
except Exception:
    _MAX_LENGTH = 48

try:
    _ENABLE_ASBN_TRAINING = bool(ENABLE_ASBN_TRAINING)
except Exception:
    _ENABLE_ASBN_TRAINING = True

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except Exception:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except Exception:
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except Exception:
    _DEBUG_TIMING = False

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
except Exception:
    _SOURCE_LANGUAGE = "bn"

try:
    _GRL_ALPHA_START = float(GRL_ALPHA_START)
    _GRL_ALPHA_END = float(GRL_ALPHA_END)
    _GRL_ALPHA_SCHEDULE = str(GRL_ALPHA_SCHEDULE)
    try:
        _GRL_ALPHA_STEPS = int(GRL_ALPHA_STEPS)
        if _GRL_ALPHA_STEPS <= 0:
            _GRL_ALPHA_STEPS = 10000
    except Exception:
        _GRL_ALPHA_STEPS = 10000
except Exception:
    _GRL_ALPHA_START = 0.01
    _GRL_ALPHA_END = 1.0
    _GRL_ALPHA_SCHEDULE = "linear"
    _GRL_ALPHA_STEPS = 10000

_has_get_tokenizer_special_tokens = "get_tokenizer_special_tokens" in globals()


class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = max(0.01, min(10.0, float(alpha)))
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        if not torch.isfinite(grad_output).all():
            print(f"❌ GRL: Received NaN/Inf gradient input")
            return torch.zeros_like(grad_output), None
        
        reversed_grad = -ctx.alpha * grad_output
        reversed_grad = torch.clamp(reversed_grad, min=-10.0, max=10.0)
        
        if not torch.isfinite(reversed_grad).all():
            print(f"❌ GRL: Produced NaN/Inf gradient output")
            return torch.zeros_like(grad_output), None
        
        return reversed_grad, None


def gradient_reversal(x, alpha: float = 1.0):
    alpha = max(0.01, min(10.0, float(alpha)))
    return GradientReversalFunction.apply(x, alpha)


class LightweightDiscriminator(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        input_dim = max(1, int(input_dim))
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 2),
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not torch.isfinite(x).all():
            print(f"❌ LightweightDiscriminator: Input contains NaN/Inf")
            x = torch.where(torch.isfinite(x), x, torch.zeros_like(x))
        
        logits = self.classifier(x)
        logits = torch.clamp(logits, min=-10.0, max=10.0)
        
        if not torch.isfinite(logits).all():
            print(f"❌ LightweightDiscriminator: Output contains NaN/Inf")
            logits = torch.where(torch.isfinite(logits), logits, torch.zeros_like(logits))
        
        return logits


class DomainDiscriminator(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        input_dim = max(1, int(input_dim))
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(256, 2),
        )
        self._initialize_weights()

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight, gain=0.5)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0.0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not torch.isfinite(x).all():
            print(f"❌ DomainDiscriminator: Input contains NaN/Inf")
            x = torch.where(torch.isfinite(x), x, torch.zeros_like(x))
        
        logits = self.classifier(x)
        logits = torch.clamp(logits, min=-10.0, max=10.0)
        
        if not torch.isfinite(logits).all():
            print(f"❌ DomainDiscriminator: Output contains NaN/Inf")
            logits = torch.where(torch.isfinite(logits), logits, torch.zeros_like(logits))
        
        return logits


class MemoryEfficientASBNModule(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        tokenizer=None,
        language: str = "bn",
        freq_threshold: float = 0.7,
        uncertainty_threshold: float = 0.3,
        gate_threshold: float = 0.5,
        warmup_steps: int = 1000,
        encoder_grl_scale: float = 0.5,
    ):
        super().__init__()
        self.language = language
        self.tokenizer = tokenizer
        self.embed_dim = max(1, int(embed_dim))

        self.bn_source = nn.BatchNorm1d(
            self.embed_dim, 
            eps=1e-3,
            momentum=0.1,
            affine=True,
            track_running_stats=True
        )
        self.bn_target = nn.BatchNorm1d(
            self.embed_dim, 
            eps=1e-3,
            momentum=0.1,
            affine=True,
            track_running_stats=True
        )

        self.d_domain = DomainDiscriminator(self.embed_dim)
        self.d_freq = LightweightDiscriminator(self.embed_dim + 2)
        self.d_ctx = LightweightDiscriminator(self.embed_dim + 2)
        self.d_xl = LightweightDiscriminator(self.embed_dim)
        self.freq_threshold = max(0.0, min(1.0, float(freq_threshold)))
        self.uncertainty_threshold = max(0.0, min(1.0, float(uncertainty_threshold)))
        self.gate_threshold = max(0.0, min(1.0, float(gate_threshold)))
        self.warmup_steps = max(0, int(warmup_steps))
        self.current_step = 0
        self.lambda_base = {"freq": 0.5, "ctx": 0.3, "xl": 0.4, "domain": 0.5}
        self.lambda_max = 1.0
        self.encoder_grl_scale = max(0.1, min(1.0, float(encoder_grl_scale)))
        self.stats_reset_interval = 10000

        self.correct_domain = 0
        self.correct_source = 0
        self.correct_target = 0
        self.total_samples = 0
        self.total_source = 0
        self.total_target = 0
        self.domain_loss_accumulator = 0.0
        self.asbn_loss_accumulator = 0.0
        self._stats_lock = threading.Lock()
        self.parse_errors = 0
        self.last_parse_error_log = 0.0

        self.stats = {
            "domain_loss": 0.0,
            "domain_accuracy": 0.0,
            "source_accuracy": 0.0,
            "target_accuracy": 0.0,
            "asbn_loss": 0.0,
            "num_updates": 0,
        }
        try:
            if tokenizer is not None and _has_get_tokenizer_special_tokens:
                self.special_tokens = get_tokenizer_special_tokens(tokenizer)
            elif tokenizer is not None:
                self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []))
            else:
                self.special_tokens = set()
        except Exception:
            self.special_tokens = set()

        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print("=" * 80)
            print("[ASBN-INIT] MemoryEfficientASBNModule initialized - NaN/Inf PROTECTED")
            print("=" * 80)
            print(f"  embed_dim={self.embed_dim}, warmup_steps={self.warmup_steps}")
            print(f"  encoder_grl_scale={self.encoder_grl_scale}")
            print(f"  BatchNorm eps=1e-3 (hardened for float32)")
            print(f"  GRL alpha range: [{_GRL_ALPHA_START:.3f}, {_GRL_ALPHA_END:.3f}]")
            print(f"  stats_reset_interval={self.stats_reset_interval}")
            print("=" * 80)

    def get_grl_alpha(self, global_step: Optional[int] = None) -> float:
        if global_step is None:
            global_step = self.current_step
        step = max(0, int(global_step))
        if _GRL_ALPHA_SCHEDULE == "linear":
            progress = min(1.0, float(step) / float(max(1, _GRL_ALPHA_STEPS)))
            alpha = _GRL_ALPHA_START + progress * (_GRL_ALPHA_END - _GRL_ALPHA_START)
        elif _GRL_ALPHA_SCHEDULE == "exponential":
            progress = min(1.0, float(step) / float(max(1, _GRL_ALPHA_STEPS)))
            denom = _GRL_ALPHA_START if abs(_GRL_ALPHA_START) > 1e-6 else 1e-6
            ratio = _GRL_ALPHA_END / denom
            alpha = _GRL_ALPHA_START * (ratio ** progress)
        else:
            alpha = _GRL_ALPHA_END
        return max(0.01, min(10.0, float(alpha)))

    def get_detailed_stats(self) -> Dict[str, Any]:
        with self._stats_lock:
            if self.total_samples == 0:
                return {
                    "domain_loss": None,
                    "domain_accuracy": None,
                    "source_accuracy": None,
                    "target_accuracy": None,
                    "asbn_loss": None,
                    "num_updates": 0,
                }
            domain_acc = (self.correct_domain / self.total_samples) if self.total_samples > 0 else 0.0
            source_acc = (self.correct_source / self.total_source) if self.total_source > 0 else 0.0
            target_acc = (self.correct_target / self.total_target) if self.total_target > 0 else 0.0
            avg_domain_loss = (self.domain_loss_accumulator / self.total_samples) if self.total_samples > 0 else 0.0
            avg_asbn_loss = (self.asbn_loss_accumulator / self.total_samples) if self.total_samples > 0 else 0.0
            return {
                "domain_loss": float(avg_domain_loss),
                "domain_accuracy": float(domain_acc),
                "source_accuracy": float(source_acc),
                "target_accuracy": float(target_acc),
                "asbn_loss": float(avg_asbn_loss),
                "num_updates": int(self.total_samples),
            }

    def get_asbn_stats(self) -> Dict[str, Any]:
        return self.get_detailed_stats()

    def reset_stats(self) -> None:
        with self._stats_lock:
            self.correct_domain = 0
            self.correct_source = 0
            self.correct_target = 0
            self.total_samples = 0
            self.total_source = 0
            self.total_target = 0
            self.domain_loss_accumulator = 0.0
            self.asbn_loss_accumulator = 0.0
            self.stats = {k: 0.0 for k in self.stats}

    def critic_parameters(self):
        return list(self.d_domain.parameters()) + list(self.d_freq.parameters()) + list(self.d_ctx.parameters()) + list(self.d_xl.parameters())

    def _ensure_discriminators_on_device(self, device: torch.device) -> None:
        try:
            for mod in (self.d_domain, self.d_freq, self.d_ctx, self.d_xl, self.bn_source, self.bn_target):
                try:
                    mod.to(device)
                    if self.training:
                        mod.train()
                except Exception as e:
                    if _VERBOSE_LOGGING:
                        print(f"❌ ASBN: Failed to move {mod.__class__.__name__} to {device}: {e}")
        except Exception as e:
            print(f"❌ ASBN: _ensure_discriminators_on_device failed: {e}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()

    def _parse_proto_probs_matrix(self, proto_probs: Any, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
        batch_size = max(1, int(batch_size))
        seq_len = max(1, int(seq_len))
        pmax = torch.full((batch_size, seq_len), 0.5, dtype=torch.float32, device=device)
        try:
            if proto_probs is None:
                return pmax

            if isinstance(proto_probs, torch.Tensor):
                p = proto_probs.detach().to(device)
                if p.dim() == 3:
                    B, T, K = p.shape
                    b_max = min(batch_size, B)
                    t_max = min(seq_len, T)
                    vals = p[:b_max, :t_max, :].max(dim=2)[0]
                    pmax[:b_max, :t_max] = torch.where(torch.isfinite(vals), vals, torch.full_like(vals, 0.5))
                    return pmax
                if p.dim() == 2:
                    if p.size(0) == batch_size and p.size(1) == seq_len:
                        pmax[:, :] = torch.where(torch.isfinite(p.float()), p.float(), pmax)
                        return pmax
                    if batch_size == 1:
                        vals = p.max(dim=1)[0]
                        t_max = min(seq_len, vals.size(0))
                        pmax[0, :t_max] = torch.where(torch.isfinite(vals[:t_max]), vals[:t_max], pmax[0, :t_max])
                        return pmax

            if isinstance(proto_probs, (list, tuple)):
                if len(proto_probs) == batch_size:
                    for b in range(batch_size):
                        row = proto_probs[b]
                        if row is None:
                            continue
                        if isinstance(row, torch.Tensor):
                            r = row.detach().to(device)
                            if r.dim() == 2:
                                t_max = min(seq_len, r.size(0))
                                vals = r[:t_max, :].max(dim=1)[0]
                                pmax[b, :t_max] = torch.where(torch.isfinite(vals), vals, pmax[b, :t_max])
                                continue
                        if isinstance(row, (list, tuple, np.ndarray)):
                            for t in range(min(seq_len, len(row))):
                                try:
                                    val = row[t]
                                    if isinstance(val, torch.Tensor):
                                        arr = val.detach().cpu().numpy().astype(np.float32).ravel()
                                    else:
                                        arr = np.asarray(val, dtype=np.float32).ravel()
                                    if arr.size == 0:
                                        continue
                                    pmax_val = float(np.nanmax(arr))
                                    if np.isfinite(pmax_val):
                                        pmax[b, t] = pmax_val
                                except Exception as e:
                                    import time
                                    now = time.time()
                                    if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and (now - self.last_parse_error_log) > 60.0:
                                        print(f"❌ ASBN: proto_probs parse error at b={b},t={t}: {type(e).__name__}")
                                        self.last_parse_error_log = now
                                    self.parse_errors += 1
                                    continue
                    return pmax

                if batch_size == 1 and len(proto_probs) >= 1:
                    row = proto_probs
                    for t in range(min(seq_len, len(row))):
                        try:
                            val = row[t]
                            if isinstance(val, torch.Tensor):
                                arr = val.detach().cpu().numpy().astype(np.float32).ravel()
                            else:
                                arr = np.asarray(val, dtype=np.float32).ravel()
                            if arr.size == 0:
                                continue
                            pmax_val = float(np.nanmax(arr))
                            if np.isfinite(pmax_val):
                                pmax[0, t] = pmax_val
                        except Exception as e:
                            import time
                            now = time.time()
                            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and (now - self.last_parse_error_log) > 60.0:
                                print(f"❌ ASBN: proto_probs parse error at t={t}: {type(e).__name__}")
                                self.last_parse_error_log = now
                            self.parse_errors += 1
                            continue
                    return pmax
        except Exception as e:
            print(f"❌ ASBN: _parse_proto_probs_matrix exception: {type(e).__name__}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
        return pmax

    def _parse_scalar_matrix(self, mat: Any, batch_size: int, seq_len: int, device: torch.device, default: float = 0.0) -> torch.Tensor:
        batch_size = max(1, int(batch_size))
        seq_len = max(1, int(seq_len))
        out = torch.full((batch_size, seq_len), float(default), dtype=torch.float32, device=device)
        try:
            if mat is None:
                return out

            if isinstance(mat, torch.Tensor):
                m = mat.detach().to(device)
                if m.dim() == 3:
                    B, T, C = m.shape
                    b_max = min(batch_size, B)
                    t_max = min(seq_len, T)
                    try:
                        out[:b_max, :t_max] = m[:b_max, :t_max, 0].float()
                    except Exception:
                        out[:b_max, :t_max] = m[:b_max, :t_max].float()
                    return out
                if m.dim() == 2:
                    if m.size(0) == batch_size and m.size(1) >= 1:
                        t_max = min(seq_len, m.size(1))
                        out[:, :t_max] = m[:, :t_max].float()
                        return out
                    if batch_size == 1:
                        t_max = min(seq_len, m.size(0))
                        out[0, :t_max] = m[:t_max, 0].float() if m.dim() > 1 else m[:t_max].float()
                        return out
                if m.dim() == 1 and batch_size == 1:
                    t_max = min(seq_len, m.size(0))
                    out[0, :t_max] = m[:t_max].float()
                    return out

            if isinstance(mat, (list, tuple, np.ndarray)):
                if len(mat) == batch_size:
                    for b in range(batch_size):
                        row = mat[b]
                        if row is None:
                            continue
                        if isinstance(row, torch.Tensor):
                            r = row.detach().to(device).float()
                            t_max = min(seq_len, r.size(0))
                            out[b, :t_max] = r[:t_max]
                            continue
                        if isinstance(row, (list, tuple, np.ndarray)):
                            t_max = min(seq_len, len(row))
                            for t in range(t_max):
                                try:
                                    v = row[t]
                                    if isinstance(v, torch.Tensor):
                                        out[b, t] = float(v.item())
                                    else:
                                        out[b, t] = float(v)
                                except Exception as e:
                                    import time
                                    now = time.time()
                                    if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and (now - self.last_parse_error_log) > 60.0:
                                        print(f"❌ ASBN: scalar_matrix parse error at b={b},t={t}: {type(e).__name__}")
                                        self.last_parse_error_log = now
                                    self.parse_errors += 1
                                    out[b, t] = float(default)
                            continue
                if batch_size == 1 and len(mat) > 0:
                    row = mat
                    t_max = min(seq_len, len(row))
                    for t in range(t_max):
                        try:
                            v = row[t]
                            if isinstance(v, torch.Tensor):
                                out[0, t] = float(v.item())
                            else:
                                out[0, t] = float(v)
                        except Exception as e:
                            import time
                            now = time.time()
                            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and (now - self.last_parse_error_log) > 60.0:
                                print(f"❌ ASBN: scalar_matrix parse error at t={t}: {type(e).__name__}")
                                self.last_parse_error_log = now
                            self.parse_errors += 1
                            out[0, t] = float(default)
                    return out
        except Exception as e:
            print(f"❌ ASBN: _parse_scalar_matrix exception: {type(e).__name__}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
        return out

    def compute_lambda_scaled_tensor(self, pmax: torch.Tensor, uncertainty: torch.Tensor, gate: torch.Tensor, lambda_type: str) -> torch.Tensor:
        try:
            device = pmax.device if isinstance(pmax, torch.Tensor) else (uncertainty.device if isinstance(uncertainty, torch.Tensor) else torch.device("cpu"))
            p = pmax.clone().detach().to(device) if isinstance(pmax, torch.Tensor) else torch.tensor(pmax, device=device)
            u = uncertainty.clone().detach().to(device) if isinstance(uncertainty, torch.Tensor) else torch.tensor(uncertainty, device=device)
            g = gate.clone().detach().to(device) if isinstance(gate, torch.Tensor) else torch.tensor(gate, device=device)

            p = torch.clamp(p, 0.0, 1.0)
            u = torch.clamp(u, 0.0, 1.0)
            g = torch.clamp(g, 0.0, 1.0)

            base = float(self.lambda_base.get(lambda_type, 0.2))
            lam = base * (1.0 - p + 0.05) * (u + 0.05) * (g + 0.05)
            lam = torch.clamp(lam, min=0.01, max=min(1.0, float(self.lambda_max)))
            lam = torch.where(torch.isfinite(lam), lam, torch.full_like(lam, 0.1))
            return lam
        except Exception as e:
            if _VERBOSE_LOGGING:
                print(f"❌ ASBN: compute_lambda_scaled_tensor failed: {e}")
            try:
                return torch.full_like(pmax if isinstance(pmax, torch.Tensor) else torch.tensor(pmax), 0.1)
            except Exception:
                return torch.tensor(0.1)

    def forward(self, h: torch.Tensor, domain_labels: Optional[torch.Tensor] = None, global_step: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor, float]:
        if not isinstance(h, torch.Tensor) or h.dim() != 3:
            dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
            return h, torch.tensor(0.0, device=dev), 0.0

        if global_step is not None:
            self.current_step = max(0, int(global_step))

        B, T, H = h.size()
        device = h.device
        domain_accuracy = 0.0
        try:
            self._ensure_discriminators_on_device(device)
            h_flat = h.view(B * T, H)

            if domain_labels is not None:
                try:
                    domain_labels = domain_labels.to(device).long()
                except Exception:
                    try:
                        domain_labels = domain_labels.long()
                        domain_labels = domain_labels.to(device)
                    except Exception as e:
                        print(f"❌ ASBN: domain_labels conversion failed: {e}")
                        domain_labels = torch.ones((B,), dtype=torch.long, device=device)
                if domain_labels.dim() == 0:
                    domain_labels = domain_labels.unsqueeze(0).expand(B)
                elif domain_labels.numel() == 1 and B > 1:
                    domain_labels = domain_labels.view(1).expand(B).contiguous()
                elif domain_labels.size(0) != B:
                    domain_labels = domain_labels[:B] if domain_labels.size(0) > B else domain_labels[0].unsqueeze(0).expand(B)

            domain_expanded = domain_labels.unsqueeze(1).expand(B, T).reshape(-1) if domain_labels is not None else torch.ones(B * T, dtype=torch.long, device=device)

            source_mask = domain_expanded == 0
            target_mask = domain_expanded == 1
            h_normalized = h_flat.clone()

            try:
                source_count = int(source_mask.sum().item())
                if source_count >= 2:
                    src_idx = source_mask.nonzero(as_tuple=True)[0]
                    h_src = h_flat[src_idx]
                    if not torch.isfinite(h_src).all():
                        print(f"❌ ASBN: Source embeddings contain NaN/Inf before BN")
                        h_src = torch.where(torch.isfinite(h_src), h_src, torch.zeros_like(h_src))
                    h_src_bn = self.bn_source(h_src)
                    if not torch.isfinite(h_src_bn).all():
                        print(f"❌ ASBN: BN_source produced NaN/Inf")
                        h_src_bn = h_src
                    h_normalized[src_idx] = h_src_bn
                elif source_count == 1:
                    src_idx = source_mask.nonzero(as_tuple=True)[0]
                    h_normalized[src_idx] = h_flat[src_idx]
            except Exception as e:
                if _VERBOSE_LOGGING:
                    print(f"❌ ASBN: bn_source failed: {type(e).__name__}")

            try:
                target_count = int(target_mask.sum().item())
                if target_count >= 2:
                    tgt_idx = target_mask.nonzero(as_tuple=True)[0]
                    h_tgt = h_flat[tgt_idx]
                    if not torch.isfinite(h_tgt).all():
                        print(f"❌ ASBN: Target embeddings contain NaN/Inf before BN")
                        h_tgt = torch.where(torch.isfinite(h_tgt), h_tgt, torch.zeros_like(h_tgt))
                    h_tgt_bn = self.bn_target(h_tgt)
                    if not torch.isfinite(h_tgt_bn).all():
                        print(f"❌ ASBN: BN_target produced NaN/Inf")
                        h_tgt_bn = h_tgt
                    h_normalized[tgt_idx] = h_tgt_bn
                elif target_count == 1:
                    tgt_idx = target_mask.nonzero(as_tuple=True)[0]
                    h_normalized[tgt_idx] = h_flat[tgt_idx]
            except Exception as e:
                if _VERBOSE_LOGGING:
                    print(f"❌ ASBN: bn_target failed: {type(e).__name__}")

            h_out = h_normalized.view(B, T, H)
            domain_loss = torch.tensor(0.0, device=device)

            if self.training and _ENABLE_ASBN_TRAINING and self.current_step >= self.warmup_steps:
                if domain_labels is not None:
                    try:
                        grl_alpha = self.get_grl_alpha(self.current_step)
                        valid_indices = torch.arange(B * T, device=device)
                        sel_emb = h_normalized[valid_indices]
                        sel_labels = domain_expanded[valid_indices]
                        if sel_emb.size(0) > 0:
                            domain_input = gradient_reversal(sel_emb, alpha=grl_alpha)
                            domain_logits = self.d_domain(domain_input).to(device)
                            if not torch.isfinite(domain_logits).all():
                                print(f"❌ ASBN: domain_logits contains NaN/Inf")
                                domain_logits = torch.where(torch.isfinite(domain_logits), domain_logits, torch.zeros_like(domain_logits))
                            domain_loss = F.cross_entropy(domain_logits, sel_labels)
                            domain_loss = torch.clamp(domain_loss, min=0.0, max=10.0)
                            if not torch.isfinite(domain_loss):
                                print(f"❌ ASBN: domain_loss is NaN/Inf: {domain_loss}")
                                domain_loss = torch.tensor(0.0, device=device)
                            else:
                                with torch.no_grad():
                                    domain_preds = torch.argmax(domain_logits, dim=1)
                                    correct = int((domain_preds == sel_labels).sum().item())
                                    domain_accuracy = float((domain_preds == sel_labels).float().mean().item())
                                    source_mask_sel = sel_labels == 0
                                    target_mask_sel = sel_labels == 1
                                    with self._stats_lock:
                                        self.correct_domain += correct
                                        self.total_samples += int(sel_labels.size(0))
                                        self.domain_loss_accumulator += float(domain_loss.item()) * int(sel_labels.size(0))
                                        self.asbn_loss_accumulator += float(domain_loss.item()) * int(sel_labels.size(0))
                                        if source_mask_sel.any():
                                            source_correct = int((domain_preds[source_mask_sel] == sel_labels[source_mask_sel]).sum().item())
                                            self.correct_source += source_correct
                                            self.total_source += int(source_mask_sel.sum().item())
                                        if target_mask_sel.any():
                                            target_correct = int((domain_preds[target_mask_sel] == sel_labels[target_mask_sel]).sum().item())
                                            self.correct_target += target_correct
                                            self.total_target += int(target_mask_sel.sum().item())
                                        if self.total_samples >= self.stats_reset_interval:
                                            if _DEBUG_DISCOVERY:
                                                stats = self.get_detailed_stats()
                                                print(f"[ASBN-STATS] Resetting after {stats['num_updates']} samples: domain_loss={stats['domain_loss']:.4f}, domain_acc={stats['domain_accuracy']:.2%}")
                                            self.reset_stats()
                    except Exception as e:
                        print(f"❌ ASBN: Domain loss computation failed: {type(e).__name__}: {e}")
                        if _VERBOSE_LOGGING:
                            traceback.print_exc()
                        domain_loss = torch.tensor(0.0, device=device)
                else:
                    if self.current_step % 100 == 0 and _VERBOSE_LOGGING:
                        print(f"⚠️  ASBN: domain_labels is None at step {self.current_step}")
            if _DEBUG_DISCOVERY and self.current_step % 500 == 0:
                print(f"[ASBN] BN applied: src={int(source_mask.sum())}, tgt={int(target_mask.sum())}")
            return h_out, domain_loss, domain_accuracy
        except Exception as e:
            print(f"❌ ASBN: forward failed: {type(e).__name__}: {e}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
            return h, torch.tensor(0.0, device=device), 0.0

    def forward_with_grl_simplified(self, h: torch.Tensor, proto_probs: Any, uncertainties: Any, gates: Any, token_word_map: Optional[List[Dict[int, str]]] = None, domain_labels: Optional[torch.Tensor] = None, global_step: Optional[int] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        if global_step is not None:
            self.current_step = max(0, int(global_step))
        dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
        
        if self.current_step < self.warmup_steps:
            if self.current_step % 100 == 0 and (_VERBOSE_LOGGING or _DEBUG_DISCOVERY):
                print(f"⚠️  ASBN BLOCKED BY WARMUP: step {self.current_step}/{self.warmup_steps}")
            zero = torch.tensor(0.0, device=dev)
            return zero, zero, zero, zero
            
        if not self.training:
            zero = torch.tensor(0.0, device=dev)
            return zero, zero, zero, zero
            
        if not _ENABLE_ASBN_TRAINING:
            if self.current_step % 100 == 0 and (_VERBOSE_LOGGING or _DEBUG_DISCOVERY):
                print(f"⚠️  ASBN TRAINING DISABLED: ENABLE_ASBN_TRAINING={_ENABLE_ASBN_TRAINING}")
            zero = torch.tensor(0.0, device=dev)
            return zero, zero, zero, zero
            
        if not isinstance(h, torch.Tensor) or h.dim() != 3:
            if _VERBOSE_LOGGING:
                print(f"❌ ASBN: Invalid input h: type={type(h)}, shape={h.shape if isinstance(h, torch.Tensor) else 'N/A'}")
            zero = torch.tensor(0.0, device=dev)
            return zero, zero, zero, zero

        device = h.device
        self._ensure_discriminators_on_device(device)
        self.d_domain.train()
        self.d_freq.train()
        self.d_ctx.train()
        self.d_xl.train()
        B, T, H = h.size()

        if domain_labels is not None:
            try:
                domain_labels = domain_labels.to(device).long()
            except Exception:
                try:
                    domain_labels = domain_labels.long().to(device)
                except Exception as e:
                    print(f"❌ ASBN-GRL: domain_labels conversion failed: {e}")
                    domain_labels = torch.ones((B,), dtype=torch.long, device=device)
            if domain_labels.dim() == 0:
                domain_labels = domain_labels.unsqueeze(0).expand(B)
            elif domain_labels.numel() == 1 and B > 1:
                domain_labels = domain_labels.view(1).expand(B).contiguous()
            elif domain_labels.size(0) != B:
                domain_labels = domain_labels[:B] if domain_labels.size(0) > B else domain_labels[0].unsqueeze(0).expand(B)
        else:
            if self.current_step % 100 == 0 and (_VERBOSE_LOGGING or _DEBUG_DISCOVERY):
                print(f"⚠️  ASBN-GRL: domain_labels is None at step {self.current_step}")

        pmax_mat = self._parse_proto_probs_matrix(proto_probs, B, T, device)
        U_mat = self._parse_scalar_matrix(uncertainties, B, T, device, default=0.5)
        G_mat = self._parse_scalar_matrix(gates, B, T, device, default=0.0)
        pmax_mat = torch.clamp(pmax_mat, 0.0, 1.0)
        U_mat = torch.clamp(U_mat, 0.0, 1.0)
        G_mat = torch.clamp(G_mat, 0.0, 1.0)

        sel_mask = torch.ones((B, T), dtype=torch.bool, device=device)

        if token_word_map:
            try:
                for b in range(min(B, len(token_word_map))):
                    wm = token_word_map[b] or {}
                    for t in range(T):
                        if t in wm:
                            try:
                                token_str = wm[t]
                                if (not token_str) or (isinstance(token_str, str) and len(token_str.strip()) == 0) or (token_str in self.special_tokens):
                                    sel_mask[b, t] = False
                            except Exception:
                                pass
            except Exception as e:
                if _VERBOSE_LOGGING:
                    print(f"❌ ASBN-GRL: token_word_map filtering failed: {e}")

        sel_idx = sel_mask.view(-1).nonzero(as_tuple=True)[0]
        if sel_idx.numel() == 0:
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.current_step % 100 == 0:
                print("⚠️  ASBN-GRL: No valid tokens after filtering")
            zero = torch.tensor(0.0, device=device)
            return zero, zero, zero, zero

        h_flat = h.view(B * T, H)
        sel_emb = h_flat[sel_idx]
        pmax_flat = pmax_mat.view(-1)[sel_idx]
        U_flat = U_mat.view(-1)[sel_idx]
        G_flat = G_mat.view(-1)[sel_idx]

        seq_len_feature = float(T) / float(max(int(_MAX_LENGTH), 1))
        freq_feature = torch.stack([pmax_flat, U_flat], dim=1).to(device)
        ctx_feature = torch.stack([G_flat, torch.full_like(G_flat, seq_len_feature)], dim=1).to(device)
        xl_input = sel_emb

        grl_alpha = self.get_grl_alpha(global_step)
        freq_input = torch.cat([sel_emb, freq_feature], dim=1)
        ctx_input = torch.cat([sel_emb, ctx_feature], dim=1)
        xl_input_grl = gradient_reversal(xl_input, alpha=grl_alpha)
        freq_input_grl = gradient_reversal(freq_input, alpha=grl_alpha)
        ctx_input_grl = gradient_reversal(ctx_input, alpha=grl_alpha)

        freq_logits = self.d_freq(freq_input_grl).to(device)
        ctx_logits = self.d_ctx(ctx_input_grl).to(device)
        xl_logits = self.d_xl(xl_input_grl).to(device)

        freq_label = (pmax_flat >= self.freq_threshold).long().to(device)
        ctx_label = (U_flat <= self.uncertainty_threshold).long().to(device)
        xl_label = (G_flat >= self.gate_threshold).long().to(device)

        if freq_logits.size(0) == 0 or freq_label.size(0) == 0:
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.current_step % 100 == 0:
                print("⚠️  ASBN-GRL: Empty logits or labels")
            zero = torch.tensor(0.0, device=device)
            return zero, zero, zero, zero

        loss_freq = F.cross_entropy(freq_logits, freq_label, reduction="none")
        loss_ctx = F.cross_entropy(ctx_logits, ctx_label, reduction="none")
        loss_xl = F.cross_entropy(xl_logits, xl_label, reduction="none")

        loss_freq = torch.clamp(loss_freq, min=0.0, max=10.0)
        loss_ctx = torch.clamp(loss_ctx, min=0.0, max=10.0)
        loss_xl = torch.clamp(loss_xl, min=0.0, max=10.0)

        lam_freq = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "freq")
        lam_ctx = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "ctx")
        lam_xl = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "xl")

        weighted = lam_freq * loss_freq + lam_ctx * loss_ctx + lam_xl * loss_xl
        weighted = torch.clamp(weighted, min=0.0, max=10.0)
        mean_weighted = torch.mean(weighted) if weighted.numel() > 0 else torch.tensor(0.0, device=device)

        domain_loss = torch.tensor(0.0, device=device)
        domain_accuracy = torch.tensor(0.0, device=device)

        if domain_labels is not None:
            try:
                batch_indices = sel_idx // T
                batch_indices = torch.clamp(batch_indices, 0, B - 1)
                domain_flat = domain_labels[batch_indices].to(device).long() if domain_labels.numel() > 0 else torch.tensor([], dtype=torch.long, device=device)

                if domain_flat.numel() > 0:
                    domain_input = gradient_reversal(sel_emb, alpha=grl_alpha)
                    domain_logits = self.d_domain(domain_input).to(device)
                    domain_loss = F.cross_entropy(domain_logits, domain_flat)
                    domain_loss = torch.clamp(domain_loss, min=0.0, max=10.0)
                    if not torch.isfinite(domain_loss):
                        print(f"❌ ASBN-GRL: domain_loss is NaN/Inf: {domain_loss}")
                        domain_loss = torch.tensor(0.0, device=device)
                    else:
                        with torch.no_grad():
                            domain_preds = torch.argmax(domain_logits, dim=1)
                            correct = int((domain_preds == domain_flat).sum().item())
                            domain_accuracy = torch.tensor(float((domain_preds == domain_flat).float().mean().item()), device=device) if domain_flat.numel() > 0 else torch.tensor(0.0, device=device)
                            source_mask = domain_flat == 0
                            target_mask = domain_flat == 1
                            with self._stats_lock:
                                self.correct_domain += correct
                                self.total_samples += int(domain_flat.size(0))
                                self.domain_loss_accumulator += float(domain_loss.item()) * int(domain_flat.size(0))
                                if source_mask.any():
                                    source_correct = int((domain_preds[source_mask] == domain_flat[source_mask]).sum().item())
                                    self.correct_source += source_correct
                                    self.total_source += int(source_mask.sum().item())
                                if target_mask.any():
                                    target_correct = int((domain_preds[target_mask] == domain_flat[target_mask]).sum().item())
                                    self.correct_target += target_correct
                                    self.total_target += int(target_mask.sum().item())
            except Exception as e:
                print(f"❌ ASBN-GRL: domain classification failed: {type(e).__name__}: {e}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()

        encoder_loss = self.encoder_grl_scale * (mean_weighted + domain_loss)
        encoder_loss = torch.clamp(encoder_loss, min=0.0, max=10.0)

        if not torch.isfinite(encoder_loss):
            print(f"❌ ASBN-GRL: encoder_loss is NaN/Inf: {encoder_loss}")
            encoder_loss = torch.tensor(0.0, device=device)

        try:
            with self._stats_lock:
                self.asbn_loss_accumulator += float(encoder_loss.item()) * (sel_emb.size(0) if isinstance(sel_emb, torch.Tensor) else 1)
                if self.total_samples >= self.stats_reset_interval and _DEBUG_DISCOVERY:
                    stats = self.get_detailed_stats()
                    if stats['num_updates'] > 0:
                        print(f"[ASBN-GRL-STATS] Resetting after {stats['num_updates']} samples: domain_loss={stats['domain_loss']:.4f}, domain_acc={stats['domain_accuracy']:.2%}")
                    self.reset_stats()
        except Exception as e:
            if _VERBOSE_LOGGING:
                print(f"❌ ASBN-GRL: Stats update failed: {e}")

        if _DEBUG_DISCOVERY and self.current_step % 500 == 0:
            try:
                print(f"[ASBN-STEP-{self.current_step}] GRL alpha={grl_alpha:.3f}, encoder_loss={float(encoder_loss.item()):.4f}, mean_weighted={float(mean_weighted.item()):.4f}, domain_loss={float(domain_loss.item()):.4f}")
            except Exception:
                pass

        return encoder_loss, mean_weighted, domain_loss, domain_accuracy

    def test_asbn(self, batch_size: int = 2, seq_len: int = 10) -> bool:
        try:
            device = next(self.parameters()).device
        except Exception:
            device = torch.device("cpu")
        h = torch.randn(batch_size, seq_len, self.embed_dim, device=device)
        domain_labels = torch.randint(0, 2, (batch_size,), device=device)
        self.train()
        self.current_step = self.warmup_steps + 1
        h_out, domain_loss, domain_acc = self.forward(h, domain_labels, global_step=self.current_step)
        if not (isinstance(h_out, torch.Tensor) and h_out.shape == h.shape):
            print("❌ ASBN test: forward shape mismatch")
            return False
        if not (isinstance(domain_loss, torch.Tensor) and domain_loss.item() >= 0.0):
            print("❌ ASBN test: domain_loss invalid")
            return False
        if not (isinstance(domain_acc, float) and 0.0 <= domain_acc <= 1.0):
            print("❌ ASBN test: domain_accuracy invalid")
            return False
        proto_probs = torch.rand(batch_size, seq_len, 3, device=device)
        uncertainties = torch.rand(batch_size, seq_len, device=device)
        gates = torch.rand(batch_size, seq_len, device=device)
        enc_loss, adv_loss, dom_loss, dom_acc = self.forward_with_grl_simplified(h, proto_probs, uncertainties, gates, domain_labels=domain_labels, global_step=self.current_step)
        if not (isinstance(enc_loss, torch.Tensor) and enc_loss.item() >= 0.0):
            print("❌ ASBN test: encoder loss invalid")
            return False
        if not (isinstance(dom_acc, torch.Tensor) and 0.0 <= dom_acc.item() <= 1.0):
            print("❌ ASBN test: domain accuracy from GRL invalid")
            return False
        print("✅ ASBN test: All checks passed")
        return True


print("\n" + "=" * 80)
print("Cell 4: ASBN module loaded - NaN/Inf GRADIENT HARDENED")
print("=" * 80)
print("Features:")
print("  - BatchNorm1d eps=1e-3 (increased from 1e-5)")
print("  - GRL gradient magnitude bounds: [-10, +10]")
print("  - GRL alpha minimum: 0.01 (prevents vanishing)")
print("  - Discriminator logit clamping: [-10, +10]")
print("  - Loss component clamping: [0, 10]")
print("  - Xavier init gain=0.5 (conservative)")
print("  - Lambda max reduced to 1.0 (from 2.0)")
print("  - Dropout reduced to 0.2 (from 0.3)")
print("=" * 80)


In [None]:
# ==============================================================================
# CELL 5: TRG (TRANSLATION RATIONALE GENERATION) - NaN/Inf HARDENED
# ==============================================================================
from typing import List, Dict, Tuple, Optional, Set, Any
from collections import deque, defaultdict
import traceback
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threading
import time

try:
    _TRG_EVIDENCE_K = int(TRG_EVIDENCE_K)
    if _TRG_EVIDENCE_K <= 0:
        _TRG_EVIDENCE_K = 3
except Exception:
    _TRG_EVIDENCE_K = 3

try:
    _TRG_GEN_EMBED = int(TRG_GEN_EMBED)
    if _TRG_GEN_EMBED <= 0:
        _TRG_GEN_EMBED = 64
except Exception:
    _TRG_GEN_EMBED = 64

try:
    _MAX_SILVER_BUFFER = int(MAX_SILVER_BUFFER)
    if _MAX_SILVER_BUFFER <= 0:
        _MAX_SILVER_BUFFER = 50
except Exception:
    _MAX_SILVER_BUFFER = 50

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except Exception:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except Exception:
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except Exception:
    _DEBUG_TIMING = False

try:
    _ENABLE_TRG_INFERENCE = bool(ENABLE_TRG_INFERENCE)
except Exception:
    _ENABLE_TRG_INFERENCE = True

try:
    _ENABLE_TRG_TRAINING = bool(globals().get("ENABLE_TRG_TRAINING", True))
except Exception:
    _ENABLE_TRG_TRAINING = True

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
except Exception:
    _SOURCE_LANGUAGE = "bn"

try:
    _TAU_HIGH = float(TAU_HIGH)
    if _TAU_HIGH < 0 or _TAU_HIGH > 1:
        _TAU_HIGH = 0.85
except Exception:
    _TAU_HIGH = 0.85

try:
    _TAU_LOW = float(TAU_LOW)
    if _TAU_LOW < 0 or _TAU_LOW > 1:
        _TAU_LOW = 0.25
except Exception:
    _TAU_LOW = 0.25

try:
    _TAU_ACCEPT = float(TAU_ACCEPT)
    if _TAU_ACCEPT < 0 or _TAU_ACCEPT > 1:
        _TAU_ACCEPT = 0.80
except Exception:
    _TAU_ACCEPT = 0.80

try:
    _TRG_UNCERTAINTY_THRESHOLD = float(globals().get("TRG_UNCERTAINTY_THRESHOLD", _TAU_LOW))
    if _TRG_UNCERTAINTY_THRESHOLD < 0 or _TRG_UNCERTAINTY_THRESHOLD > 1:
        _TRG_UNCERTAINTY_THRESHOLD = _TAU_LOW
except Exception:
    _TRG_UNCERTAINTY_THRESHOLD = _TAU_LOW

try:
    _TRG_SPAN_THRESHOLD = float(globals().get("TRG_SPAN_THRESHOLD", globals().get("SPAN_THRESHOLD", 0.05)))
    if _TRG_SPAN_THRESHOLD < 0 or _TRG_SPAN_THRESHOLD > 1:
        _TRG_SPAN_THRESHOLD = 0.05
except Exception:
    _TRG_SPAN_THRESHOLD = 0.05

try:
    _TRG_TEMPERATURE = float(globals().get("TRG_TEMPERATURE", 1.0))
    if _TRG_TEMPERATURE <= 0:
        _TRG_TEMPERATURE = 1.0
except Exception:
    _TRG_TEMPERATURE = 1.0

try:
    _MAX_EXPLANATIONS_PER_SENTENCE = int(globals().get("MAX_EXPLANATIONS_PER_SENTENCE", 10))
    if _MAX_EXPLANATIONS_PER_SENTENCE <= 0:
        _MAX_EXPLANATIONS_PER_SENTENCE = 10
except Exception:
    _MAX_EXPLANATIONS_PER_SENTENCE = 10

_has_is_valid_token = "is_valid_token" in globals()
_has_get_tokenizer_special_tokens = "get_tokenizer_special_tokens" in globals()
_has_get_cached_special_tokens = "get_cached_special_tokens" in globals()

_TRG_PUNCT_SET = set(".,;:!?\\\"\\'()-[]{}\\/")

def _fallback_is_valid_token(token: Any, special_tokens: Set[str], tokenizer=None, language: str = "bn") -> bool:
    if token is None:
        return False
    try:
        token = str(token).strip()
    except Exception:
        return False
    if not token or token in special_tokens:
        return False
    clean = token.replace("▁", "").replace("Ġ", "").replace("##", "").replace("@@", "").replace("</w>", "").strip()
    if len(clean) < 2:
        return False
    if not any(c.isalpha() for c in clean):
        return False
    if all(c in _TRG_PUNCT_SET for c in clean):
        return False
    if clean.isdigit():
        return False
    return True

def _is_word_start(raw_token: Any, token_word_map: Optional[dict], idx: int) -> bool:
    try:
        if not isinstance(raw_token, str):
            raw_token = str(raw_token)
        if token_word_map and isinstance(token_word_map, dict) and idx in token_word_map:
            w = token_word_map[idx]
            if isinstance(w, str) and w.strip():
                return True
        if raw_token.startswith("▁") or raw_token.startswith("Ġ") or raw_token.startswith("\u2581"):
            return True
        clean = raw_token.replace("▁", "").replace("Ġ", "").replace("##", "").replace("@@", "").replace("</w>", "").strip()
        if len(clean) < 2:
            return False
        if all(ch in _TRG_PUNCT_SET for ch in clean):
            return False
        if any(c.isalpha() for c in clean):
            return True
        return False
    except Exception:
        return False

class ComprehensiveTRGExplanationTemplate:
    def __init__(self):
        self.explanation_templates = {
            "high_confidence": "Chose '{sense}' with high confidence ({confidence:.1%}) based on: '{evidence}'. {alternatives_text}",
            "medium_confidence": "Selected '{sense}' with moderate confidence ({confidence:.1%}). Evidence: '{evidence}'. {alternatives_text}",
            "low_confidence": "Uncertain; chose '{sense}' ({confidence:.1%}). Evidence: '{evidence}'. {alternatives_text} Review recommended.",
            "fallback": "Token '{token}' analyzed. Context: '{evidence}'.",
        }

    def generate_explanation(self, evidence: Dict) -> str:
        if not evidence or not isinstance(evidence, dict):
            return ""
        token = str(evidence.get("token", "unknown")).replace("▁", "").replace("Ġ", "").replace("##", "").replace("@@", "").replace("</w>", "")
        sense_info = evidence.get("chosen_sense", ("unknown", 0.5))
        if isinstance(sense_info, (tuple, list)) and len(sense_info) >= 2:
            sense_name, confidence = str(sense_info[0]), float(sense_info[1])
        else:
            sense_name, confidence = "unknown", 0.5
        confidence = max(0.0, min(1.0, confidence))
        evidence_tokens = evidence.get("evidence_tokens", []) or []
        evidence_str = ", ".join([str(tok).replace("▁", "").replace("Ġ", "").replace("##", "").replace("@@", "").replace("</w>", "") for tok in evidence_tokens[:_TRG_EVIDENCE_K]]) or "limited context"
        alternatives = evidence.get("alternatives", []) or []
        alternatives_text = ""
        if isinstance(alternatives, list) and alternatives:
            alt_parts = []
            for alt in alternatives[:2]:
                if isinstance(alt, (tuple, list)) and len(alt) >= 2:
                    alt_name, alt_conf = str(alt[0]), max(0.0, min(1.0, float(alt[1])))
                    alt_parts.append(f"'{alt_name}' ({alt_conf:.1%})")
            if alt_parts:
                alternatives_text = "Alternatives: " + ", ".join(alt_parts) + "."
        if confidence >= _TAU_ACCEPT:
            key = "high_confidence"
        elif confidence >= _TRG_UNCERTAINTY_THRESHOLD:
            key = "medium_confidence"
        else:
            key = "low_confidence"
        tpl = self.explanation_templates.get(key, self.explanation_templates["fallback"])
        try:
            return tpl.format(sense=sense_name, confidence=confidence, evidence=evidence_str, alternatives_text=alternatives_text, token=token)
        except Exception:
            return f"Token '{token}' -> '{sense_name}' ({confidence:.1%})."


class MemoryEfficientTRGExtractor:
    def __init__(self, tokenizer=None, language: str = "bn", dscd_module=None):
        self.tokenizer = tokenizer
        self.language = language
        self.dscd_module = dscd_module
        self.span_clamp_warnings = 0
        self.last_warning_time = 0.0
        self.extraction_failures = 0
        self.last_failure_log = 0.0
        try:
            if tokenizer is not None and _has_get_tokenizer_special_tokens:
                self.special_tokens = get_tokenizer_special_tokens(tokenizer)
            elif tokenizer is not None and _has_get_cached_special_tokens:
                try:
                    self.special_tokens = get_cached_special_tokens(tokenizer)
                except Exception:
                    self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []))
            else:
                self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []) if tokenizer is not None else [])
        except Exception:
            self.special_tokens = set()

    def extract_evidence_from_target(self, token_idx: int, span_start: int, span_end: int, tgt_preds: Any) -> Optional[List[str]]:
        if not isinstance(token_idx, int) or token_idx < 0:
            return None
        if not isinstance(span_start, int) or not isinstance(span_end, int):
            return None
        if span_start < 0 or span_end <= span_start:
            return None
        if not isinstance(tgt_preds, (torch.Tensor, list)):
            return None
        seq_len = len(tgt_preds) if isinstance(tgt_preds, list) else int(tgt_preds.size(0))
        if span_end > seq_len or token_idx >= seq_len:
            return None
        try:
            evidence_tokens = []
            for i in range(span_start, span_end):
                if i == token_idx:
                    continue
                if isinstance(tgt_preds, list):
                    evidence_tokens.append(str(tgt_preds[i]))
                else:
                    try:
                        evidence_tokens.append(str(int(tgt_preds[i].item())))
                    except Exception:
                        try:
                            evidence_tokens.append(str(tgt_preds[i].item()))
                        except Exception:
                            evidence_tokens.append(f"token_{i}")
            return evidence_tokens if evidence_tokens else None
        except Exception:
            return None

    def extract_evidence_efficiently(self, token_idx: int, tokens: List[str], dscd_outputs: Dict, token_word_map: Optional[dict] = None, decoder_attention: Optional[torch.Tensor] = None) -> Dict:
        if not isinstance(tokens, list):
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and time.time() - self.last_failure_log > 60.0:
                print(f"⚠️  [TRG] extract_evidence_efficiently: tokens not a list (type={type(tokens)})")
                self.last_failure_log = time.time()
            return self._create_fallback_evidence(token_idx if isinstance(token_idx, int) else 0, [])
        if not isinstance(token_idx, int):
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and time.time() - self.last_failure_log > 60.0:
                print(f"⚠️  [TRG] extract_evidence_efficiently: token_idx not int (type={type(token_idx)})")
                self.last_failure_log = time.time()
            return self._create_fallback_evidence(0, tokens)
        if token_idx < 0 or token_idx >= len(tokens):
            token_idx = max(0, min(token_idx, len(tokens) - 1)) if tokens else 0
        raw_token = tokens[token_idx]
        if _has_is_valid_token:
            try:
                is_valid = is_valid_token(raw_token, self.special_tokens, self.tokenizer, language=self.language)
            except Exception:
                is_valid = _fallback_is_valid_token(raw_token, self.special_tokens, self.tokenizer, self.language)
        else:
            is_valid = _fallback_is_valid_token(raw_token, self.special_tokens, self.tokenizer, self.language)
        if not is_valid:
            return self._create_fallback_evidence(token_idx, tokens)
        try:
            proto_probs = self._safe_extract_proto_probs(token_idx, dscd_outputs)
            uncertainty = self._safe_extract_uncertainty(token_idx, dscd_outputs)
            gate = self._safe_extract_gate(token_idx, dscd_outputs)
            span = self._safe_extract_span(token_idx, dscd_outputs)
            uncertainty = max(0.0, min(1.0, float(uncertainty)))
            gate = max(0.0, min(1.0, float(gate)))
            span = max(0.0, min(1.0, float(span)))
            evidence_tokens = None
            if isinstance(decoder_attention, torch.Tensor):
                try:
                    att = decoder_attention
                    if att.dim() == 4:
                        att_mean = att.mean(dim=(0, 1))
                    elif att.dim() == 3:
                        att_mean = att.mean(dim=0)
                    elif att.dim() == 2:
                        att_mean = att
                    else:
                        att_mean = None
                    if att_mean is not None and att_mean.dim() == 2 and token_idx < att_mean.size(0):
                        vec = att_mean[token_idx]
                        k = min(5, int(vec.numel()))
                        if k > 0:
                            top_k = torch.topk(vec, k=k).indices.cpu().numpy()
                            evidence_tokens = [tokens[int(i)] for i in top_k if int(i) < len(tokens) and int(i) != token_idx]
                except Exception as e:
                    if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and time.time() - self.last_failure_log > 60.0:
                        print(f"⚠️  [TRG] Decoder attention extraction failed: {type(e).__name__}")
                        self.last_failure_log = time.time()
                    evidence_tokens = None
            if evidence_tokens is None:
                evidence_tokens = self._extract_context_window(token_idx, tokens, token_word_map)
            seen = set()
            dedup = []
            for t in (evidence_tokens or []):
                if t not in seen:
                    seen.add(t)
                    dedup.append(t)
            evidence_tokens = dedup[:_TRG_EVIDENCE_K]
            top_senses = self._compute_sense_alternatives_fast(proto_probs, temperature=_TRG_TEMPERATURE)
            chosen_sense = top_senses[0] if top_senses else ("unknown", 0.5)
            alternatives = top_senses[1:3] if len(top_senses) > 1 else []
            token_value = token_word_map[token_idx] if token_word_map and token_idx in token_word_map and isinstance(token_word_map[token_idx], str) and token_word_map[token_idx].strip() else raw_token
            return {
                "token": token_value,
                "token_idx": token_idx,
                "evidence_tokens": evidence_tokens,
                "chosen_sense": chosen_sense,
                "alternatives": alternatives,
                "uncertainty": float(uncertainty),
                "gate": float(gate),
                "span": float(span),
            }
        except Exception as e:
            self.extraction_failures += 1
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and time.time() - self.last_failure_log > 60.0:
                print(f"❌ [TRG] extract_evidence_efficiently failed: {type(e).__name__}: {e}")
                self.last_failure_log = time.time()
            return self._create_fallback_evidence(token_idx, tokens)

    def _extract_context_window(self, token_idx: int, tokens: List[str], token_word_map: Optional[dict]) -> List[str]:
        context_window = 3
        start_idx = max(0, token_idx - context_window)
        end_idx = min(len(tokens), token_idx + context_window + 1)
        evidence_tokens = []
        for i in range(start_idx, end_idx):
            if i == token_idx or i >= len(tokens):
                continue
            rtok = tokens[i]
            if not _is_word_start(rtok, token_word_map, i):
                continue
            if _has_is_valid_token:
                try:
                    ok = is_valid_token(rtok, self.special_tokens, self.tokenizer, language=self.language)
                except Exception:
                    ok = _fallback_is_valid_token(rtok, self.special_tokens, self.tokenizer, self.language)
            else:
                ok = _fallback_is_valid_token(rtok, self.special_tokens, self.tokenizer, self.language)
            if ok:
                if token_word_map and isinstance(token_word_map.get(i, ""), str) and token_word_map[i].strip():
                    evidence_tokens.append(token_word_map[i].strip())
                else:
                    clean = str(rtok).replace("▁", "").replace("Ġ", "").replace("##", "").replace("@@", "").replace("</w>", "").strip()
                    if clean:
                        evidence_tokens.append(clean)
        return evidence_tokens

    def _safe_extract_proto_probs(self, token_idx: int, dscd_outputs: Dict) -> torch.Tensor:
        try:
            if not isinstance(dscd_outputs, dict):
                return torch.tensor([1.0], dtype=torch.float32)
            pp_all = dscd_outputs.get("proto_probs", None)
            if pp_all is None:
                return torch.tensor([1.0], dtype=torch.float32)
            if isinstance(pp_all, torch.Tensor):
                p = pp_all.detach().cpu()
                if not torch.isfinite(p).all():
                    return torch.tensor([1.0], dtype=torch.float32)
                if p.dim() == 3:
                    B, T, K = p.shape
                    row = p[0] if B > 0 else p
                    if token_idx < row.shape[0]:
                        vec = row[token_idx].flatten().float()
                    else:
                        vec = row.flatten().float()
                elif p.dim() == 2:
                    vec = p[token_idx].flatten().float() if token_idx < p.size(0) else p.flatten().float()
                else:
                    vec = p.flatten().float()
                if vec.numel() == 0:
                    return torch.tensor([1.0], dtype=torch.float32)
                if not torch.isfinite(vec).all():
                    return torch.tensor([1.0], dtype=torch.float32)
                vec = torch.clamp(vec, min=0.0, max=1.0)
                s = float(vec.sum().item()) if vec.numel() > 0 else 0.0
                if s <= 1e-9:
                    return torch.tensor([1.0], dtype=torch.float32)
                return (vec / (s + 1e-9)).to(dtype=torch.float32)
            if isinstance(pp_all, (list, tuple)):
                row = pp_all[0]
                if isinstance(row, torch.Tensor):
                    r = row.detach().cpu().float()
                    if not torch.isfinite(r).all():
                        return torch.tensor([1.0], dtype=torch.float32)
                    if r.dim() >= 1 and token_idx < r.size(0):
                        vec = r[token_idx].flatten().float()
                    else:
                        vec = r.flatten().float()
                    vec = torch.clamp(vec, min=0.0, max=1.0)
                    s = float(vec.sum().item()) if vec.numel() > 0 else 0.0
                    if s <= 1e-9:
                        return torch.tensor([1.0], dtype=torch.float32)
                    return (vec / (s + 1e-9)).to(dtype=torch.float32)
                if isinstance(row, (list, tuple, np.ndarray)):
                    if token_idx < len(row):
                        val = row[token_idx]
                        if isinstance(val, torch.Tensor):
                            vec = val.detach().cpu().float().flatten()
                        else:
                            arr = np.asarray(val, dtype=np.float32).flatten()
                            if arr.size == 0:
                                return torch.tensor([1.0], dtype=torch.float32)
                            if not np.isfinite(arr).all():
                                return torch.tensor([1.0], dtype=torch.float32)
                            vec = torch.from_numpy(arr).float()
                        vec = torch.clamp(vec, min=0.0, max=1.0)
                        s = float(vec.sum().item()) if vec.numel() > 0 else 0.0
                        if s <= 1e-9:
                            return torch.tensor([1.0], dtype=torch.float32)
                        return (vec / (s + 1e-9)).to(dtype=torch.float32)
                    maybe = row[0]
                    if isinstance(maybe, torch.Tensor):
                        vec = maybe.detach().cpu().float().flatten()
                        if not torch.isfinite(vec).all():
                            return torch.tensor([1.0], dtype=torch.float32)
                        s = float(vec.sum().item()) if vec.numel() > 0 else 0.0
                        if s <= 1e-9:
                            return torch.tensor([1.0], dtype=torch.float32)
                        return (vec / (s + 1e-9)).to(dtype=torch.float32)
                    arr = np.asarray(maybe, dtype=np.float32).flatten()
                    if arr.size == 0:
                        return torch.tensor([1.0], dtype=torch.float32)
                    if not np.isfinite(arr).all():
                        return torch.tensor([1.0], dtype=torch.float32)
                    vec = torch.from_numpy(arr).float()
                    s = float(vec.sum().item())
                    if s <= 1e-9:
                        return torch.tensor([1.0], dtype=torch.float32)
                    return (vec / (s + 1e-9)).to(dtype=torch.float32)
        except Exception as e:
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.extraction_failures % 100 == 0:
                print(f"⚠️  [TRG] _safe_extract_proto_probs failed: {type(e).__name__}")
        return torch.tensor([1.0], dtype=torch.float32)

    def _safe_extract_uncertainty(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            if not isinstance(dscd_outputs, dict):
                return 0.5
            U_all = dscd_outputs.get("uncertainties", None)
            if U_all is None:
                return 0.5
            row = U_all[0] if isinstance(U_all, (list, tuple)) and len(U_all) > 0 else U_all
            if isinstance(row, torch.Tensor):
                r = row.detach().cpu()
                if r.dim() >= 1 and token_idx < r.size(0):
                    val = float(r[token_idx].item())
                    if not np.isfinite(val):
                        return 0.5
                    return max(0.0, min(1.0, float(val)))
                if r.dim() == 0:
                    val = float(r.item())
                    if not np.isfinite(val):
                        return 0.5
                    return max(0.0, min(1.0, val))
            if isinstance(row, (list, tuple, np.ndarray)) and token_idx < len(row):
                val = row[token_idx]
                if isinstance(val, torch.Tensor):
                    fval = float(val.item())
                else:
                    fval = float(val)
                if not np.isfinite(fval):
                    return 0.5
                return max(0.0, min(1.0, fval))
        except Exception as e:
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.extraction_failures % 100 == 0:
                print(f"⚠️  [TRG] _safe_extract_uncertainty failed: {type(e).__name__}")
        return 0.5

    def _safe_extract_gate(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            if not isinstance(dscd_outputs, dict):
                return 0.0
            G_all = dscd_outputs.get("gates", None)
            if G_all is None:
                return 0.0
            row = G_all[0] if isinstance(G_all, (list, tuple)) and len(G_all) > 0 else G_all
            if isinstance(row, torch.Tensor):
                r = row.detach().cpu()
                if r.dim() >= 1 and token_idx < r.size(0):
                    val = float(r[token_idx].item())
                    if not np.isfinite(val):
                        return 0.0
                    return max(0.0, min(1.0, val))
                if r.dim() == 0:
                    val = float(r.item())
                    if not np.isfinite(val):
                        return 0.0
                    return max(0.0, min(1.0, val))
            if isinstance(row, (list, tuple, np.ndarray)) and token_idx < len(row):
                val = row[token_idx]
                if isinstance(val, torch.Tensor):
                    fval = float(val.item())
                else:
                    fval = float(val)
                if not np.isfinite(fval):
                    return 0.0
                return max(0.0, min(1.0, fval))
        except Exception as e:
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.extraction_failures % 100 == 0:
                print(f"⚠️  [TRG] _safe_extract_gate failed: {type(e).__name__}")
        return 0.0

    def _safe_extract_span(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            if not isinstance(dscd_outputs, dict):
                return 0.0
            S_all = dscd_outputs.get("span_preds", None)
            if S_all is None:
                return 0.0
            row = S_all[0] if isinstance(S_all, (list, tuple)) and len(S_all) > 0 else S_all
            if isinstance(row, torch.Tensor):
                r = row.detach().cpu()
                if r.dim() >= 1 and token_idx < r.size(0):
                    span_val = float(r[token_idx].item())
                elif r.dim() == 0:
                    span_val = float(r.item())
                else:
                    return 0.0
            elif isinstance(row, (list, tuple, np.ndarray)) and token_idx < len(row):
                val = row[token_idx]
                if isinstance(val, torch.Tensor):
                    span_val = float(val.item())
                else:
                    span_val = float(val)
            else:
                return 0.0
            if not np.isfinite(span_val):
                return 0.0
            if span_val < 0.0:
                now = time.time()
                if self.span_clamp_warnings < 10 or (now - self.last_warning_time) > 60.0:
                    if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                        print(f"⚠️  [TRG] Negative span {span_val:.3f} clamped to 0.0")
                    self.span_clamp_warnings += 1
                    self.last_warning_time = now
                return 0.0
            if span_val > 1.0:
                now = time.time()
                if self.span_clamp_warnings < 10 or (now - self.last_warning_time) > 60.0:
                    if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                        print(f"⚠️  [TRG] Span {span_val:.3f} clamped to 1.0")
                    self.span_clamp_warnings += 1
                    self.last_warning_time = now
                return 1.0
            return float(span_val)
        except Exception as e:
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.extraction_failures % 100 == 0:
                print(f"⚠️  [TRG] _safe_extract_span failed: {type(e).__name__}")
        return 0.0

    def compute_span(self, sense_probs: Any) -> float:
        try:
            if isinstance(sense_probs, dict):
                probs = list(sense_probs.values())
            else:
                probs = sense_probs
            if isinstance(probs, torch.Tensor):
                probs = probs.detach().cpu().flatten().numpy().tolist()
            if isinstance(probs, (np.ndarray, list)):
                probs = list(probs)
            if len(probs) < 2:
                return 0.0
            sorted_probs = sorted([float(p) for p in probs if np.isfinite(float(p))], reverse=True)
            if len(sorted_probs) < 2:
                return 0.0
            span = sorted_probs[0] - sorted_probs[1]
            return float(max(0.0, min(1.0, span)))
        except Exception:
            return 0.0

    def _compute_sense_alternatives_fast(self, proto_probs: Any, temperature: float = 1.0) -> List[Tuple[str, float]]:
        try:
            if not isinstance(proto_probs, torch.Tensor):
                proto_probs = torch.as_tensor(proto_probs, dtype=torch.float32)
            probs = proto_probs.flatten().float()
            if not torch.isfinite(probs).all():
                return [("unknown", 0.5)]
            probs = torch.clamp(probs, min=1e-9, max=1.0)
            temperature = max(0.1, min(10.0, float(temperature)))
            if temperature != 1.0 and probs.numel() > 1:
                probs = probs / (probs.sum() + 1e-9)
                log_probs = torch.log(probs + 1e-9)
                scaled_log_probs = log_probs / temperature
                probs = F.softmax(scaled_log_probs, dim=0)
            if probs.numel() > 1:
                probs_sorted, indices = torch.sort(probs, descending=True)
                top_k = min(3, int(indices.numel()))
                return [(f"sense_{int(indices[i].item())}", max(0.0, min(1.0, float(probs_sorted[i].item())))) for i in range(top_k)]
            else:
                return [("sense_0", max(0.0, min(1.0, float(probs[0].item()))))]
        except Exception:
            return [("unknown", 0.5)]

    def _create_fallback_evidence(self, token_idx: int, tokens: List[str]) -> Dict:
        token = tokens[token_idx] if isinstance(tokens, list) and 0 <= token_idx < len(tokens) else "UNK"
        return {"token": token, "token_idx": token_idx, "evidence_tokens": [], "chosen_sense": ("unknown", 0.5), "alternatives": [], "uncertainty": 0.5, "gate": 0.0, "span": 0.0}

    def get_homograph_tokens_from_dscd(self) -> Set[str]:
        homograph_tokens: Set[str] = set()
        try:
            if self.dscd_module is not None:
                if hasattr(self.dscd_module, "discovered_homographs"):
                    homograph_tokens = set(self.dscd_module.discovered_homographs)
                elif hasattr(self.dscd_module, "prototype_stores"):
                    for token, store in self.dscd_module.prototype_stores.items():
                        try:
                            size = store.size() if hasattr(store, "size") else (len(getattr(store, "centroids", [])) if hasattr(store, "centroids") else 0)
                            if size >= 2:
                                clean = str(token).replace("▁", "").replace("Ġ", "").replace("##", "").replace("@@", "").replace("</w>", "").strip()
                                if clean:
                                    homograph_tokens.add(clean)
                        except Exception:
                            continue
            else:
                if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and time.time() % 300 < 1.0:
                    print("⚠️  [TRG] dscd_module is None, cannot get homographs")
        except Exception as e:
            if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                print(f"❌ [TRG] get_homograph_tokens_from_dscd failed: {type(e).__name__}: {e}")
        return homograph_tokens


class CompleteTRGWithExplanations(nn.Module):
    def __init__(self, embed_dim: Optional[int] = None, tokenizer=None, language: str = "bn", dscd_module=None):
        super().__init__()
        self.embed_dim = max(1, int(embed_dim) if embed_dim is not None else int(_TRG_GEN_EMBED))
        self.tokenizer = tokenizer
        self.language = language
        self.dscd_module = dscd_module
        try:
            if tokenizer is not None and _has_get_tokenizer_special_tokens:
                self.special_tokens = get_tokenizer_special_tokens(tokenizer)
            elif tokenizer is not None and _has_get_cached_special_tokens:
                try:
                    self.special_tokens = get_cached_special_tokens(tokenizer)
                except Exception:
                    self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []))
            else:
                self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []) if tokenizer is not None else [])
        except Exception:
            self.special_tokens = set()
        self.template_system = ComprehensiveTRGExplanationTemplate()
        self.evidence_extractor = MemoryEfficientTRGExtractor(tokenizer, language=language, dscd_module=dscd_module)
        self.silver_buffer = deque(maxlen=max(1, int(_MAX_SILVER_BUFFER)))
        self._silver_lock = threading.Lock()
        self.stats_reset_interval = 1000
        self.explanations_generated = 0
        self.explanations_generated_training = 0
        self.explanations_generated_inference = 0
        self.high_confidence_explanations = 0
        self.low_confidence_explanations = 0
        self.empty_evidence_count = 0
        self.total_evidence_tokens = 0
        self.tokens_filtered_word_start = 0
        self.tokens_filtered_validity = 0
        self.tokens_filtered_ambiguity = 0
        self.dscd_homographs_explained = 0
        self._stats_lock = threading.Lock()
        self._last_stats_log = 0.0
        if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
            print("=" * 80)
            print("[TRG-INIT] CompleteTRGWithExplanations - NaN/Inf HARDENED")
            print("=" * 80)
            print(f"  uncertainty={_TRG_UNCERTAINTY_THRESHOLD:.2f} span={_TRG_SPAN_THRESHOLD:.2f}")
            print(f"  Training: {'ENABLED' if _ENABLE_TRG_TRAINING else 'DISABLED'}")
            print(f"  Inference: {'ENABLED' if _ENABLE_TRG_INFERENCE else 'DISABLED'}")
            print(f"  All extracted values clamped to [0, 1]")
            print(f"  NaN/Inf protection on proto_probs, uncertainty, gate, span")
            print("=" * 80)

    def _update_stats(self, evidence: Dict, is_dscd_homograph: bool = False, is_training: bool = False) -> None:
        with self._stats_lock:
            self.explanations_generated += 1
            if is_training:
                self.explanations_generated_training += 1
            else:
                self.explanations_generated_inference += 1
            if is_dscd_homograph:
                self.dscd_homographs_explained += 1
            if not evidence.get("evidence_tokens"):
                self.empty_evidence_count += 1
            else:
                self.total_evidence_tokens += len(evidence["evidence_tokens"])
            confidence = 0.5
            chosen = evidence.get("chosen_sense")
            if isinstance(chosen, (tuple, list)) and len(chosen) >= 2:
                try:
                    confidence = max(0.0, min(1.0, float(chosen[1])))
                except Exception:
                    confidence = 0.5
            if confidence >= _TAU_ACCEPT:
                self.high_confidence_explanations += 1
            elif confidence < _TRG_UNCERTAINTY_THRESHOLD:
                self.low_confidence_explanations += 1
            if self.explanations_generated >= self.stats_reset_interval:
                if _DEBUG_DISCOVERY and time.time() - self._last_stats_log > 60.0:
                    stats = self.get_statistics()
                    print(f"[TRG-STATS] {stats}")
                    self._last_stats_log = time.time()
                self.reset_statistics()

    def _add_to_silver_buffer(self, evidence: Dict, explanation: str, tokens: List[str]) -> None:
        try:
            conf = 0.5
            chosen = evidence.get("chosen_sense")
            if isinstance(chosen, (tuple, list)) and len(chosen) >= 2:
                conf = max(0.0, min(1.0, float(chosen[1])))
            entry = {"token": str(evidence.get("token", "UNK"))[:20], "explanation": str(explanation)[:150], "confidence": conf}
            with self._silver_lock:
                self.silver_buffer.append(entry)
        except Exception:
            pass

    def generate_explanation_for_token(self, token_idx: int, tokens: List[str], dscd_outputs: Dict, token_word_map: Optional[dict] = None, decoder_attention: Optional[torch.Tensor] = None, is_dscd_homograph: bool = False) -> Tuple[str, Dict]:
        if not _ENABLE_TRG_INFERENCE and not self.training:
            return "", {}
        if not _ENABLE_TRG_TRAINING and self.training:
            return "", {}
        if not isinstance(tokens, list) or not isinstance(token_idx, int):
            return "", {}
        if token_idx < 0 or token_idx >= len(tokens):
            return "", {}
        raw_token = tokens[token_idx]
        if _has_is_valid_token:
            try:
                is_valid = is_valid_token(raw_token, self.special_tokens, self.tokenizer, language=self.language)
            except Exception:
                is_valid = _fallback_is_valid_token(raw_token, self.special_tokens, self.tokenizer, self.language)
        else:
            is_valid = _fallback_is_valid_token(raw_token, self.special_tokens, self.tokenizer, self.language)
        if not is_valid:
            return "", {}
        try:
            evidence = self.evidence_extractor.extract_evidence_efficiently(token_idx, tokens, dscd_outputs, token_word_map=token_word_map, decoder_attention=decoder_attention)
            if (not evidence.get("evidence_tokens")) and (not is_dscd_homograph):
                if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.explanations_generated % 100 == 0:
                    print(f"⚠️  [TRG] Token {token_idx} ({raw_token}): No evidence & not DSCD homograph, skipping")
                return "", {}
            explanation_text = self.template_system.generate_explanation(evidence)
            self._update_stats(evidence, is_dscd_homograph=is_dscd_homograph, is_training=self.training)
            self._add_to_silver_buffer(evidence, explanation_text, tokens)
            return explanation_text, evidence
        except Exception as e:
            if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and self.explanations_generated % 100 == 0:
                print(f"❌ [TRG] generate_explanation_for_token failed: {type(e).__name__}: {e}")
            return "", {}

    @staticmethod
    def _to_list_helper(x: Any) -> List[float]:
        if x is None:
            return []
        try:
            if isinstance(x, torch.Tensor):
                arr = x.detach().cpu().numpy().flatten()
                result = []
                for v in arr.tolist():
                    fv = float(v)
                    if not np.isfinite(fv):
                        result.append(0.0)
                    else:
                        result.append(max(0.0, min(1.0, fv)))
                return result
            if isinstance(x, (list, tuple, np.ndarray)):
                out = []
                for v in x:
                    try:
                        if isinstance(v, torch.Tensor):
                            fval = float(v.flatten()[0].item())
                        else:
                            fval = float(v)
                        if not np.isfinite(fval):
                            out.append(0.0)
                        else:
                            out.append(max(0.0, min(1.0, fval)))
                    except Exception:
                        out.append(0.0)
                return out
            fval = float(x)
            if not np.isfinite(fval):
                return [0.0]
            return [max(0.0, min(1.0, fval))]
        except Exception:
            return []

    def process_sentence_for_explanations(self, tokens: List[str], dscd_outputs: Dict, token_word_map: Optional[dict] = None, uncertainty_threshold: Optional[float] = None, span_threshold: Optional[float] = None, decoder_attention: Optional[torch.Tensor] = None, max_explanations: int = _MAX_EXPLANATIONS_PER_SENTENCE) -> List[Dict]:
        if not _ENABLE_TRG_INFERENCE and not self.training:
            return []
        if not _ENABLE_TRG_TRAINING and self.training:
            return []
        uncertainty_threshold = float(uncertainty_threshold) if uncertainty_threshold is not None else float(_TRG_UNCERTAINTY_THRESHOLD)
        span_threshold = float(span_threshold) if span_threshold is not None else float(_TRG_SPAN_THRESHOLD)
        max_explanations = max(1, int(max_explanations))
        explanations: List[Dict] = []
        try:
            if not tokens or not isinstance(tokens, list):
                return explanations
            if not isinstance(dscd_outputs, dict) or not dscd_outputs:
                if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and time.time() % 300 < 1.0:
                    print("⚠️  [TRG] dscd_outputs is empty or not dict")
                return explanations
            U_all = dscd_outputs.get("uncertainties", [])
            S_all = dscd_outputs.get("span_preds", [])
            if not U_all or not U_all[0]:
                if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and time.time() % 300 < 1.0:
                    print("⚠️  [TRG] uncertainties not found in dscd_outputs")
                return explanations
            U = self._to_list_helper(U_all[0])
            S = self._to_list_helper(S_all[0]) if S_all and S_all[0] else [0.0] * len(U)
            if len(S) < len(U):
                S.extend([0.0] * (len(U) - len(S)))
            if not U:
                return explanations
            dscd_homographs = self.evidence_extractor.get_homograph_tokens_from_dscd()
            candidates: List[Tuple[int, float, float, str, int, int]] = []
            for idx in range(min(len(tokens), len(U))):
                tok = tokens[idx]
                clean_tok = str(tok).replace("▁", "").replace("Ġ", "").replace("##", "").replace("@@", "").replace("</w>", "").strip()
                if not _is_word_start(tok, token_word_map, idx):
                    with self._stats_lock:
                        self.tokens_filtered_word_start += 1
                    continue
                if _has_is_valid_token:
                    try:
                        valid = is_valid_token(tok, self.special_tokens, self.tokenizer, language=self.language)
                    except Exception:
                        valid = _fallback_is_valid_token(tok, self.special_tokens, self.tokenizer, self.language)
                else:
                    valid = _fallback_is_valid_token(tok, self.special_tokens, self.tokenizer, self.language)
                if not valid:
                    with self._stats_lock:
                        self.tokens_filtered_validity += 1
                    continue
                u = float(U[idx]) if idx < len(U) else 0.5
                s = float(S[idx]) if idx < len(S) else 0.0
                u = max(0.0, min(1.0, u))
                s = max(0.0, min(1.0, s))
                in_dscd = clean_tok in dscd_homographs
                if in_dscd:
                    priority = 1
                elif (u >= uncertainty_threshold) and (s >= span_threshold):
                    priority = 2
                elif u >= uncertainty_threshold:
                    priority = 3
                elif s >= span_threshold:
                    priority = 4
                else:
                    with self._stats_lock:
                        self.tokens_filtered_ambiguity += 1
                    continue
                candidates.append((idx, u, s, clean_tok, priority, idx))
            if not candidates:
                return explanations
            candidates.sort(key=lambda t: (t[4], -t[1], -t[2], t[5]))
            count = 0
            for (token_idx, u, s, clean_tok, priority, _) in candidates:
                if count >= max_explanations:
                    break
                try:
                    proto_probs = None
                    try:
                        proto_probs = self.evidence_extractor._safe_extract_proto_probs(token_idx, dscd_outputs)
                    except Exception:
                        proto_probs = None
                    single_proto = (isinstance(proto_probs, torch.Tensor) and proto_probs.numel() == 1) or (isinstance(proto_probs, (list, tuple, np.ndarray)) and len(proto_probs) == 1)
                    if single_proto and (s <= 1e-4) and (u <= (0.5 * uncertainty_threshold)) and (not (clean_tok in dscd_homographs)):
                        with self._stats_lock:
                            self.tokens_filtered_ambiguity += 1
                        continue
                    explanation_text, evidence = self.generate_explanation_for_token(token_idx, tokens, dscd_outputs, token_word_map=token_word_map, decoder_attention=decoder_attention, is_dscd_homograph=(priority == 1))
                    if explanation_text and evidence:
                        out_token = token_word_map[token_idx] if token_word_map and token_idx in token_word_map else tokens[token_idx].replace("▁", "").replace("Ġ", "").replace("##", "").replace("@@", "").replace("</w>", "")
                        explanations.append({"token_idx": token_idx, "token": out_token, "explanation": explanation_text, "uncertainty": max(0.0, min(1.0, u)), "span": max(0.0, min(1.0, s)), "dscd_discovered": (priority == 1), "priority": priority})
                        count += 1
                except Exception as e:
                    if (_VERBOSE_LOGGING or _DEBUG_DISCOVERY) and count % 50 == 0:
                        print(f"❌ [TRG] Explanation generation failed for token {token_idx}: {type(e).__name__}")
                    continue
        except Exception as e:
            if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                print(f"❌ [TRG] process_sentence_for_explanations failed: {type(e).__name__}: {e}")
        return explanations

    def get_statistics(self) -> Dict:
        with self._stats_lock:
            total = max(self.explanations_generated, 1)
            avg_evidence_tokens = (self.total_evidence_tokens / total) if self.explanations_generated > 0 else 0.0
            return {
                "explanations_generated": self.explanations_generated,
                "explanations_generated_training": self.explanations_generated_training,
                "explanations_generated_inference": self.explanations_generated_inference,
                "high_confidence_explanations": self.high_confidence_explanations,
                "low_confidence_explanations": self.low_confidence_explanations,
                "empty_evidence_count": self.empty_evidence_count,
                "total_evidence_tokens": self.total_evidence_tokens,
                "tokens_filtered_word_start": self.tokens_filtered_word_start,
                "tokens_filtered_validity": self.tokens_filtered_validity,
                "tokens_filtered_ambiguity": self.tokens_filtered_ambiguity,
                "dscd_homographs_explained": self.dscd_homographs_explained,
                "high_confidence_rate": self.high_confidence_explanations / total,
                "low_confidence_rate": self.low_confidence_explanations / total,
                "empty_evidence_rate": self.empty_evidence_count / total,
                "avg_evidence_tokens": avg_evidence_tokens,
                "silver_buffer_size": len(self.silver_buffer),
                "dscd_homograph_rate": self.dscd_homographs_explained / total
            }

    def reset_statistics(self) -> None:
        with self._stats_lock:
            self.explanations_generated = 0
            self.explanations_generated_training = 0
            self.explanations_generated_inference = 0
            self.high_confidence_explanations = 0
            self.low_confidence_explanations = 0
            self.empty_evidence_count = 0
            self.total_evidence_tokens = 0
            self.tokens_filtered_word_start = 0
            self.tokens_filtered_validity = 0
            self.tokens_filtered_ambiguity = 0
            self.dscd_homographs_explained = 0

    def clear_silver_buffer(self) -> None:
        with self._silver_lock:
            self.silver_buffer.clear()

    def test_trg(self, tokenizer=None) -> bool:
        try:
            tokens = ["▁আমি", "▁কল", "▁বন্ধ", "▁করেছি", "।"]
            dscd_outputs = {"proto_probs": [[torch.tensor([0.6, 0.4]) for _ in tokens]], "uncertainties": [[0.1, 0.5, 0.2, 0.1, 0.0]], "span_preds": [[0.05, 0.3, 0.1, 0.05, 0.0]], "gates": [[0.2, 0.8, 0.3, 0.2, 0.0]]}
            token_word_map = {0: "আমি", 1: "কল", 2: "বন্ধ", 3: "করেছি", 4: "।"}
            self.eval()
            explanations = self.process_sentence_for_explanations(tokens=tokens, dscd_outputs=dscd_outputs, token_word_map=token_word_map, max_explanations=3)
            if _VERBOSE_LOGGING:
                print(f"[TRG-TEST] Generated {len(explanations)} explanations")
            self.reset_statistics()
            return True
        except Exception:
            traceback.print_exc()
            return False


print("\n" + "=" * 80)
print("Cell 5: TRG Ready (DATA-DRIVEN) - NaN/Inf HARDENED")
print("=" * 80)
print("Configuration:")
print(f"  - Uncertainty threshold: {_TRG_UNCERTAINTY_THRESHOLD:.2f}")
print(f"  - Span threshold: {_TRG_SPAN_THRESHOLD:.2f}")
print(f"  - Temperature: {_TRG_TEMPERATURE:.2f}")
print(f"  - TAU_HIGH: {_TAU_HIGH:.2f}")
print(f"  - TAU_LOW: {_TAU_LOW:.2f}")
print(f"  - TAU_ACCEPT: {_TAU_ACCEPT:.2f}")
print(f"  - Max explanations: {_MAX_EXPLANATIONS_PER_SENTENCE}")
print(f"  - Evidence K: {_TRG_EVIDENCE_K}")
print(f"  - Training mode: {'ENABLED' if _ENABLE_TRG_TRAINING else 'DISABLED'}")
print(f"  - Inference mode: {'ENABLED' if _ENABLE_TRG_INFERENCE else 'DISABLED'}")
print("NaN/Inf Protections:")
print("  ✅ All extracted values clamped to [0, 1]")
print("  ✅ NaN check on proto_probs before normalization")
print("  ✅ Division by zero protection (epsilon=1e-9)")
print("  ✅ Temperature bounds: [0.1, 10.0]")
print("  ✅ Confidence bounds enforced in explanations")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 6: TATN MODEL - NaN/Inf GRADIENT FULLY PROTECTED - STEP 124 FIX
# ==============================================================================

from typing import List, Dict, Optional, Any, Tuple
import traceback
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import M2M100ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
import threading
import gc
import time

def _get_int_global(name: str, default: int) -> int:
    try:
        val = globals().get(name)
        if val is None:
            return default
        v = int(val)
        return v if v > 0 else default
    except Exception:
        return default

def _get_float_global(name: str, default: float) -> float:
    try:
        val = globals().get(name)
        if val is None:
            return default
        return float(val)
    except Exception:
        return default

def _get_bool_global(name: str, default: bool) -> bool:
    try:
        val = globals().get(name)
        if val is None:
            return default
        return bool(val)
    except Exception:
        return default

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
    _TARGET_LANGUAGE = str(TARGET_LANGUAGE)
except Exception:
    _SOURCE_LANGUAGE = "bn"
    _TARGET_LANGUAGE = "en"

_DSCD_BUFFER_SIZE         = max(1, _get_int_global("DSCD_BUFFER_SIZE", 50))
_DSCD_MAX_PROTOS          = max(1, _get_int_global("DSCD_MAX_PROTOS", 8))
_DSCD_N_MIN               = max(1, _get_int_global("DSCD_N_MIN", 2))
_DSCD_DISPERSION_THRESHOLD= float(_get_float_global("DSCD_DISPERSION_THRESHOLD", 0.70))

_ENABLE_ASBN_TRAINING     = _get_bool_global("ENABLE_ASBN_TRAINING", True)
_ENABLE_TRG_INFERENCE     = _get_bool_global("ENABLE_TRG_INFERENCE", True)
_MEMORY_CLEANUP_FREQUENCY = max(0,  _get_int_global("MEMORY_CLEANUP_FREQUENCY", 200))

_NUM_GPUS                 = max(1, _get_int_global("NUM_GPUS",
                                   torch.cuda.device_count() if torch.cuda.is_available() else 1))
_USE_GC                   = _get_bool_global("GRADIENT_CHECKPOINTING", False)
_DSCD_ENABLE_TRAINING_CLUSTERING = _get_bool_global("DSCD_ENABLE_TRAINING_CLUSTERING", True)

_LAMBDA_ASBN              = max(0.0, float(_get_float_global("LAMBDA_ASBN", 0.05)))
_LAMBDA_DSCD              = max(0.0, float(_get_float_global("LAMBDA_DSCD", 0.15)))
_VERBOSE_LOGGING          = _get_bool_global("VERBOSE_LOGGING", False)
_DEBUG_DISCOVERY          = _get_bool_global("DEBUG_DISCOVERY", False)
_DEBUG_TIMING             = _get_bool_global("DEBUG_TIMING", False)

_PERIODIC_DISCOVERY_FREQUENCY = max(1, _get_int_global("PERIODIC_DISCOVERY_FREQUENCY", 150))
_VALIDATION_CHECK_INTERVAL    = max(1, _get_int_global("VALIDATION_CHECK_INTERVAL", 500))

_SPAN_THRESHOLD           = max(0.0, min(1.0, float(_get_float_global("SPAN_THRESHOLD", 0.15))))
_UNCERTAINTY_THRESHOLD    = max(0.0, min(1.0, float(_get_float_global("UNCERTAINTY_THRESHOLD", 0.25))))
_TRG_UNCERTAINTY_THRESHOLD= max(0.0, min(1.0, float(_get_float_global("TRG_UNCERTAINTY_THRESHOLD",
                                                                       _UNCERTAINTY_THRESHOLD))))
_TAU_LOW                  = max(0.0, min(1.0, float(_get_float_global("TAU_LOW", 0.25))))

_TRAIN_DOMAIN             = _get_int_global("TRAIN_DOMAIN", 0)
_TEST_DOMAIN              = _get_int_global("TEST_DOMAIN", 1)
_USE_DOMAIN_LABELS        = _get_bool_global("USE_DOMAIN_LABELS", True)

try:
    _M2M100_EN_TOKEN_ID = int(globals().get("M2M100_EN_TOKEN_ID", 128022))
except Exception:
    _M2M100_EN_TOKEN_ID = 128022
try:
    _M2M100_BN_TOKEN_ID = int(globals().get("M2M100_BN_TOKEN_ID", 128025))
except Exception:
    _M2M100_BN_TOKEN_ID = 128025

_LABEL_SMOOTHING = max(0.0, min(1.0, float(_get_float_global("LABEL_SMOOTHING", 0.1))))
_DECODER_DROPOUT = max(0.0, min(1.0, float(_get_float_global("DECODER_DROPOUT", 0.1))))

_has_reconstruct_word_spans = "reconstruct_word_spans" in globals()


def _safe_get_last_hidden_state(enc_output):
    if enc_output is None:
        return None
    if hasattr(enc_output, "last_hidden_state"):
        return enc_output.last_hidden_state
    if isinstance(enc_output, (list, tuple)) and enc_output and isinstance(enc_output[0], torch.Tensor):
        return enc_output[0]
    return None


def build_token_word_map_sentencepiece(input_ids: torch.Tensor, tokenizer) -> List[Dict[int, str]]:
    batch_word_maps = []
    for b in range(input_ids.size(0)):
        try:
            tokens = tokenizer.convert_ids_to_tokens(input_ids[b].tolist())
        except Exception:
            tokens = [str(x) for x in input_ids[b].tolist()]
        word_map: Dict[int, Optional[str]] = {}
        current_word = ""
        word_start = 0
        for i, token in enumerate(tokens):
            if not token:
                word_map[i] = None
                continue
            if token in {"<s>", "</s>", "<pad>", "<unk>"}:
                word_map[i] = None
                continue
            if token.startswith("▁") or token.startswith("Ġ") or token.startswith("\u2581"):
                if current_word:
                    clean = current_word.replace("▁", "").replace("Ġ", "").replace("\u2581", "").strip()
                    if clean:
                        for j in range(word_start, i):
                            word_map[j] = clean
                current_word = token
                word_start = i
            else:
                current_word += token
        if current_word:
            clean = current_word.replace("▁", "").replace("Ġ", "").replace("\u2581", "").strip()
            if clean:
                for j in range(word_start, len(tokens)):
                    word_map[j] = clean
        batch_word_maps.append(word_map)
    return batch_word_maps


def _normalize_dscd_outputs(
    raw: Dict[str, Any],
    batch_size: int,
    seq_len: int,
    device: torch.device,
    embed_dim: int
) -> Dict[str, Any]:
    if not isinstance(device, torch.device):
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    defaults = {
        "h_augmented": torch.zeros(batch_size, seq_len, embed_dim,
                                   device=device, dtype=torch.float32),
        "proto_probs": [[torch.tensor([1.0], device=device, dtype=torch.float32)
                         for _ in range(seq_len)] for _ in range(batch_size)],
        "uncertainties": [[torch.tensor(0.5, device=device, dtype=torch.float32)
                           for _ in range(seq_len)] for _ in range(batch_size)],
        "gates": [[torch.tensor(0.0, device=device, dtype=torch.float32)
                   for _ in range(seq_len)] for _ in range(batch_size)],
        "span_preds": [[torch.tensor(0.0, device=device, dtype=torch.float32)
                        for _ in range(seq_len)] for _ in range(batch_size)],
        "proto_assignments": [torch.full((seq_len,), -1, dtype=torch.long, device=device)
                              for _ in range(batch_size)],
    }
    if not isinstance(raw, dict):
        return defaults

    out = defaults.copy()

    try:
        h = raw.get("h_augmented", None)
        if isinstance(h, torch.Tensor):
            if not torch.isfinite(h).all():
                if _VERBOSE_LOGGING:
                    print("⚠️  TATN: h_augmented contains NaN/Inf, zeroing out")
                h = torch.zeros_like(h)
            h = torch.clamp(h, min=-100.0, max=100.0)
            if h.shape == (batch_size, seq_len, embed_dim):
                out["h_augmented"] = h.to(device)
            else:
                try:
                    out["h_augmented"] = h.to(device).reshape(batch_size, seq_len, embed_dim)
                except Exception:
                    pass
    except Exception:
        pass

    def _norm_scalar_matrix(key: str, default_scalar: float) -> List[List[torch.Tensor]]:
        mat = raw.get(key, None)
        res = [[torch.tensor(default_scalar, device=device, dtype=torch.float32)
                for _ in range(seq_len)] for _ in range(batch_size)]

        if mat is None:
            return res

        if isinstance(mat, (list, tuple)) and len(mat) == batch_size:
            for b in range(batch_size):
                row = mat[b]
                if isinstance(row, torch.Tensor):
                    r = row.detach().to(device)
                    if r.dim() == 1:
                        for t in range(min(seq_len, r.size(0))):
                            val = r[t].float()
                            if not torch.isfinite(val):
                                val = torch.tensor(default_scalar, device=device)
                            val = torch.clamp(val, min=0.0, max=1.0)
                            res[b][t] = val
                    elif r.dim() == 2:
                        for t in range(min(seq_len, r.size(0))):
                            val = r[t, 0].float()
                            if not torch.isfinite(val):
                                val = torch.tensor(default_scalar, device=device)
                            val = torch.clamp(val, min=0.0, max=1.0)
                            res[b][t] = val
                elif isinstance(row, (list, tuple, np.ndarray)):
                    for t in range(min(seq_len, len(row))):
                        try:
                            v = row[t]
                            if isinstance(v, torch.Tensor):
                                fv = v.to(device).float()
                                if not torch.isfinite(fv):
                                    fv = torch.tensor(default_scalar, device=device)
                                fv = torch.clamp(fv, min=0.0, max=1.0)
                                res[b][t] = fv
                            else:
                                fv = float(v)
                                if not np.isfinite(fv):
                                    fv = default_scalar
                                fv = max(0.0, min(1.0, fv))
                                res[b][t] = torch.tensor(fv, device=device)
                        except Exception:
                            res[b][t] = torch.tensor(default_scalar, device=device)
            return res

        if isinstance(mat, torch.Tensor):
            m = mat.detach().to(device)
            if m.dim() == 3 and m.size(0) >= batch_size:
                for b in range(batch_size):
                    for t in range(min(seq_len, m.size(1))):
                        v = m[b, t]
                        if not torch.isfinite(v).all():
                            v = torch.tensor(default_scalar, device=device)
                        else:
                            v = v if v.numel() == 1 else v.flatten()[0]
                            v = torch.clamp(v, min=0.0, max=1.0)
                        res[b][t] = v
                return res
            if m.dim() == 2:
                if m.size(0) >= batch_size:
                    for b in range(batch_size):
                        for t in range(min(seq_len, m.size(1))):
                            v = m[b, t].float()
                            if not torch.isfinite(v):
                                v = torch.tensor(default_scalar, device=device)
                            v = torch.clamp(v, min=0.0, max=1.0)
                            res[b][t] = v
                    return res
                else:
                    for t in range(min(seq_len, m.size(0))):
                        v = m[t, 0].float()
                        if not torch.isfinite(v):
                            v = torch.tensor(default_scalar, device=device)
                        v = torch.clamp(v, min=0.0, max=1.0)
                        res[0][t] = v
                    return res

        if isinstance(mat, (list, tuple, np.ndarray)) and batch_size == 1:
            for t in range(min(seq_len, len(mat))):
                try:
                    v = mat[t]
                    if isinstance(v, torch.Tensor):
                        fv = v.to(device).float()
                        if not torch.isfinite(fv):
                            fv = torch.tensor(default_scalar, device=device)
                        fv = torch.clamp(fv, min=0.0, max=1.0)
                        res[0][t] = fv
                    else:
                        fv = float(v)
                        if not np.isfinite(fv):
                            fv = default_scalar
                        fv = max(0.0, min(1.0, fv))
                        res[0][t] = torch.tensor(fv, device=device)
                except Exception:
                    res[0][t] = torch.tensor(default_scalar, device=device)
            return res

        return res

    def _norm_proto_probs() -> List[List[torch.Tensor]]:
        mat = raw.get("proto_probs", None)
        res = defaults["proto_probs"]
        if mat is None:
            return res

        def _to_vec(x) -> torch.Tensor:
            if isinstance(x, torch.Tensor):
                v = x.detach().to(device).float().flatten()
            else:
                arr = np.asarray(x, dtype=np.float32).flatten()
                v = torch.from_numpy(arr).to(device).float()
            if v.numel() == 0:
                v = torch.tensor([1.0], device=device)
            if not torch.isfinite(v).all():
                return torch.tensor([1.0], device=device)
            v = torch.clamp(v, min=1e-9, max=1.0)
            s = v.sum()
            if not torch.isfinite(s) or s <= 1e-9:
                return torch.tensor([1.0], device=device)
            v = v / (s + 1e-9)
            v = torch.clamp(v, min=1e-9, max=1.0)
            return v

        if isinstance(mat, list) and len(mat) == batch_size:
            out_pp = []
            for b in range(batch_size):
                row = mat[b]
                row_vecs = []
                if isinstance(row, (list, tuple)):
                    for t in range(min(seq_len, len(row))):
                        row_vecs.append(_to_vec(row[t]))
                    while len(row_vecs) < seq_len:
                        row_vecs.append(torch.tensor([1.0], device=device))
                elif isinstance(row, torch.Tensor):
                    r = row.detach().to(device)
                    if r.dim() == 2:
                        for t in range(min(seq_len, r.size(0))):
                            row_vecs.append(_to_vec(r[t]))
                    elif r.dim() == 1:
                        row_vecs.append(_to_vec(r))
                        while len(row_vecs) < seq_len:
                            row_vecs.append(_to_vec(r))
                if not row_vecs:
                    row_vecs = [torch.tensor([1.0], device=device) for _ in range(seq_len)]
                out_pp.append(row_vecs[:seq_len])
            return out_pp

        if isinstance(mat, torch.Tensor) and mat.dim() == 3:
            B, T, K = mat.shape
            out_pp = []
            for b in range(min(batch_size, B)):
                row_vecs = []
                for t in range(min(seq_len, T)):
                    row_vecs.append(_to_vec(mat[b, t]))
                while len(row_vecs) < seq_len:
                    row_vecs.append(torch.tensor([1.0], device=device))
                out_pp.append(row_vecs[:seq_len])
            while len(out_pp) < batch_size:
                out_pp.append([torch.tensor([1.0], device=device) for _ in range(seq_len)])
            return out_pp

        if isinstance(mat, torch.Tensor) and mat.dim() == 2:
            B, T = mat.shape
            out_pp = []
            for b in range(min(batch_size, B)):
                row_vecs = []
                for t in range(min(seq_len, T)):
                    p = float(mat[b, t].item())
                    if not np.isfinite(p):
                        p = 0.5
                    p = max(0.0, min(1.0, p))
                    row_vecs.append(_to_vec([p, 1.0-p]))
                while len(row_vecs) < seq_len:
                    row_vecs.append(torch.tensor([1.0], device=device))
                out_pp.append(row_vecs[:seq_len])
            while len(out_pp) < batch_size:
                out_pp.append([torch.tensor([1.0], device=device) for _ in range(seq_len)])
            return out_pp

        if isinstance(mat, (list, tuple)) and batch_size == 1:
            row = mat
            row_vecs = []
            for t in range(min(seq_len, len(row))):
                row_vecs.append(_to_vec(row[t]))
            while len(row_vecs) < seq_len:
                row_vecs.append(torch.tensor([1.0], device=device))
            return [row_vecs[:seq_len]]

        return res

    out["proto_probs"]   = _norm_proto_probs()
    out["uncertainties"] = _norm_scalar_matrix("uncertainties", default_scalar=0.5)
    out["gates"]         = _norm_scalar_matrix("gates",         default_scalar=0.0)
    out["span_preds"]    = _norm_scalar_matrix("span_preds",    default_scalar=0.0)

    pa = raw.get("proto_assignments", None)
    try:
        if isinstance(pa, list) and len(pa) == batch_size:
            safe_pa = []
            for brow in pa:
                if isinstance(brow, torch.Tensor):
                    t = brow.to(device).long().view(-1)
                    if t.size(0) < seq_len:
                        t = torch.cat([t, torch.full((seq_len - t.size(0),),
                                                     -1, dtype=torch.long, device=device)], dim=0)
                    elif t.size(0) > seq_len:
                        t = t[:seq_len]
                    safe_pa.append(t)
                elif isinstance(brow, (list, tuple, np.ndarray)):
                    arr = list(brow)[:seq_len]
                    arr += [-1] * max(0, seq_len - len(arr))
                    safe_pa.append(torch.tensor(arr, dtype=torch.long, device=device))
                else:
                    safe_pa.append(torch.full((seq_len,), -1, dtype=torch.long, device=device))
            out["proto_assignments"] = safe_pa
        elif isinstance(pa, torch.Tensor):
            pa = pa.to(device).long()
            if pa.dim() == 2:
                B, T = pa.shape
                safe_pa = []
                for b in range(min(batch_size, B)):
                    row = pa[b]
                    if row.size(0) < seq_len:
                        row = torch.cat(
                            [row, torch.full((seq_len - row.size(0),), -1, dtype=torch.long, device=device)],
                            dim=0
                        )
                    elif row.size(0) > seq_len:
                        row = row[:seq_len]
                    safe_pa.append(row)
                while len(safe_pa) < batch_size:
                    safe_pa.append(torch.full((seq_len,), -1, dtype=torch.long, device=device))
                out["proto_assignments"] = safe_pa
    except Exception:
        pass

    return out


class MemoryOptimizedTATNWithExplanations(nn.Module):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.global_step = 0
        self._step_lock = threading.Lock()
        self._config_lock = threading.Lock()
        self.last_discovery_step = 0
        self.last_validation_step = 0
        self.asbn_forward_errors = 0
        self.dscd_forward_errors = 0

        try:
            self.mbart = M2M100ForConditionalGeneration.from_pretrained(
                "facebook/m2m100_418M",
                torch_dtype=torch.float32,
                use_cache=False
            )
        except Exception as e:
            raise RuntimeError(f"Failed to load m2m100 model: {e}")

        try:
            self.mbart.config.use_cache = False
            self.mbart.config.label_smoothing_factor = _LABEL_SMOOTHING
            self.mbart.config.dropout = _DECODER_DROPOUT
            if hasattr(self.mbart.config, "attention_dropout"):
                self.mbart.config.attention_dropout = _DECODER_DROPOUT
            if hasattr(self.mbart.config, "activation_dropout"):
                self.mbart.config.activation_dropout = _DECODER_DROPOUT
            if hasattr(self.mbart.config, "decoder_dropout"):
                self.mbart.config.decoder_dropout = _DECODER_DROPOUT
        except Exception:
            pass

        try:
            emb = self.mbart.get_input_embeddings()
            model_emb_count = getattr(emb, "num_embeddings", None)
            
            tok_len = getattr(tokenizer, "vocab_size", None)
            if tok_len is None and hasattr(tokenizer, "__len__"):
                try:
                    tok_len = len(tokenizer)
                except Exception:
                    tok_len = None
            
            if isinstance(model_emb_count, int) and isinstance(tok_len, int):
                if model_emb_count != tok_len:
                    print(f"⚠️  Vocab size mismatch detected: model={model_emb_count}, tokenizer={tok_len}")
                    print(f"🔧 Auto-resizing model embeddings to match tokenizer...")
                    
                    try:
                        self.mbart.resize_token_embeddings(tok_len)
                        
                        new_emb = self.mbart.get_input_embeddings()
                        new_count = getattr(new_emb, "num_embeddings", None)
                        
                        if new_count == tok_len:
                            print(f"✅ Successfully resized embeddings: {model_emb_count} → {new_count}")
                        else:
                            print(f"❌ Resize failed: expected {tok_len}, got {new_count}")
                            raise RuntimeError(f"Embedding resize verification failed: expected {tok_len}, got {new_count}")
                        
                    except Exception as e:
                        print(f"❌ FATAL: Cannot resize embeddings: {type(e).__name__}: {e}")
                        print(f"   This will cause CUDA device-side assert errors!")
                        raise RuntimeError(f"Embedding resize failed: {e}")
                else:
                    print(f"✅ Vocab sizes match: model={model_emb_count}, tokenizer={tok_len}")
            else:
                if model_emb_count is None:
                    print(f"⚠️  WARNING: Cannot verify vocab sizes (model embedding count is None)")
                if tok_len is None:
                    print(f"⚠️  WARNING: Cannot verify vocab sizes (tokenizer length is None)")
            
            self.vocab_size = tok_len if tok_len is not None else 128104
            
        except Exception as e:
            print(f"❌ Vocab size check failed: {type(e).__name__}: {e}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
            raise

        try:
            en_token_id = None
            bn_token_id = None
            if hasattr(self.tokenizer, "get_lang_id"):
                try:
                    en_token_id = self.tokenizer.get_lang_id(_TARGET_LANGUAGE)
                except Exception:
                    en_token_id = None
                try:
                    bn_token_id = self.tokenizer.get_lang_id(_SOURCE_LANGUAGE)
                except Exception:
                    bn_token_id = None
            if en_token_id is None:
                en_token_id = _M2M100_EN_TOKEN_ID
            if bn_token_id is None:
                bn_token_id = _M2M100_BN_TOKEN_ID
            
            en_token_id = int(en_token_id)
            bn_token_id = int(bn_token_id)
            
            if en_token_id >= self.vocab_size:
                print(f"⚠️  WARNING: EN token {en_token_id} >= vocab {self.vocab_size}, clamping to {self.vocab_size - 1}")
                en_token_id = self.vocab_size - 1
            if bn_token_id >= self.vocab_size:
                print(f"⚠️  WARNING: BN token {bn_token_id} >= vocab {self.vocab_size}, clamping to {self.vocab_size - 1}")
                bn_token_id = self.vocab_size - 1
            
            en_token_id = max(0, min(en_token_id, self.vocab_size - 1))
            bn_token_id = max(0, min(bn_token_id, self.vocab_size - 1))
            
            with self._config_lock:
                if hasattr(self.mbart.config, "forced_bos_token_id"):
                    self.mbart.config.forced_bos_token_id = en_token_id
                if hasattr(self.mbart.config, "decoder_start_token_id"):
                    self.mbart.config.decoder_start_token_id = en_token_id
            
            self.en_token_id = en_token_id
            self.bn_token_id = bn_token_id
            
            print(f"✅ Language tokens validated: EN={self.en_token_id}, BN={self.bn_token_id}, vocab_size={self.vocab_size}")
            
        except Exception as e:
            print(f"❌ Language token validation failed: {type(e).__name__}: {e}")
            self.en_token_id = min(_M2M100_EN_TOKEN_ID, self.vocab_size - 1)
            self.bn_token_id = min(_M2M100_BN_TOKEN_ID, self.vocab_size - 1)

        try:
            if _USE_GC and hasattr(self.mbart, "gradient_checkpointing_enable"):
                self.mbart.gradient_checkpointing_enable()
        except Exception:
            pass

        embed_dim = max(1, int(getattr(self.mbart.config, "d_model", 1024)))

        dscd_cls = globals().get("MemoryEfficientDSCDOnline", None)
        if callable(dscd_cls):
            self.dscd = dscd_cls(
                embed_dim=embed_dim,
                tokenizer=tokenizer,
                buffer_size=_DSCD_BUFFER_SIZE,
                max_protos=_DSCD_MAX_PROTOS,
                n_min=_DSCD_N_MIN,
                language=_SOURCE_LANGUAGE,
                dispersion_threshold=_DSCD_DISPERSION_THRESHOLD,
                enable_training_clustering=_DSCD_ENABLE_TRAINING_CLUSTERING,
                max_clustering_points=500,
                max_candidates_per_step=1,
            )
        else:
            raise RuntimeError("MemoryEfficientDSCDOnline not found")

        asbn_cls = globals().get("MemoryEfficientASBNModule", None)
        if callable(asbn_cls):
            try:
                self.asbn = asbn_cls(embed_dim, tokenizer, language=_SOURCE_LANGUAGE)
            except Exception:
                self.asbn = self._build_stub_asbn()
        else:
            self.asbn = self._build_stub_asbn()

        trg_cls = globals().get("CompleteTRGWithExplanations", None)
        if callable(trg_cls):
            try:
                self.trg_system = trg_cls(embed_dim, tokenizer,
                                          language=_SOURCE_LANGUAGE,
                                          dscd_module=self.dscd)
            except Exception:
                self.trg_system = self._build_stub_trg()
        else:
            self.trg_system = self._build_stub_trg()

        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print("\n" + "=" * 80)
            print("TATN Initialized - MemoryOptimizedTATNWithExplanations - STEP 124 HARDENED")
            print("=" * 80)
            print(f"  - Embed dim: {embed_dim}")
            print(f"  - Vocab size: {self.vocab_size}")
            print(f"  - EN token: {self.en_token_id}, BN token: {self.bn_token_id}")
            print(f"  - Label smoothing: {_LABEL_SMOOTHING}")
            print(f"  - Decoder dropout: {_DECODER_DROPOUT}")
            print(f"  - Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY}")
            print(f"  - Validation interval: {_VALIDATION_CHECK_INTERVAL}")
            print(f"  - Lambda ASBN: {_LAMBDA_ASBN}, Lambda DSCD: {_LAMBDA_DSCD}")
            print("  - All losses clamped: translation [0, 100], asbn [0, 10], dscd [0, 5]")
            print("  - All scalar matrices clamped to [0, 1]")
            print("  - Proto probs NaN/Inf protected with element clamp after division")
            print("  - ASBN h_aug NaN revert protection enabled")
            print("=" * 80 + "\n")

    def _build_stub_asbn(self):
        class _StubASBN(nn.Module):
            def forward(self, h, domain_labels=None, global_step=None):
                dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
                return h, torch.tensor(0.0, device=dev), 0.0
            def forward_with_grl_simplified(self, h, *args, **kwargs):
                dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
                return torch.tensor(0.0, device=dev), torch.tensor(0.0, device=dev), \
                       torch.tensor(0.0, device=dev), torch.tensor(0.0, device=dev)
            def critic_parameters(self):
                return []
            def reset_stats(self):
                pass
            def get_detailed_stats(self):
                return {"domain_loss": 0.0, "domain_accuracy": 0.0,
                        "source_accuracy": 0.0, "target_accuracy": 0.0,
                        "asbn_loss": 0.0, "num_updates": 0}
            def get_asbn_stats(self):
                return self.get_detailed_stats()
        return _StubASBN()

    def _build_stub_trg(self):
        class _StubTRG:
            def process_sentence_for_explanations(self, *args, **kwargs):
                return []
            def get_statistics(self):
                return {"explanations_generated": 0}
            def reset_statistics(self):
                pass
        return _StubTRG()

    @staticmethod
    def _entropy_reg_from_proto_probs_static(proto_probs_list, gates_list=None, min_gate: float = 0.0) -> torch.Tensor:
        if not proto_probs_list or not isinstance(proto_probs_list, list):
            return torch.tensor(0.0)

        dev = None
        for row in proto_probs_list:
            if isinstance(row, list):
                for p in row:
                    if isinstance(p, torch.Tensor):
                        dev = p.device
                        break
            if dev is not None:
                break
        if dev is None:
            dev = torch.device("cpu")

        total = torch.tensor(0.0, device=dev)
        count = 0
        for b, row in enumerate(proto_probs_list):
            if not isinstance(row, list):
                continue
            gl = gates_list[b] if (gates_list and b < len(gates_list)) else None
            for j, probs in enumerate(row):
                if not isinstance(probs, torch.Tensor) or probs.numel() == 0:
                    continue
                if gl and j < len(gl):
                    try:
                        if float(gl[j]) < min_gate:
                            continue
                    except Exception:
                        pass
                try:
                    p = torch.clamp(probs.to(dev).float(), min=1e-9, max=1.0)
                    if not torch.isfinite(p).all():
                        continue
                    s = p.sum()
                    if not torch.isfinite(s) or s <= 1e-9:
                        continue
                    p = p / (s + 1e-9)
                    H = -torch.sum(p * torch.log(p + 1e-9))
                    H = torch.clamp(H, min=0.0, max=10.0)
                    if torch.isfinite(H):
                        total = total + H
                        count += 1
                except Exception:
                    continue
        if count == 0:
            return torch.tensor(0.0, device=dev)
        return total / max(1, count)

    def _reconstruct_word_maps_before_dscd(
        self,
        input_ids: torch.Tensor,
        batch_size: int,
        seq_len: int,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None
    ) -> List[dict]:
        if token_word_map is not None and len(token_word_map) == batch_size:
            if all(isinstance(m, dict) for m in token_word_map):
                return token_word_map
        if not _has_reconstruct_word_spans:
            return build_token_word_map_sentencepiece(input_ids, self.tokenizer)
        word_maps_batch: List[dict] = []
        for b in range(batch_size):
            try:
                if src_texts and b < len(src_texts) and isinstance(src_texts[b], str):
                    text = src_texts[b]
                else:
                    text = self.tokenizer.decode(input_ids[b], skip_special_tokens=True)
                if not text.strip():
                    word_maps_batch.append({})
                    continue
                wm, _ = reconstruct_word_spans(self.tokenizer, text, max_length=seq_len)
                cleaned = {}
                for idx, word in wm.items():
                    if isinstance(idx, int) and 0 <= idx < seq_len and isinstance(word, str):
                        cleaned[idx] = word.replace("▁", "").strip()
                word_maps_batch.append(cleaned)
            except Exception:
                word_maps_batch.append({})
        return word_maps_batch

    def _extract_domain_labels(
        self,
        batch_size: int,
        device: torch.device,
        src_texts: Optional[List[str]] = None
    ) -> Optional[torch.Tensor]:
        if not _USE_DOMAIN_LABELS:
            if _VERBOSE_LOGGING:
                print("⚠️  Domain labels disabled (USE_DOMAIN_LABELS=False)")
            return None
        try:
            import random
            seed_val = self.global_step + 42
            random.seed(seed_val)
            ids = list(range(batch_size))
            random.shuffle(ids)
            n_train = max(1, batch_size // 2)
            train_idx = set(ids[:n_train])
            labels = [ _TRAIN_DOMAIN if i in train_idx else _TEST_DOMAIN for i in range(batch_size) ]
            result = torch.tensor(labels, dtype=torch.long, device=device)
            if result.size(0) != batch_size:
                print(f"❌ TATN: Domain labels size mismatch: {result.size(0)} != {batch_size}")
                return None
            return result
        except Exception as e:
            print(f"❌ TATN: Domain label extraction failed: {type(e).__name__}: {e}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
            return None

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None,
        labels: Optional[torch.Tensor] = None,
        use_dscd: bool = True,
        use_asbn: bool = True
    ):
        with self._step_lock:
            self.global_step += 1
            current_step = self.global_step

        if input_ids is None or attention_mask is None:
            raise ValueError("input_ids and attention_mask required")
        if input_ids.dim() != 2 or attention_mask.dim() != 2:
            raise ValueError("input_ids/attention_mask must be 2D")

        batch_size, seq_len = int(input_ids.size(0)), int(input_ids.size(1))
        device = input_ids.device

        if torch.cuda.is_available() and _MEMORY_CLEANUP_FREQUENCY > 0 and \
           current_step % _MEMORY_CLEANUP_FREQUENCY == 0:
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass
            if gc.isenabled():
                gc.collect()

        if self.training and _DSCD_ENABLE_TRAINING_CLUSTERING and use_dscd:
            if current_step - self.last_discovery_step >= _PERIODIC_DISCOVERY_FREQUENCY:
                try:
                    max_tokens_per = globals().get("_MAX_TOKENS_PER_DISCOVERY", 150)
                    self.dscd.periodic_discovery_check(
                        current_step,
                        _PERIODIC_DISCOVERY_FREQUENCY,
                        max_tokens_per
                    )
                except Exception as e:
                    if _VERBOSE_LOGGING:
                        print(f"❌ TATN: Periodic discovery failed: {e}")
                self.last_discovery_step = current_step

        try:
            enc_outputs = self.mbart.model.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
        except Exception as e:
            print(f"❌ TATN: Encoder forward failed: {type(e).__name__}: {e}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
            try:
                enc_outputs = self.mbart.get_encoder()(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
            except Exception as e2:
                print(f"❌ TATN: Fallback encoder also failed: {e2}")
                raise

        h = _safe_get_last_hidden_state(enc_outputs)
        if h is None or not isinstance(h, torch.Tensor) or h.dim() != 3:
            print(f"❌ TATN: Invalid encoder output h: type={type(h)}, shape={h.shape if isinstance(h, torch.Tensor) else 'N/A'}")
            try:
                h = self.mbart.get_input_embeddings()(input_ids).to(device)
            except Exception:
                h = torch.zeros(batch_size, seq_len,
                                int(getattr(self.mbart.config, "d_model", 1024)),
                                device=device)
        
        if not torch.isfinite(h).all():
            print("❌ TATN: Encoder hidden states contain NaN/Inf, zeroing out")
            h = torch.zeros_like(h)
        h = torch.clamp(h, min=-100.0, max=100.0)
        
        embed_dim = int(h.size(-1))

        training_mode = (labels is not None) and self.training

        token_word_map = self._reconstruct_word_maps_before_dscd(
            input_ids, batch_size, seq_len, src_texts, token_word_map
        )
        domain_labels = self._extract_domain_labels(batch_size, device, src_texts)

        raw_dscd = None
        if use_dscd:
            try:
                raw_dscd = self.dscd.forward(
                    h,
                    token_types=None,
                    train_mode=self.training,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_word_map=token_word_map
                )
                if not isinstance(raw_dscd, dict) or "h_augmented" not in raw_dscd:
                    print("⚠️  TATN: DSCD returned invalid output (not dict or missing h_augmented)")
                    raw_dscd = None
            except Exception as e:
                print(f"❌ TATN: DSCD forward failed: {type(e).__name__}: {e}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
                self.dscd_forward_errors += 1
                raw_dscd = None

        if raw_dscd is None:
            if _VERBOSE_LOGGING and current_step % 100 == 0:
                print(f"⚠️  TATN: Using DSCD fallback (uncertainty=0.5) at step {current_step}")
            raw_dscd = {
                "h_augmented": h.detach().clone(),
                "proto_probs": [[torch.tensor([1.0], dtype=torch.float32, device=device)
                                 for _ in range(seq_len)] for _ in range(batch_size)],
                "uncertainties": [[torch.tensor(0.5, dtype=torch.float32, device=device)
                                   for _ in range(seq_len)] for _ in range(batch_size)],
                "gates": [[torch.tensor(0.0, dtype=torch.float32, device=device)
                           for _ in range(seq_len)] for _ in range(batch_size)],
                "span_preds": [[torch.tensor(0.0, dtype=torch.float32, device=device)
                                for _ in range(seq_len)] for _ in range(batch_size)],
                "proto_assignments": [torch.full((seq_len,), -1, dtype=torch.long, device=device)
                                      for _ in range(batch_size)],
            }

        dscd = _normalize_dscd_outputs(raw_dscd, batch_size, seq_len, device, embed_dim)

        h_aug = dscd.get("h_augmented", h)
        if not isinstance(h_aug, torch.Tensor) or h_aug.shape != h.shape:
            if _VERBOSE_LOGGING:
                print(f"⚠️  TATN: h_augmented invalid, using original h")
            h_aug = h
        
        if not torch.isfinite(h_aug).all():
            print("❌ TATN: h_augmented contains NaN/Inf after DSCD, using original h")
            h_aug = h
        
        h_aug_pre_asbn = h_aug.detach().clone()

        asbn_bn_loss = torch.tensor(0.0, device=device)
        domain_accuracy = 0.0
        if use_asbn and domain_labels is not None:
            try:
                asbn_result = self.asbn.forward(h_aug, domain_labels=domain_labels, global_step=current_step)
                if isinstance(asbn_result, tuple) and len(asbn_result) == 3:
                    h_aug, asbn_bn_loss, domain_accuracy = asbn_result
                elif isinstance(asbn_result, tuple) and len(asbn_result) == 2:
                    h_aug, asbn_bn_loss = asbn_result
                    domain_accuracy = 0.0
                    if self.asbn_forward_errors < 10:
                        print(f"⚠️  TATN: ASBN returned 2 values (expected 3), domain_accuracy set to 0.0")
                        self.asbn_forward_errors += 1
                else:
                    print(f"❌ TATN: ASBN returned unexpected type: {type(asbn_result)}")
                    asbn_bn_loss = torch.tensor(0.0, device=device)
                    domain_accuracy = 0.0
                    self.asbn_forward_errors += 1
                if not torch.isfinite(asbn_bn_loss):
                    print(f"❌ TATN: ASBN BN loss is NaN/Inf: {asbn_bn_loss}")
                    asbn_bn_loss = torch.tensor(0.0, device=device)
                else:
                    asbn_bn_loss = torch.clamp(asbn_bn_loss, 0.0, 10.0)
                
                if not torch.isfinite(h_aug).all():
                    print("❌ TATN: ASBN produced NaN/Inf in h_aug, reverting to pre-ASBN state")
                    h_aug = h_aug_pre_asbn
                
            except Exception as e:
                print(f"❌ TATN: ASBN forward failed: {type(e).__name__}: {e}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
                h_aug = h_aug_pre_asbn
                asbn_bn_loss = torch.tensor(0.0, device=device)
                domain_accuracy = 0.0
                self.asbn_forward_errors += 1
        elif use_asbn and domain_labels is None:
            if current_step % 100 == 0 and _VERBOSE_LOGGING:
                print(f"⚠️  TATN: ASBN skipped (domain_labels=None) at step {current_step}")

        try:
            enc_for_decoder = BaseModelOutput(
                last_hidden_state=h_aug,
                hidden_states=getattr(enc_outputs, "hidden_states", None),
                attentions=getattr(enc_outputs, "attentions", None),
            )
        except Exception:
            enc_for_decoder = (h_aug,)

        if training_mode:
            try:
                pad_id = getattr(self.tokenizer, "pad_token_id", 1)
            except Exception:
                pad_id = 1

            try:
                bos = int(getattr(self.mbart.config, "decoder_start_token_id", self.en_token_id))
                
                if bos < 0 or bos >= self.vocab_size:
                    print(f"⚠️  WARNING: Invalid BOS token {bos}, clamping to vocab range [0, {self.vocab_size-1}]")
                    bos = max(0, min(bos, self.vocab_size - 1))
                
                bos = max(0, min(bos, self.vocab_size - 1))
                
                bos_col = torch.full((batch_size, 1), bos, dtype=torch.long, device=device)
                
                pad_mask = (labels == pad_id)
                labels_clamped = torch.clamp(labels, min=0, max=self.vocab_size - 1)
                labels_clamped[pad_mask] = -100
                
                labels_shifted = labels_clamped[:, :-1]
                
                decoder_input_ids = torch.cat([bos_col, labels_shifted], dim=1)
                
                decoder_input_ids = torch.clamp(decoder_input_ids, min=0, max=self.vocab_size - 1)
                
                decoder_attention_mask = (decoder_input_ids != pad_id).long()
                
            except Exception as e:
                print(f"❌ TATN: Decoder input construction failed: {type(e).__name__}: {e}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
                decoder_input_ids = None
                decoder_attention_mask = None

            try:
                seq_outputs = self.mbart(
                    input_ids=None,
                    attention_mask=attention_mask,
                    encoder_outputs=enc_for_decoder,
                    decoder_input_ids=decoder_input_ids,
                    decoder_attention_mask=decoder_attention_mask,
                    labels=labels_clamped,
                    use_cache=False,
                    return_dict=True,
                )
                translation_loss = seq_outputs.loss
                if translation_loss is None or not torch.isfinite(translation_loss):
                    print(f"❌ TATN: Translation loss is None or NaN/Inf: {translation_loss}")
                    translation_loss = torch.tensor(10.0, device=device)
                else:
                    translation_loss = torch.clamp(translation_loss, 0.0, 100.0)
            except Exception as e:
                print(f"❌ TATN: MBART forward failed: {type(e).__name__}: {e}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
                translation_loss = torch.tensor(10.0, device=device)

            asbn_loss = torch.tensor(0.0, device=device)
            if use_asbn and _ENABLE_ASBN_TRAINING:
                try:
                    asbn_grl_result = self.asbn.forward_with_grl_simplified(
                        h_aug,
                        dscd.get("proto_probs", None),
                        dscd.get("uncertainties", None),
                        dscd.get("gates", None),
                        token_word_map=token_word_map,
                        domain_labels=domain_labels,
                        global_step=current_step,
                    )
                    if isinstance(asbn_grl_result, tuple) and len(asbn_grl_result) >= 1:
                        asbn_loss = asbn_grl_result[0]
                    else:
                        print(f"❌ TATN: ASBN GRL returned unexpected type: {type(asbn_grl_result)}")
                        asbn_loss = torch.tensor(0.0, device=device)
                    if not isinstance(asbn_loss, torch.Tensor):
                        print(f"❌ TATN: ASBN GRL returned non-tensor: {type(asbn_loss)}")
                        asbn_loss = torch.tensor(float(asbn_loss), device=device)
                    if not torch.isfinite(asbn_loss):
                        print(f"❌ TATN: ASBN GRL loss is NaN/Inf: {asbn_loss}")
                        asbn_loss = torch.tensor(0.0, device=device)
                    else:
                        asbn_loss = torch.clamp(asbn_loss, 0.0, 10.0)
                except Exception as e:
                    print(f"❌ TATN: ASBN GRL forward failed: {type(e).__name__}: {e}")
                    if _VERBOSE_LOGGING:
                        traceback.print_exc()
                    asbn_loss = torch.tensor(0.0, device=device)
            elif use_asbn and not _ENABLE_ASBN_TRAINING:
                if current_step % 100 == 0 and _VERBOSE_LOGGING:
                    print(f"⚠️  TATN: ASBN GRL disabled (ENABLE_ASBN_TRAINING=False)")

            dscd_reg = torch.tensor(0.0, device=device)
            try:
                dscd_reg = self._entropy_reg_from_proto_probs_static(
                    dscd.get("proto_probs", []),
                    gates_list=dscd.get("gates", []),
                    min_gate=0.0,
                )
                if not isinstance(dscd_reg, torch.Tensor):
                    dscd_reg = torch.tensor(float(dscd_reg), device=device)
                if not torch.isfinite(dscd_reg):
                    print(f"❌ TATN: DSCD reg is NaN/Inf: {dscd_reg}")
                    dscd_reg = torch.tensor(0.0, device=device)
                else:
                    dscd_reg = torch.clamp(dscd_reg.to(device), 0.0, 5.0)
            except Exception as e:
                print(f"❌ TATN: DSCD entropy regularizer failed: {type(e).__name__}: {e}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
                dscd_reg = torch.tensor(0.0, device=device)

            if not torch.isfinite(translation_loss):
                print(f"❌ TATN: Translation loss NaN/Inf: {translation_loss}, resetting to 10.0")
                translation_loss = torch.tensor(10.0, device=device)
            if not torch.isfinite(asbn_loss):
                print(f"❌ TATN: ASBN loss NaN/Inf: {asbn_loss}, resetting to 0.0")
                asbn_loss = torch.tensor(0.0, device=device)
            if not torch.isfinite(asbn_bn_loss):
                print(f"❌ TATN: ASBN BN loss NaN/Inf: {asbn_bn_loss}, resetting to 0.0")
                asbn_bn_loss = torch.tensor(0.0, device=device)
            if not torch.isfinite(dscd_reg):
                print(f"❌ TATN: DSCD reg NaN/Inf: {dscd_reg}, resetting to 0.0")
                dscd_reg = torch.tensor(0.0, device=device)

            translation_loss = torch.clamp(translation_loss, 0.0, 100.0)
            asbn_loss = torch.clamp(asbn_loss, 0.0, 10.0)
            asbn_bn_loss = torch.clamp(asbn_bn_loss, 0.0, 10.0)
            dscd_reg = torch.clamp(dscd_reg, 0.0, 5.0)

            total_asbn_loss = asbn_loss + asbn_bn_loss
            total_loss = translation_loss + \
                         _LAMBDA_ASBN * total_asbn_loss + \
                         _LAMBDA_DSCD * dscd_reg
            if not isinstance(total_loss, torch.Tensor):
                total_loss = torch.tensor(float(total_loss), device=device)
            if total_loss.numel() != 1:
                total_loss = total_loss.mean()
            if not torch.isfinite(total_loss):
                print(f"❌ TATN: Total loss is NaN/Inf: {total_loss}, using translation_loss only")
                total_loss = translation_loss

            try:
                del enc_outputs, h, raw_dscd, h_aug_pre_asbn
            except Exception:
                pass
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            return {
                "loss": total_loss,
                "translation_loss": translation_loss,
                "asbn_loss": total_asbn_loss,
                "dscd_loss": dscd_reg,
                "model_output": {"loss": translation_loss},
                "domain_accuracy": torch.tensor(domain_accuracy, device=device, dtype=torch.float32),
            }

        explanations_list: List[List[Dict[str, Any]]] = [[] for _ in range(batch_size)]

        if (not self.training) and _ENABLE_TRG_INFERENCE:
            token_word_map_len = len(token_word_map) if token_word_map else 0
            if token_word_map_len != batch_size:
                if _VERBOSE_LOGGING:
                    print(f"⚠️  TATN: token_word_map length mismatch: {token_word_map_len} != {batch_size}")
            
            for b in range(batch_size):
                try:
                    tokens = self.tokenizer.convert_ids_to_tokens(input_ids[b].tolist())
                except Exception:
                    tokens = [str(x) for x in input_ids[b].tolist()]

                dscd_for_trg = {
                    "proto_probs": [dscd["proto_probs"][b]],
                    "uncertainties": [dscd["uncertainties"][b]],
                    "gates": [dscd["gates"][b]],
                    "span_preds": [dscd["span_preds"][b]],
                }
                tok_map_b = token_word_map[b] if token_word_map and b < token_word_map_len else None

                try:
                    exps = self.trg_system.process_sentence_for_explanations(
                        tokens=tokens,
                        dscd_outputs=dscd_for_trg,
                        token_word_map=tok_map_b,
                        uncertainty_threshold=_TRG_UNCERTAINTY_THRESHOLD,
                        span_threshold=_SPAN_THRESHOLD,
                        decoder_attention=None,
                        max_explanations=globals().get("MAX_EXPLANATIONS_PER_SENTENCE", 10),
                    )
                    if isinstance(exps, list):
                        explanations_list[b] = exps
                except Exception as e:
                    if _VERBOSE_LOGGING:
                        print(f"❌ TATN: TRG explanation generation failed for batch {b}: {e}")
                    explanations_list[b] = []

        outputs = {
            "encoder_outputs": enc_outputs,
            "dscd_outputs": dscd,
            "sense_augmented_embeddings": h_aug,
            "explanations": explanations_list,
            "asbn_loss": asbn_bn_loss,
            "domain_accuracy": domain_accuracy,
            "ambiguity_signals": {
                "span": dscd.get("span_preds", []),
                "uncertainty": dscd.get("uncertainties", []),
                "confidence": [
                    [
                        max(0.0, min(1.0, 1.0 - (float(u.item()) if isinstance(u, torch.Tensor) else float(u))))
                        for u in row
                    ]
                    for row in dscd.get("uncertainties", [])
                ],
                "proto_probs": dscd.get("proto_probs", []),
            },
        }

        try:
            del h, raw_dscd
        except Exception:
            pass
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        return outputs

    def forward_with_explanations(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None,
        **kwargs
    ):
        return self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            src_texts=src_texts,
            token_word_map=token_word_map,
            labels=None,
            **kwargs,
        )

    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        max_length: int = 128,
        num_beams: int = 5,
        early_stopping: bool = True,
        **kwargs
    ) -> torch.Tensor:
        device = input_ids.device
        batch_size = input_ids.size(0)
        try:
            enc_outputs = self.mbart.model.encoder(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )
            enc_wrapped = BaseModelOutput(
                last_hidden_state=_safe_get_last_hidden_state(enc_outputs),
                hidden_states=getattr(enc_outputs, "hidden_states", None),
                attentions=getattr(enc_outputs, "attentions", None),
            )
            
            forced_bos_id = getattr(self.mbart.config, "forced_bos_token_id", self.en_token_id)
            forced_bos_id = int(forced_bos_id)
            forced_bos_id = max(0, min(forced_bos_id, self.vocab_size - 1))
            
            eos_token_id = getattr(self.tokenizer, "eos_token_id", 2)
            eos_token_id = int(eos_token_id)
            eos_token_id = max(0, min(eos_token_id, self.vocab_size - 1))
            
            pad_token_id = getattr(self.tokenizer, "pad_token_id", 1)
            pad_token_id = int(pad_token_id)
            pad_token_id = max(0, min(pad_token_id, self.vocab_size - 1))

            gen_kwargs = dict(
                input_ids=None,
                attention_mask=attention_mask,
                encoder_outputs=enc_wrapped,
                max_length=min(max_length, 100),
                min_length=1,
                num_beams=min(num_beams, 4),
                early_stopping=True,
                no_repeat_ngram_size=3,
                repetition_penalty=3.0,
                length_penalty=1.2,
                do_sample=False,
                forced_bos_token_id=forced_bos_id,
                eos_token_id=eos_token_id,
                pad_token_id=pad_token_id,
                num_return_sequences=1,
                output_scores=False,
                return_dict_in_generate=False,
            )
            gen_kwargs.update(kwargs)
            outputs = self.mbart.generate(**gen_kwargs)
            if outputs.size(1) > max_length:
                outputs = outputs[:, :max_length]
            if outputs.size(1) < 2:
                print(f"⚠️  TATN: Generate returned short output: {outputs.shape}")
            return outputs
        except Exception as e:
            print(f"❌ TATN: Generate failed: {type(e).__name__}: {e}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
            fallback_bos = getattr(self.tokenizer, "bos_token_id", self.en_token_id)
            fallback_bos = max(0, min(int(fallback_bos), self.vocab_size - 1))
            fallback_eos = getattr(self.tokenizer, "eos_token_id", 2)
            fallback_eos = max(0, min(int(fallback_eos), self.vocab_size - 1))
            return torch.tensor([[fallback_bos, fallback_eos]], dtype=torch.long, device=device).expand(batch_size, 2)

    def get_component_stats(self) -> Dict[str, Any]:
        stats: Dict[str, Any] = {
            "global_step": self.global_step,
            "last_discovery_step": self.last_discovery_step,
            "last_validation_step": self.last_validation_step,
            "asbn_forward_errors": self.asbn_forward_errors,
            "dscd_forward_errors": self.dscd_forward_errors,
        }
        try:
            stats["dscd"] = {
                "total_tokens": len(self.dscd.prototype_stores),
                "total_prototypes": sum(store.size() for store in self.dscd.prototype_stores.values()),
                "num_homographs": len(self.dscd.discovered_homographs),
            }
        except Exception:
            stats["dscd"] = {"total_tokens": 0, "total_prototypes": 0, "num_homographs": 0}
        try:
            stats["asbn"] = (
                self.asbn.get_detailed_stats()
                if hasattr(self.asbn, "get_detailed_stats") else {}
            )
        except Exception:
            stats["asbn"] = {}
        try:
            stats["trg"] = (
                self.trg_system.get_statistics()
                if hasattr(self.trg_system, "get_statistics") else {}
            )
        except Exception:
            stats["trg"] = {}
        return stats


print("\n" + "=" * 80)
print("Cell 6: TATN Ready - STEP 124 NaN/Inf FIX APPLIED")
print("=" * 80)
print("Config:")
print(f"  - Source: {_SOURCE_LANGUAGE}, Target: {_TARGET_LANGUAGE}")
print(f"  - Label smoothing: {_LABEL_SMOOTHING}")
print(f"  - Decoder dropout: {_DECODER_DROPOUT}")
print(f"  - DSCD clustering enabled: {_DSCD_ENABLE_TRAINING_CLUSTERING}")
print(f"  - ASBN training enabled: {_ENABLE_ASBN_TRAINING}")
print(f"  - TRG inference enabled: {_ENABLE_TRG_INFERENCE}")
print(f"  - Discovery freq: {_PERIODIC_DISCOVERY_FREQUENCY}")
print(f"  - λ_ASBN: {_LAMBDA_ASBN}, λ_DSCD: {_LAMBDA_DSCD}")
print("NaN/Inf Protections Applied:")
print("  ✅ Encoder h NaN check + clamp [-100, 100]")
print("  ✅ h_augmented NaN zeroing after DSCD")
print("  ✅ h_aug_pre_asbn saved before ASBN forward")
print("  ✅ ASBN h_aug NaN → REVERT to h_aug_pre_asbn (CRITICAL FIX)")
print("  ✅ Proto probs element clamp AFTER division normalization")
print("  ✅ All scalar matrices clamped [0, 1] with NaN check")
print("  ✅ All losses clamped: translation [0, 100], asbn [0, 10], dscd [0, 5]")
print("  ✅ Decoder input IDs clamped to [0, vocab_size-1]")
print("  ✅ BOS/EOS/PAD tokens validated and bounded")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 7: TRAINING LOOP - NaN/Inf GRADIENT FULLY HARDENED + LABEL VALIDATION
# ==============================================================================
import os
import time
import math
import gc
import traceback
from datetime import datetime
from pathlib import Path
from collections import defaultdict, deque
from typing import Optional, Dict, Any, List

import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast as cuda_amp_autocast
from tqdm import tqdm
from contextlib import nullcontext
import threading

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except Exception:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except Exception:
    _DEBUG_DISCOVERY = False

DEBUG_PRINT_INTERVAL = 200
_cell7_dbg_counts = defaultdict(int)

_CELL7_WORDMAP_BUILT_COUNT = 0
_CELL7_WORDMAP_PROVIDED_COUNT = 0
_CELL7_FORWARD_CALL_COUNT = 0
_CELL7_BACKWARD_SUCCESS_COUNT = 0


def cell7_dbg(key: str, msg: str, limit: int = 10):
    if not (_VERBOSE_LOGGING or _DEBUG_DISCOVERY):
        return
    _cell7_dbg_counts[key] += 1
    if _cell7_dbg_counts[key] <= limit:
        print(f"[CELL7-DBG] {msg}")


try:
    _DEVICE = DEVICE
    if not isinstance(_DEVICE, torch.device):
        _DEVICE = torch.device(str(_DEVICE))
except Exception:
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    _EPOCHS = int(EPOCHS)
except Exception:
    _EPOCHS = 1

try:
    _BATCH_SIZE = int(BATCH_SIZE)
except Exception:
    _BATCH_SIZE = 8

try:
    _ACCUMULATION_STEPS = int(ACCUMULATION_STEPS)
except Exception:
    _ACCUMULATION_STEPS = 1

try:
    _GRAD_CLIP_NORM = float(GRAD_CLIP_NORM)
except Exception:
    _GRAD_CLIP_NORM = 1.0

try:
    _MEMORY_CLEANUP_FREQUENCY = int(MEMORY_CLEANUP_FREQUENCY)
except Exception:
    _MEMORY_CLEANUP_FREQUENCY = 500

try:
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
    _NUM_GPUS = int(NUM_GPUS)
except Exception:
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1

try:
    _USE_AMP = bool(USE_AMP)
except Exception:
    _USE_AMP = True

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
    _TARGET_LANGUAGE = str(TARGET_LANGUAGE)
except Exception:
    _SOURCE_LANGUAGE = "bn"
    _TARGET_LANGUAGE = "en"

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except Exception:
    _MAX_LENGTH = 48

try:
    _VALIDATION_CHECK_INTERVAL = int(VALIDATION_CHECK_INTERVAL)
except Exception:
    _VALIDATION_CHECK_INTERVAL = 500

try:
    _PERIODIC_DISCOVERY_FREQUENCY = int(PERIODIC_DISCOVERY_FREQUENCY)
except Exception:
    _PERIODIC_DISCOVERY_FREQUENCY = 150

try:
    _TRAIN_DOMAIN = int(TRAIN_DOMAIN)
    _TEST_DOMAIN = int(TEST_DOMAIN)
except Exception:
    _TRAIN_DOMAIN = 0
    _TEST_DOMAIN = 1

try:
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in HOMOGRAPH_REFERENCE_LIST_BN)
except Exception:
    _HOMOGRAPH_REFERENCE_LIST = {
        "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার",
        "তারা", "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত",
    }
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in _HOMOGRAPH_REFERENCE_LIST)


def clear_all_gpu_caches():
    gc.collect()
    if not torch.cuda.is_available():
        return
    try:
        for i in range(torch.cuda.device_count()):
            with torch.cuda.device(i):
                try:
                    torch.cuda.empty_cache()
                except Exception:
                    pass
    except Exception:
        pass


def get_amp_ctx():
    if not _USE_AMP or not torch.cuda.is_available():
        return nullcontext()
    try:
        return cuda_amp_autocast(enabled=True)
    except Exception:
        return nullcontext()


_PROTOBUF_COMPAT_ERROR_SHOWN = globals().get("_PROTOBUF_COMPAT_ERROR_SHOWN", False)


def _build_token_to_word_map(tokenizer, input_ids):
    batch_word_maps = []
    for batch_idx in range(input_ids.size(0)):
        try:
            tokens = tokenizer.convert_ids_to_tokens(input_ids[batch_idx].tolist())
        except Exception:
            tokens = [str(x) for x in input_ids[batch_idx].tolist()]

        word_map = {}
        current_word = ""
        word_start_idx = 0

        for i, token in enumerate(tokens):
            if not token or token in ["<s>", "</s>", "<pad>", "<unk>"]:
                word_map[i] = None
                continue

            if token.startswith("▁") or token.startswith("\u2581") or token.startswith("Ġ"):
                if current_word:
                    clean_word = current_word.replace("▁", "").replace("\u2581", "").replace("Ġ", "").strip()
                    if clean_word:
                        for j in range(word_start_idx, i):
                            word_map[j] = clean_word
                current_word = token
                word_start_idx = i
            else:
                current_word += token

        if current_word:
            clean_word = current_word.replace("▁", "").replace("\u2581", "").replace("Ġ", "").strip()
            if clean_word:
                for j in range(word_start_idx, len(tokens)):
                    word_map[j] = clean_word

        batch_word_maps.append(word_map)
    return batch_word_maps


def _get_dscd_homographs(model: torch.nn.Module) -> set:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return set()

        lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
        if lock:
            with lock:
                stores = dict(getattr(dscd, "prototype_stores", {}) or {})
        else:
            stores = dict(getattr(dscd, "prototype_stores", {}) or {})

        word_prototype_counts = defaultdict(int)
        for token_key, store in stores.items():
            try:
                num_protos = 0
                if hasattr(store, "size") and callable(getattr(store, "size")):
                    try:
                        num_protos = int(store.size())
                    except Exception:
                        num_protos = 0
                else:
                    cent = getattr(store, "centroids", None)
                    try:
                        num_protos = len(cent) if cent is not None else 0
                    except Exception:
                        num_protos = 0

                clean_token = (
                    str(token_key)
                    .replace("▁", "")
                    .replace("Ġ", "")
                    .replace("##", "")
                    .replace("@@", "")
                    .replace("</w>", "")
                    .strip()
                    .lower()
                )
                if clean_token:
                    word_prototype_counts[clean_token] = max(word_prototype_counts[clean_token], num_protos)
            except Exception:
                continue

        return {w for w, c in word_prototype_counts.items() if c >= 2}
    except Exception:
        return set()


def _print_gpu_mem(prefix: str = ""):
    if not torch.cuda.is_available():
        return
    try:
        lines = [f"{prefix} GPU mem (GB):"]
        for i in range(torch.cuda.device_count()):
            try:
                alloc = torch.cuda.memory_allocated(i) / (1024**3)
                resv = torch.cuda.memory_reserved(i) / (1024**3)
                lines.append(f"  GPU {i}: alloc={alloc:.2f} resv={resv:.2f}")
            except Exception:
                lines.append(f"  GPU {i}: mem query failed")
        print("\n".join(lines))
    except Exception:
        pass


def _get_cluster_count(model: torch.nn.Module) -> int:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return 0
        stores = getattr(dscd, "prototype_stores", None) or {}
        return len(stores)
    except Exception:
        return 0


def _get_dscd_safe(model: torch.nn.Module):
    try:
        core = model.module if hasattr(model, "module") else model
        return getattr(core, "dscd", None)
    except Exception:
        return None


def _print_top_clusters(model: torch.nn.Module, top_n: int = 5):
    try:
        dscd = _get_dscd_safe(model)
        if dscd is None:
            return
        lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
        if lock:
            with lock:
                stores = dict(getattr(dscd, "prototype_stores", {}) or {})
        else:
            stores = dict(getattr(dscd, "prototype_stores", {}) or {})

        items = []
        for token, store in stores.items():
            try:
                total_count = sum(getattr(store, "counts", []) or [])
                n_protos = int(store.size()) if hasattr(store, "size") and callable(getattr(store, "size")) else (len(getattr(store, "centroids", [])) if getattr(store, "centroids", None) is not None else 0)
                items.append((token, total_count, n_protos))
            except Exception:
                continue
        items.sort(key=lambda x: x[1], reverse=True)
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print("[CLUSTER-DBG] Top clusters:")
            for i, (tok, cnt, prot) in enumerate(items[:top_n], 1):
                tok_str = str(tok)[:20]
                print(f"  {i:2d}. {tok_str:20s} samples={cnt:4d} protos={prot}")
    except Exception:
        pass


def _check_discovery_status(model: torch.nn.Module, global_step: int):
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return
        if hasattr(dscd, "discovered_log") and dscd.discovered_log:
            total_discovered = len(dscd.discovered_log)
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                print(f"[DISCOVERY-STATUS] Step {global_step}: {total_discovered} discovery events")
    except Exception:
        pass


def _check_gradients(model: torch.nn.Module, global_step: int):
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        asbn = getattr(core, "asbn", None)
        dscd_grad_count = 0
        dscd_total_params = 0
        if dscd is not None:
            for p in dscd.parameters():
                dscd_total_params += 1
                if p.grad is not None and torch.isfinite(p.grad).all():
                    dscd_grad_count += 1
        asbn_grad_count = 0
        asbn_total_params = 0
        if asbn is not None:
            for p in asbn.parameters():
                asbn_total_params += 1
                if p.grad is not None and torch.isfinite(p.grad).all():
                    asbn_grad_count += 1
        dscd_status = "✓" if dscd_grad_count > 0 else "✗ NO GRADS"
        asbn_status = "✓" if asbn_grad_count > 0 else "✗ NO GRADS"
        if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
            print(f"[GRAD-CHECK] Step {global_step}: DSCD {dscd_status} ({dscd_grad_count}/{dscd_total_params}) | ASBN {asbn_status} ({asbn_grad_count}/{asbn_total_params})")
    except Exception:
        pass


def _extract_loss_components(forward_out, device) -> Dict[str, float]:
    components = {
        "total_loss": 0.0,
        "translation_loss": 0.0,
        "asbn_loss": 0.0,
        "dscd_loss": 0.0,
        "domain_accuracy": 0.0,
    }
    
    if forward_out is None or not isinstance(forward_out, dict):
        return components
    
    try:
        if "translation_loss" in forward_out:
            tl = forward_out["translation_loss"]
            if isinstance(tl, torch.Tensor):
                if not torch.isfinite(tl).all():
                    print(f"⚠️  [EXTRACT] translation_loss contains NaN/Inf, setting to 0.0")
                    tl = torch.tensor(0.0, device=device)
                tl = torch.clamp(tl, min=0.0, max=100.0)
                components["translation_loss"] = float(tl.mean().item()) if tl.numel() > 1 else float(tl.item())
            elif tl is not None:
                try:
                    val = float(tl)
                    if not np.isfinite(val):
                        val = 0.0
                    components["translation_loss"] = max(0.0, min(100.0, val))
                except Exception:
                    pass
    except Exception:
        pass
    
    try:
        if "asbn_loss" in forward_out:
            al = forward_out["asbn_loss"]
            if isinstance(al, torch.Tensor):
                if not torch.isfinite(al).all():
                    print(f"⚠️  [EXTRACT] asbn_loss contains NaN/Inf, setting to 0.0")
                    al = torch.tensor(0.0, device=device)
                al = torch.clamp(al, min=0.0, max=10.0)
                components["asbn_loss"] = float(al.mean().item()) if al.numel() > 1 else float(al.item())
            elif al is not None:
                try:
                    val = float(al)
                    if not np.isfinite(val):
                        val = 0.0
                    components["asbn_loss"] = max(0.0, min(10.0, val))
                except Exception:
                    pass
    except Exception:
        pass
    
    try:
        if "dscd_loss" in forward_out:
            dl = forward_out["dscd_loss"]
            if isinstance(dl, torch.Tensor):
                if not torch.isfinite(dl).all():
                    print(f"⚠️  [EXTRACT] dscd_loss contains NaN/Inf, setting to 0.0")
                    dl = torch.tensor(0.0, device=device)
                dl = torch.clamp(dl, min=0.0, max=5.0)
                components["dscd_loss"] = float(dl.mean().item()) if dl.numel() > 1 else float(dl.item())
            elif dl is not None:
                try:
                    val = float(dl)
                    if not np.isfinite(val):
                        val = 0.0
                    components["dscd_loss"] = max(0.0, min(5.0, val))
                except Exception:
                    pass
    except Exception:
        pass
    
    try:
        if "domain_accuracy" in forward_out:
            da = forward_out["domain_accuracy"]
            if isinstance(da, torch.Tensor):
                if not torch.isfinite(da).all():
                    da = torch.tensor(0.0, device=device)
                da = torch.clamp(da, min=0.0, max=1.0)
                components["domain_accuracy"] = float(da.mean().item()) if da.numel() > 1 else float(da.item())
            elif isinstance(da, (int, float)):
                val = float(da)
                if not np.isfinite(val):
                    val = 0.0
                components["domain_accuracy"] = max(0.0, min(1.0, val))
    except Exception:
        pass
    
    components["total_loss"] = (
        components["translation_loss"] + 
        components["asbn_loss"] + 
        components["dscd_loss"]
    )
    
    if components["total_loss"] == 0.0 and "loss" in forward_out:
        try:
            main_loss = forward_out["loss"]
            if isinstance(main_loss, torch.Tensor):
                if not torch.isfinite(main_loss).all():
                    print(f"⚠️  [EXTRACT] main loss contains NaN/Inf, setting to 0.0")
                    main_loss = torch.tensor(0.0, device=device)
                main_loss = torch.clamp(main_loss, min=0.0, max=100.0)
                components["total_loss"] = float(main_loss.mean().item()) if main_loss.numel() > 1 else float(main_loss.item())
                if components["translation_loss"] == 0.0:
                    components["translation_loss"] = components["total_loss"]
        except Exception:
            pass
    
    return components


def _clear_all_gradients_aggressively(model, optimizer, phi_optimizer):
    try:
        optimizer.zero_grad(set_to_none=True)
        if phi_optimizer is not None:
            phi_optimizer.zero_grad(set_to_none=True)
    except Exception:
        pass
    
    try:
        for param in model.parameters():
            param.grad = None
    except Exception:
        pass
    
    try:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    except Exception:
        pass


def _sanitize_labels(labels: torch.Tensor, vocab_size: int, pad_token_id: int, device: torch.device) -> torch.Tensor:
    labels = labels.clone()
    
    labels[labels >= vocab_size] = -100
    labels[labels < -100] = -100
    
    if pad_token_id is not None and pad_token_id >= 0:
        labels[labels == pad_token_id] = -100
    
    return labels


def train_memory_efficient_tatn(
    model: torch.nn.Module,
    tokenizer,
    train_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    phi_optimizer: Optional[torch.optim.Optimizer] = None,
    epochs: Optional[int] = None,
    accumulation_steps: Optional[int] = None,
    validate_every: Optional[int] = None,
    enable_validation: bool = True,
    enable_asbn_training: bool = True,
) -> torch.nn.Module:
    global _CELL7_WORDMAP_BUILT_COUNT, _CELL7_WORDMAP_PROVIDED_COUNT
    global _CELL7_FORWARD_CALL_COUNT, _CELL7_BACKWARD_SUCCESS_COUNT

    if epochs is None:
        epochs = _EPOCHS
    if accumulation_steps is None:
        accumulation_steps = _ACCUMULATION_STEPS
    if validate_every is None:
        validate_every = _VALIDATION_CHECK_INTERVAL

    try:
        vocab_size = len(tokenizer)
    except Exception:
        try:
            vocab_size = tokenizer.vocab_size
        except Exception:
            vocab_size = 128104
            print(f"⚠️  [TRAIN] Cannot determine vocab_size, using default: {vocab_size}")
    
    try:
        pad_token_id = tokenizer.pad_token_id
        if pad_token_id is None:
            pad_token_id = tokenizer.eos_token_id if hasattr(tokenizer, "eos_token_id") else 1
    except Exception:
        pad_token_id = 1

    print(f"[TRAIN] Starting training: epochs={epochs}, batch={_BATCH_SIZE}, accum_steps={accumulation_steps}")
    print(f"[TRAIN] Vocab size: {vocab_size}, Pad token ID: {pad_token_id}")
    print(f"[TRAIN] Validation: {'enabled' if enable_validation and validate_every > 0 else 'disabled'}")
    print(f"[TRAIN] ASBN Training: {'ENABLED' if enable_asbn_training and phi_optimizer is not None else 'DISABLED'}")
    print(f"[TRAIN] DP enabled: {_USE_MULTI_GPU}, GPUs: {_NUM_GPUS}, Device: {_DEVICE}")
    print(f"[TRAIN] Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY} steps")
    print(f"[TRAIN] Gradient clip norm: {_GRAD_CLIP_NORM}")

    model.train()
    clear_all_gpu_caches()
    scaler = GradScaler(enabled=_USE_AMP and torch.cuda.is_available())

    global_step = 0
    accumulated_steps = 0
    pending_validation = False

    training_stats: Dict[str, Any] = {
        "total_loss": [],
        "translation_losses": [],
        "asbn_losses": [],
        "dscd_losses": [],
        "domain_accuracies": [],
        "epoch_losses": [],
        "backward_losses": [],
        "batches_processed": 0,
        "optimizer_updates": 0,
        "asbn_updates": 0,
        "skipped_batches": 0,
        "oom_errors": 0,
        "runtime_errors": 0,
        "exceptions": 0,
        "epoch_validations": [],
        "dscd_quality_history": [],
        "multi_sense_ratio_history": [],
        "asbn_domain_accuracy_history": [],
        "trg_explanation_history": [],
        "gradient_checks": [],
        "loss_component_breakdown": [],
        "nan_gradient_events": 0,
        "extreme_gradient_events": 0,
        "label_corruption_fixes": 0,
    }

    last_forward_loss = 0.0
    last_backward_loss = 0.0
    last_asbn_loss = 0.0
    last_translation_loss = 0.0
    last_dscd_loss = 0.0
    last_domain_accuracy = 0.0

    for epoch in range(1, epochs + 1):
        epoch_start = time.time()
        epoch_losses: List[float] = []
        skip_reasons = defaultdict(int)

        print("\n" + "=" * 80)
        print(f"EPOCH {epoch}/{epochs} STARTED")
        print("=" * 80)

        model.train()

        try:
            core = model.module if hasattr(model, "module") else model
            trg = getattr(core, "trg_system", None)
            if trg and hasattr(trg, "reset_statistics"):
                try:
                    trg.reset_statistics()
                except Exception:
                    pass
            asbn = getattr(core, "asbn", None)
            if asbn and hasattr(asbn, "reset_stats"):
                try:
                    asbn.reset_stats()
                except Exception:
                    pass
        except Exception:
            pass

        try:
            optimizer.zero_grad(set_to_none=True)
            if phi_optimizer is not None:
                phi_optimizer.zero_grad(set_to_none=True)
        except Exception:
            pass

        progress = None
        try:
            progress = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", dynamic_ncols=True)
            for batch_idx, batch in enumerate(progress):
                global_step += 1
                training_stats["batches_processed"] += 1

                if _DEBUG_DISCOVERY and global_step % DEBUG_PRINT_INTERVAL == 0:
                    _check_discovery_status(model, global_step)

                if global_step % 50 == 0 and (_DEBUG_DISCOVERY or _VERBOSE_LOGGING):
                    try:
                        core_model = model.module if hasattr(model, "module") else model
                        dscd = getattr(core_model, "dscd", None)
                        if dscd and hasattr(dscd, "buffers"):
                            lock = getattr(dscd, "buffer_lock", None)
                            if lock:
                                with lock:
                                    buffer_sizes = [len(dscd.buffers[t]) for t in dscd.buffers]
                            else:
                                buffer_sizes = [len(dscd.buffers[t]) for t in dscd.buffers]
                            avg_size = sum(buffer_sizes) / len(buffer_sizes) if buffer_sizes else 0
                            ready_for_discovery = sum(1 for s in buffer_sizes if s >= max(2, getattr(dscd, "n_min", 2)))
                            print(f"[BUFFER] Step {global_step}: types={len(buffer_sizes)} avg_size={avg_size:.1f} ready={ready_for_discovery}")
                    except Exception:
                        pass

                if batch is None:
                    training_stats["skipped_batches"] += 1
                    skip_reasons["batch_none"] += 1
                    if global_step % 100 == 0:
                        print(f"⚠️  [TRAIN] Step {global_step}: Batch is None")
                    continue

                try:
                    input_ids = batch.get("input_ids", None) if isinstance(batch, dict) else None
                    attention_mask = batch.get("attention_mask", None) if isinstance(batch, dict) else None
                    labels = batch.get("labels", None) if isinstance(batch, dict) else None

                    if input_ids is None or attention_mask is None or labels is None:
                        training_stats["skipped_batches"] += 1
                        skip_reasons["missing_fields"] += 1
                        if global_step % 100 == 0:
                            missing = []
                            if input_ids is None:
                                missing.append("input_ids")
                            if attention_mask is None:
                                missing.append("attention_mask")
                            if labels is None:
                                missing.append("labels")
                            print(f"⚠️  [TRAIN] Step {global_step}: Missing fields: {missing}")
                        continue

                    if _USE_MULTI_GPU and _NUM_GPUS > 0:
                        bsz = int(input_ids.size(0))
                        keep = (bsz // _NUM_GPUS) * _NUM_GPUS
                        if keep == 0:
                            training_stats["skipped_batches"] += 1
                            skip_reasons["dp_keep_zero"] += 1
                            if global_step % 100 == 0:
                                print(f"⚠️  [TRAIN] Step {global_step}: Batch size {bsz} too small for {_NUM_GPUS} GPUs")
                            continue
                        if keep != bsz:
                            input_ids = input_ids[:keep]
                            attention_mask = attention_mask[:keep]
                            labels = labels[:keep]

                    input_ids = input_ids.to(_DEVICE, non_blocking=True)
                    attention_mask = attention_mask.to(_DEVICE, non_blocking=True)
                    labels = labels.to(_DEVICE, non_blocking=True)

                    if input_ids.size(0) == 0:
                        training_stats["skipped_batches"] += 1
                        skip_reasons["empty_batch"] += 1
                        if global_step % 100 == 0:
                            print(f"⚠️  [TRAIN] Step {global_step}: Empty batch after processing")
                        continue

                    labels = _sanitize_labels(labels, vocab_size, pad_token_id, _DEVICE)
                    training_stats["label_corruption_fixes"] += 1
                    
                    if (labels == -100).all():
                        training_stats["skipped_batches"] += 1
                        skip_reasons["all_labels_ignored"] += 1
                        if global_step % 100 == 0:
                            print(f"⚠️  [TRAIN] Step {global_step}: All labels are -100 (padding only)")
                        continue
                    
                    valid_label_count = (labels != -100).sum().item()
                    if valid_label_count == 0:
                        training_stats["skipped_batches"] += 1
                        skip_reasons["no_valid_labels"] += 1
                        if global_step % 100 == 0:
                            print(f"⚠️  [TRAIN] Step {global_step}: No valid labels in batch")
                        continue

                    token_word_map_for_batch = batch.get("token_word_map", None) if isinstance(batch, dict) else None
                    if token_word_map_for_batch is None or not isinstance(token_word_map_for_batch, list):
                        try:
                            token_word_map_for_batch = _build_token_to_word_map(tokenizer, input_ids)
                            _CELL7_WORDMAP_BUILT_COUNT += 1
                        except Exception:
                            token_word_map_for_batch = [{} for _ in range(input_ids.size(0))]
                    else:
                        _CELL7_WORDMAP_PROVIDED_COUNT += 1

                    if not token_word_map_for_batch or len(token_word_map_for_batch) != input_ids.size(0):
                        token_word_map_for_batch = [{} for _ in range(input_ids.size(0))]

                    _CELL7_FORWARD_CALL_COUNT += 1

                    forward_kwargs = {
                        "input_ids": input_ids,
                        "attention_mask": attention_mask,
                        "labels": labels,
                        "src_texts": batch.get("src_text", None) if isinstance(batch, dict) else None,
                        "token_word_map": token_word_map_for_batch,
                    }

                    with get_amp_ctx():
                        forward_out = model(**forward_kwargs)

                    loss_tensor = None
                    if isinstance(forward_out, torch.Tensor):
                        loss_tensor = forward_out
                    elif isinstance(forward_out, dict):
                        loss_candidates = ("loss", "total_loss", "translation_loss")
                        for k in loss_candidates:
                            if k in forward_out:
                                loss_tensor = forward_out.get(k)
                                break
                        if loss_tensor is None and "model_output" in forward_out and isinstance(forward_out["model_output"], dict) and "loss" in forward_out["model_output"]:
                            loss_tensor = forward_out["model_output"]["loss"]
                    elif isinstance(forward_out, (list, tuple)) and len(forward_out) > 0 and isinstance(forward_out[0], torch.Tensor):
                        loss_tensor = forward_out[0]

                    if loss_tensor is None:
                        print(f"❌ [TRAIN] Step {global_step}: Forward returned None loss! forward_out type={type(forward_out)}, keys={forward_out.keys() if isinstance(forward_out, dict) else 'N/A'}")
                        training_stats["skipped_batches"] += 1
                        skip_reasons["none_loss"] += 1
                        _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                        accumulated_steps = 0
                        continue

                    loss_components = _extract_loss_components(forward_out, _DEVICE)
                    last_translation_loss = loss_components["translation_loss"]
                    last_asbn_loss = loss_components["asbn_loss"]
                    last_dscd_loss = loss_components["dscd_loss"]
                    last_domain_accuracy = loss_components["domain_accuracy"]
                    
                    training_stats["translation_losses"].append(last_translation_loss)
                    training_stats["asbn_losses"].append(last_asbn_loss)
                    training_stats["dscd_losses"].append(last_dscd_loss)
                    training_stats["domain_accuracies"].append(last_domain_accuracy)
                    training_stats["loss_component_breakdown"].append(loss_components)

                    if not isinstance(loss_tensor, torch.Tensor):
                        try:
                            loss_tensor = torch.tensor(float(loss_tensor), device=_DEVICE, dtype=torch.float32)
                        except Exception:
                            print(f"❌ [TRAIN] Step {global_step}: Cannot convert loss to tensor: {type(loss_tensor)}")
                            training_stats["skipped_batches"] += 1
                            skip_reasons["loss_convert_fail"] += 1
                            _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                            accumulated_steps = 0
                            continue
                    else:
                        loss_tensor = loss_tensor.to(_DEVICE)

                    if loss_tensor.numel() > 1:
                        loss_tensor = loss_tensor.mean()

                    if not torch.isfinite(loss_tensor):
                        print(f"❌ [TRAIN] Step {global_step}: NaN/Inf loss BEFORE backward! trans={last_translation_loss:.4f} asbn={last_asbn_loss:.4f} dscd={last_dscd_loss:.4f}")
                        training_stats["skipped_batches"] += 1
                        skip_reasons["nan_loss_pre_backward"] += 1
                        _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                        accumulated_steps = 0
                        continue

                    loss_tensor = torch.clamp(loss_tensor, min=0.0, max=100.0)
                    loss_val = float(loss_tensor.item())

                    last_forward_loss = loss_val
                    epoch_losses.append(loss_val)
                    training_stats["total_loss"].append(loss_val)

                    accum_steps_safe = max(1, accumulation_steps)
                    loss_scaled = loss_tensor / (accum_steps_safe + 1e-9)
                    
                    if not torch.isfinite(loss_scaled):
                        print(f"❌ [TRAIN] Step {global_step}: loss_scaled is NaN/Inf after division")
                        training_stats["skipped_batches"] += 1
                        skip_reasons["nan_loss_scaled"] += 1
                        _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                        accumulated_steps = 0
                        continue
                    
                    loss_scaled = torch.clamp(loss_scaled, min=0.0, max=50.0)
                    last_backward_loss = float(loss_scaled.item())
                    training_stats["backward_losses"].append(last_backward_loss)

                    try:
                        if scaler.is_enabled():
                            scaler.scale(loss_scaled).backward()
                        else:
                            loss_scaled.backward()
                        _CELL7_BACKWARD_SUCCESS_COUNT += 1
                        
                        has_grads = False
                        has_nan_grads = False
                        has_extreme_grads = False
                        max_grad_norm_before_clip = 0.0
                        
                        for p in model.parameters():
                            if p.grad is not None:
                                has_grads = True
                                if not torch.isfinite(p.grad).all():
                                    has_nan_grads = True
                                    break
                                grad_norm = p.grad.data.norm(2).item()
                                max_grad_norm_before_clip = max(max_grad_norm_before_clip, grad_norm)
                                if grad_norm > 100.0:
                                    has_extreme_grads = True
                        
                        if not has_grads:
                            print(f"⚠️  [TRAIN] Step {global_step}: Backward completed but NO gradients created")
                            training_stats["skipped_batches"] += 1
                            skip_reasons["no_grads"] += 1
                            _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                            accumulated_steps = 0
                            continue
                        
                        if has_nan_grads:
                            print(f"❌ [TRAIN] Step {global_step}: Backward created NaN/Inf gradients")
                            training_stats["skipped_batches"] += 1
                            training_stats["nan_gradient_events"] += 1
                            skip_reasons["nan_grads_post_backward"] += 1
                            _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                            accumulated_steps = 0
                            continue
                        
                        if has_extreme_grads:
                            if _VERBOSE_LOGGING:
                                print(f"⚠️  [TRAIN] Step {global_step}: Extreme gradient detected (max_norm={max_grad_norm_before_clip:.2f}), will clip")
                            training_stats["extreme_gradient_events"] += 1
                            
                    except Exception as e:
                        print(f"❌ [TRAIN] Step {global_step}: Backward failed: {type(e).__name__}: {e}")
                        if _VERBOSE_LOGGING:
                            traceback.print_exc()
                        training_stats["skipped_batches"] += 1
                        skip_reasons["backward_failed"] += 1
                        _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                        accumulated_steps = 0
                        continue

                    accumulated_steps += 1

                    if accumulated_steps >= accumulation_steps:
                        update_success = True
                        try:
                            if scaler.is_enabled():
                                scaler.unscale_(optimizer)
                                
                                try:
                                    grad_norm = torch.nn.utils.clip_grad_norm_(
                                        model.parameters(), 
                                        _GRAD_CLIP_NORM,
                                        error_if_nonfinite=False
                                    )
                                except RuntimeError as e:
                                    if "non-finite" in str(e).lower():
                                        print(f"❌ [TRAIN] Step {global_step}: Gradient clipping detected non-finite gradients")
                                        update_success = False
                                        training_stats["nan_gradient_events"] += 1
                                        grad_norm = torch.tensor(float('inf'))
                                    else:
                                        raise
                                
                                if not torch.isfinite(grad_norm):
                                    print(f"❌ [TRAIN] Step {global_step}: Gradient norm is NaN/Inf: {grad_norm}")
                                    update_success = False
                                    training_stats["nan_gradient_events"] += 1
                                else:
                                    scaler.step(optimizer)
                                    training_stats["optimizer_updates"] += 1
                                
                                if phi_optimizer is not None and enable_asbn_training:
                                    phi_params = []
                                    for g in phi_optimizer.param_groups:
                                        phi_params.extend([p for p in g.get("params", [])])
                                    
                                    if phi_params:
                                        try:
                                            scaler.unscale_(phi_optimizer)
                                        except RuntimeError as e:
                                            if "unscale_() has already been called" not in str(e):
                                                if _VERBOSE_LOGGING:
                                                    print(f"⚠️  [TRAIN] Step {global_step}: Phi unscale failed: {e}")
                                        
                                        try:
                                            phi_grad_norm = torch.nn.utils.clip_grad_norm_(
                                                phi_params, 
                                                _GRAD_CLIP_NORM,
                                                error_if_nonfinite=False
                                            )
                                        except RuntimeError as e:
                                            if "non-finite" in str(e).lower():
                                                if _VERBOSE_LOGGING:
                                                    print(f"❌ [TRAIN] Step {global_step}: Phi gradient clipping detected non-finite gradients")
                                                phi_grad_norm = torch.tensor(float('inf'))
                                            else:
                                                raise
                                        
                                        if not torch.isfinite(phi_grad_norm):
                                            if _VERBOSE_LOGGING:
                                                print(f"❌ [TRAIN] Step {global_step}: Phi gradient norm is NaN/Inf: {phi_grad_norm}")
                                        else:
                                            try:
                                                scaler.step(phi_optimizer)
                                                training_stats["asbn_updates"] += 1
                                            except Exception as e:
                                                print(f"❌ [TRAIN] Step {global_step}: Phi optimizer step failed: {type(e).__name__}: {e}")
                                    else:
                                        if global_step % 100 == 0 and _VERBOSE_LOGGING:
                                            print(f"⚠️  [TRAIN] Step {global_step}: Phi optimizer has no parameters")
                                
                                scaler.update()
                            else:
                                try:
                                    grad_norm = torch.nn.utils.clip_grad_norm_(
                                        model.parameters(), 
                                        _GRAD_CLIP_NORM,
                                        error_if_nonfinite=False
                                    )
                                except RuntimeError as e:
                                    if "non-finite" in str(e).lower():
                                        print(f"❌ [TRAIN] Step {global_step}: Gradient clipping detected non-finite gradients")
                                        update_success = False
                                        training_stats["nan_gradient_events"] += 1
                                        grad_norm = torch.tensor(float('inf'))
                                    else:
                                        raise
                                
                                if not torch.isfinite(grad_norm):
                                    print(f"❌ [TRAIN] Step {global_step}: Gradient norm is NaN/Inf: {grad_norm}")
                                    update_success = False
                                    training_stats["nan_gradient_events"] += 1
                                else:
                                    optimizer.step()
                                    training_stats["optimizer_updates"] += 1
                                
                                if phi_optimizer is not None and enable_asbn_training:
                                    phi_params = []
                                    for g in phi_optimizer.param_groups:
                                        phi_params.extend([p for p in g.get("params", [])])
                                    
                                    if phi_params:
                                        try:
                                            phi_grad_norm = torch.nn.utils.clip_grad_norm_(
                                                phi_params, 
                                                _GRAD_CLIP_NORM,
                                                error_if_nonfinite=False
                                            )
                                        except RuntimeError as e:
                                            if "non-finite" in str(e).lower():
                                                if _VERBOSE_LOGGING:
                                                    print(f"❌ [TRAIN] Step {global_step}: Phi gradient clipping detected non-finite gradients")
                                                phi_grad_norm = torch.tensor(float('inf'))
                                            else:
                                                raise
                                        
                                        if not torch.isfinite(phi_grad_norm):
                                            if _VERBOSE_LOGGING:
                                                print(f"❌ [TRAIN] Step {global_step}: Phi gradient norm is NaN/Inf: {phi_grad_norm}")
                                        else:
                                            try:
                                                phi_optimizer.step()
                                                training_stats["asbn_updates"] += 1
                                            except Exception as e:
                                                print(f"❌ [TRAIN] Step {global_step}: Phi optimizer step failed: {type(e).__name__}: {e}")
                                    else:
                                        if global_step % 100 == 0 and _VERBOSE_LOGGING:
                                            print(f"⚠️  [TRAIN] Step {global_step}: Phi optimizer has no parameters")

                            optimizer.zero_grad(set_to_none=True)
                            if phi_optimizer is not None and enable_asbn_training:
                                phi_optimizer.zero_grad(set_to_none=True)

                            if not update_success:
                                training_stats["skipped_batches"] += 1
                                skip_reasons["nan_grad_during_update"] += 1
                                _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)

                        except RuntimeError as e:
                            if "out of memory" in str(e).lower():
                                training_stats["oom_errors"] += 1
                                training_stats["skipped_batches"] += 1
                                skip_reasons["oom"] += 1
                                print(f"❌ [OOM] OOM at step {global_step}")
                                _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                                for p in model.parameters():
                                    p.grad = None
                                clear_all_gpu_caches()
                                accumulated_steps = 0
                                continue
                            else:
                                print(f"❌ [TRAIN] Step {global_step}: Optimizer step RuntimeError: {e}")
                                if _VERBOSE_LOGGING:
                                    traceback.print_exc()
                                training_stats["runtime_errors"] += 1
                                skip_reasons["opt_runtime"] += 1
                        except Exception as e:
                            print(f"❌ [TRAIN] Step {global_step}: Optimizer step failed: {type(e).__name__}: {e}")
                            if _VERBOSE_LOGGING:
                                traceback.print_exc()
                            training_stats["exceptions"] += 1
                            skip_reasons["opt_exception"] += 1
                        finally:
                            accumulated_steps = 0

                    if global_step % 100 == 0 and (_DEBUG_DISCOVERY or _VERBOSE_LOGGING):
                        _check_gradients(model, global_step)
                        training_stats["gradient_checks"].append({
                            "step": global_step,
                            "timestamp": time.time(),
                        })

                    if enable_validation and validate_every and validate_every > 0 and global_step % validate_every == 0:
                        if accumulated_steps > 0:
                            print(f"⚠️  [TRAIN] Step {global_step}: Validation triggered mid-accumulation, flushing {accumulated_steps} gradients")
                            _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                            accumulated_steps = 0
                        try:
                            val_fn = globals().get("comprehensive_epoch_validation", None)
                            if callable(val_fn):
                                print(f"🔍 [VALIDATION] Running validation at step {global_step}...")
                                validation_results = val_fn(
                                    model,
                                    tokenizer,
                                    epoch,
                                    global_step,
                                    _SOURCE_LANGUAGE,
                                    _TARGET_LANGUAGE,
                                    _MAX_LENGTH,
                                    _DEVICE,
                                )
                                if validation_results and validation_results.get("validation_completed", False):
                                    training_stats["epoch_validations"].append(validation_results)
                                    print(f"✅ [VALIDATION] Completed at step {global_step}")
                                    pending_validation = False
                                else:
                                    print(f"⚠️  [VALIDATION] Failed at step {global_step}")
                            else:
                                print(f"⚠️  [VALIDATION] comprehensive_epoch_validation not found")
                        except Exception as e:
                            print(f"❌ [VALIDATION] Failed at step {global_step}: {type(e).__name__}: {e}")
                            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                                traceback.print_exc()
                        model.train()

                    if global_step % DEBUG_PRINT_INTERVAL == 0 and (_DEBUG_DISCOVERY or _VERBOSE_LOGGING):
                        _print_gpu_mem("[TRAIN-DEBUG]")
                        cluster_count = _get_cluster_count(model)
                        print(f"[TRAIN-DEBUG] step={global_step} loss={last_forward_loss:.4f} trans={last_translation_loss:.4f} asbn={last_asbn_loss:.4f} dscd={last_dscd_loss:.4f} clusters={cluster_count}")
                        _print_top_clusters(model, top_n=5)

                    if global_step % _MEMORY_CLEANUP_FREQUENCY == 0:
                        clear_all_gpu_caches()

                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        training_stats["oom_errors"] += 1
                        training_stats["skipped_batches"] += 1
                        skip_reasons["oom"] += 1
                        print(f"❌ [OOM] Caught OOM at step {global_step}")
                        _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                        for p in model.parameters():
                            p.grad = None
                        clear_all_gpu_caches()
                        accumulated_steps = 0
                        continue
                    else:
                        print(f"❌ [TRAIN] Step {global_step}: RuntimeError: {type(e).__name__}: {e}")
                        if _VERBOSE_LOGGING:
                            traceback.print_exc()
                        training_stats["runtime_errors"] += 1
                        training_stats["skipped_batches"] += 1
                        skip_reasons["runtime"] += 1
                        _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                        accumulated_steps = 0
                        continue

                except Exception as e:
                    print(f"❌ [TRAIN] Step {global_step}: Exception: {type(e).__name__}: {e}")
                    if _VERBOSE_LOGGING:
                        traceback.print_exc()
                    training_stats["exceptions"] += 1
                    training_stats["skipped_batches"] += 1
                    skip_reasons["exceptions"] += 1
                    _clear_all_gradients_aggressively(model, optimizer, phi_optimizer)
                    accumulated_steps = 0
                    continue

                processed_batches = training_stats["batches_processed"] - training_stats["skipped_batches"]
                expected_updates = max(1, math.floor(processed_batches / max(1, accumulation_steps)))
                success_rate = 100.0 * training_stats["optimizer_updates"] / expected_updates if expected_updates > 0 else 0.0
                cluster_count = _get_cluster_count(model)
                next_disc_str = "NA"
                try:
                    if _PERIODIC_DISCOVERY_FREQUENCY and _PERIODIC_DISCOVERY_FREQUENCY > 0:
                        steps_to_next = _PERIODIC_DISCOVERY_FREQUENCY - (global_step % _PERIODIC_DISCOVERY_FREQUENCY)
                        if steps_to_next >= _PERIODIC_DISCOVERY_FREQUENCY:
                            steps_to_next = 0
                        next_disc_str = f"next_disc_in={steps_to_next}"
                except Exception:
                    next_disc_str = "next_disc=err"
                
                safe_last_forward_loss = max(0.0, min(last_forward_loss, 1000.0))
                safe_last_translation_loss = max(0.0, min(last_translation_loss, 1000.0))
                safe_last_asbn_loss = max(0.0, min(last_asbn_loss, 100.0))
                safe_last_dscd_loss = max(0.0, min(last_dscd_loss, 100.0))
                
                progress.set_postfix_str(f"loss={safe_last_forward_loss:.4f} trans={safe_last_translation_loss:.4f} asbn={safe_last_asbn_loss:.4f} dscd={safe_last_dscd_loss:.4f} dom_acc={last_domain_accuracy:.2f} rate={success_rate:.1f}% clusters={cluster_count} {next_disc_str}")

        finally:
            if progress is not None:
                try:
                    progress.close()
                except Exception:
                    pass

        if accumulated_steps > 0:
            print(f"⚠️  [TRAIN] End of epoch {epoch}: Flushing {accumulated_steps} accumulated gradients")
            try:
                if scaler.is_enabled():
                    scaler.unscale_(optimizer)
                    try:
                        grad_norm = torch.nn.utils.clip_grad_norm_(
                            model.parameters(), 
                            _GRAD_CLIP_NORM,
                            error_if_nonfinite=False
                        )
                    except RuntimeError as e:
                        if "non-finite" in str(e).lower():
                            print(f"❌ [TRAIN] End of epoch {epoch}: Gradient clipping detected non-finite gradients")
                            grad_norm = torch.tensor(float('inf'))
                        else:
                            raise
                    
                    if torch.isfinite(grad_norm):
                        scaler.step(optimizer)
                        training_stats["optimizer_updates"] += 1
                        print(f"✅ [TRAIN] End of epoch {epoch}: Flushed main optimizer (grad_norm={grad_norm:.4f})")
                    else:
                        print(f"❌ [TRAIN] End of epoch {epoch}: Gradient norm is NaN/Inf: {grad_norm}")
                    
                    if phi_optimizer is not None and enable_asbn_training:
                        try:
                            scaler.unscale_(phi_optimizer)
                        except RuntimeError as e:
                            if "unscale_() has already been called" not in str(e):
                                print(f"⚠️  [TRAIN] End of epoch {epoch}: Phi unscale failed: {e}")
                        
                        phi_params = []
                        for g in phi_optimizer.param_groups:
                            phi_params.extend([p for p in g.get("params", [])])
                        
                        if phi_params:
                            try:
                                phi_grad_norm = torch.nn.utils.clip_grad_norm_(
                                    phi_params, 
                                    _GRAD_CLIP_NORM,
                                    error_if_nonfinite=False
                                )
                            except RuntimeError as e:
                                if "non-finite" in str(e).lower():
                                    print(f"❌ [TRAIN] End of epoch {epoch}: Phi gradient clipping detected non-finite gradients")
                                    phi_grad_norm = torch.tensor(float('inf'))
                                else:
                                    raise
                            
                            if torch.isfinite(phi_grad_norm):
                                try:
                                    scaler.step(phi_optimizer)
                                    training_stats["asbn_updates"] += 1
                                    print(f"✅ [TRAIN] End of epoch {epoch}: Flushed phi optimizer (grad_norm={phi_grad_norm:.4f})")
                                except Exception as e:
                                    print(f"❌ [TRAIN] End of epoch {epoch}: Phi step failed: {e}")
                            else:
                                print(f"❌ [TRAIN] End of epoch {epoch}: Phi gradient norm is NaN/Inf: {phi_grad_norm}")
                    
                    scaler.update()
                else:
                    try:
                        grad_norm = torch.nn.utils.clip_grad_norm_(
                            model.parameters(), 
                            _GRAD_CLIP_NORM,
                            error_if_nonfinite=False
                        )
                    except RuntimeError as e:
                        if "non-finite" in str(e).lower():
                            print(f"❌ [TRAIN] End of epoch {epoch}: Gradient clipping detected non-finite gradients")
                            grad_norm = torch.tensor(float('inf'))
                        else:
                            raise
                    
                    if torch.isfinite(grad_norm):
                        optimizer.step()
                        training_stats["optimizer_updates"] += 1
                        print(f"✅ [TRAIN] End of epoch {epoch}: Flushed main optimizer (grad_norm={grad_norm:.4f})")
                    else:
                        print(f"❌ [TRAIN] End of epoch {epoch}: Gradient norm is NaN/Inf: {grad_norm}")
                    
                    if phi_optimizer is not None and enable_asbn_training:
                        phi_params = []
                        for g in phi_optimizer.param_groups:
                            phi_params.extend([p for p in g.get("params", [])])
                        
                        if phi_params:
                            try:
                                phi_grad_norm = torch.nn.utils.clip_grad_norm_(
                                    phi_params, 
                                    _GRAD_CLIP_NORM,
                                    error_if_nonfinite=False
                                )
                            except RuntimeError as e:
                                if "non-finite" in str(e).lower():
                                    print(f"❌ [TRAIN] End of epoch {epoch}: Phi gradient clipping detected non-finite gradients")
                                    phi_grad_norm = torch.tensor(float('inf'))
                                else:
                                    raise
                            
                            if torch.isfinite(phi_grad_norm):
                                try:
                                    phi_optimizer.step()
                                    training_stats["asbn_updates"] += 1
                                    print(f"✅ [TRAIN] End of epoch {epoch}: Flushed phi optimizer (grad_norm={phi_grad_norm:.4f})")
                                except Exception as e:
                                    print(f"❌ [TRAIN] End of epoch {epoch}: Phi step failed: {e}")
                            else:
                                print(f"❌ [TRAIN] End of epoch {epoch}: Phi gradient norm is NaN/Inf: {phi_grad_norm}")

                optimizer.zero_grad(set_to_none=True)
                if phi_optimizer is not None:
                    phi_optimizer.zero_grad(set_to_none=True)
            except Exception as e:
                print(f"❌ [TRAIN] End of epoch {epoch}: Gradient flush failed: {type(e).__name__}: {e}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
            finally:
                accumulated_steps = 0

        epoch_duration_min = (time.time() - epoch_start) / 60.0
        processed_batches = training_stats["batches_processed"] - training_stats["skipped_batches"]
        expected_updates = max(1, math.floor(processed_batches / max(1, accumulation_steps)))
        success_rate = 100.0 * training_stats["optimizer_updates"] / expected_updates if expected_updates > 0 else 0.0
        cluster_count = _get_cluster_count(model)

        avg_epoch_loss = float(np.mean(epoch_losses)) if epoch_losses else 0.0
        epoch_start_idx = max(0, len(training_stats["loss_component_breakdown"]) - len(epoch_losses))
        epoch_components = training_stats["loss_component_breakdown"][epoch_start_idx:]
        avg_translation_loss = float(np.mean([c["translation_loss"] for c in epoch_components])) if epoch_components else 0.0
        avg_asbn_loss = float(np.mean([c["asbn_loss"] for c in epoch_components])) if epoch_components else 0.0
        avg_dscd_loss = float(np.mean([c["dscd_loss"] for c in epoch_components])) if epoch_components else 0.0
        avg_domain_accuracy = float(np.mean([c["domain_accuracy"] for c in epoch_components])) if epoch_components else 0.0
        
        training_stats["epoch_losses"].append(avg_epoch_loss)

        print("\n" + "=" * 80)
        print(f"EPOCH {epoch}/{epochs} SUMMARY")
        print("=" * 80)
        print(f" Duration (min): {epoch_duration_min:.2f}")
        print(f" Optimizer updates: {training_stats['optimizer_updates']}")
        print(f" ASBN updates: {training_stats['asbn_updates']}")
        print(f" Batches: processed={processed_batches}, skipped={training_stats['skipped_batches']}")
        print(f" Label corruption fixes: {training_stats['label_corruption_fixes']}")
        print(f" Success rate: {success_rate:.1f}%")
        print(f" Clustered Token Types: {cluster_count}")
        print(f" Avg Epoch Loss: {avg_epoch_loss:.4f}")
        print(f"   - Translation: {avg_translation_loss:.4f}")
        print(f"   - ASBN: {avg_asbn_loss:.4f}")
        print(f"   - DSCD: {avg_dscd_loss:.4f}")
        print(f"   - Domain Accuracy: {avg_domain_accuracy:.2%}")
        print(f" NaN gradient events: {training_stats['nan_gradient_events']}")
        print(f" Extreme gradient events: {training_stats['extreme_gradient_events']}")

        if skip_reasons:
            print(f" Skip reasons: {dict(skip_reasons)}")

        try:
            val_fn = globals().get("comprehensive_epoch_validation", None)
            if callable(val_fn):
                print(f"🔍 [VALIDATION] Running end-of-epoch validation for epoch {epoch}...")
                validation_results = val_fn(
                    model=model,
                    tokenizer=tokenizer,
                    epoch=epoch,
                    global_step=global_step,
                    source_lang=_SOURCE_LANGUAGE,
                    target_lang=_TARGET_LANGUAGE,
                    max_length=_MAX_LENGTH,
                    device=_DEVICE,
                )
                if validation_results is not None and validation_results.get("validation_completed", False):
                    training_stats["epoch_validations"].append(validation_results)
                    training_stats["dscd_quality_history"].append(validation_results.get("dscd_quality_score", 0.0))
                    training_stats["asbn_domain_accuracy_history"].append(validation_results.get("asbn_domain_accuracy", 0.0))
                    training_stats["trg_explanation_history"].append(validation_results.get("trg_total_explanations", 0))
                    print(f"✅ [VALIDATION] End-of-epoch validation completed for epoch {epoch}")
                else:
                    print(f"⚠️  [VALIDATION] End-of-epoch validation failed for epoch {epoch}")
        except Exception as e:
            print(f"❌ [VALIDATION] End-of-epoch validation failed for epoch {epoch}: {type(e).__name__}: {e}")
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                traceback.print_exc()

        print("-" * 80)
        if skip_reasons:
            for k, v in skip_reasons.items():
                print(f"  {k}: {v}")
        print("=" * 80)

    try:
        checkpoint_dir = Path("/kaggle/working")
        if not checkpoint_dir.exists():
            checkpoint_dir = Path(".")
            print(f"[CHECKPOINT] /kaggle/working not found, using current directory: {checkpoint_dir.absolute()}")

        checkpoint_path = checkpoint_dir / "tatn_final.pt"
        core_model = model.module if hasattr(model, "module") else model

        dscd_state = {}
        try:
            if hasattr(core_model, "dscd"):
                dscd = core_model.dscd
                lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
                
                prototype_stores_data = {}
                corrupted_count = 0
                valid_count = 0
                
                if lock:
                    with lock:
                        stores = dict(getattr(dscd, "prototype_stores", {}) or {})
                else:
                    stores = dict(getattr(dscd, "prototype_stores", {}) or {})

                for token, store in stores.items():
                    try:
                        centroids = getattr(store, "centroids", None)
                        counts = getattr(store, "counts", None)
                        
                        if not isinstance(centroids, list) or not isinstance(counts, list):
                            corrupted_count += 1
                            continue
                        
                        if len(centroids) == 0 or len(centroids) != len(counts):
                            corrupted_count += 1
                            continue
                        
                        if any(c <= 0 for c in counts):
                            corrupted_count += 1
                            continue
                        
                        cent_list = []
                        valid_entry = True
                        for c in centroids:
                            try:
                                if isinstance(c, torch.Tensor):
                                    if not torch.isfinite(c).all():
                                        corrupted_count += 1
                                        valid_entry = False
                                        break
                                    cent_list.append(c.detach().cpu().tolist())
                                else:
                                    arr = np.asarray(c)
                                    if not np.isfinite(arr).all():
                                        corrupted_count += 1
                                        valid_entry = False
                                        break
                                    cent_list.append(arr.tolist())
                            except Exception:
                                corrupted_count += 1
                                valid_entry = False
                                break
                        
                        if valid_entry and cent_list and len(cent_list) == len(counts):
                            store_data = {"centroids": cent_list, "counts": [int(c) for c in counts]}
                            prototype_stores_data[str(token)] = store_data
                            valid_count += 1
                    except Exception:
                        corrupted_count += 1
                        continue

                dscd_state = {
                    "prototype_stores_data": prototype_stores_data,
                    "valid_stores": valid_count,
                    "corrupted_stores": corrupted_count,
                }
                
                if corrupted_count > 0:
                    print(f"⚠️  [CHECKPOINT] DSCD: {corrupted_count} corrupted stores skipped, {valid_count} valid stores saved")
        except Exception as e:
            print(f"❌ [CHECKPOINT] DSCD state extraction failed: {type(e).__name__}: {e}")
            dscd_state = {}

        checkpoint_data = {
            "epochs_trained": epochs,
            "global_steps": global_step,
            "final_train_loss": training_stats["epoch_losses"][-1] if training_stats["epoch_losses"] else 0.0,
            "final_translation_loss": avg_translation_loss,
            "final_asbn_loss": avg_asbn_loss,
            "final_dscd_loss": avg_dscd_loss,
            "final_domain_accuracy": avg_domain_accuracy,
            "model_state_dict": core_model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "phi_optimizer_state_dict": phi_optimizer.state_dict() if phi_optimizer is not None else None,
            "scaler_state_dict": scaler.state_dict() if scaler is not None else None,
            "training_stats": training_stats,
            "dscd_state": dscd_state,
            "config": {
                "SPAN_THRESHOLD": globals().get("SPAN_THRESHOLD", 0.15),
                "TAU_LOW": globals().get("TAU_LOW", 0.25),
                "UNCERTAINTY_THRESHOLD": globals().get("UNCERTAINTY_THRESHOLD", 0.25),
                "TRG_UNCERTAINTY_THRESHOLD": globals().get("TRG_UNCERTAINTY_THRESHOLD", 0.25),
                "LAMBDA_ASBN": globals().get("LAMBDA_ASBN", 0.05),
                "LAMBDA_DSCD": globals().get("LAMBDA_DSCD", 0.15),
                "TRG_TEMPERATURE": globals().get("TRG_TEMPERATURE", 1.0),
                "PERIODIC_DISCOVERY_FREQUENCY": _PERIODIC_DISCOVERY_FREQUENCY,
                "NUM_EPOCHS": epochs,
                "BATCH_SIZE": _BATCH_SIZE,
                "LEARNING_RATE": optimizer.param_groups[0]["lr"] if optimizer.param_groups else 0.0,
            },
        }

        torch.save(checkpoint_data, checkpoint_path)
        print(f"[CHECKPOINT] Saved to {checkpoint_path}")
        try:
            print(f"[CHECKPOINT] Size: {checkpoint_path.stat().st_size / (1024**2):.2f} MB")
        except Exception:
            pass

    except Exception as e:
        print(f"❌ [CHECKPOINT] Failed to save: {type(e).__name__}: {e}")
        if _VERBOSE_LOGGING:
            traceback.print_exc()

    print("=" * 80)
    print("FINAL TRAINING STATISTICS")
    print("=" * 80)

    processed_batches = training_stats["batches_processed"] - training_stats["skipped_batches"]
    expected_updates = max(1, math.floor(processed_batches / max(1, accumulation_steps)))
    success_rate = 100.0 * training_stats["optimizer_updates"] / expected_updates if expected_updates > 0 else 0.0

    print(f"[TRAIN] Success Rate: {success_rate:.1f}%")
    print(f"[TRAIN] Total Steps: {global_step}")
    print(f"[TRAIN] Optimizer Updates: {training_stats['optimizer_updates']}")
    print(f"[TRAIN] ASBN Updates: {training_stats['asbn_updates']}")
    print(f"[TRAIN] Label Corruption Fixes: {training_stats['label_corruption_fixes']}")
    print(f"[TRAIN] Gradient Checks: {len(training_stats['gradient_checks'])}")
    print(f"[TRAIN] NaN Gradient Events: {training_stats['nan_gradient_events']}")
    print(f"[TRAIN] Extreme Gradient Events: {training_stats['extreme_gradient_events']}")
    print(f"[TRAIN] Clustered Token Types: {_get_cluster_count(model)}")

    if training_stats["dscd_quality_history"]:
        print("[TRAIN] DSCD Quality Score Trend:")
        for i, score in enumerate(training_stats["dscd_quality_history"], 1):
            print(f"  Epoch {i}: {score:.1%}")

    if training_stats["asbn_domain_accuracy_history"]:
        print("[TRAIN] ASBN Domain Accuracy Trend:")
        for i, acc in enumerate(training_stats["asbn_domain_accuracy_history"], 1):
            print(f"  Epoch {i}: {acc:.1%}")

    if training_stats["trg_explanation_history"]:
        print("[TRAIN] TRG Explanation Count Trend:")
        for i, count in enumerate(training_stats["trg_explanation_history"], 1):
            print(f"  Epoch {i}: {count} explanations")

    print("=" * 80)
    return model


print("\n" + "=" * 80)
print("Cell 7: Training loop ready - NaN/Inf GRADIENT + LABEL VALIDATION FULLY HARDENED")
print("=" * 80)
print("Features:")
print("  - Label sanitization BEFORE forward (vocab bounds + pad=-100)")
print("  - Empty label batch detection (all -100 check)")
print("  - Pre-backward NaN loss detection")
print("  - Post-backward gradient validation with magnitude check")
print("  - Epsilon-protected loss scaling (div by accum_steps + 1e-9)")
print("  - RuntimeError catch for clip_grad_norm_ non-finite detection")
print("  - Loss component extraction with NaN/Inf validation")
print("  - Domain accuracy clamped to [0, 1]")
print("  - Aggressive gradient clearing on failure")
print("  - NaN/extreme gradient event tracking")
print("  - Label corruption fix counter")
print("=" * 80)


In [None]:
# ==============================================================================
# CELL 8: INFERENCE PIPELINE WITH TOKEN VALIDATION - DATAPARALLEL COMPATIBLE
# ==============================================================================
import os
import time
import math
import torch
import traceback
from typing import List, Dict, Any, Optional
from collections import defaultdict
import threading
import gc

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
    _TARGET_LANGUAGE = str(TARGET_LANGUAGE)
except Exception:
    _SOURCE_LANGUAGE = "bn"
    _TARGET_LANGUAGE = "en"

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except Exception:
    _MAX_LENGTH = 48

try:
    _DEVICE = DEVICE
    if not isinstance(_DEVICE, torch.device):
        _DEVICE = torch.device(str(_DEVICE))
except Exception:
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except Exception:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except Exception:
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except Exception:
    _DEBUG_TIMING = False

try:
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
except Exception:
    _USE_MULTI_GPU = torch.cuda.is_available() and torch.cuda.device_count() > 1

try:
    _SPAN_THRESHOLD = float(SPAN_THRESHOLD)
except Exception:
    _SPAN_THRESHOLD = 0.15

try:
    _TAU_LOW = float(TAU_LOW)
except Exception:
    _TAU_LOW = 0.25

try:
    _UNCERTAINTY_THRESHOLD = float(UNCERTAINTY_THRESHOLD)
except Exception:
    _UNCERTAINTY_THRESHOLD = _TAU_LOW

try:
    _TRG_UNCERTAINTY_THRESHOLD = float(TRG_UNCERTAINTY_THRESHOLD)
except Exception:
    _TRG_UNCERTAINTY_THRESHOLD = _UNCERTAINTY_THRESHOLD

try:
    _MAX_EXPLANATIONS_PER_SENTENCE = int(MAX_EXPLANATIONS_PER_SENTENCE)
except Exception:
    _MAX_EXPLANATIONS_PER_SENTENCE = 10

try:
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in HOMOGRAPH_REFERENCE_LIST_BN)
except Exception:
    _HOMOGRAPH_REFERENCE_LIST = {
        "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার",
        "তারা", "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত",
    }
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in _HOMOGRAPH_REFERENCE_LIST)

try:
    _M2M100_EN_TOKEN_ID = int(M2M100_EN_TOKEN_ID)
except Exception:
    _M2M100_EN_TOKEN_ID = 128022

try:
    _VOCAB_SIZE = int(VOCAB_SIZE)
except Exception:
    _VOCAB_SIZE = 128112

_SUBWORD_PUNCT_SET = {".", ",", "!", "?", "-"}


def _get_store_size(store) -> int:
    try:
        if hasattr(store, "size") and callable(getattr(store, "size")):
            return int(store.size())
        centroids = getattr(store, "centroids", None)
        if centroids is not None:
            if isinstance(centroids, torch.Tensor):
                return int(centroids.size(0))
            elif isinstance(centroids, list):
                return len(centroids)
        return 0
    except Exception:
        return 0


def get_dscd_homographs(model: torch.nn.Module) -> set:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return set()

        lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
        if lock:
            with lock:
                stores = dict(getattr(dscd, "prototype_stores", {}))
        else:
            stores = dict(getattr(dscd, "prototype_stores", {}))

        word_prototype_counts = defaultdict(int)
        for token_key, store in stores.items():
            try:
                num_protos = _get_store_size(store)
                clean_token = (
                    str(token_key)
                    .replace("▁", "")
                    .replace("Ġ", "")
                    .replace("##", "")
                    .replace("@@", "")
                    .replace("</w>", "")
                    .strip()
                    .lower()
                )
                if clean_token:
                    word_prototype_counts[clean_token] = max(word_prototype_counts[clean_token], num_protos)
            except Exception:
                continue

        homographs = {w for w, c in word_prototype_counts.items() if c >= 2}
        return homographs
    except Exception:
        return set()


def build_token_to_word_map(tokenizer, input_ids: torch.Tensor) -> List[Dict[int, str]]:
    batch_word_maps = []
    for b in range(input_ids.size(0)):
        try:
            token_ids_list = input_ids[b].tolist()
            token_ids_list = [max(0, min(int(tid), _VOCAB_SIZE - 1)) for tid in token_ids_list]
            
            try:
                tokens = tokenizer.convert_ids_to_tokens(token_ids_list)
            except Exception:
                toks = []
                for idv in token_ids_list:
                    try:
                        toks.append(tokenizer.decode([idv], skip_special_tokens=True))
                    except Exception:
                        toks.append(str(idv))
                tokens = toks
        except Exception:
            tokens = ["<unk>"] * input_ids.size(1)

        word_map: Dict[int, Optional[str]] = {}
        current_word = ""
        word_start_idx = 0
        for i, token in enumerate(tokens):
            if not token or token in ["<s>", "</s>", "<pad>", "<unk>", ""]:
                word_map[i] = None
                continue
            if token.startswith("▁"):
                if current_word:
                    clean_word = current_word.replace("▁", "").strip()
                    if clean_word:
                        for j in range(word_start_idx, i):
                            word_map[j] = clean_word
                current_word = token
                word_start_idx = i
            else:
                current_word += token
        if current_word:
            clean_word = current_word.replace("▁", "").strip()
            if clean_word:
                for j in range(word_start_idx, len(tokens)):
                    word_map[j] = clean_word
        cleaned_map = {int(k): str(v) for k, v in word_map.items() if v}
        batch_word_maps.append(cleaned_map)
    return batch_word_maps


class InferenceStatistics:
    def __init__(self):
        self.lock = threading.Lock()
        self.reset()

    def reset(self):
        with self.lock:
            self.total_inferences = 0
            self.successful_translations = 0
            self.failed_translations = 0
            self.total_explanations = 0
            self.high_confidence_explanations = 0
            self.low_confidence_explanations = 0
            self.total_confidence = 0.0
            self.dscd_homographs_explained = set()
            self.reference_homographs_explained = set()
            self.sum_span = 0.0
            self.sum_uncertainty = 0.0
            self.dscd_empty_warnings = 0
            self.token_counts = defaultdict(int)
            self.token_confidences = defaultdict(list)

    def record_inference(self, result: Dict[str, Any], dscd_homographs: Optional[set] = None):
        if not isinstance(result, dict):
            return
        
        with self.lock:
            self.total_inferences += 1
            translation = result.get("translation", "")
            if translation and translation != "ERROR DURING TRANSLATION":
                self.successful_translations += 1
            else:
                self.failed_translations += 1

            explanations = result.get("explanations", [])
            if not isinstance(explanations, list):
                explanations = []
            
            self.total_explanations += len(explanations)

            for exp in explanations:
                if not isinstance(exp, dict):
                    continue
                try:
                    conf = float(exp.get("confidence", 0.5))
                    if not math.isfinite(conf):
                        conf = 0.5
                    
                    self.total_confidence += conf
                    if conf >= 0.65:
                        self.high_confidence_explanations += 1
                    elif conf < 0.4:
                        self.low_confidence_explanations += 1

                    span_val = float(exp.get("span", 0.0))
                    unc_val = float(exp.get("uncertainty", 0.0))
                    
                    if math.isfinite(span_val):
                        self.sum_span += span_val
                    if math.isfinite(unc_val):
                        self.sum_uncertainty += unc_val

                    word = str(exp.get("ambiguous_word", exp.get("token", ""))).strip()
                    clean_word = (
                        word.replace("▁", "")
                        .replace("Ġ", "")
                        .replace("##", "")
                        .replace("@@", "")
                        .replace("</w>", "")
                        .lower()
                    )
                    if clean_word:
                        self.token_counts[clean_word] += 1
                        self.token_confidences[clean_word].append(conf)
                        if dscd_homographs and clean_word in dscd_homographs:
                            self.dscd_homographs_explained.add(clean_word)
                        if clean_word in _HOMOGRAPH_REFERENCE_LIST:
                            self.reference_homographs_explained.add(clean_word)
                except Exception:
                    continue

    def get_summary(self) -> Dict[str, Any]:
        with self.lock:
            total_exp = max(self.total_explanations, 1)
            unique_tokens = len(self.token_counts)
            diversity_ratio = unique_tokens / total_exp if total_exp > 0 else 0.0
            avg_conf = (self.total_confidence / total_exp) if total_exp > 0 else 0.0
            avg_span = (self.sum_span / total_exp) if total_exp > 0 else 0.0
            avg_unc = (self.sum_uncertainty / total_exp) if total_exp > 0 else 0.0
            return {
                "total_inferences": self.total_inferences,
                "successful_translations": self.successful_translations,
                "failed_translations": self.failed_translations,
                "success_rate": self.successful_translations / max(self.total_inferences, 1),
                "total_explanations": self.total_explanations,
                "explanations_per_inference": self.total_explanations / max(self.total_inferences, 1),
                "high_confidence_rate": self.high_confidence_explanations / total_exp,
                "low_confidence_rate": self.low_confidence_explanations / total_exp,
                "avg_confidence": avg_conf,
                "avg_span": avg_span,
                "avg_uncertainty": avg_unc,
                "dscd_homographs_explained": list(self.dscd_homographs_explained),
                "reference_homographs_explained": list(self.reference_homographs_explained),
                "dscd_empty_warnings": self.dscd_empty_warnings,
                "unique_tokens_explained": unique_tokens,
                "diversity_ratio": diversity_ratio,
            }

    def print_summary(self):
        summary = self.get_summary()
        print("=" * 80)
        print("INFERENCE STATISTICS SUMMARY")
        print("=" * 80)
        print(f"Total inferences: {summary['total_inferences']}")
        print(f"Success rate: {summary['success_rate']:.1%}")
        print(f"Total explanations: {summary['total_explanations']}")
        print(f"Explanations per inference: {summary['explanations_per_inference']:.2f}")
        print(f"Unique tokens explained: {summary['unique_tokens_explained']}")
        print(f"Diversity ratio: {summary['diversity_ratio']:.2f}")
        print(f"Avg confidence: {summary['avg_confidence']:.3f}")
        print(f"High confidence rate: {summary['high_confidence_rate']:.1%}")
        print(f"Avg span: {summary['avg_span']:.3f}")
        print(f"Avg uncertainty: {summary['avg_uncertainty']:.3f}")
        if summary["dscd_homographs_explained"]:
            print(f"DSCD homographs explained: {len(summary['dscd_homographs_explained'])}")
            print(f"  {', '.join(summary['dscd_homographs_explained'])}")
        if summary["reference_homographs_explained"]:
            print(f"Reference homographs explained: {len(summary['reference_homographs_explained'])}")
            print(f"  {', '.join(summary['reference_homographs_explained'])}")
        if summary["dscd_empty_warnings"] > 0:
            print(f"DSCD empty warnings: {summary['dscd_empty_warnings']}")
        print("=" * 80)


INFERENCE_STATS = InferenceStatistics()


def to_device(batch_enc: Any, device: torch.device):
    try:
        if hasattr(batch_enc, "to"):
            return batch_enc.to(device)
    except Exception:
        pass

    if isinstance(batch_enc, dict):
        out = {}
        for k, v in batch_enc.items():
            try:
                if isinstance(v, torch.Tensor):
                    out[k] = v.to(device)
                elif isinstance(v, dict):
                    out[k] = to_device(v, device)
                elif isinstance(v, (list, tuple)):
                    out[k] = [t.to(device) if isinstance(t, torch.Tensor) else t for t in v]
                else:
                    out[k] = v
            except Exception:
                out[k] = v
        return out
    return batch_enc


def extract_dscd_outputs(raw_out: Any) -> Dict[str, Any]:
    if raw_out is None:
        return {}
    if isinstance(raw_out, dict):
        if "dscd_outputs" in raw_out and isinstance(raw_out["dscd_outputs"], dict):
            return raw_out["dscd_outputs"]
        if "dscd" in raw_out and isinstance(raw_out["dscd"], dict):
            return raw_out["dscd"]
        if "explanations" in raw_out or "proto_probs" in raw_out or "h_augmented" in raw_out:
            return raw_out
        for key in ["dscd_outputs", "dscd", "dscd_out"]:
            if key in raw_out and isinstance(raw_out[key], dict):
                return raw_out[key]
        return raw_out
    if isinstance(raw_out, (list, tuple)):
        for item in raw_out:
            if isinstance(item, dict):
                return extract_dscd_outputs(item)
    return {}


def get_explanations_list(dscd: Dict[str, Any]) -> List[List[Dict[str, Any]]]:
    if not dscd:
        return []
    expl = dscd.get("explanations", None)
    if expl is None:
        for alt in ("explanations_per_sentence", "trg_explanations", "exps"):
            if alt in dscd:
                expl = dscd[alt]
                break
    if expl is None:
        return []
    if isinstance(expl, list):
        if len(expl) > 0 and isinstance(expl[0], dict):
            return [expl]
        if len(expl) > 0 and isinstance(expl[0], list):
            return expl
    return []


def is_subword_token(token: str) -> bool:
    if not token or len(token.strip()) == 0:
        return True
    token = token.strip()
    if token.startswith("▁") or token.startswith("Ġ") or token.startswith("##") or token.startswith("@@"):
        return False
    if (len(token) == 1 and token in _SUBWORD_PUNCT_SET) or token.isdigit():
        return True
    return False


def should_filter_explanation(expl: Dict[str, Any], span_th: float, unc_th: float) -> bool:
    try:
        token = expl.get("ambiguous_word", expl.get("token", ""))
        if is_subword_token(str(token)):
            return True
        span = float(expl.get("span", 0.0))
        uncertainty = float(expl.get("uncertainty", 0.0))
        if not math.isfinite(span) or not math.isfinite(uncertainty):
            return True
        ambiguous = (span < span_th) and (uncertainty >= unc_th)
        return not ambiguous
    except Exception:
        return True


def force_english_bos(tokenizer, mbart_model) -> Optional[int]:
    forced_id = None
    try:
        if hasattr(tokenizer, "get_lang_id"):
            for code in (_TARGET_LANGUAGE, "en_XX", "en", "eng"):
                try:
                    lid = tokenizer.get_lang_id(code)
                    if lid is not None:
                        forced_id = int(lid)
                        break
                except Exception:
                    continue
        elif hasattr(tokenizer, "lang_code_to_id"):
            forced_id = tokenizer.lang_code_to_id.get(_TARGET_LANGUAGE, None)
            if forced_id is not None:
                forced_id = int(forced_id)
    except Exception:
        forced_id = None

    if forced_id is None:
        forced_id = _M2M100_EN_TOKEN_ID

    try:
        if forced_id is not None and hasattr(mbart_model, "config"):
            mbart_model.config.forced_bos_token_id = int(forced_id)
            mbart_model.config.decoder_start_token_id = int(forced_id)
    except Exception:
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print("[INF] Could not set forced BOS on mbart config")
    return forced_id


def safe_generate(mbart, input_ids=None, encoder_outputs=None, attention_mask=None, max_length=64, num_beams=4, **kwargs):
    try:
        if encoder_outputs is not None:
            generated = mbart.generate(
                encoder_outputs=encoder_outputs,
                attention_mask=attention_mask,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
                **kwargs,
            )
        else:
            generated = mbart.generate(
                input_ids,
                attention_mask=attention_mask,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
                **kwargs,
            )
        
        if generated is not None and isinstance(generated, torch.Tensor):
            generated = torch.clamp(generated, min=0, max=_VOCAB_SIZE - 1)
        
        return generated
        
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            if _DEBUG_DISCOVERY:
                print("[INF] OOM during generation, retrying with smaller beams...")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            if encoder_outputs is not None:
                generated = mbart.generate(
                    encoder_outputs=encoder_outputs,
                    attention_mask=attention_mask,
                    max_length=min(max_length, 48),
                    num_beams=1,
                    early_stopping=True,
                    **kwargs,
                )
            else:
                generated = mbart.generate(
                    input_ids,
                    attention_mask=attention_mask,
                    max_length=min(max_length, 48),
                    num_beams=1,
                    early_stopping=True,
                    **kwargs,
                )
            
            if generated is not None and isinstance(generated, torch.Tensor):
                generated = torch.clamp(generated, min=0, max=_VOCAB_SIZE - 1)
            
            return generated
        raise


def translate_with_explanations(
    model,
    tokenizer,
    input_sentence: str,
    device: Optional[torch.device] = None,
    span_threshold: Optional[float] = None,
    uncertainty_threshold: Optional[float] = None,
    track_stats: bool = True,
) -> Dict[str, Any]:
    if device is None:
        device = _DEVICE
    elif not isinstance(device, torch.device):
        try:
            device = torch.device(str(device))
        except Exception:
            device = _DEVICE
    
    span_th = _SPAN_THRESHOLD if span_threshold is None else float(span_threshold)
    unc_th = _UNCERTAINTY_THRESHOLD if uncertainty_threshold is None else float(uncertainty_threshold)

    if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
        print(f"[INF] Starting inference - input: {input_sentence[:120]}")

    try:
        tokenizer.src_lang = _SOURCE_LANGUAGE
    except Exception:
        pass

    dscd_homographs = set()
    try:
        dscd_homographs = get_dscd_homographs(model)
    except Exception:
        dscd_homographs = set()

    encoder_hidden = None
    encoder_hidden_adjusted = None

    try:
        enc = tokenizer(
            input_sentence,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=_MAX_LENGTH,
        )
        enc = to_device(enc, device)
        
        if "input_ids" in enc and isinstance(enc["input_ids"], torch.Tensor):
            enc["input_ids"] = torch.clamp(enc["input_ids"], min=0, max=_VOCAB_SIZE - 1)

        model.eval()
        core = model.module if hasattr(model, "module") else model
        src_texts = [input_sentence]
        dscd_validated = False

        try:
            dscd = getattr(core, "dscd", None)
            if dscd is not None:
                lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
                if lock:
                    with lock:
                        stores = getattr(dscd, "prototype_stores", {})
                        num_stores = len(stores)
                        multi_sense = sum(1 for s in stores.values() if _get_store_size(s) >= 2)
                else:
                    stores = getattr(dscd, "prototype_stores", {})
                    num_stores = len(stores)
                    multi_sense = sum(1 for s in stores.values() if _get_store_size(s) >= 2)
                
                if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                    print(f"[INF] DSCD: types={num_stores}, multi_sense={multi_sense}, discovered={len(dscd_homographs)}")
                if num_stores == 0:
                    if track_stats:
                        INFERENCE_STATS.dscd_empty_warnings += 1
                else:
                    dscd_validated = True
        except Exception:
            pass

        with torch.inference_mode():
            raw_dscd_out = {}
            mbart = getattr(core, "mbart", None)
            if mbart is None:
                raise RuntimeError("Model backend missing .mbart")

            encoder_outputs_raw = mbart.model.encoder(
                input_ids=enc.get("input_ids"),
                attention_mask=enc.get("attention_mask"),
            )
            
            if hasattr(encoder_outputs_raw, "last_hidden_state"):
                encoder_hidden = encoder_outputs_raw.last_hidden_state
            elif isinstance(encoder_outputs_raw, tuple) and len(encoder_outputs_raw) > 0 and isinstance(
                encoder_outputs_raw[0], torch.Tensor
            ):
                encoder_hidden = encoder_outputs_raw[0]
            else:
                encoder_hidden = encoder_outputs_raw

            if not isinstance(encoder_hidden, torch.Tensor):
                raise RuntimeError(f"Invalid encoder hidden state type: {type(encoder_hidden)}")
            
            if encoder_hidden.dim() != 3:
                raise RuntimeError(f"Invalid encoder hidden state shape: {encoder_hidden.shape}")
            
            if not torch.isfinite(encoder_hidden).all():
                if _VERBOSE_LOGGING:
                    print(f"[INF] WARNING: Encoder hidden state contains NaN/Inf values")
                encoder_hidden = torch.nan_to_num(encoder_hidden, nan=0.0, posinf=1.0, neginf=-1.0)

            try:
                if hasattr(core, "forward_with_explanations"):
                    try:
                        raw_dscd_out = core.forward_with_explanations(
                            input_ids=enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask"),
                            src_texts=src_texts,
                            use_dscd=True,
                            use_asbn=False,
                        )
                    except TypeError:
                        raw_dscd_out = core.forward_with_explanations(
                            enc.get("input_ids"),
                            enc.get("attention_mask"),
                            src_texts,
                        )
                else:
                    out = core.forward(
                        input_ids=enc.get("input_ids"),
                        attention_mask=enc.get("attention_mask"),
                        src_texts=src_texts,
                        labels=None,
                        use_dscd=True,
                        use_asbn=False,
                    )
                    raw_dscd_out = extract_dscd_outputs(out) if isinstance(out, dict) else {}
            except Exception as e:
                if _VERBOSE_LOGGING:
                    print(f"[INF] forward_with_explanations failed: {e}")
                raw_dscd_out = {}

            dscd_out = extract_dscd_outputs(raw_dscd_out)

            if isinstance(raw_dscd_out, dict) and "sense_augmented_embeddings" in raw_dscd_out:
                encoder_hidden_adjusted = raw_dscd_out["sense_augmented_embeddings"]
            elif "h_augmented" in dscd_out:
                encoder_hidden_adjusted = dscd_out["h_augmented"]
            else:
                encoder_hidden_adjusted = encoder_hidden

            if isinstance(encoder_hidden_adjusted, torch.Tensor):
                if encoder_hidden_adjusted.dim() != 3:
                    if _VERBOSE_LOGGING:
                        print(f"[INF] Adjusted embeddings wrong dims: {encoder_hidden_adjusted.shape}")
                    encoder_hidden_adjusted = encoder_hidden
                elif encoder_hidden_adjusted.shape != encoder_hidden.shape:
                    if _VERBOSE_LOGGING:
                        print(f"[INF] Adjusted embeddings shape mismatch: {encoder_hidden_adjusted.shape} vs {encoder_hidden.shape}")
                    encoder_hidden_adjusted = encoder_hidden
                elif not torch.isfinite(encoder_hidden_adjusted).all():
                    if _VERBOSE_LOGGING:
                        print(f"[INF] Adjusted embeddings contain NaN/Inf")
                    encoder_hidden_adjusted = torch.nan_to_num(encoder_hidden_adjusted, nan=0.0, posinf=1.0, neginf=-1.0)
            else:
                encoder_hidden_adjusted = encoder_hidden

            forced_id = force_english_bos(tokenizer, mbart)
            orig_use_cache = getattr(mbart.config, "use_cache", None) if hasattr(mbart, "config") else None
            try:
                if hasattr(mbart, "config"):
                    mbart.config.use_cache = True
            except Exception:
                pass

            try:
                if isinstance(encoder_hidden_adjusted, torch.Tensor):
                    encoder_hidden_adjusted = encoder_hidden_adjusted.to(device)
                    from transformers.modeling_outputs import BaseModelOutput

                    encoder_outputs_for_decoder = BaseModelOutput(last_hidden_state=encoder_hidden_adjusted)
                    generated = safe_generate(
                        mbart,
                        encoder_outputs=encoder_outputs_for_decoder,
                        attention_mask=enc.get("attention_mask"),
                        max_length=min(_MAX_LENGTH, 64),
                        num_beams=4,
                        pad_token_id=getattr(tokenizer, "pad_token_id", None),
                        forced_bos_token_id=forced_id,
                        repetition_penalty=2.5,
                        no_repeat_ngram_size=2,
                        length_penalty=1.0,
                        do_sample=False,
                    )
                else:
                    generated = safe_generate(
                        mbart,
                        input_ids=enc.get("input_ids"),
                        attention_mask=enc.get("attention_mask"),
                        max_length=min(_MAX_LENGTH, 64),
                        num_beams=4,
                        pad_token_id=getattr(tokenizer, "pad_token_id", None),
                        forced_bos_token_id=forced_id,
                        repetition_penalty=2.5,
                        no_repeat_ngram_size=2,
                        length_penalty=1.0,
                        do_sample=False,
                    )
                
                if generated is not None and isinstance(generated, torch.Tensor):
                    generated = torch.clamp(generated, min=0, max=_VOCAB_SIZE - 1)
                
                translation = (
                    tokenizer.decode(generated[0], skip_special_tokens=True) if generated is not None and len(generated) > 0 else ""
                )
            finally:
                try:
                    if hasattr(mbart, "config") and orig_use_cache is not None:
                        mbart.config.use_cache = orig_use_cache
                except Exception:
                    pass

            sentence_explanations = []
            try:
                trg = getattr(core, "trg_system", None)
                if trg and hasattr(trg, "process_sentence_for_explanations"):
                    token_word_map = build_token_to_word_map(tokenizer, enc.get("input_ids"))
                    token_word_map_single = token_word_map[0] if token_word_map else {}
                    try:
                        input_ids_clamped = torch.clamp(enc.get("input_ids"), min=0, max=_VOCAB_SIZE - 1)
                        tokens_batch = tokenizer.convert_ids_to_tokens(input_ids_clamped[0].tolist())
                    except Exception:
                        tokens_batch = [
                            tokenizer.decode([i], skip_special_tokens=True)
                            for i in input_ids_clamped[0].tolist()
                        ]

                    dscd_for_trg = {}
                    for k in ("uncertainties", "span_preds", "gates", "proto_probs"):
                        if k in dscd_out:
                            dscd_for_trg[k] = dscd_out[k]

                    trg_result = trg.process_sentence_for_explanations(
                        tokens=tokens_batch,
                        dscd_outputs=dscd_for_trg,
                        token_word_map=token_word_map_single,
                        uncertainty_threshold=_TRG_UNCERTAINTY_THRESHOLD,
                        span_threshold=_SPAN_THRESHOLD,
                        decoder_attention=None,
                        max_explanations=_MAX_EXPLANATIONS_PER_SENTENCE,
                    )
                    sentence_explanations = trg_result if isinstance(trg_result, list) else []
                else:
                    explanations_list = get_explanations_list(dscd_out)
                    sentence_explanations = explanations_list[0] if explanations_list else []
            except Exception as e:
                if _VERBOSE_LOGGING:
                    print(f"[INF] TRG explanation generation failed: {e}")
                sentence_explanations = []

            def is_real_ambiguity(e: Dict[str, Any]) -> bool:
                try:
                    s = float(e.get("span", 0.0))
                    u = float(e.get("uncertainty", 0.0))
                    if not math.isfinite(s) or not math.isfinite(u):
                        return False
                    return (s < span_th) and (u >= unc_th)
                except Exception:
                    return False

            real_amb_count = 0
            out_explanations: List[Dict[str, Any]] = []
            confidences: List[float] = []
            spans: List[float] = []
            uncertainties: List[float] = []

            if isinstance(sentence_explanations, list):
                for ex in sentence_explanations:
                    try:
                        word = ex.get("ambiguous_word", ex.get("token", ""))
                        if not isinstance(word, str):
                            word = ""
                        clean_word = (
                            word.replace("▁", "")
                            .replace("Ġ", "")
                            .replace("##", "")
                            .replace("@@", "")
                            .replace("</w>", "")
                            .strip()
                        )
                        if clean_word:
                            ex["ambiguous_word"] = clean_word

                        if should_filter_explanation(ex, span_th, unc_th):
                            continue

                        is_real = is_real_ambiguity(ex)
                        if is_real:
                            real_amb_count += 1

                        s = float(ex.get("span", 0.0))
                        u = float(ex.get("uncertainty", 0.0))
                        
                        if not math.isfinite(s):
                            s = 0.0
                        if not math.isfinite(u):
                            u = 0.5
                        
                        confidence = ex.get("confidence", None)
                        if confidence is None:
                            confidence = max(0.0, min(1.0, (s * (1.0 - u))))
                        confidence = float(confidence)
                        
                        if not math.isfinite(confidence):
                            confidence = 0.5

                        confidences.append(confidence)
                        spans.append(s)
                        uncertainties.append(u)

                        out_explanations.append(
                            {
                                "ambiguous_word": ex.get("ambiguous_word", ex.get("token", "N/A")),
                                "position": ex.get("position", ex.get("token_idx", "N/A")),
                                "explanation": ex.get("explanation", ex.get("explain", "")),
                                "uncertainty": float(u),
                                "span": float(s),
                                "confidence": confidence,
                                "is_real_amb": bool(is_real),
                            }
                        )
                    except Exception:
                        continue

            quality_metrics = {
                "total_raw_explanations": len(sentence_explanations) if isinstance(sentence_explanations, list) else 0,
                "filtered_explanations": (len(sentence_explanations) - len(out_explanations))
                if isinstance(sentence_explanations, list)
                else 0,
                "high_confidence_count": sum(1 for c in confidences if c >= 0.65),
                "low_confidence_count": sum(1 for c in confidences if c < 0.4),
                "avg_confidence": (sum(confidences) / len(confidences)) if confidences else 0.0,
                "avg_span": (sum(spans) / len(spans)) if spans else 0.0,
                "avg_uncertainty": (sum(uncertainties) / len(uncertainties)) if uncertainties else 0.0,
            }

            result = {
                "input_sentence": input_sentence,
                "translation": translation,
                "ambiguous_words_detected": int(real_amb_count),
                "explanations": out_explanations,
                "quality_metrics": quality_metrics,
                "dscd_validated": dscd_validated,
            }

            if track_stats:
                INFERENCE_STATS.record_inference(result, dscd_homographs=dscd_homographs)

            return result

    except Exception as e:
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print(f"[INF] ERROR: {type(e).__name__}: {str(e)[:200]}")
            try:
                traceback.print_exc()
            except Exception:
                pass

        error_result = {
            "input_sentence": input_sentence,
            "translation": "ERROR DURING TRANSLATION",
            "ambiguous_words_detected": 0,
            "explanations": [],
            "quality_metrics": {},
            "dscd_validated": False,
            "error": str(e)[:200],
        }
        if track_stats:
            INFERENCE_STATS.record_inference(error_result, dscd_homographs=dscd_homographs)
        return error_result

    finally:
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception:
            pass
        try:
            if gc.isenabled():
                gc.collect()
        except Exception:
            pass


def demonstrate_system(model, tokenizer, sentences: Optional[List[str]] = None):
    if sentences is None:
        sentences = [
            "আমি কল বন্ধ করেছি।",
            "কাল আমি বই কিনব।",
            "পাতা ঝরে পড়েছে।",
            "তিনি ব্যাংক গেছেন।",
            "আমি ভালো আছি।",
        ]

    print("=" * 80)
    print("TATN DEMO: Translation + Explanations")
    print("=" * 80)

    INFERENCE_STATS.reset()
    for s in sentences:
        print(f"\n{s}")
        res = translate_with_explanations(model, tokenizer, s)
        print(f"Translation: {res.get('translation', 'N/A')}")
        print(f"Ambiguous words detected: {res.get('ambiguous_words_detected', 0)}")
        quality = res.get("quality_metrics", {})
        if quality:
            print(
                f"Quality: conf={quality.get('avg_confidence', 0):.3f}, high={quality.get('high_confidence_count', 0)}, low={quality.get('low_confidence_count', 0)}"
            )
        if res.get("explanations"):
            for idx, ex in enumerate(res["explanations"], 1):
                conf = ex.get("confidence", 0.0)
                print(f"  {idx}. {ex['ambiguous_word']} (pos={ex.get('position', 'N/A')}, conf={conf:.3f})")
                print(f"     {ex.get('explanation', 'N/A')[:200]}")
        else:
            print("  No explanations")
    print("=" * 80)
    INFERENCE_STATS.print_summary()


def dscd_discovery_warmup(model, tokenizer, num_sents: int = 8000, batch_size: int = 64, max_len: Optional[int] = None):
    if max_len is None:
        max_len = _MAX_LENGTH
    core = model.module if hasattr(model, "module") else model
    try:
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            print("[WARMUP] Model has no dscd component")
            return

        orig_enable = getattr(dscd, "enable_training_clustering", None)
        orig_nmin = getattr(dscd, "nmin", None)
        orig_buffer = getattr(dscd, "buffer_size", None)
        try:
            if hasattr(dscd, "enable_training_clustering"):
                dscd.enable_training_clustering = True
            if hasattr(dscd, "nmin"):
                dscd.nmin = max(3, int(getattr(dscd, "nmin", 5)))
            if hasattr(dscd, "buffer_size"):
                dscd.buffer_size = max(200, int(getattr(dscd, "buffer_size", 300)))
        except Exception:
            pass

        texts: List[str] = []
        try:
            if "load_and_preprocess_optimized" in globals():
                pairs = load_and_preprocess_optimized(num_sents)
                texts = [bn for bn, _ in pairs[:num_sents]]
                if len(texts) < num_sents // 2:
                    print(f"[WARMUP] WARNING: Only loaded {len(texts)} texts, expected {num_sents}")
            else:
                base = [
                    "আমি কল বন্ধ করেছি।",
                    "কাল আমি বই কিনব।",
                    "পাতা ঝরে পড়েছে।",
                    "তিনি ব্যাংক গেছেন।",
                    "আমি ভালো আছি।",
                ]
                while len(texts) < num_sents:
                    texts.extend(base)
                texts = texts[:num_sents]
                print(f"[WARMUP] WARNING: Using fallback data ({len(texts)} sentences)")
        except Exception:
            texts = []

        if not texts:
            print("[WARMUP] ERROR: No texts loaded")
            return

        processed = 0
        core.eval()
        print(f"[WARMUP] Processing {len(texts)} sentences in batches of {batch_size}...")
        start_time = time.time()
        last_print = start_time

        with torch.inference_mode():
            for i in range(0, len(texts), batch_size):
                batch = texts[i : i + batch_size]
                try:
                    enc = tokenizer(
                        batch,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=max_len,
                    )
                    enc = to_device(enc, _DEVICE)
                    
                    if "input_ids" in enc and isinstance(enc["input_ids"], torch.Tensor):
                        enc["input_ids"] = torch.clamp(enc["input_ids"], min=0, max=_VOCAB_SIZE - 1)
                    
                    if hasattr(core, "forward_with_explanations"):
                        core.forward_with_explanations(
                            input_ids=enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask"),
                            src_texts=batch,
                            use_dscd=True,
                            use_asbn=False,
                        )
                    else:
                        core.mbart.model.encoder(
                            input_ids=enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask"),
                        )
                    processed += len(batch)
                    current_time = time.time()
                    if i % (batch_size * 10) == 0 or current_time - last_print > 5:
                        elapsed = current_time - start_time
                        rate = processed / elapsed if elapsed > 0 else 0
                        eta = (len(texts) - processed) / rate if rate > 0 else 0
                        print(
                            f"[WARMUP] {processed}/{len(texts)} ({processed/len(texts)*100:.1f}%) rate={rate:.1f} sents/s ETA {eta:.0f}s"
                        )
                        last_print = current_time
                    del enc
                except Exception:
                    continue

        total_time = time.time() - start_time
        print(f"[WARMUP] Completed in {total_time:.1f}s ({processed/total_time:.1f} sents/s)")

        try:
            lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
            if lock:
                with lock:
                    stores = dict(dscd.prototype_stores)
            else:
                stores = dict(dscd.prototype_stores)
            num_types = len(stores)
            total_protos = sum(_get_store_size(store) for store in stores.values())
            multi = sum(1 for store in stores.values() if _get_store_size(store) >= 2)
            print("[WARMUP] Summary:")
            print(f"  - Token types: {num_types}")
            print(f"  - Total prototypes: {total_protos}")
            print(f"  - Multi-sense tokens: {multi}")
            if num_types > 0:
                print(f"  - Multi-sense ratio: {multi/num_types:.1%}")
            dscd_homographs = get_dscd_homographs(model)
            print(f"[WARMUP] Discovered Homographs: {len(dscd_homographs)}")
            reference_found = dscd_homographs.intersection(_HOMOGRAPH_REFERENCE_LIST)
            print(f"[WARMUP] Reference found: {len(reference_found)} / {len(_HOMOGRAPH_REFERENCE_LIST)}")
        except Exception:
            pass

        try:
            if hasattr(dscd, "enable_training_clustering") and orig_enable is not None:
                dscd.enable_training_clustering = orig_enable
            if hasattr(dscd, "nmin") and orig_nmin is not None:
                dscd.nmin = orig_nmin
            if hasattr(dscd, "buffer_size") and orig_buffer is not None:
                dscd.buffer_size = orig_buffer
        except Exception:
            pass

    except Exception:
        if _DEBUG_DISCOVERY:
            try:
                traceback.print_exc()
            except Exception:
                pass


print("=" * 80)
print("Cell 8: Inference pipeline ready - DATAPARALLEL COMPATIBLE")
print("=" * 80)
print("Configuration:")
print(f"  - Source language: {_SOURCE_LANGUAGE}")
print(f"  - Target language: {_TARGET_LANGUAGE}")
print(f"  - Max length: {_MAX_LENGTH}")
print(f"  - Span threshold: {_SPAN_THRESHOLD}")
print(f"  - TAU_LOW: {_TAU_LOW}")
print(f"  - Uncertainty threshold: {_UNCERTAINTY_THRESHOLD}")
print(f"  - TRG uncertainty threshold: {_TRG_UNCERTAINTY_THRESHOLD}")
print(f"  - Max explanations per sentence: {_MAX_EXPLANATIONS_PER_SENTENCE}")
print(f"  - Vocab size: {_VOCAB_SIZE}")
print("=" * 80)


In [None]:
# ==============================================================================
# CELL 9: COMPREHENSIVE TESTING & EVALUATION WITH TOKEN VALIDATION - FIXED
# ==============================================================================

from typing import Dict, List, Tuple, Optional, Any
import torch
import traceback
import time
import functools
import math
from collections import defaultdict

try:
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
except Exception:
    _USE_MULTI_GPU = torch.cuda.is_available() and torch.cuda.device_count() > 1

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
except Exception:
    _SOURCE_LANGUAGE = "bn"

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except Exception:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except Exception:
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except Exception:
    _DEBUG_TIMING = False

try:
    _SPAN_THRESHOLD = float(globals().get("TRG_SPAN_THRESHOLD", SPAN_THRESHOLD))
except Exception:
    try:
        _SPAN_THRESHOLD = float(SPAN_THRESHOLD)
    except Exception:
        _SPAN_THRESHOLD = 0.15

try:
    _TAU_LOW = float(TAU_LOW)
except Exception:
    _TAU_LOW = 0.25

try:
    _UNCERTAINTY_THRESHOLD = float(globals().get("TRG_UNCERTAINTY_THRESHOLD", _TAU_LOW))
except Exception:
    _UNCERTAINTY_THRESHOLD = _TAU_LOW

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except Exception:
    _MAX_LENGTH = 48

try:
    _DEVICE = DEVICE
    if not isinstance(_DEVICE, torch.device):
        _DEVICE = torch.device(str(_DEVICE))
except Exception:
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    _VOCAB_SIZE = int(VOCAB_SIZE)
except Exception:
    _VOCAB_SIZE = 128112

try:
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in HOMOGRAPH_REFERENCE_LIST_BN)
except Exception:
    _HOMOGRAPH_REFERENCE_LIST = {
        "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা",
        "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত"
    }
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in _HOMOGRAPH_REFERENCE_LIST)


def _get_store_size(store) -> int:
    try:
        if hasattr(store, "size") and callable(getattr(store, "size")):
            return int(store.size())
        centroids = getattr(store, "centroids", None)
        if centroids is not None:
            if isinstance(centroids, torch.Tensor):
                return int(centroids.size(0))
            elif isinstance(centroids, list):
                return len(centroids)
        return 0
    except Exception:
        return 0


def _get_cluster_count(model: torch.nn.Module) -> int:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return 0
        lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
        if lock:
            with lock:
                stores = getattr(dscd, "prototype_stores", {}) or {}
                return len(stores)
        else:
            stores = getattr(dscd, "prototype_stores", {}) or {}
            return len(stores)
    except Exception:
        return 0


def _get_dscd_homographs(model: torch.nn.Module) -> set:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return set()

        lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
        word_prototype_counts = defaultdict(int)

        if lock:
            with lock:
                prototype_stores = getattr(dscd, "prototype_stores", {}) or {}
                items = list(prototype_stores.items())
        else:
            prototype_stores = getattr(dscd, "prototype_stores", {}) or {}
            items = list(prototype_stores.items())

        for token_key, store in items:
            try:
                num_protos = _get_store_size(store)
                clean_token = (
                    str(token_key)
                    .replace("▁", "")
                    .replace("Ġ", "")
                    .replace("##", "")
                    .replace("@@", "")
                    .replace("</w>", "")
                    .strip()
                    .lower()
                )
                if clean_token:
                    word_prototype_counts[clean_token] = max(
                        word_prototype_counts[clean_token], num_protos
                    )
            except Exception:
                continue

        homographs = {w for w, c in word_prototype_counts.items() if c >= 2}
        return homographs
    except Exception:
        return set()


def _print_top_clusters(model: torch.nn.Module, top_n: int = 5):
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return

        lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
        if lock:
            with lock:
                prototype_stores = dict(getattr(dscd, "prototype_stores", {}) or {})
        else:
            prototype_stores = dict(getattr(dscd, "prototype_stores", {}) or {})

        if not prototype_stores:
            print("[CLUSTER] No clusters found yet")
            return

        cluster_info = []
        for token, store in prototype_stores.items():
            try:
                try:
                    total_count = sum(getattr(store, "counts", []) or [])
                except Exception:
                    total_count = 0

                n_protos = _get_store_size(store)

                is_valid = True
                try:
                    if not getattr(store, "centroids", None):
                        is_valid = False
                    if (
                        not getattr(store, "counts", None)
                        or sum(getattr(store, "counts", [])) <= 0
                    ):
                        is_valid = False
                except Exception:
                    is_valid = False

                if is_valid:
                    cluster_info.append(
                        {
                            "token": token,
                            "count": total_count,
                            "protos": n_protos,
                            "mu": getattr(store, "mu", 0.0),
                            "tau": getattr(store, "tau", 0.0),
                        }
                    )
            except Exception:
                continue

        if not cluster_info:
            print("[CLUSTER] No valid clusters to display")
            return

        cluster_info.sort(key=lambda x: x["count"], reverse=True)

        print(f"\n[CLUSTER] Top {min(top_n, len(cluster_info))} clusters:")
        print("-" * 90)
        print(f"{'Rank':<6}{'Token':<15}{'Count':<12}{'Protos':<10}{'Mu':<15}{'Tau':<12}")
        print("-" * 90)
        for rank, info in enumerate(cluster_info[:top_n], 1):
            token_str = str(info["token"])
            token_display = token_str[:12] if len(token_str) > 12 else token_str
            print(
                f"{rank:<6}{token_display:<15}{info['count']:<12}"
                f"{info['protos']:<10}{info['mu']:<15.6f}{info['tau']:<12.6f}"
            )
        print("-" * 90)
    except Exception:
        if _DEBUG_DISCOVERY:
            try:
                traceback.print_exc()
            except Exception:
                pass


def _timed(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if _DEBUG_TIMING:
            start = time.time()
            result = func(*args, **kwargs)
            elapsed = time.time() - start
            print(f"[TIMING] {func.__name__}: {elapsed:.2f}s")
            return result
        else:
            return func(*args, **kwargs)

    return wrapper


def _safe_float(x: Any, default: float = 0.0) -> float:
    try:
        if x is None:
            return float(default)
        if isinstance(x, torch.Tensor):
            val = float(x.item())
        else:
            val = float(x)
        if not math.isfinite(val):
            return float(default)
        return val
    except Exception:
        return float(default)


def _is_valid_explanation_item(expl: Dict[str, Any]) -> bool:
    if not isinstance(expl, dict):
        return False
    span = _safe_float(expl.get("span", 0.0), 0.0)
    unc = _safe_float(expl.get("uncertainty", 0.0), 0.0)
    conf = expl.get("confidence", None)
    if conf is None:
        conf = max(span, 1.0 - unc)
    conf = _safe_float(conf, 0.0)
    return (span > 1e-3) or (unc > 1e-3) or (conf > 0.4)


@torch.inference_mode()
@_timed
def comprehensive_post_training_testing(
    model: torch.nn.Module,
    tokenizer,
    run_warmup: bool = True,
    compare_baseline: bool = False,
    baseline_metrics: Optional[Dict[str, Any]] = None,
) -> Dict[str, Any]:
    print("\n" + "=" * 80)
    print("COMPREHENSIVE POST-TRAINING EVALUATION (Pure Data-Driven)")
    print("=" * 80)

    test_sentences: List[Tuple[str, str, str, List[str]]] = [
        ("আমি কল বন্ধ করেছি।", "I turned off the tap", "কল = tap/call", ["কল"]),
        ("কাল আমি বই কিনব।", "Tomorrow I will buy a book", "কাল = tomorrow/yesterday", ["কাল"]),
        ("পাতা ঝরে পড়েছে।", "The leaf has fallen", "পাতা = leaf/page", ["পাতা"]),
        ("তিনি ব্যাংক গেছেন।", "He went to the bank", "ব্যাংক = bank/embankment", ["ব্যাংক"]),
        ("ফল খুব সুস্বাদু।", "The fruit is delicious", "ফল = fruit/result", ["ফল"]),
        ("মাথা ব্যথা করছে।", "Head is aching", "মাথা = head/top", ["মাথা"]),
        ("কল থেকে কল এসেছে।", "A call came from the tap", "Multiple কল", ["কল"]),
        ("কালকে কাল মেঘ দেখা গেছে।", "Yesterday black clouds were seen", "Multiple কাল", ["কাল"]),
        ("আজ ভাল আবহাওয়া।", "Weather is good today", "Simple", []),
        ("আমি ভালো আছি।", "I am fine", "Simple", []),
        ("সে খুব মিষ্টি কথা বলে।", "She speaks sweetly", "Simple", []),
        ("এটা আমার বই।", "This is my book", "Simple", []),
        (
            "তিনি ব্যাংকে কাজ করেন এবং ব্যাংকে বসে থাকেন।",
            "He works at the bank and sits on the embankment",
            "Long with multiple",
            ["ব্যাংক"],
        ),
    ]

    core_model = model.module if (_USE_MULTI_GPU and hasattr(model, "module")) else model
    core_model.eval()

    quality_metrics = {
        "total_confidence": 0.0,
        "confidence_samples": 0,
        "high_confidence_count": 0,
        "medium_confidence_count": 0,
        "low_confidence_count": 0,
        "confidences": [],
        "spans": [],
        "uncertainties": [],
    }

    homograph_tracking = {
        "test_expected_homographs": set(),
        "dscd_discovered_homographs": set(),
        "explained_homographs": set(),
        "homograph_explanations": defaultdict(list),
    }

    error_tracking = {
        "translation_failures": 0,
        "dscd_failures": 0,
        "trg_failures": 0,
        "timeout_errors": 0,
        "oom_errors": 0,
        "other_errors": 0,
        "error_details": [],
        "per_test_status": [],
    }

    timing_metrics = {
        "total_time": 0.0,
        "per_test_times": [],
        "avg_test_time": 0.0,
    }

    discovery_validated = False
    try:
        dscd = getattr(core_model, "dscd", None)
        if dscd and hasattr(dscd, "discovered_log") and dscd.discovered_log:
            discovery_validated = True
            if _DEBUG_DISCOVERY:
                last_discovery = dscd.discovered_log[-1]
                discovered = last_discovery.get("discovered", 0)
                candidates = last_discovery.get("candidates", 0)
                print(f"[EVAL] Discovery log: {discovered}/{candidates} homographs")
        else:
            if _DEBUG_DISCOVERY:
                print("[EVAL] No discovery log found")
    except Exception:
        if _DEBUG_DISCOVERY:
            try:
                traceback.print_exc()
            except Exception:
                pass

    asbn_stats: Dict[str, Any] = {}
    try:
        asbn = getattr(core_model, "asbn", None)
        if asbn and hasattr(asbn, "get_detailed_stats"):
            result = asbn.get_detailed_stats()
            if isinstance(result, dict):
                asbn_stats = result
        elif asbn and hasattr(asbn, "get_asbn_stats"):
            result = asbn.get_asbn_stats()
            if isinstance(result, dict):
                asbn_stats = result
        
        if asbn_stats:
            domain_acc = _safe_float(asbn_stats.get("domain_accuracy", 0.0), 0.0)
            if _DEBUG_DISCOVERY:
                print(f"[EVAL] ASBN: domain_acc={domain_acc:.2%}")
    except Exception:
        if _DEBUG_DISCOVERY:
            try:
                traceback.print_exc()
            except Exception:
                pass

    trg_stats: Dict[str, Any] = {}
    try:
        trg = getattr(core_model, "trg_system", None)
        if trg and hasattr(trg, "get_statistics"):
            result = trg.get_statistics()
            if isinstance(result, dict):
                trg_stats = result
            if _DEBUG_DISCOVERY and trg_stats:
                exp_gen = int(trg_stats.get("explanations_generated", 0))
                print(f"[EVAL] TRG: {exp_gen} total")
    except Exception:
        if _DEBUG_DISCOVERY:
            try:
                traceback.print_exc()
            except Exception:
                pass

    homograph_tracking["dscd_discovered_homographs"] = _get_dscd_homographs(core_model)
    print(f"[EVAL] DSCD discovered: {len(homograph_tracking['dscd_discovered_homographs'])} homographs")
    if homograph_tracking["dscd_discovered_homographs"] and _DEBUG_DISCOVERY:
        print(f"[EVAL] Sample: {list(homograph_tracking['dscd_discovered_homographs'])[:10]}")

    if run_warmup:
        try:
            dscd = getattr(core_model, "dscd", None)
            if dscd is not None:
                lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
                if lock:
                    with lock:
                        stores = getattr(dscd, "prototype_stores", None)
                        store_count = len(stores) if stores else 0
                else:
                    stores = getattr(dscd, "prototype_stores", None)
                    store_count = len(stores) if stores else 0
                if store_count == 0 and "dscd_discovery_warmup" in globals():
                    print("[EVAL] Running warmup (num_sents=4000)...")
                    try:
                        dscd_discovery_warmup(
                            model,
                            tokenizer,
                            num_sents=4000,
                            batch_size=64,
                            max_len=_MAX_LENGTH,
                        )
                        homograph_tracking[
                            "dscd_discovered_homographs"
                        ] = _get_dscd_homographs(core_model)
                    except Exception:
                        print("[EVAL] Warmup failed")
        except Exception:
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

    total_tests = len(test_sentences)
    successful_translations = 0
    total_explanations = 0
    total_high_span = 0
    total_real_ambiguous = 0

    print(f"\n[EVAL] Running {total_tests} tests...")
    print("-" * 80)

    try:
        tokenizer.src_lang = _SOURCE_LANGUAGE
    except Exception:
        pass

    def _is_real_amb(expl: Dict[str, Any]) -> bool:
        try:
            span = _safe_float(expl.get("span", 0.0), 0.0)
            u = _safe_float(expl.get("uncertainty", 0.0), 0.0)
            conf = expl.get("confidence", None)
            if conf is None:
                conf = max(span, 1.0 - u)
            conf = _safe_float(conf, 0.0)
            return (span > _SPAN_THRESHOLD) or (u < _UNCERTAINTY_THRESHOLD and conf >= 0.6)
        except Exception:
            return False

    def _compute_similarity(pred: str, expected: str) -> float:
        try:
            if not isinstance(pred, str) or not isinstance(expected, str):
                return 0.0
            pred_words = set(pred.lower().strip().split())
            exp_words = set(expected.lower().strip().split())
            if not exp_words:
                return 0.0
            overlap = len(pred_words & exp_words)
            return overlap / len(exp_words)
        except Exception:
            return 0.0

    for _, _, _, expected_homos in test_sentences:
        for h in expected_homos:
            clean_h = h.strip().lower()
            if clean_h:
                homograph_tracking["test_expected_homographs"].add(clean_h)

    eval_start = time.time()

    for idx, (src_text, expected_translation, desc, expected_homos) in enumerate(
        test_sentences, 1
    ):
        test_start = time.time()

        print(f"\nTest {idx}/{total_tests}: {desc}")
        print("=" * 60)

        test_status = {
            "test_id": idx,
            "success": False,
            "translation_ok": False,
            "explanations_count": 0,
            "error": None,
        }

        try:
            if "translate_with_explanations" not in globals():
                print("[EVAL] translate_with_explanations not available")
                error_tracking["other_errors"] += 1
                test_status["error"] = "function_not_available"
                error_tracking["per_test_status"].append(test_status)
                continue

            result = translate_with_explanations(
                core_model if core_model is not None else model,
                tokenizer,
                src_text,
                device=_DEVICE,
                span_threshold=_SPAN_THRESHOLD,
                uncertainty_threshold=_UNCERTAINTY_THRESHOLD,
                track_stats=False,
            )

            if not isinstance(result, dict):
                print("[EVAL] Invalid result type")
                error_tracking["other_errors"] += 1
                test_status["error"] = "invalid_result"
                error_tracking["per_test_status"].append(test_status)
                continue

            translation = str(result.get("translation", "") or "")
            raw_explanations = result.get("explanations", []) or []

            filtered_explanations: List[Dict[str, Any]] = []
            seen_tokens_positions = set()
            for ex in raw_explanations:
                try:
                    span_val = _safe_float(ex.get("span", 0.0), 0.0)
                    unc_val = _safe_float(ex.get("uncertainty", 0.0), 0.0)
                    conf_val = ex.get("confidence", None)
                    if conf_val is None:
                        conf_val = max(span_val, 1.0 - unc_val)
                    conf_val = _safe_float(conf_val, 0.0)

                    token_word = ex.get("ambiguous_word", ex.get("token", ""))
                    if not isinstance(token_word, str):
                        token_word = str(token_word)
                    token_word_clean = (
                        token_word.replace("▁", "")
                        .replace("Ġ", "")
                        .replace("##", "")
                        .replace("@@", "")
                        .replace("</w>", "")
                        .strip()
                    )
                    
                    if not token_word_clean:
                        continue
                    
                    pos = ex.get("position", ex.get("token_idx", None))
                    try:
                        pos_i = int(pos) if pos is not None else None
                    except Exception:
                        pos_i = None

                    dedupe_key = (token_word_clean.lower(), pos_i)
                    if dedupe_key in seen_tokens_positions:
                        continue
                    seen_tokens_positions.add(dedupe_key)

                    if _is_valid_explanation_item(
                        {"span": span_val, "uncertainty": unc_val, "confidence": conf_val}
                    ):
                        cleaned = {
                            "ambiguous_word": token_word_clean,
                            "position": pos_i,
                            "explanation": str(
                                ex.get("explanation", ex.get("explain", ""))
                            )[:512],
                            "uncertainty": float(unc_val),
                            "span": float(span_val),
                            "confidence": float(conf_val),
                            "is_raw": ex,
                        }
                        filtered_explanations.append(cleaned)
                except Exception:
                    continue

            amb_count = len(filtered_explanations)
            explanations = filtered_explanations

            similarity = _compute_similarity(translation, expected_translation)

            print(f"Input: {src_text}")
            print(f"Expected: {expected_translation}")
            print(f"Translation: {translation}")
            print(f"Similarity: {similarity:.1%}")
            print(f"Ambiguous: {amb_count}")

            if explanations:
                print("\nExplanations:")
                high_span_local = 0
                real_amb_local = 0
                for j, expl in enumerate(explanations, 1):
                    try:
                        span_val = float(expl.get("span", 0.0))
                        u_val = float(expl.get("uncertainty", 0.0))
                        conf_val = float(expl.get("confidence", 0.0))
                    except Exception:
                        span_val = 0.0
                        u_val = 0.0
                        conf_val = 0.0

                    marker = (
                        f"[S>{_SPAN_THRESHOLD:.2f}]" if span_val > _SPAN_THRESHOLD else "          "
                    )
                    word = expl.get("ambiguous_word", "N/A")
                    pos = expl.get("position", "N/A")

                    print(f"  {j}. {marker} '{word}' @ {pos}")
                    print(f"       conf={conf_val:.3f} | U={u_val:.3f} | S={span_val:.3f}")
                    text = str(expl.get("explanation", ""))
                    if len(text) > 120:
                        text = text[:120] + "..."
                    print(f"       {text}")

                    quality_metrics["confidences"].append(conf_val)
                    quality_metrics["spans"].append(span_val)
                    quality_metrics["uncertainties"].append(u_val)
                    quality_metrics["total_confidence"] += conf_val
                    quality_metrics["confidence_samples"] += 1

                    if conf_val >= 0.65:
                        quality_metrics["high_confidence_count"] += 1
                    elif conf_val >= 0.4:
                        quality_metrics["medium_confidence_count"] += 1
                    else:
                        quality_metrics["low_confidence_count"] += 1

                    if span_val > _SPAN_THRESHOLD:
                        high_span_local += 1
                    if _is_real_amb(expl):
                        real_amb_local += 1

                    clean_word = (
                        str(word)
                        .replace("▁", "")
                        .replace("Ġ", "")
                        .replace("##", "")
                        .replace("@@", "")
                        .replace("</w>", "")
                        .strip()
                        .lower()
                    )
                    if clean_word:
                        homograph_tracking["explained_homographs"].add(clean_word)
                        homograph_tracking["homograph_explanations"][clean_word].append(
                            {
                                "sentence": src_text,
                                "confidence": conf_val,
                                "span": span_val,
                                "uncertainty": u_val,
                            }
                        )

                total_explanations += len(explanations)
                total_high_span += high_span_local
                total_real_ambiguous += real_amb_local
                test_status["explanations_count"] = len(explanations)
            else:
                print("No explanations")

            sim_threshold = 0.25
            is_non_error = translation not in (
                "Error occurred",
                "Translation generation failed",
                "ERROR DURING TRANSLATION",
            )
            if (
                translation
                and translation.strip()
                and is_non_error
                and similarity >= sim_threshold
            ):
                successful_translations += 1
                test_status["translation_ok"] = True
                test_status["success"] = True
                print("Success")
            else:
                print("Translation failed")
                error_tracking["translation_failures"] += 1
                test_status["error"] = "translation_failed"

        except RuntimeError as e:
            error_str = str(e).lower()
            if "out of memory" in error_str:
                print(f"[EVAL] OOM: {str(e)[:100]}")
                error_tracking["oom_errors"] += 1
                test_status["error"] = "oom"
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
            elif "timeout" in error_str:
                print(f"[EVAL] Timeout: {str(e)[:100]}")
                error_tracking["timeout_errors"] += 1
                test_status["error"] = "timeout"
            else:
                print(f"[EVAL] Runtime: {type(e).__name__}")
                error_tracking["other_errors"] += 1
                test_status["error"] = "runtime"
            error_tracking["error_details"].append(f"Test {idx}: {type(e).__name__}")
        except Exception as e:
            print(f"[EVAL] Error: {type(e).__name__}")
            error_tracking["other_errors"] += 1
            test_status["error"] = type(e).__name__
            error_tracking["error_details"].append(f"Test {idx}: {type(e).__name__}")
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        error_tracking["per_test_status"].append(test_status)

        test_time = time.time() - test_start
        timing_metrics["per_test_times"].append(test_time)

        print("-" * 60)

    timing_metrics["total_time"] = time.time() - eval_start
    timing_metrics["avg_test_time"] = (
        sum(timing_metrics["per_test_times"]) / len(timing_metrics["per_test_times"])
        if timing_metrics["per_test_times"]
        else 0.0
    )

    if quality_metrics["confidence_samples"] > 0:
        quality_metrics["avg_confidence"] = (
            quality_metrics["total_confidence"] / quality_metrics["confidence_samples"]
        )
        quality_metrics["avg_span"] = (
            sum(quality_metrics["spans"]) / len(quality_metrics["spans"])
            if quality_metrics["spans"]
            else 0.0
        )
        quality_metrics["avg_uncertainty"] = (
            sum(quality_metrics["uncertainties"])
            / len(quality_metrics["uncertainties"])
            if quality_metrics["uncertainties"]
            else 0.0
        )
        if quality_metrics["confidences"]:
            sorted_conf = sorted(quality_metrics["confidences"])
            n = len(sorted_conf)
            quality_metrics["confidence_p25"] = sorted_conf[max(0, n // 4)]
            quality_metrics["confidence_p50"] = sorted_conf[max(0, n // 2)]
            quality_metrics["confidence_p75"] = sorted_conf[min(n - 1, 3 * n // 4)]
    else:
        quality_metrics["avg_confidence"] = 0.0
        quality_metrics["avg_span"] = 0.0
        quality_metrics["avg_uncertainty"] = 0.0

    explained_from_dscd = set()
    if homograph_tracking.get("explained_homographs") and homograph_tracking.get(
        "dscd_discovered_homographs"
    ):
        explained_from_dscd = homograph_tracking["explained_homographs"].intersection(
            homograph_tracking["dscd_discovered_homographs"]
        )

    test_expected_discovered = set()
    if homograph_tracking.get("test_expected_homographs") and homograph_tracking.get(
        "dscd_discovered_homographs"
    ):
        test_expected_discovered = homograph_tracking[
            "test_expected_homographs"
        ].intersection(homograph_tracking["dscd_discovered_homographs"])

    reference_discovered = set()
    if _HOMOGRAPH_REFERENCE_LIST and homograph_tracking.get(
        "dscd_discovered_homographs"
    ):
        reference_discovered = _HOMOGRAPH_REFERENCE_LIST.intersection(
            homograph_tracking["dscd_discovered_homographs"]
        )

    homograph_tracking["explained_from_dscd_rate"] = (
        len(explained_from_dscd) / len(homograph_tracking["dscd_discovered_homographs"])
        if homograph_tracking.get("dscd_discovered_homographs")
        else 0.0
    )
    homograph_tracking["test_expected_discovery_rate"] = (
        len(test_expected_discovered)
        / len(homograph_tracking["test_expected_homographs"])
        if homograph_tracking.get("test_expected_homographs")
        else 0.0
    )
    homograph_tracking["reference_discovery_rate"] = (
        len(reference_discovered) / len(_HOMOGRAPH_REFERENCE_LIST)
        if _HOMOGRAPH_REFERENCE_LIST
        else 0.0
    )

    try:
        dscd_stats = {
            "total_words": 0,
            "multi_sense_words": 0,
            "total_prototypes": 0,
            "corrupted_stores": 0,
        }
        dscd = getattr(core_model, "dscd", None)
        if dscd is not None and hasattr(dscd, "prototype_stores"):
            lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
            if lock:
                with lock:
                    stores = dict(getattr(dscd, "prototype_stores") or {})
            else:
                stores = dict(getattr(dscd, "prototype_stores") or {})

            total_words = 0
            multi = 0
            total_protos = 0
            corrupted = 0

            for _, store in stores.items():
                sz = _get_store_size(store)
                
                is_valid = True
                try:
                    if (
                        not getattr(store, "centroids", None)
                        or len(getattr(store, "centroids", [])) == 0
                    ):
                        is_valid = False
                    if (
                        not getattr(store, "counts", None)
                        or sum(getattr(store, "counts", [])) <= 0
                    ):
                        is_valid = False
                    if hasattr(store, "mu"):
                        mu_val = getattr(store, "mu")
                        if isinstance(mu_val, (int, float)):
                            if not math.isfinite(mu_val) or mu_val < 0 or mu_val > 10:
                                is_valid = False
                except Exception:
                    is_valid = False

                if not is_valid:
                    corrupted += 1
                    continue

                total_words += 1
                total_protos += sz
                if sz >= 2:
                    multi += 1

            dscd_stats = {
                "total_words": total_words,
                "multi_sense_words": multi,
                "total_prototypes": total_protos,
                "corrupted_stores": corrupted,
            }
    except Exception:
        dscd_stats = {
            "total_words": 0,
            "multi_sense_words": 0,
            "total_prototypes": 0,
            "corrupted_stores": 0,
        }

    print("\n" + "=" * 80)
    print("COMPREHENSIVE EVALUATION SUMMARY")
    print("=" * 80)

    print(f"\n[TRANSLATION QUALITY]")
    print(f"  Total tests: {total_tests}")
    print(f"  Successful: {successful_translations}")
    print(
        f"  Success rate: {(successful_translations / total_tests * 100.0) if total_tests > 0 else 0.0:.1f}%"
    )

    print(f"\n[AMBIGUITY DETECTION]")
    print(f"  Total explanations: {total_explanations}")
    print(f"  High-span (S>{_SPAN_THRESHOLD}): {total_high_span}")
    print(f"  Real ambiguous: {total_real_ambiguous}")
    if total_tests > 0:
        print(f"  Avg explanations/test: {total_explanations / total_tests:.2f}")

    print(f"\n[EXPLANATION QUALITY]")
    print(f"  Avg confidence: {quality_metrics.get('avg_confidence', 0.0):.3f}")
    print(f"  Avg span: {quality_metrics.get('avg_span', 0.0):.3f}")
    print(f"  Avg uncertainty: {quality_metrics.get('avg_uncertainty', 0.0):.3f}")

    if "confidence_p50" in quality_metrics:
        print(
            f"  Confidence P25/P50/P75: {quality_metrics.get('confidence_p25', 0):.3f} / "
            f"{quality_metrics.get('confidence_p50', 0):.3f} / "
            f"{quality_metrics.get('confidence_p75', 0):.3f}"
        )

    print(f"  High (>=0.65): {quality_metrics['high_confidence_count']}")
    print(f"  Medium (0.4-0.65): {quality_metrics['medium_confidence_count']}")
    print(f"  Low (<0.4): {quality_metrics['low_confidence_count']}")

    print(f"\n[HOMOGRAPH DISCOVERY]")
    print(f"  DSCD discovered: {len(homograph_tracking['dscd_discovered_homographs'])}")
    print(f"  Explained: {len(homograph_tracking['explained_homographs'])}")
    print(f"  Explanation rate: {homograph_tracking['explained_from_dscd_rate']:.1%}")
    print(f"  Test discovery rate: {homograph_tracking['test_expected_discovery_rate']:.1%}")

    if homograph_tracking["explained_homographs"]:
        print(f"\n  Explained homographs (top 10):")
        for homo in sorted(homograph_tracking["explained_homographs"])[:10]:
            exps = homograph_tracking["homograph_explanations"].get(homo, [])
            count = len(exps)
            avg_conf = (
                sum(e["confidence"] for e in exps) / len(exps) if exps else 0.0
            )
            in_dscd = (
                "[D]" if homo in homograph_tracking["dscd_discovered_homographs"] else "   "
            )
            in_ref = "[R]" if homo in _HOMOGRAPH_REFERENCE_LIST else "   "
            print(f"    {in_dscd} {in_ref} '{homo}': {count} x conf={avg_conf:.3f}")

    print(f"\n[REFERENCE COMPARISON]")
    print(f"  Reference: {len(_HOMOGRAPH_REFERENCE_LIST)} words")
    print(
        f"  Discovered: {len(reference_discovered)}/{len(_HOMOGRAPH_REFERENCE_LIST)}"
    )
    print(f"  Coverage: {homograph_tracking['reference_discovery_rate']:.1%}")

    print(f"\n[DSCD PROTOTYPES]")
    print(f"  Word types: {dscd_stats['total_words']}")
    print(f"  Multi-sense: {dscd_stats['multi_sense_words']}")
    print(f"  Total prototypes: {dscd_stats['total_prototypes']}")
    if dscd_stats.get("corrupted_stores", 0) > 0:
        print(f"  Corrupted stores: {dscd_stats['corrupted_stores']}")
    if dscd_stats["total_words"] > 0:
        print(
            f"  Multi-sense ratio: {dscd_stats['multi_sense_words'] / dscd_stats['total_words']:.1%}"
        )

    if asbn_stats:
        print(f"\n[ASBN]")
        domain_acc = _safe_float(asbn_stats.get("domain_accuracy", 0.0), 0.0)
        print(f"  Domain accuracy: {domain_acc:.2%}")
        if "source_accuracy" in asbn_stats:
            src_acc = _safe_float(asbn_stats.get("source_accuracy", 0.0), 0.0)
            tgt_acc = _safe_float(asbn_stats.get("target_accuracy", 0.0), 0.0)
            print(f"  Source accuracy: {src_acc:.2%}")
            print(f"  Target accuracy: {tgt_acc:.2%}")

    if trg_stats:
        print(f"\n[TRG]")
        exp_gen = int(trg_stats.get("explanations_generated", 0))
        hc_rate = _safe_float(trg_stats.get("high_confidence_rate", 0.0), 0.0)
        print(f"  Total explanations: {exp_gen}")
        print(f"  High confidence: {hc_rate:.1%}")

    print(f"\n[PERFORMANCE]")
    print(f"  Total time: {timing_metrics['total_time']:.2f}s")
    print(f"  Avg time/test: {timing_metrics['avg_test_time']:.2f}s")

    total_errors = sum(
        [
            error_tracking["translation_failures"],
            error_tracking["dscd_failures"],
            error_tracking["trg_failures"],
            error_tracking["timeout_errors"],
            error_tracking["oom_errors"],
            error_tracking["other_errors"],
        ]
    )

    if total_errors > 0:
        print(f"\n[ERRORS]")
        print(f"  Total: {total_errors}")
        print(f"  Translation: {error_tracking['translation_failures']}")
        print(f"  OOM: {error_tracking['oom_errors']}")
        print(f"  Other: {error_tracking['other_errors']}")

    if compare_baseline and baseline_metrics and isinstance(baseline_metrics, dict):
        print(f"\n[BASELINE COMPARISON]")
        try:
            baseline_success = _safe_float(baseline_metrics.get("success_rate_pct", 0.0), 0.0)
            current_success = (
                successful_translations / total_tests * 100.0 if total_tests > 0 else 0.0
            )
            success_delta = current_success - baseline_success
            
            baseline_expl = int(baseline_metrics.get("total_explanations", 0))
            expl_delta = total_explanations - baseline_expl
            
            baseline_qm = baseline_metrics.get("quality_metrics", {})
            if isinstance(baseline_qm, dict):
                baseline_quality = _safe_float(baseline_qm.get("avg_confidence", 0.0), 0.0)
            else:
                baseline_quality = 0.0
            quality_delta = quality_metrics.get("avg_confidence", 0.0) - baseline_quality
            
            print(
                f"  Translation: {current_success:.1f}% ({success_delta:+.1f}%)"
            )
            print(f"  Explanations: {total_explanations} ({expl_delta:+d})")
            print(
                f"  Confidence: {quality_metrics.get('avg_confidence', 0.0):.3f} "
                f"({quality_delta:+.3f})"
            )
            
            baseline_ht = baseline_metrics.get("homograph_tracking", {})
            if isinstance(baseline_ht, dict):
                baseline_homo_rate = _safe_float(baseline_ht.get("explained_from_dscd_rate", 0.0), 0.0)
                homo_delta = (
                    homograph_tracking["explained_from_dscd_rate"]
                    - baseline_homo_rate
                )
                print(
                    f"  Explanation rate: "
                    f"{homograph_tracking['explained_from_dscd_rate']:.1%} "
                    f"({homo_delta:+.1%})"
                )
        except Exception as e:
            print(f"  Comparison failed: {type(e).__name__}")
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

    warnings = []
    if successful_translations < total_tests * 0.5:
        warnings.append("High translation failure (>50%)")
    if total_explanations == 0:
        warnings.append("No explanations generated")
    if dscd_stats["total_words"] < 100:
        warnings.append("Very few prototypes (<100)")
    
    corrupted_threshold = float(dscd_stats["total_words"]) * 0.1
    if dscd_stats["total_words"] > 0 and dscd_stats.get("corrupted_stores", 0) > corrupted_threshold:
        warnings.append(f"High corruption rate ({dscd_stats.get('corrupted_stores', 0)} stores)")
    
    if quality_metrics["low_confidence_count"] > quality_metrics["high_confidence_count"]:
        warnings.append("More low than high confidence")
    if homograph_tracking["explained_from_dscd_rate"] < 0.3:
        warnings.append("Low explanation rate (<30%)")
    if not discovery_validated:
        warnings.append("Discovery log missing")
    
    if asbn_stats:
        asbn_domain_acc = _safe_float(asbn_stats.get("domain_accuracy", 0.0), 0.0)
        if asbn_domain_acc < 0.5:
            warnings.append("ASBN domain accuracy <50%")

    if warnings:
        print(f"\n[WARNINGS]")
        for w in warnings:
            print(f"  - {w}")
    else:
        print(f"\n[HEALTH] All systems nominal")

    print("=" * 80)

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return {
        "total_tests": total_tests,
        "successful_translations": successful_translations,
        "success_rate_pct": (
            successful_translations / total_tests * 100.0 if total_tests > 0 else 0.0
        ),
        "total_explanations": total_explanations,
        "total_high_span": total_high_span,
        "total_real_ambiguous": total_real_ambiguous,
        "dscd_stats": dscd_stats,
        "quality_metrics": quality_metrics,
        "homograph_tracking": homograph_tracking,
        "error_tracking": error_tracking,
        "asbn_stats": asbn_stats,
        "trg_stats": trg_stats,
        "discovery_validated": discovery_validated,
        "timing_metrics": timing_metrics,
    }


def test_evaluation_pipeline(model, tokenizer) -> bool:
    print("\n" + "=" * 60)
    print("[TEST] Testing evaluation pipeline")
    print("=" * 60)

    try:
        result = comprehensive_post_training_testing(
            model, tokenizer, run_warmup=False, compare_baseline=False
        )
        assert "total_tests" in result
        assert "quality_metrics" in result
        assert "homograph_tracking" in result
        print("Evaluation pipeline test passed")
        print("=" * 60 + "\n")
        return True
    except Exception as e:
        print(f"Evaluation pipeline test failed: {e}")
        try:
            traceback.print_exc()
        except Exception:
            pass
        print("=" * 60 + "\n")
        return False


print("\n" + "=" * 80)
print("Cell 9: Testing & evaluation ready - TOKEN VALIDATION FIXED")
print("=" * 80)
print("FIXES APPLIED:")
print("  ✓ Added _DEVICE type validation (torch.device)")
print("  ✓ Added _VOCAB_SIZE constant for token validation")
print("  ✓ Added empty token validation in explanation filtering")
print("  ✓ Added type validation in _compute_similarity()")
print("  ✓ Added math.isfinite() check for mu values in DSCD stores")
print("  ✓ Added OOM recovery with torch.cuda.empty_cache()")
print("  ✓ Enhanced baseline comparison with dict type checks")
print("  ✓ All existing fixes from previous version preserved")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 10: TATN MAIN PIPELINE (FINAL INTEGRATION) - FIXED
# ==============================================================================
import os
import time
import traceback
import math
from typing import Tuple, Optional, Dict, Any, List
import gc
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

def _g(name, default):
    try:
        return globals().get(name, default)
    except Exception:
        return default

def _safe_float(x: Any, default: float = 0.0) -> float:
    try:
        if x is None:
            return float(default)
        if isinstance(x, torch.Tensor):
            val = float(x.item())
        else:
            val = float(x)
        if not math.isfinite(val):
            return float(default)
        return val
    except Exception:
        return float(default)

def _safe_int(x: Any, default: int = 0) -> int:
    try:
        if x is None:
            return default
        return int(x)
    except Exception:
        return default

try:
    _USE_MULTI_GPU = bool(_g("USE_MULTI_GPU", False))
    _NUM_GPUS = int(_g("NUM_GPUS", torch.cuda.device_count() if torch.cuda.is_available() else 0))

    raw_device = _g("DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    if isinstance(raw_device, torch.device):
        _DEVICE = raw_device
    else:
        try:
            _DEVICE = torch.device(str(raw_device))
        except Exception:
            _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    _SOURCE_LANGUAGE = str(_g("SOURCE_LANGUAGE", "bn"))
    _TARGET_LANGUAGE = str(_g("TARGET_LANGUAGE", "en"))
    _NUM_SAMPLES = int(_g("NUM_SAMPLES", 30000))
    _MAX_LENGTH = int(_g("MAX_LENGTH", 48))
    _BATCH_SIZE = int(_g("BATCH_SIZE", 8))
    _EPOCHS = int(_g("EPOCHS", 1))
    _ACCUMULATION_STEPS = int(_g("ACCUMULATION_STEPS", 1))
    _LR_NMT = float(_g("LR_NMT", 2e-5))
    _LR_PHI = float(_g("LR_PHI", 1e-5))
    _ENABLE_ASBN_TRAINING = bool(_g("ENABLE_ASBN_TRAINING", True))
    _ENABLE_TRG_INFERENCE = bool(_g("ENABLE_TRG_INFERENCE", True))
    _VALIDATION_CHECK_INTERVAL = int(_g("VALIDATION_CHECK_INTERVAL", 500))
    _PERIODIC_DISCOVERY_FREQUENCY = int(_g("PERIODIC_DISCOVERY_FREQUENCY", 150))
    _DSCD_WARMUP_SAMPLES = int(_g("DSCD_WARMUP_SAMPLES", 4000))

    raw_homo = _g("HOMOGRAPH_REFERENCE_LIST_BN", ["কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা"])
    _HOMOGRAPH_REFERENCE_LIST_BN = set(str(w) for w in raw_homo)
    HOMOGRAPH_REFERENCE_LIST_BN = _HOMOGRAPH_REFERENCE_LIST_BN

    _FREEZE_ENCODER = bool(_g("FREEZE_ENCODER", False))
    _DEBUG_TIMING = bool(_g("DEBUG_TIMING", False))
    _VERBOSE_LOGGING = bool(_g("VERBOSE_LOGGING", False))
    _DEBUG_DISCOVERY = bool(_g("DEBUG_DISCOVERY", False))

    _M2M100_EN_TOKEN_ID = int(_g("M2M100_EN_TOKEN_ID", 128022))
    _M2M100_BN_TOKEN_ID = int(_g("M2M100_BN_TOKEN_ID", 128025))

    _SPAN_THRESHOLD = float(_g("SPAN_THRESHOLD", 0.15))
    _TAU_LOW = float(_g("TAU_LOW", 0.25))
    _UNCERTAINTY_THRESHOLD = float(_g("UNCERTAINTY_THRESHOLD", _TAU_LOW))
    _TRG_UNCERTAINTY_THRESHOLD = float(_g("TRG_UNCERTAINTY_THRESHOLD", _UNCERTAINTY_THRESHOLD))

except Exception:
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _SOURCE_LANGUAGE = "bn"
    _TARGET_LANGUAGE = "en"
    _NUM_SAMPLES = 30000
    _MAX_LENGTH = 48
    _BATCH_SIZE = 8
    _EPOCHS = 1
    _ACCUMULATION_STEPS = 1
    _LR_NMT = 2e-5
    _LR_PHI = 1e-5
    _ENABLE_ASBN_TRAINING = True
    _ENABLE_TRG_INFERENCE = True
    _VALIDATION_CHECK_INTERVAL = 500
    _PERIODIC_DISCOVERY_FREQUENCY = 150
    _DSCD_WARMUP_SAMPLES = 4000
    _HOMOGRAPH_REFERENCE_LIST_BN = {"কল", "কাল", "পাতা", "ব্যাংক"}
    HOMOGRAPH_REFERENCE_LIST_BN = _HOMOGRAPH_REFERENCE_LIST_BN
    _FREEZE_ENCODER = False
    _DEBUG_TIMING = False
    _VERBOSE_LOGGING = False
    _DEBUG_DISCOVERY = False
    _M2M100_EN_TOKEN_ID = 128022
    _M2M100_BN_TOKEN_ID = 128025
    _SPAN_THRESHOLD = 0.15
    _TAU_LOW = 0.25
    _UNCERTAINTY_THRESHOLD = 0.25
    _TRG_UNCERTAINTY_THRESHOLD = 0.25

_CHECKPOINT_DIR = _g("CHECKPOINT_DIR", "/kaggle/working")
_CHECKPOINT_PATH = os.path.join(_CHECKPOINT_DIR, _g("CHECKPOINT_FILENAME", "tatn_final.pt"))

def _safe_clear_gpu_caches():
    try:
        if "clear_all_gpu_caches" in globals() and callable(globals()["clear_all_gpu_caches"]):
            globals()["clear_all_gpu_caches"]()
            return
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                try:
                    with torch.cuda.device(i):
                        torch.cuda.empty_cache()
                except Exception:
                    pass
        if gc.isenabled():
            gc.collect()
    except Exception:
        pass

def _safe_get(d: dict, *keys, default=None):
    if not isinstance(d, dict):
        return default
    v = d
    for k in keys:
        if not isinstance(v, dict):
            return default
        v = v.get(k, None)
        if v is None:
            return default
    return v

def _safe_tokenizer_from_pretrained(model_name: str, local_files_only: bool = False):
    try:
        from transformers import M2M100TokenizerFast as FastTok
        return FastTok.from_pretrained(model_name, local_files_only=local_files_only)
    except Exception:
        pass
    try:
        from transformers import M2M100Tokenizer
        return M2M100Tokenizer.from_pretrained(model_name, local_files_only=local_files_only)
    except Exception as e:
        raise RuntimeError(f"Failed to load tokenizer for {model_name}: {e}")

def initialize_environment():
    print("[PIPELINE] Initializing environment...")
    if torch.cuda.is_available():
        gcnt = torch.cuda.device_count()
        print(f"[PIPELINE] GPUs: {gcnt}")
        for i in range(gcnt):
            try:
                name = torch.cuda.get_device_name(i)
                mem = torch.cuda.get_device_properties(i).total_memory / 1024**3
                print(f"  GPU {i}: {name} ({mem:.1f} GB)")
            except Exception:
                print(f"  GPU {i}: Unknown")
        _safe_clear_gpu_caches()
    else:
        print("[PIPELINE] CPU only")
    return True

def _resolve_thresholds():
    span_candidates = [
        _g("TRG_SPAN_THRESHOLD", None),
        _g("SPAN_THRESHOLD", None),
        _g("TRG_SPAN", None),
    ]
    tau_candidates = [
        _g("TRG_UNCERTAINTY_THRESHOLD", None),
        _g("UNCERTAINTY_THRESHOLD", None),
        _g("TAU_LOW", None),
    ]
    span = next((v for v in span_candidates if v is not None), None)
    tau = next((v for v in tau_candidates if v is not None), None)
    try:
        span = float(span) if span is not None else 0.15
    except Exception:
        span = 0.15
    try:
        tau = float(tau) if tau is not None else 0.25
    except Exception:
        tau = 0.25
    return span, tau

def main_pipeline() -> Tuple[Any, Any]:
    print("\n" + "=" * 80)
    print("TATN MAIN PIPELINE - COMPLETE INTEGRATION")
    print("=" * 80)

    span_thresh, tau_low = _resolve_thresholds()
    unc_thresh = float(_g("UNCERTAINTY_THRESHOLD", tau_low))
    trg_unc_thresh = float(_g("TRG_UNCERTAINTY_THRESHOLD", unc_thresh))

    print("Configuration:")
    print(f"  - Span threshold: {span_thresh}")
    print(f"  - TAU_LOW: {tau_low}")
    print(f"  - Uncertainty threshold: {unc_thresh}")
    print(f"  - TRG uncertainty threshold: {trg_unc_thresh}")
    print(f"  - Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY}")
    print(f"  - ASBN training: {'ENABLED' if _ENABLE_ASBN_TRAINING else 'DISABLED'}")
    print(f"  - Epochs: {_EPOCHS}")
    print(f"  - Batch size: {_BATCH_SIZE}")

    print("=" * 80)
    pipeline_start = time.time()
    if _DEBUG_TIMING:
        phase_start = time.time()

    initialize_environment()
    if _DEBUG_TIMING:
        print(f"[TIMING] Initialization: {time.time() - phase_start:.2f}s")
        phase_start = time.time()

    print("\n[PHASE 1] Loading tokenizer...")
    try:
        tokenizer = _safe_tokenizer_from_pretrained("facebook/m2m100_418M")
    except Exception as e:
        print(f"[PHASE 1] Tokenizer load failed: {e}")
        raise

    try:
        if hasattr(tokenizer, "src_lang"):
            tokenizer.src_lang = _SOURCE_LANGUAGE
    except Exception:
        pass

    try:
        vocab_size = getattr(tokenizer, "vocab_size", None)
        if vocab_size is None:
            try:
                vocab_size = len(tokenizer)
            except Exception:
                vocab_size = None
    except Exception:
        vocab_size = None

    print(f"[PHASE 1] Tokenizer loaded (vocab: {vocab_size if vocab_size else 'unknown'})")
    if _DEBUG_TIMING:
        print(f"[TIMING] Tokenizer: {time.time() - phase_start:.2f}s")
        phase_start = time.time()

    print(f"\n[PHASE 2] Loading data ({_NUM_SAMPLES} samples)...")
    pairs = None
    if "load_and_preprocess_optimized" in globals() and callable(globals()["load_and_preprocess_optimized"]):
        try:
            pairs = globals()["load_and_preprocess_optimized"](_NUM_SAMPLES)
        except Exception as e:
            print(f"[PHASE 2] Data loading failed: {e}")
            pairs = [("আমি কল বন্ধ করেছি।", "I turned off the tap.")]
    else:
        print("[PHASE 2] Using fallback data")
        pairs = [("আমি কল বন্ধ করেছি।", "I turned off the tap.")]

    if "MemoryEfficientDataset" not in globals() or not callable(globals().get("MemoryEfficientDataset")):
        raise RuntimeError("MemoryEfficientDataset not found - please run Cell 2 or provide dataset implementation")

    DatasetCls = globals()["MemoryEfficientDataset"]
    try:
        dataset = DatasetCls(pairs, tokenizer, max_length=_MAX_LENGTH)
    except Exception:
        class _SimpleDataset:
            def __init__(self, pairs, tokenizer, max_length=48):
                self.pairs = pairs
                self.tokenizer = tokenizer
                self.max_length = max_length
            def __len__(self):
                return len(self.pairs)
            def __getitem__(self, idx):
                s, t = self.pairs[idx]
                enc_s = self.tokenizer(
                    s,
                    truncation=True,
                    max_length=self.max_length,
                    padding="max_length",
                    return_tensors="pt",
                )
                enc_t = self.tokenizer(
                    t,
                    truncation=True,
                    max_length=self.max_length,
                    padding="max_length",
                    return_tensors="pt",
                )
                labels = enc_t["input_ids"].squeeze(0)
                pad_token_id = getattr(self.tokenizer, "pad_token_id", 1)
                labels = labels.masked_fill(labels == pad_token_id, -100)
                return {
                    "input_ids": enc_s["input_ids"].squeeze(0),
                    "attention_mask": enc_s["attention_mask"].squeeze(0),
                    "labels": labels,
                    "src_text": s,
                    "tokens": [],
                    "token_word_map": {},
                    "domain_label": 0,
                }
        dataset = _SimpleDataset(pairs, tokenizer, max_length=_MAX_LENGTH)

    collate_fn = globals().get("safe_collate", None)
    if "create_optimized_dataloader" in globals() and callable(globals()["create_optimized_dataloader"]):
        try:
            train_loader = globals()["create_optimized_dataloader"](
                dataset, batch_size=_BATCH_SIZE, shuffle=True
            )
        except Exception:
            dl_kwargs = {
                "batch_size": _BATCH_SIZE,
                "shuffle": True,
                "num_workers": 0,
                "pin_memory": torch.cuda.is_available(),
            }
            if collate_fn is not None:
                dl_kwargs["collate_fn"] = collate_fn
            train_loader = DataLoader(dataset, **dl_kwargs)
    else:
        dl_kwargs = {
            "batch_size": _BATCH_SIZE,
            "shuffle": True,
            "num_workers": 0,
            "pin_memory": torch.cuda.is_available(),
        }
        if collate_fn is not None:
            dl_kwargs["collate_fn"] = collate_fn
        train_loader = DataLoader(dataset, **dl_kwargs)

    try:
        actual_batches = len(train_loader)
        if actual_batches == 0:
            raise RuntimeError("Dataloader is empty - no batches to train on")
        print(f"[PHASE 2] Dataset: {len(dataset)} samples, {actual_batches} batches")
    except TypeError as e:
        print(f"[PHASE 2] Warning: Cannot determine batch count: {e}")
        print(f"[PHASE 2] Dataset: {len(dataset)} samples")
    except Exception as e:
        print(f"[PHASE 2] Warning: Dataset validation error: {e}")
        print("[PHASE 2] Dataset loaded")
    
    del pairs
    _safe_clear_gpu_caches()
    if _DEBUG_TIMING:
        print(f"[TIMING] Data loading: {time.time() - phase_start:.2f}s")
        phase_start = time.time()

    print("\n[PHASE 3] Initializing model...")
    if "MemoryOptimizedTATNWithExplanations" not in globals() or not callable(globals().get("MemoryOptimizedTATNWithExplanations")):
        raise RuntimeError("Model class MemoryOptimizedTATNWithExplanations not found - run Cell 6")

    ModelCls = globals()["MemoryOptimizedTATNWithExplanations"]
    try:
        model_core = ModelCls(tokenizer)
    except Exception as e:
        raise RuntimeError(f"Failed to instantiate model: {e}")

    if _USE_MULTI_GPU and _NUM_GPUS > 1:
        device_ids = list(range(_NUM_GPUS))
        print(f"[PHASE 3] Using DataParallel on {device_ids}")
        model = nn.DataParallel(model_core, device_ids=device_ids)
    else:
        model = model_core

    try:
        model = model.to(_DEVICE)
    except Exception as e:
        print(f"[PHASE 3] Failed to move model to {_DEVICE}: {e}")
        raise

    core_model = model.module if hasattr(model, "module") else model

    try:
        dscd_present = hasattr(core_model, "dscd") and core_model.dscd is not None
        asbn_present = hasattr(core_model, "asbn") and core_model.asbn is not None
        trg_present = hasattr(core_model, "trg_system") and core_model.trg_system is not None
        
        print("[PHASE 3] Component check:")
        print(f"  - DSCD: {'OK' if dscd_present else 'MISSING'}")
        print(f"  - ASBN: {'OK' if asbn_present else 'MISSING'}")
        print(f"  - TRG: {'OK' if trg_present else 'MISSING'}")
        
        if not dscd_present:
            print("[PHASE 3] WARNING: DSCD module not found!")
        if not asbn_present:
            print("[PHASE 3] WARNING: ASBN module not found!")
        if not trg_present:
            print("[PHASE 3] WARNING: TRG module not found!")
            
        if not (dscd_present and asbn_present):
            raise RuntimeError("Critical components missing - ensure Cell 3, 4, 5 ran successfully")
    except Exception as e:
        print(f"[PHASE 3] Component validation failed: {e}")
        raise

    try:
        mbart = getattr(core_model, "mbart", None)
        if mbart is not None:
            try:
                emb = mbart.get_input_embeddings()
                model_emb_count = getattr(emb, "num_embeddings", None)
            except Exception:
                model_emb_count = None
            
            tok_len = vocab_size
            
            if isinstance(model_emb_count, int) and isinstance(tok_len, int):
                if model_emb_count != tok_len:
                    print(f"[PHASE 3] ❌ FATAL: Vocab mismatch: model={model_emb_count}, tokenizer={tok_len}")
                    print(f"[PHASE 3] Cell 6 should have fixed this - check Cell 6 output")
                    raise RuntimeError(f"Vocab size mismatch will cause CUDA errors: {model_emb_count} != {tok_len}")
                else:
                    print(f"[PHASE 3] ✅ Vocab sizes verified: model={model_emb_count}, tokenizer={tok_len}")
            else:
                print(f"[PHASE 3] ⚠️  Cannot verify vocab sizes: model={model_emb_count}, tokenizer={tok_len}")
            
            forced_bos = None
            try:
                if hasattr(tokenizer, "get_lang_id"):
                    try:
                        forced_bos = int(tokenizer.get_lang_id(_TARGET_LANGUAGE))
                    except Exception:
                        pass
                
                if forced_bos is None:
                    forced_bos = _M2M100_EN_TOKEN_ID
                
                if hasattr(mbart, "config"):
                    mbart.config.forced_bos_token_id = int(forced_bos)
                    mbart.config.decoder_start_token_id = int(forced_bos)
                    print(f"[PHASE 3] Set forced_bos_token_id to {forced_bos} for {_TARGET_LANGUAGE}")
            except Exception as e:
                print(f"[PHASE 3] Warning: Failed to set forced_bos_token_id: {e}")
    except Exception as e:
        print(f"[PHASE 3] Model config update failed: {e}")

    if _FREEZE_ENCODER:
        try:
            enc_obj = getattr(core_model, "mbart", None)
            if enc_obj is not None and hasattr(enc_obj, "model") and hasattr(enc_obj.model, "encoder"):
                for p in enc_obj.model.encoder.parameters():
                    p.requires_grad = False
            elif hasattr(core_model, "encoder"):
                for p in core_model.encoder.parameters():
                    p.requires_grad = False
            print("[PHASE 3] Encoder frozen")
        except Exception:
            pass

    print("[PHASE 3] Model initialized")
    if _DEBUG_TIMING:
        print(f"[TIMING] Model init: {time.time() - phase_start:.2f}s")
        phase_start = time.time()

    print("\n[PHASE 4] Setting up optimizers...")
    critic_params: List[torch.nn.Parameter] = []
    try:
        if hasattr(core_model, "asbn") and core_model.asbn is not None:
            if hasattr(core_model.asbn, "critic_parameters") and callable(getattr(core_model.asbn, "critic_parameters")):
                try:
                    critic_params = list(core_model.asbn.critic_parameters())
                except Exception:
                    print("[PHASE 4] WARNING: critic_parameters() failed, using all ASBN parameters")
                    critic_params = [p for p in core_model.asbn.parameters()]
            else:
                print("[PHASE 4] WARNING: No critic_parameters() method, using all ASBN parameters")
                critic_params = [p for p in core_model.asbn.parameters()]
    except Exception:
        critic_params = []

    critic_ids = {id(p) for p in critic_params}
    base_params = [p for p in core_model.parameters() if p.requires_grad and id(p) not in critic_ids]

    if not base_params:
        print("[PHASE 4] ERROR: No trainable base parameters found!")
        raise RuntimeError("No parameters to optimize - model might be frozen")

    optimizer = torch.optim.AdamW(base_params, lr=_LR_NMT)

    phi_optimizer = None
    if _ENABLE_ASBN_TRAINING:
        if critic_params:
            phi_params = [p for p in critic_params if p.requires_grad]
            if phi_params:
                phi_optimizer = torch.optim.AdamW(phi_params, lr=_LR_PHI)
                print(f"[PHASE 4] ASBN optimizer created ({len(phi_params)} params)")
            else:
                print("[PHASE 4] WARNING: ASBN critic parameters found but none require grad")
        else:
            print("[PHASE 4] WARNING: ASBN training enabled but no critic parameters found")
    else:
        print("[PHASE 4] ASBN training disabled")

    print(f"[PHASE 4] Base optimizer: {len(base_params)} parameters")
    print("[PHASE 4] Optimizers ready")
    if _DEBUG_TIMING:
        print(f"[TIMING] Optimizers: {time.time() - phase_start:.2f}s")
        phase_start = time.time()

    print("\n[PHASE 5] Pre-training validation...")
    try:
        print("[PHASE 5] Checking model vocab alignment...")
        core_check = model.module if hasattr(model, "module") else model
        mbart_check = getattr(core_check, "mbart", None)
        
        if mbart_check:
            emb_check = mbart_check.get_input_embeddings()
            model_vocab = _safe_int(emb_check.num_embeddings, 0)
            tok_vocab = _safe_int(vocab_size, 0) if isinstance(vocab_size, int) else 0
            
            if tok_vocab == 0:
                try:
                    tok_vocab = len(tokenizer)
                except Exception:
                    tok_vocab = 0
            
            print(f"[PHASE 5] Model vocab size: {model_vocab}")
            print(f"[PHASE 5] Tokenizer vocab size: {tok_vocab}")
            
            if model_vocab == 0 or tok_vocab == 0:
                print(f"[PHASE 5] WARNING: Cannot verify vocab sizes")
            elif model_vocab != tok_vocab:
                print(f"[PHASE 5] ❌ FATAL: Vocab mismatch! Training will crash!")
                raise RuntimeError(f"Vocab size mismatch: model={model_vocab}, tokenizer={tok_vocab}")
            else:
                print(f"[PHASE 5] ✅ Vocab sizes match - safe to train")
        
        print("[PHASE 5] Validating first batch...")
        first_batch = next(iter(train_loader))
        
        try:
            for k, v in first_batch.items():
                if isinstance(v, torch.Tensor):
                    first_batch[k] = v.to(_DEVICE)
        except Exception as e:
            print(f"[PHASE 5] Warning: Could not move batch to device: {e}")
        
        input_ids = first_batch["input_ids"]
        labels = first_batch["labels"]
        
        max_input = _safe_int(input_ids.max().item(), 0)
        valid_labels = labels[labels != -100]
        max_label = _safe_int(valid_labels.max().item(), 0) if valid_labels.numel() > 0 else 0
        
        print(f"[PHASE 5] Max input_id: {max_input} (must be < {model_vocab})")
        print(f"[PHASE 5] Max label: {max_label} (must be < {model_vocab})")
        
        if model_vocab > 0 and (max_input >= model_vocab or max_label >= model_vocab):
            print(f"[PHASE 5] ❌ FATAL: Batch contains out-of-range tokens!")
            raise RuntimeError("Dataset contains invalid token IDs")
        else:
            print(f"[PHASE 5] ✅ Batch tokens within valid range")
            
    except StopIteration:
        print("[PHASE 5] ERROR: Dataloader is empty!")
        raise RuntimeError("No batches in train_loader")
    except Exception as e:
        print(f"[PHASE 5] Pre-training validation failed: {e}")
        raise

    print("\n[PHASE 5B] Baseline evaluation...")
    baseline_metrics: Optional[Dict[str, Any]] = None
    try:
        dscd = getattr(core_model, "dscd", None)
        has_prototypes = False
        if dscd is not None:
            proto_stores = getattr(dscd, "prototype_stores", None)
            if proto_stores:
                lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
                if lock:
                    with lock:
                        has_prototypes = len(proto_stores) > 0
                else:
                    has_prototypes = len(proto_stores) > 0
        if has_prototypes:
            print("[PHASE 5B] Prototypes exist - skipping baseline")
        elif "comprehensive_post_training_testing" in globals() and callable(globals()["comprehensive_post_training_testing"]):
            result = globals()["comprehensive_post_training_testing"](
                model, tokenizer, run_warmup=False, compare_baseline=False
            )
            if isinstance(result, dict):
                baseline_metrics = result
                print("[PHASE 5B] Baseline complete")
            else:
                print("[PHASE 5B] Baseline returned invalid type")
        else:
            print("[PHASE 5B] Skipping baseline (function not found)")
    except Exception as e:
        print(f"[PHASE 5B] Baseline failed: {e}")
        baseline_metrics = None

    if _DEBUG_TIMING:
        print(f"[TIMING] Pre-training validation: {time.time() - phase_start:.2f}s")
        phase_start = time.time()

    _safe_clear_gpu_caches()

    print("\n[PHASE 6] Training...")
    trained_model = model
    training_stats = None
    if "train_memory_efficient_tatn" in globals() and callable(globals()["train_memory_efficient_tatn"]):
        try:
            res = globals()["train_memory_efficient_tatn"](
                model,
                tokenizer,
                train_loader,
                optimizer,
                phi_optimizer=phi_optimizer,
                epochs=_EPOCHS,
                accumulation_steps=_ACCUMULATION_STEPS,
                validate_every=_VALIDATION_CHECK_INTERVAL,
                enable_validation=(_VALIDATION_CHECK_INTERVAL > 0),
                enable_asbn_training=_ENABLE_ASBN_TRAINING,
            )
            if isinstance(res, tuple) and len(res) == 2:
                trained_model, training_stats = res
            else:
                trained_model = res
                training_stats = None
            print("[PHASE 6] Training complete")
        except Exception as e:
            print(f"[PHASE 6] Training failed: {e}")
            if _VERBOSE_LOGGING:
                try:
                    traceback.print_exc()
                except Exception:
                    pass
            trained_model = model
            training_stats = None
    else:
        print("[PHASE 6] Skipping training (function not found)")

    if _DEBUG_TIMING:
        print(f"[TIMING] Training: {time.time() - phase_start:.2f}s")
        phase_start = time.time()
    _safe_clear_gpu_caches()

    print("\n[PHASE 7] Post-training validation...")
    try:
        core_for_validation = trained_model.module if hasattr(trained_model, "module") else trained_model
        dscd = getattr(core_for_validation, "dscd", None)
        if dscd is None:
            print("[PHASE 7] No DSCD module")
        else:
            proto_stores = getattr(dscd, "prototype_stores", None) or {}
            lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
            if lock:
                with lock:
                    stores = dict(proto_stores)
            else:
                stores = dict(proto_stores)

            def _store_size(s):
                try:
                    if hasattr(s, "size") and callable(getattr(s, "size")):
                        return int(s.size())
                    cents = getattr(s, "centroids", None)
                    if isinstance(cents, torch.Tensor):
                        return int(cents.size(0))
                    elif isinstance(cents, list):
                        return len(cents)
                    return 0
                except Exception:
                    return 0

            total_protos = sum(_store_size(store) for store in stores.values())
            multi_sense = sum(1 for store in stores.values() if _store_size(store) >= 2)
            print("[PHASE 7] DSCD status:")
            print(f"  - Tokens: {len(stores)}")
            print(f"  - Prototypes: {total_protos}")
            print(f"  - Multi-sense: {multi_sense}")
            if len(stores) == 0 or total_protos == 0:
                print("[PHASE 7] WARNING: No prototypes created during training")
    except Exception as e:
        print(f"[PHASE 7] Validation failed: {e}")

    if _DEBUG_TIMING:
        print(f"[TIMING] Validation: {time.time() - phase_start:.2f}s")
        phase_start = time.time()
    _safe_clear_gpu_caches()

    print("\n[PHASE 8] Post-training evaluation...")
    eval_results: Optional[Dict[str, Any]] = None
    if "comprehensive_post_training_testing" in globals() and callable(globals()["comprehensive_post_training_testing"]):
        try:
            core_for_eval = trained_model.module if hasattr(trained_model, "module") else trained_model
            trg = getattr(core_for_eval, "trg_system", None)
            if trg and hasattr(trg, "reset_statistics"):
                try:
                    trg.reset_statistics()
                except Exception:
                    pass
            result = globals()["comprehensive_post_training_testing"](
                trained_model,
                tokenizer,
                run_warmup=False,
                compare_baseline=(isinstance(baseline_metrics, dict) and len(baseline_metrics) > 0),
                baseline_metrics=baseline_metrics,
            )
            if isinstance(result, dict):
                eval_results = result
                print("[PHASE 8] Evaluation complete")
            else:
                print("[PHASE 8] Evaluation returned invalid type")
        except Exception as e:
            print(f"[PHASE 8] Evaluation failed: {e}")
            eval_results = None
    else:
        print("[PHASE 8] Skipping evaluation (function not found)")

    if _DEBUG_TIMING:
        print(f"[TIMING] Evaluation: {time.time() - phase_start:.2f}s")
        phase_start = time.time()
    _safe_clear_gpu_caches()

    print("\n[PHASE 9] Saving checkpoint...")
    try:
        os.makedirs(_CHECKPOINT_DIR, exist_ok=True)
        core_for_save = trained_model.module if hasattr(trained_model, "module") else trained_model
        was_training = getattr(core_for_save, "training", False)
        core_for_save.eval()
        try:
            model_state = {}
            for k, v in core_for_save.state_dict().items():
                try:
                    if isinstance(v, torch.Tensor):
                        if v.numel() > 0 and torch.isfinite(v).all():
                            model_state[k] = v.cpu().detach().clone()
                        else:
                            print(f"[PHASE 9] WARNING: Skipping corrupted tensor: {k}")
                    else:
                        model_state[k] = v
                except Exception as e:
                    print(f"[PHASE 9] WARNING: Failed to save parameter {k}: {e}")

            dscd_state = {}
            if hasattr(core_for_save, "dscd") and core_for_save.dscd is not None:
                dscd_obj = core_for_save.dscd
                if hasattr(dscd_obj, "state_dict") and callable(getattr(dscd_obj, "state_dict")):
                    try:
                        raw_state = dscd_obj.state_dict()
                        if isinstance(raw_state, dict) and "prototype_stores_data" in raw_state:
                            stores = raw_state.get("prototype_stores_data", {}) or {}
                            valid_stores = {}
                            corrupted_count = 0
                            for token, store_dict in stores.items():
                                if not isinstance(store_dict, dict):
                                    corrupted_count += 1
                                    continue
                                centroids = store_dict.get("centroids", None)
                                counts = store_dict.get("counts", None)
                                is_valid = False
                                try:
                                    if isinstance(centroids, torch.Tensor):
                                        if centroids.numel() > 0 and centroids.dim() == 2 and torch.isfinite(centroids).all():
                                            if isinstance(counts, list) and len(counts) == centroids.size(0) and sum(counts) > 0:
                                                is_valid = True
                                    elif isinstance(centroids, list) and len(centroids) > 0:
                                        if isinstance(counts, list) and len(counts) == len(centroids) and sum(counts) > 0:
                                            is_valid = True
                                except Exception:
                                    is_valid = False
                                if is_valid:
                                    valid_stores[token] = store_dict
                                else:
                                    corrupted_count += 1
                            dscd_state = {"prototype_stores_data": valid_stores}
                            if corrupted_count > 0:
                                print(f"[PHASE 9] Filtered {corrupted_count} corrupted prototype stores")
                    except Exception as e:
                        print(f"[PHASE 9] DSCD state extraction failed: {e}")
                        dscd_state = {}

            optimizer_state = None
            try:
                optimizer_state = optimizer.state_dict()
            except Exception:
                optimizer_state = None

            phi_optimizer_state = None
            try:
                if phi_optimizer is not None:
                    phi_optimizer_state = phi_optimizer.state_dict()
            except Exception:
                phi_optimizer_state = None

            checkpoint = {
                "model_state_dict": model_state,
                "dscd_state": dscd_state,
                "optimizer_state_dict": optimizer_state,
                "phi_optimizer_state_dict": phi_optimizer_state,
                "baseline_metrics": baseline_metrics,
                "eval_results": eval_results,
                "training_stats": training_stats,
                "timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
                "config": {
                    "epochs": _EPOCHS,
                    "batch_size": _BATCH_SIZE,
                    "span_threshold": span_thresh,
                    "tau_low": tau_low,
                    "uncertainty_threshold": unc_thresh,
                    "trg_uncertainty_threshold": trg_unc_thresh,
                    "discovery_frequency": _PERIODIC_DISCOVERY_FREQUENCY,
                    "enable_asbn_training": _ENABLE_ASBN_TRAINING,
                    "enable_trg_inference": _ENABLE_TRG_INFERENCE,
                },
            }
            torch.save(checkpoint, _CHECKPOINT_PATH)
            try:
                verify = torch.load(_CHECKPOINT_PATH, map_location="cpu")
                has_model = isinstance(verify.get("model_state_dict", None), dict) and len(verify.get("model_state_dict", {})) > 0
                has_dscd = isinstance(verify.get("dscd_state", None), dict) and len(verify.get("dscd_state", {})) > 0
                has_phi = verify.get("phi_optimizer_state_dict", None) is not None
                
                print(f"[PHASE 9] Checkpoint saved: {_CHECKPOINT_PATH}")
                print(f"  - Model: {'OK' if has_model else 'MISSING'}")
                print(f"  - DSCD: {'OK' if has_dscd else 'MISSING'}")
                print(f"  - Phi optimizer: {'OK' if has_phi else 'N/A'}")
                
                if has_model:
                    try:
                        test_load = {}
                        corrupted_tensors = 0
                        for k, v in verify["model_state_dict"].items():
                            if isinstance(v, torch.Tensor):
                                if torch.isfinite(v).all():
                                    test_load[k] = v
                                else:
                                    corrupted_tensors += 1
                        if corrupted_tensors > 0:
                            print(f"  - WARNING: {corrupted_tensors} corrupted tensors in checkpoint")
                        else:
                            print(f"  - Model state loadable: OK ({len(test_load)} tensors)")
                    except Exception as e:
                        print(f"  - Model state loadable: FAILED ({e})")
                
                if has_dscd:
                    d_state = verify.get("dscd_state", {})
                    stores_data = d_state.get("prototype_stores_data", {}) if isinstance(d_state, dict) else {}
                    num_tokens = len(stores_data) if isinstance(stores_data, dict) else 0
                    multi_sense_count = 0
                    if isinstance(stores_data, dict):
                        for sd in stores_data.values():
                            try:
                                cent = sd.get("centroids", None)
                                if isinstance(cent, torch.Tensor) and cent.size(0) >= 2:
                                    multi_sense_count += 1
                                elif isinstance(cent, list) and len(cent) >= 2:
                                    multi_sense_count += 1
                            except Exception:
                                continue
                    print(f"  - DSCD tokens: {num_tokens}")
                    print(f"  - Multi-sense: {multi_sense_count}")
            except Exception as e:
                print(f"[PHASE 9] Checkpoint written but verification failed: {e}")
        finally:
            if was_training:
                try:
                    core_for_save.train()
                except Exception:
                    pass
    except Exception as e:
        print(f"[PHASE 9] Checkpoint failed: {e}")
        if _VERBOSE_LOGGING:
            try:
                traceback.print_exc()
            except Exception:
                pass

    if _DEBUG_TIMING:
        print(f"[TIMING] Checkpoint: {time.time() - phase_start:.2f}s")
        phase_start = time.time()

    print("\n[PHASE 10] Final component validation...")
    try:
        core_final = trained_model.module if hasattr(trained_model, "module") else trained_model
        dscd_ok = False
        if hasattr(core_final, "dscd"):
            proto_stores = getattr(core_final.dscd, "prototype_stores", None)
            if proto_stores:
                lock = getattr(core_final.dscd, "buffer_lock", None) or getattr(core_final.dscd, "clustering_lock", None)
                if lock:
                    with lock:
                        dscd_ok = len(proto_stores) > 0
                else:
                    dscd_ok = len(proto_stores) > 0
        asbn_ok = hasattr(core_final, "asbn") and callable(getattr(core_final.asbn, "forward", None))
        trg_ok = hasattr(core_final, "trg_system") and callable(getattr(core_final.trg_system, "process_sentence_for_explanations", None))
        print("[PHASE 10] Component validation:")
        print(f"  - DSCD: {'OK' if dscd_ok else 'MISSING'}")
        print(f"  - ASBN: {'OK' if asbn_ok else 'MISSING'}")
        print(f"  - TRG: {'OK' if trg_ok else 'MISSING'}")
    except Exception as e:
        print(f"[PHASE 10] Validation failed: {e}")

    pipeline_time = time.time() - pipeline_start
    print("\n" + "=" * 80)
    print("PIPELINE COMPLETE - FINAL SUMMARY")
    print("=" * 80)
    print(f"[TIMING] Total time: {pipeline_time:.2f}s ({pipeline_time/60:.2f} min)")
    print(f"[TRAINING] Completed {_EPOCHS} epoch(s)")

    if isinstance(baseline_metrics, dict) and isinstance(eval_results, dict):
        baseline_success = _safe_float(baseline_metrics.get("success_rate_pct", 0.0), 0.0)
        final_success = _safe_float(eval_results.get("success_rate_pct", 0.0), 0.0)
        improvement = final_success - baseline_success
        print(
            f"[EVALUATION] Baseline -> Final: {baseline_success:.1f}% -> "
            f"{final_success:.1f}%, Improvement: {improvement:+.1f}%"
        )
    elif isinstance(eval_results, dict):
        final_success = _safe_float(eval_results.get("success_rate_pct", 0.0), 0.0)
        print(f"[EVALUATION] Success rate: {final_success:.1f}%")
    else:
        print("[EVALUATION] No results")

    print(f"\n[CHECKPOINT] {_CHECKPOINT_PATH if os.path.exists(_CHECKPOINT_PATH) else 'Not saved'}")
    if os.path.exists(_CHECKPOINT_PATH):
        try:
            size_mb = os.path.getsize(_CHECKPOINT_PATH) / 1024**2
            print(f"  - Size: {size_mb:.2f} MB")
        except Exception:
            pass

    print("\n" + "=" * 80)
    print("Usage: trained_model, tokenizer = main_pipeline()")
    print("=" * 80)

    _safe_clear_gpu_caches()
    return trained_model, tokenizer

print("\n" + "=" * 80)
print("Cell 10: Main pipeline ready - FIXED")
print("=" * 80)
print("KEY FIXES APPLIED:")
print("  ✓ REMOVED tokenizer modification code (lines 230-245)")
print("  ✓ Cell 6 now handles all vocab resizing automatically")
print("  ✓ Simplified vocab size handling (no string 'unknown')")
print("  ✓ Enhanced error messages with emoji indicators")
print("  ✓ All vocab checks now fail-fast with RuntimeError")
print("  ✓ Removed 'tokenizer_modified' from checkpoint config")



In [None]:
# ==============================================================================
# CELL 11: MAIN EXECUTION WRAPPER (FINAL) - FIXED
# ==============================================================================
from datetime import datetime, timezone
import os
import traceback
import math
import sys
import time
import torch
import gc
from collections import defaultdict
from typing import Any, Dict, Set

# Safe configuration load with defaults (robust retrieval from globals)
def _g(name, default):
    try:
        return globals().get(name, default)
    except Exception:
        return default

try:
    _NUM_SAMPLES = int(_g("NUM_SAMPLES", 30000))
    _EPOCHS = int(_g("EPOCHS", 2))
    _BATCH_SIZE = int(_g("BATCH_SIZE", 4))
    _ACCUMULATION_STEPS = int(_g("ACCUMULATION_STEPS", 16))

    raw_device = _g("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(raw_device, torch.device):
        _DEVICE = raw_device
    else:
        try:
            _DEVICE = torch.device(str(raw_device))
        except Exception:
            _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    _ENABLE_ASBN_TRAINING = bool(_g("ENABLE_ASBN_TRAINING", True))
    _ENABLE_TRG_INFERENCE = bool(_g("ENABLE_TRG_INFERENCE", True))
    _PERIODIC_DISCOVERY_FREQUENCY = int(_g("PERIODIC_DISCOVERY_FREQUENCY", 150))
    _VERBOSE_LOGGING = bool(_g("VERBOSE_LOGGING", False))
    _DEBUG_DISCOVERY = bool(_g("DEBUG_DISCOVERY", False))
    _DEBUG_TIMING = bool(_g("DEBUG_TIMING", False))
    _NUM_GPUS = int(_g("NUM_GPUS", torch.cuda.device_count() if torch.cuda.is_available() else 0))
    _USE_MULTI_GPU = bool(_g("USE_MULTI_GPU", _NUM_GPUS > 1))

    # TRG/ambiguity thresholds - prefer TRG-specific values if available
    _SPAN_THRESHOLD = float(_g("TRG_SPAN_THRESHOLD", _g("SPAN_THRESHOLD", 0.15)))
    _TAU_LOW = float(_g("TAU_LOW", 0.25))
    _UNCERTAINTY_THRESHOLD = float(_g("UNCERTAINTY_THRESHOLD", _TAU_LOW))
    _TRG_UNCERTAINTY_THRESHOLD = float(_g("TRG_UNCERTAINTY_THRESHOLD", _UNCERTAINTY_THRESHOLD))

    _M2M100_EN_TOKEN_ID = int(_g("M2M100_EN_TOKEN_ID", 128022))
    _M2M100_BN_TOKEN_ID = int(_g("M2M100_BN_TOKEN_ID", 128025))

    raw_list = _g("HOMOGRAPH_REFERENCE_LIST_BN", ["কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা"])
    _HOMOGRAPH_REFERENCE_LIST_BN = set(str(w).strip().lower() for w in raw_list if w is not None)

    cell0_loaded = "NUM_SAMPLES" in globals()
except Exception as e:
    # Fallbacks if global read fails
    print(f"[EXEC] Config load error: {e}")
    _NUM_SAMPLES = 30000
    _EPOCHS = 2
    _BATCH_SIZE = 4
    _ACCUMULATION_STEPS = 16
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _ENABLE_ASBN_TRAINING = True
    _ENABLE_TRG_INFERENCE = True
    _PERIODIC_DISCOVERY_FREQUENCY = 150
    _VERBOSE_LOGGING = False
    _DEBUG_DISCOVERY = False
    _DEBUG_TIMING = False
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = (_NUM_GPUS > 1)
    _SPAN_THRESHOLD = 0.15
    _TAU_LOW = 0.25
    _UNCERTAINTY_THRESHOLD = 0.25
    _TRG_UNCERTAINTY_THRESHOLD = 0.25
    _M2M100_EN_TOKEN_ID = 128022
    _M2M100_BN_TOKEN_ID = 128025
    _HOMOGRAPH_REFERENCE_LIST_BN = {"কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা"}
    cell0_loaded = False

_CHECKPOINT_PATH = _g("CHECKPOINT_PATH", _g("CHECKPOINT_DIR", "/kaggle/working") + "/tatn_final.pt")


def _safe_div_ceil(a: int, b: int) -> int:
    try:
        a_i = int(a)
        b_i = int(b)
        if b_i <= 0:
            return 0
        return math.ceil(a_i / b_i)
    except Exception:
        return 0


def _format_duration(seconds: float) -> str:
    try:
        seconds = float(seconds)
        if seconds < 60:
            return f"{seconds:.1f}s"
        if seconds < 3600:
            return f"{seconds/60:.1f}min"
        return f"{seconds/3600:.2f}hr"
    except Exception:
        return "N/A"


def _safe_get(d: dict, *keys, default=None):
    if not isinstance(d, dict):
        return default
    result = d
    for key in keys:
        if not isinstance(result, dict):
            return default
        result = result.get(key, default)
        if result is default:
            return default
    return result


def _get_dscd_homographs(model) -> Set[str]:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if not dscd or not hasattr(dscd, "prototype_stores"):
            return set()

        lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
        if lock:
            with lock:
                stores = dict(getattr(dscd, "prototype_stores", {}) or {})
        else:
            stores = dict(getattr(dscd, "prototype_stores", {}) or {})

        word_prototype_counts = defaultdict(int)
        for token, store in stores.items():
            try:
                num_protos = 0
                if hasattr(store, "size") and callable(getattr(store, "size")):
                    try:
                        num_protos = int(store.size())
                    except Exception:
                        num_protos = 0
                else:
                    centroids = getattr(store, "centroids", None)
                    try:
                        num_protos = len(centroids) if centroids is not None else 0
                    except Exception:
                        num_protos = 0

                clean = (
                    str(token)
                    .replace("▁", "")
                    .replace("Ġ", "")
                    .replace("##", "")
                    .replace("@@", "")
                    .replace("</w>", "")
                    .strip()
                    .lower()
                )
                if clean:
                    word_prototype_counts[clean] = max(word_prototype_counts[clean], num_protos)
            except Exception:
                continue

        homographs = {w for w, c in word_prototype_counts.items() if c >= 2}
        return homographs
    except Exception:
        return set()


def _safe_cleanup():
    try:
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                try:
                    with torch.cuda.device(i):
                        torch.cuda.empty_cache()
                except Exception:
                    pass
        if gc.isenabled():
            gc.collect()
    except Exception:
        pass


# entrypoint
if __name__ == "__main__":
    print("=" * 80)
    print("MEMORY-OPTIMIZED TATN - COMPLETE EXECUTION")
    print("=" * 80)

    user_login = os.getenv("KAGGLE_USERNAME") or os.getenv("USER") or "manas0003"
    start_time = time.time()
    now_utc = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")

    print(f"User: {user_login}")
    print(f"Started: {now_utc}")

    print("\n[CONFIGURATION]")
    print(f"  Cell 0 status: {'Loaded' if cell0_loaded else 'Using fallbacks'}")
    print(f"  Samples: {_NUM_SAMPLES}")
    print(f"  Epochs: {_EPOCHS}")
    print(f"  Batch Size: {_BATCH_SIZE}")
    print(f"  Accumulation: {_ACCUMULATION_STEPS}")
    print(f"  Device: {_DEVICE}")
    print(f"  Multi-GPU: {'ENABLED' if _USE_MULTI_GPU else 'DISABLED'} ({_NUM_GPUS} GPUs)")
    print(f"  Span threshold: {_SPAN_THRESHOLD}")
    print(f"  TRG uncertainty threshold: {_TRG_UNCERTAINTY_THRESHOLD}")
    print(f"  Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY}")
    print(f"  ASBN: {'Enabled' if _ENABLE_ASBN_TRAINING else 'Disabled'}")
    print(f"  TRG: {'Enabled' if _ENABLE_TRG_INFERENCE else 'Disabled'}")
    print(f"  Debug: {'Enabled' if _DEBUG_DISCOVERY else 'Disabled'}")
    print("=" * 80)

    trained_model, tokenizer = None, None
    pipeline_success = False
    failure_category = None
    failure_details = ""

    # Ensure main pipeline exists
    if "main_pipeline" not in globals() or not callable(globals().get("main_pipeline")):
        print("\nERROR: main_pipeline not found")
        print("   -> Run Cell 10 before executing Cell 11")
        failure_category = "MISSING_DEPENDENCY"
        failure_details = "Cell 10 not executed or main_pipeline not defined"
    else:
        try:
            print("\nStarting pipeline...")
            if _DEBUG_TIMING:
                print("   Expected runtime: depends on config and hardware")

            pipeline_start = time.time()
            # main_pipeline returns (model, tokenizer) in this notebook
            res = globals()["main_pipeline"]()
            # support both returning (model, tokenizer) or model only
            if isinstance(res, tuple) and len(res) >= 2:
                trained_model, tokenizer = res[0], res[1]
            elif res is None:
                trained_model, tokenizer = None, None
            else:
                trained_model = res
                tokenizer = globals().get("tokenizer", None)

            pipeline_duration = time.time() - pipeline_start
            print(f"\nPipeline completed: {_format_duration(pipeline_duration)}")
            if trained_model is None or tokenizer is None:
                pipeline_success = False
                failure_category = "PIPELINE_NO_OUTPUT"
                failure_details = "main_pipeline did not return (model, tokenizer)"
            else:
                pipeline_success = True

        except KeyboardInterrupt:
            print("\nInterrupted by user")
            failure_category = "USER_INTERRUPT"
            failure_details = "Manual stop"

        except RuntimeError as e:
            msg = str(e).lower()
            if "tokenizer" in msg or "sentencepiece" in msg:
                print("\nTokenizer error")
                failure_category = "TOKENIZER_ERROR"
                failure_details = str(e)[:200]
                print("\nFix:")
                print("   ! pip install transformers==4.30.2 sentencepiece tokenizers")
                print("   Then RESTART kernel and re-run Cells 0-11")
            elif "out of memory" in msg:
                print("\nOut of Memory")
                failure_category = "OOM_ERROR"
                failure_details = "GPU OOM"
                print("\nFixes:")
                print("   1. Reduce BATCH_SIZE (try 2-4)")
                print("   2. Reduce NUM_SAMPLES (try 10k-20k)")
                print("   3. Increase ACCUMULATION_STEPS (32-64)")
            else:
                print(f"\nRuntime error: {type(e).__name__}")
                print(f"   {str(e)[:400]}")
                failure_category = "RUNTIME_ERROR"
                failure_details = str(e)[:200]
            if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                print("\n[TRACEBACK]")
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        except Exception as e:
            print(f"\nUnexpected error: {type(e).__name__}")
            print(f"   {str(e)[:400]}")
            failure_category = "UNKNOWN_ERROR"
            failure_details = str(e)[:200]
            if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                print("\n[TRACEBACK]")
                try:
                    traceback.print_exc()
                except Exception:
                    pass

    # If pipeline succeeded, proceed to verification and light evaluation
    if pipeline_success and trained_model is not None and tokenizer is not None:
        print("\n" + "=" * 80)
        print("PIPELINE SUCCEEDED")
        print("=" * 80)

        # THREAD CLEANUP: wait for clustering threads/futures to finish (safe, bounded wait)
        print("\n[THREAD CLEANUP]")
        print("Waiting for clustering threads to complete (bounded)...")
        try:
            core = trained_model.module if hasattr(trained_model, "module") else trained_model
            dscd = getattr(core, "dscd", None)
            if dscd and hasattr(dscd, "active_threads"):
                try:
                    lock = getattr(dscd, "thread_lock", None) or getattr(dscd, "buffer_lock", None)
                    if lock:
                        with lock:
                            threads = list(dscd.active_threads) if hasattr(dscd.active_threads, "__iter__") else []
                    else:
                        threads = list(dscd.active_threads) if hasattr(dscd.active_threads, "__iter__") else []

                    if threads:
                        print(f"   Found {len(threads)} active threads/futures")
                        completed, timed_out = 0, 0
                        for i, worker in enumerate(threads):
                            try:
                                if hasattr(worker, "result") and callable(getattr(worker, "result")):
                                    try:
                                        worker.result(timeout=10)
                                        completed += 1
                                    except Exception:
                                        timed_out += 1
                                elif hasattr(worker, "is_alive") and hasattr(worker, "join"):
                                    if worker.is_alive():
                                        try:
                                            worker.join(timeout=10)
                                        except Exception:
                                            pass
                                        if worker.is_alive():
                                            timed_out += 1
                                        else:
                                            completed += 1
                                    else:
                                        completed += 1
                                else:
                                    completed += 1
                            except Exception:
                                timed_out += 1
                                continue
                        try:
                            if lock:
                                with lock:
                                    try:
                                        dscd.active_threads.clear()
                                    except Exception:
                                        pass
                            else:
                                try:
                                    dscd.active_threads.clear()
                                except Exception:
                                    pass
                        except Exception:
                            pass
                        print(f"   Cleanup complete: {completed} completed, {timed_out} timed out")
                        if timed_out > 0 and _DEBUG_DISCOVERY:
                            print(f"   Warning: {timed_out} workers abandoned")
                    else:
                        print("   No active threads/futures")
                except Exception as e:
                    print(f"   Thread cleanup error: {type(e).__name__}: {str(e)[:100]}")
            else:
                print("   No thread tracking available")
        except Exception as e:
            print(f"   Thread cleanup failed: {type(e).__name__}: {str(e)[:100]}")

        time.sleep(0.5)
        print("   Ready for evaluation")

        # PROTOTYPE EXTRACTION & VERIFICATION (safe extraction)
        print("\n[PROTOTYPE EXTRACTION & VERIFICATION]")
        print("Extracting and validating DSCD prototypes...")
        prototype_extraction_success = False
        total_prototypes_extracted = 0
        multi_sense_tokens_extracted = 0

        try:
            core = trained_model.module if hasattr(trained_model, "module") else trained_model
            dscd = getattr(core, "dscd", None)
            if dscd and hasattr(dscd, "prototype_stores"):
                lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
                if lock:
                    with lock:
                        stores = dict(getattr(dscd, "prototype_stores", {}) or {})
                else:
                    stores = dict(getattr(dscd, "prototype_stores", {}) or {})

                prototype_data = {}
                for token, store in stores.items():
                    try:
                        centroids = getattr(store, "centroids", None)
                        counts = getattr(store, "counts", None)
                        cent_count = 0
                        if isinstance(centroids, torch.Tensor):
                            cent_count = centroids.size(0)
                        elif isinstance(centroids, list):
                            cent_count = len(centroids)
                        if cent_count > 0 and isinstance(counts, list) and len(counts) == cent_count and sum(counts) > 0:
                            if isinstance(centroids, torch.Tensor):
                                centroids_out = centroids.detach().cpu().tolist()
                            else:
                                centroids_out = []
                                for c in centroids:
                                    if isinstance(c, torch.Tensor):
                                        centroids_out.append(c.detach().cpu().tolist())
                                    else:
                                        try:
                                            centroids_out.append([float(x) for x in c])
                                        except Exception:
                                            centroids_out.append(c)
                            prototype_data[str(token)] = {"centroids": centroids_out, "counts": [int(c) for c in counts]}
                            total_prototypes_extracted += len(counts)
                            if len(counts) >= 2:
                                multi_sense_tokens_extracted += 1
                    except Exception:
                        if _DEBUG_DISCOVERY:
                            try:
                                traceback.print_exc()
                            except Exception:
                                pass
                        continue

                print(
                    f"  Extracted prototypes: tokens={len(prototype_data)} total_protos={total_prototypes_extracted} multi_sense={multi_sense_tokens_extracted}"
                )
                if len(prototype_data) == 0:
                    print("  WARNING: No prototypes extracted!")
                elif total_prototypes_extracted == 0:
                    print("  WARNING: Prototypes extracted but all empty!")
                else:
                    prototype_extraction_success = True
                    if os.path.exists(_CHECKPOINT_PATH):
                        try:
                            ckpt = torch.load(_CHECKPOINT_PATH, map_location="cpu")
                            dscd_state = ckpt.get("dscd_state", {}) or {}
                            stored = dscd_state.get("prototype_stores_data", {}) if isinstance(dscd_state, dict) else {}
                            if not isinstance(stored, dict) or len(stored) != len(prototype_data):
                                dscd_state["prototype_stores_data"] = prototype_data
                                ckpt["dscd_state"] = dscd_state
                                ckpt["prototype_stats"] = {
                                    "total_tokens": len(prototype_data),
                                    "total_prototypes": total_prototypes_extracted,
                                    "multi_sense_tokens": multi_sense_tokens_extracted,
                                }
                                torch.save(ckpt, _CHECKPOINT_PATH)
                                print("  Checkpoint updated with verified prototypes")
                            else:
                                print("  Checkpoint prototypes match extracted data")
                        except Exception as e:
                            print(f"  Failed to update checkpoint: {type(e).__name__}: {str(e)[:100]}")
            else:
                print("  DSCD module not found or has no prototype_stores")
        except Exception as e:
            print(f"  Prototype extraction failed: {type(e).__name__}: {str(e)[:150]}")

        # CHECKPOINT BASIC VALIDATION
        print("\n[CHECKPOINT]")
        checkpoint_valid = False
        try:
            if os.path.exists(_CHECKPOINT_PATH):
                size_mb = os.path.getsize(_CHECKPOINT_PATH) / (1024 ** 2)
                print(f"  File: {_CHECKPOINT_PATH} ({size_mb:.1f} MB)")
                ckpt = torch.load(_CHECKPOINT_PATH, map_location="cpu")
                model_state = ckpt.get("model_state_dict", {})
                has_model = isinstance(model_state, dict) and len(model_state) > 0
                dscd_state = ckpt.get("dscd_state", {}) or {}
                has_dscd = (
                    isinstance(dscd_state, dict) and len(dscd_state.get("prototype_stores_data", {})) > 0
                    if isinstance(dscd_state, dict)
                    else False
                )
                print(f"  Model: {'Present' if has_model else 'MISSING'}")
                print(f"  DSCD: {'Present' if has_dscd else 'MISSING'}")
                if has_dscd:
                    stores_data = dscd_state.get("prototype_stores_data", {}) if isinstance(dscd_state, dict) else {}
                    num_tokens = len(stores_data) if isinstance(stores_data, dict) else 0
                    corrupted = 0
                    if isinstance(stores_data, dict):
                        for token, sd in stores_data.items():
                            if not isinstance(sd, dict):
                                corrupted += 1
                                continue
                            cent = sd.get("centroids", [])
                            counts = sd.get("counts", [])
                            if not cent or not counts or len(cent) != len(counts) or sum(counts) <= 0:
                                corrupted += 1
                    print(f"  Tokens: {num_tokens}")
                    if corrupted > 0:
                        print(f"  WARNING: {corrupted} corrupted stores")
                    if num_tokens > 0 and corrupted < num_tokens * 0.1:
                        checkpoint_valid = True
                        print("  Status: VALID")
                    elif num_tokens > 0:
                        print("  Status: CORRUPTED")
                    else:
                        print("  Status: EMPTY DSCD")
                else:
                    print("  Status: MISSING DSCD")
            else:
                print(f"  NOT FOUND: {_CHECKPOINT_PATH}")
        except Exception as e:
            print(f"  Validation failed: {type(e).__name__}: {str(e)[:100]}")

        # COMPONENT SUMMARY
        print("\n[COMPONENTS]")
        try:
            core = trained_model.module if hasattr(trained_model, "module") else trained_model
            dscd = getattr(core, "dscd", None)
            if dscd and hasattr(dscd, "get_prototype_summary"):
                try:
                    lock = getattr(dscd, "buffer_lock", None) or getattr(dscd, "clustering_lock", None)
                    if lock:
                        with lock:
                            dscd_stats = dscd.get_prototype_summary()
                    else:
                        dscd_stats = dscd.get_prototype_summary()
                    if isinstance(dscd_stats, dict):
                        print("  DSCD:")
                        print(f"    - Tokens: {dscd_stats.get('total_tokens', 0)}")
                        print(f"    - Prototypes: {dscd_stats.get('total_prototypes', 0)}")
                        print(f"    - Homographs: {dscd_stats.get('num_homographs', 0)}")
                except Exception:
                    pass
            asbn = getattr(core, "asbn", None)
            if asbn and hasattr(asbn, "get_detailed_stats"):
                try:
                    result = asbn.get_detailed_stats()
                    if isinstance(result, dict):
                        print("  ASBN:")
                        print(f"    - Domain accuracy: {result.get('domain_accuracy', 0):.2%}")
                except Exception:
                    pass
            trg = getattr(core, "trg_system", None)
            if trg and hasattr(trg, "get_statistics"):
                try:
                    result = trg.get_statistics()
                    if isinstance(result, dict):
                        print("  TRG:")
                        print(f"    - Explanations: {result.get('explanations_generated', 0)}")
                except Exception:
                    pass
        except Exception:
            pass

        # INFERENCE VALIDATION (light)
        print("\n[INFERENCE VALIDATION]")
        print("Testing disambiguation on short set of sentences...")
        print("-" * 80)
        _safe_cleanup()

        inference_success = 0
        inference_failed = 0
        dscd_homographs_detected = set()
        dscd_homographs = _get_dscd_homographs(trained_model)
        print(f"DSCD discovered: {len(dscd_homographs)} homographs")
        if dscd_homographs and _DEBUG_DISCOVERY:
            print(f"  Sample: {list(sorted(dscd_homographs))[:10]}")

        test_sentences = [
            ("আমি কল বন্ধ করেছি।", "কল (tap/call)"),
            ("কাল আমি বই কিনব।", "কাল (tomorrow/yesterday)"),
            ("পাতা ঝরে পড়েছে।", "পাতা (leaf/page)"),
        ]

        inference_times = []
        try:
            if "translate_with_explanations" not in globals() or not callable(globals().get("translate_with_explanations")):
                print("translate_with_explanations not available - run Cell 8")
            else:
                for idx, (sentence, desc) in enumerate(test_sentences, 1):
                    try:
                        print(f"\n{idx}. {desc}")
                        print(f"   Input: {sentence}")
                        inf_start = time.time()
                        # Use TRG-specific uncertainty threshold; span threshold from config
                        res = globals()["translate_with_explanations"](
                            trained_model,
                            tokenizer,
                            sentence,
                            device=_DEVICE,
                            span_threshold=_SPAN_THRESHOLD,
                            uncertainty_threshold=_TRG_UNCERTAINTY_THRESHOLD,
                            track_stats=False,
                        )
                        inf_time = time.time() - inf_start
                        inference_times.append(inf_time)
                        if isinstance(res, dict):
                            translation = res.get("translation", "N/A")
                            amb_count = int(res.get("ambiguous_words_detected", 0) or 0)
                            exs = res.get("explanations", []) or []
                            print(f"   Translation: {translation}")
                            print(f"   Ambiguous: {amb_count}")
                            print(f"   Time: {inf_time:.3f}s")

                            if exs and isinstance(exs, list):
                                for exp in exs:
                                    if isinstance(exp, dict):
                                        word = exp.get("ambiguous_word", exp.get("token", "N/A"))
                                        clean = (
                                            str(word)
                                            .replace("▁", "")
                                            .replace("Ġ", "")
                                            .replace("##", "")
                                            .replace("@@", "")
                                            .replace("</w>", "")
                                            .strip()
                                            .lower()
                                        )
                                        if clean in dscd_homographs:
                                            dscd_homographs_detected.add(clean)
                                        try:
                                            conf = float(exp.get("confidence", 0.5))
                                            span = float(exp.get("span", 0.0))
                                            u = float(exp.get("uncertainty", 0.0))
                                            print(f"   -> '{word}': conf={conf:.3f}, s={span:.3f}, u={u:.3f}")
                                        except Exception:
                                            print(f"   -> '{word}': (no metrics)")
                                inference_success += 1
                            else:
                                print("   No explanations")
                                inference_success += 1
                        else:
                            print("   Unexpected format")
                            inference_failed += 1
                        _safe_cleanup()
                    except Exception as e:
                        print(f"   Failed: {type(e).__name__}: {str(e)[:200]}")
                        if _DEBUG_DISCOVERY:
                            try:
                                traceback.print_exc()
                            except Exception:
                                pass
                        inference_failed += 1

                print("\n" + "-" * 80)
                print(f"Results: {inference_success}/{len(test_sentences)} successful")
                if inference_times:
                    avg_time = sum(inference_times) / len(inference_times)
                    print(f"Performance: {avg_time:.3f}s avg per sentence")
                if dscd_homographs_detected:
                    print(f"DSCD homographs detected: {', '.join(sorted(dscd_homographs_detected))}")
                else:
                    print("No DSCD homographs detected in test sentences")
                    if len(dscd_homographs) == 0:
                        print("   -> DSCD has no discoveries")
                    else:
                        print(f"   -> DSCD has {len(dscd_homographs)} homographs but none in test sentences")
        except Exception:
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        # SYSTEM TEST: quick component presence checks
        print("\n[SYSTEM TEST]")
        try:
            core = trained_model.module if hasattr(trained_model, "module") else trained_model
            dscd_ok = hasattr(core, "dscd") and callable(getattr(core.dscd, "forward", None))
            asbn_ok = hasattr(core, "asbn") and callable(getattr(core.asbn, "forward", None))
            trg_ok = hasattr(core, "trg_system") and callable(getattr(core.trg_system, "process_sentence_for_explanations", None))
            mbart_ok = False
            if hasattr(core, "mbart"):
                mbart = core.mbart
                mbart_ok = callable(getattr(mbart, "generate", None))
            print("  Component status:")
            print(f"    - DSCD: {'OK' if dscd_ok else 'MISSING'}")
            print(f"    - ASBN: {'OK' if asbn_ok else 'MISSING'}")
            print(f"    - TRG: {'OK' if trg_ok else 'MISSING'}")
            print(f"    - M2M100: {'OK' if mbart_ok else 'MISSING'}")
            if dscd_ok and asbn_ok and trg_ok and mbart_ok:
                print("  All components operational")
            else:
                print("  Some components missing")
        except Exception:
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        print("\n" + "=" * 80)
        print("NEXT STEPS")
        print("=" * 80)
        print("\n1. Single translation:")
        print("   result = translate_with_explanations(trained_model, tokenizer, 'আমি কল বন্ধ করেছি।')")
        print("\n2. Batch translation:")
        print("   for sent in sentences:")
        print("       res = translate_with_explanations(trained_model, tokenizer, sent)")
        print("\n3. Load checkpoint:")
        print("   ckpt = torch.load('/kaggle/working/tatn_final.pt')")
        print("   model.load_state_dict(ckpt['model_state_dict'])")
        print("   model.dscd.load_state_dict(ckpt['dscd_state'])")
        print("\n4. Full evaluation:")
        print("   results = comprehensive_post_training_testing(trained_model, tokenizer)")
        print("\n5. Demo:")
        print("   demonstrate_system(trained_model, tokenizer)")

        if prototype_extraction_success:
            print(f"\nPrototypes saved: {total_prototypes_extracted} prototypes from {multi_sense_tokens_extracted} multi-sense tokens")
        else:
            print("\nPrototype extraction or checkpoint verification incomplete")

        print("\n" + "=" * 80)

    else:
        # Pipeline failed - diagnostics
        print("\n" + "=" * 80)
        print("PIPELINE FAILED")
        print("=" * 80)
        print(f"\nCategory: {failure_category or 'UNKNOWN'}")
        if failure_details:
            print(f"Details: {failure_details[:200]}")
        print("\n[DIAGNOSTICS]")
        components = {
            "Cell 0": "NUM_SAMPLES" in globals(),
            "Cell 1": "reconstruct_word_spans" in globals(),
            "Cell 2": "MemoryEfficientDataset" in globals(),
            "Cell 3": "MemoryEfficientDSCDOnline" in globals(),
            "Cell 4": "MemoryEfficientASBNModule" in globals(),
            "Cell 5": "CompleteTRGWithExplanations" in globals(),
            "Cell 6": "MemoryOptimizedTATNWithExplanations" in globals(),
            "Cell 7": "train_memory_efficient_tatn" in globals(),
            "Cell 8": "translate_with_explanations" in globals(),
            "Cell 9": "comprehensive_post_training_testing" in globals(),
            "Cell 10": "main_pipeline" in globals(),
        }
        for comp, present in components.items():
            status = "OK" if present else "MISSING"
            print(f"  {status} {comp}")

        print("\n[RECOVERY]")
        if failure_category == "MISSING_DEPENDENCY":
            print("\n-> Run Cells 0-10 in sequence, then re-run Cell 11")
        elif failure_category == "TOKENIZER_ERROR":
            print("\n-> Install dependencies:")
            print("  ! pip install transformers==4.30.2 sentencepiece tokenizers")
            print("  Then RESTART kernel and re-run Cells 0-11")
        elif failure_category == "OOM_ERROR":
            print("\n-> Reduce memory in Cell 0:")
            print("  BATCH_SIZE = 2")
            print("  NUM_SAMPLES = 15000")
            print("  ACCUMULATION_STEPS = 32")
        elif failure_category == "RUNTIME_ERROR":
            print("\n-> Enable debug in Cell 0 and re-run Cell 11 for details")
        elif failure_category == "USER_INTERRUPT":
            print("\n-> Check checkpoint exists: os.path.exists('%s')" % _CHECKPOINT_PATH)
        else:
            print("\n-> General steps: enable DEBUG, re-run Cells 0-11, verify data and GPU availability")

    total_duration = time.time() - start_time
    end_utc = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")

    print("\n" + "=" * 80)
    print("EXECUTION SUMMARY")
    print("=" * 80)
    print(f"User: {user_login}")
    print(f"Started: {now_utc}")
    print(f"Finished: {end_utc}")
    print(f"Duration: {_format_duration(total_duration)}")

    if pipeline_success:
        print("Status: SUCCESS")
        if "checkpoint_valid" in locals() and checkpoint_valid:
            print("Checkpoint: VALID")
        else:
            print("Checkpoint: CHECK NEEDED")
        if "prototype_extraction_success" in locals() and prototype_extraction_success:
            print(f"Prototypes: SAVED ({total_prototypes_extracted} total)")
        else:
            print("Prototypes: EXTRACTION FAILED")
    else:
        print(f"Status: FAILED ({failure_category or 'UNKNOWN'})")

    print("=" * 80)
    _safe_cleanup()

print("\n" + "=" * 80)
print("Cell 11: Execution wrapper ready - FIXED")
print("=" * 80)