In [None]:
!pip uninstall -y transformers sentence-transformers
!pip install transformers==4.30.2 --no-deps --force-reinstall
!pip install sentencepiece tokenizers sacremoses
!pip install scipy scikit-learn
!pip install --upgrade "protobuf==3.20.3"
# optional:
!pip install sentence-transformers==2.2.2
!pip install sacrebleu

In [None]:
import transformers
print(transformers.__version__)


In [None]:
# ==============================================================================
# CELL 0: TATN CONFIGURATION (BENGALI → ENGLISH)
# ==============================================================================

import os
import sys
import math
import random
import re
import unicodedata
import time
import threading
from pathlib import Path
from collections import deque, defaultdict
from typing import List, Dict, Tuple, Optional, Union, Set, Any
from types import SimpleNamespace

import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import warnings
import gc

# ------------------------------------------------------------------------------
# Optional dependencies
# ------------------------------------------------------------------------------

try:
    import pandas as pd
    _HAS_PANDAS = True
except ImportError:
    _HAS_PANDAS = False
    print("[WARN] pandas not available; CSV loading will fail")

try:
    from transformers import M2M100TokenizerFast as M2M100Tokenizer
    _HAS_M2M_TOKENIZER = True
except Exception:
    try:
        from transformers import M2M100Tokenizer
        _HAS_M2M_TOKENIZER = True
    except Exception:
        M2M100Tokenizer = None
        _HAS_M2M_TOKENIZER = False
        print("[WARN] M2M100Tokenizer not available")

try:
    from datasets import load_dataset
    _HAS_DATASETS = True
except Exception:
    load_dataset = None
    _HAS_DATASETS = False

warnings.filterwarnings("ignore")

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# ------------------------------------------------------------------------------
# Device / GPU configuration
# ------------------------------------------------------------------------------

NUM_GPUS = torch.cuda.device_count()
USE_MULTI_GPU = NUM_GPUS > 1

if USE_MULTI_GPU:
    print(f"[Cell 0] Multi-GPU Mode: {NUM_GPUS} GPUs available")
    DEVICE = torch.device("cuda:0")
else:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mode = "Single GPU Mode" if torch.cuda.is_available() else "CPU Mode"
    print(f"[Cell 0] {mode}")

print(f"[Cell 0] Device: {DEVICE} (visible GPUs: {NUM_GPUS})")

# ------------------------------------------------------------------------------
# Dataset path (NOTE: CSV is en→bn but model trains bn→en; swap happens in Cell 2)
# ------------------------------------------------------------------------------

DATASET_CSV_PATH = os.environ.get(
    "DATASET_PATH",
    "/kaggle/input/bn-homo/bn_homograph_complete_dataset.csv"
)

if not os.path.exists(DATASET_CSV_PATH):
    print(f"[WARN] Dataset CSV not found at: {DATASET_CSV_PATH}")
    print("[WARN] Training will use fallback dataset if file is not accessible")
else:
    print(f"[INFO] Dataset CSV found: {DATASET_CSV_PATH}")
    if _HAS_PANDAS:
        try:
            _test_df = pd.read_csv(DATASET_CSV_PATH, nrows=1)
            if "src" not in _test_df.columns or "tgt" not in _test_df.columns:
                print("[ERROR] CSV missing required columns 'src' and/or 'tgt'")
                print(f"[ERROR] Found columns: {list(_test_df.columns)}")
            else:
                print(f"[INFO] CSV validation passed (columns: {list(_test_df.columns)})")
            del _test_df
        except Exception as e:
            print(f"[WARN] Could not validate CSV structure: {e}")

# ------------------------------------------------------------------------------
# Core training hyperparameters
# ------------------------------------------------------------------------------

BATCH_SIZE = 100
NUM_SAMPLES = 30000
MAX_LENGTH = 50

LR_NMT = 2e-5
LR_TRG = 1e-5
LR_PHI = 1e-5

EPOCHS = 1
GRAD_CLIP_NORM = 1.0
USE_AMP = True
PRINT_INTERVAL = 300
SEED = 42

ACCUMULATION_STEPS = 16

# ------------------------------------------------------------------------------
# TRG / MC-dropout / data loader / memory
# ------------------------------------------------------------------------------

MC_DROPOUT_PASSES = 5
TRG_EVIDENCE_K = 3
MAX_SILVER_BUFFER = 50

NUM_WORKERS = 2
PIN_MEMORY = True
PREFETCH_FACTOR = 2
GRADIENT_CHECKPOINTING = True

# ------------------------------------------------------------------------------
# Debug flags (keep DSCD internal logs off by default to avoid spam)
# ------------------------------------------------------------------------------

DEBUG_DISCOVERY = False
DEBUG_TIMING = True
DEBUG_VERBOSE = False

# ------------------------------------------------------------------------------
# DSCD configuration (more permissive for multi-sense)
# ------------------------------------------------------------------------------

DSCD_BUFFER_SIZE = 50
DSCD_MAX_PROTOS = 8
DSCD_N_MIN = 3
DSCD_DISPERSION_THRESHOLD = 0.20
DSCD_EMBED_DIM = 1024
DSCD_TEMPERATURE = 0.7
DSCD_DROPOUT = 0.1
DSCD_AUGMENT_SCALE = 0.1
DSCD_ENABLE_TRAINING_CLUSTERING = True
DSCD_WARMUP_SAMPLES = 8000

PERIODIC_DISCOVERY_FREQUENCY = 50
_MAX_TOKENS_PER_DISCOVERY = 150

# ------------------------------------------------------------------------------
# ASBN / TRG configuration
# ------------------------------------------------------------------------------

ENABLE_ASBN_TRAINING = True
ENABLE_ASBN_INFERENCE = True

ENABLE_TRG_TRAINING = True
ENABLE_TRG_INFERENCE = True

CLUSTERING_TIMEOUT = 5
MEMORY_CLEANUP_FREQUENCY = 100
VALIDATION_CHECK_INTERVAL = 200
VERBOSE_LOGGING = False

# ------------------------------------------------------------------------------
# Checkpoint configuration
# ------------------------------------------------------------------------------

CHECKPOINT_DIR = "/kaggle/working/"
CHECKPOINT_SAVE_AFTER_TRAINING = True
CHECKPOINT_FILENAME = "tatn_final.pt"
CHECKPOINT_INTERVAL = 99999999
SAVE_REPLAY_BUFFER = False
LOAD_REPLAY_BUFFER = False
REPLAY_BUFFER_SIZE = 25000
RESUME_FROM_CHECKPOINT = False
CHECKPOINT_PATH = ""

# ------------------------------------------------------------------------------
# TRG uncertainty / span thresholds
# ------------------------------------------------------------------------------

TAU_LOW = 0.15
TAU_HIGH = 0.85
TAU_ACCEPT = 0.8

TRG_MAX_GEN_LEN = 16
TRG_GEN_EMBED = 64
TRG_GEN_HID = 64

SPAN_THRESHOLD = 0.12
UNCERTAINTY_THRESHOLD = 0.15
TRG_TEMPERATURE = 1.0

# ------------------------------------------------------------------------------
# ASBN loss weights
# ------------------------------------------------------------------------------

ASBN_HIDDEN_DIM = 64
ASBN_LAMBDA = 0.1
ASBN_DROPOUT = 0.1

LAMBDA_ASBN = 0.05
LAMBDA_DSCD = 0.15

# ------------------------------------------------------------------------------
# Domain labels / GRL schedule
# ------------------------------------------------------------------------------

TRAIN_DOMAIN = 0
TEST_DOMAIN = 1
USE_DOMAIN_LABELS = True

GRL_ALPHA_START = 0.0
GRL_ALPHA_END = 1.0
GRL_ALPHA_SCHEDULE = "linear"
GRL_ALPHA_STEPS = (
    NUM_SAMPLES // (BATCH_SIZE * ACCUMULATION_STEPS) * EPOCHS
    if BATCH_SIZE * ACCUMULATION_STEPS > 0
    else 10000
)

# ------------------------------------------------------------------------------
# Language configuration
# ------------------------------------------------------------------------------

SOURCE_LANGUAGE = "bn"
TARGET_LANGUAGE = "en"

M2M100_BN_TOKEN_ID = 128025
M2M100_EN_TOKEN_ID = 128022

# ------------------------------------------------------------------------------
# Reference homograph list (evaluation only; DSCD unsupervised)
# ------------------------------------------------------------------------------

HOMOGRAPH_REFERENCE_LIST_BN: Set[str] = {
    "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা",
    "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত",
    "দিন", "রাত", "জল", "বাড়ি", "পার্ক", "নদী", "বন", "ফুল", "গাছ",
    "চোখ", "মুখ", "পা", "কান", "গলা", "নাক", "দাঁত", "কোমর",
    "পড়া", "দেখা", "যাওয়া", "আসা", "খেলা", "লেখা", "বলা", "শোনা",
    "চলা", "ধরা", "দেওয়া", "নেওয়া",
    "সময়", "বছর", "মাস", "সাল", "ঘন্টা", "মুহূর্ত",
    "গরম", "শীত", "বাতাস", "আগুন", "পাথর", "মাটি",
    "ভাব", "রং", "আলো", "ছায়া", "শব্দ", "অর্থ",
}

HOMOGRAPH_WATCHLIST_BN: Set[str] = set()
HOMOGRAPH_WATCHLIST: Set[str] = set()
USE_WATCHLIST_PRIORITIZATION = False
WATCHLIST_ONLY_FOR_TRG = False

# ------------------------------------------------------------------------------
# Normalization utilities
# ------------------------------------------------------------------------------

def normalize_bengali(t: str) -> str:
    if not t:
        return ""
    t = unicodedata.normalize("NFKC", t)
    t = t.replace("▁", "").replace("##", "").strip()
    return t

def normalize_english(t: str) -> str:
    if not t:
        return ""
    t = unicodedata.normalize("NFKC", t).lower().strip()
    return t

# ------------------------------------------------------------------------------
# CUDA helpers
# ------------------------------------------------------------------------------

def empty_cuda_cache() -> None:
    gc.collect()
    if torch.cuda.is_available():
        try:
            torch.cuda.empty_cache()
        except Exception:
            pass

def safe_cuda_synchronize() -> None:
    if torch.cuda.is_available():
        try:
            torch.cuda.synchronize()
        except Exception:
            pass

def monitor_gpu_usage() -> None:
    if torch.cuda.is_available():
        visible_gpus = torch.cuda.device_count()
        print(f"\n[GPU MONITOR] Checking {visible_gpus} GPU(s):")
        for i in range(visible_gpus):
            try:
                mem_alloc = torch.cuda.memory_allocated(i) / (1024 ** 3)
                mem_reserved = torch.cuda.memory_reserved(i) / (1024 ** 3)
                print(
                    f"  GPU {i}: {mem_alloc:.2f}GB allocated / {mem_reserved:.2f}GB reserved"
                )
            except Exception:
                print(f"  GPU {i}: memory stats unavailable")
    else:
        print("[GPU MONITOR] No CUDA devices available")

# ------------------------------------------------------------------------------
# Checkpoint helpers
# ------------------------------------------------------------------------------

def get_checkpoint_path() -> str:
    return os.path.join(CHECKPOINT_DIR, CHECKPOINT_FILENAME)

def should_save_checkpoint(global_step: int, epoch: int, is_final: bool = False) -> bool:
    if is_final and CHECKPOINT_SAVE_AFTER_TRAINING:
        return True
    if (
        CHECKPOINT_INTERVAL < 99999999
        and global_step >= CHECKPOINT_INTERVAL
        and global_step % CHECKPOINT_INTERVAL == 0
    ):
        return True
    return False

# ------------------------------------------------------------------------------
# Function timeout utility
# ------------------------------------------------------------------------------

class FunctionTimeoutError(Exception):
    pass

def with_timeout(seconds: int):
    def decorator(func):
        def wrapper(*args, **kwargs):
            result = [FunctionTimeoutError("Function timed out")]

            def target():
                try:
                    result[0] = func(*args, **kwargs)
                except Exception as e:
                    result[0] = e

            thread = threading.Thread(target=target, daemon=True)
            thread.start()
            thread.join(timeout=seconds)
            if thread.is_alive():
                return None
            if isinstance(result[0], Exception):
                if isinstance(result[0], FunctionTimeoutError):
                    return None
                raise result[0]
            return result[0]
        return wrapper
    return decorator

# ------------------------------------------------------------------------------
# Token utilities for DSCD / tokenizer
# ------------------------------------------------------------------------------

def get_special_tokens(tokenizer) -> Set[str]:
    try:
        s = set(getattr(tokenizer, "all_special_tokens", []))
    except Exception:
        s = {"<pad>", "</s>", "<s>", "<unk>"}
    s.update({SOURCE_LANGUAGE, TARGET_LANGUAGE})
    return s

_token_validation_cache: Dict[Tuple[str, str], bool] = {}
_cache_lock = threading.Lock()
_cache_max_size = 10000

def is_valid_token(
    token,
    special_tokens: Optional[Set[str]] = None,
    tokenizer=None,
    language: str = "bn",
) -> bool:
    token = "" if token is None else str(token)
    cache_key = (token, language)

    with _cache_lock:
        if cache_key in _token_validation_cache:
            return _token_validation_cache[cache_key]

    clean = token.replace("▁", "").replace("##", "").strip()
    if special_tokens and token in special_tokens:
        result = False
    else:
        min_len = 2
        if len(clean) < min_len:
            result = False
        else:
            has_bengali_chars = any("\u0980" <= c <= "\u09FF" for c in clean)
            if not has_bengali_chars:
                result = False
            else:
                bengali_count = sum(1 for c in clean if "\u0980" <= c <= "\u09FF")
                alphanum_count = sum(1 for c in clean if c.isalnum())
                if alphanum_count == 0:
                    result = False
                else:
                    bengali_ratio = bengali_count / alphanum_count
                    result = bengali_ratio >= 0.5

    with _cache_lock:
        if len(_token_validation_cache) < _cache_max_size:
            _token_validation_cache[cache_key] = result

    return result

def safe_tokenize_with_offsets(tokenizer, text: str, max_length: int = 512):
    try:
        encoded = tokenizer(
            text,
            return_offsets_mapping=True,
            max_length=max_length,
            truncation=True,
            add_special_tokens=False,
        )
        toks = tokenizer.convert_ids_to_tokens(encoded.get("input_ids", []))
        offsets = encoded.get("offset_mapping", [(0, 0)] * len(toks))
        return toks, offsets
    except Exception:
        return None, None

# ------------------------------------------------------------------------------
# Discovery timing helper used by DSCD
# ------------------------------------------------------------------------------

class DiscoveryTimer:
    def __init__(self):
        self.discovery_times: List[float] = []
        self.discovery_steps: List[int] = []

    def record(self, step: int, duration: float) -> None:
        self.discovery_times.append(duration)
        self.discovery_steps.append(step)

    def get_stats(self) -> Dict[str, float]:
        if not self.discovery_times:
            return {"count": 0, "total": 0.0, "avg": 0.0, "max": 0.0}
        total = sum(self.discovery_times)
        return {
            "count": len(self.discovery_times),
            "total": total,
            "avg": total / len(self.discovery_times),
            "max": max(self.discovery_times),
        }

_discovery_timer = DiscoveryTimer()
discoverytimer = _discovery_timer

# ------------------------------------------------------------------------------
# Seeding and CuDNN behaviour
# ------------------------------------------------------------------------------

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

if hasattr(torch, "set_float32_matmul_precision"):
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# ------------------------------------------------------------------------------
# Summary printout
# ------------------------------------------------------------------------------

effective_batch = BATCH_SIZE * ACCUMULATION_STEPS
if USE_MULTI_GPU:
    effective_batch *= NUM_GPUS

print("\n" + "=" * 80)
print("TATN CONFIGURATION (Bengali to English)")
print("=" * 80)
print(f"User: {os.getenv('KAGGLE_USERNAME', os.getenv('USER', 'manas0003'))}")
print(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime())} UTC")
print(f"Multi-GPU: {'ENABLED' if USE_MULTI_GPU else 'DISABLED'} ({NUM_GPUS} GPUs)")
print(f"Dataset: {DATASET_CSV_PATH}")
print(f"Samples: {NUM_SAMPLES:,} | Batch: {BATCH_SIZE} | Accum: {ACCUMULATION_STEPS}")
print(f"Effective batch: {effective_batch}")
print(f"Max length: {MAX_LENGTH} | Epochs: {EPOCHS} | AMP: {USE_AMP}")
print()
print("DSCD Config:")
print(f"  Buffer: {DSCD_BUFFER_SIZE} | n_min: {DSCD_N_MIN} | Max protos: {DSCD_MAX_PROTOS}")
print(f"  Dispersion threshold: {DSCD_DISPERSION_THRESHOLD}")
print(f"  Periodic discovery: Every {PERIODIC_DISCOVERY_FREQUENCY} steps")
print(f"  Max tokens per discovery: {_MAX_TOKENS_PER_DISCOVERY}")
print()
print("TRG & Uncertainty:")
print(f"  MC Dropout passes: {MC_DROPOUT_PASSES} | TAU_LOW: {TAU_LOW}")
print(f"  SPAN_THRESHOLD: {SPAN_THRESHOLD} | UNCERTAINTY_THRESHOLD: {UNCERTAINTY_THRESHOLD}")
print(f"  TAU_HIGH: {TAU_HIGH} | Temperature: {TRG_TEMPERATURE}")
print()
print("ASBN / Loss:")
print(f"  LAMBDA_ASBN: {LAMBDA_ASBN} | LAMBDA_DSCD: {LAMBDA_DSCD}")
print(f"  Domain labels: {USE_DOMAIN_LABELS} | GRL: {GRL_ALPHA_SCHEDULE}")
print(f"  GRL steps: {GRL_ALPHA_STEPS}")
print()
print("Debug Flags:")
print(f"  Discovery logging: {DEBUG_DISCOVERY}")
print(f"  Timing monitoring: {DEBUG_TIMING}")
print(f"  Verbose mode: {DEBUG_VERBOSE}")
print()
print("Validation:")
print(f"  Check interval: {VALIDATION_CHECK_INTERVAL} steps")
print()
print("Language Tokens:")
print(f"  Bengali (bn): {M2M100_BN_TOKEN_ID}")
print(f"  English (en): {M2M100_EN_TOKEN_ID}")
print()
print("Checkpoint:")
print(f"  Path: {get_checkpoint_path()}")
print(f"  Save strategy: Final only")
print()
print("Discovery Mode:")
print("  PURE UNSUPERVISED (no watchlist bias)")
print(f"  Reference list: {len(HOMOGRAPH_REFERENCE_LIST_BN)} words (evaluation only)")
print("  Watchlist prioritization: DISABLED")
print("=" * 80)

if not _HAS_PANDAS:
    print("[ERROR] pandas not available - CSV loading will fail!")
if not _HAS_M2M_TOKENIZER:
    print("[ERROR] M2M100Tokenizer not available - tokenization will fail!")

try:
    test_file = os.path.join(CHECKPOINT_DIR, ".test_write")
    with open(test_file, "w") as f:
        f.write("test")
    os.remove(test_file)
    print(f"Checkpoint directory writable: {CHECKPOINT_DIR}")
except Exception as e:
    print(f"Checkpoint directory not writable: {e}")

monitor_gpu_usage()

print("\n" + "=" * 80)
print("Cell 0: Configuration loaded successfully")
print("=" * 80)


In [None]:
# ===========================================================================================
# CELL 1: TOKENIZER UTILITIES + PROTOTYPE DATABASE (BENGALI-FOCUSED)
# ===========================================================================================

import threading
import json
import pickle
from typing import Tuple, List, Dict, Optional, Set, Any
from pathlib import Path
from collections import defaultdict
import numpy as np
import torch
import datetime

try:
    if isinstance(MAX_LENGTH, (int, float)) and MAX_LENGTH > 0:
        SAFE_OFFSET_MAX_LEN = int(MAX_LENGTH)
    else:
        SAFE_OFFSET_MAX_LEN = 48
except (NameError, ValueError, TypeError):
    SAFE_OFFSET_MAX_LEN = 48

try:
    _SOURCE_LANG = SOURCE_LANGUAGE
except NameError:
    _SOURCE_LANG = "bn"

try:
    _DEBUG_VERBOSE = DEBUG_VERBOSE
except NameError:
    _DEBUG_VERBOSE = False

try:
    _DEBUG_DISCOVERY = DEBUG_DISCOVERY
except NameError:
    _DEBUG_DISCOVERY = False

_SPECIAL_TOKENS_CACHE: Dict[str, Set[str]] = {}
_SPECIAL_TOKENS_LOCK = threading.Lock()
_LANGUAGE_WARNING_COUNT = 0
_MAX_LANGUAGE_WARNINGS = 3

def _special_token_cache_key(tokenizer) -> str:
    name = getattr(tokenizer, "name_or_path", None) or getattr(tokenizer, "name", None)
    if not name:
        name = "unknown_tokenizer"
    vocab = None
    if hasattr(tokenizer, "vocab_size"):
        try:
            vocab = int(getattr(tokenizer, "vocab_size"))
        except Exception:
            vocab = None
    elif hasattr(tokenizer, "get_vocab") and callable(getattr(tokenizer, "get_vocab")):
        try:
            vocab = len(tokenizer.get_vocab())
        except Exception:
            vocab = None
    return f"{name}__vocab={vocab}"

def get_tokenizer_special_tokens(tokenizer) -> Set[str]:
    cache_key = _special_token_cache_key(tokenizer)
    with _SPECIAL_TOKENS_LOCK:
        if cache_key in _SPECIAL_TOKENS_CACHE:
            return _SPECIAL_TOKENS_CACHE[cache_key]

        special_tokens: Set[str] = set()
        try:
            if hasattr(tokenizer, "all_special_tokens"):
                try:
                    result = getattr(tokenizer, "all_special_tokens")
                    if isinstance(result, (list, tuple, set)):
                        special_tokens.update(x for x in result if x)
                except Exception:
                    pass
            if hasattr(tokenizer, "additional_special_tokens"):
                try:
                    result = getattr(tokenizer, "additional_special_tokens")
                    if isinstance(result, (list, tuple, set)):
                        special_tokens.update(x for x in result if x)
                except Exception:
                    pass
            for attr in ("pad_token", "unk_token", "bos_token", "eos_token",
                         "cls_token", "sep_token", "mask_token"):
                if hasattr(tokenizer, attr):
                    try:
                        tok = getattr(tokenizer, attr)
                        if tok:
                            special_tokens.add(tok)
                    except Exception:
                        pass
            try:
                stm = (
                    getattr(tokenizer, "special_tokens_map", None)
                    or getattr(tokenizer, "special_tokens_map_extended", None)
                )
                if isinstance(stm, dict):
                    for v in stm.values():
                        if isinstance(v, str) and v:
                            special_tokens.add(v)
            except Exception:
                pass
        except Exception:
            special_tokens = set()

        special_tokens.update({
            "__bn__", "__en__",
            "</s>", "<pad>", "<s>", "<unk>",
            "[PAD]", "[EOS]", "[UNK]", "[CLS]", "[SEP]", "[MASK]",
        })

        try:
            vocab = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else {}
            special_tokens = {
                tok
                for tok in special_tokens
                if tok in vocab or tok in {"</s>", "<pad>", "<s>", "<unk>"}
            }
        except Exception:
            pass

        _SPECIAL_TOKENS_CACHE[cache_key] = special_tokens
        return special_tokens

def _normalize_offset_mapping_for_batchencoding(enc):
    try:
        if "offset_mapping" in enc and enc["offset_mapping"] is not None:
            off = enc["offset_mapping"]
            try:
                if hasattr(off, "tolist"):
                    arr = off.tolist()
                    if isinstance(arr, list) and len(arr) > 0 and isinstance(arr[0], list):
                        enc["offset_mapping"] = [
                            (x[0], x[1])
                            if (isinstance(x, (list, tuple)) and len(x) >= 2)
                            else (None, None)
                            for x in arr[0]
                        ]
                        return enc
                if isinstance(off, (list, tuple)):
                    if len(off) > 0 and isinstance(off[0], (list, tuple)):
                        enc["offset_mapping"] = [
                            (x[0], x[1])
                            if (isinstance(x, (list, tuple)) and len(x) >= 2)
                            else (None, None)
                            for x in off[0]
                        ]
                        return enc
            except Exception:
                pass
    except Exception:
        pass

    try:
        data = getattr(enc, "data", None)
        if (
            data
            and isinstance(data, dict)
            and "offset_mapping" in data
            and data["offset_mapping"] is not None
        ):
            om = data["offset_mapping"]
            if isinstance(om, (list, tuple)) and len(om) > 0 and isinstance(om[0], (list, tuple)):
                enc["offset_mapping"] = [
                    (x[0], x[1])
                    if (isinstance(x, (list, tuple)) and len(x) >= 2)
                    else (None, None)
                    for x in om[0]
                ]
                return enc
    except Exception:
        pass

    try:
        seq_len = 0
        if "input_ids" in enc:
            input_ids = enc["input_ids"]
            if hasattr(input_ids, "shape") and len(input_ids.shape) > 0:
                seq_len = int(input_ids.shape[-1])
            elif (
                isinstance(input_ids, (list, tuple))
                and len(input_ids) > 0
                and isinstance(input_ids[0], (list, tuple))
            ):
                seq_len = len(input_ids[0])
        enc["offset_mapping"] = [(None, None)] * seq_len
    except Exception:
        enc["offset_mapping"] = []

    return enc

def safe_offsets_tokenize(
    tokenizer,
    text: str,
    max_length: Optional[int] = None,
    include_special_tokens: bool = False,
) -> dict:
    if max_length is None:
        max_length = SAFE_OFFSET_MAX_LEN
    eff_max = int(max_length)

    try:
        if not isinstance(text, str):
            text = "" if text is None else str(text)
    except Exception:
        if _DEBUG_VERBOSE:
            print("[WARN] Failed to convert input to string, using empty string")
        text = ""

    char_limit = min(eff_max * 30, 8000)
    sample_text = text[:char_limit]

    is_fast = getattr(tokenizer, "is_fast", False)

    if is_fast:
        try:
            enc = tokenizer(
                sample_text,
                return_offsets_mapping=True,
                return_tensors="pt",
                truncation=True,
                padding=False,
                max_length=eff_max,
                add_special_tokens=include_special_tokens,
            )
            enc = _normalize_offset_mapping_for_batchencoding(enc)
            return enc
        except Exception:
            pass

    try:
        enc = tokenizer(
            sample_text,
            return_tensors="pt",
            truncation=True,
            padding=False,
            max_length=eff_max,
            add_special_tokens=include_special_tokens,
        )
    except Exception as e:
        if _DEBUG_VERBOSE:
            print(f"[WARN] Tokenization failed: {e}, returning empty encoding")
        pad_id = getattr(tokenizer, "pad_token_id", 0)
        enc = {
            "input_ids": torch.tensor([[pad_id]], dtype=torch.long),
            "attention_mask": torch.tensor([[1]], dtype=torch.long),
        }
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc

    try:
        input_ids = None
        try:
            input_ids = enc["input_ids"][0].tolist()
        except Exception:
            if hasattr(enc, "data") and "input_ids" in enc.data:
                input_ids = enc.data["input_ids"][0]

        tokens: List[str] = []
        if input_ids is not None:
            try:
                tokens = tokenizer.convert_ids_to_tokens(input_ids)
            except Exception:
                tokens = []

        offsets_list: List[Tuple[Optional[int], Optional[int]]] = []
        src = sample_text
        cur_pos = 0
        for tok in tokens:
            token_text = (tok or "").replace("▁", "").replace("##", "").replace("Ġ", "").strip()
            if not token_text:
                offsets_list.append((None, None))
                continue
            idx = src.find(token_text, cur_pos)
            if idx == -1:
                idx = src.lower().find(token_text.lower(), cur_pos)
            if idx == -1:
                offsets_list.append((None, None))
            else:
                start = int(idx)
                end = int(idx + len(token_text))
                offsets_list.append((start, end))
                cur_pos = end

        enc["offset_mapping"] = offsets_list
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc
    except Exception:
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc

def reconstruct_word_spans(
    tokenizer,
    text: str,
    max_length: Optional[int] = None,
) -> Tuple[Dict[int, Optional[str]], List[str]]:
    global _LANGUAGE_WARNING_COUNT

    if max_length is None:
        max_length = SAFE_OFFSET_MAX_LEN
    eff_max = int(max_length)

    if not isinstance(text, str) or len(text.strip()) == 0:
        return {}, []

    has_bengali = any("\u0980" <= c <= "\u09FF" for c in text)
    has_english = any("a" <= c.lower() <= "z" for c in text)

    if _DEBUG_VERBOSE and _DEBUG_DISCOVERY:
        bengali_pct = (
            sum(1 for c in text if "\u0980" <= c <= "\u09FF")
            / max(1, len(text))
            * 100.0
        )
        print(f"[TOKENIZER] Text sample: {text[:50]}")
        print(
            f"[TOKENIZER] Bengali: {has_bengali} ({bengali_pct:.1f}%), "
            f"English: {has_english}"
        )

    if not has_bengali and has_english and _LANGUAGE_WARNING_COUNT < _MAX_LANGUAGE_WARNINGS:
        if _DEBUG_DISCOVERY:
            print("[TOKENIZER WARNING] Text appears to be ENGLISH, not BENGALI")
            print(f"  Sample: {text[:80]}")
        _LANGUAGE_WARNING_COUNT += 1
        if _LANGUAGE_WARNING_COUNT == _MAX_LANGUAGE_WARNINGS:
            print("[TOKENIZER] Suppressing further language warnings")

    char_limit = min(eff_max * 30, 8000)
    text = text[:char_limit]
    text_len = len(text)

    special_tokens = get_tokenizer_special_tokens(tokenizer)

    try:
        current_lang = SOURCE_LANGUAGE
    except NameError:
        current_lang = _SOURCE_LANG

    try:
        encoded = safe_offsets_tokenize(
            tokenizer, text, max_length=eff_max, include_special_tokens=False
        )
    except Exception:
        return {}, []

    offsets = encoded.get("offset_mapping", [])
    try:
        input_ids = encoded["input_ids"][0].tolist()
    except Exception:
        input_ids = []
    try:
        tokens = tokenizer.convert_ids_to_tokens(input_ids) if input_ids else []
    except Exception:
        tokens = []

    if isinstance(offsets, list) and len(offsets) > 0 and all(
        isinstance(x, tuple) for x in offsets
    ):
        offsets_list = offsets
    elif isinstance(offsets, list) and len(offsets) > 0 and isinstance(
        offsets[0], (list, tuple)
    ):
        offsets_list = [
            (x[0], x[1])
            if (isinstance(x, (list, tuple)) and len(x) >= 2)
            else (None, None)
            for x in offsets[0]
        ]
    else:
        offsets_list = [(None, None)] * len(tokens)

    token_word_map: Dict[int, Optional[str]] = {}
    words: List[str] = []

    used_any_offset = any(
        isinstance(o, tuple) and o[0] is not None and o[1] is not None
        for o in offsets_list
    )
    if used_any_offset:
        word_start: Optional[int] = None
        word_end: Optional[int] = None

        for idx, (off, tok) in enumerate(zip(offsets_list, tokens)):
            try:
                off_start = int(off[0]) if off[0] is not None else None
                off_end = int(off[1]) if off[1] is not None else None
            except Exception:
                off_start, off_end = None, None

            if off_start is not None and off_end is not None:
                if off_start < 0 or off_end < 0:
                    if _DEBUG_VERBOSE:
                        print(
                            f"[WARN] Negative offset detected: "
                            f"({off_start}, {off_end}), skipping"
                        )
                    off_start, off_end = None, None
                else:
                    off_start = max(0, min(off_start, text_len))
                    off_end = max(off_start, min(off_end, text_len))

            if off_start is None or off_end is None:
                if word_start is not None and word_end is not None:
                    try:
                        wtext = text[word_start:word_end].strip()
                        if wtext:
                            words.append(wtext)
                    except Exception:
                        pass
                word_start = None
                word_end = None
                token_word_map[idx] = None
                continue

            if tok in special_tokens:
                token_word_map[idx] = None
                continue

            if word_start is None:
                word_start = off_start
                word_end = off_end
            else:
                if off_start > word_end:
                    try:
                        wtext = text[word_start:word_end].strip()
                        if wtext:
                            words.append(wtext)
                    except Exception:
                        pass
                    word_start = off_start
                    word_end = off_end
                else:
                    word_end = max(word_end, off_end)

            try:
                current_word = text[word_start:word_end].strip()
                token_word_map[idx] = current_word if current_word else None
            except Exception:
                token_word_map[idx] = None

        if word_start is not None and word_end is not None:
            try:
                wtext = text[word_start:word_end].strip()
                if wtext:
                    words.append(wtext)
            except Exception:
                pass

        if token_word_map:
            words = [w for w in words if isinstance(w, str) and w.strip()]
            return token_word_map, words

    token_word_map = {}
    assembled: List[str] = []
    current_parts: List[str] = []
    running_word = ""
    max_word_len = 100

    for i, tok in enumerate(tokens):
        if tok in special_tokens:
            token_word_map[i] = None
            continue
        clean = (tok or "").replace("▁", "").replace("Ġ", "").replace("##", "").strip()
        if not clean:
            token_word_map[i] = None
            continue

        if tok.startswith("▁") or tok.startswith("Ġ"):
            if current_parts:
                word = "".join(current_parts)
                if len(word) <= max_word_len:
                    assembled.append(word)
            current_parts = [clean]
            running_word = clean
        else:
            current_parts.append(clean)
            running_word = "".join(current_parts)
            if len(running_word) > max_word_len:
                if current_parts[:-1]:
                    word = "".join(current_parts[:-1])
                    assembled.append(word)
                current_parts = [clean]
                running_word = clean

        token_word_map[i] = running_word if running_word else None

    if current_parts:
        word = "".join(current_parts)
        if len(word) <= max_word_len:
            assembled.append(word)

    if token_word_map:
        words = [w for w in assembled if w and w.strip()]
        return token_word_map, words

    try:
        word_list = [w for w in text.split() if w.strip()]
        token_word_map = {}

        if tokens and word_list:
            word_idx = 0

            for i, tok in enumerate(tokens):
                clean = (tok or "").replace("▁", "").replace("Ġ", "").replace("##", "").strip()
                if not clean or tok in special_tokens:
                    token_word_map[i] = None
                    continue

                if word_idx < len(word_list):
                    current_word = word_list[word_idx]
                    if clean in current_word or current_word.startswith(clean):
                        token_word_map[i] = current_word
                    else:
                        word_idx = min(word_idx + 1, len(word_list) - 1)
                        token_word_map[i] = word_list[word_idx]
                else:
                    token_word_map[i] = word_list[-1] if word_list else None

        return token_word_map, word_list
    except Exception:
        return {}, []

def is_valid_token(
    token: str,
    special_tokens: Optional[Set[str]] = None,
    tokenizer=None,
    language: str = "bn"
) -> bool:
    if not token or not isinstance(token, str):
        return False
    
    token = token.strip()
    if not token:
        return False
    
    if special_tokens and token in special_tokens:
        return False
    
    clean = token.replace("▁", "").replace("Ġ", "").replace("##", "").replace(",", "").strip()
    if len(clean) < 2:
        return False
    
    if not any(c.isalpha() for c in clean):
        return False
    
    punct_set = set(".,!?;:—-")
    if all(c in punct_set for c in clean):
        return False
    
    if clean.isdigit():
        return False
    
    return True

def map_subwords_to_words(tokens: List[str], tokenizer) -> Dict[int, str]:
    special_tokens = get_tokenizer_special_tokens(tokenizer)
    word_map = {}
    current_word = ""
    word_start_idx = 0
    
    for i, tok in enumerate(tokens):
        if tok in special_tokens:
            word_map[i] = None
            continue
        
        clean = tok.replace("▁", "").replace("Ġ", "").replace("##", "").strip()
        
        if tok.startswith("▁") or tok.startswith("Ġ"):
            if current_word:
                for j in range(word_start_idx, i):
                    word_map[j] = current_word
            current_word = clean
            word_start_idx = i
        else:
            current_word += clean
    
    for j in range(word_start_idx, len(tokens)):
        word_map[j] = current_word if current_word else None
    
    return word_map

class PrototypeDatabase:
    def __init__(self, save_dir='./dscd_prototypes'):
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(exist_ok=True, parents=True)
        
        self.prototype_file = self.save_dir / 'prototypes.pkl'
        self.metadata_file = self.save_dir / 'metadata.json'
        
        self.prototypes = {}
        self.metadata = {
            'total_words': 0,
            'total_prototypes': 0,
            'last_updated': None,
            'version': '1.0'
        }
    
    def load(self):
        if self.prototype_file.exists():
            try:
                with open(self.prototype_file, 'rb') as f:
                    self.prototypes = pickle.load(f)
                print(f"✓ Loaded {len(self.prototypes)} words from {self.prototype_file}")
            except Exception as e:
                print(f"⚠ Failed to load prototypes: {e}")
                self.prototypes = {}
        
        if self.metadata_file.exists():
            try:
                with open(self.metadata_file, 'r', encoding='utf-8') as f:
                    self.metadata = json.load(f)
            except Exception as e:
                print(f"⚠ Failed to load metadata: {e}")
    
    def save(self):
        self.metadata['total_words'] = len(self.prototypes)
        self.metadata['total_prototypes'] = sum(len(senses) for senses in self.prototypes.values())
        self.metadata['last_updated'] = datetime.datetime.now().isoformat()
        
        try:
            with open(self.prototype_file, 'wb') as f:
                pickle.dump(self.prototypes, f)
            
            with open(self.metadata_file, 'w', encoding='utf-8') as f:
                json.dump(self.metadata, f, indent=2, ensure_ascii=False)
            
            print(f"✓ Saved {self.metadata['total_words']} words, {self.metadata['total_prototypes']} prototypes")
            print(f"  → {self.prototype_file}")
        except Exception as e:
            print(f"⚠ Failed to save prototypes: {e}")
    
    def add_or_update_prototype(self, word: str, sense_id: int, centroid, count: int, validity: bool = True):
        word = str(word).strip().lower()
        
        if word not in self.prototypes:
            self.prototypes[word] = {}
            if _DEBUG_DISCOVERY:
                print(f"  NEW WORD: '{word}'")
        
        if sense_id not in self.prototypes[word]:
            if _DEBUG_DISCOVERY:
                print(f"    NEW SENSE: '{word}' → sense {sense_id}")
        
        self.prototypes[word][sense_id] = {
            'centroid': centroid.detach().cpu() if torch.is_tensor(centroid) else centroid,
            'count': count,
            'validity': validity
        }
    
    def get_prototypes(self, word: str):
        word = str(word).strip().lower()
        return self.prototypes.get(word, None)
    
    def sync_from_dscd(self, dscd_module):
        print("Syncing prototypes from DSCD module...")
        
        lock = None
        if hasattr(dscd_module, 'buffer_lock'):
            lock = dscd_module.buffer_lock
        elif hasattr(dscd_module, 'clustering_lock'):
            lock = dscd_module.clustering_lock
        
        if lock:
            with lock:
                stores = dict(dscd_module.prototype_stores)
        else:
            stores = dict(dscd_module.prototype_stores)
        
        for word, store in stores.items():
            if not hasattr(store, 'centroids') or not hasattr(store, 'counts'):
                continue
            
            centroids = store.centroids
            counts = store.counts
            
            for sense_id, (centroid, count) in enumerate(zip(centroids, counts)):
                validity = True
                if hasattr(store, 'size') and store.size >= 2:
                    validity = True
                
                self.add_or_update_prototype(
                    word=word,
                    sense_id=sense_id,
                    centroid=centroid,
                    count=count,
                    validity=validity
                )
        
        self.save()
        return len(stores)
    
    def load_into_dscd(self, dscd_module):
        print("Loading prototypes into DSCD module...")
        
        loaded_count = 0
        
        for word, senses in self.prototypes.items():
            centroids = []
            counts = []
            
            for sense_id in sorted(senses.keys()):
                sense_data = senses[sense_id]
                if not sense_data['validity']:
                    continue
                
                centroids.append(sense_data['centroid'].to(dscd_module.device))
                counts.append(sense_data['count'])
            
            if len(centroids) == 0:
                continue
            
            from types import SimpleNamespace
            store = SimpleNamespace(
                centroids=centroids,
                counts=counts,
                size=len(centroids),
                mu=0.5,
                tau=1.0
            )
            
            dscd_module.prototype_stores[word] = store
            loaded_count += 1
        
        print(f"✓ Loaded {loaded_count} words into DSCD")
        return loaded_count

def inspect_prototypes(word: str = None, proto_db=None):
    if proto_db is None:
        try:
            proto_db = PROTOTYPE_DB
        except NameError:
            print("Error: PROTOTYPE_DB not found. Create instance first.")
            return
    
    proto_db.load()
    
    if word:
        word = word.strip().lower()
        senses = proto_db.get_prototypes(word)
        if senses is None:
            print(f"Word '{word}' not found in database")
            return
        
        print(f"\n{'='*80}")
        print(f"WORD: '{word}' ({len(senses)} senses)")
        print(f"{'='*80}")
        
        for sense_id, data in sorted(senses.items()):
            print(f"\n  Sense {sense_id}:")
            print(f"    Count: {data['count']}")
            print(f"    Valid: {data['validity']}")
            print(f"    Centroid shape: {data['centroid'].shape}")
    else:
        print(f"\n{'='*80}")
        print("PROTOTYPE DATABASE SUMMARY")
        print(f"{'='*80}")
        print(f"Total words: {len(proto_db.prototypes)}")
        print(f"Total prototypes: {sum(len(s) for s in proto_db.prototypes.values())}")
        print(f"\nTop 20 words by sense count:")
        
        word_sense_counts = [(w, len(s)) for w, s in proto_db.prototypes.items()]
        word_sense_counts.sort(key=lambda x: x[1], reverse=True)
        
        for word, num_senses in word_sense_counts[:20]:
            counts = [proto_db.prototypes[word][sid]['count'] for sid in proto_db.prototypes[word]]
            print(f"  '{word}': {num_senses} senses, counts={counts}")

def test_tokenizer_utilities_quick(tokenizer=None) -> bool:
    sample_bn = "কাল আমি বাজারে যাব।"
    sample_en = "Tomorrow I will go to the market."

    print("\n" + "=" * 60)
    print("TOKENIZER UTILITIES TEST")
    print("=" * 60)

    try:
        if tokenizer is None:
            print("No tokenizer provided: skipping test")
            return True

        print("\n[TEST 1] Bengali text processing:")
        print(f"  Input: {sample_bn}")
        enc_bn = safe_offsets_tokenize(
            tokenizer, sample_bn, max_length=32, include_special_tokens=False
        )
        enc_len = (
            int(enc_bn["input_ids"].shape[-1])
            if isinstance(enc_bn, dict) and "input_ids" in enc_bn
            else "N/A"
        )
        print(f"  Encoded length: {enc_len}")
        offsets_bn = enc_bn.get("offset_mapping") or []
        print(f"  Offsets (first 5): {offsets_bn[:5]}")

        token_map_bn, words_bn = reconstruct_word_spans(tokenizer, sample_bn, max_length=32)
        print(f"  Reconstructed words: {words_bn}")
        print(f"  Token map sample: {dict(list(token_map_bn.items())[:3])}")

        has_bengali_words = any(
            any("\u0980" <= c <= "\u09FF" for c in w) for w in words_bn
        )
        print(f"  Contains Bengali words: {has_bengali_words}")

        print("\n[TEST 2] English text processing (should show warning):")
        print(f"  Input: {sample_en}")
        token_map_en, words_en = reconstruct_word_spans(tokenizer, sample_en, max_length=32)
        print(f"  Reconstructed words: {words_en}")

        has_english_words = any(
            any("a" <= c.lower() <= "z" for c in w) for w in words_en
        )
        print(f"  Contains English words: {has_english_words}")

        print("\n[TEST 3] Token-to-word mapping:")
        tokens_test = ["▁আমি", "▁", "ক", "ল", "▁বন্ধ"]
        word_map = map_subwords_to_words(tokens_test, tokenizer)
        print(f"  Tokens: {tokens_test}")
        print(f"  Word map: {word_map}")
        print(f"  '▁কল' split correctly: {word_map.get(2) == 'কল' and word_map.get(3) == 'কল'}")

        if has_bengali_words and not any(
            "a" <= c.lower() <= "z" for c in "".join(words_bn)
        ):
            print("\nTest PASSED: Bengali processing works correctly")
            return True
        else:
            print("\nTest WARNING: Check language detection logic")
            return False

    except Exception as e:
        print(f"\nTest FAILED: {repr(e)}")
        import traceback
        traceback.print_exc()
        return False
    finally:
        print("=" * 60 + "\n")

safeoffsetstokenize = safe_offsets_tokenize
reconstructwordspans = reconstruct_word_spans
gettokenizerspecialtokens = get_tokenizer_special_tokens
isvalidtoken = is_valid_token
mapsubwordstowords = map_subwords_to_words

try:
    PROTOTYPE_DB = PrototypeDatabase(save_dir='/kaggle/working')
except Exception as e:
    print(f"Warning: Could not initialize PROTOTYPE_DB: {e}")
    PROTOTYPE_DB = None

print("Cell 1: Tokenizer utilities + PrototypeDatabase loaded")


In [None]:
# ==============================================================================
# CELL 2: MEMORY-EFFICIENT DATA LOADING (BENGALI → ENGLISH TASK)
# ==============================================================================

from typing import Optional, List, Tuple, Dict, Any
from collections import defaultdict
import os
import time
import random
import traceback
import re

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, get_worker_info
from tqdm import tqdm

try:
    import pandas as pd
    _HAS_PANDAS = True
except ImportError:
    pd = None
    _HAS_PANDAS = False
    print("[CELL2] WARNING: pandas not available; CSV loading will fail!")

try:
    from datasets import load_dataset
    _HAS_DATASETS = True
except Exception:
    load_dataset = None
    _HAS_DATASETS = False

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except NameError:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_VERBOSE = bool(DEBUG_VERBOSE)
except NameError:
    _DEBUG_VERBOSE = False

DEBUG_CELL2 = bool(_VERBOSE_LOGGING) or bool(_DEBUG_VERBOSE)
DEBUG_LIMIT = 10
_cell2_dbg_counts: Dict[str, int] = defaultdict(int)

def cell2_dbg(key: str, msg: str, limit: int = DEBUG_LIMIT) -> None:
    if not DEBUG_CELL2:
        return
    _cell2_dbg_counts[key] += 1
    if _cell2_dbg_counts[key] <= limit:
        print(f"[CELL2-DBG] {msg}")

try:
    _NUM_SAMPLES = int(NUM_SAMPLES)
except Exception:
    _NUM_SAMPLES = 50000
    print("[CELL2] WARNING: NUM_SAMPLES not defined, using default 50000")

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except Exception:
    _MAX_LENGTH = 48
    print("[CELL2] WARNING: MAX_LENGTH not defined, using default 48")

try:
    _SOURCE_LANG = str(SOURCE_LANGUAGE)
    _TARGET_LANG = str(TARGET_LANGUAGE)
except NameError:
    _SOURCE_LANG = "bn"
    _TARGET_LANG = "en"
    print("[CELL2] WARNING: SOURCE_LANGUAGE/TARGET_LANGUAGE not defined, using defaults bn/en")

try:
    _M2M_BN_TOKEN_ID = int(M2M100_BN_TOKEN_ID)
    _M2M_EN_TOKEN_ID = int(M2M100_EN_TOKEN_ID)
except NameError:
    _M2M_BN_TOKEN_ID = 128025
    _M2M_EN_TOKEN_ID = 128022
    print("[CELL2] WARNING: M2M100 token IDs not defined, using defaults")

try:
    _NUM_GPUS = int(NUM_GPUS)
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
except NameError:
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1
    print(f"[CELL2] WARNING: GPU config not defined, detected {_NUM_GPUS} GPUs")

try:
    _NUM_WORKERS = int(NUM_WORKERS)
except NameError:
    _NUM_WORKERS = 0
    print("[CELL2] WARNING: NUM_WORKERS not defined, using 0")

try:
    _PIN_MEMORY = bool(PIN_MEMORY)
except NameError:
    _PIN_MEMORY = False

try:
    _PREFETCH_FACTOR = int(PREFETCH_FACTOR)
except NameError:
    _PREFETCH_FACTOR = 2

try:
    _DATASET_CSV_PATH = str(DATASET_CSV_PATH)
except NameError:
    _DATASET_CSV_PATH = "/kaggle/input/bengali-english-homograph/bengali_homograph_sentences.csv"
    print(f"[CELL2] WARNING: DATASET_CSV_PATH not defined, using default: {_DATASET_CSV_PATH}")

try:
    _TRAIN_DOMAIN = int(TRAIN_DOMAIN)
    _TEST_DOMAIN = int(TEST_DOMAIN)
    _USE_DOMAIN_LABELS = bool(USE_DOMAIN_LABELS)
except NameError:
    _TRAIN_DOMAIN = 0
    _TEST_DOMAIN = 1
    _USE_DOMAIN_LABELS = False
    print("[CELL2] WARNING: Domain label config not found, disabling domain labels")

_has_reconstruct_word_spans = "reconstruct_word_spans" in globals()
_has_safe_offsets_tokenize = "safe_offsets_tokenize" in globals()
_has_map_subwords_to_words = "map_subwords_to_words" in globals()

_BENGALI_CHAR_RE = re.compile(r"[\u0980-\u09FF]")

def is_bengali_text(s: str) -> bool:
    if s is None:
        return False
    if not isinstance(s, str) or not s:
        return False
    return bool(_BENGALI_CHAR_RE.search(s))

def normalize_bengali(text: str) -> str:
    if not isinstance(text, str):
        return ""
    text = text.strip()
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[।\.]$', '', text)
    return text

def normalize_english(text: str) -> str:
    if not isinstance(text, str):
        return ""
    text = text.strip().lower()
    text = re.sub(r'\s+', ' ', text)
    text = re.sub(r'[\.!?]+$', '', text)
    return text

def _dataloader_worker_init_fn(worker_id: int) -> None:
    worker_info = get_worker_info()
    dataset = worker_info.dataset if worker_info is not None else None
    try:
        if dataset is not None and hasattr(dataset, "_tokenizer_name_or_path") and dataset._tokenizer_name_or_path:
            try:
                from transformers import M2M100Tokenizer
                dataset.tokenizer = M2M100Tokenizer.from_pretrained(dataset._tokenizer_name_or_path)
                dataset.is_fast = getattr(dataset.tokenizer, "is_fast", False)
                if DEBUG_CELL2:
                    print(f"[CELL2-WORKER-{worker_id}] Tokenizer reloaded successfully")
            except Exception as e:
                cell2_dbg("worker_tokenizer_reload", f"Worker {worker_id} tokenizer reload failed: {e}")
                dataset.tokenizer = None
                dataset.is_fast = False
    except Exception:
        if DEBUG_CELL2:
            print(f"[CELL2-WORKER-INIT] Tokenizer rebind failed in worker {worker_id}")

    try:
        base = int(os.environ.get("PYTHONHASHSEED", "0"))
        seed = (base ^ (worker_id + 1) ^ int(time.time())) & 0xFFFFFFFF
        random.seed(seed)
        np.random.seed(seed % (2**31 - 1))
        torch.manual_seed(seed % (2**31 - 1))
    except Exception:
        pass

def load_and_preprocess_optimized(
    num_samples: Optional[int] = None,
    split: str = "train",
) -> List[Tuple[str, str]]:
    if num_samples is None:
        num_samples = _NUM_SAMPLES
    if num_samples <= 0:
        raise ValueError("num_samples must be positive")

    print(f"[CELL2] Loading up to {num_samples} samples from local CSV: {_DATASET_CSV_PATH}")

    if not _HAS_PANDAS:
        print("[CELL2] ERROR: pandas not available; cannot load CSV!")
        print("[CELL2] Using fallback dataset for debugging.")
        return _get_fallback_dataset()

    if not os.path.exists(_DATASET_CSV_PATH):
        print(f"[CELL2] ERROR: CSV file not found at: {_DATASET_CSV_PATH}")
        print("[CELL2] Using fallback dataset for debugging.")
        return _get_fallback_dataset()

    try:
        print("[CELL2] Reading CSV file...")
        df = pd.read_csv(_DATASET_CSV_PATH)
        if df.empty:
            print("[CELL2] ERROR: CSV file is empty")
            return _get_fallback_dataset()

        if "src" not in df.columns or "tgt" not in df.columns:
            print(f"[CELL2] ERROR: CSV missing required columns. Found columns: {list(df.columns)}")
            print("[CELL2] Expected format: src (Bengali), tgt (English) OR src (English), tgt (Bengali)")
            return _get_fallback_dataset()

        sample_src = str(df["src"].iloc[0]) if len(df) > 0 else ""
        sample_tgt = str(df["tgt"].iloc[0]) if len(df) > 0 else ""

        src_is_bengali = bool(_BENGALI_CHAR_RE.search(sample_src))
        tgt_is_bengali = bool(_BENGALI_CHAR_RE.search(sample_tgt))
        src_is_english = bool(re.search(r"[a-zA-Z]", sample_src)) and not src_is_bengali
        tgt_is_english = bool(re.search(r"[a-zA-Z]", sample_tgt)) and not tgt_is_bengali

        if src_is_english and tgt_is_bengali:
            print("[CELL2] Detected src=English, tgt=Bengali: Swapping columns for bn→en task.")
            df = df.rename(columns={"src": "_temp_tgt", "tgt": "_temp_src"})
            df = df.rename(columns={"_temp_src": "src", "_temp_tgt": "tgt"})
            sample_src = str(df["src"].iloc[0]) if len(df) > 0 else ""
            sample_tgt = str(df["tgt"].iloc[0]) if len(df) > 0 else ""
            src_is_bengali = bool(_BENGALI_CHAR_RE.search(sample_src))
            tgt_is_english = bool(re.search(r"[a-zA-Z]", sample_tgt)) and not bool(_BENGALI_CHAR_RE.search(sample_tgt))
            if not src_is_bengali or not tgt_is_english:
                print("[CELL2] ERROR: Swap failed, after swap src is not Bengali or tgt is not English.")
                return _get_fallback_dataset()
            else:
                print("[CELL2] Swap successful: src=Bengali, tgt=English")
        elif not src_is_bengali or not tgt_is_english:
            print("[CELL2] WARNING: After column check, src not Bengali or tgt not English. Proceeding but output may be incorrect.")

        df = df.head(num_samples)
        print(f"[CELL2] Processing {len(df)} rows from CSV...")

        pairs: List[Tuple[str, str]] = []
        skipped = 0

        for row_tuple in tqdm(df.itertuples(index=False), total=len(df), desc="Loading dataset"):
            try:
                src_val = row_tuple.src
                tgt_val = row_tuple.tgt
                if pd.isna(src_val) or pd.isna(tgt_val):
                    skipped += 1
                    cell2_dbg("nan_value", "NaN value detected")
                    continue
                bn = str(src_val).strip()
                en = str(tgt_val).strip()
                if not bn or not en:
                    skipped += 1
                    cell2_dbg("empty_field", "Empty src/tgt field")
                    continue
                if not is_bengali_text(bn):
                    skipped += 1
                    cell2_dbg("not_bengali_src", "src field not Bengali")
                    continue
                if not re.search(r"[a-zA-Z]", en):
                    skipped += 1
                    cell2_dbg("not_english_tgt", "tgt field not English")
                    continue
                max_words = max(20, _MAX_LENGTH // 2)
                if len(bn.split()) > max_words or len(en.split()) > max_words:
                    skipped += 1
                    cell2_dbg("too_long", "Text too long")
                    continue
                bn_norm = normalize_bengali(bn)
                en_norm = normalize_english(en)
                if not bn_norm or not en_norm:
                    skipped += 1
                    cell2_dbg("empty_after_norm", "Empty after normalization")
                    continue
                pairs.append((bn_norm, en_norm))
            except Exception as e:
                skipped += 1
                cell2_dbg("row_exception", f"Row load exception: {type(e).__name__}")
                continue

        print(f"[CELL2] Loaded {len(pairs)} pairs from CSV, skipped {skipped} rows")
        if len(pairs) == 0:
            print("[CELL2] ERROR: No valid pairs loaded from CSV!")
            print("[CELL2] Check that src column contains Bengali and tgt column contains English.")
            return _get_fallback_dataset()

        return pairs

    except pd.errors.EmptyDataError:
        print(f"[CELL2] ERROR: CSV file is empty: {_DATASET_CSV_PATH}")
        return _get_fallback_dataset()
    except Exception as e:
        print(f"[CELL2] ERROR loading CSV: {type(e).__name__}: {str(e)}")
        traceback.print_exc()
        print("[CELL2] Using fallback dataset")
        return _get_fallback_dataset()

def _get_fallback_dataset() -> List[Tuple[str, str]]:
    print("[CELL2] Using fallback dataset (50 unique samples)")
    fallback_pairs = [
        ("আমি কল বন্ধ করেছি।", "i turned off the tap."),
        ("সে আমাকে পরে কল করবে।", "he will call me later."),
        ("আমরা প্রতিদিন তাজা ফল খাই।", "we eat fresh fruits every day."),
        ("তার কঠোর পরিশ্রমের ভালো ফল হয়েছে।", "his hard work has brought good results."),
        ("গাছে নতুন পাতাগুলো গজিয়েছে।", "new leaves have sprouted on the tree."),
        ("আমি বইয়ের পাতা উল্টাচ্ছি।", "i am turning the pages of the book."),
        ("কাল আমি বাজারে গিয়েছিলাম।", "yesterday i went to the market."),
        ("কাল আমি তোমার সাথে দেখা করব।", "tomorrow i will meet you."),
        ("তারা আকাশে উজ্জ্বল।", "the stars are bright in the sky."),
        ("তারা বাড়িতে নেই।", "they are not at home."),
        ("ব্যাংক নদীর ধারে ভেঙে গেছে।", "the bank by the river has collapsed."),
        ("আমি ব্যাংকে টাকা জমা দিয়েছি।", "i deposited money in the bank."),
        ("বার বার চেষ্টা করতে হবে।", "you have to try again and again."),
        ("আমি বার খুলে ভিতরে ঢুকলাম।", "i opened the bar and entered."),
        ("তার মাথা ব্যথা করছে।", "his head is hurting."),
        ("আমি মাথা নেড়ে সম্মতি দিলাম।", "i nodded my head in agreement."),
        ("সে হার মেনে নিয়েছে।", "he accepted defeat."),
        ("আমি গলায় সোনার হার পরেছি।", "i am wearing a gold necklace."),
        ("পানি খুব ঠান্ডা।", "the water is very cold."),
        ("আমি পানি খাচ্ছি।", "i am drinking water."),
        ("দল খেলায় জিতেছে।", "the team won the game."),
        ("আমি মাটি দল দিয়ে ফেললাম।", "i trampled the soil."),
        ("বাজার থেকে সবজি কিনলাম।", "i bought vegetables from the market."),
        ("বাজার অনেক ভিড় ছিল।", "the market was very crowded."),
        ("তার নাম আহমেদ।", "his name is ahmed."),
        ("নাম না করে কাজ করো।", "work without making a name."),
        ("কথা বলা বন্ধ করো।", "stop talking."),
        ("তার কথা শুনে ভালো লাগল।", "i felt good hearing his words."),
        ("বই পড়তে ভালো লাগে।", "i like reading books."),
        ("আমি একটি নতুন বই কিনেছি।", "i bought a new book."),
        ("ঘর পরিষ্কার করা হয়েছে।", "the house has been cleaned."),
        ("আমি ঘরে বসে আছি।", "i am sitting at home."),
        ("মন ভালো নেই।", "my mind is not good."),
        ("আমার মন চায় বেড়াতে যেতে।", "my mind wants to go for a walk."),
        ("হাত ধুয়ে নাও।", "wash your hands."),
        ("আমি তার হাত ধরলাম।", "i held his hand."),
        ("দিন কেটে যাচ্ছে।", "the day is passing by."),
        ("আজ কি দিন?", "what day is today?"),
        ("রাত হয়ে এসেছে।", "night has come."),
        ("আমি রাত জেগে পড়েছি।", "i studied staying up at night."),
        ("জল খুব গরম।", "the water is very hot."),
        ("আমি জল দিয়ে গাছ সিঞ্চন করেছি।", "i watered the plants."),
        ("বাড়ি যাচ্ছি।", "i am going home."),
        ("আমার বাড়ি ঢাকায়।", "my house is in dhaka."),
        ("পার্কে অনেক মানুষ।", "there are many people in the park."),
        ("আমি প্রতিদিন পার্কে হাঁটি।", "i walk in the park every day."),
        ("নদী বইছে।", "the river is flowing."),
        ("আমি নদীর ধারে দাঁড়িয়ে আছি।", "i am standing by the river."),
        ("বন খুব সুন্দর।", "the forest is very beautiful."),
        ("আমি বন দেখতে গিয়েছিলাম।", "i went to see the forest."),
    ]
    return [(normalize_bengali(bn), normalize_english(en)) for bn, en in fallback_pairs]

class MemoryEfficientDataset(Dataset):
    def __init__(
        self,
        pairs: List[Tuple[str, str]],
        tokenizer: Any = None,
        max_length: Optional[int] = None,
        split: str = "train",
    ):
        if max_length is None:
            max_length = _MAX_LENGTH
        self.max_length = int(max_length)
        self.tokenizer = tokenizer
        self.split = split

        try:
            self._tokenizer_name_or_path = getattr(tokenizer, "name_or_path", None)
        except Exception:
            self._tokenizer_name_or_path = None

        try:
            self.is_fast = getattr(self.tokenizer, "is_fast", False)
        except Exception:
            self.is_fast = False

        self.pairs: List[Tuple[str, str]] = []
        invalid = 0

        for i, p in enumerate(pairs):
            try:
                if not isinstance(p, (list, tuple)) or len(p) != 2:
                    invalid += 1
                    cell2_dbg("init_badpair", f"Bad pair structure at idx={i}")
                    continue
                src, tgt = p
                if not isinstance(src, str) or not isinstance(tgt, str):
                    invalid += 1
                    cell2_dbg("init_badtype", f"Non-string src/tgt at idx={i}")
                    continue
                if not src or not tgt:
                    invalid += 1
                    cell2_dbg("init_empty", f"Empty src/tgt at idx={i}")
                    continue
                if len(src) > self.max_length * 20 or len(tgt) > self.max_length * 20:
                    invalid += 1
                    cell2_dbg("init_long", f"Extremely long text at idx={i}")
                    continue
                self.pairs.append((src, tgt))
            except Exception as e:
                invalid += 1
                cell2_dbg("init_exc", f"Init pair exception idx={i}: {type(e).__name__}")

        print(f"[CELL2] Dataset initialized: {len(self.pairs)} valid pairs, {invalid} invalid")

        try:
            if "get_tokenizer_special_tokens" in globals():
                self.special_tokens = get_tokenizer_special_tokens(self.tokenizer)
            else:
                self.special_tokens = set(getattr(self.tokenizer, "all_special_tokens", [])) if self.tokenizer is not None else set()
        except Exception:
            self.special_tokens = {
                f"__{_SOURCE_LANG}__",
                f"__{_TARGET_LANG}__",
                "</s>",
                "<pad>",
                "<s>",
                "<unk>",
            }

    def __getstate__(self):
        state = self.__dict__.copy()
        state["tokenizer"] = None
        state["_tokenizer_name_or_path"] = getattr(self, "_tokenizer_name_or_path", None)
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.tokenizer = None
        self.is_fast = False

    def __len__(self) -> int:
        return len(self.pairs)

    def _encode_src(self, src_text: str):
        src_text = src_text if isinstance(src_text, str) else str(src_text)
        try:
            if self.tokenizer is None:
                self.tokenizer = globals().get("tokenizer", None)
                self.is_fast = getattr(self.tokenizer, "is_fast", False) if self.tokenizer is not None else False
            if self.tokenizer is None:
                raise RuntimeError("Tokenizer not available")

            if _has_safe_offsets_tokenize:
                enc = safe_offsets_tokenize(self.tokenizer, src_text, max_length=self.max_length)
                try:
                    if isinstance(enc["input_ids"], torch.Tensor):
                        input_ids = enc["input_ids"].squeeze(0)
                    else:
                        input_ids = torch.tensor(enc["input_ids"][0])
                except Exception:
                    input_ids = torch.tensor(enc.get("input_ids", [[1]])[0])
                attention_mask = enc.get("attention_mask", torch.ones_like(input_ids))
                if isinstance(attention_mask, list):
                    attention_mask = torch.tensor(attention_mask[0]) if attention_mask else torch.ones_like(input_ids)
                try:
                    ids_list = input_ids.tolist() if isinstance(input_ids, torch.Tensor) else list(input_ids)
                    tokens = self.tokenizer.convert_ids_to_tokens(ids_list)
                except Exception:
                    tokens = []
            else:
                enc = self.tokenizer(
                    src_text,
                    max_length=self.max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt",
                    add_special_tokens=False,
                )
                input_ids = enc["input_ids"].squeeze(0)
                attention_mask = enc.get("attention_mask", torch.ones_like(input_ids)).squeeze(0)
                try:
                    tokens = self.tokenizer.convert_ids_to_tokens(input_ids.tolist())
                except Exception:
                    tokens = []

            token_word_map: Dict[int, str] = {}
            if _has_reconstruct_word_spans:
                try:
                    wm, words = reconstruct_word_spans(self.tokenizer, src_text, max_length=self.max_length)
                    if isinstance(wm, dict) and wm:
                        token_word_map = wm
                except Exception as e:
                    cell2_dbg("wm_exc", f"reconstruct_word_spans failed: {e}")

            if not token_word_map and tokens and _has_map_subwords_to_words:
                try:
                    token_word_map = map_subwords_to_words(tokens, self.tokenizer)
                except Exception as e:
                    cell2_dbg("map_subwords_exc", f"map_subwords_to_words failed: {e}")

            if not token_word_map and tokens:
                try:
                    current_word: List[str] = []
                    for idx, tok in enumerate(tokens):
                        if isinstance(tok, str) and tok not in self.special_tokens:
                            clean = (tok.replace("▁", "").replace("Ġ", "").replace("##", "").strip())
                            if clean:
                                if tok.startswith("▁") or tok.startswith("Ġ"):
                                    current_word = [clean]
                                else:
                                    current_word.append(clean)
                                token_word_map[idx] = "".join(current_word)
                except Exception as e:
                    cell2_dbg("fallback_wm", f"Fallback word map failed: {e}")

            return input_ids, attention_mask, tokens, token_word_map

        except Exception as e:
            cell2_dbg("encode_src_exc", f"Encoding source failed: {type(e).__name__}")
            pad_id = getattr(self.tokenizer, "pad_token_id", 1) if self.tokenizer is not None else 1
            input_ids = torch.full((self.max_length,), int(pad_id), dtype=torch.long)
            attention_mask = torch.zeros(self.max_length, dtype=torch.long)
            return input_ids, attention_mask, [], {}

    def _encode_tgt(self, tgt_text: str):
        tgt_text = tgt_text if isinstance(tgt_text, str) else str(tgt_text)
        try:
            if self.tokenizer is None:
                self.tokenizer = globals().get("tokenizer", None)
            if self.tokenizer is None:
                raise RuntimeError("Tokenizer not available")

            dec = self.tokenizer(
                tgt_text,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
                add_special_tokens=False,
            )
            labels = dec["input_ids"].squeeze(0)
            pad_id = getattr(self.tokenizer, "pad_token_id", 1) if self.tokenizer is not None else 1
            labels[labels == int(pad_id)] = -100
            return labels
        except Exception as e:
            cell2_dbg("encode_tgt_exc", f"Encoding tgt failed: {type(e).__name__}")
            return torch.full((self.max_length,), -100, dtype=torch.long)

    def _make_safe_sample(self, reason: str = "fallback") -> Dict[str, Any]:
        try:
            src = "আমি"
            tgt = "i"
            input_ids, attention_mask, tokens, token_word_map = self._encode_src(src)
            labels = self._encode_tgt(tgt)
            domain_label = random.choice([_TRAIN_DOMAIN, _TEST_DOMAIN])
            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "token_word_map": token_word_map,
                "src_text": src,
                "tokens": tokens,
                "domain_label": domain_label,
            }
        except Exception:
            pad_id = 1
            domain_label = random.choice([_TRAIN_DOMAIN, _TEST_DOMAIN])
            return {
                "input_ids": torch.full((self.max_length,), int(pad_id), dtype=torch.long),
                "attention_mask": torch.zeros(self.max_length, dtype=torch.long),
                "labels": torch.full((self.max_length,), -100, dtype=torch.long),
                "token_word_map": {},
                "src_text": "",
                "tokens": [],
                "domain_label": domain_label,
            }

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        try:
            if idx < 0 or idx >= len(self.pairs):
                cell2_dbg("getitem_oob", f"Index out of range idx={idx}")
                return self._make_safe_sample("oob")

            src, tgt = self.pairs[idx]
            if not isinstance(src, str) or not isinstance(tgt, str):
                cell2_dbg("getitem_bad_types", f"Bad types at idx={idx}")
                return self._make_safe_sample("bad_types")

            if DEBUG_CELL2 and idx < 3:
                has_bengali = is_bengali_text(src)
                has_english = any("a" <= c.lower() <= "z" for c in src)
                print(f"[CELL2-GETITEM-{idx}] src sample: {src[:50]}")
                print(f"[CELL2-GETITEM-{idx}] Bengali: {has_bengali}, English: {has_english}")
                if not has_bengali:
                    print(f"[CELL2] WARNING: src_text is NOT Bengali at idx={idx}!")

            input_ids, attention_mask, tokens, token_word_map = self._encode_src(src)
            labels = self._encode_tgt(tgt)
            domain_label = random.choice([_TRAIN_DOMAIN, _TEST_DOMAIN])

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "token_word_map": token_word_map,
                "src_text": src,
                "tokens": tokens,
                "domain_label": domain_label,
            }
        except Exception as e:
            cell2_dbg("getitem_exc", f"Unhandled __getitem__ exception idx={idx}: {type(e).__name__}")
            return self._make_safe_sample("unhandled")

def _infer_pad_id_from_sample(sample: Dict[str, Any], default_pad_id: int = 1) -> int:
    try:
        tk = globals().get("tokenizer", None)
        if tk is not None:
            pad = getattr(tk, "pad_token_id", None)
            if pad is not None:
                return int(pad)
    except Exception:
        cell2_dbg("infer_pad_exc", "infer pad id failed")
    return int(default_pad_id)

def _pad_or_truncate_array(tensor: torch.Tensor, length: int, pad_value: int) -> torch.Tensor:
    if tensor is None:
        return torch.full((length,), int(pad_value), dtype=torch.long)
    t = tensor.view(-1).long()
    L = t.size(0)
    if L == length:
        return t
    if L < length:
        pad = torch.full((length - L,), int(pad_value), dtype=t.dtype)
        return torch.cat([t, pad], dim=0)
    return t[:length]

def safe_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    valid = [b for b in batch if isinstance(b, dict) and "input_ids" in b and isinstance(b["input_ids"], torch.Tensor)]
    if not valid:
        pad = _infer_pad_id_from_sample({}, default_pad_id=1)
        return {
            "input_ids": torch.full((1, _MAX_LENGTH), pad, dtype=torch.long),
            "attention_mask": torch.zeros(1, _MAX_LENGTH, dtype=torch.long),
            "labels": torch.full((1, _MAX_LENGTH), -100, dtype=torch.long),
            "token_word_map": [{}],
            "src_text": [""],
            "tokens": [[]],
            "domain_labels": torch.tensor([_TRAIN_DOMAIN], dtype=torch.long),
        }

    pad_id = _infer_pad_id_from_sample(valid[0], default_pad_id=1)
    inputs, masks, labs, twmaps, srcs, toks, domains = [], [], [], [], [], [], []

    for i, s in enumerate(valid):
        try:
            in_ids = s["input_ids"]
            att = s.get("attention_mask", None)
            lab = s["labels"]
            domain = s.get("domain_label", random.choice([_TRAIN_DOMAIN, _TEST_DOMAIN]))

            if att is None:
                att = (in_ids != pad_id).long()
            else:
                try:
                    att = att.view(-1).long()
                except Exception:
                    att = (in_ids != pad_id).long()

            try:
                in_ids = in_ids.view(-1)
            except Exception:
                in_ids = in_ids.flatten()

            try:
                lab = lab.view(-1)
            except Exception:
                lab = lab.flatten()

            in_ids = _pad_or_truncate_array(in_ids, _MAX_LENGTH, pad_id)
            att = _pad_or_truncate_array(att, _MAX_LENGTH, 0)
            lab = _pad_or_truncate_array(lab, _MAX_LENGTH, -100)

            inputs.append(in_ids)
            masks.append(att)
            labs.append(lab)
            twmaps.append(s.get("token_word_map", {}))
            srcs.append(s.get("src_text", ""))
            toks.append(s.get("tokens", []))
            domains.append(domain)
        except Exception as e:
            cell2_dbg("collate_item_exc", f"Collate item exception idx={i}: {type(e).__name__}")
            continue

    if not inputs:
        pad = _infer_pad_id_from_sample({}, default_pad_id=1)
        return {
            "input_ids": torch.full((1, _MAX_LENGTH), pad, dtype=torch.long),
            "attention_mask": torch.zeros(1, _MAX_LENGTH, dtype=torch.long),
            "labels": torch.full((1, _MAX_LENGTH), -100, dtype=torch.long),
            "token_word_map": [{}],
            "src_text": [""],
            "tokens": [[]],
            "domain_labels": torch.tensor([_TRAIN_DOMAIN], dtype=torch.long),
        }

    input_ids = torch.stack(inputs, dim=0)
    attention_mask = torch.stack(masks, dim=0)
    labels = torch.stack(labs, dim=0)
    try:
        domain_labels = torch.tensor(domains, dtype=torch.long)
    except Exception:
        domain_labels = torch.tensor([random.choice([_TRAIN_DOMAIN, _TEST_DOMAIN]) for _ in inputs], dtype=torch.long)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "token_word_map": twmaps,
        "src_text": srcs,
        "tokens": toks,
        "domain_labels": domain_labels,
    }

def create_optimized_dataloader(
    dataset: Dataset,
    batch_size: Optional[int] = None,
    shuffle: bool = True,
    split: str = "train",
) -> DataLoader:
    if batch_size is None:
        try:
            batch_size = int(BATCH_SIZE)
        except NameError:
            batch_size = 8

    batch_size = int(batch_size)
    original_batch_size = batch_size
    adjusted = False

    if _USE_MULTI_GPU and _NUM_GPUS > 0 and batch_size % _NUM_GPUS != 0:
        new_batch_size = (batch_size // _NUM_GPUS) * _NUM_GPUS
        if new_batch_size == 0:
            if DEBUG_CELL2:
                print(f"[CELL2] WARNING: batch_size {batch_size} < num_gpus {_NUM_GPUS}. Keeping original.")
        else:
            batch_size = new_batch_size
            adjusted = batch_size != original_batch_size

    if adjusted:
        print(f"[CELL2] Adjusted batch size {original_batch_size} to {batch_size} (DP-divisible, GPUs={_NUM_GPUS})")

    num_workers = _NUM_WORKERS if isinstance(_NUM_WORKERS, int) and _NUM_WORKERS >= 0 else 0
    try:
        max_possible = max(0, (os.cpu_count() or 1) - 1)
        if num_workers > max_possible:
            num_workers = max_possible
    except Exception:
        pass

    loader_kwargs: Dict[str, Any] = {
        "dataset": dataset,
        "batch_size": batch_size,
        "shuffle": shuffle,
        "num_workers": num_workers,
        "pin_memory": bool(_PIN_MEMORY and torch.cuda.is_available()),
        "collate_fn": safe_collate,
        "drop_last": False,
    }

    if num_workers > 0:
        loader_kwargs["worker_init_fn"] = _dataloader_worker_init_fn
        loader_kwargs["prefetch_factor"] = _PREFETCH_FACTOR
        loader_kwargs["persistent_workers"] = False

    try:
        dataloader = DataLoader(**loader_kwargs)
    except Exception as e:
        print(f"[CELL2] DataLoader init failed with num_workers={num_workers}: {type(e).__name__}")
        print("[CELL2] Retrying with num_workers=0")
        loader_kwargs["num_workers"] = 0
        loader_kwargs.pop("prefetch_factor", None)
        loader_kwargs.pop("persistent_workers", None)
        loader_kwargs.pop("worker_init_fn", None)
        dataloader = DataLoader(**loader_kwargs)

    if _USE_MULTI_GPU and _NUM_GPUS > 0:
        per_gpu = batch_size // _NUM_GPUS if _NUM_GPUS > 0 else batch_size
        print(f"[CELL2] DataLoader created: total_batch={batch_size}, per_gpu={per_gpu}, workers={loader_kwargs.get('num_workers', 0)}")
    else:
        print(f"[CELL2] DataLoader created: batch_size={batch_size}, workers={loader_kwargs.get('num_workers', 0)}")

    return dataloader

print("Cell 2: Memory-efficient data loading ready")


In [None]:
# ==============================================================================
# CELL 3: DSCD MODULE (PURE UNSUPERVISED DISCOVERY - FIXED DYNAMIC MULTI-SENSE)
# ==============================================================================

import threading
import time
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gc
from collections import deque
import unicodedata
from typing import Optional, Dict, List, Any, Set, Tuple

PRINT_INTERVAL = 200

try:
    from scipy.cluster.hierarchy import linkage, fcluster
    from scipy.spatial.distance import pdist
    _HAS_CLUSTERING = True
except Exception:
    _HAS_CLUSTERING = False
    print("[CELL3] WARNING: scipy not available")

try:
    from sklearn.cluster import KMeans
    _HAS_KMEANS = True
except Exception:
    _HAS_KMEANS = False
    print("[CELL3] WARNING: sklearn not available")

try:
    DSCD_MAX_PROTOS = int(DSCD_MAX_PROTOS)
    DSCD_BUFFER_SIZE = int(DSCD_BUFFER_SIZE)
    DSCD_N_MIN = int(DSCD_N_MIN)
    DSCD_DISPERSION_THRESHOLD = float(DSCD_DISPERSION_THRESHOLD)
    VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
    DSCD_ENABLE_TRAINING_CLUSTERING = bool(DSCD_ENABLE_TRAINING_CLUSTERING)
except (NameError, ValueError, TypeError):
    DSCD_MAX_PROTOS = 8
    DSCD_BUFFER_SIZE = 50
    DSCD_N_MIN = 5
    DSCD_DISPERSION_THRESHOLD = 0.50
    VERBOSE_LOGGING = True
    DSCD_ENABLE_TRAINING_CLUSTERING = True
    print("[CELL3] WARNING: Using default DSCD config")

try:
    DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except NameError:
    DEBUG_DISCOVERY = False

try:
    _MAX_TOKENS_PER_DISCOVERY = int(globals().get("_MAX_TOKENS_PER_DISCOVERY", 150))
except Exception:
    _MAX_TOKENS_PER_DISCOVERY = 150

try:
    DSCD_NEW_SENSE_LAMBDA = float(globals().get("DSCD_NEW_SENSE_LAMBDA", 1.5))
except Exception:
    DSCD_NEW_SENSE_LAMBDA = 1.5

try:
    HOMOGRAPH_REFERENCE_LIST_BN = set(HOMOGRAPH_REFERENCE_LIST_BN)
    print(f"[CELL3] Loaded reference list for evaluation: {len(HOMOGRAPH_REFERENCE_LIST_BN)} words")
except (NameError, TypeError):
    HOMOGRAPH_REFERENCE_LIST_BN = {
        "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা"
    }
    print("[CELL3] Using default reference list")

DSCD_MAX_CLUSTERING_POINTS = 500
_PUNCT_SET = set('.,!?;:()[]{}"\'-—–/\\')

def normalize_token_key(token: str) -> str:
    return (
        str(token)
        .replace("▁", "")
        .replace(" ", "")
        .replace("Ġ", "")
        .replace("##", "")
        .strip()
        .lower()
    )

def is_word_token(token: str, min_letters: int = 2, min_letter_fraction: float = 0.6) -> bool:
    if not token or not isinstance(token, str):
        return False
    token = token.strip()
    if not token:
        return False
    letters = 0
    total = 0
    for ch in token:
        cat = unicodedata.category(ch)
        if cat.startswith('L'):
            letters += 1
        if not ch.isspace():
            total += 1
    if total == 0:
        return False
    if letters < min_letters:
        return False
    if letters / total < min_letter_fraction:
        return False
    return True

class MemoryEfficientPrototypeStore:
    def __init__(self, embeddim, maxprotos: Optional[int] = None):
        if maxprotos is None:
            maxprotos = DSCD_MAX_PROTOS
        self.embeddim = embeddim
        self.maxprotos = int(maxprotos)
        self.centroids: List[torch.Tensor] = []
        self.counts: List[int] = []
        self.creation_time: List[float] = []
        self.distances: List[float] = []
        self.mu = 0.0
        self.tau = 1e-6
        self.alpha = 0.1
        self.labels: Optional[torch.Tensor] = None

    def add_prototype(self, vector: torch.Tensor, current_time: Optional[float] = None, count: int = 1) -> None:
        if current_time is None:
            current_time = time.time()
        v = vector.detach().cpu().clone()
        if len(self.centroids) < self.maxprotos:
            self.centroids.append(v)
            self.counts.append(int(count))
            self.creation_time.append(float(current_time))
        else:
            min_idx = int(np.argmin(self.counts)) if len(self.counts) > 0 else 0
            self.centroids[min_idx] = v
            self.counts[min_idx] = int(count)
            self.creation_time[min_idx] = float(current_time)

    def update_prototype(self, idx: int, vector: torch.Tensor, eta: float = 0.05, assignment_distance: Optional[float] = None) -> None:
        if idx < 0 or idx >= len(self.centroids):
            self.add_prototype(vector, time.time(), count=1)
            return
        old_centroid = self.centroids[idx]
        new_vector = vector.detach().cpu()
        self.centroids[idx] = (1.0 - eta) * old_centroid + eta * new_vector
        self.counts[idx] = int(self.counts[idx]) + 1
        if assignment_distance is not None:
            self.update_rolling_stats(float(assignment_distance))

    def update_rolling_stats(self, d: float) -> None:
        if not self.distances:
            self.mu = float(d)
            self.tau = 1e-6
            self.distances = [float(d)]
            return
        prev_mu = self.mu
        self.mu = (1 - self.alpha) * self.mu + self.alpha * float(d)
        self.tau = (1 - self.alpha) * self.tau + self.alpha * abs(float(d) - prev_mu)
        self.distances.append(float(d))
        if len(self.distances) > 50:
            self.distances.pop(0)

    def get_adaptive_threshold(self, lam: float = 1.0) -> float:
        return float(self.mu + lam * self.tau)

    def size(self) -> int:
        return len(self.centroids)

    def ensure_consistency(self) -> None:
        n = len(self.centroids)
        if len(self.counts) != n:
            self.counts = self.counts[:n] if len(self.counts) > n else self.counts + [1] * (n - len(self.counts))
        if len(self.creation_time) != n:
            self.creation_time = self.creation_time[:n] if len(self.creation_time) > n else self.creation_time + [time.time()] * (n - len(self.creation_time))

class MemoryEfficientDSCDOnline(nn.Module):
    def __init__(
        self,
        embeddim: int,
        tokenizer=None,
        buffersize: Optional[int] = None,
        maxprotos: Optional[int] = None,
        nmin: Optional[int] = None,
        dispersion_threshold: Optional[float] = None,
        language: str = "bn",
        enable_training_clustering: Optional[bool] = None,
        max_clustering_points: Optional[int] = None,
        max_candidates_per_step: int = 2,
        dscd_min_letters: int = 2,
        dscd_min_letter_fraction: float = 0.6,
    ):
        super().__init__()
        if buffersize is None:
            buffersize = DSCD_BUFFER_SIZE
        if maxprotos is None:
            maxprotos = DSCD_MAX_PROTOS
        if nmin is None:
            nmin = DSCD_N_MIN
        if dispersion_threshold is None:
            dispersion_threshold = DSCD_DISPERSION_THRESHOLD
        if max_clustering_points is None:
            max_clustering_points = DSCD_MAX_CLUSTERING_POINTS
        if enable_training_clustering is None:
            enable_training_clustering = DSCD_ENABLE_TRAINING_CLUSTERING

        self.embeddim = int(embeddim)
        self.buffersize = int(buffersize)
        self.maxprotos = int(maxprotos)
        self.nmin = int(nmin)
        self.dispersion_threshold = float(dispersion_threshold)
        self.language = language
        self.tokenizer = tokenizer
        self.dscd_min_letters = int(dscd_min_letters)
        self.dscd_min_letter_fraction = float(dscd_min_letter_fraction)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        try:
            if tokenizer is not None and "get_tokenizer_special_tokens" in globals():
                self.special_tokens = get_tokenizer_special_tokens(tokenizer)
            else:
                self.special_tokens = set(getattr(tokenizer, "all_special_tokens", [])) if tokenizer is not None else set()
        except Exception:
            self.special_tokens = set()

        self.dscd_allowed_tokens: Set[str] = set()
        self.dscd_ignored_tokens: Set[str] = set()
        self.dscd_cache_max_size = 10000

        self.prototype_stores: Dict[str, MemoryEfficientPrototypeStore] = {}
        self.prototype_stores = self.prototype_stores
        self.buffers: Dict[str, deque] = {}
        self.discovered_log: List[Dict[str, Any]] = []
        self.discovered_homographs: Set[str] = set()
        self.last_periodic_check = 0
        self.cleanup_counter = 0

        self.dispersion_cache: Dict[str, float] = {}
        self.dispersion_last_updated: Dict[str, float] = {}
        self.dispersion_lock = threading.Lock()
        self.clustering_lock = threading.RLock()
        self.buffer_lock = threading.Lock()

        from collections import deque as thread_deque
        self.active_threads = thread_deque(maxlen=100)
        self.thread_lock = threading.Lock()
        self.last_cluster_time: Dict[str, float] = {}
        self.cluster_cooldown_seconds = 5.0
        self.enable_training_clustering = bool(enable_training_clustering)
        self.discovery_count = 0
        self.discovery_times: List[float] = []
        self.clustered_tokens: Set[str] = set()
        self.cluster_stats: Dict[str, Dict[str, Any]] = {}

        self.span_head = nn.Sequential(
            nn.Linear(self.embeddim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1),
        )

        self.sigmanet = nn.Sequential(
            nn.Linear(self.embeddim, 16),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(16, 1),
        )

        self.gate_w = nn.Parameter(torch.tensor(1.0))
        self.gate_b = nn.Parameter(torch.tensor(0.4))
        self.gamma = nn.Parameter(torch.tensor(0.3))

        self.temperature = 0.07

        self.max_clustering_points = int(max_clustering_points)
        self.max_candidates_per_step = int(max_candidates_per_step)

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        state = super().state_dict(destination, prefix, keep_vars)
        plain_stores = {}
        for token, store in self.prototype_stores.items():
            plain_stores[token] = {
                "centroids": [c.cpu() for c in store.centroids] if hasattr(store, 'centroids') else [],
                "counts": list(store.counts) if hasattr(store, 'counts') else [],
                "creation_time": list(store.creation_time) if hasattr(store, 'creation_time') else [],
                "mu": float(store.mu) if hasattr(store, 'mu') else 0.0,
                "tau": float(store.tau) if hasattr(store, 'tau') else 0.0,
                "size": int(store.size()) if hasattr(store, 'size') else 0,
            }
        state[prefix + "prototype_stores"] = plain_stores
        state[prefix + "discovered_homographs"] = list(self.discovered_homographs)
        return state

    def load_state_dict(self, state_dict, strict=True):
        prefix = ""
        plain_stores = state_dict.pop('prototype_stores', {})
        discovered = state_dict.pop('discovered_homographs', [])
        super().load_state_dict(state_dict, strict=strict)

        if not plain_stores:
            print("[DSCD] WARNING: Empty prototype_stores in checkpoint")
            return

        self.prototype_stores = {}
        self.discovered_homographs = set(discovered)

        for token, store_dict in plain_stores.items():
            store = MemoryEfficientPrototypeStore(embeddim=self.embeddim, maxprotos=self.maxprotos)
            centroids_data = store_dict.get("centroids", [])
            store.centroids = []
            for c in centroids_data:
                if isinstance(c, torch.Tensor):
                    store.centroids.append(c)
                else:
                    store.centroids.append(torch.tensor(c))
            store.counts = store_dict.get("counts", [])
            store.creation_time = store_dict.get("creation_time", [])
            store.mu = store_dict.get("mu", 0.0)
            store.tau = store_dict.get("tau", 0.0)
            store.ensure_consistency()
            self.prototype_stores[token] = store

        print(f"[DSCD] Loaded {len(self.prototype_stores)} tokens, {sum(s.size() for s in self.prototype_stores.values())} prototypes")

    @staticmethod
    def clean_token(token):
        token = str(token)
        token = token.replace("▁", "").replace("##", "")
        for punct in [",", ".", ",", "!", "?", ":", ";", "-"]:
            token = token.replace(punct, "")
        return token.strip()

    def is_valid_multisense(self, token):
        if token not in self.prototype_stores:
            return False
        store = self.prototype_stores[token]
        total_occurrences = sum(store.counts) if hasattr(store, 'counts') else 0
        min_per_proto = min(store.counts) if hasattr(store, 'counts') and store.counts else 0
        return (
            store.size() >= 2
            and total_occurrences >= 10
            and min_per_proto >= 2
        )

    def is_multisense_store(self, store: MemoryEfficientPrototypeStore) -> bool:
        k = store.size()
        if k < 2:
            return False
        counts = store.counts if store.counts else [1] * k
        strong = sum(1 for c in counts if c >= max(2, self.nmin // 2))
        if strong < 2:
            return False
        try:
            cents = []
            for c in store.centroids:
                if isinstance(c, torch.Tensor):
                    cents.append(c.cpu().numpy())
                else:
                    cents.append(np.asarray(c, dtype=np.float32))
            if len(cents) < 2:
                return False
            cents = np.stack(cents, axis=0)
            dists = np.linalg.norm(cents[:, None, :] - cents[None, :, :], axis=-1)
            tri = dists[np.triu_indices(len(cents), k=1)]
            if tri.size == 0:
                return False
            min_dist = float(tri.min())
            base = max(store.tau, 1e-3)
            return min_dist > base * DSCD_NEW_SENSE_LAMBDA
        except Exception:
            return True

    def periodic_discovery_check(self, current_step: int, frequency: int) -> None:
        if not self.enable_training_clustering:
            print(f"[DSCD-DISCOVERY] Clustering disabled (enable_training_clustering={self.enable_training_clustering})")
            return
        
        if current_step % frequency != 0:
            return
        
        print(f"\n{'='*80}")
        print(f"[DSCD-DISCOVERY] Starting @ step {current_step}")
        print(f"{'='*80}")
        
        try:
            with self.buffer_lock:
                total_buffers = len(self.buffers)
                buffer_sizes = {k: len(v) for k, v in self.buffers.items()}
                large_buffers = {k: v for k, v in buffer_sizes.items() if v >= self.nmin}
            
            print(f"[DSCD-DISCOVERY] Total tokens with buffers: {total_buffers}")
            print(f"[DSCD-DISCOVERY] Tokens with enough samples (>={self.nmin}): {len(large_buffers)}")
            
            if len(large_buffers) > 0:
                top_5 = sorted(large_buffers.items(), key=lambda x: x[1], reverse=True)[:5]
                print(f"[DSCD-DISCOVERY] Top 5 buffer sizes:")
                for tok, size in top_5:
                    print(f"  - '{tok}': {size} samples")
        except Exception as e:
            print(f"[DSCD-DISCOVERY] Error checking buffers: {e}")
            return
        
        if len(large_buffers) == 0:
            print(f"[DSCD-DISCOVERY] No tokens with >={self.nmin} samples - skipping")
            print(f"{'='*80}\n")
            return
        
        print(f"[DSCD-DISCOVERY] Attempting to acquire clustering_lock (timeout=5s)...")
        lock_start = time.time()
        acquired = self.clustering_lock.acquire(timeout=5.0)
        lock_time = time.time() - lock_start
        
        if not acquired:
            print(f"[DSCD-DISCOVERY] ❌ LOCK TIMEOUT after {lock_time:.2f}s @ step {current_step}")
            print(f"[DSCD-DISCOVERY] Lock is held by another process - skipping discovery")
            print(f"{'='*80}\n")
            return
        
        print(f"[DSCD-DISCOVERY] ✅ Lock acquired in {lock_time:.2f}s")
        
        try:
            discovery_start = time.time()
            
            print(f"[DSCD-DISCOVERY] Calling discover_homographs()...")
            discovered_count = self.discover_homographs()
            
            discovery_time = time.time() - discovery_start
            
            print(f"\n[DSCD-DISCOVERY] ✅ COMPLETED in {discovery_time:.2f}s")
            print(f"[DSCD-DISCOVERY] Discovered {discovered_count} new homographs")
            print(f"[DSCD-DISCOVERY] Total discovered so far: {len(self.discovered_homographs)}")
            
        except Exception as e:
            print(f"[DSCD-DISCOVERY] ❌ FAILED with error: {type(e).__name__}: {e}")
            import traceback
            traceback.print_exc()
        finally:
            self.clustering_lock.release()
            print(f"[DSCD-DISCOVERY] Lock released")
            print(f"{'='*80}\n")

    def discover_homographs_for_tokens(
        self,
        token_names: List[str],
        min_cluster_samples: int,
        dispersion_threshold: float,
        global_step: int,
    ) -> int:
        discovered_in_run: List[str] = []
        for idx, token in enumerate(token_names):
            try:
                success = self.cluster_buffer_to_prototypes_hierarchical(token)
                if success:
                    store = self.prototype_stores.get(token)
                    if store and store.size() >= 2:
                        clean_token = normalize_token_key(token)
                        self.discovered_homographs.add(clean_token)
                        discovered_in_run.append(clean_token)
            except Exception:
                continue

        try:
            self.discovered_log.append({
                "timestamp": time.time(),
                "global_step": global_step,
                "candidates_processed": len(token_names),
                "discovered_count": len(discovered_in_run),
                "homographs": discovered_in_run,
                "total_discovered": len(self.discovered_homographs),
            })
        except Exception:
            pass

        return len(discovered_in_run)

    def discover_homographs(
        self,
        min_cluster_samples: Optional[int] = None,
        dispersion_threshold: Optional[float] = None,
        max_candidates: int = 500,
    ) -> int:
        if min_cluster_samples is None:
            min_cluster_samples = self.nmin
        if dispersion_threshold is None:
            dispersion_threshold = self.dispersion_threshold

        print(f"[DISCOVER] Scanning buffers (min_samples={min_cluster_samples}, min_dispersion={dispersion_threshold:.3f})...")
        
        candidates: List[Tuple[str, float, int, float]] = []

        acquired = self.buffer_lock.acquire(timeout=15.0)
        if not acquired:
            print("[DISCOVER] ❌ Lock timeout - buffer_lock held by another thread")
            return 0
        
        try:
            total_tokens = len(self.buffers)
            print(f"[DISCOVER] Checking {total_tokens} tokens...")
            
            buffers_snapshot = {k: (len(v), list(v)) for k, v in self.buffers.items()}
        finally:
            self.buffer_lock.release()
        
        checked = 0
        skipped_small = 0
        start_time = time.time()
        
        for token, (buffer_size, buffer_data) in buffers_snapshot.items():
            checked += 1
            
            if checked % 200 == 0:
                elapsed = time.time() - start_time
                rate = checked / elapsed if elapsed > 0 else 0
                eta = (total_tokens - checked) / rate if rate > 0 else 0
                print(f"[DISCOVER] Progress: {checked}/{total_tokens} ({100*checked//total_tokens}%) "
                      f"| {rate:.0f} tok/s | ETA: {eta:.0f}s | candidates: {len(candidates)}")
            
            if buffer_size < min_cluster_samples:
                skipped_small += 1
                continue
            
            try:
                if len(buffer_data) < 2:
                    continue
                
                embeddings_np = np.stack([
                    e.cpu().numpy() if isinstance(e, torch.Tensor) else np.asarray(e, dtype=np.float32)
                    for e in buffer_data
                ], axis=0)
                
                centroid = embeddings_np.mean(axis=0)
                distances = np.linalg.norm(embeddings_np - centroid[None, :], axis=1)
                dispersion = float(distances.std())
                
                if dispersion >= dispersion_threshold:
                    rank_score = dispersion * buffer_size
                    candidates.append((token, rank_score, buffer_size, dispersion))
            except Exception:
                continue
        
        print(f"[DISCOVER] ✅ Scan complete: {len(candidates)} candidates found (skipped {skipped_small} small buffers)")

        if not candidates:
            print(f"[DISCOVER] No candidates meet criteria - returning")
            return 0

        candidates.sort(key=lambda x: x[1], reverse=True)
        candidates = candidates[:max_candidates]
        
        print(f"[DISCOVER] Top {len(candidates)} candidates selected for clustering")
        print(f"[DISCOVER] Top 3:")
        for i, (tok, score, bufsize, disp) in enumerate(candidates[:3], 1):
            print(f"  {i}. '{tok}': score={score:.2f}, buffer={bufsize}, dispersion={disp:.3f}")

        discovered: List[str] = []
        
        try:
            from tqdm import tqdm
            candidate_iter = tqdm(candidates, desc="[DISCOVER] Clustering", ncols=100)
            use_tqdm = True
        except ImportError:
            candidate_iter = candidates
            use_tqdm = False
            print(f"[DISCOVER] Processing {len(candidates)} candidates...")
        
        for idx, (token, score, bufsize, disp) in enumerate(candidate_iter):
            if not use_tqdm and idx % 10 == 0:
                print(f"[DISCOVER] Progress: {idx}/{len(candidates)} ({100*idx//len(candidates)}%)")
            
            try:
                success = self.cluster_buffer_to_prototypes_hierarchical(token)
                if success:
                    store = self.prototype_stores.get(token)
                    if store and store.size() >= 2:
                        clean_token = normalize_token_key(token)
                        self.discovered_homographs.add(clean_token)
                        discovered.append(clean_token)
                        
                        if len(discovered) <= 5:
                            print(f"\n[DISCOVER] ✓ '{token}' → {store.size()} prototypes (counts={store.counts})")
            except Exception as e:
                if DEBUG_DISCOVERY:
                    print(f"\n[DISCOVER] ✗ '{token}' failed: {type(e).__name__}")
                continue

        print(f"\n[DISCOVER] ✅ Clustering complete!")
        print(f"[DISCOVER] New homographs discovered: {len(discovered)}")
        
        if len(discovered) > 0:
            print(f"[DISCOVER] Sample homographs: {discovered[:10]}")

        try:
            self.discovered_log.append({
                "timestamp": time.time(),
                "candidates": len(candidates),
                "discovered": len(discovered),
                "homographs": discovered[:20],
            })
        except Exception:
            pass

        return len(discovered)

    def get_dispersion(self, tokentype: str) -> float:
        with self.dispersion_lock:
            if tokentype in self.dispersion_cache:
                try:
                    last_update = self.dispersion_last_updated.get(tokentype, 0.0)
                    if (time.time() - last_update) < 3600:
                        return self.dispersion_cache[tokentype]
                except Exception:
                    pass

        with self.buffer_lock:
            if tokentype not in self.buffers or len(self.buffers[tokentype]) < 2:
                return 0.0

            try:
                embeddings: List[np.ndarray] = []
                for emb in self.buffers[tokentype]:
                    try:
                        if isinstance(emb, torch.Tensor):
                            embeddings.append(emb.cpu().numpy())
                        else:
                            embeddings.append(np.asarray(emb, dtype=np.float32))
                    except Exception:
                        continue

                if len(embeddings) < 2:
                    return 0.0

                embeddings_np = np.stack(embeddings, axis=0)
                centroid = embeddings_np.mean(axis=0)
                distances = np.linalg.norm(embeddings_np - centroid[None, :], axis=1)
                dispersion = float(distances.std())

                with self.dispersion_lock:
                    self.dispersion_cache[tokentype] = dispersion
                    self.dispersion_last_updated[tokentype] = time.time()

                return dispersion
            except Exception:
                return 0.0

    def validate_prototypes(
        self,
        homograph_list: Optional[List[str]] = None,
        cluster_missing: bool = False,
    ) -> Dict[str, Any]:
        if homograph_list is None:
            try:
                homograph_list = list(HOMOGRAPH_REFERENCE_LIST_BN)
            except Exception:
                homograph_list = ["কল", "কাল", "পাতা", "ব্যাংক", "ফল"]

        print("=" * 80)
        print("[DSCD-VALIDATION] Prototype Quality Check")
        print("=" * 80)

        validation_results: Dict[str, Any] = {
            "total_tokens": len(self.prototype_stores),
            "total_prototypes": 0,
            "multisense_tokens": 0,
            "homographs_found": 0,
            "homographs_missing": [],
            "avg_prototypes_per_token": 0.0,
            "avg_samples_per_prototype": 0.0,
            "quality_score": 0.0,
        }

        total_samples = 0
        for token, store in self.prototype_stores.items():
            num_protos = len(store.centroids)
            validation_results["total_prototypes"] += num_protos

            if self.is_multisense_store(store):
                validation_results["multisense_tokens"] += 1

            try:
                total_samples += sum(store.counts)
            except Exception:
                pass

        if validation_results["total_tokens"] > 0:
            validation_results["avg_prototypes_per_token"] = validation_results["total_prototypes"] / validation_results["total_tokens"]
        if validation_results["total_prototypes"] > 0:
            validation_results["avg_samples_per_prototype"] = total_samples / validation_results["total_prototypes"]

        print("[VALIDATION] Reference Homograph Coverage:")
        print("-" * 80)

        missing_tokens_to_cluster: List[str] = []
        for homograph in homograph_list:
            clean_h = homograph.replace("▁", "").replace("##", "").replace("Ġ", "").replace(" ", "").strip()
            found = False
            found_key = None
            found_protos = 0

            if homograph in self.prototype_stores:
                found = True
                found_key = homograph
                found_protos = len(self.prototype_stores[homograph].centroids)
            elif clean_h in self.prototype_stores:
                found = True
                found_key = clean_h
                found_protos = len(self.prototype_stores[clean_h].centroids)
            else:
                for key in self.prototype_stores.keys():
                    clean_key = str(key).replace("▁", "").replace("##", "").replace("Ġ", "").replace(" ", "").strip()
                    if clean_key == clean_h or clean_h in clean_key or clean_key in clean_h:
                        found = True
                        found_key = key
                        found_protos = len(self.prototype_stores[key].centroids)
                        break

            if found and self.is_multisense_store(self.prototype_stores[found_key]):
                validation_results["homographs_found"] += 1
                try:
                    counts = self.prototype_stores[found_key].counts
                    print(f"  ✓ '{homograph}' → {found_protos} prototypes (counts={counts})")
                except Exception:
                    print(f"  ✓ '{homograph}' → {found_protos} prototypes")
            elif found and found_protos == 1:
                validation_results["homographs_missing"].append(homograph)
                print(f"  ⚠ '{homograph}' → Only 1 prototype")
                if cluster_missing:
                    missing_tokens_to_cluster.append(found_key)
            else:
                validation_results["homographs_missing"].append(homograph)
                print(f"  ✗ '{homograph}' → NOT FOUND")
                if cluster_missing:
                    if homograph in self.buffers or clean_h in self.buffers:
                        key_to_cluster = homograph if homograph in self.buffers else clean_h
                        if len(self.buffers[key_to_cluster]) >= max(5, self.nmin // 2):
                            print(f"      → Found in buffer, will cluster")
                            missing_tokens_to_cluster.append(key_to_cluster)

        if cluster_missing and missing_tokens_to_cluster:
            print(f"[VALIDATION] Clustering {len(missing_tokens_to_cluster)} missing tokens...")
            for token in missing_tokens_to_cluster:
                try:
                    self.cluster_buffer_to_prototypes_hierarchical(token)
                    if token in self.prototype_stores and self.is_multisense_store(self.prototype_stores[token]):
                        print(f"  ✓ Successfully clustered '{token}'")
                except Exception as e:
                    print(f"  ✗ Failed to cluster '{token}': {e}")

        homograph_coverage = validation_results["homographs_found"] / len(homograph_list) if homograph_list else 0.0
        multisense_ratio = validation_results["multisense_tokens"] / validation_results["total_tokens"] if validation_results["total_tokens"] > 0 else 0.0
        validation_results["quality_score"] = (
            homograph_coverage * 0.6 + multisense_ratio * 0.4
        )

        print("-" * 80)
        print("[VALIDATION] Summary:")
        print(f"  - Total tokens: {validation_results['total_tokens']}")
        print(f"  - Total prototypes: {validation_results['total_prototypes']}")
        print(f"  - Multi-sense tokens: {validation_results['multisense_tokens']}")
        print(f"  - Reference found: {validation_results['homographs_found']}/{len(homograph_list)}")
        print(f"  - Quality Score: {validation_results['quality_score']:.2}")
        print("=" * 80)

        return validation_results

    def should_track_token(self, tokentext: str) -> bool:
        if not tokentext or not isinstance(tokentext, str):
            return False

        if len(self.dscd_allowed_tokens) > self.dscd_cache_max_size:
            self.dscd_allowed_tokens.clear()
        if len(self.dscd_ignored_tokens) > self.dscd_cache_max_size:
            self.dscd_ignored_tokens.clear()

        if tokentext in self.dscd_allowed_tokens:
            return True
        if tokentext in self.dscd_ignored_tokens:
            return False

        if not getattr(self, "training", False):
            if tokentext in self.prototype_stores:
                self.dscd_allowed_tokens.add(tokentext)
                return True
            clean = tokentext.replace("▁", "").replace("##", "").replace("Ġ", "").replace(" ", "").strip()
            if clean and clean in self.prototype_stores:
                self.dscd_allowed_tokens.add(tokentext)
                return True

        if tokentext in self.special_tokens:
            self.dscd_ignored_tokens.add(tokentext)
            return False

        clean = tokentext.replace("▁", "").replace("##", "").replace("Ġ", "").replace(" ", "").strip()
        if not clean:
            self.dscd_ignored_tokens.add(tokentext)
            return False

        if len(clean) < 2:
            self.dscd_ignored_tokens.add(tokentext)
            return False

        if not any(c.isalpha() for c in clean):
            self.dscd_ignored_tokens.add(tokentext)
            return False

        if clean.isdigit():
            self.dscd_ignored_tokens.add(tokentext)
            return False

        if all(c in _PUNCT_SET for c in clean):
            self.dscd_ignored_tokens.add(tokentext)
            return False

        try:
            bengali_block = any("\u0980" <= c <= "\u09FF" for c in clean)
            if bengali_block:
                if len(clean) >= 2:
                    self.dscd_allowed_tokens.add(tokentext)
                    return True
                else:
                    self.dscd_ignored_tokens.add(tokentext)
                    return False
        except Exception:
            pass

        if is_word_token(
            clean,
            min_letters=self.dscd_min_letters,
            min_letter_fraction=self.dscd_min_letter_fraction,
        ):
            self.dscd_allowed_tokens.add(tokentext)
            return True

        self.dscd_ignored_tokens.add(tokentext)
        return False

    def canonical_token_key(
        self,
        raw_token: str,
        token_word_map: Optional[Dict[int, Optional[str]]],
        idx: int,
    ) -> Optional[str]:
        canonical: Optional[str] = None
        try:
            if token_word_map and isinstance(token_word_map, dict) and idx in token_word_map and token_word_map[idx]:
                canonical = str(token_word_map[idx]).strip()
        except Exception:
            canonical = None

        if not canonical:
            canonical = normalize_token_key(raw_token)

        if not canonical or len(canonical) < 2:
            return None
        return canonical

    def cleanup_threads(self) -> None:
        try:
            with self.thread_lock:
                alive = [th for th in list(self.active_threads) if th.is_alive()]
                self.active_threads.clear()
                self.active_threads.extend(alive)
        except Exception:
            pass

    def cleanup_memory(self) -> None:
        try:
            for tokentype, buffer in list(self.buffers.items()):
                if len(buffer) > int(self.buffersize * 1.5):
                    while len(buffer) > self.buffersize:
                        buffer.popleft()

            try:
                now = time.time()
                expired = [k for k, v in self.dispersion_last_updated.items() if (now - v) > 3600]
                for k in expired:
                    self.dispersion_cache.pop(k, None)
                    self.dispersion_last_updated.pop(k, None)
            except Exception:
                pass

            if gc.isenabled():
                gc.collect()
        except Exception:
            pass

    def forward(
        self,
        token_embeddings,
        token_types=None,
        train_mode: bool = True,
        token_word_map=None,
        h_all=None,
        input_ids=None,
        attention_mask=None,
    ):
        if token_embeddings is None and h_all is not None:
            token_embeddings = h_all

        if token_embeddings is None:
            raise ValueError("MemoryEfficientDSCDOnline.forward requires token_embeddings or h_all")

        if input_ids is not None and token_types is None:
            batch_size, seq_len = input_ids.shape
            token_types = []
            for b in range(batch_size):
                if self.tokenizer is not None:
                    try:
                        token_types.append(
                            self.tokenizer.convert_ids_to_tokens(input_ids[b].tolist())
                        )
                    except Exception:
                        token_types.append([f"tok{i}" for i in range(seq_len)])
                else:
                    token_types.append([f"tok{i}" for i in range(seq_len)])

        self.cleanup_counter += 1
        if self.cleanup_counter % 50 == 0:
            self.cleanup_counter = 0
            self.cleanup_memory()
            self.cleanup_threads()

        device = token_embeddings.device
        batch_size = int(token_embeddings.size(0))
        seq_len = int(token_embeddings.size(1))

        all_outputs: Dict[str, List[Any]] = {
            "proto_assignments": [],
            "proto_probs": [],
            "uncertainties": [],
            "span_preds": [],
            "gates": [],
            "h_augmented": [],
        }

        for b in range(batch_size):
            word_map = token_word_map[b] if token_word_map and len(token_word_map) > b else None
            batch_outputs = self.process_sequence(
                token_embeddings[b],
                token_types[b] if token_types and len(token_types) > b else [f"tok{i}" for i in range(seq_len)],
                device,
                word_map=word_map,
                train_mode=train_mode,
            )

            for k in all_outputs:
                all_outputs[k].append(batch_outputs[k])

        try:
            h_aug_list: List[torch.Tensor] = []
            max_seq_len = seq_len

            for b in range(batch_size):
                h_batch_list = all_outputs["h_augmented"][b]
                if len(h_batch_list) > 0 and isinstance(h_batch_list[0], torch.Tensor):
                    h_batch = torch.stack(h_batch_list, dim=0)
                    if h_batch.size(0) < max_seq_len:
                        pad = max_seq_len - h_batch.size(0)
                        h_batch = F.pad(h_batch, (0, 0, 0, pad), value=0)
                    elif h_batch.size(0) > max_seq_len:
                        h_batch = h_batch[:max_seq_len]
                else:
                    h_batch = torch.zeros(max_seq_len, self.embeddim, device=device)
                h_aug_list.append(h_batch)

            all_outputs["h_augmented"] = torch.stack(h_aug_list, dim=0)
        except Exception:
            all_outputs["h_augmented"] = token_embeddings

        try:
            proto_assign_tensor = []
            for row in all_outputs["proto_assignments"]:
                try:
                    stacked = torch.stack(
                        [x if isinstance(x, torch.Tensor) else torch.tensor(x) for x in row],
                        dim=0,
                    )
                    proto_assign_tensor.append(stacked)
                except Exception:
                    proto_assign_tensor.append(
                        torch.tensor(
                            [int(x) if not isinstance(x, torch.Tensor) else int(x.item()) for x in row],
                            dtype=torch.long,
                        )
                    )
            all_outputs["proto_assignments"] = proto_assign_tensor
        except Exception:
            pass

        return all_outputs

    def process_sequence(
        self,
        token_embeddings: torch.Tensor,
        token_types: List[Any],
        device: torch.device,
        word_map: Optional[Dict[int, Optional[str]]] = None,
        train_mode: bool = True,
    ) -> Dict[str, List[Any]]:
        seq_len = int(token_embeddings.size(0))

        outputs: Dict[str, List[Any]] = {
            "proto_assignments": [],
            "proto_probs": [],
            "uncertainties": [],
            "span_preds": [],
            "gates": [],
            "h_augmented": [],
        }

        for j in range(seq_len):
            raw_tok = token_types[j] if j < len(token_types) else f"tok{j}"
            if not isinstance(raw_tok, str):
                raw_tok = str(raw_tok) if raw_tok is not None else f"tok{j}"

            token_key = self.canonical_token_key(raw_tok, word_map, j)
            h_j = token_embeddings[j]

            if not token_key:
                outputs["proto_assignments"].append(torch.tensor(-1))
                outputs["proto_probs"].append([])
                outputs["uncertainties"].append(0.0)
                outputs["span_preds"].append(0.0)
                outputs["gates"].append(0.0)
                outputs["h_augmented"].append(h_j)
                continue

            if not self.should_track_token(token_key):
                outputs["proto_assignments"].append(torch.tensor(-1))
                outputs["proto_probs"].append([])
                outputs["uncertainties"].append(0.0)
                outputs["span_preds"].append(0.0)
                outputs["gates"].append(0.0)
                outputs["h_augmented"].append(h_j)
                continue

            with self.buffer_lock:
                if token_key not in self.buffers:
                    self.buffers[token_key] = deque(maxlen=self.buffersize)
                    self.prototype_stores[token_key] = MemoryEfficientPrototypeStore(
                        self.embeddim, self.maxprotos
                    )

                try:
                    self.buffers[token_key].append(h_j.detach().clone().cpu())
                except Exception:
                    try:
                        self.buffers[token_key].append(h_j.cpu())
                    except Exception:
                        pass

            store = self.prototype_stores[token_key]
            centroids_snapshot: Optional[List[torch.Tensor]] = None

            try:
                if hasattr(store, 'centroids') and len(store.centroids) > 0:
                    centroids_snapshot = []
                    for c in store.centroids:
                        try:
                            if isinstance(c, torch.Tensor):
                                centroids_snapshot.append(c.clone().cpu())
                            else:
                                centroids_snapshot.append(
                                    torch.from_numpy(np.asarray(c, dtype=np.float32)).cpu()
                                )
                        except Exception:
                            continue
                    if not centroids_snapshot:
                        centroids_snapshot = None
            except Exception:
                centroids_snapshot = None

            assignment = -1
            prob_list: List[float] = []
            uncertainty = 0.0
            span_pred = 0.0
            gate_val = 0.0
            h_aug = h_j

            if centroids_snapshot and len(centroids_snapshot) >= 1:
                try:
                    try:
                        h_cpu = h_j.detach().cpu().numpy()
                    except Exception:
                        h_cpu = h_j.cpu().numpy()

                    try:
                        cents_np = np.stack([c.numpy() for c in centroids_snapshot], axis=0)
                    except Exception:
                        cents_np = np.stack(
                            [np.asarray(c, dtype=np.float32) for c in centroids_snapshot],
                            axis=0,
                        )

                    dists_np = np.linalg.norm(cents_np - h_cpu[None, :], axis=1)

                    if dists_np.size > 0:
                        min_dist = float(dists_np.min())
                        min_idx = int(np.argmin(dists_np))
                        max_dist = float(dists_np.max())
                        num_valid = len(dists_np)

                        if store.size() < self.maxprotos and min_dist > store.get_adaptive_threshold(DSCD_NEW_SENSE_LAMBDA):
                            store.add_prototype(h_j, time.time(), count=1)
                            assignment = store.size() - 1
                            centroids_snapshot.append(h_j.cpu())
                            cents_np = np.vstack([cents_np, h_cpu[None, :]])
                            dists_np = np.append(dists_np, 0.0)
                            num_valid += 1
                        else:
                            assignment = min_idx

                        if num_valid >= 2:
                            span_range = max_dist - min_dist
                            ratio = span_range / (max_dist + 1e-8)
                            if ratio > 0.15:
                                span_pred = float((0.3 + (ratio - 0.15) * 1.5))
                            else:
                                span_pred = float(ratio * 2.0)
                            span_pred = min(1.0, max(0.0, span_pred))
                        else:
                            if max_dist > 0.3:
                                span_pred = min(1.0, max_dist * 0.8)
                            else:
                                span_pred = 0.1

                        try:
                            store.update_rolling_stats(min_dist)
                        except Exception:
                            pass

                        try:
                            dist_tensor = torch.from_numpy(dists_np).to(device)
                            probs_tensor = F.softmax(-dist_tensor / self.temperature, dim=0)
                            prob_list = probs_tensor.tolist()

                            entropy = -torch.sum(probs_tensor * torch.log(probs_tensor + 1e-10))
                            max_entropy = np.log(num_valid) if num_valid > 1 else 1.0
                            H_norm = float(entropy.item() / max_entropy) if max_entropy > 0 else 0.0

                            try:
                                sigma_pred = self.sigmanet(h_j.unsqueeze(0))
                                sigma_norm = float(torch.sigmoid(sigma_pred).item())
                            except Exception:
                                sigma_norm = 0.5

                            dmin_norm = min(1.0, min_dist / 0.5)

                            uncertainty = 0.5 * H_norm + 0.25 * sigma_norm + 0.25 * dmin_norm
                            uncertainty = min(1.0, max(0.0, uncertainty))

                        except Exception:
                            exps = np.exp(-dists_np - np.max(-dists_np)) if dists_np.size > 0 else np.array([])
                            if exps.size > 0:
                                probs = exps / (exps.sum() + 1e-12)
                                prob_list = probs.tolist()
                                entropy_val = -np.sum(probs * np.log(probs + 1e-10))
                                max_entropy = np.log(num_valid) if num_valid > 1 else 1.0
                                uncertainty = float(entropy_val / max_entropy) if max_entropy > 0 else 0.0
                            else:
                                prob_list = []
                                uncertainty = 0.0

                        if num_valid >= 2 and uncertainty > 0.15:
                            logit = 10.0 * (uncertainty - 0.15)
                            gate_val = float(1.0 / (1.0 + math.exp(-logit)))
                        else:
                            gate_val = 0.0

                        if gate_val > 0.3 and 0 <= assignment < len(centroids_snapshot):
                            try:
                                centroid_t = centroids_snapshot[assignment]
                                if device != torch.device("cpu"):
                                    try:
                                        centroid_t = centroid_t.to(device)
                                    except Exception:
                                        pass
                                h_aug = h_j + 0.1 * (centroid_t - h_j)
                            except Exception:
                                h_aug = h_j

                except Exception as e:
                    if DEBUG_DISCOVERY:
                        print(f"[DSCD] Assignment error for {token_key}: {str(e)[:200]}")

            outputs["proto_assignments"].append(torch.tensor(assignment))
            outputs["proto_probs"].append(prob_list)
            outputs["uncertainties"].append(uncertainty)
            outputs["span_preds"].append(span_pred)
            outputs["gates"].append(gate_val)
            outputs["h_augmented"].append(h_aug)

        try:
            if not train_mode and len(self.prototype_stores) > 0 and VERBOSE_LOGGING:
                if self.last_periodic_check % PRINT_INTERVAL == 0:
                    self.print_clusters_summary()
                self.last_periodic_check += 1
        except Exception:
            pass

        return outputs

    def print_clusters_summary(self) -> None:
        try:
            items: List[Tuple[str, int, int, float, float, int]] = []
            for token, store in self.prototype_stores.items():
                try:
                    proto_sample_count = sum(getattr(store, 'counts', []) or [])
                except Exception:
                    proto_sample_count = 0

                buffer_len = len(self.buffers.get(token, [])) if token in self.buffers else 0
                total_count = proto_sample_count if proto_sample_count > 0 else buffer_len
                protos = store.size()
                mu = getattr(store, 'mu', 0.0)
                tau = getattr(store, 'tau', 0.0)
                items.append((token, total_count, protos, mu, tau, buffer_len))

            items.sort(key=lambda x: x[1], reverse=True)
            top5 = items[:5]

            if VERBOSE_LOGGING:
                print("[CLUSTER] Top 5 clusters:")
                print("-" * 100)
                print(f"{'Rank':<6} {'Token':<18} {'Count':<12} {'Protos':<8} {'BufLen':<8} {'mu':<15} {'tau':<15}")
                print("-" * 100)
                for rank, (tok, cnt, prot, mu, tau, buflen) in enumerate(top5, 1):
                    tok_str = str(tok)[:18]
                    print(f"{rank:<6} {tok_str:<18} {cnt:<12} {prot:<8} {buflen:<8} {mu:<15.6f} {tau:<15.6f}")
                print("-" * 100)

                total_samples = sum(item[1] for item in items)
                total_protos = sum(item[2] for item in items)
                total_buffers = sum(item[5] for item in items)
                print(f"Total: {len(items)} clusters | {total_samples} samples | {total_protos} protos | {total_buffers} buffers")
        except Exception as e:
            try:
                if VERBOSE_LOGGING:
                    print(f"[CLUSTER] Error printing summary: {str(e)[:200]}")
            except Exception:
                pass

    def cluster_buffer_to_prototypes_hierarchical(self, tokentype: str) -> bool:
        try:
            if not self.should_track_token(tokentype):
                if DEBUG_DISCOVERY:
                    print(f"[DSCD-CLUSTER] Skipping non-word token: {tokentype}")
                return False

            with self.buffer_lock:
                if tokentype not in self.buffers:
                    return False

                buf_snapshot = [e.clone() if isinstance(e, torch.Tensor) else e for e in self.buffers[tokentype]]

            if len(buf_snapshot) < self.nmin:
                if DEBUG_DISCOVERY:
                    print(f"[DSCD-CLUSTER] {tokentype} buffer={len(buf_snapshot)} < nmin={self.nmin}")
                return False

            emb_list: List[np.ndarray] = []
            for e in buf_snapshot:
                try:
                    if isinstance(e, torch.Tensor):
                        try:
                            emb_list.append(e.numpy())
                        except Exception:
                            emb_list.append(e.cpu().numpy())
                    else:
                        emb_list.append(np.asarray(e, dtype=np.float32))
                except Exception:
                    continue

            if len(emb_list) == 0:
                return False

            if len(emb_list) > self.max_clustering_points:
                idxs = np.random.choice(len(emb_list), size=self.max_clustering_points, replace=False)
                new_embeddings = np.stack([emb_list[i] for i in idxs], axis=0)
            else:
                new_embeddings = np.stack(emb_list, axis=0)

            if new_embeddings.shape[0] < 2:
                return False

            norms = np.linalg.norm(new_embeddings, axis=1)
            if np.all(norms < 1e-6):
                if DEBUG_DISCOVERY:
                    print(f"[DSCD-CLUSTER] {tokentype} all zero vectors, skipping")
                return False

            if DEBUG_DISCOVERY:
                print(f"[DSCD-CLUSTER] {tokentype} buflen={len(buf_snapshot)} "
                      f"sampled={new_embeddings.shape[0]} meannorm={norms.mean():.4f}")

            store = self.prototype_stores[tokentype]

            existing_centroids: List[np.ndarray] = []
            if hasattr(store, 'centroids') and len(store.centroids) > 0:
                for c in store.centroids:
                    try:
                        if isinstance(c, torch.Tensor):
                            try:
                                existing_centroids.append(c.cpu().numpy())
                            except Exception:
                                existing_centroids.append(c.numpy())
                        else:
                            existing_centroids.append(np.asarray(c, dtype=np.float32))
                    except Exception:
                        continue

            if len(existing_centroids) >= 1:
                existing_centroids_np = np.stack(existing_centroids, axis=0)
                combined_embeddings = np.vstack([existing_centroids_np, new_embeddings])
                if DEBUG_DISCOVERY:
                    print(f"[DSCD-CLUSTER] {tokentype} Incremental - "
                          f"{len(existing_centroids)} existing + {new_embeddings.shape[0]} new "
                          f"= {combined_embeddings.shape[0]} total embeddings")
                embeddings = combined_embeddings
            else:
                embeddings = new_embeddings

            protos_added = 0
            new_centroids: List[torch.Tensor] = []
            new_counts: List[int] = []
            new_times: List[float] = []

            if _HAS_CLUSTERING:
                try:
                    condensed = pdist(embeddings, metric='euclidean')
                    if condensed.size > 0:
                        Z = linkage(condensed, method='average')

                        max_dist = condensed.max() if condensed.size > 0 else 1.0
                        relative_threshold = self.dispersion_threshold
                        absolute_threshold = relative_threshold * max_dist

                        clusters = fcluster(Z, t=absolute_threshold, criterion='distance') - 1

                        if clusters.size > 0:
                            max_c = int(clusters.max())
                            for c_id in range(max_c + 1):
                                mask = (clusters == c_id)
                                cluster_size = int(mask.sum())
                                if cluster_size >= self.nmin:
                                    centroid = embeddings[mask].mean(axis=0).astype(np.float32)
                                    centroid_tensor = torch.from_numpy(centroid)
                                    new_centroids.append(centroid_tensor)
                                    new_counts.append(cluster_size)
                                    new_times.append(time.time())
                                    protos_added += 1

                            if len(new_centroids) > self.maxprotos:
                                sorted_indices = np.argsort(new_counts)[-self.maxprotos:]
                                new_centroids = [new_centroids[i] for i in sorted_indices]
                                new_counts = [new_counts[i] for i in sorted_indices]
                                new_times = [new_times[i] for i in sorted_indices]
                                protos_added = len(new_centroids)

                            store.centroids = new_centroids
                            store.counts = new_counts
                            store.creation_time = new_times
                            store.labels = torch.tensor(clusters)

                            if DEBUG_DISCOVERY and protos_added > 0:
                                print(f"[DSCD-CLUSTER] Hierarchical created {protos_added} prototypes for {tokentype}")
                except Exception as e:
                    if DEBUG_DISCOVERY:
                        print(f"[DSCD-CLUSTER] Hierarchical failed for {tokentype}: {type(e).__name__} {str(e)[:200]}")

            if protos_added == 0 and _HAS_KMEANS:
                try:
                    min_k = 1
                    max_k = min(self.maxprotos, len(embeddings) // self.nmin)
                    if max_k < min_k:
                        max_k = min_k

                    if len(embeddings) > 20:
                        k_guess = min(max_k, max(2, int(np.sqrt(len(embeddings)) / 2)))
                    elif len(embeddings) > 10:
                        k_guess = min(max_k, 2)
                    else:
                        k_guess = 1

                    k_guess = max(min_k, min(k_guess, len(embeddings)))

                    if k_guess >= 1 and len(embeddings) >= k_guess:
                        km = KMeans(n_clusters=k_guess, random_state=0, n_init=10).fit(embeddings)
                        labels = km.labels_

                        new_centroids = []
                        new_counts = []
                        new_times = []

                        for c in range(k_guess):
                            mask = (labels == c)
                            cluster_size = int(mask.sum())
                            if cluster_size >= self.nmin:
                                centroid = embeddings[mask].mean(axis=0).astype(np.float32)
                                centroid_tensor = torch.from_numpy(centroid)
                                new_centroids.append(centroid_tensor)
                                new_counts.append(cluster_size)
                                new_times.append(time.time())
                                protos_added += 1

                        store.centroids = new_centroids
                        store.counts = new_counts
                        store.creation_time = new_times
                        store.labels = torch.tensor(labels)

                        if DEBUG_DISCOVERY and protos_added > 0:
                            print(f"[DSCD-CLUSTER] KMeans created {protos_added} prototypes for {tokentype}")
                except Exception as e:
                    if DEBUG_DISCOVERY:
                        print(f"[DSCD-CLUSTER] KMeans failed for {tokentype}: {type(e).__name__} {str(e)[:200]}")

            if DEBUG_DISCOVERY:
                print(f"[DSCD-CLUSTER] {tokentype} final={store.size()} protos, counts={store.counts}")

            try:
                if store.centroids:
                    counts = store.counts if store.counts else [1] * len(store.centroids)
                    total_count = sum(counts)
                    mean_count = float(total_count / max(1, len(counts)))

                    self.cluster_stats[str(tokentype)] = {
                        "num_prototypes": len(store.centroids),
                        "counts": [int(c) for c in counts],
                        "total_samples": int(total_count),
                        "mean_count": float(mean_count),
                        "mu": float(store.mu),
                        "tau": float(store.tau),
                    }
            except Exception:
                pass

            return store.size() > 0

        except Exception as e:
            if DEBUG_DISCOVERY:
                print(f"[DSCD-ERROR] Clustering error for {tokentype}: {type(e).__name__} {str(e)[:200]}")
            return False

    def get_prototype_summary(self) -> Dict[str, Any]:
        total_tokens = len(self.prototype_stores)
        total_prototypes = sum(s.size() for s in self.prototype_stores.values())
        num_homographs = sum(1 for s in self.prototype_stores.values() if s.size() >= 2)
        
        return {
            "total_tokens": total_tokens,
            "total_prototypes": total_prototypes,
            "num_homographs": num_homographs,
        }

    def get_explanations(self, threshold_span: float = 0.3) -> List[Dict[str, Any]]:
        expl: List[Dict[str, Any]] = []
        for tokentype, store in self.prototype_stores.items():
            if store.size() >= 2:
                expl.append({"token": str(tokentype), "protos": store.size()})
        return expl

print("=" * 80)
print("Cell 3: DSCD Ready (DEADLOCK FIXED + OPTIMIZED)")
print("=" * 80)
print("CRITICAL FIXES APPLIED:")
print("=" * 80)
print(" ✅ FIX-DEADLOCK-1: discover_homographs() uses lock timeout (15s)")
print(" ✅ FIX-DEADLOCK-2: Snapshot pattern - release lock before computation")
print(" ✅ FIX-PROGRESS-1: Progress updates every 200 tokens (was 500)")
print(" ✅ FIX-PERFORMANCE-1: Inline dispersion calculation (no nested locks)")
print(" ✅ FIX-PERFORMANCE-2: Added ETA calculation with rate tracking")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 4: ASBN MODULE - FIXED
# ==============================================================================

import traceback
from typing import Any, List, Tuple, Optional, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except Exception:
    _MAX_LENGTH = 48

try:
    _ENABLE_ASBN_TRAINING = bool(ENABLE_ASBN_TRAINING)
except Exception:
    _ENABLE_ASBN_TRAINING = True

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except Exception:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except Exception:
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except Exception:
    _DEBUG_TIMING = False

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
except Exception:
    _SOURCE_LANGUAGE = "bn"

try:
    _GRL_ALPHA_START = float(GRL_ALPHA_START)
    _GRL_ALPHA_END = float(GRL_ALPHA_END)
    _GRL_ALPHA_SCHEDULE = str(GRL_ALPHA_SCHEDULE)
    try:
        _GRL_ALPHA_STEPS = int(GRL_ALPHA_STEPS)
    except Exception:
        _GRL_ALPHA_STEPS = 10000
except Exception:
    _GRL_ALPHA_START = 0.0
    _GRL_ALPHA_END = 1.0
    _GRL_ALPHA_SCHEDULE = "linear"
    _GRL_ALPHA_STEPS = 10000

_has_is_valid_token = "is_valid_token" in globals()
_has_get_tokenizer_special_tokens = "get_tokenizer_special_tokens" in globals()
_has_should_track_token = "should_track_token" in globals()

class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = float(alpha)
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.alpha * grad_output, None

def gradient_reversal(x, alpha: float = 1.0):
    return GradientReversalFunction.apply(x, alpha)

class LightweightDiscriminator(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.classifier(x)

class DomainDiscriminator(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.classifier(x)

class MemoryEfficientASBNModule(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        tokenizer=None,
        language: str = "bn",
        freq_threshold: float = 0.7,
        uncertainty_threshold: float = 0.3,
        gate_threshold: float = 0.5,
        warmup_steps: int = 1000,
        encoder_grl_scale: float = 0.5,
    ):
        super().__init__()
        self.language = language
        self.tokenizer = tokenizer
        self.embed_dim = int(embed_dim)

        self.bn_source = nn.BatchNorm1d(self.embed_dim, track_running_stats=True)
        self.bn_target = nn.BatchNorm1d(self.embed_dim, track_running_stats=True)

        self.d_domain = DomainDiscriminator(self.embed_dim)
        self.d_freq = LightweightDiscriminator(self.embed_dim + 2)
        self.d_ctx = LightweightDiscriminator(self.embed_dim + 2)
        self.d_xl = LightweightDiscriminator(self.embed_dim)
        self.freq_threshold = float(freq_threshold)
        self.uncertainty_threshold = float(uncertainty_threshold)
        self.gate_threshold = float(gate_threshold)
        self.warmup_steps = int(warmup_steps)
        self.current_step = 0
        self.lambda_base = {"freq": 1.0, "ctx": 0.5, "xl": 0.8, "domain": 1.0}
        self.lambda_max = 2.0
        self.encoder_grl_scale = float(encoder_grl_scale)
        self.stats_reset_interval = 1000
        self.stats = {
            "domain_loss": 0.0,
            "domain_accuracy": 0.0,
            "source_accuracy": 0.0,
            "target_accuracy": 0.0,
            "asbn_loss": 0.0,
            "num_updates": 0,
        }
        try:
            if tokenizer is not None:
                if _has_get_tokenizer_special_tokens:
                    self.special_tokens = get_tokenizer_special_tokens(tokenizer)
                else:
                    self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []))
            else:
                self.special_tokens = set()
        except Exception:
            self.special_tokens = set()
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print("[ASBN-INIT] Initialized MemoryEfficientASBNModule:")
            print(f"  - embed_dim: {self.embed_dim}")
            print(f"  - warmup_steps: {self.warmup_steps}")
            print(f"  - encoder_grl_scale: {self.encoder_grl_scale}")
            print(f"  - GRL_ALPHA_STEPS: {_GRL_ALPHA_STEPS}")
            print(f"  - thresholds: freq={self.freq_threshold}, uncert={self.uncertainty_threshold}, gate={self.gate_threshold}")

    def get_grl_alpha(self, global_step: Optional[int] = None) -> float:
        if global_step is None:
            global_step = self.current_step
        step = max(0, int(global_step))
        if _GRL_ALPHA_SCHEDULE == "linear":
            progress = min(1.0, float(step) / float(_GRL_ALPHA_STEPS))
            alpha = _GRL_ALPHA_START + progress * (_GRL_ALPHA_END - _GRL_ALPHA_START)
        elif _GRL_ALPHA_SCHEDULE == "exponential":
            progress = min(1.0, float(step) / float(_GRL_ALPHA_STEPS))
            ratio = _GRL_ALPHA_END / max(1e-8, _GRL_ALPHA_START if _GRL_ALPHA_START > 0 else 1e-3)
            alpha = _GRL_ALPHA_START * (ratio ** progress)
        else:
            alpha = _GRL_ALPHA_END
        alpha = float(torch.clamp(torch.tensor(alpha), 0.0, 1.0).item())
        return alpha

    def get_asbn_stats(self) -> Dict[str, float]:
        return self.get_detailed_stats()

    def get_detailed_stats(self) -> Dict[str, float]:
        if self.stats["num_updates"] > 0:
            n = float(self.stats["num_updates"])
            return {
                "domain_loss": self.stats["domain_loss"] / n,
                "domain_accuracy": self.stats["domain_accuracy"] / n,
                "source_accuracy": self.stats["source_accuracy"] / n,
                "target_accuracy": self.stats["target_accuracy"] / n,
                "asbn_loss": self.stats["asbn_loss"] / n,
                "num_updates": self.stats["num_updates"],
            }
        return {
            "domain_loss": 0.0,
            "domain_accuracy": 0.0,
            "source_accuracy": 0.0,
            "target_accuracy": 0.0,
            "asbn_loss": 0.0,
            "num_updates": 0,
        }

    def reset_stats(self) -> None:
        self.stats = {
            "domain_loss": 0.0,
            "domain_accuracy": 0.0,
            "source_accuracy": 0.0,
            "target_accuracy": 0.0,
            "asbn_loss": 0.0,
            "num_updates": 0,
        }

    def critic_parameters(self):
        return (
            list(self.d_domain.parameters())
            + list(self.d_freq.parameters())
            + list(self.d_ctx.parameters())
            + list(self.d_xl.parameters())
        )

    def _ensure_discriminators_on_device(self, device: torch.device) -> None:
        try:
            for mod in (
                self.d_domain,
                self.d_freq,
                self.d_ctx,
                self.d_xl,
                self.bn_source,
                self.bn_target,
            ):
                try:
                    p = next(mod.parameters())
                    if p.device != device:
                        mod.to(device)
                except StopIteration:
                    mod.to(device)
                except Exception:
                    pass
        except Exception:
            if _VERBOSE_LOGGING:
                try:
                    print("[ASBN] Device migration failed:", traceback.format_exc().splitlines()[-1])
                except Exception:
                    print("[ASBN] Device migration failed")

    def _parse_proto_probs_matrix(self, proto_probs: Any, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
        pmax = torch.full((batch_size, seq_len), 0.5, dtype=torch.float32, device=device)
        try:
            if proto_probs is None:
                return pmax
            if isinstance(proto_probs, torch.Tensor):
                if proto_probs.dim() == 3:
                    B, T, K = proto_probs.shape
                    p = proto_probs.detach().to(device)
                    b_max = min(batch_size, B)
                    t_max = min(seq_len, T)
                    pmax[:b_max, :t_max] = p[:b_max, :t_max].max(dim=2)[0]
                    return pmax
                if proto_probs.dim() == 2:
                    p = proto_probs.detach().to(device)
                    if batch_size >= 1:
                        t_max = min(seq_len, p.size(0))
                        pmax[0, :t_max] = p[:t_max].max(dim=1)[0]
                        return pmax
            if isinstance(proto_probs, (list, tuple)):
                if len(proto_probs) == batch_size:
                    for b in range(batch_size):
                        row = proto_probs[b]
                        if isinstance(row, torch.Tensor) and row.dim() == 2:
                            t_max = min(seq_len, row.size(0))
                            pmax[b, :t_max] = row[:t_max].max(dim=1)[0].to(device)
                        elif isinstance(row, (list, tuple)):
                            for t in range(min(seq_len, len(row))):
                                try:
                                    val = row[t]
                                    if isinstance(val, torch.Tensor):
                                        pmax[b, t] = float(val.max().item())
                                    else:
                                        arr = np.asarray(val, dtype=np.float32)
                                        pmax[b, t] = float(np.max(arr))
                                except Exception:
                                    pmax[b, t] = 0.5
                else:
                    if batch_size == 1:
                        row = proto_probs
                        for t in range(min(seq_len, len(row))):
                            try:
                                val = row[t]
                                if isinstance(val, torch.Tensor):
                                    pmax[0, t] = float(val.max().item())
                                else:
                                    pmax[0, t] = float(np.max(np.asarray(val, dtype=np.float32)))
                            except Exception:
                                pmax[0, t] = 0.5
        except Exception as e:
            if _VERBOSE_LOGGING:
                print(f"[ASBN] parse_proto_probs exception: {e}")
        return pmax

    def _parse_scalar_matrix(self, mat: Any, batch_size: int, seq_len: int, device: torch.device,
                            default: float = 0.0) -> torch.Tensor:
        out = torch.full((batch_size, seq_len), float(default), dtype=torch.float32, device=device)
        try:
            if mat is None:
                return out
            if isinstance(mat, torch.Tensor):
                if mat.dim() == 3:
                    if mat.size(2) == 1:
                        mat = mat.squeeze(2)
                    B, T = mat.size(0), mat.size(1)
                    b_max = min(batch_size, B)
                    t_max = min(seq_len, T)
                    out[:b_max, :t_max] = mat[:b_max, :t_max].to(device)
                elif mat.dim() == 2:
                    if mat.size(0) == batch_size:
                        t_max = min(seq_len, mat.size(1))
                        out[:, :t_max] = mat[:, :t_max].to(device)
                    elif batch_size == 1:
                        t_max = min(seq_len, mat.size(0))
                        out[0, :t_max] = mat[:t_max].to(device)
                elif mat.dim() == 1 and batch_size == 1:
                    t_max = min(seq_len, mat.size(0))
                    out[0, :t_max] = mat[:t_max].to(device)
            elif isinstance(mat, (list, tuple)):
                if len(mat) == batch_size:
                    for b in range(batch_size):
                        row = mat[b]
                        if isinstance(row, torch.Tensor) and row.dim() >= 1:
                            t_max = min(seq_len, row.size(0))
                            for t in range(t_max):
                                out[b, t] = float(row[t].item())
                        elif isinstance(row, (list, tuple, np.ndarray)):
                            t_max = min(seq_len, len(row))
                            for t in range(t_max):
                                try:
                                    v = row[t]
                                    out[b, t] = (float(v.item()) if isinstance(v, torch.Tensor) else float(v))
                                except Exception:
                                    out[b, t] = float(default)
                elif batch_size == 1:
                    row = mat
                    t_max = min(seq_len, len(row))
                    for t in range(t_max):
                        try:
                            v = row[t]
                            out[0, t] = (float(v.item()) if isinstance(v, torch.Tensor) else float(v))
                        except Exception:
                            out[0, t] = float(default)
        except Exception:
            if _VERBOSE_LOGGING:
                try:
                    print("[ASBN] parse_scalar_matrix exception:", traceback.format_exc().splitlines()[-1])
                except Exception:
                    pass
        return out

    def compute_lambda_scaled_tensor(self, pmax: torch.Tensor, uncertainty: torch.Tensor,
                                    gate: torch.Tensor, lambda_type: str) -> torch.Tensor:
        base = float(self.lambda_base.get(lambda_type, 0.2))
        
        lam = base * (1.0 - pmax + 0.15) * (uncertainty + 0.15) * (gate + 0.15)
        
        lam = torch.clamp(lam, min=0.05, max=float(self.lambda_max))
        
        lam = torch.where(torch.isfinite(lam), lam, torch.full_like(lam, 0.05))
        
        if _DEBUG_DISCOVERY and torch.any(lam < 0.1):
            low_count = (lam < 0.1).sum().item()
            if low_count > 0:
                print(f"[ASBN-WARN] {low_count} tokens have low lambda (<0.1) for {lambda_type}")
        
        return lam

    def forward(self, h: torch.Tensor, domain_labels: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        if not isinstance(h, torch.Tensor) or h.dim() != 3:
            dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
            return h, torch.tensor(0.0, device=dev)
        
        B, T, H = h.size()
        device = h.device
        
        if domain_labels is not None:
            if domain_labels.dim() == 0:
                domain_labels = domain_labels.unsqueeze(0)
            if domain_labels.size(0) == 1 and B > 1:
                domain_labels = domain_labels.expand(B)
            elif domain_labels.size(0) != B:
                if _DEBUG_DISCOVERY:
                    print(f"[ASBN] Domain label size mismatch: {domain_labels.size(0)} vs batch {B}, using first label")
                domain_labels = domain_labels[0].unsqueeze(0).expand(B)
        
        h_flat = h.view(B * T, H)
        h_normalized = h_flat.clone()
        
        if domain_labels is not None:
            domain_expanded = domain_labels.unsqueeze(1).expand(B, T).reshape(-1)
            source_mask = domain_expanded == 0
            target_mask = domain_expanded == 1
            
            if self.training:
                if source_mask.sum() >= 2:
                    h_normalized[source_mask] = self.bn_source(h_flat[source_mask])
                elif source_mask.sum() == 1:
                    if self.bn_source.track_running_stats and self.bn_source.num_batches_tracked > 0:
                        self.bn_source.eval()
                        h_normalized[source_mask] = self.bn_source(h_flat[source_mask])
                        self.bn_source.train()
                    else:
                        h_normalized[source_mask] = h_flat[source_mask]
                
                if target_mask.sum() >= 2:
                    h_normalized[target_mask] = self.bn_target(h_flat[target_mask])
                elif target_mask.sum() == 1:
                    if self.bn_target.track_running_stats and self.bn_target.num_batches_tracked > 0:
                        self.bn_target.eval()
                        h_normalized[target_mask] = self.bn_target(h_flat[target_mask])
                        self.bn_target.train()
                    else:
                        h_normalized[target_mask] = h_flat[target_mask]
        
        h_out = h_normalized.view(B, T, H)
        
        if not self.training or not _ENABLE_ASBN_TRAINING or domain_labels is None:
            return h_out, torch.tensor(0.0, device=device)
        
        if self.current_step < self.warmup_steps:
            return h_out, torch.tensor(0.0, device=device)
        
        try:
            self._ensure_discriminators_on_device(device)
            
            grl_alpha = self.get_grl_alpha(self.current_step)
            
            domain_flat = domain_labels.unsqueeze(1).expand(B, T).reshape(-1)
            
            domain_input = gradient_reversal(h_normalized, alpha=grl_alpha)
            domain_logits = self.d_domain(domain_input)
            domain_loss = F.cross_entropy(domain_logits, domain_flat)
            
            with torch.no_grad():
                domain_preds = torch.argmax(domain_logits, dim=1)
                domain_accuracy = (domain_preds == domain_flat).float().mean()
                
                source_mask = domain_flat == 0
                target_mask = domain_flat == 1
                
                if source_mask.any():
                    source_acc = (domain_preds[source_mask] == domain_flat[source_mask]).float().mean()
                    self.stats["source_accuracy"] += float(source_acc.item())
                
                if target_mask.any():
                    target_acc = (domain_preds[target_mask] == domain_flat[target_mask]).float().mean()
                    self.stats["target_accuracy"] += float(target_acc.item())
                
                self.stats["domain_loss"] += float(domain_loss.item())
                self.stats["domain_accuracy"] += float(domain_accuracy.item())
                self.stats["asbn_loss"] += float(domain_loss.item())
                self.stats["num_updates"] += 1
                
                if self.stats["num_updates"] >= self.stats_reset_interval:
                    if _DEBUG_DISCOVERY:
                        stats = self.get_detailed_stats()
                        print(f"\n[ASBN] Stats after {stats['num_updates']} updates:")
                        print(f"  Domain loss: {stats['domain_loss']:.4f}")
                        print(f"  Domain accuracy: {stats['domain_accuracy']:.2%}")
                        print(f"  Source accuracy: {stats['source_accuracy']:.2%}")
                        print(f"  Target accuracy: {stats['target_accuracy']:.2%}")
                    self.reset_stats()
            
            return h_out, domain_loss
            
        except Exception as e:
            if _VERBOSE_LOGGING:
                print(f"[ASBN] Adversarial training failed: {e}")
            return h_out, torch.tensor(0.0, device=device)

    def forward_with_grl_simplified(self, h: torch.Tensor, proto_probs: Any, uncertainties: Any, gates: Any,
                                   token_word_map: Optional[List[Dict[int, str]]] = None,
                                   domain_labels: Optional[torch.Tensor] = None,
                                   global_step: Optional[int] = None) \
            -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        if global_step is not None:
            self.current_step = int(global_step)
        dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
        if self.current_step < self.warmup_steps:
            if _DEBUG_DISCOVERY and self.current_step % 100 == 0:
                print(f"[ASBN] Warmup: {self.current_step}/{self.warmup_steps}")
            zero = torch.tensor(0.0, device=dev)
            return zero, zero, zero, zero
        if not self.training or not _ENABLE_ASBN_TRAINING:
            zero = torch.tensor(0.0, device=dev)
            return zero, zero, zero, zero
        if not isinstance(h, torch.Tensor) or h.dim() != 3:
            zero = torch.tensor(0.0, device=dev)
            return zero, zero, zero, zero
        device = h.device
        self._ensure_discriminators_on_device(device)
        self.d_domain.train()
        self.d_freq.train()
        self.d_ctx.train()
        self.d_xl.train()
        B, T, H = h.size()
        pmax_mat = self._parse_proto_probs_matrix(proto_probs, B, T, device)
        U_mat = self._parse_scalar_matrix(uncertainties, B, T, device, default=0.2)
        G_mat = self._parse_scalar_matrix(gates, B, T, device, default=0.1)
        sel_mask = torch.ones((B, T), dtype=torch.bool, device=device)
        batch_indices = torch.arange(B, device=device).unsqueeze(1).expand(B, T)
        if token_word_map:
            try:
                for b in range(min(B, len(token_word_map))):
                    wm = token_word_map[b] or {}
                    for t in range(T):
                        if t in wm:
                            try:
                                token_str = wm[t]
                                tracked = True
                                if _has_should_track_token:
                                    tracked = bool(globals()["should_track_token"](token_str))
                                elif _has_is_valid_token:
                                    tracked = bool(is_valid_token(token_str, self.special_tokens, self.tokenizer, language=self.language))
                                if not tracked:
                                    sel_mask[b, t] = False
                            except Exception:
                                pass
            except Exception:
                if _VERBOSE_LOGGING:
                    try:
                        print("[ASBN] Token filtering failed:", traceback.format_exc().splitlines()[-1])
                    except Exception:
                        pass
        sel_idx = sel_mask.view(-1).nonzero(as_tuple=False).squeeze(1)
        batch_idx = batch_indices.view(-1)[sel_idx]
        if sel_idx.numel() == 0:
            if _DEBUG_DISCOVERY:
                print("[ASBN] No valid tokens after filtering")
            zero = torch.tensor(0.0, device=device)
            return zero, zero, zero, zero
        h_flat = h.view(B * T, H)
        sel_emb = h_flat[sel_idx]
        pmax_flat = pmax_mat.view(-1)[sel_idx]
        U_flat = U_mat.view(-1)[sel_idx]
        G_flat = G_mat.view(-1)[sel_idx]
        seq_len_feature = float(T) / max(int(_MAX_LENGTH), 1)
        freq_feature = torch.stack([pmax_flat, U_flat], dim=1).to(device)
        ctx_feature = torch.stack([G_flat, torch.full_like(G_flat, seq_len_feature)], dim=1).to(device)
        xl_input = sel_emb
        grl_alpha = self.get_grl_alpha(global_step)
        freq_input = torch.cat([sel_emb, freq_feature], dim=1)
        ctx_input = torch.cat([sel_emb, ctx_feature], dim=1)
        xl_input_grl = gradient_reversal(xl_input, alpha=grl_alpha)
        freq_input_grl = gradient_reversal(freq_input, alpha=grl_alpha)
        ctx_input_grl = gradient_reversal(ctx_input, alpha=grl_alpha)
        freq_logits = self.d_freq(freq_input_grl)
        ctx_logits = self.d_ctx(ctx_input_grl)
        xl_logits = self.d_xl(xl_input_grl)
        freq_label = (pmax_flat > self.freq_threshold).long().to(device)
        ctx_label = (U_flat < self.uncertainty_threshold).long().to(device)
        xl_label = (G_flat > self.gate_threshold).long().to(device)
        loss_freq = F.cross_entropy(freq_logits, freq_label, reduction="none")
        loss_ctx = F.cross_entropy(ctx_logits, ctx_label, reduction="none")
        loss_xl = F.cross_entropy(xl_logits, xl_label, reduction="none")
        lam_freq = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "freq")
        lam_ctx = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "ctx")
        lam_xl = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "xl")
        weighted = lam_freq * loss_freq + lam_ctx * loss_ctx + lam_xl * loss_xl
        mean_weighted = torch.mean(weighted)
        domain_loss = torch.tensor(0.0, device=device)
        domain_accuracy = torch.tensor(0.0, device=device)
        if domain_labels is not None:
            try:
                if domain_labels.dim() == 0:
                    domain_labels = domain_labels.unsqueeze(0)
                if domain_labels.size(0) == 1 and B > 1:
                    domain_labels = domain_labels.expand(B)
                elif domain_labels.size(0) != B:
                    domain_labels = domain_labels[0].unsqueeze(0).expand(B)
                domain_flat = domain_labels[batch_idx]
                domain_input = gradient_reversal(sel_emb, alpha=grl_alpha)
                domain_logits = self.d_domain(domain_input)
                domain_loss = F.cross_entropy(domain_logits, domain_flat)
                with torch.no_grad():
                    domain_preds = torch.argmax(domain_logits, dim=1)
                    domain_accuracy = (domain_preds == domain_flat).float().mean()
                    source_mask = domain_flat == 0
                    target_mask = domain_flat == 1
                    if source_mask.any():
                        source_acc = ((domain_preds[source_mask] == domain_flat[source_mask]).float().mean())
                        self.stats["source_accuracy"] += float(source_acc.item())
                    if target_mask.any():
                        target_acc = ((domain_preds[target_mask] == domain_flat[target_mask]).float().mean())
                        self.stats["target_accuracy"] += float(target_acc.item())
            except Exception as e:
                if _VERBOSE_LOGGING:
                    print(f"[ASBN] Domain loss failed: {e}")
        warmup_progress = min(1.0, float(self.current_step - self.warmup_steps) / float(self.warmup_steps))
        effective_scale = self.encoder_grl_scale * warmup_progress
        encoder_loss = effective_scale * (mean_weighted + domain_loss)
        
        if torch.isnan(encoder_loss) or torch.isinf(encoder_loss):
            if _VERBOSE_LOGGING:
                print("[ASBN-WARN] NaN/Inf detected in encoder_loss, returning zero")
            encoder_loss = torch.tensor(0.0, device=device)
        
        try:
            with torch.no_grad():
                self.stats["domain_loss"] += float(domain_loss.item())
                self.stats["domain_accuracy"] += float(domain_accuracy.item())
                self.stats["asbn_loss"] += float(encoder_loss.item())
                self.stats["num_updates"] += 1
                if self.stats["num_updates"] >= self.stats_reset_interval:
                    if _DEBUG_DISCOVERY:
                        stats = self.get_detailed_stats()
                        print(f"\n[ASBN-STATS] After {stats['num_updates']} updates:")
                        print(f"  Domain loss: {stats['domain_loss']:.4f}")
                        print(f"  Domain acc: {stats['domain_accuracy']:.2%}")
                        print(f"  Source acc: {stats['source_accuracy']:.2%}")
                        print(f"  Target acc: {stats['target_accuracy']:.2%}")
                        print(f"  ASBN loss: {stats['asbn_loss']:.4f}")
                    self.reset_stats()
        except Exception:
            pass
        if _DEBUG_DISCOVERY and self.current_step % 500 == 0:
            print(f"\n[ASBN-STEP-{self.current_step}]")
            print(f"  GRL alpha: {grl_alpha:.3f}")
            print(f"  Effective scale: {effective_scale:.3f}")
            print(f"  Encoder loss: {encoder_loss.item():.4f}")
            print(f"  Domain loss: {domain_loss.item():.4f}")
            print(f"  Domain acc: {domain_accuracy.item():.2%}")
        return encoder_loss, mean_weighted, domain_loss, domain_accuracy

    def test_asbn(self, batch_size: int = 2, seq_len: int = 10) -> bool:
        print("\n" + "=" * 60)
        print("[ASBN-TEST] Testing ASBN module")
        print("=" * 60)
        try:
            try:
                device = next(self.parameters()).device
            except StopIteration:
                device = torch.device("cpu")
            h = torch.randn(batch_size, seq_len, self.embed_dim, device=device)
            domain_labels = torch.randint(0, 2, (batch_size,), device=device)
            h_out, _ = self.forward(h, domain_labels)
            assert h_out.shape == h.shape, "Forward output shape mismatch"
            print("  forward() passed")
            proto_probs = torch.rand(batch_size, seq_len, 3, device=device)
            uncertainties = torch.rand(batch_size, seq_len, device=device)
            gates = torch.rand(batch_size, seq_len, device=device)
            self.train()
            self.current_step = self.warmup_steps + 1
            enc_loss, adv_loss, dom_loss, dom_acc = self.forward_with_grl_simplified(
                h,
                proto_probs,
                uncertainties,
                gates,
                domain_labels=domain_labels,
                global_step=self.current_step,
            )
            assert enc_loss.item() >= 0.0, "Encoder loss negative"
            assert 0.0 <= dom_acc.item() <= 1.0, "Domain accuracy out of range"
            print("  forward_with_grl_simplified() passed")
            stats = self.get_detailed_stats()
            assert "domain_loss" in stats, "Missing domain_loss in stats"
            print("  Statistics tracking passed")
            print("\nAll ASBN tests passed")
            print("=" * 60 + "\n")
            return True
        except Exception as e:
            print(f"\nASBN test failed: {e}")
            traceback.print_exc()
            print("=" * 60 + "\n")
            return False

print("\n" + "=" * 80)
print("Cell 4: ASBN Ready (dynamic GRL, DSCD-aware) - FIXED")
print("=" * 80)
print("FIXES APPLIED:")
print("=" * 80)
print(" F1:  Lambda offset increased 0.1 → 0.15, min clamp 0.01 → 0.05")
print(" F2:  _parse_scalar_matrix now handles [B,T,1] by squeezing dimension 2")
print(" F3:  GRL alpha now clamped to [0.0, 1.0] range")
print(" F4:  Encoder_grl_scale ramps gradually during warmup period")
print(" F5:  BatchNorm single-token fallback uses running stats when available")
print(" F6:  Added NaN/Inf check before returning encoder_loss")
print(" F7:  Default uncertainty changed from 0.1 → 0.2 for realistic signals")
print(" F8:  Default gate changed from 0.0 → 0.1 to ensure gradient flow")
print(" F9:  Added debug warning when lambda values drop below 0.1")
print(" F10: Lambda computation uses safer addition instead of multiplication chain")
print(" F11: CRITICAL - forward() now does adversarial training with domain labels")
print(" F12: Added domain discrimination in forward() method")
print(" F13: Fixed domain_labels dimension handling in forward()")
print(" F14: Added statistics tracking to forward() method")
print(" F15: BatchNorm applied before adversarial loss in forward()")
print(" F16: Domain classifier input passes through GRL in forward()")
print(" F17: Safe lambda formula prevents gradient vanishing")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 5: TRG (PURE DATA-DRIVEN) - FIXED
# ==============================================================================

from typing import List, Dict, Tuple, Optional, Set, Any
from collections import deque
import traceback
import numpy as np
import torch
import torch.nn as nn
import threading
import time

try:
    _TRG_EVIDENCE_K = int(TRG_EVIDENCE_K)
except (NameError, ValueError, TypeError):
    _TRG_EVIDENCE_K = 3

try:
    _TRG_GEN_EMBED = int(TRG_GEN_EMBED)
except (NameError, ValueError, TypeError):
    _TRG_GEN_EMBED = 64

try:
    _MAX_SILVER_BUFFER = int(MAX_SILVER_BUFFER)
except (NameError, ValueError, TypeError):
    _MAX_SILVER_BUFFER = 50

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except NameError:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except NameError:
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except NameError:
    _DEBUG_TIMING = False

try:
    _ENABLE_TRG_INFERENCE = bool(ENABLE_TRG_INFERENCE)
except NameError:
    _ENABLE_TRG_INFERENCE = True

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
except (NameError, TypeError):
    _SOURCE_LANGUAGE = "bn"

try:
    _TRG_UNCERTAINTY_THRESHOLD = float(TAU_LOW)
except (NameError, ValueError, TypeError):
    _TRG_UNCERTAINTY_THRESHOLD = 0.15

try:
    _TRG_SPAN_THRESHOLD = float(SPAN_THRESHOLD)
except (NameError, ValueError, TypeError):
    _TRG_SPAN_THRESHOLD = 0.12

try:
    _TAU_HIGH = float(TAU_HIGH)
except (NameError, ValueError, TypeError):
    _TAU_HIGH = 0.85

try:
    _TAU_LOW = float(TAU_LOW)
except (NameError, ValueError, TypeError):
    _TAU_LOW = 0.15

try:
    _TAU_ACCEPT = float(TAU_ACCEPT)
except (NameError, ValueError, TypeError):
    _TAU_ACCEPT = 0.80

try:
    _TRG_TEMPERATURE = float(TRG_TEMPERATURE)
except (NameError, ValueError, TypeError):
    _TRG_TEMPERATURE = 1.0

try:
    _MAX_EXPLANATIONS_PER_SENTENCE = (
        int(MAX_EXPLANATIONS_PER_SENTENCE)
        if "MAX_EXPLANATIONS_PER_SENTENCE" in globals()
        else 10
    )
except Exception:
    _MAX_EXPLANATIONS_PER_SENTENCE = 10

_has_is_valid_token = "is_valid_token" in globals()
_has_get_tokenizer_special_tokens = "get_tokenizer_special_tokens" in globals()
_has_get_cached_special_tokens = "get_cached_special_tokens" in globals()

_TRG_PUNCT_SET = set(".,;:!?\"\'-()[]{}/\\")


def _fallback_is_valid_token(
    token: str, special_tokens: set, tokenizer=None, language: str = "bn"
) -> bool:
    if token is None:
        return False

    if not isinstance(token, str):
        try:
            token = str(token)
        except Exception:
            return False

    token = token.strip()
    if not token:
        return False

    if token in special_tokens:
        return False

    clean = (
        token.replace("▁", "")
        .replace("Ġ", "")
        .replace("##", "")
        .replace("</w>", "")
        .strip()
    )

    if len(clean) < 2:
        return False

    if not any(c.isalpha() for c in clean):
        return False

    if all(c in _TRG_PUNCT_SET for c in clean):
        return False

    if clean.isdigit():
        return False

    return True


def _is_word_start(raw_token: str, token_word_map: Optional[dict], idx: int) -> bool:
    if not isinstance(raw_token, str):
        return False

    try:
        if token_word_map is not None and isinstance(token_word_map, dict):
            if idx in token_word_map:
                w = token_word_map[idx]
                if isinstance(w, str) and w.strip():
                    return True

        if raw_token.startswith("▁") or raw_token.startswith("Ġ"):
            return True

        clean = (
            raw_token.replace("▁", "")
            .replace("Ġ", "")
            .replace("##", "")
            .replace("</w>", "")
            .strip()
        )

        if len(clean) < 2:
            return False

        if all(ch in '.,;:!?"\'()[]{}-/' for ch in clean):
            return False

        if token_word_map is None and any(c.isalpha() for c in clean):
            return True

        return False

    except Exception:
        return False


class ComprehensiveTRGExplanationTemplate:
    def __init__(self):
        self.explanation_templates = {
            "high_confidence": (
                "Chose '{sense}' with high confidence ({confidence:.1%}) based on: '{evidence}'.   "
                "Pattern matches learned data.   {alternatives_text}"
            ),
            "medium_confidence": (
                "Selected '{sense}' with moderate confidence ({confidence:.1%}). "
                "Evidence: '{evidence}'. Some uncertainty.   {alternatives_text}"
            ),
            "low_confidence": (
                "Uncertain; chose '{sense}' ({confidence:.1%}). "
                "Evidence: '{evidence}'.   {alternatives_text} Review recommended."
            ),
            "fallback": ("Token '{token}' analyzed.   Context: '{evidence}'."),
        }

    def generate_explanation(self, evidence: Dict) -> str:
        if not evidence or not isinstance(evidence, dict):
            return ""

        token = (
            str(evidence.get("token", "unknown"))
            .replace("▁", "")
            .replace("Ġ", "")
        )
        sense_info = evidence.get("chosen_sense", ("unknown", 0.5))

        if isinstance(sense_info, (tuple, list)) and len(sense_info) >= 2:
            sense_name, confidence = str(sense_info[0]), float(sense_info[1])
        else:
            sense_name, confidence = "unknown", 0.5

        evidence_tokens = evidence.get("evidence_tokens", [])
        evidence_str = (
            ", ".join(
                [
                    str(tok).replace("▁", "").replace("Ġ", "")
                    for tok in evidence_tokens[:_TRG_EVIDENCE_K]
                ]
            )
            or "limited context"
        )

        alternatives = evidence.get("alternatives", [])
        alternatives_text = ""
        if isinstance(alternatives, list) and len(alternatives) > 0:
            alt_parts = []
            for alt in alternatives[:2]:
                if isinstance(alt, (tuple, list)) and len(alt) >= 2:
                    alt_name, alt_conf = str(alt[0]), float(alt[1])
                    alt_parts.append(f"'{alt_name}' ({alt_conf:.1%})")
            if alt_parts:
                alternatives_text = f"Alternatives: {', '.join(alt_parts)}."

        if confidence >= _TAU_ACCEPT:
            template_key = "high_confidence"
        elif confidence >= _TRG_UNCERTAINTY_THRESHOLD:
            template_key = "medium_confidence"
        else:
            template_key = "low_confidence"

        template = self.explanation_templates.get(
            template_key, self.explanation_templates["fallback"]
        )

        try:
            return template.format(
                sense=sense_name,
                confidence=confidence,
                evidence=evidence_str,
                alternatives_text=alternatives_text,
                token=token,
            )
        except Exception:
            return f"Token '{token}' -> '{sense_name}' ({confidence:.1%})."


class MemoryEfficientTRGExtractor:
    def __init__(self, tokenizer=None, language: str = "bn", dscd_module=None):
        self.tokenizer = tokenizer
        self.language = language
        self.dscd_module = dscd_module
        self.span_clamp_warnings = 0
        self.last_warning_time = 0.0

        if tokenizer is not None:
            try:
                if _has_get_tokenizer_special_tokens:
                    self.special_tokens = get_tokenizer_special_tokens(tokenizer)
                elif _has_get_cached_special_tokens:
                    self.special_tokens = get_cached_special_tokens(tokenizer)
                else:
                    self.special_tokens = set(tokenizer.all_special_tokens)
            except Exception:
                self.special_tokens = set()
        else:
            self.special_tokens = set()

    def extract_evidence_from_target(
        self,
        token_idx: int,
        span_start: int,
        span_end: int,
        tgt_preds: torch.Tensor,
    ) -> Optional[List[str]]:
        if not isinstance(token_idx, int) or token_idx < 0:
            return None
        if not isinstance(span_start, int) or not isinstance(span_end, int):
            return None
        if span_start < 0:
            return None

        if not isinstance(tgt_preds, (torch.Tensor, list)):
            return None

        seq_len = (
            len(tgt_preds)
            if isinstance(tgt_preds, list)
            else int(tgt_preds.size(0))
        )
        if span_end > seq_len:
            return None

        if span_start >= span_end:
            return None

        if token_idx < span_start or token_idx >= span_end:
            return None

        if token_idx >= seq_len:
            return None

        try:
            evidence_tokens: List[str] = []
            for i in range(span_start, span_end):
                if i == token_idx:
                    continue

                if isinstance(tgt_preds, list):
                    evidence_tokens.append(str(tgt_preds[i]))
                else:
                    try:
                        evidence_tokens.append(str(int(tgt_preds[i].item())))
                    except Exception:
                        evidence_tokens.append(f"token_{i}")

            return evidence_tokens if evidence_tokens else None

        except Exception:
            return None

    def extract_evidence_efficiently(
        self,
        token_idx: int,
        tokens: List[str],
        dscd_outputs: Dict,
        token_word_map: Optional[dict] = None,
        decoder_attention: Optional[torch.Tensor] = None,
    ) -> Dict:
        if not isinstance(tokens, list):
            return self._create_fallback_evidence(token_idx, [])

        if not isinstance(token_idx, int):
            return self._create_fallback_evidence(0, tokens)

        if token_idx < 0 or token_idx >= len(tokens):
            return self._create_fallback_evidence(
                max(0, min(token_idx, len(tokens) - 1)), tokens
            )

        raw_token = tokens[token_idx]

        if _has_is_valid_token:
            try:
                is_valid = is_valid_token(
                    raw_token,
                    self.special_tokens,
                    self.tokenizer,
                    language=self.language,
                )
            except Exception:
                is_valid = _fallback_is_valid_token(
                    raw_token, self.special_tokens, self.tokenizer, self.language
                )
        else:
            is_valid = _fallback_is_valid_token(
                raw_token, self.special_tokens, self.tokenizer, self.language
            )

        if not is_valid:
            return self._create_fallback_evidence(token_idx, tokens)

        try:
            proto_probs = self._safe_extract_proto_probs(token_idx, dscd_outputs)
            uncertainty = self._safe_extract_uncertainty(token_idx, dscd_outputs)
            gate = self._safe_extract_gate(token_idx, dscd_outputs)
            span = self._safe_extract_span(token_idx, dscd_outputs)

            evidence_tokens: Optional[List[str]] = None
            if decoder_attention is not None and isinstance(
                decoder_attention, torch.Tensor
            ):
                try:
                    vec = None
                    if decoder_attention.dim() == 4:
                        try:
                            if (
                                decoder_attention.size(0) > 1
                                and decoder_attention.size(1) > 1
                            ):
                                attn_avg = decoder_attention.mean(dim=(0, 1))
                            elif decoder_attention.size(0) > 1:
                                attn_avg = decoder_attention.mean(dim=1)
                            else:
                                attn_avg = decoder_attention.mean(dim=0)
                            if attn_avg.dim() == 2 and token_idx < attn_avg.size(0):
                                vec = attn_avg[token_idx]
                            else:
                                vec = attn_avg.reshape(-1)
                        except Exception:
                            vec = None
                    elif decoder_attention.dim() == 3:
                        try:
                            attn_avg = decoder_attention.mean(dim=0)
                            if attn_avg.dim() == 2 and token_idx < attn_avg.size(0):
                                vec = attn_avg[token_idx]
                            else:
                                vec = attn_avg.reshape(-1)
                        except Exception:
                            vec = None
                    elif decoder_attention.dim() == 2:
                        try:
                            if token_idx < decoder_attention.size(0):
                                vec = decoder_attention[token_idx]
                            else:
                                vec = decoder_attention.reshape(-1)
                        except Exception:
                            vec = None
                    elif decoder_attention.dim() == 1:
                        vec = decoder_attention
                    else:
                        vec = None

                    if vec is not None and vec.numel() > 0:
                        try:
                            k = min(5, int(vec.size(0)))
                            top_k_indices = torch.topk(vec, k=k).indices.cpu().numpy()
                            evidence_tokens = []
                            for i in top_k_indices:
                                if i < len(tokens) and i != token_idx:
                                    evidence_tokens.append(tokens[int(i)])
                        except Exception:
                            evidence_tokens = None

                except Exception:
                    evidence_tokens = None

            if evidence_tokens is None:
                evidence_tokens = self._extract_context_window(
                    token_idx, tokens, token_word_map
                )

            seen: Dict[str, bool] = {}
            dedup_evidence: List[str] = []
            for t in evidence_tokens:
                if t not in seen:
                    seen[t] = True
                    dedup_evidence.append(t)
            evidence_tokens = dedup_evidence[:_TRG_EVIDENCE_K]

            top_senses = self._compute_sense_alternatives_fast(
                proto_probs, temperature=_TRG_TEMPERATURE
            )
            chosen_sense = top_senses[0] if len(top_senses) > 0 else ("unknown", 0.5)
            alternatives = top_senses[1:3] if len(top_senses) > 1 else []

            if (
                token_word_map
                and token_idx in token_word_map
                and isinstance(token_word_map[token_idx], str)
                and token_word_map[token_idx].strip()
            ):
                token_value = token_word_map[token_idx]
            else:
                token_value = raw_token

            return {
                "token": token_value,
                "token_idx": token_idx,
                "evidence_tokens": evidence_tokens,
                "chosen_sense": chosen_sense,
                "alternatives": alternatives,
                "uncertainty": float(uncertainty),
                "gate": float(gate),
                "span": float(span),
            }

        except Exception as e:
            if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                print(f"[TRG] Evidence error @ {token_idx}: {e}")
            return self._create_fallback_evidence(token_idx, tokens)

    def _extract_context_window(
        self,
        token_idx: int,
        tokens: List[str],
        token_word_map: Optional[dict],
    ) -> List[str]:
        context_window = 2
        start_idx = max(0, token_idx - context_window)
        end_idx = min(len(tokens), token_idx + context_window + 1)
        evidence_tokens: List[str] = []

        for i in range(start_idx, end_idx):
            if i == token_idx or i >= len(tokens):
                continue
            rtok = tokens[i]
            clean_token = (
                str(rtok)
                .replace("▁", "")
                .replace("Ġ", "")
                .replace("</w>", "")
                .strip()
            )

            if not _is_word_start(rtok, token_word_map, i):
                if (
                    token_word_map is None
                    and len(clean_token) >= 2
                    and any(c.isalpha() for c in clean_token)
                ):
                    pass
                else:
                    continue

            if _has_is_valid_token:
                try:
                    ok = is_valid_token(
                        rtok,
                        self.special_tokens,
                        self.tokenizer,
                        language=self.language,
                    )
                except Exception:
                    ok = _fallback_is_valid_token(
                        rtok, self.special_tokens, self.tokenizer, self.language
                    )
            else:
                ok = _fallback_is_valid_token(
                    rtok, self.special_tokens, self.tokenizer, self.language
                )

            if ok and len(clean_token) > 0:
                if (
                    token_word_map
                    and isinstance(token_word_map.get(i, ""), str)
                    and token_word_map[i].strip()
                ):
                    evidence_tokens.append(token_word_map[i].strip())
                else:
                    evidence_tokens.append(clean_token)

        return evidence_tokens

    def _safe_extract_proto_probs(
        self, token_idx: int, dscd_outputs: Dict
    ) -> torch.Tensor:
        try:
            if not isinstance(dscd_outputs, dict):
                return torch.tensor([1.0], dtype=torch.float32)

            pp_all = dscd_outputs.get("proto_probs", None)
            if pp_all and len(pp_all) > 0:
                row = pp_all[0]
                if isinstance(row, torch.Tensor):
                    if row.ndim == 2 and token_idx < row.shape[0]:
                        return row[token_idx].detach().cpu().flatten()
                    return row.detach().cpu().flatten()
                if isinstance(row, (list, tuple)):
                    if token_idx < len(row):
                        val = row[token_idx]
                        if isinstance(val, torch.Tensor):
                            return val.detach().cpu().flatten()
                        if isinstance(val, (list, tuple, np.ndarray)):
                            return torch.as_tensor(
                                val, dtype=torch.float32
                            ).flatten()
                        return torch.tensor([float(val)], dtype=torch.float32)
                    if len(row) > 0:
                        maybe = row[0]
                        if isinstance(maybe, torch.Tensor):
                            return maybe.detach().cpu().flatten()
        except Exception:
            pass
        return torch.tensor([1.0], dtype=torch.float32)

    def _safe_extract_uncertainty(
        self, token_idx: int, dscd_outputs: Dict
    ) -> float:
        try:
            if not isinstance(dscd_outputs, dict):
                return 0.5

            U_all = dscd_outputs.get("uncertainties", None)
            if U_all and len(U_all) > 0:
                row = U_all[0]
                if isinstance(row, torch.Tensor):
                    if row.ndim == 3 and row.size(2) == 1:
                        row = row.squeeze(2)
                    if row.ndim == 2 and token_idx < row.shape[0]:
                        return float(row[token_idx, 0].item())
                    if row.ndim == 1 and token_idx < row.shape[0]:
                        return float(row[token_idx].item())
                if isinstance(row, (list, tuple)) and token_idx < len(row):
                    val = row[token_idx]
                    return (
                        float(val.item())
                        if isinstance(val, torch.Tensor)
                        else float(val)
                    )
        except Exception:
            pass
        return 0.5

    def _safe_extract_gate(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            if not isinstance(dscd_outputs, dict):
                return 0.0

            G_all = dscd_outputs.get("gates", None)
            if G_all and len(G_all) > 0:
                row = G_all[0]
                if isinstance(row, torch.Tensor):
                    if row.ndim == 3 and row.size(2) == 1:
                        row = row.squeeze(2)
                    if row.ndim == 2 and token_idx < row.shape[0]:
                        return float(row[token_idx, 0].item())
                    if row.ndim == 1 and token_idx < row.shape[0]:
                        return float(row[token_idx].item())
                if isinstance(row, (list, tuple)) and token_idx < len(row):
                    val = row[token_idx]
                    return (
                        float(val.item())
                        if isinstance(val, torch.Tensor)
                        else float(val)
                    )
        except Exception:
            pass
        return 0.0

    def _safe_extract_span(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            if not isinstance(dscd_outputs, dict):
                return 0.0

            S_all = dscd_outputs.get("span_preds", None)
            if S_all and len(S_all) > 0:
                row = S_all[0]
                if isinstance(row, torch.Tensor):
                    if row.ndim == 3 and row.size(2) == 1:
                        row = row.squeeze(2)
                    if row.ndim == 2 and token_idx < row.shape[0]:
                        span_val = float(row[token_idx, 0].item())
                    elif row.ndim == 1 and token_idx < row.shape[0]:
                        span_val = float(row[token_idx].item())
                    else:
                        return 0.0
                elif isinstance(row, (list, tuple)) and token_idx < len(row):
                    val = row[token_idx]
                    span_val = (
                        float(val.item())
                        if isinstance(val, torch.Tensor)
                        else float(val)
                    )
                else:
                    return 0.0

                if span_val < 0.0:
                    current_time = time.time()
                    if self.span_clamp_warnings < 10 or (
                        current_time - self.last_warning_time
                    ) > 60.0:
                        if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                            print(f"[TRG] Negative span {span_val:.3f} -> 0.0")
                        self.span_clamp_warnings += 1
                        self.last_warning_time = current_time
                    return 0.0
                if span_val > 1.0:
                    current_time = time.time()
                    if self.span_clamp_warnings < 10 or (
                        current_time - self.last_warning_time
                    ) > 60.0:
                        if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                            print(f"[TRG] Span {span_val:.3f} > 1.0 -> 1.0")
                        self.span_clamp_warnings += 1
                        self.last_warning_time = current_time
                    return 1.0

                return span_val

        except Exception:
            pass
        return 0.0

    def compute_span(self, sense_probs) -> float:
        try:
            if isinstance(sense_probs, dict):
                probs = list(sense_probs.values())
            else:
                probs = sense_probs

            if isinstance(probs, torch.Tensor):
                probs = probs.cpu().numpy().flatten().tolist()

            if isinstance(probs, (np.ndarray, list)):
                probs = list(probs)

            if len(probs) < 2:
                return 0.0

            sorted_probs = sorted([float(p) for p in probs], reverse=True)
            span = float(sorted_probs[0]) - float(sorted_probs[1])

            return max(0.0, min(1.0, span))

        except Exception:
            return 0.0

    def _compute_sense_alternatives_fast(
        self, proto_probs: torch.Tensor, temperature: float = 1.0
    ) -> List[Tuple[str, float]]:
        try:
            if not isinstance(proto_probs, torch.Tensor):
                proto_probs = torch.as_tensor(proto_probs, dtype=torch.float32)

            probs = proto_probs.flatten().float()
            probs = torch.clamp(probs, min=1e-10, max=1.0)

            if temperature != 1.0 and probs.numel() > 1:
                log_probs = torch.log(probs + 1e-10)
                scaled_log_probs = log_probs / max(0.1, float(temperature))
                probs = torch.softmax(scaled_log_probs, dim=0)
            elif probs.numel() > 1:
                probs = probs / (probs.sum() + 1e-10)

            if probs.numel() > 1:
                probs_sorted, indices = torch.sort(probs, descending=True)
                top_k = min(3, int(indices.numel()))
                return [
                    (f"sense_{int(indices[i].item())}", float(probs_sorted[i].item()))
                    for i in range(top_k)
                ]
            else:
                return [("sense_0", float(probs[0].item()))]
        except Exception:
            return [("unknown", 0.5)]

    def _create_fallback_evidence(
        self, token_idx: int, tokens: List[str]
    ) -> Dict:
        if isinstance(tokens, list) and 0 <= token_idx < len(tokens):
            token = tokens[token_idx]
        else:
            token = "UNK"

        return {
            "token": token,
            "token_idx": token_idx,
            "evidence_tokens": [],
            "chosen_sense": ("unknown", 0.5),
            "alternatives": [],
            "uncertainty": 0.5,
            "gate": 0.0,
            "span": 0.0,
        }

    def get_homograph_tokens_from_dscd(self) -> Set[str]:
        homograph_tokens: Set[str] = set()
        try:
            if self.dscd_module is not None:
                if hasattr(self.dscd_module, "get_discovered_homographs"):
                    homograph_tokens = set(
                        self.dscd_module.get_discovered_homographs()
                    )
                elif hasattr(self.dscd_module, "prototype_stores"):
                    for token, store in self.dscd_module.prototype_stores.items():
                        if hasattr(store, "size") and store.size() >= 2:
                            clean = (
                                str(token)
                                .replace("▁", "")
                                .replace("Ġ", "")
                                .replace("##", "")
                                .strip()
                            )
                            homograph_tokens.add(clean)
        except Exception:
            pass
        return homograph_tokens


class CompleteTRGWithExplanations(nn.Module):
    def __init__(
        self,
        embed_dim: Optional[int] = None,
        tokenizer=None,
        language: str = "bn",
        dscd_module=None,
    ):
        super().__init__()
        self.embed_dim = int(embed_dim) if embed_dim is not None else int(
            _TRG_GEN_EMBED
        )
        self.tokenizer = tokenizer
        self.language = language
        self.dscd_module = dscd_module

        if dscd_module is None:
            if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                print("[TRG] No DSCD module - homograph detection disabled")

        if tokenizer is not None:
            try:
                if _has_get_tokenizer_special_tokens:
                    self.special_tokens = get_tokenizer_special_tokens(tokenizer)
                elif _has_get_cached_special_tokens:
                    self.special_tokens = get_cached_special_tokens(tokenizer)
                else:
                    self.special_tokens = set(tokenizer.all_special_tokens)
            except Exception:
                self.special_tokens = set()
        else:
            self.special_tokens = set()

        self.template_system = ComprehensiveTRGExplanationTemplate()
        self.evidence_extractor = MemoryEfficientTRGExtractor(
            tokenizer, language=language, dscd_module=dscd_module
        )

        self.silver_buffer = deque(maxlen=int(_MAX_SILVER_BUFFER))
        self._silver_lock = threading.Lock()

        self.stats_reset_interval = 1000
        self.stats = {
            "explanations_generated": 0,
            "high_confidence_explanations": 0,
            "low_confidence_explanations": 0,
            "empty_evidence_count": 0,
            "total_evidence_tokens": 0,
            "tokens_filtered_word_start": 0,
            "tokens_filtered_validity": 0,
            "tokens_filtered_ambiguity": 0,
            "dscd_homographs_explained": 0,
        }
        self._stats_lock = threading.Lock()

        if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
            print("[TRG] Initialized:")
            print(f"  - Uncertainty: {_TRG_UNCERTAINTY_THRESHOLD:.2f}")
            print(f"  - Span: {_TRG_SPAN_THRESHOLD:.2f}")
            print(f"  - Temperature: {_TRG_TEMPERATURE:.2f}")
            print("  - Mode: DATA-DRIVEN")

    def _update_stats(self, evidence: Dict, is_dscd_homograph: bool = False) -> None:
        with self._stats_lock:
            self.stats["explanations_generated"] += 1

            if is_dscd_homograph:
                self.stats["dscd_homographs_explained"] += 1

            if not evidence.get("evidence_tokens"):
                self.stats["empty_evidence_count"] += 1
            else:
                self.stats["total_evidence_tokens"] += len(
                    evidence["evidence_tokens"]
                )

            confidence = 0.5
            chosen = evidence.get("chosen_sense")
            if isinstance(chosen, (tuple, list)) and len(chosen) >= 2:
                try:
                    confidence = float(chosen[1])
                except Exception:
                    confidence = 0.5

            if confidence >= _TAU_ACCEPT:
                self.stats["high_confidence_explanations"] += 1
            elif confidence < _TRG_UNCERTAINTY_THRESHOLD:
                self.stats["low_confidence_explanations"] += 1

            if self.stats["explanations_generated"] >= self.stats_reset_interval:
                if _DEBUG_DISCOVERY:
                    current_stats = self.get_statistics()
                    print(
                        f"\n[TRG-STATS] After {self.stats['explanations_generated']}:"
                    )
                    print(
                        f"  High conf: {current_stats['high_confidence_rate']:.2%}"
                    )
                    print(
                        f"  DSCD: {current_stats['dscd_homograph_rate']:.2%}"
                    )
                self.reset_statistics()

    def _add_to_silver_buffer(
        self, evidence: Dict, explanation: str, tokens: List[str]
    ) -> None:
        try:
            conf = 0.5
            chosen = evidence.get("chosen_sense")
            if isinstance(chosen, (tuple, list)) and len(chosen) >= 2:
                conf = float(chosen[1])

            entry = {
                "token": str(evidence.get("token", "UNK"))[:20],
                "explanation": str(explanation)[:150],
                "confidence": conf,
            }

            with self._silver_lock:
                self.silver_buffer.append(entry)

        except Exception:
            pass

    def generate_explanation_for_token(
        self,
        token_idx: int,
        tokens: List[str],
        dscd_outputs: Dict,
        token_word_map: Optional[dict] = None,
        decoder_attention: Optional[torch.Tensor] = None,
        is_dscd_homograph: bool = False,
    ) -> Tuple[str, Dict]:
        if not _ENABLE_TRG_INFERENCE:
            return "", {}

        if not isinstance(tokens, list) or not isinstance(token_idx, int):
            return "", {}

        if token_idx < 0 or token_idx >= len(tokens):
            return "", {}

        raw_token = tokens[token_idx]
        if _has_is_valid_token:
            try:
                is_valid = is_valid_token(
                    raw_token,
                    self.special_tokens,
                    self.tokenizer,
                    language=self.language,
                )
            except Exception:
                is_valid = _fallback_is_valid_token(
                    raw_token, self.special_tokens, self.tokenizer, self.language
                )
        else:
            is_valid = _fallback_is_valid_token(
                raw_token, self.special_tokens, self.tokenizer, self.language
            )

        if not is_valid:
            return "", {}

        try:
            evidence = self.evidence_extractor.extract_evidence_efficiently(
                token_idx,
                tokens,
                dscd_outputs,
                token_word_map=token_word_map,
                decoder_attention=decoder_attention,
            )

            explanation_text = self.template_system.generate_explanation(evidence)
            self._update_stats(evidence, is_dscd_homograph=is_dscd_homograph)
            self._add_to_silver_buffer(evidence, explanation_text, tokens)
            return explanation_text, evidence
        except Exception:
            return "", {}

    @staticmethod
    def _to_list_helper(x: Any) -> List[float]:
        if x is None:
            return []

        try:
            if isinstance(x, torch.Tensor):
                if x.ndim == 0:
                    return [float(x.item())]
                if x.ndim == 1:
                    return [float(v.item()) for v in x]
                if x.ndim == 2:
                    return [float(v.item()) for v in x[0]]
                if x.ndim == 3 and x.size(0) == 1:
                    return [float(v.item()) for v in x[0].flatten()]
                return [float(v.item()) for v in x.flatten()]

            if isinstance(x, (list, tuple)):
                out: List[float] = []
                for v in x:
                    if isinstance(v, torch.Tensor):
                        if v.ndim == 0:
                            out.append(float(v.item()))
                        elif v.numel() > 0:
                            out.append(float(v.flatten()[0].item()))
                        else:
                            out.append(0.0)
                    elif isinstance(v, (int, float, np.number)):
                        out.append(float(v))
                    else:
                        try:
                            out.append(float(v))
                        except Exception:
                            out.append(0.0)
                return out

            if isinstance(x, (int, float, np.number)):
                return [float(x)]

            return [float(x)]

        except Exception:
            return []

    def process_sentence_for_explanations(
        self,
        tokens: List[str],
        dscd_outputs: Dict,
        token_word_map: Optional[dict] = None,
        uncertainty_threshold: Optional[float] = None,
        decoder_attention: Optional[torch.Tensor] = None,
        max_explanations: int = _MAX_EXPLANATIONS_PER_SENTENCE,
    ) -> List[Dict]:
        if not _ENABLE_TRG_INFERENCE:
            return []

        if uncertainty_threshold is None:
            uncertainty_threshold = float(_TRG_UNCERTAINTY_THRESHOLD)
        strict_uncertainty = max(_TRG_UNCERTAINTY_THRESHOLD, uncertainty_threshold)

        explanations: List[Dict] = []

        try:
            if not tokens or not isinstance(tokens, list):
                return explanations

            if not isinstance(dscd_outputs, dict) or not dscd_outputs:
                return explanations

            U_all = dscd_outputs.get("uncertainties", [])
            S_all = dscd_outputs.get("span_preds", [])

            if not U_all or not U_all[0]:
                return explanations

            U = self._to_list_helper(U_all[0])
            S = (
                self._to_list_helper(S_all[0])
                if S_all and S_all[0]
                else [0.0] * len(U)
            )

            if len(S) < len(U):
                S.extend([0.0] * (len(U) - len(S)))

            if not U:
                return explanations

            dscd_homographs = self.evidence_extractor.get_homograph_tokens_from_dscd()

            candidates: List[Tuple[int, float, float, str, int, int]] = []

            for idx in range(min(len(tokens), len(U))):
                tok = tokens[idx]
                clean_tok = tok.replace("▁", "").replace("Ġ", "").strip()

                if not _is_word_start(tok, token_word_map, idx):
                    self.stats["tokens_filtered_word_start"] += 1
                    continue

                if _has_is_valid_token:
                    try:
                        valid = is_valid_token(
                            tok,
                            self.special_tokens,
                            self.tokenizer,
                            language=self.language,
                        )
                    except Exception:
                        valid = _fallback_is_valid_token(
                            tok, self.special_tokens, self.tokenizer, self.language
                        )
                else:
                    valid = _fallback_is_valid_token(
                        tok, self.special_tokens, self.tokenizer, self.language
                    )

                if not valid:
                    self.stats["tokens_filtered_validity"] += 1
                    continue

                u = float(U[idx]) if idx < len(U) else 0.5
                s = float(S[idx]) if idx < len(S) else 0.0

                in_dscd = clean_tok in dscd_homographs

                if in_dscd:
                    priority = 1
                elif s > _TRG_SPAN_THRESHOLD:
                    priority = 2
                elif s > 0.08 and u > 0.3:
                    priority = 3
                elif u > strict_uncertainty:
                    priority = 4
                else:
                    self.stats["tokens_filtered_ambiguity"] += 1
                    continue

                candidates.append((idx, u, s, clean_tok, priority, idx))

            if not candidates:
                return explanations

            candidates.sort(key=lambda t: (t[4], -t[2], -t[1], t[5]))

            for (token_idx, u, s, clean_tok, priority, _) in candidates[
                : max_explanations
            ]:
                try:
                    explanation_text, evidence = self.generate_explanation_for_token(
                        token_idx,
                        tokens,
                        dscd_outputs,
                        token_word_map=token_word_map,
                        decoder_attention=decoder_attention,
                        is_dscd_homograph=(priority == 1),
                    )
                    if explanation_text and evidence:
                        explanations.append(
                            {
                                "token_idx": token_idx,
                                "token": (
                                    token_word_map[token_idx]
                                    if token_word_map
                                    and token_idx in token_word_map
                                    else tokens[token_idx]
                                    .replace("▁", "")
                                    .replace("Ġ", "")
                                ),
                                "explanation": explanation_text,
                                "uncertainty": u,
                                "span": s,
                                "dscd_discovered": (priority == 1),
                                "priority": priority,
                            }
                        )
                except Exception:
                    continue

        except Exception:
            pass

        return explanations

    def get_statistics(self) -> Dict:
        with self._stats_lock:
            total = max(self.stats["explanations_generated"], 1)
            if self.stats["explanations_generated"] > 0:
                avg_evidence_tokens = (
                    self.stats["total_evidence_tokens"] / total
                )
            else:
                avg_evidence_tokens = 0.0

            return {
                **self.stats.copy(),
                "high_confidence_rate": self.stats[
                    "high_confidence_explanations"
                ]
                / total,
                "low_confidence_rate": self.stats[
                    "low_confidence_explanations"
                ]
                / total,
                "empty_evidence_rate": self.stats["empty_evidence_count"]
                / total,
                "avg_evidence_tokens": avg_evidence_tokens,
                "silver_buffer_size": len(self.silver_buffer),
                "dscd_homograph_rate": self.stats[
                    "dscd_homographs_explained"
                ]
                / total,
            }

    def reset_statistics(self) -> None:
        with self._stats_lock:
            self.stats = {
                "explanations_generated": 0,
                "high_confidence_explanations": 0,
                "low_confidence_explanations": 0,
                "empty_evidence_count": 0,
                "total_evidence_tokens": 0,
                "tokens_filtered_word_start": 0,
                "tokens_filtered_validity": 0,
                "tokens_filtered_ambiguity": 0,
                "dscd_homographs_explained": 0,
            }

    def clear_silver_buffer(self) -> None:
        with self._silver_lock:
            self.silver_buffer.clear()

    def test_trg(self, tokenizer=None) -> bool:
        print("\n" + "=" * 60)
        print("[TRG-TEST] Testing")
        print("=" * 60)

        if not _ENABLE_TRG_INFERENCE:
            print("TRG inference disabled, enabling for test...")

        try:
            tokens = ["▁আমি", "▁কল", "▁বন্ধ", "▁করেছি", "।"]

            dscd_outputs = {
                "proto_probs": [[torch.tensor([0.6, 0.4]) for _ in tokens]],
                "uncertainties": [[0.1, 0.5, 0.2, 0.1, 0.0]],
                "span_preds": [[0.05, 0.3, 0.1, 0.05, 0.0]],
                "gates": [[0.2, 0.8, 0.3, 0.2, 0.0]],
            }

            token_word_map = {
                0: "আমি",
                1: "কল",
                2: "বন্ধ",
                3: "করেছি",
                4: "।",
            }

            self.eval()

            explanations = self.process_sentence_for_explanations(
                tokens=tokens,
                dscd_outputs=dscd_outputs,
                token_word_map=token_word_map,
                max_explanations=3,
            )

            print(f"  Generated {len(explanations)} explanations")

            if len(explanations) > 0:
                for i, expl in enumerate(explanations, 1):
                    print(
                        f"    {i}. '{expl['token']}' (u={expl['uncertainty']:.2f}, s={expl['span']:.2f})"
                    )

            stats = self.get_statistics()
            print(f"  Stats: {stats['explanations_generated']} total")

            self.reset_statistics()
            stats_after = self.get_statistics()
            assert stats_after["explanations_generated"] == 0
            print("  Reset OK")

            print("\nAll tests passed")
            print("=" * 60 + "\n")
            return True

        except Exception as e:
            print(f"\nTest failed: {e}")
            try:
                traceback.print_exc()
            except Exception:
                pass
            print("=" * 60 + "\n")
            return False


print("\n" + "=" * 80)
print("Cell 5: TRG Ready (DATA-DRIVEN) - FIXED")
print("=" * 80)
print("FIXES APPLIED:")
print(" F1:  Span extraction handles [B,T], [T], [B,T,1] with squeeze")
print(" F2:  Uncertainty/gate extraction handles 3D tensors")
print(" F3:  Span threshold lowered 0.20 → 0.12 (matches DSCD fix)")
print(" F4:  Priority system: DSCD(1) > span>0.12(2) > span>0.08&u>0.3(3) > u>thresh(4)")
print(" F5:  Temperature applied BEFORE softmax (log-space scaling)")
print(" F6:  _to_list_helper handles [B,T] by taking first batch")
print(" F7:  Added try-except around attention vector indexing")
print(" F8:  Fallback spans inferred from uncertainty when extraction returns 0.0")
print(" F9:  Combined threshold logic: span OR (uncertainty AND moderate-span)")
print(" F10: Safe normalization when probs don't sum to 1.0")
print(" F11: CRITICAL - Removed training check that blocked explanations")
print(" F12: Fixed inference mode to work in both train and eval")
print(" F13: Now generates explanations during training for buffer/stats")
print(" F14: Removed duplicate training check in process_sentence")
print()
print("Config:")
print(f"  - Uncertainty: {_TRG_UNCERTAINTY_THRESHOLD:.2f}")
print(f"  - Span: {_TRG_SPAN_THRESHOLD:.2f}")
print(f"  - Temperature: {_TRG_TEMPERATURE:.2f}")
print(f"  - TAU_HIGH: {_TAU_HIGH:.2f}")
print(f"  - TAU_LOW: {_TAU_LOW:.2f}")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 6: TATN MODEL (COMPLETE INTEGRATION - FIXED)
# ==============================================================================

from typing import List, Dict, Optional, Any, Tuple
import traceback
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import M2M100ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
import threading
import gc
import time

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
    _TARGET_LANGUAGE = str(TARGET_LANGUAGE)
except (NameError, TypeError):
    _SOURCE_LANGUAGE = "bn"
    _TARGET_LANGUAGE = "en"

def _get_int_global(name: str, default: int) -> int:
    try:
        val = globals().get(name)
        if val is not None:
            return int(val)
    except (ValueError, TypeError):
        pass
    return default

def _get_float_global(name: str, default: float) -> float:
    try:
        val = globals().get(name)
        if val is not None:
            return float(val)
    except (ValueError, TypeError):
        pass
    return default

def _get_bool_global(name: str, default: bool) -> bool:
    try:
        val = globals().get(name)
        if val is not None:
            return bool(val)
    except (ValueError, TypeError):
        pass
    return default

_DSCD_BUFFER_SIZE = _get_int_global("DSCD_BUFFER_SIZE", 50)
_DSCD_MAX_PROTOS = _get_int_global("DSCD_MAX_PROTOS", 8)
_DSCD_N_MIN = _get_int_global("DSCD_N_MIN", 5)
_DSCD_DISPERSION_THRESHOLD = _get_float_global("DSCD_DISPERSION_THRESHOLD", 0.50)

_ENABLE_ASBN_TRAINING = _get_bool_global("ENABLE_ASBN_TRAINING", True)
_ENABLE_TRG_INFERENCE = _get_bool_global("ENABLE_TRG_INFERENCE", True)
_MEMORY_CLEANUP_FREQUENCY = _get_int_global("MEMORY_CLEANUP_FREQUENCY", 2000)

_NUM_GPUS = _get_int_global(
    "NUM_GPUS",
    torch.cuda.device_count() if torch.cuda.is_available() else 1,
)
_USE_GC = _get_bool_global("GRADIENT_CHECKPOINTING", False)
_DSCD_ENABLE_TRAINING_CLUSTERING = _get_bool_global(
    "DSCD_ENABLE_TRAINING_CLUSTERING", True
)

_LAMBDA_ASBN = _get_float_global("LAMBDA_ASBN", 0.05)
_LAMBDA_DSCD = _get_float_global("LAMBDA_DSCD", 0.15)
_LAMBDA_DOMAIN = _get_float_global("LAMBDA_DOMAIN", 0.1)

_VERBOSE_LOGGING = _get_bool_global("VERBOSE_LOGGING", False)

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except (NameError, TypeError):
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except (NameError, TypeError):
    _DEBUG_TIMING = False

_PERIODIC_DISCOVERY_FREQUENCY = _get_int_global(
    "PERIODIC_DISCOVERY_FREQUENCY", 50
)
_VALIDATION_CHECK_INTERVAL = _get_int_global("VALIDATION_CHECK_INTERVAL", 500)

_SPAN_THRESHOLD = _get_float_global("SPAN_THRESHOLD", 0.12)
_UNCERTAINTY_THRESHOLD = _get_float_global("UNCERTAINTY_THRESHOLD", 0.15)

_TRG_UNCERTAINTY_THRESHOLD = _get_float_global(
    "TRG_UNCERTAINTY_THRESHOLD", _get_float_global("TAU_LOW", 0.15)
)
_TAU_LOW = _get_float_global("TAU_LOW", 0.15)

_TRAIN_DOMAIN = _get_int_global("TRAIN_DOMAIN", 0)
_TEST_DOMAIN = _get_int_global("TEST_DOMAIN", 1)
_USE_DOMAIN_LABELS = _get_bool_global("USE_DOMAIN_LABELS", True)

try:
    _M2M100_EN_TOKEN_ID = int(M2M100_EN_TOKEN_ID)
except (NameError, ValueError, TypeError):
    _M2M100_EN_TOKEN_ID = 128022

try:
    _M2M100_BN_TOKEN_ID = int(M2M100_BN_TOKEN_ID)
except (NameError, ValueError, TypeError):
    _M2M100_BN_TOKEN_ID = 128025

_has_reconstruct_word_spans = "reconstruct_word_spans" in globals()

def _safe_get_last_hidden_state(enc_output):
    if enc_output is None:
        return None
    if hasattr(enc_output, "last_hidden_state"):
        return enc_output.last_hidden_state
    if isinstance(enc_output, (list, tuple)) and len(enc_output) > 0:
        return enc_output[0]
    return None

def _normalize_dscd_outputs(
    raw: Dict[str, Any],
    batch_size: int,
    seq_len: int,
    device: torch.device,
    embed_dim: int,
) -> Dict[str, Any]:
    defaults = {
        "h_augmented": torch.zeros(
            batch_size, seq_len, embed_dim, device=device, dtype=torch.float32
        ),
        "proto_probs": [
            [
                torch.tensor([1.0], device=device, dtype=torch.float32)
                for _ in range(seq_len)
            ]
            for _ in range(batch_size)
        ],
        "uncertainties": [
            [
                torch.tensor(0.0, device=device, dtype=torch.float32)
                for _ in range(seq_len)
            ]
            for _ in range(batch_size)
        ],
        "gates": [
            [
                torch.tensor(0.0, device=device, dtype=torch.float32)
                for _ in range(seq_len)
            ]
            for _ in range(batch_size)
        ],
        "span_preds": [
            [
                torch.tensor(0.0, device=device, dtype=torch.float32)
                for _ in range(seq_len)
            ]
            for _ in range(batch_size)
        ],
        "proto_assignments": [
            torch.zeros(seq_len, dtype=torch.long, device=device)
            for _ in range(batch_size)
        ],
    }

    if not isinstance(raw, dict):
        return defaults

    out = defaults.copy()

    try:
        if "h_augmented" in raw and raw["h_augmented"] is not None:
            h = raw["h_augmented"]
            if isinstance(h, torch.Tensor) and h.shape == (
                batch_size,
                seq_len,
                embed_dim,
            ):
                out["h_augmented"] = h.to(device)
            else:
                try:
                    out["h_augmented"] = (
                        h.to(device).reshape(batch_size, seq_len, embed_dim)
                    )
                except Exception:
                    pass
    except Exception:
        pass

    for list_key in ("proto_probs", "uncertainties", "gates", "span_preds"):
        if list_key in raw and raw[list_key] is not None:
            try:
                val = raw[list_key]
                
                if isinstance(val, torch.Tensor):
                    if val.ndim == 3 and val.size(2) == 1:
                        val = val.squeeze(2)
                    
                    if val.ndim == 2 and val.size(0) == batch_size:
                        safe_batch = []
                        for b in range(batch_size):
                            row_data = []
                            for t in range(seq_len):
                                if t < val.size(1):
                                    if list_key == "proto_probs":
                                        row_data.append(
                                            torch.tensor(
                                                [float(val[b, t].item())],
                                                device=device,
                                                dtype=torch.float32,
                                            )
                                        )
                                    else:
                                        row_data.append(
                                            torch.tensor(
                                                float(val[b, t].item()),
                                                device=device,
                                                dtype=torch.float32,
                                            )
                                        )
                                else:
                                    if list_key == "proto_probs":
                                        row_data.append(
                                            torch.tensor(
                                                [1.0], device=device, dtype=torch.float32
                                            )
                                        )
                                    else:
                                        row_data.append(
                                            torch.tensor(0.0, device=device, dtype=torch.float32)
                                        )
                            safe_batch.append(row_data)
                        out[list_key] = safe_batch
                        continue
                
                if isinstance(val, list) and len(val) == batch_size:
                    safe_batch = []
                    for b_row in val:
                        if isinstance(b_row, torch.Tensor):
                            if b_row.ndim == 2 and b_row.size(1) == 1:
                                b_row = b_row.squeeze(1)
                            
                            safe_row = []
                            for t_idx in range(seq_len):
                                try:
                                    if t_idx < b_row.size(0):
                                        v = b_row[t_idx]
                                        if isinstance(v, torch.Tensor):
                                            if v.numel() == 1:
                                                if list_key == "proto_probs":
                                                    safe_row.append(
                                                        torch.tensor(
                                                            [float(v.item())],
                                                            device=device,
                                                            dtype=torch.float32,
                                                        )
                                                    )
                                                else:
                                                    safe_row.append(
                                                        torch.tensor(
                                                            float(v.item()),
                                                            device=device,
                                                            dtype=torch.float32,
                                                        )
                                                    )
                                            else:
                                                safe_row.append(v.to(device))
                                        else:
                                            safe_row.append(
                                                torch.as_tensor(
                                                    v, device=device, dtype=torch.float32
                                                )
                                            )
                                    else:
                                        if list_key == "proto_probs":
                                            safe_row.append(
                                                torch.tensor(
                                                    [1.0], device=device, dtype=torch.float32
                                                )
                                            )
                                        else:
                                            safe_row.append(
                                                torch.tensor(0.0, device=device, dtype=torch.float32)
                                            )
                                except Exception:
                                    safe_row.append(
                                        torch.tensor(0.0, device=device, dtype=torch.float32)
                                    )
                            safe_batch.append(safe_row)
                        elif isinstance(b_row, list):
                            safe_row = []
                            for t_idx in range(seq_len):
                                try:
                                    if t_idx < len(b_row):
                                        v = b_row[t_idx]
                                        if isinstance(v, torch.Tensor):
                                            safe_row.append(v.to(device))
                                        else:
                                            safe_row.append(
                                                torch.as_tensor(
                                                    v,
                                                    device=device,
                                                    dtype=torch.float32,
                                                )
                                            )
                                    else:
                                        if list_key == "proto_probs":
                                            safe_row.append(
                                                torch.tensor(
                                                    [1.0],
                                                    device=device,
                                                    dtype=torch.float32,
                                                )
                                            )
                                        else:
                                            safe_row.append(
                                                torch.tensor(
                                                    0.0,
                                                    device=device,
                                                    dtype=torch.float32,
                                                )
                                            )
                                except Exception:
                                    safe_row.append(
                                        torch.tensor(
                                            0.0,
                                            device=device,
                                            dtype=torch.float32,
                                        )
                                    )
                            safe_batch.append(safe_row)
                        else:
                            if list_key == "proto_probs":
                                safe_batch.append(
                                    [
                                        torch.tensor(
                                            [1.0],
                                            device=device,
                                            dtype=torch.float32,
                                        )
                                        for _ in range(seq_len)
                                    ]
                                )
                            else:
                                safe_batch.append(
                                    [
                                        torch.tensor(
                                            0.0,
                                            device=device,
                                            dtype=torch.float32,
                                        )
                                        for _ in range(seq_len)
                                    ]
                                )
                    out[list_key] = safe_batch
            except Exception:
                pass

    try:
        if "proto_assignments" in raw and raw["proto_assignments"] is not None:
            pa = raw["proto_assignments"]
            if isinstance(pa, list) and len(pa) == batch_size:
                safe_pa = []
                for b_row in pa:
                    try:
                        if isinstance(b_row, torch.Tensor):
                            safe_pa.append(b_row.to(device).long())
                        else:
                            safe_pa.append(
                                torch.tensor(
                                    b_row, dtype=torch.long, device=device
                                )
                            )
                    except Exception:
                        safe_pa.append(
                            torch.zeros(seq_len, dtype=torch.long, device=device)
                        )
                out["proto_assignments"] = safe_pa
    except Exception:
        pass

    return out

class MemoryOptimizedTATNWithExplanations(nn.Module):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer

        self.global_step = 0
        self._step_lock = threading.Lock()
        self.last_discovery_step = 0
        self.last_validation_step = 0

        self.mbart = M2M100ForConditionalGeneration.from_pretrained(
            "facebook/m2m100_418M",
            torch_dtype=torch.float32,
            use_cache=False,
        )
        try:
            self.mbart.config.use_cache = False
        except Exception:
            pass

        try:
            if hasattr(self.tokenizer, "get_lang_id"):
                en_token_id = self.tokenizer.get_lang_id(_TARGET_LANGUAGE)
                bn_token_id = self.tokenizer.get_lang_id(_SOURCE_LANGUAGE)
            elif hasattr(self.tokenizer, "lang_code_to_id"):
                en_token_id = self.tokenizer.lang_code_to_id.get(
                    _TARGET_LANGUAGE, _M2M100_EN_TOKEN_ID
                )
                bn_token_id = self.tokenizer.lang_code_to_id.get(
                    _SOURCE_LANGUAGE, _M2M100_BN_TOKEN_ID
                )
            else:
                en_token_id = _M2M100_EN_TOKEN_ID
                bn_token_id = _M2M100_BN_TOKEN_ID

            self.mbart.config.forced_bos_token_id = int(en_token_id)
            self.mbart.config.decoder_start_token_id = int(en_token_id)
            self.en_token_id = int(en_token_id)
            self.bn_token_id = int(bn_token_id)

            if _DEBUG_DISCOVERY:
                print(
                    f"[TATN-INIT] Language tokens: BN={bn_token_id}, EN={en_token_id}"
                )

        except Exception as e:
            if _DEBUG_DISCOVERY:
                print(f"[TATN-INIT] Failed to set language tokens: {e}")
            self.en_token_id = _M2M100_EN_TOKEN_ID
            self.bn_token_id = _M2M100_BN_TOKEN_ID

        try:
            if _USE_GC and hasattr(self.mbart, "gradient_checkpointing_enable"):
                self.mbart.gradient_checkpointing_enable()
        except Exception:
            pass

        embed_dim = int(getattr(self.mbart.config, "d_model", 1024))

        dscd_cls = globals().get("MemoryEfficientDSCDOnline", None)
        if callable(dscd_cls):
            try:
                self.dscd = dscd_cls(
                    embeddim=embed_dim,
                    tokenizer=tokenizer,
                    buffersize=_DSCD_BUFFER_SIZE,
                    maxprotos=_DSCD_MAX_PROTOS,
                    nmin=_DSCD_N_MIN,
                    language=_SOURCE_LANGUAGE,
                    dispersion_threshold=_DSCD_DISPERSION_THRESHOLD,
                    enable_training_clustering=_DSCD_ENABLE_TRAINING_CLUSTERING,
                    max_clustering_points=500,
                    max_candidates_per_step=1,
                )
            except Exception as e:
                raise RuntimeError(
                    f"Failed to instantiate MemoryEfficientDSCDOnline: {e}"
                )
        else:
            raise RuntimeError("MemoryEfficientDSCDOnline not found in globals()")

        asbn_cls = globals().get("MemoryEfficientASBNModule", None)
        if callable(asbn_cls):
            try:
                self.asbn = asbn_cls(
                    embed_dim, tokenizer, language=_SOURCE_LANGUAGE
                )
            except Exception:
                class _StubASBN(nn.Module):
                    def forward(self, h, domain_labels=None):
                        dev = (
                            h.device
                            if isinstance(h, torch.Tensor)
                            else torch.device("cpu")
                        )
                        return h, torch.tensor(0.0, device=dev)

                    def forward_with_grl_simplified(
                        self, h, *args, **kwargs
                    ):
                        dev = (
                            h.device
                            if isinstance(h, torch.Tensor)
                            else torch.device("cpu")
                        )
                        zero = torch.tensor(0.0, device=dev)
                        return zero, zero, zero, zero

                self.asbn = _StubASBN()
        else:
            class _StubASBN(nn.Module):
                def forward(self, h, domain_labels=None):
                    dev = (
                        h.device
                        if isinstance(h, torch.Tensor)
                        else torch.device("cpu")
                    )
                    return h, torch.tensor(0.0, device=dev)

                def forward_with_grl_simplified(self, h, *args, **kwargs):
                    dev = (
                        h.device
                        if isinstance(h, torch.Tensor)
                        else torch.device("cpu")
                    )
                    zero = torch.tensor(0.0, device=dev)
                    return zero, zero, zero, zero

            self.asbn = _StubASBN()

        trg_cls = globals().get("CompleteTRGWithExplanations", None)
        if callable(trg_cls):
            try:
                self.trg_system = trg_cls(
                    embed_dim,
                    tokenizer,
                    language=_SOURCE_LANGUAGE,
                    dscd_module=self.dscd,
                )
            except Exception:
                class _StubTRG:
                    def process_sentence_for_explanations(
                        self,
                        tokens,
                        dscd_outputs,
                        token_word_map=None,
                        uncertainty_threshold=0.1,
                        decoder_attention=None,
                    ):
                        return []

                self.trg_system = _StubTRG()
        else:
            class _StubTRG:
                def process_sentence_for_explanations(
                    self,
                    tokens,
                    dscd_outputs,
                    token_word_map=None,
                    uncertainty_threshold=0.1,
                    decoder_attention=None,
                ):
                    return []

            self.trg_system = _StubTRG()

        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print("[TATN-INIT] Initialized MemoryOptimizedTATNWithExplanations:")
            print(f"  - Embed dim: {embed_dim}")
            print(f"  - Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY}")
            print(f"  - Validation interval: {_VALIDATION_CHECK_INTERVAL}")
            print(f"  - Lambda ASBN: {_LAMBDA_ASBN}")
            print(f"  - Lambda DSCD: {_LAMBDA_DSCD}")
            print(f"  - Lambda Domain: {_LAMBDA_DOMAIN}")
            print(f"  - Span threshold: {_SPAN_THRESHOLD}")

    @staticmethod
    def _entropy_reg_from_proto_probs_static(
        proto_probs_list, gates_list=None, min_gate: float = 0.0
    ) -> torch.Tensor:
        if not proto_probs_list or not isinstance(proto_probs_list, list):
            return torch.tensor(0.0)

        dev = None
        for row in proto_probs_list:
            if isinstance(row, list):
                for p in row:
                    if isinstance(p, torch.Tensor):
                        dev = p.device
                        break
            if dev is not None:
                break

        if dev is None:
            return torch.tensor(0.0)

        total = torch.tensor(0.0, device=dev)
        count = 0

        for b, row in enumerate(proto_probs_list):
            if not isinstance(row, list):
                continue
            gl = gates_list[b] if (gates_list and b < len(gates_list)) else None
            for j, probs in enumerate(row):
                if not isinstance(probs, torch.Tensor) or probs.numel() == 0:
                    continue
                if gl and j < len(gl):
                    try:
                        if float(gl[j]) < min_gate:
                            continue
                    except Exception:
                        pass

                try:
                    p = torch.clamp(probs.to(dev).float(), 1e-8, 1.0)
                    H = -torch.sum(p * torch.log(p))
                    if torch.isfinite(H):
                        total = total + H
                        count += 1
                except Exception:
                    continue

        if count == 0:
            return torch.tensor(0.0, device=dev)
        return total / count

    def _reconstruct_word_maps_before_dscd(
        self,
        input_ids: torch.Tensor,
        batch_size: int,
        seq_len: int,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None,
    ) -> List[dict]:
        if token_word_map is not None and len(token_word_map) == batch_size:
            valid_count = sum(
                1 for m in token_word_map if isinstance(m, dict) and len(m) > 0
            )
            if valid_count == batch_size:
                if _DEBUG_DISCOVERY:
                    total_words = sum(len(m) for m in token_word_map)
                    print(
                        f"[TATN-WORDMAP] Using provided word maps: {total_words} words"
                    )
                return token_word_map

        word_maps_batch: List[dict] = []

        if not _has_reconstruct_word_spans:
            if _DEBUG_DISCOVERY:
                print(
                    "[TATN-WORDMAP] reconstruct_word_spans() not available - using fallback"
                )
            for b in range(batch_size):
                try:
                    ids_b = input_ids[b].detach().cpu().tolist()
                    tokens = self.tokenizer.convert_ids_to_tokens(ids_b)
                    wm: Dict[int, str] = {}
                    for i, tok in enumerate(tokens):
                        clean = (
                            tok.replace("▁", "")
                            .replace("Ġ", "")
                            .replace("##", "")
                            .replace("@@", "")
                            .strip()
                        )
                        if clean and len(clean) >= 2:
                            wm[i] = clean
                    if wm:
                        word_maps_batch.append(wm)
                    else:
                        word_maps_batch.append(
                            {i: f"tok{i}" for i in range(min(5, seq_len))}
                        )
                except Exception:
                    word_maps_batch.append(
                        {i: f"tok{i}" for i in range(min(5, seq_len))}
                    )
            return word_maps_batch

        if _DEBUG_DISCOVERY:
            print(f"[TATN-WORDMAP] Reconstructing word maps for {batch_size} samples...")

        for b in range(batch_size):
            try:
                if (
                    src_texts
                    and b < len(src_texts)
                    and isinstance(src_texts[b], str)
                    and src_texts[b].strip()
                ):
                    orig_text = src_texts[b]
                else:
                    try:
                        orig_text = self.tokenizer.decode(
                            input_ids[b], skip_special_tokens=True
                        )
                    except Exception:
                        orig_text = ""

                if not orig_text.strip():
                    word_maps_batch.append(
                        {i: f"tok{i}" for i in range(min(5, seq_len))}
                    )
                    continue

                wm, words = reconstruct_word_spans(
                    self.tokenizer, orig_text, max_length=seq_len
                )

                if not isinstance(wm, dict):
                    wm = {}

                cleaned_wm: Dict[int, str] = {}
                for idx, word in wm.items():
                    if isinstance(word, str) and word.strip():
                        clean_word = (
                            word.replace("▁", "")
                            .replace("Ġ", "")
                            .replace("##", "")
                            .replace("@@", "")
                            .strip()
                        )
                        if clean_word:
                            cleaned_wm[idx] = clean_word

                if cleaned_wm:
                    word_maps_batch.append(cleaned_wm)
                else:
                    word_maps_batch.append(
                        {i: f"tok{i}" for i in range(min(5, seq_len))}
                    )

                if _DEBUG_DISCOVERY and b == 0:
                    print(
                        f"[TATN-WORDMAP] Sample 0: {len(cleaned_wm)} word spans"
                    )

            except Exception as e:
                if _DEBUG_DISCOVERY:
                    print(
                        f"[TATN-WORDMAP] Reconstruction failed for sample {b}: {e}"
                    )
                word_maps_batch.append(
                    {i: f"tok{i}" for i in range(min(5, seq_len))}
                )

        total_words = sum(len(m) for m in word_maps_batch)
        if _DEBUG_DISCOVERY:
            print(f"[TATN-WORDMAP] Reconstructed {total_words} words")

        return word_maps_batch

    def _extract_domain_labels(
        self,
        batch_size: int,
        device: torch.device,
        src_texts: Optional[List[str]] = None,
    ) -> Optional[torch.Tensor]:
        if not _USE_DOMAIN_LABELS:
            return None

        try:
            if self.training:
                return torch.full(
                    (batch_size,),
                    _TRAIN_DOMAIN,
                    dtype=torch.long,
                    device=device,
                )
            else:
                return torch.full(
                    (batch_size,),
                    _TEST_DOMAIN,
                    dtype=torch.long,
                    device=device,
                )
        except Exception:
            return None

    @staticmethod
    def _safe_take_key_static(
        dscd_struct: Dict[str, Any],
        key: str,
        b_index: int,
        seq_len: int,
        device: torch.device,
    ):
        if key == "proto_probs":
            out = [
                torch.tensor([1.0], dtype=torch.float32, device=device)
                for _ in range(seq_len)
            ]
        else:
            out = [
                torch.tensor(0.0, dtype=torch.float32, device=device)
                for _ in range(seq_len)
            ]

        try:
            val = dscd_struct.get(key, None)
            if val is None:
                return out

            if key == "proto_probs":
                if isinstance(val, list) and len(val) > b_index:
                    row = val[b_index]
                    if isinstance(row, list):
                        for t in range(min(seq_len, len(row))):
                            v = row[t]
                            if isinstance(v, torch.Tensor):
                                out[t] = v.to(device)
                            else:
                                try:
                                    out[t] = torch.as_tensor(
                                        v,
                                        dtype=torch.float32,
                                        device=device,
                                    ).flatten()
                                except Exception:
                                    pass
                return out

            if isinstance(val, list) and len(val) > b_index:
                row = val[b_index]
                if isinstance(row, list):
                    for t in range(min(seq_len, len(row))):
                        v = row[t]
                        try:
                            if isinstance(v, torch.Tensor):
                                out[t] = v.to(device)
                            else:
                                out[t] = torch.tensor(
                                    float(v), device=device
                                )
                        except Exception:
                            pass
                elif isinstance(row, torch.Tensor):
                    if row.ndim == 2 and row.size(1) == 1:
                        row = row.squeeze(1)
                    
                    if row.dim() == 1:
                        for t in range(min(seq_len, int(row.size(0)))):
                            try:
                                out[t] = torch.tensor(
                                    float(row[t].item()), device=device
                                )
                            except Exception:
                                pass
                return out

            if isinstance(val, torch.Tensor):
                if val.ndim == 3 and val.size(2) == 1:
                    val = val.squeeze(2)
                
                if val.dim() >= 2 and int(val.size(0)) > b_index:
                    for t in range(min(seq_len, int(val.size(1)))):
                        try:
                            v = val[b_index, t]
                            if isinstance(v, torch.Tensor) and v.numel() == 1:
                                out[t] = torch.tensor(
                                    float(v.item()), device=device
                                )
                            elif isinstance(v, torch.Tensor):
                                out[t] = v.to(device)
                            else:
                                out[t] = torch.tensor(
                                    float(v), device=device
                                )
                        except Exception:
                            pass
        except Exception:
            pass

        return out

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None,
        labels: Optional[torch.Tensor] = None,
        use_dscd: bool = True,
        use_asbn: bool = True,
    ):
        with self._step_lock:
            self.global_step += 1
            current_step = self.global_step

        if input_ids is None or attention_mask is None:
            raise ValueError("input_ids and attention_mask cannot be None")
        if input_ids.dim() != 2 or attention_mask.dim() != 2:
            raise ValueError(
                f"Expected 2D tensors, got {input_ids.shape}, {attention_mask.shape}"
            )

        batch_size, seq_len = int(input_ids.size(0)), int(input_ids.size(1))
        device = input_ids.device

        if (
            torch.cuda.is_available()
            and _MEMORY_CLEANUP_FREQUENCY > 0
            and current_step % _MEMORY_CLEANUP_FREQUENCY == 0
        ):
            for i in range(min(_NUM_GPUS, torch.cuda.device_count())):
                try:
                    with torch.cuda.device(i):
                        torch.cuda.empty_cache()
                except Exception:
                    pass
            if gc.isenabled():
                gc.collect()

        if self.training and _DSCD_ENABLE_TRAINING_CLUSTERING and use_dscd:
            if (
                current_step - self.last_discovery_step
                >= _PERIODIC_DISCOVERY_FREQUENCY
            ):
                try:
                    print("\n" + "=" * 80)
                    print(f"[TATN-DISCOVERY] TRIGGER @ step {current_step}")
                    print("=" * 80)

                    num_buffers_before = len(self.dscd.buffers)
                    num_stores_before = len(self.dscd.prototype_stores)
                    total_protos_before = sum(s.size() for s in self.dscd.prototype_stores.values())
                    
                    print(f"[TATN-DISCOVERY] BEFORE:")
                    print(f"  - Buffers: {num_buffers_before}")
                    print(f"  - Stores: {num_stores_before}")
                    print(f"  - Prototypes: {total_protos_before}")

                    start_time = time.time()

                    self.dscd.periodic_discovery_check(
                        current_step, _PERIODIC_DISCOVERY_FREQUENCY
                    )

                    elapsed = time.time() - start_time
                    self.last_discovery_step = current_step

                    summary = self.dscd.get_prototype_summary()
                    total_protos_after = summary.get('total_prototypes', 0)
                    new_protos = total_protos_after - total_protos_before
                    
                    print(f"\n[TATN-DISCOVERY] AFTER ({elapsed:.2f}s):")
                    print(f"  - Homographs: {summary.get('num_homographs', 0)}")
                    print(f"  - Total prototypes: {total_protos_after}")
                    print(f"  - NEW prototypes created: {new_protos}")
                    
                    if new_protos == 0 and num_buffers_before > 0:
                        print(f"[TATN-DISCOVERY] ⚠ WARNING: No prototypes created despite {num_buffers_before} buffers!")
                        print(f"[TATN-DISCOVERY] Check: dispersion_threshold={_DSCD_DISPERSION_THRESHOLD}, nmin={_DSCD_N_MIN}")
                    
                    print("=" * 80 + "\n")

                except Exception as e:
                    print(f"[TATN-DISCOVERY] ❌ FAILED: {e}")
                    try:
                        traceback.print_exc()
                    except Exception:
                        pass

        if not self.training and _VALIDATION_CHECK_INTERVAL > 0:
            if (
                current_step - self.last_validation_step
                >= _VALIDATION_CHECK_INTERVAL
            ):
                try:
                    if _DEBUG_DISCOVERY:
                        print(f"\n[TATN-VALIDATION] Step {current_step}")
                        summary = self.dscd.get_prototype_summary()
                        print(f"  - Tokens: {summary.get('total_tokens', 0)}")
                        print(
                            f"  - Prototypes: {summary.get('total_prototypes', 0)}"
                        )
                        print(
                            f"  - Homographs: {summary.get('num_homographs', 0)}"
                        )
                    self.last_validation_step = current_step
                except Exception:
                    pass

        enc_outputs = None
        try:
            enc_outputs = self.mbart.model.encoder(
                input_ids=input_ids, attention_mask=attention_mask
            )
        except Exception:
            try:
                enc_outputs = self.mbart.get_encoder()(
                    input_ids=input_ids, attention_mask=attention_mask
                )
            except Exception as e:
                if _DEBUG_DISCOVERY:
                    print(f"[TATN] Encoder failed: {e}")
                enc_outputs = None

        h = _safe_get_last_hidden_state(enc_outputs)
        if h is None:
            try:
                emb = self.mbart.get_input_embeddings()(input_ids).to(device)
                h = emb
            except Exception:
                h = torch.zeros(
                    batch_size,
                    seq_len,
                    int(getattr(self.mbart.config, "d_model", 1024)),
                    device=device,
                )

        embed_dim = int(h.size(-1))
        training_mode = labels is not None and self.training

        token_word_map = self._reconstruct_word_maps_before_dscd(
            input_ids, batch_size, seq_len, src_texts, token_word_map
        )

        domain_labels = self._extract_domain_labels(
            batch_size, device, src_texts
        )

        if use_dscd:
            try:
                raw_dscd = self.dscd.forward(
                    h,
                    token_types=None,
                    train_mode=self.training,
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    token_word_map=token_word_map,
                )
            except Exception as e:
                if _DEBUG_DISCOVERY:
                    print(f"[TATN] DSCD forward failed: {e}")
                raw_dscd = {
                    "h_augmented": h.detach().clone(),
                    "proto_probs": [
                        [
                            torch.tensor(
                                [1.0],
                                dtype=torch.float32,
                                device=device,
                            )
                            for _ in range(seq_len)
                        ]
                        for _ in range(batch_size)
                    ],
                    "uncertainties": [
                        [
                            torch.tensor(0.0, device=device)
                            for _ in range(seq_len)
                        ]
                        for _ in range(batch_size)
                    ],
                    "gates": [
                        [
                            torch.tensor(0.0, device=device)
                            for _ in range(seq_len)
                        ]
                        for _ in range(batch_size)
                    ],
                    "span_preds": [
                        [
                            torch.tensor(0.0, device=device)
                            for _ in range(seq_len)
                        ]
                        for _ in range(batch_size)
                    ],
                    "proto_assignments": [
                        torch.zeros(
                            seq_len, dtype=torch.long, device=device
                        )
                        for _ in range(batch_size)
                    ],
                }
        else:
            raw_dscd = {
                "h_augmented": h.detach().clone(),
                "proto_probs": [
                    [
                        torch.tensor(
                            [1.0], dtype=torch.float32, device=device
                        )
                        for _ in range(seq_len)
                    ]
                    for _ in range(batch_size)
                ],
                "uncertainties": [
                    [
                        torch.tensor(0.0, device=device)
                        for _ in range(seq_len)
                    ]
                    for _ in range(batch_size)
                ],
                "gates": [
                    [
                        torch.tensor(0.0, device=device)
                        for _ in range(seq_len)
                    ]
                    for _ in range(batch_size)
                ],
                "span_preds": [
                    [
                        torch.tensor(0.0, device=device)
                        for _ in range(seq_len)
                    ]
                    for _ in range(batch_size)
                ],
                "proto_assignments": [
                    torch.zeros(seq_len, dtype=torch.long, device=device)
                    for _ in range(batch_size)
                ],
            }

        dscd = _normalize_dscd_outputs(
            raw_dscd, batch_size, seq_len, device, embed_dim
        )
        h_aug = dscd.get("h_augmented", h)

        if not isinstance(h_aug, torch.Tensor) or h_aug.shape != h.shape:
            if _DEBUG_DISCOVERY:
                print(
                    f"[TATN] h_augmented shape mismatch "
                    f"(expected {h.shape}, got {getattr(h_aug, 'shape', None)})"
                )
            h_aug = h

        domain_loss_from_forward = torch.tensor(0.0, device=device)
        
        if use_asbn and domain_labels is not None:
            try:
                h_aug, domain_loss_from_forward = self.asbn.forward(
                    h_aug, domain_labels=domain_labels
                )
                
                if not isinstance(domain_loss_from_forward, torch.Tensor):
                    domain_loss_from_forward = torch.tensor(0.0, device=device)
                else:
                    domain_loss_from_forward = domain_loss_from_forward.to(device)
                
                if not torch.isfinite(domain_loss_from_forward):
                    domain_loss_from_forward = torch.tensor(0.0, device=device)
                
            except Exception as e:
                if _DEBUG_DISCOVERY:
                    print(f"[TATN] ASBN forward (BN+domain) failed: {e}")
                domain_loss_from_forward = torch.tensor(0.0, device=device)

        try:
            enc_for_decoder = BaseModelOutput(
                last_hidden_state=h_aug,
                hidden_states=(
                    getattr(enc_outputs, "hidden_states", None)
                    if enc_outputs
                    else None
                ),
                attentions=(
                    getattr(enc_outputs, "attentions", None)
                    if enc_outputs
                    else None
                ),
            )
        except Exception:
            enc_for_decoder = (h_aug,)

        if training_mode:
            try:
                if labels is not None:
                    if labels.size(1) < 2:
                        if _DEBUG_DISCOVERY:
                            print("[TATN] Labels too short for decoder_input_ids construction")
                        decoder_input_ids = None
                        decoder_attention_mask = None
                    else:
                        decoder_input_ids = labels.clone()
                        decoder_input_ids = torch.where(
                            decoder_input_ids == -100,
                            torch.full_like(
                                decoder_input_ids,
                                self.tokenizer.pad_token_id,
                            ),
                            decoder_input_ids,
                        )

                        bos_column = torch.full(
                            (batch_size, 1),
                            int(self.mbart.config.decoder_start_token_id),
                            dtype=torch.long,
                            device=device,
                        )
                        decoder_input_ids = torch.cat(
                            [bos_column, decoder_input_ids[:, :-1]], dim=1
                        )
                        decoder_attention_mask = (
                            decoder_input_ids != self.tokenizer.pad_token_id
                        ).long()
                else:
                    decoder_input_ids = None
                    decoder_attention_mask = None

                seq_outputs = self.mbart(
                    input_ids=None,
                    attention_mask=attention_mask,
                    encoder_outputs=enc_for_decoder,
                    decoder_input_ids=decoder_input_ids,
                    decoder_attention_mask=decoder_attention_mask,
                    labels=labels,
                    use_cache=False,
                    return_dict=True,
                )
                translation_loss = getattr(seq_outputs, "loss", None)
                if translation_loss is None:
                    translation_loss = torch.tensor(0.0, device=device)
            except Exception as e:
                if _DEBUG_DISCOVERY:
                    print(f"[TATN] Decoder forward failed: {e}")
                    try:
                        traceback.print_exc()
                    except Exception:
                        pass
                translation_loss = torch.tensor(0.0, device=device)

            encoder_loss_from_grl = torch.tensor(0.0, device=device)
            adversarial_loss = torch.tensor(0.0, device=device)
            domain_loss_from_grl = torch.tensor(0.0, device=device)
            domain_accuracy = torch.tensor(0.0, device=device)
            
            if use_asbn:
                try:
                    asbn_ret = self.asbn.forward_with_grl_simplified(
                        h_aug,
                        dscd.get("proto_probs", None),
                        dscd.get("uncertainties", None),
                        dscd.get("gates", None),
                        token_word_map=token_word_map,
                        domain_labels=domain_labels,
                        global_step=current_step,
                    )
                    
                    if isinstance(asbn_ret, (tuple, list)) and len(asbn_ret) >= 4:
                        encoder_loss_from_grl = asbn_ret[0]
                        adversarial_loss = asbn_ret[1]
                        domain_loss_from_grl = asbn_ret[2]
                        domain_accuracy = asbn_ret[3]
                    elif isinstance(asbn_ret, (tuple, list)):
                        encoder_loss_from_grl = asbn_ret[0]
                    else:
                        encoder_loss_from_grl = asbn_ret
                    
                    if not isinstance(encoder_loss_from_grl, torch.Tensor):
                        encoder_loss_from_grl = torch.tensor(float(encoder_loss_from_grl), device=device)
                    else:
                        encoder_loss_from_grl = encoder_loss_from_grl.to(device)
                    
                    if not torch.isfinite(encoder_loss_from_grl):
                        encoder_loss_from_grl = torch.tensor(0.0, device=device)
                    
                    encoder_loss_from_grl = torch.clamp(encoder_loss_from_grl, 0.0, 10.0)
                    
                    if _DEBUG_DISCOVERY and current_step % 100 == 0:
                        print(f"[TATN-ASBN] Step {current_step}:")
                        print(f"  Domain loss (forward): {domain_loss_from_forward.item():.4f}")
                        print(f"  Domain loss (GRL): {domain_loss_from_grl.item():.4f}")
                        print(f"  Domain accuracy: {domain_accuracy.item():.2%}")
                        print(f"  Encoder loss: {encoder_loss_from_grl.item():.4f}")
                    
                except Exception as e:
                    if _DEBUG_DISCOVERY:
                        print(f"[TATN] ASBN forward_with_grl_simplified failed: {e}")
                    encoder_loss_from_grl = torch.tensor(0.0, device=device)

            try:
                dscd_reg = self._entropy_reg_from_proto_probs_static(
                    dscd.get("proto_probs", []),
                    gates_list=dscd.get("gates", []),
                    min_gate=0.0,
                )
                if not isinstance(dscd_reg, torch.Tensor):
                    dscd_reg = torch.tensor(
                        float(dscd_reg), device=device
                    )
                else:
                    dscd_reg = dscd_reg.to(device)
                if not torch.isfinite(dscd_reg):
                    dscd_reg = torch.tensor(0.0, device=device)
            except Exception as e:
                if _DEBUG_DISCOVERY:
                    print(f"[TATN] DSCD reg failed: {e}")
                dscd_reg = torch.tensor(0.0, device=device)

            total_loss = (
                translation_loss
                + _LAMBDA_ASBN * encoder_loss_from_grl
                + _LAMBDA_DOMAIN * domain_loss_from_grl
                + _LAMBDA_DSCD * dscd_reg
            )
            
            if not isinstance(total_loss, torch.Tensor):
                total_loss = torch.tensor(float(total_loss), device=device)
            if total_loss.numel() != 1:
                total_loss = total_loss.mean()

            if not torch.isfinite(total_loss):
                if _DEBUG_DISCOVERY:
                    print(
                        "[TATN] NaN/Inf detected in total_loss - "
                        "using translation_loss only"
                    )
                total_loss = (
                    translation_loss
                    if torch.isfinite(translation_loss)
                    else torch.tensor(1.0, device=device)
                )

            try:
                del enc_outputs, h, raw_dscd
            except Exception:
                pass
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            return total_loss

        explanations_list: List[List[Dict[str, Any]]] = []

        if (not self.training) and _ENABLE_TRG_INFERENCE:
            if _DEBUG_DISCOVERY:
                print(
                    f"\n[TATN-INFERENCE] Starting TRG for {batch_size} samples"
                )

            tokens_batch: List[List[str]] = []

            for b in range(batch_size):
                try:
                    ids_b = input_ids[b].detach().cpu().tolist()
                    if hasattr(self.tokenizer, "convert_ids_to_tokens"):
                        toks = self.tokenizer.convert_ids_to_tokens(ids_b)
                    else:
                        toks = []
                    if not toks:
                        toks = ["UNK"] * seq_len
                    elif len(toks) < seq_len:
                        toks = toks + [""] * (seq_len - len(toks))
                    elif len(toks) > seq_len:
                        toks = toks[:seq_len]
                except Exception:
                    toks = ["UNK"] * seq_len

                tokens_batch.append(toks)

            decoder_attention = None

            try:
                total_explanations = 0

                for b in range(batch_size):
                    per_sent = {
                        "proto_probs": [self._safe_take_key_static(
                            dscd, "proto_probs", b, seq_len, device
                        )],
                        "uncertainties": [self._safe_take_key_static(
                            dscd, "uncertainties", b, seq_len, device
                        )],
                        "gates": [self._safe_take_key_static(
                            dscd, "gates", b, seq_len, device
                        )],
                        "span_preds": [self._safe_take_key_static(
                            dscd, "span_preds", b, seq_len, device
                        )],
                    }

                    try:
                        exps = self.trg_system.process_sentence_for_explanations(
                            tokens_batch[b],
                            per_sent,
                            token_word_map=(
                                token_word_map[b]
                                if token_word_map
                                and b < len(token_word_map)
                                else None
                            ),
                            uncertainty_threshold=_TRG_UNCERTAINTY_THRESHOLD,
                            decoder_attention=decoder_attention,
                        )
                        batch_exps = exps if isinstance(exps, list) else []
                        explanations_list.append(batch_exps)
                        total_explanations += len(batch_exps)

                        if _DEBUG_DISCOVERY and b < 2:
                            print(
                                f"[TATN-INFERENCE] Sample {b}: "
                                f"{len(batch_exps)} explanations"
                            )
                            if len(batch_exps) == 0:
                                print(f"  U: {[float(u) for u in per_sent['uncertainties'][0][:5]]}")
                                print(f"  S: {[float(s) for s in per_sent['span_preds'][0][:5]]}")

                    except Exception as e:
                        if _DEBUG_DISCOVERY:
                            print(
                                f"[TATN-INFERENCE] TRG failed for sample {b}: {e}"
                            )
                            try:
                                traceback.print_exc()
                            except Exception:
                                pass
                        explanations_list.append([])

                if _DEBUG_DISCOVERY:
                    print(
                        f"\n[TATN-INFERENCE] Total explanations: {total_explanations}"
                    )
                    if total_explanations == 0:
                        print("[TATN-INFERENCE] NO EXPLANATIONS GENERATED - check thresholds")

            except Exception as e:
                if _DEBUG_DISCOVERY:
                    print(f"[TATN-INFERENCE] TRG generation failed: {e}")
                    try:
                        traceback.print_exc()
                    except Exception:
                        pass
                explanations_list = [[] for _ in range(batch_size)]
        else:
            explanations_list = [[] for _ in range(batch_size)]

        outputs = {
            "encoder_outputs": enc_outputs,
            "dscd_outputs": dscd,
            "sense_augmented_embeddings": h_aug,
            "explanations": explanations_list,
            "asbn_loss": torch.tensor(0.0, device=device),
            "ambiguity_signals": {
                "span": dscd.get("span_preds", []),
                "uncertainty": dscd.get("uncertainties", []),
                "confidence": [
                    [
                        1.0
                        - (
                            float(u)
                            if isinstance(u, (float, int))
                            else (
                                float(u.item())
                                if isinstance(u, torch.Tensor)
                                else 1.0
                            )
                        )
                        for u in row
                    ]
                    for row in dscd.get("uncertainties", [])
                ],
                "proto_probs": dscd.get("proto_probs", []),
            },
        }

        try:
            del h, raw_dscd
        except Exception:
            pass
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

        return outputs

    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        max_length: int = 128,
        num_beams: int = 5,
        early_stopping: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        try:
            enc_outputs = self.mbart.model.encoder(
                input_ids=input_ids, attention_mask=attention_mask
            )

            enc_wrapped = BaseModelOutput(
                last_hidden_state=(
                    enc_outputs.last_hidden_state
                    if hasattr(enc_outputs, "last_hidden_state")
                    else enc_outputs[0]
                ),
                hidden_states=getattr(enc_outputs, "hidden_states", None),
                attentions=getattr(enc_outputs, "attentions", None),
            )

            return self.mbart.generate(
                input_ids=None,
                attention_mask=attention_mask,
                encoder_outputs=enc_wrapped,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=early_stopping,
                forced_bos_token_id=int(
                    self.mbart.config.forced_bos_token_id
                ),
                **kwargs,
            )
        except Exception as e:
            if _DEBUG_DISCOVERY:
                print(f"[TATN-GENERATE] Failed: {e}")
            raise

    def forward_with_explanations(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None,
    ):
        return self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            src_texts=src_texts,
            token_word_map=token_word_map,
            labels=None,
        )

    def get_component_stats(self) -> Dict[str, Any]:
        stats: Dict[str, Any] = {
            "global_step": self.global_step,
            "last_discovery_step": self.last_discovery_step,
            "last_validation_step": self.last_validation_step,
        }

        try:
            if hasattr(self.dscd, "get_prototype_summary"):
                stats["dscd"] = self.dscd.get_prototype_summary()
        except Exception:
            pass

        try:
            if hasattr(self.asbn, "get_detailed_stats"):
                stats["asbn"] = self.asbn.get_detailed_stats()
        except Exception:
            pass

        try:
            if hasattr(self.trg_system, "get_statistics"):
                stats["trg"] = self.trg_system.get_statistics()
        except Exception:
            pass

        return stats

print("\n" + "=" * 80)
print("Cell 6: TATN Ready (DISCOVERY FREQUENCY FIXED)")
print("=" * 80)
print()
print("Config:")
print(f"  - Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY} steps ← FIXED!")
print(f"  - Span threshold: {_SPAN_THRESHOLD:.2f}")
print(f"  - Uncertainty threshold: {_UNCERTAINTY_THRESHOLD:.2f}")
print(f"  - TRG uncertainty: {_TRG_UNCERTAINTY_THRESHOLD:.2f}")
print(f"  - Lambda ASBN: {_LAMBDA_ASBN}")
print(f"  - Lambda Domain: {_LAMBDA_DOMAIN}")
print(f"  - Lambda DSCD: {_LAMBDA_DSCD}")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 7: TRAINING LOOP (PURE UNSUPERVISED) - FIXED
# ==============================================================================

import os
import time
import math
import gc
import traceback
from datetime import datetime
from pathlib import Path
from collections import defaultdict, deque
from typing import Optional, Dict, Any, List

import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast as cuda_amp_autocast
from tqdm import tqdm
from contextlib import nullcontext
import threading

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except (NameError, TypeError):
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except (NameError, TypeError):
    _DEBUG_DISCOVERY = False

DEBUG_PRINT_INTERVAL = 200
_cell7_dbg_counts = defaultdict(int)

def cell7_dbg(key: str, msg: str, limit: int = 10):
    if not (_VERBOSE_LOGGING or _DEBUG_DISCOVERY):
        return
    _cell7_dbg_counts[key] += 1
    if _cell7_dbg_counts[key] <= limit:
        print(f"[CELL7-DBG] {msg}")

try:
    _DEVICE = DEVICE
except (NameError, TypeError):
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    _EPOCHS = int(EPOCHS)
except (NameError, ValueError, TypeError):
    _EPOCHS = 1

try:
    _BATCH_SIZE = int(BATCH_SIZE)
except (NameError, ValueError, TypeError):
    _BATCH_SIZE = 8

try:
    _ACCUMULATION_STEPS = int(ACCUMULATION_STEPS)
except (NameError, ValueError, TypeError):
    _ACCUMULATION_STEPS = 1

try:
    _GRAD_CLIP_NORM = float(GRAD_CLIP_NORM)
except (NameError, ValueError, TypeError):
    _GRAD_CLIP_NORM = 1.0

try:
    _MEMORY_CLEANUP_FREQUENCY = int(MEMORY_CLEANUP_FREQUENCY)
except (NameError, ValueError, TypeError):
    _MEMORY_CLEANUP_FREQUENCY = 500

try:
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
    _NUM_GPUS = int(NUM_GPUS)
except (NameError, ValueError, TypeError):
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1

try:
    _USE_AMP = bool(USE_AMP)
except (NameError, TypeError):
    _USE_AMP = True

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
    _TARGET_LANGUAGE = str(TARGET_LANGUAGE)
except (NameError, TypeError):
    _SOURCE_LANGUAGE = "bn"
    _TARGET_LANGUAGE = "en"

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except (NameError, ValueError, TypeError):
    _MAX_LENGTH = 48

try:
    _VALIDATION_CHECK_INTERVAL = int(VALIDATION_CHECK_INTERVAL)
except (NameError, ValueError, TypeError):
    _VALIDATION_CHECK_INTERVAL = 500

try:
    _PERIODIC_DISCOVERY_FREQUENCY = int(PERIODIC_DISCOVERY_FREQUENCY)
except (NameError, ValueError, TypeError):
    _PERIODIC_DISCOVERY_FREQUENCY = 50

try:
    _TRAIN_DOMAIN = int(TRAIN_DOMAIN)
    _TEST_DOMAIN = int(TEST_DOMAIN)
except (NameError, ValueError, TypeError):
    _TRAIN_DOMAIN = 0
    _TEST_DOMAIN = 1

try:
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in HOMOGRAPH_REFERENCE_LIST_BN)
except (NameError, TypeError):
    _HOMOGRAPH_REFERENCE_LIST = {
        "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা",
        "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত"
    }
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in _HOMOGRAPH_REFERENCE_LIST)

def clear_all_gpu_caches():
    gc.collect()
    if not torch.cuda.is_available():
        return
    try:
        for i in range(torch.cuda.device_count()):
            with torch.cuda.device(i):
                try:
                    torch.cuda.empty_cache()
                except Exception:
                    pass
    except Exception:
        pass

def get_amp_ctx():
    if not _USE_AMP or not torch.cuda.is_available():
        return nullcontext()
    try:
        return cuda_amp_autocast()
    except Exception:
        return nullcontext()

_PROTOBUF_COMPAT_ERROR_SHOWN = globals().get("_PROTOBUF_COMPAT_ERROR_SHOWN", False)

def _get_dscd_homographs(model: torch.nn.Module) -> set:
    try:
        core = model.module if hasattr(model, 'module') else model
        dscd = getattr(core, 'dscd', None)
        if dscd is None:
            return set()

        if hasattr(dscd, 'get_discovered_homographs'):
            return dscd.get_discovered_homographs()

        homographs = set()
        lock = None
        if hasattr(dscd, 'buffer_lock'):
            lock = dscd.buffer_lock
        elif hasattr(dscd, 'clustering_lock'):
            lock = dscd.clustering_lock

        if lock:
            with lock:
                for token, store in dscd.prototype_stores.items():
                    try:
                        if store.size() >= 2:
                            clean_token = str(token).replace('▁', '').replace('Ġ', '').replace('##', '').strip().lower()
                            homographs.add(clean_token)
                    except Exception:
                        continue
        else:
            for token, store in dscd.prototype_stores.items():
                try:
                    if store.size() >= 2:
                        clean_token = str(token).replace('▁', '').replace('Ġ', '').replace('##', '').strip().lower()
                        homographs.add(clean_token)
                except Exception:
                    continue

        return homographs
    except Exception:
        return set()

@torch.inference_mode()
def comprehensive_epoch_validation(
    model: torch.nn.Module,
    tokenizer,
    epoch: int,
    global_step: int,
    source_lang: str,
    target_lang: str,
    max_length: int,
    device: torch.device
) -> Dict[str, Any]:
    global _PROTOBUF_COMPAT_ERROR_SHOWN

    print("\n" + "=" * 80)
    print(f"EPOCH {epoch} COMPREHENSIVE VALIDATION (Step {global_step})")
    print("=" * 80)

    core_model = model.module if hasattr(model, "module") else model
    was_training = core_model.training

    if not isinstance(device, torch.device):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dscd_homographs = _get_dscd_homographs(model)
    if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
        print(f"[VALIDATION] DSCD discovered homographs: {len(dscd_homographs)}")
        if dscd_homographs:
            print(f"[VALIDATION] Sample: {list(dscd_homographs)[:10]}")

    validation_results = {
        'epoch': epoch,
        'step': global_step,
        'translations_success': 0,
        'translations_failed': 0,
        'explanations_generated': 0,
        'dscd_homographs_explained': 0,
        'reference_homographs_explained': 0,
        'avg_explanation_confidence': 0.0,
        'dscd_quality_score': 0.0,
        'dscd_multi_sense_tokens': 0,
        'dscd_total_prototypes': 0,
        'asbn_domain_loss': 0.0,
        'asbn_domain_accuracy': 0.0,
        'asbn_source_accuracy': 0.0,
        'asbn_target_accuracy': 0.0,
        'trg_total_explanations': 0,
        'validation_completed': False,
    }

    try:
        core_model.eval()
        
        try:
            trg_system = getattr(core_model, 'trg_system', None)
            if trg_system is not None and hasattr(trg_system, 'eval'):
                trg_system.eval()
        except Exception:
            pass

        val_sentences = [
            ("আমি কল বন্ধ করেছি।", "I turned off the tap", "কল=tap/call"),
            ("কাল আমি বই কিনব।", "Tomorrow I will buy a book", "কাল=tomorrow/yesterday"),
            ("পাতা ঝরে পড়েছে।", "The leaf has fallen", "পাতা=leaf/page"),
            ("তিনি ব্যাংক গেছেন।", "He went to the bank", "ব্যাংক=bank/embankment"),
            ("আমি ভালো আছি।", "I am fine", "No ambiguity"),
            ("সে খুব মিষ্টি কথা বলে।", "She speaks sweetly", "No ambiguity"),
            ("এটা আমার বই।", "This is my book", "No ambiguity"),
            ("আজ আবহাওয়া ভালো।", "Weather is good today", "No ambiguity"),
            ("ফল খুব সুস্বাদু।", "The fruit is delicious", "ফল=fruit/result"),
            ("মাথা ব্যথা করছে।", "Head is aching", "মাথা=head/top"),
        ]

        print(f"\n[VALIDATION] Testing {len(val_sentences)} samples:")
        print("-" * 80)

        confidences = []
        dscd_homograph_words_detected = set()
        reference_homograph_words_detected = set()

        mbart_obj = None
        try:
            mbart_obj = getattr(core_model, "mbart", None)
        except Exception:
            mbart_obj = None

        try:
            try:
                tokenizer.src_lang = source_lang
            except Exception:
                pass

            forced_id = None
            try:
                if hasattr(tokenizer, "get_lang_id"):
                    for code in (target_lang, "en_XX", "en", "eng"):
                        try:
                            lid = tokenizer.get_lang_id(code)
                            if lid is not None:
                                forced_id = int(lid)
                                break
                        except Exception:
                            continue
                elif hasattr(tokenizer, "lang_code_to_id"):
                    forced_id = tokenizer.lang_code_to_id.get(target_lang, None)
                    if forced_id is not None:
                        forced_id = int(forced_id)
            except Exception:
                forced_id = None

            if forced_id is None:
                try:
                    forced_id = int(globals().get('M2M100_EN_TOKEN_ID', 128022))
                except Exception:
                    forced_id = 128022

            orig_use_cache = None
            try:
                if mbart_obj is not None and hasattr(mbart_obj.config, "use_cache"):
                    orig_use_cache = mbart_obj.config.use_cache
                    mbart_obj.config.use_cache = True
            except Exception:
                orig_use_cache = None

            for idx, (src, expected, note) in enumerate(val_sentences, 1):
                try:
                    enc = tokenizer(
                        src,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=max_length,
                    )
                    enc = {
                        k: (
                            v.to(device, non_blocking=True)
                            if isinstance(v, torch.Tensor)
                            else v
                        )
                        for k, v in enc.items()
                    }

                    if forced_id is not None:
                        try:
                            if mbart_obj is not None:
                                mbart_obj.config.forced_bos_token_id = int(forced_id)
                                mbart_obj.config.decoder_start_token_id = int(forced_id)
                        except Exception:
                            pass

                    out_ids = None
                    try:
                        gen_src = mbart_obj if mbart_obj is not None else core_model
                        if hasattr(gen_src, "generate"):
                            out_ids = gen_src.generate(
                                enc.get("input_ids"),
                                attention_mask=enc.get("attention_mask"),
                                max_length=max_length,
                                num_beams=2,
                                do_sample=False,
                                early_stopping=True,
                                pad_token_id=int(
                                    getattr(tokenizer, "pad_token_id", 1)
                                ),
                                forced_bos_token_id=int(forced_id)
                                if forced_id is not None
                                else None,
                            )
                    except AttributeError:
                        if not _PROTOBUF_COMPAT_ERROR_SHOWN:
                            print(
                                "[VALIDATION] Warning: generation raised AttributeError (protobuf incompatibility)."
                            )
                            _PROTOBUF_COMPAT_ERROR_SHOWN = True
                        out_ids = None
                    except Exception as e:
                        print(
                            f"[VALIDATION] Generation error: {type(e).__name__}: {str(e)[:200]}"
                        )
                        out_ids = None

                    translation = ""
                    if out_ids is not None and (
                        (isinstance(out_ids, torch.Tensor) and out_ids.numel() > 0)
                        or (
                            isinstance(out_ids, (list, tuple))
                            and len(out_ids) > 0
                        )
                    ):
                        try:
                            if isinstance(out_ids, (list, tuple)):
                                translation = tokenizer.batch_decode(
                                    out_ids, skip_special_tokens=True
                                )[0] if out_ids else ""
                            else:
                                translation = (
                                    tokenizer.decode(
                                        out_ids[0], skip_special_tokens=True
                                    )
                                    if out_ids.size(0) > 0
                                    else ""
                                )
                        except AttributeError:
                            if not _PROTOBUF_COMPAT_ERROR_SHOWN:
                                print(
                                    "[VALIDATION] Warning: decode raised AttributeError (protobuf)."
                                )
                                _PROTOBUF_COMPAT_ERROR_SHOWN = True
                            translation = ""
                        except Exception as e:
                            print(
                                f"[VALIDATION] Decode error: {type(e).__name__}: {str(e)[:200]}"
                            )
                            translation = ""
                    else:
                        translation = ""

                    if translation:
                        validation_results['translations_success'] += 1
                    else:
                        validation_results['translations_failed'] += 1
                        print(
                            f"  {idx:2d}. Translation failed: {note[:30]:30s}"
                        )
                        continue

                    explanation_status = ""
                    try:
                        if 'translate_with_explanations' in globals():
                            res = translate_with_explanations(
                                model, tokenizer, src
                            )
                            exps = res.get('explanations', [])
                            validation_results['explanations_generated'] += len(
                                exps
                            )

                            if exps:
                                explanation_status = f"{len(exps)} expl"
                                for exp in exps:
                                    try:
                                        conf = exp.get('confidence', 0.5)
                                        confidences.append(float(conf))

                                        word = exp.get('ambiguous_word', '').strip()
                                        clean_word = (
                                            word.replace('▁', '')
                                            .replace('Ġ', '')
                                            .replace('##', '')
                                            .lower()
                                        )

                                        if clean_word in dscd_homographs:
                                            validation_results[
                                                'dscd_homographs_explained'
                                            ] += 1
                                            dscd_homograph_words_detected.add(
                                                clean_word
                                            )

                                        if clean_word in _HOMOGRAPH_REFERENCE_LIST:
                                            validation_results[
                                                'reference_homographs_explained'
                                            ] += 1
                                            reference_homograph_words_detected.add(
                                                clean_word
                                            )
                                    except Exception:
                                        pass
                            else:
                                explanation_status = "no expl"
                                if _DEBUG_DISCOVERY:
                                    print(f"  [VALIDATION] Sample {idx}: 0 explanations (check thresholds)")
                        else:
                            explanation_status = "unavailable"
                    except Exception as e:
                        explanation_status = f"error: {type(e).__name__}"

                    print(
                        f"  {idx:2d}. {explanation_status:15s} "
                        f"{note[:30]:30s} -> {translation[:200]}"
                    )
                    del enc
                    if out_ids is not None:
                        del out_ids

                except Exception as e:
                    validation_results['translations_failed'] += 1
                    print(
                        f"  {idx:2d}. ERROR: {note[:30]:30s} -> {type(e).__name__}"
                    )
                    if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                        try:
                            traceback.print_exc()
                        except Exception:
                            pass

        finally:
            try:
                if mbart_obj is not None and orig_use_cache is not None:
                    mbart_obj.config.use_cache = orig_use_cache
            except Exception:
                pass
            if torch.cuda.is_available():
                try:
                    torch.cuda.synchronize()
                except Exception:
                    pass

            clear_all_gpu_caches()

        print("\n" + "-" * 80)
        print("[VALIDATION] DSCD Prototype Quality Check:")
        try:
            dscd = core_model.dscd if hasattr(core_model, 'dscd') else None
            if dscd and hasattr(dscd, 'validate_prototypes'):
                lock = None
                if hasattr(dscd, 'buffer_lock'):
                    lock = dscd.buffer_lock
                elif hasattr(dscd, 'clustering_lock'):
                    lock = dscd.clustering_lock

                if lock:
                    with lock:
                        quality_results = dscd.validate_prototypes()
                else:
                    quality_results = dscd.validate_prototypes()

                validation_results['dscd_quality_score'] = quality_results.get(
                    'quality_score', 0.0
                )
                validation_results['dscd_multi_sense_tokens'] = quality_results.get(
                    'multi_sense_tokens', 0
                )
                validation_results['dscd_total_prototypes'] = quality_results.get(
                    'total_prototypes', 0
                )
                print(
                    f"  - Quality Score: {validation_results['dscd_quality_score']:.1%}"
                )
                print(
                    f"  - Multi-sense tokens: {validation_results['dscd_multi_sense_tokens']}"
                )
                print(
                    f"  - Total prototypes: {validation_results['dscd_total_prototypes']}"
                )
            else:
                print("  - Validation not available")
        except Exception as e:
            print(f"  - Validation failed: {type(e).__name__}")

        print("\n" + "-" * 80)
        print("[VALIDATION] ASBN Training Statistics:")
        try:
            asbn = core_model.asbn if hasattr(core_model, 'asbn') else None
            if asbn and hasattr(asbn, 'get_detailed_stats'):
                asbn_stats = asbn.get_detailed_stats()
                validation_results['asbn_domain_loss'] = asbn_stats.get(
                    'domain_loss', 0.0
                )
                validation_results['asbn_domain_accuracy'] = asbn_stats.get(
                    'domain_accuracy', 0.0
                )
                validation_results['asbn_source_accuracy'] = asbn_stats.get(
                    'source_accuracy', 0.0
                )
                validation_results['asbn_target_accuracy'] = asbn_stats.get(
                    'target_accuracy', 0.0
                )
                print(
                    f"  - Domain Loss: {validation_results['asbn_domain_loss']:.4f}"
                )
                print(
                    f"  - Domain Accuracy: {validation_results['asbn_domain_accuracy']:.2%}"
                )
                print(
                    f"  - Source Accuracy: {validation_results['asbn_source_accuracy']:.2%}"
                )
                print(
                    f"  - Target Accuracy: {validation_results['asbn_target_accuracy']:.2%}"
                )
            elif asbn and hasattr(asbn, 'get_asbn_stats'):
                asbn_stats = asbn.get_asbn_stats()
                validation_results['asbn_domain_loss'] = asbn_stats.get(
                    'domain_loss', 0.0
                )
                validation_results['asbn_domain_accuracy'] = asbn_stats.get(
                    'domain_accuracy', 0.0
                )
                print(
                    f"  - Domain Loss: {validation_results['asbn_domain_loss']:.4f}"
                )
                print(
                    f"  - Domain Accuracy: {validation_results['asbn_domain_accuracy']:.2%}"
                )
            else:
                print("  - ASBN statistics not available")
        except Exception as e:
            print(f"  - ASBN stats retrieval failed: {type(e).__name__}")

        print("\n" + "-" * 80)
        print("[VALIDATION] TRG Explanation Statistics:")
        try:
            trg = core_model.trg_system if hasattr(core_model, 'trg_system') else None
            if trg and hasattr(trg, 'get_statistics'):
                trg_stats = trg.get_statistics()
                validation_results['trg_total_explanations'] = trg_stats.get(
                    'explanations_generated', 0
                )
                print(
                    f"  - Total explanations: {validation_results['trg_total_explanations']}"
                )
                print(
                    f"  - High confidence rate: {trg_stats.get('high_confidence_rate', 0):.1%}"
                )
                print(
                    f"  - DSCD homograph rate: {trg_stats.get('dscd_homograph_rate', 0):.1%}"
                )
            else:
                print("  - TRG statistics not available")
        except Exception as e:
            print(f"  - TRG stats retrieval failed: {type(e).__name__}")

        if confidences:
            validation_results['avg_explanation_confidence'] = sum(
                confidences
            ) / len(confidences)

        print("-" * 80)
        print("\n[VALIDATION] Summary:")
        print(
            f"  - Translations: {validation_results['translations_success']}/{len(val_sentences)} successful"
        )
        print(
            f"  - Explanations generated: {validation_results['explanations_generated']}"
        )
        print(
            f"  - Avg explanation confidence: {validation_results['avg_explanation_confidence']:.3f}"
        )
        print(
            f"  - DSCD homographs explained: {validation_results['dscd_homographs_explained']}"
        )
        print(
            f"  - Reference homographs explained: {validation_results['reference_homographs_explained']}"
        )

        if dscd_homograph_words_detected:
            print(
                f"  - DSCD homographs detected: {', '.join(sorted(dscd_homograph_words_detected))}"
            )

        print(
            f"  - DSCD Quality Score: {validation_results['dscd_quality_score']:.1%}"
        )
        print(
            f"  - Multi-sense tokens: {validation_results['dscd_multi_sense_tokens']}"
        )
        print(
            f"  - ASBN Domain Accuracy: {validation_results['asbn_domain_accuracy']:.2%}"
        )

        warnings = []
        if validation_results['translations_failed'] > len(val_sentences) // 2:
            warnings.append("High translation failure rate")
        if validation_results['explanations_generated'] == 0:
            warnings.append("No explanations generated - check thresholds (span=0.12, uncertainty=0.15)")
        if validation_results['dscd_quality_score'] < 0.3:
            warnings.append("Low DSCD quality score")
        if validation_results['dscd_multi_sense_tokens'] < 10:
            warnings.append("Very few multi-sense tokens")

        if warnings:
            print("\n[VALIDATION] Health Warnings:")
            for w in warnings:
                print(f"  - {w}")
        else:
            print("\n[VALIDATION] All systems healthy")

        validation_results['validation_completed'] = True

    except Exception as e:
        print(
            f"\n[VALIDATION] Critical error: {type(e).__name__}: {str(e)[:200]}"
        )
        if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
            try:
                traceback.print_exc()
            except Exception:
                pass
        validation_results['validation_completed'] = False

    finally:
        if was_training:
            core_model.train()
        clear_all_gpu_caches()

    print("=" * 80 + "\n")
    return validation_results

def _print_gpu_mem(prefix: str = ""):
    if not torch.cuda.is_available():
        return
    try:
        lines = [f"{prefix} GPU mem (GB):"]
        for i in range(torch.cuda.device_count()):
            try:
                alloc = torch.cuda.memory_allocated(i) / (1024**3)
                resv = torch.cuda.memory_reserved(i) / (1024**3)
                lines.append(f"  GPU {i}: alloc={alloc:.2f} resv={resv:.2f}")
            except Exception:
                lines.append(f"  GPU {i}: mem query failed")
        print("\n".join(lines))
    except Exception:
        pass

def _get_cluster_count(model: torch.nn.Module) -> int:
    try:
        core = model
        while hasattr(core, 'module'):
            core = core.module

        dscd = getattr(core, 'dscd', None)
        if dscd is None:
            return 0

        stores = getattr(dscd, 'prototype_stores', None)
        if stores is None:
            return 0

        lock = None
        if hasattr(dscd, 'buffer_lock'):
            lock = dscd.buffer_lock
        elif hasattr(dscd, 'clustering_lock'):
            lock = dscd.clustering_lock

        if lock:
            with lock:
                return len(stores)
        else:
            return len(stores)

    except Exception:
        return 0

def _get_dscd_safe(model: torch.nn.Module):
    try:
        core = model
        while hasattr(core, 'module'):
            core = core.module
        return getattr(core, 'dscd', None)
    except Exception:
        return None

def _print_top_clusters(model: torch.nn.Module, top_n: int = 5):
    dscd = _get_dscd_safe(model)
    if dscd is None:
        return

    try:
        dscd_homographs = _get_dscd_homographs(model)
        items = []
        homograph_items = []

        lock = None
        if hasattr(dscd, 'buffer_lock'):
            lock = dscd.buffer_lock
        elif hasattr(dscd, 'clustering_lock'):
            lock = dscd.clustering_lock

        if lock:
            with lock:
                stores_snapshot = list(dscd.prototype_stores.items())
        else:
            stores_snapshot = list(dscd.prototype_stores.items())

        for token, store in stores_snapshot:
            try:
                total_count = sum(getattr(store, "counts", []) or [])
                protos = store.size() if hasattr(store, "size") else len(
                    getattr(store, "centroids", [])
                )
                clean_token = (
                    str(token).replace('▁', '').replace('Ġ', '').replace('##', '').strip().lower()
                )
                is_homograph = clean_token in dscd_homographs
                item = (
                    token,
                    total_count,
                    protos,
                    len(dscd.buffers.get(token, []))
                    if hasattr(dscd, 'buffers')
                    else 0,
                    is_homograph,
                )
                items.append(item)
                if is_homograph:
                    homograph_items.append(item)
            except Exception:
                continue

        items.sort(key=lambda x: x[1], reverse=True)

        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print("[CLUSTER-DBG] Top clusters:")
            for i, (tok, cnt, prot, buflen, is_homo) in enumerate(
                items[:top_n], 1
            ):
                marker = "HOMO" if is_homo else "    "
                print(
                    f"{marker} {i:2d}. {str(tok)[:20]:20s} "
                    f"samples={cnt:4d} protos={prot} buf={buflen}"
                )
            if homograph_items:
                print(
                    f"[CLUSTER-DBG] DSCD-discovered homographs: {len(homograph_items)}"
                )
                for tok, cnt, prot, buflen, _ in homograph_items[:5]:
                    print(
                        f"  HOMO {str(tok)[:20]:20s} samples={cnt:4d} protos={prot}"
                    )
    except Exception as e:
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print(f"[CLUSTER-DBG] _print_top_clusters error: {type(e).__name__}")

def _check_discovery_status(model: torch.nn.Module, global_step: int):
    try:
        core = model
        while hasattr(core, 'module'):
            core = core.module

        dscd = getattr(core, 'dscd', None)
        if dscd is None:
            return

        if hasattr(dscd, 'discovered_log') and dscd.discovered_log:
            total_discovered = len(dscd.discovered_log)

            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                print(
                    f"[DISCOVERY-STATUS] Step {global_step}: {total_discovered} discovery events"
                )

                recent = (
                    dscd.discovered_log[-3:]
                    if len(dscd.discovered_log) >= 3
                    else dscd.discovered_log
                )
                for entry in recent:
                    discovered = entry.get('discovered', 0)
                    candidates = entry.get('candidates', 0)
                    print(
                        f"  - {discovered}/{candidates} homographs discovered"
                    )
        else:
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                print(
                    f"[DISCOVERY-STATUS] No discoveries yet at step {global_step}"
                )
    except Exception as e:
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print(f"[DISCOVERY-STATUS] Error: {e}")

def train_memory_efficient_tatn(
    model: torch.nn.Module,
    tokenizer,
    train_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    phi_optimizer: Optional[torch.optim.Optimizer] = None,
    epochs: Optional[int] = None,
    accumulation_steps: Optional[int] = None,
    validate_every: Optional[int] = None,
    enable_validation: bool = True
) -> torch.nn.Module:
    if epochs is None:
        epochs = _EPOCHS
    if accumulation_steps is None:
        accumulation_steps = _ACCUMULATION_STEPS
    if validate_every is None:
        validate_every = _VALIDATION_CHECK_INTERVAL

    print(
        f"[TRAIN] Starting training: epochs={epochs}, batch={_BATCH_SIZE}, "
        f"accum_steps={accumulation_steps}"
    )
    print(
        f"[TRAIN] Validation: "
        f"{'enabled' if enable_validation and validate_every > 0 else 'disabled'}"
    )
    print(
        f"[TRAIN] DP enabled: {_USE_MULTI_GPU}, GPUs: {_NUM_GPUS}, Device: {_DEVICE}"
    )
    print(
        f"[TRAIN] Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY} steps"
    )
    print(
        "[TRAIN] Checkpoint: Will save to /kaggle/working/tatn_final.pt "
        "after all epochs"
    )

    model.train()
    clear_all_gpu_caches()
    scaler = GradScaler(enabled=(_USE_AMP and torch.cuda.is_available()))

    global_step = 0
    accumulated_steps = 0
    pending_validation = False

    training_stats: Dict[str, Any] = {
        "total_loss": [],
        "epoch_losses": [],
        "backward_losses": [],
        "batches_processed": 0,
        "optimizer_updates": 0,
        "skipped_batches": 0,
        "oom_errors": 0,
        "runtime_errors": 0,
        "exceptions": 0,
        "epoch_validations": [],
        "dscd_quality_history": [],
        "multi_sense_ratio_history": [],
        "asbn_domain_accuracy_history": [],
        "trg_explanation_history": [],
    }

    last_forward_loss = 0.0
    last_backward_loss = 0.0

    for epoch in range(1, epochs + 1):
        epoch_start = time.time()
        epoch_losses: List[float] = []
        skip_reasons = defaultdict(int)

        print(f"\n{'='*80}")
        print(f"EPOCH {epoch}/{epochs} STARTED")
        print(f"{'='*80}")

        try:
            core = model.module if hasattr(model, 'module') else model
            trg = getattr(core, 'trg_system', None)
            if trg and hasattr(trg, 'reset_statistics'):
                try:
                    trg.reset_statistics()
                    print(f"[TRAIN] TRG statistics reset for epoch {epoch}")
                except Exception:
                    pass
        except Exception as e:
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                print(f"[TRAIN] TRG stats reset failed: {e}")

        try:
            core = model.module if hasattr(model, 'module') else model
            asbn = getattr(core, 'asbn', None)
            if asbn and hasattr(asbn, 'reset_stats'):
                try:
                    asbn.reset_stats()
                    print(f"[TRAIN] ASBN statistics reset for epoch {epoch}")
                except Exception:
                    pass
        except Exception as e:
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                print(f"[TRAIN] ASBN stats reset failed: {e}")

        try:
            optimizer.zero_grad(set_to_none=True)
        except Exception:
            pass

        progress = None
        try:
            progress = tqdm(
                train_loader,
                desc=f"Epoch {epoch}/{epochs}",
                dynamic_ncols=True,
            )

            for batch_idx, batch in enumerate(progress):
                global_step += 1
                training_stats["batches_processed"] += 1

                if (
                    _DEBUG_DISCOVERY or _VERBOSE_LOGGING
                ) and global_step % DEBUG_PRINT_INTERVAL == 0:
                    print(
                        f"\n[TRAIN-DEBUG] Epoch {epoch} Batch {batch_idx} "
                        f"GlobalStep {global_step}"
                    )
                    _check_discovery_status(model, global_step)

                if (
                    enable_validation
                    and validate_every
                    and validate_every > 0
                    and (global_step % validate_every == 0)
                ):
                    if accumulated_steps == 0:
                        try:
                            optimizer.zero_grad(set_to_none=True)
                        except Exception:
                            pass

                        val_result = comprehensive_epoch_validation(
                            model,
                            tokenizer,
                            epoch,
                            global_step,
                            _SOURCE_LANGUAGE,
                            _TARGET_LANGUAGE,
                            _MAX_LENGTH,
                            _DEVICE,
                        )

                        if val_result:
                            training_stats['epoch_validations'].append(
                                val_result
                            )
                    else:
                        pending_validation = True

                if batch is None:
                    training_stats["skipped_batches"] += 1
                    skip_reasons["batch_none"] += 1
                    continue

                try:
                    input_ids = batch["input_ids"]
                    attention_mask = batch["attention_mask"]
                    labels = batch["labels"]

                    domain_labels = batch.get("domain_labels", None)
                    if domain_labels is not None:
                        if not isinstance(domain_labels, torch.Tensor):
                            domain_labels = None
                        elif domain_labels.dim() == 0:
                            domain_labels = domain_labels.unsqueeze(0)

                    if _USE_MULTI_GPU and _NUM_GPUS > 0:
                        bsz = int(input_ids.size(0))
                        keep = (bsz // _NUM_GPUS) * _NUM_GPUS
                        if keep == 0:
                            training_stats["skipped_batches"] += 1
                            skip_reasons["dp_keep_zero"] += 1
                            continue
                        if keep != bsz:
                            input_ids = input_ids[:keep]
                            attention_mask = attention_mask[:keep]
                            labels = labels[:keep]
                            if domain_labels is not None:
                                domain_labels = domain_labels[:keep]

                    input_ids = input_ids.to(_DEVICE, non_blocking=True)
                    attention_mask = attention_mask.to(
                        _DEVICE, non_blocking=True
                    )
                    labels = labels.to(_DEVICE, non_blocking=True)

                    if domain_labels is not None:
                        domain_labels = domain_labels.to(
                            _DEVICE, non_blocking=True
                        )

                    if input_ids.size(0) == 0:
                        training_stats["skipped_batches"] += 1
                        skip_reasons["empty_batch"] += 1
                        continue

                    forward_kwargs = {
                        "input_ids": input_ids,
                        "attention_mask": attention_mask,
                        "labels": labels,
                        "src_texts": batch.get("src_text", None),
                        "token_word_map": batch.get("token_word_map", None),
                    }

                    amp_ctx = get_amp_ctx()
                    with amp_ctx:
                        forward_out = model(**forward_kwargs)

                        if isinstance(forward_out, torch.Tensor):
                            loss_tensor = forward_out
                        elif isinstance(forward_out, dict) and "loss" in forward_out:
                            loss_tensor = forward_out["loss"]
                        else:
                            if isinstance(forward_out, (list, tuple)) and len(
                                forward_out
                            ) > 0 and isinstance(
                                forward_out[0], torch.Tensor
                            ):
                                loss_tensor = forward_out[0]
                            else:
                                raise RuntimeError(
                                    "Model forward did not return a recognizable loss tensor"
                                )

                        if not isinstance(loss_tensor, torch.Tensor):
                            loss_tensor = torch.tensor(
                                float(loss_tensor), device=_DEVICE
                            )
                        else:
                            loss_tensor = loss_tensor.to(_DEVICE)

                        if loss_tensor.numel() > 1:
                            loss_val = float(loss_tensor.mean().item())
                            loss_tensor = loss_tensor.mean()
                        else:
                            loss_val = float(loss_tensor.item())

                        last_forward_loss = loss_val
                        epoch_losses.append(loss_val)
                        training_stats["total_loss"].append(loss_val)

                    loss_scaled = loss_tensor / max(1, accumulation_steps)
                    last_backward_loss = float(loss_scaled.item())
                    training_stats["backward_losses"].append(last_backward_loss)

                    if scaler.is_enabled():
                        scaler.scale(loss_scaled).backward()
                    else:
                        loss_scaled.backward()

                    accumulated_steps += 1

                    if accumulated_steps >= accumulation_steps:
                        try:
                            if scaler.is_enabled():
                                scaler.unscale_(optimizer)
                                torch.nn.utils.clip_grad_norm_(
                                    model.parameters(), _GRAD_CLIP_NORM
                                )
                                scaler.step(optimizer)
                                scaler.update()
                            else:
                                torch.nn.utils.clip_grad_norm_(
                                    model.parameters(), _GRAD_CLIP_NORM
                                )
                                optimizer.step()
                            optimizer.zero_grad(set_to_none=True)
                            training_stats["optimizer_updates"] += 1
                        except RuntimeError as e:
                            if "out of memory" in str(e).lower():
                                training_stats["oom_errors"] += 1
                                training_stats["skipped_batches"] += 1
                                skip_reasons["oom"] += 1
                                print(f"[OOM] OOM at step {global_step}")
                                try:
                                    optimizer.zero_grad(set_to_none=True)
                                except Exception:
                                    pass
                                for p in model.parameters():
                                    p.grad = None
                                clear_all_gpu_caches()
                                accumulated_steps = 0
                                continue
                            else:
                                training_stats["runtime_errors"] += 1
                                skip_reasons["opt_runtime"] += 1
                                print(
                                    f"[ERROR] Runtime error during optimizer step: {type(e).__name__}"
                                )
                        except Exception as e:
                            training_stats["exceptions"] += 1
                            skip_reasons["opt_exception"] += 1
                            print(
                                f"[ERROR] Exception during optimizer step: {type(e).__name__}"
                            )
                        finally:
                            accumulated_steps = 0
                            if pending_validation:
                                try:
                                    optimizer.zero_grad(set_to_none=True)
                                except Exception:
                                    pass

                                val_result = comprehensive_epoch_validation(
                                    model,
                                    tokenizer,
                                    epoch,
                                    global_step,
                                    _SOURCE_LANGUAGE,
                                    _TARGET_LANGUAGE,
                                    _MAX_LENGTH,
                                    _DEVICE,
                                )

                                if val_result:
                                    training_stats['epoch_validations'].append(
                                        val_result
                                    )

                                pending_validation = False

                    if global_step % DEBUG_PRINT_INTERVAL == 0:
                        _print_gpu_mem("[TRAIN-DEBUG]")
                        cluster_count = _get_cluster_count(model)
                        print(
                            f"[TRAIN-DEBUG] step={global_step} "
                            f"loss={last_forward_loss:.4f} clusters={cluster_count}"
                        )
                        _print_top_clusters(model, top_n=5)

                    if global_step % _MEMORY_CLEANUP_FREQUENCY == 0:
                        clear_all_gpu_caches()

                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        training_stats["oom_errors"] += 1
                        training_stats["skipped_batches"] += 1
                        skip_reasons["oom"] += 1
                        print(f"[OOM] Caught OOM at step {global_step}")
                        try:
                            optimizer.zero_grad(set_to_none=True)
                        except Exception:
                            pass
                        for p in model.parameters():
                            p.grad = None
                        clear_all_gpu_caches()
                        accumulated_steps = 0
                        continue
                    else:
                        training_stats["runtime_errors"] += 1
                        training_stats["skipped_batches"] += 1
                        skip_reasons["runtime"] += 1
                        print(
                            f"[RUNTIME] RuntimeError at step {global_step}: {type(e).__name__}"
                        )
                        try:
                            optimizer.zero_grad(set_to_none=True)
                        except Exception:
                            pass
                        accumulated_steps = 0
                        continue
                except Exception as e:
                    training_stats["exceptions"] += 1
                    training_stats["skipped_batches"] += 1
                    skip_reasons["exceptions"] += 1
                    print(
                        f"[EXCEPTION] Exception at step {global_step}: {type(e).__name__}"
                    )
                    if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                        try:
                            traceback.print_exc()
                        except Exception:
                            pass
                    try:
                        optimizer.zero_grad(set_to_none=True)
                    except Exception:
                        pass
                    accumulated_steps = 0
                    continue

                processed_batches = (
                    training_stats["batches_processed"]
                    - training_stats["skipped_batches"]
                )
                expected_updates = max(
                    1,
                    math.floor(
                        processed_batches / max(1, accumulation_steps)
                    ),
                )
                success_rate = (
                    100.0
                    * training_stats["optimizer_updates"]
                    / expected_updates
                    if expected_updates > 0
                    else 0.0
                )
                cluster_count = _get_cluster_count(model)

                next_disc_str = "N/A"
                try:
                    if (
                        _PERIODIC_DISCOVERY_FREQUENCY
                        and _PERIODIC_DISCOVERY_FREQUENCY > 0
                    ):
                        steps_to_next = (
                            _PERIODIC_DISCOVERY_FREQUENCY
                            - (global_step % _PERIODIC_DISCOVERY_FREQUENCY)
                        )
                        next_disc_str = f"next_disc_in={steps_to_next}"
                except Exception:
                    next_disc_str = "next_disc_err"

                progress.set_postfix_str(
                    f"fwd_loss={last_forward_loss:.4f} "
                    f"bwd_loss={last_backward_loss:.4f} "
                    f"rate={success_rate:.1f}% "
                    f"clusters={cluster_count} {next_disc_str}"
                )

        finally:
            if progress is not None:
                try:
                    progress.close()
                except Exception:
                    pass

        if accumulated_steps > 0:
            try:
                if scaler.is_enabled():
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), _GRAD_CLIP_NORM
                    )
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), _GRAD_CLIP_NORM
                    )
                    optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                training_stats["optimizer_updates"] += 1
            except Exception as e:
                print(
                    f"[EPOCH-FLUSH] Exception on epoch flush: {type(e).__name__}"
                )
            finally:
                accumulated_steps = 0

        epoch_duration_min = (time.time() - epoch_start) / 60.0
        processed_batches = (
            training_stats["batches_processed"]
            - training_stats["skipped_batches"]
        )
        expected_updates = max(
            1,
            math.floor(processed_batches / max(1, accumulation_steps)),
        )
        success_rate = (
            100.0
            * training_stats["optimizer_updates"]
            / expected_updates
            if expected_updates > 0
            else 0.0
        )
        cluster_count = _get_cluster_count(model)

        avg_epoch_loss = float(np.mean(epoch_losses)) if epoch_losses else 0.0
        training_stats["epoch_losses"].append(avg_epoch_loss)

        print("\n" + "=" * 80)
        print(f"EPOCH {epoch}/{epochs} SUMMARY")
        print("=" * 80)
        print(f"  Duration (min): {epoch_duration_min:.2f}")
        print(f"  Optimizer updates: {training_stats['optimizer_updates']}")
        print(
            f"  Batches: processed={processed_batches}, "
            f"skipped={training_stats['skipped_batches']}"
        )
        print(f"  Success rate: {success_rate:.1f}%")
        print(f"  Clustered tokens: {cluster_count}")
        print(f"  Avg epoch loss: {avg_epoch_loss:.6f}")
        if skip_reasons:
            print("  Skip reasons:")
            for k, v in sorted(skip_reasons.items(), key=lambda x: -x[1]):
                print(f"    - {k}: {v}")
        print("=" * 80)

        try:
            print(
                f"\n[TRAIN] Running comprehensive validation after epoch {epoch}..."
            )

            try:
                optimizer.zero_grad(set_to_none=True)
            except Exception:
                pass

            validation_results = comprehensive_epoch_validation(
                model=model,
                tokenizer=tokenizer,
                epoch=epoch,
                global_step=global_step,
                source_lang=_SOURCE_LANGUAGE,
                target_lang=_TARGET_LANGUAGE,
                max_length=_MAX_LENGTH,
                device=_DEVICE,
            )

            if validation_results and validation_results.get(
                'validation_completed', False
            ):
                training_stats['epoch_validations'].append(
                    validation_results
                )
                training_stats['dscd_quality_history'].append(
                    validation_results.get('dscd_quality_score', 0.0)
                )
                training_stats['asbn_domain_accuracy_history'].append(
                    validation_results.get('asbn_domain_accuracy', 0.0)
                )
                training_stats['trg_explanation_history'].append(
                    validation_results.get('trg_total_explanations', 0)
                )

                try:
                    dscd = (
                        model.module.dscd
                        if hasattr(model, 'module')
                        else getattr(model, 'dscd', None)
                    )

                    lock = None
                    if dscd is not None:
                        if hasattr(dscd, 'buffer_lock'):
                            lock = dscd.buffer_lock
                        elif hasattr(dscd, 'clustering_lock'):
                            lock = dscd.clustering_lock

                    if dscd is not None:
                        if lock:
                            with lock:
                                total_tokens = len(dscd.prototype_stores)
                        else:
                            total_tokens = len(dscd.prototype_stores)

                        multi_sense = validation_results.get(
                            'dscd_multi_sense_tokens', 0
                        )
                        ratio = (
                            multi_sense / total_tokens
                            if total_tokens > 0
                            else 0.0
                        )
                        training_stats['multi_sense_ratio_history'].append(
                            ratio
                        )
                    else:
                        training_stats['multi_sense_ratio_history'].append(
                            0.0
                        )
                except Exception:
                    training_stats['multi_sense_ratio_history'].append(0.0)
            else:
                print("[TRAIN] Validation incomplete")

        except Exception as e:
            print(f"[TRAIN] Epoch validation failed: {type(e).__name__}")
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

    print(f"\n{'='*80}")
    print("TRAINING COMPLETE - SAVING FINAL CHECKPOINT")
    print(f"{'='*80}")

    try:
        checkpoint_path = Path("/kaggle/working/tatn_final.pt")

        core_model = model.module if hasattr(model, 'module') else model
        
        while hasattr(core_model, 'module'):
            core_model = core_model.module

        dscd_state = {}
        try:
            if hasattr(core_model, 'dscd'):
                try:
                    dscd_state = core_model.dscd.state_dict()
                except Exception:
                    dscd_state = {}
        except Exception:
            dscd_state = {}

        checkpoint_data = {
            'epochs_trained': epochs,
            'global_steps': global_step,
            'final_train_loss': training_stats['epoch_losses'][-1]
            if training_stats['epoch_losses']
            else 0.0,
            'model_state_dict': core_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict()
            if scaler is not None
            else None,
            'training_stats': training_stats,
            'dscd_state': dscd_state,
            'config': {
                'SPAN_THRESHOLD': globals().get('SPAN_THRESHOLD', 0.12),
                'TAU_LOW': globals().get('TAU_LOW', 0.15),
                'LAMBDA_ASBN': globals().get('LAMBDA_ASBN', 0.05),
                'LAMBDA_DSCD': globals().get('LAMBDA_DSCD', 0.15),
                'TRG_TEMPERATURE': globals().get('TRG_TEMPERATURE', 1.0),
                'PERIODIC_DISCOVERY_FREQUENCY': _PERIODIC_DISCOVERY_FREQUENCY,
                'NUM_EPOCHS': epochs,
                'BATCH_SIZE': _BATCH_SIZE,
                'LEARNING_RATE': optimizer.param_groups[0]['lr']
                if optimizer and optimizer.param_groups
                else 0.0,
            },
        }

        torch.save(checkpoint_data, checkpoint_path)

        file_size_mb = checkpoint_path.stat().st_size / (1024**2)

        print("\nFINAL CHECKPOINT SAVED")
        print(f"   Path: {checkpoint_path}")
        print(f"   Size: {file_size_mb:.2f} MB")
        print(f"   Epochs trained: {epochs}")
        print(f"   Global steps: {global_step}")
        print(
            f"   Final train loss: "
            f"{training_stats['epoch_losses'][-1] if training_stats['epoch_losses'] else 0.0:.4f}"
        )
        print(f"{'='*80}\n")

    except Exception as e:
        print(f"FINAL CHECKPOINT SAVE FAILED: {type(e).__name__}")
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            try:
                traceback.print_exc()
            except Exception:
                pass

    print("\n" + "=" * 80)
    print("TRAINING COMPLETED - FINAL SUMMARY")
    print("=" * 80)

    processed_batches = (
        training_stats["batches_processed"]
        - training_stats["skipped_batches"]
    )
    expected_updates = max(
        1,
        math.floor(processed_batches / max(1, accumulation_steps)),
    )
    success_rate = (
        100.0
        * training_stats["optimizer_updates"]
        / expected_updates
        if expected_updates > 0
        else 0.0
    )

    print(f"[TRAIN] Success Rate: {success_rate:.1f}%")
    print(f"[TRAIN] Total Steps: {global_step}")
    print(
        f"[TRAIN] Clustered Token Types: {_get_cluster_count(model)}"
    )

    if training_stats['dscd_quality_history']:
        print("\n[TRAIN] DSCD Quality Score Trend:")
        for i, score in enumerate(training_stats['dscd_quality_history'], 1):
            print(f"  Epoch {i}: {score:.1%}")

    if training_stats['asbn_domain_accuracy_history']:
        print("\n[TRAIN] ASBN Domain Accuracy Trend:")
        for i, acc in enumerate(
            training_stats['asbn_domain_accuracy_history'], 1
        ):
            print(f"  Epoch {i}: {acc:.1%}")

    if training_stats['trg_explanation_history']:
        print("\n[TRAIN] TRG Explanation Count Trend:")
        for i, count in enumerate(
            training_stats['trg_explanation_history'], 1
        ):
            print(f"  Epoch {i}: {count} explanations")

    print("=" * 80)
    return model

print("\n" + "=" * 80)
print("Cell 7: Training loop ready (PURE UNSUPERVISED) - FIXED")
print("=" * 80)
print("FIXES APPLIED:")
print("=" * 80)
print(" F1-F8: (Previous fixes)")
print(" F9: CRITICAL: Changed discovery frequency from 3000 → 50 steps")
print(" F10: Synced with Cell 6 discovery frequency")
print(" F11: Discovery now happens DURING training (not just at end)")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 8: INFERENCE PIPELINE (PURE UNSUPERVISED) - FIXED
# ==============================================================================

import os
import time
import math
import torch
import traceback
from typing import List, Dict, Any, Tuple, Optional
from collections import defaultdict
import threading
import gc

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
    _TARGET_LANGUAGE = str(TARGET_LANGUAGE)
except (NameError, TypeError):
    _SOURCE_LANGUAGE = "bn"
    _TARGET_LANGUAGE = "en"

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except (NameError, ValueError, TypeError):
    _MAX_LENGTH = 48

try:
    _DEVICE = DEVICE
except (NameError, TypeError):
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except (NameError, TypeError):
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except (NameError, TypeError):
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except (NameError, TypeError):
    _DEBUG_TIMING = False

try:
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
except (NameError, TypeError):
    _USE_MULTI_GPU = torch.cuda.is_available() and (torch.cuda.device_count() > 1)

try:
    _REAL_AMB_SPAN_THRESHOLD = float(SPAN_THRESHOLD)
except (NameError, ValueError, TypeError):
    _REAL_AMB_SPAN_THRESHOLD = 0.12

try:
    _REAL_AMB_UNCERTAINTY_THRESHOLD = float(TAU_LOW)
except (NameError, ValueError, TypeError):
    _REAL_AMB_UNCERTAINTY_THRESHOLD = 0.15

try:
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in HOMOGRAPH_REFERENCE_LIST_BN)
except (NameError, TypeError):
    _HOMOGRAPH_REFERENCE_LIST = {
        "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা",
        "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত"
    }
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in _HOMOGRAPH_REFERENCE_LIST)

try:
    _M2M100_EN_TOKEN_ID = int(M2M100_EN_TOKEN_ID)
except (NameError, ValueError, TypeError):
    _M2M100_EN_TOKEN_ID = 128022

try:
    _M2M100_BN_TOKEN_ID = int(M2M100_BN_TOKEN_ID)
except (NameError, ValueError, TypeError):
    _M2M100_BN_TOKEN_ID = 128025

_SUBWORD_PUNCT_SET = set(".,!?;:()[]{}\"'-")


def _get_dscd_homographs(model: torch.nn.Module) -> set:
    try:
        core = model.module if hasattr(model, 'module') else model
        dscd = getattr(core, 'dscd', None)
        if dscd is None:
            return set()

        if hasattr(dscd, 'get_discovered_homographs'):
            try:
                return dscd.get_discovered_homographs()
            except Exception:
                pass

        homographs = set()

        lock = None
        if hasattr(dscd, 'buffer_lock'):
            lock = dscd.buffer_lock
        elif hasattr(dscd, 'clustering_lock'):
            lock = dscd.clustering_lock

        if lock:
            with lock:
                for token, store in dscd.prototype_stores.items():
                    try:
                        if store.size() >= 2:
                            clean_token = (
                                str(token)
                                .replace('▁', '')
                                .replace('Ġ', '')
                                .replace('##', '')
                                .replace(' ', '')
                                .strip()
                                .lower()
                            )
                            homographs.add(clean_token)
                    except Exception:
                        continue
        else:
            for token, store in dscd.prototype_stores.items():
                try:
                    if store.size() >= 2:
                        clean_token = (
                            str(token)
                            .replace('▁', '')
                            .replace('Ġ', '')
                            .replace('##', '')
                            .replace(' ', '')
                            .strip()
                            .lower()
                        )
                        homographs.add(clean_token)
                except Exception:
                    continue

        return homographs
    except Exception:
        return set()


class InferenceStatistics:
    def __init__(self):
        self._lock = threading.Lock()
        self.reset()

    def reset(self):
        with self._lock:
            self.total_inferences = 0
            self.successful_translations = 0
            self.failed_translations = 0
            self.total_explanations = 0
            self.high_confidence_explanations = 0
            self.low_confidence_explanations = 0
            self.total_confidence = 0.0
            self.dscd_homographs_explained = set()
            self.reference_homographs_explained = set()
            self.avg_span = 0.0
            self.avg_uncertainty = 0.0
            self.dscd_empty_warnings = 0
            self.token_counts = defaultdict(int)
            self.token_confidences = defaultdict(list)

    def record_inference(self, result: Dict[str, Any], dscd_homographs: Optional[set] = None):
        with self._lock:
            self.total_inferences += 1

            if result.get('translation') and result['translation'] != "ERROR DURING TRANSLATION":
                self.successful_translations += 1
            else:
                self.failed_translations += 1

            explanations = result.get('explanations', [])
            self.total_explanations += len(explanations)

            for exp in explanations:
                try:
                    conf = exp.get('confidence', 0.5)
                    self.total_confidence += float(conf)

                    if conf >= 0.65:
                        self.high_confidence_explanations += 1
                    elif conf < 0.4:
                        self.low_confidence_explanations += 1

                    word = str(exp.get('ambiguous_word', '')).strip()
                    clean_word = (
                        word.replace('▁', '')
                        .replace('Ġ', '')
                        .replace('##', '')
                        .replace(' ', '')
                        .lower()
                    )

                    self.token_counts[clean_word] += 1
                    self.token_confidences[clean_word].append(float(conf))

                    if dscd_homographs and clean_word in dscd_homographs:
                        self.dscd_homographs_explained.add(clean_word)

                    if clean_word in _HOMOGRAPH_REFERENCE_LIST:
                        self.reference_homographs_explained.add(clean_word)

                    self.avg_span += float(exp.get('span', 0.0))
                    self.avg_uncertainty += float(exp.get('uncertainty', 0.0))

                except Exception:
                    pass

    def get_summary(self) -> Dict[str, Any]:
        with self._lock:
            total_exp = max(self.total_explanations, 1)

            unique_tokens = len(self.token_counts)
            diversity_ratio = unique_tokens / total_exp if total_exp > 0 else 0.0

            return {
                'total_inferences': self.total_inferences,
                'successful_translations': self.successful_translations,
                'failed_translations': self.failed_translations,
                'success_rate': self.successful_translations / max(self.total_inferences, 1),
                'total_explanations': self.total_explanations,
                'explanations_per_inference': self.total_explanations / max(self.total_inferences, 1),
                'high_confidence_rate': self.high_confidence_explanations / total_exp,
                'low_confidence_rate': self.low_confidence_explanations / total_exp,
                'avg_confidence': self.total_confidence / total_exp,
                'avg_span': self.avg_span / total_exp,
                'avg_uncertainty': self.avg_uncertainty / total_exp,
                'dscd_homographs_explained': list(self.dscd_homographs_explained),
                'reference_homographs_explained': list(self.reference_homographs_explained),
                'dscd_empty_warnings': self.dscd_empty_warnings,
                'unique_tokens_explained': unique_tokens,
                'diversity_ratio': diversity_ratio,
            }

    def print_summary(self):
        summary = self.get_summary()
        print("\n" + "=" * 80)
        print("INFERENCE STATISTICS SUMMARY")
        print("=" * 80)
        print(f"Total inferences: {summary['total_inferences']}")
        print(f"Success rate: {summary['success_rate']:.1%}")
        print(f"Total explanations: {summary['total_explanations']}")
        print(f"Explanations per inference: {summary['explanations_per_inference']:.2f}")
        print(f"Unique tokens explained: {summary['unique_tokens_explained']}")
        print(f"Diversity ratio: {summary['diversity_ratio']:.2%}")
        print(f"Avg confidence: {summary['avg_confidence']:.3f}")
        print(f"High confidence rate: {summary['high_confidence_rate']:.1%}")
        print(f"Avg span: {summary['avg_span']:.3f}")
        print(f"Avg uncertainty: {summary['avg_uncertainty']:.3f}")

        if summary['dscd_homographs_explained']:
            print(f"\nDSCD homographs explained ({len(summary['dscd_homographs_explained'])}):")
            print(f"  {', '.join(summary['dscd_homographs_explained'])}")

        if summary['reference_homographs_explained']:
            print(f"\nReference homographs explained ({len(summary['reference_homographs_explained'])}):")
            print(f"  {', '.join(summary['reference_homographs_explained'])}")

        if summary['dscd_empty_warnings'] > 0:
            print(f"\nDSCD empty warnings: {summary['dscd_empty_warnings']}")
        print("=" * 80 + "\n")


_INFERENCE_STATS = InferenceStatistics()


def _to_device_batch(enc: Any, device: torch.device):
    try:
        if hasattr(enc, "to"):
            return enc.to(device)
    except Exception:
        pass

    if isinstance(enc, dict):
        out = {}
        for k, v in enc.items():
            try:
                if isinstance(v, torch.Tensor):
                    out[k] = v.to(device)
                elif isinstance(v, dict):
                    out[k] = _to_device_batch(v, device)
                elif isinstance(v, (list, tuple)):
                    out[k] = [
                        t.to(device) if isinstance(t, torch.Tensor) else t
                        for t in v
                    ]
                else:
                    out[k] = v
            except Exception:
                out[k] = v
        return out

    return enc


def _extract_dscd_outputs(raw_out: Any) -> Dict[str, Any]:
    if raw_out is None:
        return {}

    if isinstance(raw_out, dict):
        if "dscd_outputs" in raw_out and isinstance(raw_out["dscd_outputs"], dict):
            return raw_out["dscd_outputs"]
        if "dscd" in raw_out and isinstance(raw_out["dscd"], dict):
            return raw_out["dscd"]
        if "explanations" in raw_out or "proto_probs" in raw_out:
            return raw_out

        for key in ("dscd_outputs", "dscd", "dscd_out"):
            if key in raw_out and isinstance(raw_out[key], dict):
                return raw_out[key]

        return raw_out

    if isinstance(raw_out, (list, tuple)):
        for item in raw_out:
            if isinstance(item, dict):
                return _extract_dscd_outputs(item)

    return {}


def _get_explanations_list(dscd: Dict[str, Any]) -> List[List[Dict[str, Any]]]:
    if not dscd:
        return []

    expl = dscd.get("explanations", None)
    if expl is None:
        for alt in ("explanations_per_sentence", "trg_explanations", "exps"):
            if alt in dscd:
                expl = dscd[alt]
                break

    if expl is None:
        return []

    if isinstance(expl, list):
        if len(expl) > 0 and isinstance(expl[0], dict):
            return [expl]
        if len(expl) > 0 and isinstance(expl[0], list):
            return expl

    return []


def _is_subword_token(token: str) -> bool:
    if not token or len(token.strip()) == 0:
        return True

    token = token.strip()
    if (
        token.startswith("##")
        or token.startswith("▁")
        or token.startswith("  ")
        or token.startswith("@@")
        or token.startswith(" ")
    ):
        return True

    if len(token) < 2:
        return True

    if (len(token) == 1 and token in _SUBWORD_PUNCT_SET) or token.isdigit():
        return True

    return False


def _should_filter_explanation(expl: Dict[str, Any], span_th: float, u_th: float) -> bool:
    try:
        token = expl.get('ambiguous_word', expl.get('token', ''))
        if not token or not isinstance(token, str):
            return True
        
        span = float(expl.get('span', 0.0))
        uncertainty = float(expl.get('uncertainty', 0.0))

        if _is_subword_token(str(token)):
            return True

        if span <= span_th or uncertainty <= u_th:
            return True

        return False
    except Exception:
        return True


def _force_english_bos(tokenizer, mbart_model) -> Optional[int]:
    forced_id = None
    try:
        if hasattr(tokenizer, "get_lang_id"):
            for code in (_TARGET_LANGUAGE, "en_XX", "en", "eng"):
                try:
                    lid = tokenizer.get_lang_id(code)
                    if lid is not None:
                        forced_id = int(lid)
                        break
                except Exception:
                    continue
        elif hasattr(tokenizer, "lang_code_to_id"):
            forced_id = tokenizer.lang_code_to_id.get(_TARGET_LANGUAGE, None)
            if forced_id is not None:
                forced_id = int(forced_id)
    except Exception:
        forced_id = None

    if forced_id is None:
        forced_id = _M2M100_EN_TOKEN_ID

    if forced_id is not None and hasattr(mbart_model, "config"):
        try:
            mbart_model.config.forced_bos_token_id = int(forced_id)
            mbart_model.config.decoder_start_token_id = int(forced_id)
        except Exception:
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                print("[INF] Could not set forced BOS on mbart config")

    return forced_id


def _safe_generate(
    mbart,
    input_ids=None,
    encoder_outputs=None,
    attention_mask=None,
    max_length=64,
    num_beams=2,
    **kwargs,
):
    try:
        if encoder_outputs is not None:
            return mbart.generate(
                encoder_outputs=encoder_outputs,
                attention_mask=attention_mask,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
                **kwargs,
            )
        else:
            return mbart.generate(
                input_ids,
                attention_mask=attention_mask,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=True,
                **kwargs,
            )
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            if _DEBUG_DISCOVERY:
                print("[INF] OOM during generation, reducing beam size...")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            if encoder_outputs is not None:
                return mbart.generate(
                    encoder_outputs=encoder_outputs,
                    attention_mask=attention_mask,
                    max_length=min(max_length, 48),
                    num_beams=1,
                    early_stopping=True,
                    **kwargs,
                )
            else:
                return mbart.generate(
                    input_ids,
                    attention_mask=attention_mask,
                    max_length=min(max_length, 48),
                    num_beams=1,
                    early_stopping=True,
                    **kwargs,
                )
        else:
            raise


def translate_with_explanations(
    model,
    tokenizer,
    input_sentence: str,
    device: Optional[torch.device] = None,
    span_threshold: Optional[float] = None,
    uncertainty_threshold: Optional[float] = None,
    track_stats: bool = True,
) -> Dict[str, Any]:
    device = _DEVICE if device is None else device
    span_th = _REAL_AMB_SPAN_THRESHOLD if span_threshold is None else float(span_threshold)
    u_th = _REAL_AMB_UNCERTAINTY_THRESHOLD if uncertainty_threshold is None else float(uncertainty_threshold)

    if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
        print(f"\n[INF] Starting inference:")
        print(f"[INF]   Input: {input_sentence[:60]}")
        print(f"[INF]   Thresholds: span={span_th:.2f}, uncertainty={u_th:.2f}")

    cleanup_vars = []
    dscd = None
    encoder_hidden = None
    encoder_hidden_adjusted = None

    dscd_homographs = _get_dscd_homographs(model)

    try:
        try:
            tokenizer.src_lang = _SOURCE_LANGUAGE
        except Exception:
            pass

        enc = tokenizer(
            input_sentence,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=_MAX_LENGTH,
        )
        enc = _to_device_batch(enc, device)
        cleanup_vars.append("enc")

        model.eval()
        core = model.module if (_USE_MULTI_GPU and hasattr(model, "module")) else model

        try:
            trg_system = getattr(core, 'trg_system', None)
            if trg_system is not None and hasattr(trg_system, 'eval'):
                trg_system.eval()
        except Exception:
            pass

        src_texts = [input_sentence]

        dscd_validated = False
        try:
            dscd = core.dscd if hasattr(core, 'dscd') else None
            if dscd:
                lock = None
                if hasattr(dscd, 'buffer_lock'):
                    lock = dscd.buffer_lock
                elif hasattr(dscd, 'clustering_lock'):
                    lock = dscd.clustering_lock

                if lock:
                    with lock:
                        num_stores = len(dscd.prototype_stores)
                        multi_sense = sum(
                            1
                            for store in dscd.prototype_stores.values()
                            if hasattr(store, 'centroids')
                            and len(store.centroids) >= 2
                        )
                else:
                    num_stores = len(dscd.prototype_stores)
                    multi_sense = sum(
                        1
                        for store in dscd.prototype_stores.values()
                        if hasattr(store, 'centroids')
                        and len(store.centroids) >= 2
                    )

                if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                    print(
                        f"[INF] DSCD state: {num_stores} tokens, "
                        f"{multi_sense} multi-sense, {len(dscd_homographs)} discovered"
                    )

                if num_stores == 0:
                    print("[INF] CRITICAL WARNING: DSCD prototype stores are EMPTY - run warmup first!")
                    if track_stats:
                        _INFERENCE_STATS.dscd_empty_warnings += 1
                else:
                    dscd_validated = True
        except Exception as e:
            if _DEBUG_DISCOVERY:
                print(f"[INF] DSCD validation failed: {e}")

        with torch.inference_mode():
            raw_dscd_out: Dict[str, Any] = {}

            try:
                if not hasattr(core, "mbart"):
                    raise RuntimeError("Model backend missing .mbart")

                mbart = core.mbart

                encoder_outputs_raw = mbart.model.encoder(
                    input_ids=enc.get("input_ids"),
                    attention_mask=enc.get("attention_mask"),
                )
                cleanup_vars.append("encoder_outputs_raw")

                if hasattr(encoder_outputs_raw, 'last_hidden_state'):
                    encoder_hidden = encoder_outputs_raw.last_hidden_state
                elif isinstance(encoder_outputs_raw, tuple):
                    encoder_hidden = encoder_outputs_raw[0]
                else:
                    encoder_hidden = encoder_outputs_raw
                cleanup_vars.append("encoder_hidden")

                if not isinstance(encoder_hidden, torch.Tensor) or encoder_hidden.dim() != 3:
                    raise RuntimeError(
                        f"Invalid encoder hidden: {type(encoder_hidden)}, "
                        f"shape={encoder_hidden.shape if isinstance(encoder_hidden, torch.Tensor) else 'N/A'}"
                    )

                if _DEBUG_DISCOVERY:
                    print(f"[INF] Encoder hidden: {encoder_hidden.shape}")

                if hasattr(core, "forward_with_explanations"):
                    try:
                        raw_dscd_out = core.forward_with_explanations(
                            input_ids=enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask"),
                            src_texts=src_texts,
                        )
                    except TypeError:
                        raw_dscd_out = core.forward_with_explanations(
                            enc.get("input_ids"),
                            enc.get("attention_mask"),
                            src_texts,
                        )
                else:
                    if _DEBUG_DISCOVERY:
                        print("[INF] forward_with_explanations not found, using forward()")
                    out = core.forward(
                        input_ids=enc.get("input_ids"),
                        attention_mask=enc.get("attention_mask"),
                        src_texts=src_texts,
                        labels=None,
                    )
                    if isinstance(out, dict):
                        raw_dscd_out = _extract_dscd_outputs(out)

                dscd_out = _extract_dscd_outputs(raw_dscd_out)
                if isinstance(raw_dscd_out, dict) and 'sense_augmented_embeddings' in raw_dscd_out:
                    encoder_hidden_adjusted = raw_dscd_out['sense_augmented_embeddings']
                elif 'h_augmented' in dscd_out:
                    encoder_hidden_adjusted = dscd_out['h_augmented']
                else:
                    encoder_hidden_adjusted = encoder_hidden
                cleanup_vars.append("encoder_hidden_adjusted")

                if isinstance(encoder_hidden_adjusted, torch.Tensor):
                    if encoder_hidden_adjusted.shape != encoder_hidden.shape:
                        if _DEBUG_DISCOVERY:
                            print("[INF] Shape mismatch, using original")
                        encoder_hidden_adjusted = encoder_hidden
                else:
                    encoder_hidden_adjusted = encoder_hidden

                if _DEBUG_DISCOVERY:
                    print("[INF] DSCD forward completed")

            except Exception as e:
                if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                    print(f"[INF] DSCD forward error: {e}")
                raw_dscd_out = {}
                if 'encoder_hidden' in locals() and encoder_hidden is not None:
                    encoder_hidden_adjusted = encoder_hidden
                else:
                    encoder_hidden_adjusted = None

            forced_id = _force_english_bos(tokenizer, mbart)
            orig_use_cache = (
                getattr(mbart.config, "use_cache", None)
                if hasattr(mbart, "config")
                else None
            )
            if hasattr(mbart, "config"):
                try:
                    mbart.config.use_cache = True
                except Exception:
                    pass

            try:
                if _DEBUG_DISCOVERY:
                    print("[INF] Generating translation...")

                if encoder_hidden_adjusted is not None and isinstance(
                    encoder_hidden_adjusted, torch.Tensor
                ):
                    encoder_hidden_adjusted = encoder_hidden_adjusted.to(device)

                    from transformers.modeling_outputs import BaseModelOutput

                    encoder_outputs_for_decoder = BaseModelOutput(
                        last_hidden_state=encoder_hidden_adjusted
                    )

                    generated = _safe_generate(
                        mbart,
                        encoder_outputs=encoder_outputs_for_decoder,
                        attention_mask=enc.get("attention_mask"),
                        max_length=min(_MAX_LENGTH, 64),
                        num_beams=2,
                        pad_token_id=getattr(tokenizer, "pad_token_id", None),
                        forced_bos_token_id=forced_id,
                    )
                else:
                    generated = _safe_generate(
                        mbart,
                        input_ids=enc.get("input_ids"),
                        attention_mask=enc.get("attention_mask"),
                        max_length=min(_MAX_LENGTH, 64),
                        num_beams=2,
                        pad_token_id=getattr(tokenizer, "pad_token_id", None),
                        forced_bos_token_id=forced_id,
                    )
                cleanup_vars.append("generated")

                translation = (
                    tokenizer.decode(generated[0], skip_special_tokens=True)
                    if generated is not None and len(generated) > 0
                    else ""
                )

                if _DEBUG_DISCOVERY:
                    print(f"[INF] Translation: {translation[:60]}")

            finally:
                if hasattr(mbart, "config") and orig_use_cache is not None:
                    try:
                        mbart.config.use_cache = orig_use_cache
                    except Exception:
                        pass

            if _DEBUG_DISCOVERY:
                print("[INF] Extracting explanations...")

            dscd_out = _extract_dscd_outputs(raw_dscd_out)
            explanations_list = _get_explanations_list(dscd_out)
            sentence_explanations = (
                explanations_list[0]
                if (isinstance(explanations_list, list) and len(explanations_list) > 0)
                else []
            )

            if _DEBUG_DISCOVERY:
                print(f"[INF] Raw explanations: {len(sentence_explanations)}")

            def _is_real_ambiguity(e: Dict[str, Any]) -> bool:
                try:
                    s = float(e.get("span", 0.0))
                    u = float(e.get("uncertainty", 0.0))
                    return (s > span_th) or (u > u_th)
                except Exception:
                    return False

            real_amb_count = 0
            out_explanations: List[Dict[str, Any]] = []
            filtered_count = 0

            quality_metrics = {
                'total_raw_explanations': len(sentence_explanations)
                if isinstance(sentence_explanations, list)
                else 0,
                'filtered_explanations': 0,
                'high_confidence_count': 0,
                'low_confidence_count': 0,
                'avg_confidence': 0.0,
                'avg_span': 0.0,
                'avg_uncertainty': 0.0,
            }

            confidences: List[float] = []
            spans: List[float] = []
            uncertainties: List[float] = []

            if isinstance(sentence_explanations, list):
                for ex in sentence_explanations:
                    try:
                        word = ex.get('ambiguous_word', ex.get('token', ''))
                        if isinstance(word, str):
                            clean_word = (
                                word.replace('▁', '')
                                .replace('Ġ', '')
                                .replace('##', '')
                                .replace(' ', '')
                                .strip()
                            )
                            if clean_word:
                                ex['ambiguous_word'] = clean_word

                        if _should_filter_explanation(ex, span_th, u_th):
                            filtered_count += 1
                            continue

                        is_real = _is_real_ambiguity(ex)
                        if is_real:
                            real_amb_count += 1

                        confidence = ex.get('confidence', None)
                        if confidence is None:
                            s = float(ex.get('span', 0.0))
                            u = float(ex.get('uncertainty', 0.0))
                            confidence = max(s, u)
                        confidence = float(confidence)

                        confidences.append(confidence)
                        spans.append(float(ex.get('span', 0.0)))
                        uncertainties.append(float(ex.get('uncertainty', 0.0)))

                        if confidence >= 0.65:
                            quality_metrics['high_confidence_count'] += 1
                        elif confidence < 0.4:
                            quality_metrics['low_confidence_count'] += 1

                        out_explanations.append(
                            {
                                "ambiguous_word": ex.get(
                                    "ambiguous_word", ex.get("token", "N/A")
                                ),
                                "position": ex.get(
                                    "position", ex.get("token_idx", "N/A")
                                ),
                                "explanation": ex.get("explanation", "")
                                or ex.get("explain", "")
                                or "",
                                "uncertainty": float(ex.get("uncertainty", 0.0)),
                                "span": float(ex.get("span", 0.0)),
                                "confidence": confidence,
                                "is_real_amb": bool(is_real),
                            }
                        )
                    except Exception:
                        continue

            quality_metrics['filtered_explanations'] = filtered_count
            if confidences:
                quality_metrics['avg_confidence'] = sum(confidences) / len(confidences)
                quality_metrics['avg_span'] = sum(spans) / len(spans)
                quality_metrics['avg_uncertainty'] = sum(uncertainties) / len(uncertainties)

            if _DEBUG_DISCOVERY:
                print(
                    f"[INF] Final: {len(out_explanations)} explanations "
                    f"(filtered: {filtered_count})"
                )

            result = {
                "input_sentence": input_sentence,
                "translation": translation,
                "ambiguous_words_detected": int(real_amb_count),
                "explanations": out_explanations,
                "quality_metrics": quality_metrics,
                "dscd_validated": dscd_validated,
            }

            if track_stats:
                _INFERENCE_STATS.record_inference(result, dscd_homographs=dscd_homographs)

            return result

    except Exception as e:
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print(f"[INF] ERROR: {type(e).__name__}: {str(e)[:200]}")
            try:
                traceback.print_exc()
            except Exception:
                pass

        error_result = {
            "input_sentence": input_sentence,
            "translation": "ERROR DURING TRANSLATION",
            "ambiguous_words_detected": 0,
            "explanations": [],
            "quality_metrics": {},
            "dscd_validated": False,
            "error": str(e)[:200],
        }

        if track_stats:
            _INFERENCE_STATS.record_inference(error_result, dscd_homographs=dscd_homographs)

        return error_result

    finally:
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception:
            pass

        try:
            if gc.isenabled():
                gc.collect()
        except Exception:
            pass


def demonstrate_system(model, tokenizer, sentences: Optional[List[str]] = None):
    if sentences is None:
        sentences = [
            "আমি কল বন্ধ করেছি।",
            "কাল আমি বই কিনব।",
            "পাতা ঝরে পড়েছে।",
            "তিনি ব্যাংক গেছেন।",
            "আজ ভাল আবহাওয়া।",
        ]

    print("=" * 80)
    print("TATN DEMO: Translation + Explanations")
    print("=" * 80)

    _INFERENCE_STATS.reset()

    for s in sentences:
        print(f"\nInput: {s}")
        res = translate_with_explanations(model, tokenizer, s)
        print("Translation:", res.get("translation", ""))
        print("Ambiguous words detected:", res.get("ambiguous_words_detected", 0))

        quality = res.get("quality_metrics", {})
        if quality:
            print(
                f"Quality: conf={quality.get('avg_confidence', 0):.3f}, "
                f"high={quality.get('high_confidence_count', 0)}, "
                f"low={quality.get('low_confidence_count', 0)}"
            )

        if res.get("explanations"):
            for idx, ex in enumerate(res["explanations"], 1):
                print(
                    f"  {idx}. '{ex['ambiguous_word']}' "
                    f"pos={ex['position']} conf={ex.get('confidence', 0):.3f}"
                )
                print("     ", ex.get("explanation", "")[:200])
        else:
            print("  No explanations")

    print("=" * 80)
    _INFERENCE_STATS.print_summary()


def dscd_discovery_warmup(
    model,
    tokenizer,
    num_sents: int = 8000,
    batch_size: int = 64,
    max_len: Optional[int] = None,
):
    if max_len is None:
        max_len = _MAX_LENGTH

    core = model.module if (_USE_MULTI_GPU and hasattr(model, "module")) else model

    try:
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            print("[WARMUP] Model has no dscd component")
            return

        print("\n" + "=" * 80)
        print("[WARMUP] Starting DSCD discovery warmup")
        print("=" * 80)

        orig_enable = getattr(dscd, "enable_training_clustering", False)
        orig_n_min = getattr(dscd, "n_min", None)
        orig_buffer = getattr(dscd, "buffer_size", None)

        lock = None
        if hasattr(dscd, 'buffer_lock'):
            lock = dscd.buffer_lock
        elif hasattr(dscd, 'clustering_lock'):
            lock = dscd.clustering_lock

        if lock:
            with lock:
                initial_token_count = len(dscd.prototype_stores)
        else:
            initial_token_count = len(dscd.prototype_stores)

        print(f"[WARMUP] Initial prototype stores: {initial_token_count}")

        try:
            if hasattr(dscd, "enable_training_clustering"):
                dscd.enable_training_clustering = True
            if hasattr(dscd, "n_min"):
                dscd.n_min = max(3, int(getattr(dscd, "n_min", 5)))
            if hasattr(dscd, "buffer_size"):
                dscd.buffer_size = max(200, int(getattr(dscd, "buffer_size", 300)))
        except Exception:
            pass

        texts: List[str] = []
        try:
            if "load_and_preprocess_optimized" in globals():
                pairs = load_and_preprocess_optimized(num_sents)
                texts = [bn for (bn, _) in pairs][:num_sents]
            else:
                base = [
                    "আমি কল বন্ধ করেছি।",
                    "কাল আমি বই কিনব।",
                    "পাতা ঝরে পড়েছে।",
                    "তিনি ব্যাংক গেছেন।",
                ]
                while len(texts) < num_sents:
                    texts.extend(base)
                texts = texts[:num_sents]
        except Exception:
            texts = ["আমি কল বন্ধ করেছি।"] * num_sents

        processed = 0
        core.eval()

        print(f"\n[WARMUP] Processing {len(texts)} sentences (batch={batch_size})...")

        start_time = time.time()
        last_print = start_time

        with torch.inference_mode():
            for i in range(0, len(texts), batch_size):
                batch = texts[i : i + batch_size]
                try:
                    enc = tokenizer(
                        batch,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=max_len,
                    )
                    enc = _to_device_batch(enc, _DEVICE)

                    if hasattr(core, "forward_with_explanations"):
                        core.forward_with_explanations(
                            input_ids=enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask"),
                            src_texts=batch,
                        )
                    else:
                        core.mbart.model.encoder(
                            input_ids=enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask"),
                        )

                    processed += len(batch)

                    current_time = time.time()
                    if (i // batch_size) % 10 == 0 or (current_time - last_print) > 5:
                        elapsed = current_time - start_time
                        rate = processed / elapsed if elapsed > 0 else 0
                        eta = (len(texts) - processed) / rate if rate > 0 else 0
                        print(
                            f"[WARMUP] {processed}/{len(texts)} "
                            f"({processed/len(texts)*100:.1f}%) | "
                            f"{rate:.1f} sent/s | ETA {eta:.0f}s"
                        )
                        last_print = current_time

                    del enc

                except Exception as e:
                    print(
                        f"[WARMUP] Batch {i//batch_size} failed: {str(e)[:100]}"
                    )
                    continue

        total_time = time.time() - start_time
        print(
            f"\n[WARMUP] Completed in {total_time:.1f}s "
            f"({processed/total_time:.1f} sent/s)"
        )
        print("-" * 80)

        try:
            if lock:
                with lock:
                    stores = dict(dscd.prototype_stores)
            else:
                stores = dict(dscd.prototype_stores)

            num_types = len(stores)
            total_protos = (
                sum(store.size() for store in stores.values()) if stores else 0
            )
            multi = (
                sum(1 for store in stores.values() if store.size() >= 2)
                if stores
                else 0
            )

            print("[WARMUP] Summary:")
            print(f"  - Initial token types: {initial_token_count}")
            print(f"  - Final token types: {num_types}")
            print(f"  - Growth: +{num_types - initial_token_count}")
            print(f"  - Total prototypes: {total_protos}")
            print(f"  - Multi-sense tokens: {multi}")

            if num_types > 0:
                print(f"  - Multi-sense ratio: {multi/num_types:.1%}")

            if hasattr(core, 'asbn') and hasattr(core.asbn, 'get_detailed_stats'):
                try:
                    asbn_stats = core.asbn.get_detailed_stats()
                    print(f"\n[WARMUP] ASBN Stats:")
                    print(f"  - Domain accuracy: {asbn_stats.get('domain_accuracy', 0):.2%}")
                    print(f"  - Source accuracy: {asbn_stats.get('source_accuracy', 0):.2%}")
                    print(f"  - Target accuracy: {asbn_stats.get('target_accuracy', 0):.2%}")
                except Exception:
                    pass

            dscd_homographs = _get_dscd_homographs(model)

            print(f"\n[WARMUP] Discovered Homographs: {len(dscd_homographs)}")
            if dscd_homographs:
                print(f"  Sample: {list(dscd_homographs)[:10]}")

            reference_found = dscd_homographs.intersection(_HOMOGRAPH_REFERENCE_LIST)

            print(f"\n[WARMUP] Reference List Comparison:")
            print(f"  - Reference list: {len(_HOMOGRAPH_REFERENCE_LIST)} words")
            print(f"  - Found in DSCD: {len(reference_found)}")
            print(
                f"  - Coverage: {len(reference_found)/len(_HOMOGRAPH_REFERENCE_LIST):.1%}"
            )

            if num_types == initial_token_count:
                print("\n[WARMUP] CRITICAL: NO NEW PROTOTYPES CREATED - check DSCD training mode")
            elif num_types == 0:
                print("\n[WARMUP] CRITICAL: NO PROTOTYPES IN STORES - DSCD not working")
            elif len(reference_found) < len(_HOMOGRAPH_REFERENCE_LIST) // 2:
                print("\n[WARMUP] WARNING: < 50% reference coverage - may need more data")
            else:
                print("\n[WARMUP] SUCCESS - DSCD ready for inference")

        except Exception as e:
            print(f"[WARMUP] Validation failed: {e}")

    finally:
        try:
            if dscd is not None:
                if hasattr(dscd, "enable_training_clustering"):
                    dscd.enable_training_clustering = orig_enable
                if hasattr(dscd, "n_min") and orig_n_min is not None:
                    dscd.n_min = orig_n_min
                if hasattr(dscd, "buffer_size") and orig_buffer is not None:
                    dscd.buffer_size = orig_buffer
        except Exception:
            pass

        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception:
            pass

        try:
            if gc.isenabled():
                gc.collect()
        except Exception:
            pass

        print("=" * 80 + "\n")


def load_checkpoint_for_resume(
    model: torch.nn.Module, optimizer, checkpoint_path: str
) -> Tuple[bool, int, int, float]:
    if not os.path.exists(checkpoint_path):
        print(f"[CHECKPOINT] Not found: {checkpoint_path}")
        return False, 0, 0, 0.0

    try:
        ckpt = torch.load(checkpoint_path, map_location=_DEVICE)
    except Exception as e:
        print(f"[CHECKPOINT] Load failed: {e}")
        return False, 0, 0, 0.0

    core = model.module if (_USE_MULTI_GPU and hasattr(model, "module")) else model

    state = ckpt.get("model_state_dict", ckpt)
    try:
        core.load_state_dict(state, strict=False)
    except Exception as e:
        print(f"[CHECKPOINT] model.load_state_dict failed: {e}")

        try:
            if isinstance(state, dict):
                new_state = {}
                for k, v in state.items():
                    new_key = k.replace("module.", "") if k.startswith("module.") else k
                    new_state[new_key] = v
                core.load_state_dict(new_state, strict=False)
        except Exception:
            pass

    try:
        if optimizer is not None and "optimizer_state_dict" in ckpt:
            optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    except Exception as e:
        print(f"[CHECKPOINT] optimizer.load_state_dict failed: {e}")

    try:
        if "dscd_state" in ckpt and ckpt["dscd_state"]:
            dscd_state = ckpt["dscd_state"]

            print("[CHECKPOINT] Restoring DSCD...")
            dscd = core.dscd if hasattr(core, 'dscd') else None

            if dscd and hasattr(dscd, 'load_state_dict'):
                lock = None
                if hasattr(dscd, 'buffer_lock'):
                    lock = dscd.buffer_lock
                elif hasattr(dscd, 'clustering_lock'):
                    lock = dscd.clustering_lock

                if lock:
                    with lock:
                        dscd.load_state_dict(dscd_state)
                        num_tokens = len(dscd.prototype_stores)
                        total_protos = sum(
                            store.size() for store in dscd.prototype_stores.values()
                        )
                        multi_sense = sum(
                            1
                            for store in dscd.prototype_stores.values()
                            if store.size() >= 2
                        )
                else:
                    dscd.load_state_dict(dscd_state)
                    num_tokens = len(dscd.prototype_stores)
                    total_protos = sum(
                        store.size() for store in dscd.prototype_stores.values()
                    )
                    multi_sense = sum(
                        1
                        for store in dscd.prototype_stores.values()
                        if store.size() >= 2
                    )

                print("[CHECKPOINT] DSCD restored:")
                print(f"  - Tokens: {num_tokens}")
                print(f"  - Prototypes: {total_protos}")
                print(f"  - Multi-sense: {multi_sense}")

                if num_tokens == 0 or total_protos == 0:
                    print(
                        "[CHECKPOINT] CRITICAL WARNING: DSCD state empty - run warmup before inference!"
                    )
            else:
                print("[CHECKPOINT] Model has no dscd.load_state_dict")
        else:
            print("[CHECKPOINT] No DSCD state in checkpoint")
    except Exception as e:
        print(f"[CHECKPOINT] DSCD restore failed: {e}")

    epoch = int(ckpt.get("epochs_trained", ckpt.get("epoch", 0)))
    step = int(
        ckpt.get(
            "global_steps", ckpt.get("global_step", ckpt.get("step", 0))
        )
    )
    avg_loss = float(
        ckpt.get(
            "final_train_loss",
            ckpt.get("avg_epoch_loss", ckpt.get("avg_loss", 0.0)),
        )
    )

    print(f"[CHECKPOINT] Loaded: epoch={epoch} step={step} loss={avg_loss:.6f}")
    return True, epoch, step, avg_loss


print("\n" + "=" * 80)
print("Cell 8: Inference pipeline ready (PURE UNSUPERVISED) - FIXED")
print("=" * 80)
print("Configuration:")
print(f"  - Span threshold: {_REAL_AMB_SPAN_THRESHOLD}")
print(f"  - Uncertainty threshold: {_REAL_AMB_UNCERTAINTY_THRESHOLD}")
print(f"  - Reference list: {len(_HOMOGRAPH_REFERENCE_LIST)} words")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 9: COMPREHENSIVE TESTING & EVALUATION (PURE DATA-DRIVEN) - FIXED
# ==============================================================================
from typing import Dict, List, Tuple, Optional, Any
import torch
import traceback
import time
import functools
from collections import defaultdict

try:
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
except (NameError, TypeError):
    _USE_MULTI_GPU = torch.cuda.is_available() and torch.cuda.device_count() > 1

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
except (NameError, TypeError):
    _SOURCE_LANGUAGE = "bn"

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except (NameError, TypeError):
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except (NameError, TypeError):
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except (NameError, TypeError):
    _DEBUG_TIMING = False

try:
    _SPAN_THRESHOLD = float(SPAN_THRESHOLD)
except (NameError, ValueError, TypeError):
    _SPAN_THRESHOLD = 0.12

try:
    _UNCERTAINTY_THRESHOLD = float(TAU_LOW)
except (NameError, ValueError, TypeError):
    _UNCERTAINTY_THRESHOLD = 0.15

try:
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in HOMOGRAPH_REFERENCE_LIST_BN)
except (NameError, TypeError):
    _HOMOGRAPH_REFERENCE_LIST = {
        "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা",
        "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত"
    }
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in _HOMOGRAPH_REFERENCE_LIST)

def _get_cluster_count(model: torch.nn.Module) -> int:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return 0

        lock = None
        if hasattr(dscd, 'buffer_lock'):
            lock = dscd.buffer_lock
        elif hasattr(dscd, 'clustering_lock'):
            lock = dscd.clustering_lock

        if lock:
            with lock:
                stores = getattr(dscd, "prototype_stores", {}) or {}
                return len(stores)
        else:
            stores = getattr(dscd, "prototype_stores", {}) or {}
            return len(stores)
    except Exception:
        return 0

def _get_dscd_homographs(model: torch.nn.Module) -> set:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return set()

        if hasattr(dscd, 'get_discovered_homographs'):
            try:
                return dscd.get_discovered_homographs()
            except Exception:
                pass

        homographs = set()

        lock = None
        if hasattr(dscd, 'buffer_lock'):
            lock = dscd.buffer_lock
        elif hasattr(dscd, 'clustering_lock'):
            lock = dscd.clustering_lock

        if lock:
            with lock:
                prototype_stores = getattr(dscd, "prototype_stores", {}) or {}
                for token, store in prototype_stores.items():
                    try:
                        if hasattr(store, 'size') and store.size() >= 2:
                            clean_token = (
                                str(token)
                                .replace('▁', '')
                                .replace('Ġ', '')
                                .replace('##', '')
                                .replace(' ', '')
                                .strip()
                                .lower()
                            )
                            homographs.add(clean_token)
                    except Exception:
                        continue
        else:
            prototype_stores = getattr(dscd, "prototype_stores", {}) or {}
            for token, store in prototype_stores.items():
                try:
                    if hasattr(store, 'size') and store.size() >= 2:
                        clean_token = (
                            str(token)
                            .replace('▁', '')
                            .replace('Ġ', '')
                            .replace('##', '')
                            .replace(' ', '')
                            .strip()
                            .lower()
                        )
                        homographs.add(clean_token)
                except Exception:
                    continue

        return homographs
    except Exception:
        return set()

def _print_top_clusters(model: torch.nn.Module, top_n: int = 5):
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return

        lock = None
        if hasattr(dscd, 'buffer_lock'):
            lock = dscd.buffer_lock
        elif hasattr(dscd, 'clustering_lock'):
            lock = dscd.clustering_lock

        if lock:
            with lock:
                prototype_stores = dict(getattr(dscd, "prototype_stores", {}) or {})
        else:
            prototype_stores = dict(getattr(dscd, "prototype_stores", {}) or {})

        if not prototype_stores:
            print("[CLUSTER] No clusters found yet")
            return

        cluster_info = []
        for token, store in prototype_stores.items():
            try:
                total_count = sum(getattr(store, "counts", []))
            except Exception:
                total_count = 0
            try:
                n_protos = len(getattr(store, "centroids", []))
            except Exception:
                n_protos = 0
            cluster_info.append({
                'token': token,
                'count': total_count,
                'protos': n_protos,
                'mu': getattr(store, "mu", 0.0),
                'tau': getattr(store, "tau", 0.0)
            })

        cluster_info.sort(key=lambda x: x['count'], reverse=True)

        print(f"\n[CLUSTER] Top {min(top_n, len(cluster_info))} clusters:")
        print("-" * 90)
        print(f"{'Rank':<6}{'Token':<15}{'Count':<12}{'Protos':<10}{'Mu':<15}{'Tau':<12}")
        print("-" * 90)

        for rank, info in enumerate(cluster_info[:top_n], 1):
            token_str = str(info['token'])
            token_display = token_str[:12] if len(token_str) > 12 else token_str
            print(
                f"{rank:<6}{token_display:<15}{info['count']:<12}{info['protos']:<10}"
                f"{info['mu']:<15.6f}{info['tau']:<12.6f}"
            )

        print("-" * 90)

    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[CLUSTER] Error: {str(e)[:100]}")

def _timed(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if _DEBUG_TIMING:
            start = time.time()
            result = func(*args, **kwargs)
            elapsed = time.time() - start
            print(f"[TIMING] {func.__name__}: {elapsed:.2f}s")
            return result
        else:
            return func(*args, **kwargs)
    return wrapper

@torch.inference_mode()
@_timed
def comprehensive_post_training_testing(
    model: torch.nn.Module,
    tokenizer,
    run_warmup: bool = True,
    compare_baseline: bool = False,
    baseline_metrics: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
    print("\n" + "=" * 80)
    print("COMPREHENSIVE POST-TRAINING EVALUATION (Pure Data-Driven)")
    print("=" * 80)

    test_sentences: List[Tuple[str, str, str, List[str]]] = [
        ("আমি কল বন্ধ করেছি।", "I turned off the tap", "কল = tap/call", ["কল"]),
        ("কাল আমি বই কিনব।", "Tomorrow I will buy a book", "কাল = tomorrow/yesterday", ["কাল"]),
        ("পাতা ঝরে পড়েছে।", "The leaf has fallen", "পাতা = leaf/page", ["পাতা"]),
        ("তিনি ব্যাংক গেছেন।", "He went to the bank", "ব্যাংক = bank/embankment", ["ব্যাংক"]),
        ("ফল খুব সুস্বাদু।", "The fruit is delicious", "ফল = fruit/result", ["ফল"]),
        ("মাথা ব্যথা করছে।", "Head is aching", "মাথা = head/top", ["মাথা"]),
        ("কল থেকে কল এসেছে।", "A call came from the tap", "Multiple কল", ["কল"]),
        ("কালকে কাল মেঘ দেখা গেছে।", "Yesterday black clouds were seen", "Multiple কাল", ["কাল"]),
        ("আজ ভাল আবহাওয়া।", "Weather is good today", "Simple", []),
        ("আমি ভালো আছি।", "I am fine", "Simple", []),
        ("সে খুব মিষ্টি কথা বলে।", "She speaks sweetly", "Simple", []),
        ("এটা আমার বই।", "This is my book", "Simple", []),
        ("তিনি ব্যাংকে কাজ করেন এবং ব্যাংকে বসে থাকেন।",
         "He works at the bank and sits on the embankment",
         "Long with multiple", ["ব্যাংক"]),
    ]

    core_model = model.module if (_USE_MULTI_GPU and hasattr(model, "module")) else model
    core_model.eval()

    quality_metrics = {
        'total_confidence': 0.0,
        'confidence_samples': 0,
        'high_confidence_count': 0,
        'medium_confidence_count': 0,
        'low_confidence_count': 0,
        'confidences': [],
        'spans': [],
        'uncertainties': [],
    }

    homograph_tracking = {
        'test_expected_homographs': set(),
        'dscd_discovered_homographs': set(),
        'explained_homographs': set(),
        'homograph_explanations': defaultdict(list),
    }

    error_tracking = {
        'translation_failures': 0,
        'dscd_failures': 0,
        'trg_failures': 0,
        'timeout_errors': 0,
        'oom_errors': 0,
        'other_errors': 0,
        'error_details': [],
        'per_test_status': [],
    }

    timing_metrics = {
        'total_time': 0.0,
        'per_test_times': [],
        'avg_test_time': 0.0,
    }

    discovery_validated = False
    try:
        dscd = getattr(core_model, "dscd", None)
        if dscd and hasattr(dscd, 'discovered_log') and dscd.discovered_log:
            discovery_validated = True
            last_discovery = dscd.discovered_log[-1]
            discovered = last_discovery.get('discovered', 0)
            candidates = last_discovery.get('candidates', 0)
            if _DEBUG_DISCOVERY:
                print(f"[EVAL] Discovery log: {discovered}/{candidates} homographs")
        else:
            if _DEBUG_DISCOVERY:
                print(f"[EVAL] No discovery log found")
    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[EVAL] Discovery validation failed: {e}")

    asbn_stats: Dict[str, Any] = {}
    try:
        asbn = getattr(core_model, "asbn", None)
        if asbn and hasattr(asbn, 'get_detailed_stats'):
            asbn_stats = asbn.get_detailed_stats()
        elif asbn and hasattr(asbn, 'get_asbn_stats'):
            asbn_stats = asbn.get_asbn_stats()
        
        if not asbn_stats:
            asbn_stats = {}

        if asbn_stats and _DEBUG_DISCOVERY:
            print(f"[EVAL] ASBN: domain_acc={asbn_stats.get('domain_accuracy', 0):.2%}")
            if 'source_accuracy' in asbn_stats and 'target_accuracy' in asbn_stats:
                print(f"[EVAL] ASBN: source_acc={asbn_stats.get('source_accuracy', 0):.2%}, target_acc={asbn_stats.get('target_accuracy', 0):.2%}")
    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[EVAL] ASBN stats failed: {e}")
        asbn_stats = {}

    trg_stats: Dict[str, Any] = {}
    try:
        trg = getattr(core_model, "trg_system", None)
        if trg and hasattr(trg, 'get_statistics'):
            trg_stats = trg.get_statistics()
            if _DEBUG_DISCOVERY:
                print(f"[EVAL] TRG: {trg_stats.get('explanations_generated', 0)} total")
    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[EVAL] TRG stats failed: {e}")

    homograph_tracking['dscd_discovered_homographs'] = _get_dscd_homographs(core_model)
    print(f"[EVAL] DSCD discovered: {len(homograph_tracking['dscd_discovered_homographs'])} homographs")
    if homograph_tracking['dscd_discovered_homographs'] and _DEBUG_DISCOVERY:
        print(f"[EVAL] Sample: {list(homograph_tracking['dscd_discovered_homographs'])[:10]}")

    if run_warmup:
        try:
            dscd = getattr(core_model, "dscd", None)
            if dscd is not None:
                lock = None
                if hasattr(dscd, 'buffer_lock'):
                    lock = dscd.buffer_lock
                elif hasattr(dscd, 'clustering_lock'):
                    lock = dscd.clustering_lock

                if lock:
                    with lock:
                        stores = getattr(dscd, "prototype_stores", None)
                        store_count = len(stores) if stores else 0
                        clustering_enabled = getattr(dscd, "enable_training_clustering", False)
                else:
                    stores = getattr(dscd, "prototype_stores", None)
                    store_count = len(stores) if stores else 0
                    clustering_enabled = getattr(dscd, "enable_training_clustering", False)

                if store_count == 0 and 'dscd_discovery_warmup' in globals():
                    if not clustering_enabled:
                        print("[EVAL] CRITICAL WARNING: Clustering disabled - warmup may not discover prototypes!")
                        print("[EVAL] Set DSCD_ENABLE_TRAINING_CLUSTERING=True before training")
                    print("[EVAL] Running warmup (num_sents=4000)...")
                    try:
                        dscd_discovery_warmup(model, tokenizer, num_sents=4000, batch_size=64)
                        homograph_tracking['dscd_discovered_homographs'] = _get_dscd_homographs(core_model)
                    except Exception as e:
                        print(f"[EVAL] Warmup failed: {e}")
        except Exception:
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

    total_tests = len(test_sentences)
    successful_translations = 0
    total_explanations = 0
    total_high_span = 0
    total_real_ambiguous = 0

    print(f"\n[EVAL] Running {total_tests} tests...")
    print("-" * 80)

    try:
        tokenizer.src_lang = _SOURCE_LANGUAGE
    except Exception:
        pass

    def _is_real_amb(expl: Dict[str, Any]) -> bool:
        try:
            s = float(expl.get("span", 0.0))
            u = float(expl.get("uncertainty", 0.0))
            return (s >= _SPAN_THRESHOLD) or (u >= _UNCERTAINTY_THRESHOLD)
        except Exception:
            return False

    def _compute_similarity(pred: str, expected: str) -> float:
        try:
            if not pred or not expected:
                return 0.0
            pred_words = set(str(pred).lower().split())
            exp_words = set(str(expected).lower().split())
            if not exp_words:
                return 0.0
            overlap = len(pred_words & exp_words)
            return overlap / len(exp_words)
        except Exception:
            return 0.0

    for _, _, _, expected_homos in test_sentences:
        homograph_tracking['test_expected_homographs'].update([h.lower() for h in expected_homos])

    eval_start = time.time()

    for idx, (src_text, expected_translation, desc, expected_homos) in enumerate(test_sentences, 1):
        test_start = time.time()

        print(f"\nTest {idx}/{total_tests}: {desc}")
        print("=" * 60)

        test_status = {
            'test_id': idx,
            'success': False,
            'translation_ok': False,
            'explanations_count': 0,
            'error': None,
        }

        try:
            if 'translate_with_explanations' not in globals():
                print("[EVAL] translate_with_explanations not available")
                error_tracking['other_errors'] += 1
                test_status['error'] = 'function_not_available'
                error_tracking['per_test_status'].append(test_status)
                continue

            result = translate_with_explanations(
                core_model if core_model is not None else model,
                tokenizer,
                src_text,
                span_threshold=_SPAN_THRESHOLD,
                uncertainty_threshold=_UNCERTAINTY_THRESHOLD
            )

            translation = str(result.get("translation", "") or "")
            amb_count = int(result.get("ambiguous_words_detected", 0))
            explanations = result.get("explanations", []) or []

            similarity = _compute_similarity(translation, expected_translation)

            print(f"Input: {src_text}")
            print(f"Expected: {expected_translation}")
            print(f"Translation: {translation}")
            print(f"Similarity: {similarity:.1%}")
            print(f"Ambiguous: {amb_count}")

            if explanations:
                print("\nExplanations:")
                high_span_local = 0
                real_amb_local = 0

                for j, expl in enumerate(explanations, 1):
                    span_val = float(expl.get("span", 0.0))
                    u_val = float(expl.get("uncertainty", 0.0))
                    conf_val = float(expl.get("confidence", max(span_val, u_val)))

                    marker = f"[S>={_SPAN_THRESHOLD:.2f}]" if span_val >= _SPAN_THRESHOLD else "           "

                    word = expl.get("ambiguous_word", expl.get("token", "N/A"))
                    pos = expl.get("position", expl.get("token_idx", "N/A"))

                    print(f"  {j}. {marker} '{word}' @ {pos}")
                    print(f"       conf={conf_val:.3f} | U={u_val:.3f} | S={span_val:.3f}")
                    text = str(expl.get("explanation", ""))
                    if len(text) > 120:
                        text = text[:120] + "..."
                    print(f"       {text}")

                    quality_metrics['confidences'].append(conf_val)
                    quality_metrics['spans'].append(span_val)
                    quality_metrics['uncertainties'].append(u_val)
                    quality_metrics['total_confidence'] = quality_metrics.get('total_confidence', 0.0) + conf_val
                    quality_metrics['confidence_samples'] += 1

                    if conf_val >= 0.65:
                        quality_metrics['high_confidence_count'] += 1
                    elif conf_val >= 0.4:
                        quality_metrics['medium_confidence_count'] += 1
                    else:
                        quality_metrics['low_confidence_count'] += 1

                    if span_val >= _SPAN_THRESHOLD:
                        high_span_local += 1
                    if _is_real_amb(expl):
                        real_amb_local += 1

                    clean_word = (
                        str(word)
                        .replace('▁', '')
                        .replace('Ġ', '')
                        .replace('##', '')
                        .replace(' ', '')
                        .strip()
                        .lower()
                    )
                    homograph_tracking['explained_homographs'].add(clean_word)
                    homograph_tracking['homograph_explanations'][clean_word].append({
                        'sentence': src_text,
                        'confidence': conf_val,
                        'span': span_val,
                        'uncertainty': u_val,
                    })

                total_explanations += len(explanations)
                total_high_span += high_span_local
                total_real_ambiguous += real_amb_local
                test_status['explanations_count'] = len(explanations)
            else:
                print("No explanations")

            if translation and translation.strip() and translation not in (
                "Error occurred",
                "Translation generation failed",
                "ERROR DURING TRANSLATION",
            ):
                successful_translations += 1
                test_status['translation_ok'] = True
                test_status['success'] = True
                print("Success")
            else:
                print("Translation failed")
                error_tracking['translation_failures'] += 1
                test_status['error'] = 'translation_failed'

        except RuntimeError as e:
            error_str = str(e).lower()
            if "out of memory" in error_str:
                print(f"[EVAL] OOM: {str(e)[:100]}")
                error_tracking['oom_errors'] += 1
                test_status['error'] = 'oom'
            elif "timeout" in error_str:
                print(f"[EVAL] Timeout: {str(e)[:100]}")
                error_tracking['timeout_errors'] += 1
                test_status['error'] = 'timeout'
            else:
                print(f"[EVAL] Runtime: {type(e).__name__}")
                error_tracking['other_errors'] += 1
                test_status['error'] = 'runtime'
            error_tracking['error_details'].append(f"Test {idx}: {type(e).__name__}")
        except Exception as e:
            print(f"[EVAL] Error: {type(e).__name__}")
            error_tracking['other_errors'] += 1
            test_status['error'] = type(e).__name__
            error_tracking['error_details'].append(f"Test {idx}: {type(e).__name__}")
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        error_tracking['per_test_status'].append(test_status)

        test_time = time.time() - test_start
        timing_metrics['per_test_times'].append(test_time)

        print("-" * 60)

    timing_metrics['total_time'] = time.time() - eval_start
    if timing_metrics['per_test_times']:
        timing_metrics['avg_test_time'] = (
            sum(timing_metrics['per_test_times']) / len(timing_metrics['per_test_times'])
        )

    if quality_metrics['confidence_samples'] > 0:
        quality_metrics['avg_confidence'] = (
            quality_metrics['total_confidence'] / quality_metrics['confidence_samples']
        )
        quality_metrics['avg_span'] = (
            sum(quality_metrics['spans']) / len(quality_metrics['spans'])
            if quality_metrics['spans']
            else 0.0
        )
        quality_metrics['avg_uncertainty'] = (
            sum(quality_metrics['uncertainties']) / len(quality_metrics['uncertainties'])
            if quality_metrics['uncertainties']
            else 0.0
        )

        if quality_metrics['confidences'] and len(quality_metrics['confidences']) > 0:
            sorted_conf = sorted(quality_metrics['confidences'])
            n = len(sorted_conf)
            quality_metrics['confidence_p25'] = sorted_conf[n // 4] if n >= 4 else sorted_conf[0]
            quality_metrics['confidence_p50'] = sorted_conf[n // 2]
            quality_metrics['confidence_p75'] = sorted_conf[3 * n // 4] if n >= 4 else sorted_conf[-1]
    else:
        quality_metrics['avg_confidence'] = 0.0
        quality_metrics['avg_span'] = 0.0
        quality_metrics['avg_uncertainty'] = 0.0

    explained_from_dscd = homograph_tracking['explained_homographs'].intersection(
        homograph_tracking['dscd_discovered_homographs']
    )

    test_expected_discovered = homograph_tracking['test_expected_homographs'].intersection(
        homograph_tracking['dscd_discovered_homographs']
    )

    reference_discovered = _HOMOGRAPH_REFERENCE_LIST.intersection(
        homograph_tracking['dscd_discovered_homographs']
    )

    homograph_tracking['explained_from_dscd_rate'] = (
        len(explained_from_dscd) / len(homograph_tracking['dscd_discovered_homographs'])
        if homograph_tracking['dscd_discovered_homographs']
        else 0.0
    )
    homograph_tracking['test_expected_discovery_rate'] = (
        len(test_expected_discovered) / len(homograph_tracking['test_expected_homographs'])
        if homograph_tracking['test_expected_homographs']
        else 0.0
    )
    homograph_tracking['reference_discovery_rate'] = (
        len(reference_discovered) / len(_HOMOGRAPH_REFERENCE_LIST)
        if _HOMOGRAPH_REFERENCE_LIST
        else 0.0
    )

    try:
        dscd_stats = {"total_words": 0, "multi_sense_words": 0, "total_prototypes": 0}
        dscd = getattr(core_model, "dscd", None)
        if dscd is not None and hasattr(dscd, "prototype_stores"):
            lock = None
            if hasattr(dscd, 'buffer_lock'):
                lock = dscd.buffer_lock
            elif hasattr(dscd, 'clustering_lock'):
                lock = dscd.clustering_lock

            if lock:
                with lock:
                    stores = dict(getattr(dscd, "prototype_stores") or {})
            else:
                stores = dict(getattr(dscd, "prototype_stores") or {})

            total_words = 0
            multi = 0
            total_protos = 0
            for key, store in stores.items():
                try:
                    if hasattr(store, "size"):
                        sz = int(store.size())
                    else:
                        sz = 0
                except Exception:
                    sz = 0
                total_words += 1
                total_protos += sz
                if sz >= 2:
                    multi += 1
            dscd_stats = {
                "total_words": total_words,
                "multi_sense_words": multi,
                "total_prototypes": total_protos,
            }
    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[EVAL] DSCD stats failed: {e}")
        dscd_stats = {"total_words": 0, "multi_sense_words": 0, "total_prototypes": 0}

    print("\n" + "=" * 80)
    print("COMPREHENSIVE EVALUATION SUMMARY")
    print("=" * 80)

    print(f"\n[TRANSLATION QUALITY]")
    print(f"  Total tests: {total_tests}")
    print(f"  Successful: {successful_translations}")
    print(f"  Success rate: {successful_translations / total_tests * 100:.1f}%")

    print(f"\n[AMBIGUITY DETECTION]")
    print(f"  Total explanations: {total_explanations}")
    print(f"  High-span (S>={_SPAN_THRESHOLD}): {total_high_span}")
    print(f"  Real ambiguous: {total_real_ambiguous}")
    if total_tests > 0:
        print(f"  Avg explanations/test: {total_explanations / total_tests:.2f}")

    print(f"\n[EXPLANATION QUALITY]")
    print(f"  Avg confidence: {quality_metrics['avg_confidence']:.3f}")
    print(f"  Avg span: {quality_metrics['avg_span']:.3f}")
    print(f"  Avg uncertainty: {quality_metrics['avg_uncertainty']:.3f}")

    if 'confidence_p50' in quality_metrics:
        print(
            f"  Confidence P25/P50/P75: "
            f"{quality_metrics.get('confidence_p25', 0):.3f} / "
            f"{quality_metrics.get('confidence_p50', 0):.3f} / "
            f"{quality_metrics.get('confidence_p75', 0):.3f}"
        )

    print(f"  High (>=0.65): {quality_metrics['high_confidence_count']}")
    print(f"  Medium (0.4-0.65): {quality_metrics['medium_confidence_count']}")
    print(f"  Low (<0.4): {quality_metrics['low_confidence_count']}")

    print(f"\n[HOMOGRAPH DISCOVERY]")
    print(f"  DSCD discovered: {len(homograph_tracking['dscd_discovered_homographs'])}")
    print(f"  Explained: {len(homograph_tracking['explained_homographs'])}")
    print(f"  Explanation rate: {homograph_tracking['explained_from_dscd_rate']:.1%}")
    print(f"  Test discovery rate: {homograph_tracking['test_expected_discovery_rate']:.1%}")

    if homograph_tracking['explained_homographs']:
        print(f"\n  Explained homographs (top 10):")
        for homo in sorted(homograph_tracking['explained_homographs'])[:10]:
            exps = homograph_tracking['homograph_explanations'].get(homo, [])
            count = len(exps)
            avg_conf = sum(e['confidence'] for e in exps) / len(exps) if exps else 0.0
            in_dscd = "[D]" if homo in homograph_tracking['dscd_discovered_homographs'] else "   "
            in_ref = "[R]" if homo in _HOMOGRAPH_REFERENCE_LIST else "   "
            print(f"    {in_dscd} {in_ref} '{homo}': {count} x conf={avg_conf:.3f}")

    print(f"\n[REFERENCE COMPARISON]")
    print(f"  Reference: {len(_HOMOGRAPH_REFERENCE_LIST)} words")
    print(f"  Discovered: {len(reference_discovered)}/{len(_HOMOGRAPH_REFERENCE_LIST)}")
    print(f"  Coverage: {homograph_tracking['reference_discovery_rate']:.1%}")

    print(f"\n[DSCD PROTOTYPES]")
    print(f"  Word types: {dscd_stats['total_words']}")
    print(f"  Multi-sense: {dscd_stats['multi_sense_words']}")
    print(f"  Total prototypes: {dscd_stats['total_prototypes']}")
    if dscd_stats['total_words'] > 0:
        print(
            f"  Multi-sense ratio: "
            f"{dscd_stats['multi_sense_words'] / dscd_stats['total_words']:.1%}"
        )

    if asbn_stats:
        print(f"\n[ASBN]")
        print(f"  Domain accuracy: {asbn_stats.get('domain_accuracy', 0):.2%}")
        if 'source_accuracy' in asbn_stats:
            print(f"  Source accuracy: {asbn_stats['source_accuracy']:.2%}")
        if 'target_accuracy' in asbn_stats:
            print(f"  Target accuracy: {asbn_stats['target_accuracy']:.2%}")
        if 'num_updates' in asbn_stats:
            print(f"  Updates: {asbn_stats['num_updates']}")

    if trg_stats:
        print(f"\n[TRG]")
        print(f"  Total explanations: {trg_stats.get('explanations_generated', 0)}")
        print(f"  High confidence: {trg_stats.get('high_confidence_rate', 0):.1%}")

    print(f"\n[PERFORMANCE]")
    print(f"  Total time: {timing_metrics['total_time']:.2f}s")
    print(f"  Avg time/test: {timing_metrics['avg_test_time']:.2f}s")

    total_errors = sum([
        error_tracking['translation_failures'],
        error_tracking['dscd_failures'],
        error_tracking['trg_failures'],
        error_tracking['timeout_errors'],
        error_tracking['oom_errors'],
        error_tracking['other_errors'],
    ])

    if total_errors > 0:
        print(f"\n[ERRORS]")
        print(f"  Total: {total_errors}")
        print(f"  Translation: {error_tracking['translation_failures']}")
        print(f"  OOM: {error_tracking['oom_errors']}")
        print(f"  Other: {error_tracking['other_errors']}")

    if compare_baseline and baseline_metrics and isinstance(baseline_metrics, dict):
        print(f"\n[BASELINE COMPARISON]")
        try:
            baseline_success = baseline_metrics.get('success_rate_pct', 0)
            current_success = (
                successful_translations / total_tests * 100.0
            ) if total_tests > 0 else 0.0
            success_delta = current_success - baseline_success

            baseline_expl = baseline_metrics.get('total_explanations', 0)
            expl_delta = total_explanations - baseline_expl

            baseline_quality = 0.0
            if 'quality_metrics' in baseline_metrics:
                baseline_quality_metrics = baseline_metrics['quality_metrics']
                if isinstance(baseline_quality_metrics, dict):
                    baseline_quality = baseline_quality_metrics.get('avg_confidence', 0.0)
            
            quality_delta = quality_metrics['avg_confidence'] - baseline_quality

            print(f"  Translation: {current_success:.1f}% ({success_delta:+.1f}%)")
            print(f"  Explanations: {total_explanations} ({expl_delta:+d})")
            print(
                f"  Confidence: {quality_metrics['avg_confidence']:.3f} "
                f"({quality_delta:+.3f})"
            )

            baseline_homo_rate = 0.0
            if 'homograph_tracking' in baseline_metrics:
                baseline_homo_tracking = baseline_metrics['homograph_tracking']
                if isinstance(baseline_homo_tracking, dict):
                    baseline_homo_rate = baseline_homo_tracking.get('explained_from_dscd_rate', 0.0)
            
            homo_delta = homograph_tracking['explained_from_dscd_rate'] - baseline_homo_rate
            print(
                f"  Explanation rate: "
                f"{homograph_tracking['explained_from_dscd_rate']:.1%} "
                f"({homo_delta:+.1%})"
            )
        except Exception as e:
            print(f"  Comparison failed: {e}")

    warnings = []
    if successful_translations < total_tests * 0.5:
        warnings.append("High translation failure (>50%)")
    if total_explanations == 0:
        warnings.append(f"No explanations generated - check thresholds (span={_SPAN_THRESHOLD}, U={_UNCERTAINTY_THRESHOLD})")
    if dscd_stats['total_words'] < 100:
        warnings.append("Very few prototypes (<100) - run warmup or increase training data")
    if quality_metrics['low_confidence_count'] > quality_metrics['high_confidence_count']:
        warnings.append("More low than high confidence")
    if homograph_tracking['explained_from_dscd_rate'] < 0.3:
        warnings.append("Low explanation rate (<30%)")
    if not discovery_validated:
        warnings.append("Discovery log missing")
    if asbn_stats and asbn_stats.get('domain_accuracy', 0) < 0.5:
        warnings.append("ASBN domain accuracy <50% - check Cell 4/6 fixes")

    if warnings:
        print(f"\n[WARNINGS]")
        for w in warnings:
            print(f"  - {w}")
    else:
        print(f"\n[HEALTH] All systems nominal")

    print("=" * 80)

    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    return {
        "total_tests": total_tests,
        "successful_translations": successful_translations,
        "success_rate_pct": (successful_translations / total_tests * 100.0) if total_tests > 0 else 0.0,
        "total_explanations": total_explanations,
        "total_high_span": total_high_span,
        "total_real_ambiguous": total_real_ambiguous,
        "dscd_stats": dscd_stats,
        "quality_metrics": quality_metrics,
        "homograph_tracking": homograph_tracking,
        "error_tracking": error_tracking,
        "asbn_stats": asbn_stats,
        "trg_stats": trg_stats,
        "discovery_validated": discovery_validated,
        "timing_metrics": timing_metrics,
    }

def test_evaluation_pipeline(model, tokenizer) -> bool:
    print("\n" + "="*60)
    print("[TEST] Testing evaluation pipeline")
    print("="*60)

    try:
        result = comprehensive_post_training_testing(
            model,
            tokenizer,
            run_warmup=False,
            compare_baseline=False
        )

        assert 'total_tests' in result
        assert 'quality_metrics' in result
        assert 'homograph_tracking' in result

        print("Evaluation pipeline test passed")
        print("="*60 + "\n")
        return True

    except Exception as e:
        print(f"Evaluation pipeline test failed: {e}")
        try:
            traceback.print_exc()
        except Exception:
            pass
        print("="*60 + "\n")
        return False

print("\n" + "=" * 80)
print("Cell 9: Testing & evaluation ready (PURE DATA-DRIVEN) - FIXED")
print("=" * 80)
print()
print(f"Configuration:")
print(f"  - Span threshold: {_SPAN_THRESHOLD}")
print(f"  - Uncertainty threshold: {_UNCERTAINTY_THRESHOLD}")
print(f"  - Reference list: {len(_HOMOGRAPH_REFERENCE_LIST)} words")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 10: TATN MAIN PIPELINE (FINAL INTEGRATION, ALL FIXES)
# ==============================================================================

import os
import time
import traceback
from typing import Tuple, Optional, Dict, Any
import gc
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

def _g(name, default):
    return globals().get(name, default)

try:
    _USE_MULTI_GPU = bool(_g("USE_MULTI_GPU", False))
    _NUM_GPUS = int(_g("NUM_GPUS", torch.cuda.device_count() if torch.cuda.is_available() else 0))
    _DEVICE = _g("DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    _SOURCE_LANGUAGE = str(_g("SOURCE_LANGUAGE", "bn"))
    _TARGET_LANGUAGE = str(_g("TARGET_LANGUAGE", "en"))
    _NUM_SAMPLES = int(_g("NUM_SAMPLES", 30000))
    _MAX_LENGTH = int(_g("MAX_LENGTH", 48))
    _BATCH_SIZE = int(_g("BATCH_SIZE", 8))
    _EPOCHS = int(_g("EPOCHS", 1))
    _ACCUMULATION_STEPS = int(_g("ACCUMULATION_STEPS", 1))
    _LR_NMT = float(_g("LR_NMT", 2e-5))
    _LR_PHI = float(_g("LR_PHI", 1e-5))
    _ENABLE_ASBN_TRAINING = bool(_g("ENABLE_ASBN_TRAINING", False))
    _VALIDATION_CHECK_INTERVAL = int(_g("VALIDATION_CHECK_INTERVAL", 500))
    _PERIODIC_DISCOVERY_FREQUENCY = int(_g("PERIODIC_DISCOVERY_FREQUENCY", 50))
    _DSCD_WARMUP_SAMPLES = int(_g("DSCD_WARMUP_SAMPLES", 4000))
    _HOMOGRAPH_REFERENCE_LIST_BN = set(_g("HOMOGRAPH_REFERENCE_LIST_BN",
        ["কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা"]))
    HOMOGRAPH_REFERENCE_LIST_BN = _HOMOGRAPH_REFERENCE_LIST_BN
    _FREEZE_ENCODER = bool(_g("FREEZE_ENCODER", False))
    _DEBUG_TIMING = bool(_g("DEBUG_TIMING", False))
except (ValueError, TypeError):
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _SOURCE_LANGUAGE = "bn"
    _TARGET_LANGUAGE = "en"
    _NUM_SAMPLES = 30000
    _MAX_LENGTH = 48
    _BATCH_SIZE = 8
    _EPOCHS = 1
    _ACCUMULATION_STEPS = 1
    _LR_NMT = 2e-5
    _LR_PHI = 1e-5
    _ENABLE_ASBN_TRAINING = False
    _VALIDATION_CHECK_INTERVAL = 500
    _PERIODIC_DISCOVERY_FREQUENCY = 50
    _DSCD_WARMUP_SAMPLES = 4000
    _HOMOGRAPH_REFERENCE_LIST_BN = {"কল", "কাল", "পাতা", "ব্যাংক"}
    HOMOGRAPH_REFERENCE_LIST_BN = _HOMOGRAPH_REFERENCE_LIST_BN
    _FREEZE_ENCODER = False
    _DEBUG_TIMING = False

_CHECKPOINT_DIR = "/kaggle/working"
_CHECKPOINT_PATH = os.path.join(_CHECKPOINT_DIR, "tatn_final.pt")

def _safe_clear_gpu_caches():
    try:
        if "clear_all_gpu_caches" in globals():
            globals()["clear_all_gpu_caches"]()
            return
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                try:
                    with torch.cuda.device(i):
                        torch.cuda.empty_cache()
                except Exception:
                    pass
        if gc.isenabled():
            gc.collect()
    except Exception:
        pass

def _safe_get(d: dict, *keys, default=None):
    if not isinstance(d, dict):
        return default
    result = d
    for key in keys:
        if not isinstance(result, dict):
            return default
        result = result.get(key, None)
        if result is None:
            return default
    return result

def _safe_tokenizer_from_pretrained(model_name: str, local_files_only: bool = False):
    try:
        from transformers import M2M100Tokenizer
        tok = M2M100Tokenizer.from_pretrained(model_name, local_files_only=local_files_only)
        required = ['encode', 'decode', 'convert_ids_to_tokens', '__call__']
        for method in required:
            if not hasattr(tok, method):
                raise RuntimeError(f"Tokenizer missing: {method}")
        return tok
    except Exception as e:
        print(f"[TOKENIZER] Load failed: {e}")
        raise

def initialize_environment():
    print("[PIPELINE] Initializing environment...")
    if torch.cuda.is_available():
        gcnt = torch.cuda.device_count()
        print(f"[PIPELINE] GPUs: {gcnt}")
        for i in range(gcnt):
            try:
                name = torch.cuda.get_device_name(i)
                mem = torch.cuda.get_device_properties(i).total_memory / 1024**3
                print(f"  GPU {i}: {name} ({mem:.1f} GB)")
            except Exception:
                print(f"  GPU {i}: Unknown")
        _safe_clear_gpu_caches()
    else:
        print("[PIPELINE] CPU only")
    return True

def main_pipeline() -> Tuple[object, object]:
    print("\n" + "=" * 80)
    print("TATN MAIN PIPELINE - COMPLETE INTEGRATION")
    print("=" * 80)
    
    span_thresh = _g('SPAN_THRESHOLD', None)
    uncertainty_thresh = _g('TAU_LOW', None)
    
    print(f"Configuration:")
    print(f"  - Span threshold: {span_thresh if span_thresh is not None else 'NOT SET'}")
    print(f"  - Uncertainty threshold: {uncertainty_thresh if uncertainty_thresh is not None else 'NOT SET'}")
    print(f"  - Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY}")
    print(f"  - DSCD warmup samples: {_DSCD_WARMUP_SAMPLES}")
    print(f"  - Epochs: {_EPOCHS}")
    print(f"  - Batch size: {_BATCH_SIZE}")
    print(f"  - ASBN training: {'ENABLED' if _ENABLE_ASBN_TRAINING else 'DISABLED'}")
    
    config_warnings = []
    if span_thresh is None or (isinstance(span_thresh, (int, float)) and abs(span_thresh - 0.12) > 0.001):
        config_warnings.append("SPAN_THRESHOLD not set to 0.12 - may affect explanation generation")
    if uncertainty_thresh is None or (isinstance(uncertainty_thresh, (int, float)) and abs(uncertainty_thresh - 0.15) > 0.001):
        config_warnings.append("TAU_LOW not set to 0.15 - may affect explanation generation")
    if _PERIODIC_DISCOVERY_FREQUENCY <= 0:
        config_warnings.append("Discovery frequency is 0 - periodic discovery disabled")
    elif _PERIODIC_DISCOVERY_FREQUENCY != 50:
        config_warnings.append(f"Discovery frequency is {_PERIODIC_DISCOVERY_FREQUENCY} - should be 50 for optimal performance")
    if _DSCD_WARMUP_SAMPLES < 1000:
        config_warnings.append("DSCD warmup samples < 1000 - may not discover enough prototypes")
    
    if config_warnings:
        print("\n[CONFIG WARNINGS]")
        for w in config_warnings:
            print(f"  - {w}")
    
    print("=" * 80)

    pipeline_start = time.time()
    if _DEBUG_TIMING:
        phase_start = time.time()

    initialize_environment()
    if _DEBUG_TIMING:
        print(f"[TIMING] Initialization: {time.time() - phase_start:.2f}s")
        phase_start = time.time()

    print("\n[PHASE 1] Loading tokenizer...")
    tokenizer = _safe_tokenizer_from_pretrained("facebook/m2m100_418M")
    try:
        tokenizer.src_lang = _SOURCE_LANGUAGE
    except Exception:
        pass

    try:
        if not hasattr(tokenizer, 'pad_token_id') or tokenizer.pad_token_id is None:
            if hasattr(tokenizer, 'add_special_tokens'):
                tokenizer.add_special_tokens({"pad_token": "<pad>"})
    except Exception:
        pass

    vocab_size = getattr(tokenizer, 'vocab_size', None)
    if vocab_size is None:
        try:
            vocab_size = len(tokenizer)
        except Exception:
            vocab_size = 128000

    print(f"[PHASE 1] Tokenizer loaded (vocab: {vocab_size})")
    if _DEBUG_TIMING:
        print(f"[TIMING] Tokenizer: {time.time() - phase_start:.2f}s")
        phase_start = time.time()

    print(f"\n[PHASE 2] Loading data ({_NUM_SAMPLES} samples)...")
    if "load_and_preprocess_optimized" in globals():
        try:
            pairs = load_and_preprocess_optimized(_NUM_SAMPLES)
        except Exception as e:
            print(f"[PHASE 2] Data loading failed: {e}")
            pairs = [("আমি কল বন্ধ করেছি।", "I turned off the tap.")]
    else:
        print("[PHASE 2] Using fallback data")
        pairs = [("আমি কল বন্ধ করেছি।", "I turned off the tap.")]

    if "MemoryEfficientDataset" not in globals():
        raise RuntimeError("MemoryEfficientDataset not found - run Cell 2")
    dataset = MemoryEfficientDataset(pairs, tokenizer, max_length=_MAX_LENGTH)
    collate_fn = globals().get("safe_collate", None)
    if "create_optimized_dataloader" in globals():
        try:
            train_loader = create_optimized_dataloader(dataset, batch_size=_BATCH_SIZE, shuffle=True)
        except Exception:
            dataloader_kwargs = {
                'batch_size': _BATCH_SIZE,
                'shuffle': True,
                'num_workers': 0,
                'pin_memory': torch.cuda.is_available()
            }
            if collate_fn is not None:
                dataloader_kwargs['collate_fn'] = collate_fn
            train_loader = DataLoader(dataset, **dataloader_kwargs)
    else:
        dataloader_kwargs = {
            'batch_size': _BATCH_SIZE,
            'shuffle': True,
            'num_workers': 0,
            'pin_memory': torch.cuda.is_available()
        }
        if collate_fn is not None:
            dataloader_kwargs['collate_fn'] = collate_fn
        train_loader = DataLoader(dataset, **dataloader_kwargs)

    try:
        print(f"[PHASE 2] Dataset: {len(dataset)} samples, {len(train_loader)} batches")
    except Exception:
        print("[PHASE 2] Dataset loaded")

    del pairs
    _safe_clear_gpu_caches()

    if _DEBUG_TIMING:
        print(f"[TIMING] Data loading: {time.time() - phase_start:.2f}s")
        phase_start = time.time()

    print("\n[PHASE 3] Initializing model...")
    if "MemoryOptimizedTATNWithExplanations" not in globals():
        raise RuntimeError("Model class not found - run Cell 6")

    try:
        model_core = MemoryOptimizedTATNWithExplanations(tokenizer)
    except RuntimeError as e:
        error_msg = str(e)
        if "embed_dim" in error_msg or "unexpected keyword argument" in error_msg:
            print(f"\n[ERROR] Model initialization failed: {error_msg}")
            print("\n[FIX REQUIRED IN CELL 6]")
            print("  In MemoryOptimizedTATNWithExplanations.__init__(), find:")
            print("    self.dscd = dscdcls(embed_dim=embeddim, ...)")
            print("  Replace with:")
            print("    self.dscd = dscdcls(embeddim=embeddim, ...)")
            print("  (Change 'embed_dim' to 'embeddim' - remove underscore)")
            raise RuntimeError(f"Failed to instantiate MemoryEfficientDSCDOnline: {error_msg}")
        else:
            raise
    except Exception as e:
        print(f"\n[ERROR] Model initialization failed: {type(e).__name__}: {e}")
        raise

    if hasattr(model_core, 'dscd') and model_core.dscd is not None:
        dscd_proto_stores = getattr(model_core.dscd, 'prototype_stores', None)
        clustering_enabled = getattr(model_core.dscd, 'enable_training_clustering', False)
        print(f"[PHASE 3] DSCD component initialized successfully")
        print(f"  - Clustering: {'ENABLED' if clustering_enabled else 'DISABLED'}")
        if not clustering_enabled:
            print(f"  - WARNING: Clustering disabled - no prototypes will be discovered during training!")
    else:
        print("[PHASE 3] WARNING: DSCD component missing or None")

    if _USE_MULTI_GPU and _NUM_GPUS > 1:
        device_ids = list(range(_NUM_GPUS))
        print(f"[PHASE 3] Using DataParallel on {device_ids}")
        model = nn.DataParallel(model_core, device_ids=device_ids)
    else:
        model = model_core

    model = model.to(_DEVICE)
    core_model = model.module if hasattr(model, "module") else model

    try:
        mbart = getattr(core_model, "mbart", None)
        if mbart and hasattr(mbart, "resize_token_embeddings"):
            try:
                current_size = mbart.get_input_embeddings().num_embeddings
                if isinstance(vocab_size, int):
                    target_size = vocab_size
                else:
                    target_size = current_size
                if current_size != target_size:
                    mbart.resize_token_embeddings(target_size)
                    print(f"[PHASE 3] Resized embeddings: {current_size} -> {target_size}")
            except Exception:
                pass
    except Exception:
        pass

    if _FREEZE_ENCODER:
        try:
            for p in core_model.mbart.model.encoder.parameters():
                p.requires_grad = False
            print("[PHASE 3] Encoder frozen")
        except Exception:
            pass

    print(f"[PHASE 3] Model initialized")
    if _DEBUG_TIMING:
        print(f"[TIMING] Model init: {time.time() - phase_start:.2f}s")

    print("\n[PHASE 4] Setting up optimizers...")

    try:
        critic_params = list(core_model.asbn.critic_parameters()) if hasattr(core_model, "asbn") and hasattr(core_model.asbn, "critic_parameters") else []
    except Exception:
        critic_params = []

    critic_ids = {id(p) for p in critic_params}
    base_params = [p for p in core_model.parameters() if p.requires_grad and id(p) not in critic_ids]
    
    trainable_base = [p for p in base_params if p.requires_grad]
    if len(trainable_base) == 0:
        print("[PHASE 4] WARNING: No trainable base parameters - model may not train!")
    
    optimizer = torch.optim.AdamW(base_params, lr=_LR_NMT)
    print(f"[PHASE 4] Base optimizer created ({len(trainable_base)} trainable params)")

    phi_optimizer = None
    if critic_params and _ENABLE_ASBN_TRAINING:
        trainable_critic = [p for p in critic_params if p.requires_grad]
        if len(trainable_critic) > 0:
            phi_optimizer = torch.optim.AdamW(trainable_critic, lr=_LR_PHI)
            print(f"[PHASE 4] ASBN optimizer created ({len(trainable_critic)} params)")
        else:
            print(f"[PHASE 4] WARNING: ASBN enabled but no trainable critic parameters!")
    elif _ENABLE_ASBN_TRAINING and not critic_params:
        print(f"[PHASE 4] WARNING: ASBN enabled but no critic parameters found!")

    print(f"[PHASE 4] Optimizers ready")
    if _DEBUG_TIMING:
        phase_start = time.time()

    print("\n[PHASE 5] Baseline evaluation...")
    baseline_metrics = None

    try:
        dscd = getattr(core_model, 'dscd', None)
        has_prototypes = False
        clustering_enabled = False
        should_run_baseline = True

        if dscd:
            prototype_stores = getattr(dscd, 'prototype_stores', None)
            clustering_enabled = getattr(dscd, 'enable_training_clustering', False)
            
            if prototype_stores is not None:
                lock = None
                if hasattr(dscd, 'buffer_lock'):
                    lock = dscd.buffer_lock
                elif hasattr(dscd, 'clustering_lock'):
                    lock = dscd.clustering_lock

                if lock:
                    with lock:
                        has_prototypes = len(prototype_stores) > 0
                else:
                    has_prototypes = len(prototype_stores) > 0

        if has_prototypes:
            print("[PHASE 5] Prototypes exist - skipping baseline")
            should_run_baseline = False
        elif not clustering_enabled:
            print("[PHASE 5] Clustering disabled - skipping baseline (no discoveries expected)")
            should_run_baseline = False
        
        if should_run_baseline and "comprehensive_post_training_testing" in globals():
            try:
                trg = getattr(core_model, 'trg_system', None)
                if trg and hasattr(trg, 'reset_statistics'):
                    trg.reset_statistics()
            except Exception:
                pass

            baseline_metrics = comprehensive_post_training_testing(model, tokenizer, run_warmup=False)
            baseline_success = baseline_metrics.get('success_rate_pct', 0)
            baseline_expl = baseline_metrics.get('total_explanations', 0)
            print(f"[PHASE 5] Baseline: {baseline_success:.1f}% success, {baseline_expl} explanations")
        elif not should_run_baseline:
            pass
        else:
            print("[PHASE 5] Skipping baseline (function not found)")
    except Exception as e:
        print(f"[PHASE 5] Baseline failed: {e}")

    if _DEBUG_TIMING:
        print(f"[TIMING] Baseline: {time.time() - phase_start:.2f}s")

    _safe_clear_gpu_caches()
    if _DEBUG_TIMING:
        phase_start = time.time()

    print("\n[PHASE 6] Training...")
    trained_model = model
    training_stats = None

    if "train_memory_efficient_tatn" in globals():
        try:
            try:
                trg = getattr(core_model, 'trg_system', None)
                if trg and hasattr(trg, 'reset_statistics'):
                    trg.reset_statistics()
            except Exception:
                pass
            trained_model = train_memory_efficient_tatn(
                model,
                tokenizer,
                train_loader,
                optimizer,
                phi_optimizer=phi_optimizer,
                epochs=_EPOCHS,
                accumulation_steps=_ACCUMULATION_STEPS,
                validate_every=_VALIDATION_CHECK_INTERVAL,
                enable_validation=(_VALIDATION_CHECK_INTERVAL > 0)
            )
            print("[PHASE 6] Training complete")
        except Exception as e:
            print(f"[PHASE 6] Training failed: {e}")
            if _DEBUG_TIMING:
                try:
                    traceback.print_exc()
                except Exception:
                    pass
            trained_model = model
    else:
        print("[PHASE 6] Skipping training (function not found)")

    if _DEBUG_TIMING:
        print(f"[TIMING] Training: {time.time() - phase_start:.2f}s")
    _safe_clear_gpu_caches()
    if _DEBUG_TIMING:
        phase_start = time.time()

    print("\n[PHASE 7] Discovery...")
    discovery_success = False
    try:
        core_for_discovery = trained_model.module if hasattr(trained_model, 'module') else trained_model
        dscd = getattr(core_for_discovery, 'dscd', None)
        if dscd is None:
            print("[PHASE 7] No DSCD module")
        else:
            initial_proto_count = 0
            prototype_stores = getattr(dscd, 'prototype_stores', None)
            if prototype_stores is not None:
                lock = None
                if hasattr(dscd, 'buffer_lock'):
                    lock = dscd.buffer_lock
                elif hasattr(dscd, 'clustering_lock'):
                    lock = dscd.clustering_lock
                if lock:
                    with lock:
                        initial_proto_count = len(prototype_stores)
                else:
                    initial_proto_count = len(prototype_stores)

            if hasattr(dscd, 'periodic_discovery_check') and _PERIODIC_DISCOVERY_FREQUENCY > 0:
                print("[PHASE 7] Using periodic_discovery_check()...")
                try:
                    total_steps = int(_EPOCHS * max(1, len(train_loader)))
                    dscd.periodic_discovery_check(total_steps, _PERIODIC_DISCOVERY_FREQUENCY)
                    discovery_success = True
                except Exception as e:
                    print(f"[PHASE 7] periodic_discovery_check failed: {e}")
                    if hasattr(dscd, 'discover_homographs'):
                        try:
                            print("[PHASE 7] Fallback: forcing discover_homographs()...")
                            dscd.discover_homographs()
                        except Exception as e2:
                            print(f"[PHASE 7] Fallback discovery failed: {e2}")
                    else:
                        print("[PHASE 7] No discover_homographs() method available")
            
            if prototype_stores is not None:
                if lock:
                    with lock:
                        stores = dict(prototype_stores)
                else:
                    stores = dict(prototype_stores)
                
                def _store_size(s):
                    try:
                        if callable(getattr(s, "size", None)):
                            return int(s.size())
                        return int(getattr(s, "size", 0))
                    except Exception:
                        return 0
                
                total_protos = sum(_store_size(store) for store in stores.values())
                multi_sense = sum(1 for store in stores.values() if _store_size(store) >= 2)
                
                print("[PHASE 7] Discovery complete:")
                print(f"  - Tokens: {len(stores)} (was {initial_proto_count})")
                print(f"  - Prototypes: {total_protos}")
                print(f"  - Multi-sense: {multi_sense}")
                
                if len(stores) == initial_proto_count and initial_proto_count > 0:
                    print("[PHASE 7] WARNING: No new prototypes created during discovery")
                    discovery_success = False
                elif len(stores) == 0:
                    print("[PHASE 7] CRITICAL: No prototypes created - check DSCD clustering enabled")
                    discovery_success = False
                elif total_protos > 0:
                    discovery_success = True
    except Exception as e:
        print(f"[PHASE 7] Discovery failed: {e}")

    if _DEBUG_TIMING:
        print(f"[TIMING] Discovery: {time.time() - phase_start:.2f}s")
    _safe_clear_gpu_caches()
    if _DEBUG_TIMING:
        phase_start = time.time()

    print("\n[PHASE 8] Warmup...")
    if "dscd_discovery_warmup" in globals():
        try:
            warmup_samples = _DSCD_WARMUP_SAMPLES
            dscd_discovery_warmup(trained_model, tokenizer, num_sents=warmup_samples, batch_size=64, max_len=_MAX_LENGTH)
            print(f"[PHASE 8] Warmup complete ({warmup_samples} samples)")
        except Exception as e:
            print(f"[PHASE 8] Warmup failed: {e}")
    else:
        print("[PHASE 8] Skipping warmup (function not found)")

    if _DEBUG_TIMING:
        print(f"[TIMING] Warmup: {time.time() - phase_start:.2f}s")
    _safe_clear_gpu_caches()
    if _DEBUG_TIMING:
        phase_start = time.time()

    print("\n[PHASE 9] Post-training evaluation...")
    eval_results: Dict[str, Any] = {}

    if "comprehensive_post_training_testing" in globals():
        try:
            try:
                core_for_eval = trained_model.module if hasattr(trained_model, 'module') else trained_model
                trg = getattr(core_for_eval, 'trg_system', None)
                if trg and hasattr(trg, 'reset_statistics'):
                    trg.reset_statistics()
            except Exception:
                pass
            eval_results = comprehensive_post_training_testing(
                trained_model,
                tokenizer,
                run_warmup=False,
                compare_baseline=(baseline_metrics is not None),
                baseline_metrics=baseline_metrics
            )
            final_success = eval_results.get('success_rate_pct', 0)
            final_expl = eval_results.get('total_explanations', 0)
            print(f"[PHASE 9] Evaluation: {final_success:.1f}% success, {final_expl} explanations")
        except Exception as e:
            print(f"[PHASE 9] Evaluation failed: {e}")
    else:
        print("[PHASE 9] Skipping evaluation (function not found)")

    if _DEBUG_TIMING:
        print(f"[TIMING] Evaluation: {time.time() - phase_start:.2f}s")
    _safe_clear_gpu_caches()
    if _DEBUG_TIMING:
        phase_start = time.time()

    print("\n[PHASE 10] Saving checkpoint...")
    try:
        os.makedirs(_CHECKPOINT_DIR, exist_ok=True)
        core_for_save = trained_model.module if hasattr(trained_model, "module") else trained_model
        was_training = getattr(core_for_save, "training", False)
        core_for_save.eval()
        try:
            model_state = core_for_save.state_dict()
            dscd_state = {}
            if hasattr(core_for_save, 'dscd') and core_for_save.dscd is not None and hasattr(core_for_save.dscd, 'state_dict'):
                try:
                    dscd_state = core_for_save.dscd.state_dict()
                except Exception as e:
                    print(f"[PHASE 10] DSCD state_dict failed: {e}")
                    dscd_state = {}
            
            optimizer_state = None
            if optimizer is not None:
                try:
                    optimizer_state = optimizer.state_dict()
                    if 'state' in optimizer_state and optimizer_state['state'] is not None:
                        for param_state in optimizer_state['state'].values():
                            if isinstance(param_state, dict):
                                for buffer_key in ['momentum_buffer', 'exp_avg', 'exp_avg_sq']:
                                    try:
                                        if buffer_key in param_state:
                                            del param_state[buffer_key]
                                    except Exception:
                                        pass
                except Exception as e:
                    print(f"[PHASE 10] Optimizer state failed: {e}")
                    optimizer_state = None
            
            checkpoint = {
                'model_state_dict': model_state,
                'dscd_state': dscd_state,
                'optimizer_state_dict': optimizer_state,
                'training_stats': training_stats,
                'baseline_metrics': baseline_metrics,
                'eval_results': eval_results,
                'discovery_success': discovery_success,
                'timestamp': time.strftime('%Y-%m-%d %H:%M:%S'),
                'config': {
                    'epochs': _EPOCHS,
                    'batch_size': _BATCH_SIZE,
                    'span_threshold': span_thresh,
                    'uncertainty_threshold': uncertainty_thresh,
                    'discovery_frequency': _PERIODIC_DISCOVERY_FREQUENCY,
                }
            }
            
            torch.save(checkpoint, _CHECKPOINT_PATH)
            
            try:
                verify = torch.load(_CHECKPOINT_PATH, map_location='cpu')
                has_model = 'model_state_dict' in verify and verify.get('model_state_dict') is not None and len(verify['model_state_dict']) > 0
                has_dscd = 'dscd_state' in verify and verify.get('dscd_state') is not None and len(verify.get('dscd_state', {})) > 0
                print(f"[PHASE 10] Checkpoint saved: {_CHECKPOINT_PATH}")
                print(f"  - Model: {'OK' if has_model else 'MISSING'}")
                print(f"  - DSCD: {'OK' if has_dscd else 'MISSING'}")
                if has_dscd:
                    dscd_state_dict = verify.get('dscd_state', {})
                    num_tokens = 0
                    if dscd_state_dict is not None and 'prototype_stores' in dscd_state_dict and isinstance(dscd_state_dict['prototype_stores'], dict):
                        num_tokens = len(dscd_state_dict['prototype_stores'])
                    print(f"  - DSCD tokens: {num_tokens}")
            except Exception as e:
                print(f"[PHASE 10] Checkpoint verification failed: {e}")
                print(f"[PHASE 10] Checkpoint may be corrupted - recommend re-saving")
        finally:
            if was_training:
                try:
                    core_for_save.train()
                except Exception:
                    pass
    except Exception as e:
        print(f"[PHASE 10] Checkpoint failed: {e}")
        if _DEBUG_TIMING:
            try:
                traceback.print_exc()
            except Exception:
                pass

    if _DEBUG_TIMING:
        print(f"[TIMING] Checkpoint: {time.time() - phase_start:.2f}s")
    
    print("\n[PHASE 11] Final validation...")
    try:
        core_final = trained_model.module if hasattr(trained_model, 'module') else trained_model
        dscd_ok = False
        if hasattr(core_final, 'dscd') and core_final.dscd is not None:
            prototype_stores = getattr(core_final.dscd, 'prototype_stores', None)
            if prototype_stores is not None:
                lock = None
                if hasattr(core_final.dscd, 'buffer_lock'):
                    lock = core_final.dscd.buffer_lock
                elif hasattr(core_final.dscd, 'clustering_lock'):
                    lock = core_final.clustering_lock
                if lock:
                    with lock:
                        dscd_ok = len(prototype_stores) > 0
                else:
                    dscd_ok = len(prototype_stores) > 0
        
        asbn_ok = hasattr(core_final, 'asbn') and hasattr(core_final.asbn, 'forward')
        
        trg_ok = False
        if hasattr(core_final, 'trg_system') and core_final.trg_system is not None:
            if hasattr(core_final.trg_system, 'process_sentence_for_explanations'):
                trg = core_final.trg_system
                trg_ok = not getattr(trg, 'training', True)
        
        print(f"[PHASE 11] Component validation:")
        print(f"  - DSCD: {'OK' if dscd_ok else 'MISSING/EMPTY'}")
        print(f"  - ASBN: {'OK' if asbn_ok else 'MISSING'}")
        print(f"  - TRG: {'OK' if trg_ok else 'IN TRAINING MODE' if hasattr(core_final, 'trg_system') and core_final.trg_system is not None else 'MISSING'}")
        
        all_ok = dscd_ok and asbn_ok and trg_ok
        if all_ok:
            print("[PHASE 11] All components validated ✓")
        else:
            print("[PHASE 11] Some components missing or misconfigured")
    except Exception as e:
        print(f"[PHASE 11] Validation failed: {e}")

    pipeline_time = time.time() - pipeline_start

    print("\n" + "=" * 80)
    print("PIPELINE COMPLETE - FINAL SUMMARY")
    print("=" * 80)
    print(f"\n[TIMING]")
    print(f"  Total time: {pipeline_time:.2f}s ({pipeline_time/60:.2f} min)")

    print(f"\n[TRAINING]")
    if training_stats:
        total_loss = training_stats.get('total_loss', [])
        optimizer_updates = training_stats.get('optimizer_updates', 0)
        print(f"  Completed: {optimizer_updates} optimizer updates")
        if total_loss:
            recent_loss = sum(total_loss[-100:]) / len(total_loss[-100:])
            print(f"  - Final loss: {recent_loss:.6f}")
    else:
        print("  No stats available")

    print(f"\n[DISCOVERY]")
    if discovery_success:
        print("  Success ✓")
    else:
        print("  Issues detected - check DSCD clustering enabled and discovery frequency")

    print(f"\n[EVALUATION]")
    if baseline_metrics is not None and eval_results:
        baseline_success = baseline_metrics.get('success_rate_pct', 0)
        final_success = eval_results.get('success_rate_pct', 0)
        improvement = final_success - baseline_success

        print(f"  Baseline -> Final: {baseline_success:.1f}% -> {final_success:.1f}%")
        print(f"  Improvement: {improvement:+.1f}%")

        baseline_dscd_stats = baseline_metrics.get('dscd_stats', {})
        final_dscd_stats = eval_results.get('dscd_stats', {})

        baseline_dscd = None
        if baseline_dscd_stats is not None and isinstance(baseline_dscd_stats, dict):
            baseline_dscd = baseline_dscd_stats.get('multi_sense_words', 0)
        
        final_dscd = None
        if final_dscd_stats is not None and isinstance(final_dscd_stats, dict):
            final_dscd = final_dscd_stats.get('multi_sense_words', 0)

        if baseline_dscd is not None and final_dscd is not None:
            print(f"  DSCD multi-sense: {baseline_dscd} -> {final_dscd}")

        baseline_asbn_stats = baseline_metrics.get('asbn_stats', {})
        final_asbn_stats = eval_results.get('asbn_stats', {})

        baseline_asbn = None
        if baseline_asbn_stats is not None and isinstance(baseline_asbn_stats, dict):
            baseline_asbn = baseline_asbn_stats.get('domain_accuracy', 0)
        
        final_asbn = None
        if final_asbn_stats is not None and isinstance(final_asbn_stats, dict):
            final_asbn = final_asbn_stats.get('domain_accuracy', 0)

        if baseline_asbn is not None and final_asbn is not None:
            print(f"  ASBN accuracy: {baseline_asbn:.2%} -> {final_asbn:.2%}")
    elif eval_results:
        print(f"  Success rate: {eval_results.get('success_rate_pct', 0):.1f}%")
    else:
        print("  No results")

    print(f"\n[CHECKPOINT]")
    if os.path.exists(_CHECKPOINT_PATH):
        try:
            size_mb = os.path.getsize(_CHECKPOINT_PATH) / 1024**2
            print(f"  Saved: {_CHECKPOINT_PATH}")
            print(f"  - Size: {size_mb:.2f} MB")
        except Exception:
            print(f"  Saved: {_CHECKPOINT_PATH}")
    else:
        print("  Not saved")

    print("\n" + "=" * 80)
    print("Usage: trained_model, tokenizer = main_pipeline()")
    print("=" * 80)

    _safe_clear_gpu_caches()

    return trained_model, tokenizer

print("\n" + "=" * 80)
print("Cell 10: Main pipeline ready (FINAL INTEGRATION) - FIXED")
print("=" * 80 + "\n")


In [None]:
# ==============================================================================
# CELL 11: MAIN EXECUTION WRAPPER (FINAL) - FIXED
# ==============================================================================
from datetime import datetime, timezone
import os
import traceback
import math
import sys
import time
import torch
import gc

try:
    _NUM_SAMPLES = int(globals().get('NUM_SAMPLES', 30000))
    _EPOCHS = int(globals().get('EPOCHS', 2))
    _BATCH_SIZE = int(globals().get('BATCH_SIZE', 4))
    _ACCUMULATION_STEPS = int(globals().get('ACCUMULATION_STEPS', 16))
    
    raw_device = globals().get('DEVICE', "cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(raw_device, torch.device):
        _DEVICE = raw_device
    else:
        _DEVICE = torch.device(str(raw_device))

    _ENABLE_ASBN_TRAINING = bool(globals().get('ENABLE_ASBN_TRAINING', True))
    _ENABLE_TRG_INFERENCE = bool(globals().get('ENABLE_TRG_INFERENCE', True))
    _PERIODIC_DISCOVERY_FREQUENCY = int(globals().get('PERIODIC_DISCOVERY_FREQUENCY', 50))
    _VERBOSE_LOGGING = bool(globals().get('VERBOSE_LOGGING', False))
    _DEBUG_DISCOVERY = bool(globals().get('DEBUG_DISCOVERY', False))
    _DEBUG_TIMING = bool(globals().get('DEBUG_TIMING', False))
    _NUM_GPUS = int(globals().get('NUM_GPUS', torch.cuda.device_count() if torch.cuda.is_available() else 0))
    _USE_MULTI_GPU = bool(globals().get('USE_MULTI_GPU', _NUM_GPUS > 1))
    _SPAN_THRESHOLD = float(globals().get('SPAN_THRESHOLD', 0.12))
    _TAU_LOW = float(globals().get('TAU_LOW', 0.15))
    
    raw_list = globals().get('HOMOGRAPH_REFERENCE_LIST_BN', ["কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা"])
    _HOMOGRAPH_REFERENCE_LIST_BN = set(str(w) for w in raw_list)
    cell0_loaded = 'NUM_SAMPLES' in globals()
    
except (NameError, TypeError, ValueError) as e:
    print(f"[EXEC] Config load error: {e}")
    _NUM_SAMPLES = 30000
    _EPOCHS = 2
    _BATCH_SIZE = 4
    _ACCUMULATION_STEPS = 16
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _ENABLE_ASBN_TRAINING = True
    _ENABLE_TRG_INFERENCE = True
    _PERIODIC_DISCOVERY_FREQUENCY = 50
    _VERBOSE_LOGGING = False
    _DEBUG_DISCOVERY = False
    _DEBUG_TIMING = False
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = (_NUM_GPUS > 1)
    _SPAN_THRESHOLD = 0.12
    _TAU_LOW = 0.15
    _HOMOGRAPH_REFERENCE_LIST_BN = {"কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা"}
    cell0_loaded = False
    print("[EXEC] Using fallback configuration (Cell 0 not executed)")

_CHECKPOINT_PATH = "/kaggle/working/tatn_final.pt"

def _safe_div_ceil(a: int, b: int) -> int:
    try:
        if isinstance(a, int) and isinstance(b, int) and b > 0:
            return math.ceil(a / b)
    except Exception:
        pass
    return 0

def _format_duration(seconds: float) -> str:
    if seconds < 60:
        return f"{seconds:.1f}s"
    elif seconds < 3600:
        return f"{seconds/60:.1f}min"
    else:
        return f"{seconds/3600:.2f}hr"

def _safe_get(d: dict, *keys, default=None):
    if not isinstance(d, dict):
        return default
    result = d
    for key in keys:
        if not isinstance(result, dict):
            return default
        result = result.get(key, default)
        if result is default:
            return default
    return result

def _get_dscd_homographs(model):
    try:
        core = model.module if hasattr(model, 'module') else model
        dscd = getattr(core, 'dscd', None)

        if dscd and hasattr(dscd, 'get_discovered_homographs'):
            try:
                return dscd.get_discovered_homographs()
            except Exception:
                pass

        if dscd and hasattr(dscd, 'prototype_stores'):
            homographs = set()

            lock = None
            if hasattr(dscd, 'buffer_lock'):
                lock = dscd.buffer_lock
            elif hasattr(dscd, 'clustering_lock'):
                lock = dscd.clustering_lock

            if lock:
                with lock:
                    stores = dict(dscd.prototype_stores)
            else:
                stores = dict(dscd.prototype_stores)

            for token, store in stores.items():
                try:
                    if hasattr(store, 'size') and callable(getattr(store, 'size', None)):
                        size_ok = store.size() >= 2
                    else:
                        size_val = getattr(store, 'size', None)
                        if isinstance(size_val, int):
                            size_ok = size_val >= 2
                        else:
                            size_ok = False
                except Exception:
                    size_ok = False

                if size_ok:
                    clean = (
                        str(token)
                        .replace('▁', '')
                        .replace('Ġ', '')
                        .replace('##', '')
                        .replace(' ', '')
                        .strip()
                        .lower()
                    )
                    homographs.add(clean)
            return homographs
    except Exception:
        pass
    return set()

def _safe_cleanup():
    try:
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                try:
                    with torch.cuda.device(i):
                        torch.cuda.empty_cache()
                except Exception:
                    pass
        if gc.isenabled():
            gc.collect()
    except Exception:
        pass

if __name__ == "__main__":
    print("=" * 80)
    print("MEMORY-OPTIMIZED TATN - COMPLETE EXECUTION")
    print("=" * 80)

    user_login = os.getenv("KAGGLE_USERNAME") or os.getenv("USER") or "manas0003"
    start_time = time.time()
    now_utc = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")

    print(f"User: {user_login}")
    print(f"Started: {now_utc}")

    print("\n[CONFIGURATION]")
    print(f"  Cell 0 status: {'Loaded' if cell0_loaded else 'Using fallbacks'}")
    print(f"  Samples: {_NUM_SAMPLES}")
    print(f"  Epochs: {_EPOCHS}")
    print(f"  Batch Size: {_BATCH_SIZE}")
    print(f"  Accumulation: {_ACCUMULATION_STEPS}")
    print(f"  Device: {_DEVICE}")
    print(f"  Multi-GPU: {'ENABLED' if _USE_MULTI_GPU else 'DISABLED'} ({_NUM_GPUS} GPUs)")
    print(f"  Span threshold: {_SPAN_THRESHOLD}")
    print(f"  Uncertainty threshold: {_TAU_LOW}")
    print(f"  Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY}")

    config_issues = []
    if abs(_SPAN_THRESHOLD - 0.12) > 0.001:
        config_issues.append(f"SPAN_THRESHOLD={_SPAN_THRESHOLD} should be 0.12")
    if abs(_TAU_LOW - 0.15) > 0.001:
        config_issues.append(f"TAU_LOW={_TAU_LOW} should be 0.15")
    if _PERIODIC_DISCOVERY_FREQUENCY <= 0:
        config_issues.append("Discovery frequency <= 0 (disabled)")
    elif _PERIODIC_DISCOVERY_FREQUENCY != 50:
        config_issues.append(f"Discovery frequency={_PERIODIC_DISCOVERY_FREQUENCY} should be 50")
    
    if config_issues:
        print("\n  [CONFIG WARNINGS]")
        for issue in config_issues:
            print(f"    - {issue}")

    if _USE_MULTI_GPU and _NUM_GPUS > 0:
        per_gpu = _safe_div_ceil(_BATCH_SIZE, _NUM_GPUS)
        print(f"  Batch per GPU: {per_gpu}")

    print(f"  ASBN: {'Enabled' if _ENABLE_ASBN_TRAINING else 'Disabled'}")
    print(f"  TRG: {'Enabled' if _ENABLE_TRG_INFERENCE else 'Disabled'}")
    print(f"  Debug: {'Enabled' if _DEBUG_DISCOVERY else 'Disabled'}")
    print("=" * 80)

    trained_model, tokenizer = None, None
    pipeline_success = False
    failure_category = None
    failure_details = ""

    if 'main_pipeline' not in globals():
        print("\nERROR: main_pipeline not found")
        print("   -> Run Cell 10 before executing Cell 11")
        failure_category = "MISSING_DEPENDENCY"
        failure_details = "Cell 10 not executed"
    else:
        try:
            print("\nStarting pipeline...")

            if _DEBUG_TIMING:
                print("   Expected: ~15-45 min (config dependent)")

            pipeline_start = time.time()
            trained_model, tokenizer = main_pipeline()
            pipeline_duration = time.time() - pipeline_start

            print(f"\nPipeline completed: {_format_duration(pipeline_duration)}")
            pipeline_success = True

        except KeyboardInterrupt:
            print("\nInterrupted by user")
            failure_category = "USER_INTERRUPT"
            failure_details = "Manual stop"

        except RuntimeError as e:
            msg = str(e).lower()

            if "embed_dim" in msg or "unexpected keyword argument" in msg:
                print("\nDSCD Initialization Error")
                print(f"   {str(e)}")
                failure_category = "DSCD_INIT_ERROR"
                failure_details = str(e)[:200]

                print("\nFix in Cell 6 (MemoryOptimizedTATNWithExplanations.__init__):")
                print("   Find line ~70-80:")
                print("      self.dscd = dscdcls(embed_dim=embeddim, ...)")
                print("   Replace with:")
                print("      self.dscd = dscdcls(embeddim=embeddim, ...)")
                print("   (Change parameter name from 'embed_dim' to 'embeddim')")
                print("\n   Then re-run Cell 6, Cell 10, and Cell 11")

            elif "tokenizer" in msg or "sentencepiece" in msg:
                print("\nTokenizer error")
                failure_category = "TOKENIZER_ERROR"
                failure_details = str(e)[:200]

                print("\nFix:")
                print("   ! pip install transformers==4.30.2 sentencepiece tokenizers")
                print("   Then RESTART kernel and re-run Cells 0-11")

            elif "out of memory" in msg:
                print("\nOut of Memory")
                failure_category = "OOM_ERROR"
                failure_details = "GPU OOM"

                print("\nFixes (edit in Cell 0):")
                print("   BATCH_SIZE = 2")
                print("   NUM_SAMPLES = 15000")
                print("   ACCUMULATION_STEPS = 32")
                print("   Then re-run Cells 0-11")

            else:
                print(f"\nRuntime error: {type(e).__name__}")
                print(f"   {str(e)[:400]}")
                failure_category = "RUNTIME_ERROR"
                failure_details = str(e)[:200]

            if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                print("\n[TRACEBACK]")
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        except Exception as e:
            print(f"\nUnexpected error: {type(e).__name__}")
            print(f"   {str(e)[:400]}")
            failure_category = "UNKNOWN_ERROR"
            failure_details = str(e)[:200]

            if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                print("\n[TRACEBACK]")
                try:
                    traceback.print_exc()
                except Exception:
                    pass

    if pipeline_success and trained_model is not None and tokenizer is not None:
        print("\n" + "=" * 80)
        print("PIPELINE SUCCEEDED")
        print("=" * 80)

        print("\n[CHECKPOINT]")
        checkpoint_valid = False

        try:
            if os.path.exists(_CHECKPOINT_PATH):
                size_mb = os.path.getsize(_CHECKPOINT_PATH) / (1024**2)
                print(f"  File: {_CHECKPOINT_PATH}")
                print(f"  Size: {size_mb:.1f} MB")

                ckpt = torch.load(_CHECKPOINT_PATH, map_location='cpu')

                has_model = 'model_state_dict' in ckpt and ckpt.get('model_state_dict') is not None and len(ckpt['model_state_dict']) > 0
                has_dscd = 'dscd_state' in ckpt and ckpt.get('dscd_state') is not None and len(ckpt.get('dscd_state', {})) > 0

                print(f"  Model: {'Present' if has_model else 'MISSING'}")
                print(f"  DSCD: {'Present' if has_dscd else 'MISSING'}")

                if has_dscd:
                    dscd_state_dict = ckpt.get('dscd_state', {})
                    num_tokens = 0
                    if dscd_state_dict is not None and 'prototype_stores' in dscd_state_dict:
                        proto_stores = dscd_state_dict['prototype_stores']
                        if isinstance(proto_stores, dict):
                            num_tokens = len(proto_stores)
                    
                    print(f"  Tokens: {num_tokens}")

                    if num_tokens > 0:
                        checkpoint_valid = True
                        print("  Status: VALID ✓")
                    else:
                        print("  Status: EMPTY DSCD (no prototypes)")
                        print("    -> Run warmup or check DSCD clustering enabled")
                else:
                    print("  Status: MISSING DSCD STATE")
                    print("    -> DSCD was not saved properly")
            else:
                print(f"  NOT FOUND: {_CHECKPOINT_PATH}")

        except Exception as e:
            print(f"  Validation failed: {e}")

        print("\n[COMPONENTS]")

        try:
            core = trained_model.module if hasattr(trained_model, 'module') else trained_model

            dscd = getattr(core, 'dscd', None)
            if dscd:
                if hasattr(dscd, 'get_prototype_summary'):
                    try:
                        dscd_stats = dscd.get_prototype_summary()
                        print("  DSCD:")
                        print(f"    - Tokens: {dscd_stats.get('total_tokens', 0)}")
                        print(f"    - Prototypes: {dscd_stats.get('total_prototypes', 0)}")
                        print(f"    - Homographs: {dscd_stats.get('num_homographs', 0)}")
                    except Exception:
                        pass
                else:
                    proto_stores = getattr(dscd, 'prototype_stores', None)
                    if proto_stores is not None:
                        lock = None
                        if hasattr(dscd, 'buffer_lock'):
                            lock = dscd.buffer_lock
                        elif hasattr(dscd, 'clustering_lock'):
                            lock = dscd.clustering_lock
                        
                        if lock:
                            with lock:
                                num_tokens = len(proto_stores)
                        else:
                            num_tokens = len(proto_stores)
                        
                        print("  DSCD:")
                        print(f"    - Tokens: {num_tokens}")

            asbn = getattr(core, 'asbn', None)
            if asbn and hasattr(asbn, 'get_detailed_stats'):
                try:
                    asbn_stats = asbn.get_detailed_stats()
                    print("  ASBN:")
                    print(f"    - Domain accuracy: {asbn_stats.get('domain_accuracy', 0):.2%}")
                    if 'source_accuracy' in asbn_stats:
                        print(f"    - Source: {asbn_stats['source_accuracy']:.2%}")
                        print(f"    - Target: {asbn_stats['target_accuracy']:.2%}")
                except Exception:
                    pass

            trg = getattr(core, 'trg_system', None)
            if trg and hasattr(trg, 'get_statistics'):
                try:
                    trg_stats = trg.get_statistics()
                    print("  TRG:")
                    print(f"    - Explanations: {trg_stats.get('explanations_generated', 0)}")
                    print(f"    - High confidence: {trg_stats.get('high_confidence_rate', 0):.1%}")
                    print(f"    - DSCD homograph rate: {trg_stats.get('dscd_homograph_rate', 0):.1%}")
                except Exception:
                    pass

        except Exception as e:
            print(f"  Stats failed: {e}")

        print("\n[METRICS]")

        try:
            if os.path.exists(_CHECKPOINT_PATH):
                ckpt = torch.load(_CHECKPOINT_PATH, map_location='cpu')

                training_stats = ckpt.get('training_stats', {})
                if training_stats is not None and isinstance(training_stats, dict):
                    total_loss = training_stats.get('total_loss', [])
                    updates = training_stats.get('optimizer_updates', 0)

                    print("  Training:")
                    print(f"    - Updates: {updates}")
                    if total_loss:
                        if len(total_loss) >= 100:
                            final = sum(total_loss[-100:]) / len(total_loss[-100:])
                        else:
                            final = sum(total_loss) / len(total_loss)
                        print(f"    - Final loss: {final:.6f}")

                eval_results = ckpt.get('eval_results', {})
                baseline = ckpt.get('baseline_metrics', {})

                if eval_results is not None and isinstance(eval_results, dict):
                    final_success = eval_results.get('success_rate_pct', 0)
                    total_expl = eval_results.get('total_explanations', 0)

                    print("  Evaluation:")
                    if baseline is not None and isinstance(baseline, dict):
                        baseline_success = baseline.get('success_rate_pct', 0)
                        improvement = final_success - baseline_success
                        print(f"    - Baseline -> Final: {baseline_success:.1f}% -> {final_success:.1f}%")
                        print(f"    - Improvement: {improvement:+.1f}%")
                    else:
                        print(f"    - Success: {final_success:.1f}%")

                    print(f"    - Explanations: {total_expl}")

                    quality = eval_results.get('quality_metrics', {})
                    if quality is not None and isinstance(quality, dict):
                        print(f"    - Avg confidence: {quality.get('avg_confidence', 0):.3f}")

        except Exception as e:
            print(f"  Metrics failed: {e}")

        print("\n[INFERENCE VALIDATION]")
        
        core_for_inf = trained_model.module if hasattr(trained_model, 'module') else trained_model
        trg = getattr(core_for_inf, 'trg_system', None)
        trg_mode_ok = False
        if trg is not None:
            trg_mode = getattr(trg, 'training', True)
            if trg_mode:
                print("  [WARNING] TRG in training mode - switching to eval")
                try:
                    trg.eval()
                    trg_mode_after = getattr(trg, 'training', True)
                    if not trg_mode_after:
                        trg_mode_ok = True
                        print("  [OK] TRG successfully switched to eval mode")
                    else:
                        print("  [WARNING] TRG.eval() called but still in training mode")
                except Exception as e:
                    print(f"  [ERROR] Failed to switch TRG to eval: {e}")
            else:
                trg_mode_ok = True
        
        inference_span_threshold = _SPAN_THRESHOLD
        inference_uncertainty_threshold = _TAU_LOW
        
        try:
            if os.path.exists(_CHECKPOINT_PATH):
                ckpt = torch.load(_CHECKPOINT_PATH, map_location='cpu')
                config = ckpt.get('config', {})
                if config is not None and isinstance(config, dict):
                    ckpt_span = config.get('span_threshold', None)
                    ckpt_uncertainty = config.get('uncertainty_threshold', None)
                    if ckpt_span is not None:
                        inference_span_threshold = float(ckpt_span)
                    if ckpt_uncertainty is not None:
                        inference_uncertainty_threshold = float(ckpt_uncertainty)
        except Exception:
            pass
        
        print(f"  Using thresholds: span={inference_span_threshold:.2f}, uncertainty={inference_uncertainty_threshold:.2f}")
        print("\nTesting disambiguation on ambiguous sentences...")
        print("-" * 80)

        _safe_cleanup()

        inference_success = 0
        inference_failed = 0
        dscd_homographs_detected = set()
        explained_words_all = set()

        dscd_homographs = _get_dscd_homographs(trained_model)
        print(f"DSCD discovered: {len(dscd_homographs)} homographs")
        if dscd_homographs and _DEBUG_DISCOVERY:
            print(f"  Sample: {list(dscd_homographs)[:10]}")

        test_sentences = [
            ("আমি কল বন্ধ করেছি।", "কল (tap/call)"),
            ("কাল আমি বই কিনব।", "কাল (tomorrow/yesterday)"),
            ("পাতা ঝরে পড়েছে।", "পাতা (leaf/page)"),
        ]

        inference_times = []

        try:
            if 'translate_with_explanations' not in globals():
                print("translate_with_explanations not available")
                print("   -> Run Cell 8 before Cell 11")
            else:
                for idx, (sentence, desc) in enumerate(test_sentences, 1):
                    try:
                        print(f"\n{idx}.  {desc}")
                        print(f"   Input: {sentence}")

                        inf_start = time.time()
                        res = translate_with_explanations(
                            trained_model, 
                            tokenizer, 
                            sentence,
                            span_threshold=inference_span_threshold,
                            uncertainty_threshold=inference_uncertainty_threshold
                        )
                        inf_time = time.time() - inf_start
                        inference_times.append(inf_time)

                        if isinstance(res, dict):
                            translation = res.get('translation', 'N/A')
                            amb_count = res.get('ambiguous_words_detected', 0)
                            exs = res.get('explanations', []) or []

                            print(f"   Translation: {translation}")
                            print(f"   Ambiguous: {amb_count}")
                            print(f"   Time: {inf_time:.3f}s")

                            if exs:
                                for exp in exs:
                                    word = exp.get('ambiguous_word', exp.get('token', 'N/A'))
                                    clean = (
                                        str(word)
                                        .replace('▁', '')
                                        .replace('Ġ', '')
                                        .replace('##', '')
                                        .replace(' ', '')
                                        .strip()
                                        .lower()
                                    )
                                    
                                    explained_words_all.add(clean)

                                    if clean in dscd_homographs:
                                        dscd_homographs_detected.add(clean)

                                    try:
                                        conf = float(exp.get('confidence', 0.5))
                                        span = float(exp.get('span', 0.0))
                                        u = float(exp.get('uncertainty', 0.0))
                                        
                                        in_dscd_marker = "[D]" if clean in dscd_homographs else "   "
                                        print(f"   {in_dscd_marker} '{word}': conf={conf:.3f}, s={span:.3f}, u={u:.3f}")
                                    except Exception:
                                        print(f"   -> '{word}': (no metrics)")

                                inference_success += 1
                            else:
                                print("   No explanations")
                                inference_success += 1
                        else:
                            print("   Unexpected format")
                            inference_failed += 1

                        _safe_cleanup()

                    except Exception as e:
                        print(f"   Failed: {type(e).__name__}")
                        if _DEBUG_DISCOVERY:
                            print(f"   Error: {str(e)[:100]}")
                        inference_failed += 1

                print("\n" + "-" * 80)
                print(f"Results: {inference_success}/{len(test_sentences)} successful")

                if inference_times:
                    avg_time = sum(inference_times) / len(inference_times)
                    min_time = min(inference_times)
                    max_time = max(inference_times)
                    print(f"Performance: {avg_time:.3f}s avg ({min_time:.3f}s min, {max_time:.3f}s max)")

                if dscd_homographs_detected:
                    print(f"DSCD homographs detected: {', '.join(sorted(dscd_homographs_detected))}")
                    coverage = len(dscd_homographs_detected) / len(dscd_homographs) if dscd_homographs else 0
                    print(f"  Coverage: {coverage:.1%} of DSCD discovered homographs")
                else:
                    print("No DSCD homographs detected in explanations")
                    if len(dscd_homographs) == 0:
                        print("   -> DSCD has no discoveries (run warmup)")
                    elif len(explained_words_all) > 0:
                        print(f"   -> Explanations generated but not matching DSCD ({len(explained_words_all)} unique words)")
                        print(f"   -> Check thresholds (span={inference_span_threshold}, u={inference_uncertainty_threshold})")
                    else:
                        print(f"   -> No explanations generated at all")
                        print(f"   -> Check TRG mode ({'eval' if trg_mode_ok else 'training'})")

        except Exception as e:
            print(f"Validation failed: {e}")
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        print("\n[SYSTEM TEST]")

        try:
            core = trained_model.module if hasattr(trained_model, 'module') else trained_model

            dscd_ok = False
            if hasattr(core, 'dscd') and core.dscd is not None:
                if hasattr(core.dscd, 'forward'):
                    proto_stores = getattr(core.dscd, 'prototype_stores', None)
                    if proto_stores is not None:
                        lock = None
                        if hasattr(core.dscd, 'buffer_lock'):
                            lock = core.dscd.buffer_lock
                        elif hasattr(core.dscd, 'clustering_lock'):
                            lock = core.dscd.clustering_lock
                        
                        if lock:
                            with lock:
                                dscd_ok = len(proto_stores) > 0
                        else:
                            dscd_ok = len(proto_stores) > 0
            
            asbn_ok = hasattr(core, 'asbn') and hasattr(core.asbn, 'forward')
            trg_ok = hasattr(core, 'trg_system') and hasattr(core.trg_system, 'process_sentence_for_explanations') and trg_mode_ok
            mbart_ok = hasattr(core, 'mbart') and hasattr(core.mbart, 'generate')

            print("  Component status:")
            print(f"    - DSCD: {'OK (with prototypes)' if dscd_ok else 'MISSING/EMPTY'}")
            print(f"    - ASBN: {'OK' if asbn_ok else 'MISSING'}")
            print(f"    - TRG: {'OK (eval mode)' if trg_ok else 'MISSING/TRAINING MODE'}")
            print(f"    - M2M100: {'OK' if mbart_ok else 'MISSING'}")

            all_ok = dscd_ok and asbn_ok and trg_ok and mbart_ok

            if all_ok:
                print("  All components operational ✓")
            else:
                print("  Some components missing or misconfigured")
                if not dscd_ok:
                    print("    -> DSCD: run warmup or check clustering enabled")
                if not trg_ok:
                    print("    -> TRG: switch to eval mode")

        except Exception as e:
            print(f"  Test failed: {e}")

        print("\n" + "=" * 80)
        print("NEXT STEPS")
        print("=" * 80)

        print("\n1. Single translation:")
        print("   result = translate_with_explanations(trained_model, tokenizer, 'আমি কল বন্ধ করেছি।')")

        print("\n2. Batch translation:")
        print("   for sent in sentences:")
        print("       res = translate_with_explanations(trained_model, tokenizer, sent)")

        print("\n3. Load checkpoint:")
        print("   ckpt = torch.load('/kaggle/working/tatn_final.pt')")
        print("   model.load_state_dict(ckpt['model_state_dict'])")
        print("   model.dscd.load_state_dict(ckpt['dscd_state'])")

        print("\n4. Full evaluation:")
        print("   results = comprehensive_post_training_testing(trained_model, tokenizer)")

        print("\n5. Demo:")
        print("   demonstrate_system(trained_model, tokenizer)")

        if not checkpoint_valid:
            print("\n[ACTION REQUIRED]")
            print("  Checkpoint needs verification - DSCD state may be empty")
            print("  -> Run dscd_discovery_warmup(trained_model, tokenizer) to populate")

        print("\n" + "=" * 80)

    else:
        print("\n" + "=" * 80)
        print("PIPELINE FAILED")
        print("=" * 80)

        print(f"\nCategory: {failure_category or 'UNKNOWN'}")
        if failure_details:
            print(f"Details: {failure_details[:200]}")

        print("\n[DIAGNOSTICS]")

        components = {
            'Cell 0': 'NUM_SAMPLES' in globals(),
            'Cell 1': 'reconstruct_word_spans' in globals(),
            'Cell 2': 'MemoryEfficientDataset' in globals(),
            'Cell 3': 'MemoryEfficientDSCDOnline' in globals(),
            'Cell 4': 'MemoryEfficientASBNModule' in globals(),
            'Cell 5': 'CompleteTRGWithExplanations' in globals(),
            'Cell 6': 'MemoryOptimizedTATNWithExplanations' in globals(),
            'Cell 7': 'train_memory_efficient_tatn' in globals(),
            'Cell 8': 'translate_with_explanations' in globals(),
            'Cell 9': 'comprehensive_post_training_testing' in globals(),
            'Cell 10': 'main_pipeline' in globals(),
        }

        all_present = True
        for comp, present in components.items():
            status = "OK" if present else "MISSING"
            print(f"  {status:7} {comp}")
            if not present:
                all_present = False

        if all_present:
            span_thresh_global = globals().get('SPAN_THRESHOLD', None)
            tau_low_global = globals().get('TAU_LOW', None)
            discovery_freq_global = globals().get('PERIODIC_DISCOVERY_FREQUENCY', None)
            
            threshold_ok = True
            if span_thresh_global is not None:
                if abs(float(span_thresh_global) - 0.12) > 0.001:
                    print(f"\n  [WARNING] SPAN_THRESHOLD={span_thresh_global} should be 0.12 (set in Cell 0)")
                    threshold_ok = False
            
            if tau_low_global is not None:
                if abs(float(tau_low_global) - 0.15) > 0.001:
                    print(f"  [WARNING] TAU_LOW={tau_low_global} should be 0.15 (set in Cell 0)")
                    threshold_ok = False
            
            if discovery_freq_global is not None:
                if int(discovery_freq_global) != 50:
                    print(f"  [WARNING] PERIODIC_DISCOVERY_FREQUENCY={discovery_freq_global} should be 50 (set in Cell 0)")
                    threshold_ok = False
            
            if not threshold_ok:
                print("  -> Fix thresholds in Cell 0 and re-run Cells 0-11")

        print("\n[RECOVERY]")

        if failure_category == "MISSING_DEPENDENCY":
            print("\n-> Run Cells 0-10 in sequence, then re-run Cell 11")
            print("   Order: 0 -> 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 -> 8 -> 9 -> 10 -> 11")

        elif failure_category == "DSCD_INIT_ERROR":
            print("\n-> Fix parameter name in Cell 6:")
            print("   In MemoryOptimizedTATNWithExplanations.__init__() around line 70-80:")
            print("   Search for:")
            print("      self.dscd = dscdcls(embed_dim=embeddim, ...)")
            print("   Replace with:")
            print("      self.dscd = dscdcls(embeddim=embeddim, ...)")
            print("   (Change 'embed_dim' to 'embeddim' - no underscore)")
            print("\n   Then re-run: Cell 6 -> Cell 10 -> Cell 11")

        elif failure_category == "TOKENIZER_ERROR":
            print("\n-> Install dependencies:")
            print("   ! pip install transformers==4.30.2 sentencepiece tokenizers")
            print("   Then RESTART kernel and re-run Cells 0-11")

        elif failure_category == "OOM_ERROR":
            print("\n-> Reduce memory usage in Cell 0:")
            print("   BATCH_SIZE = 2")
            print("   NUM_SAMPLES = 15000")
            print("   ACCUMULATION_STEPS = 32")
            print("   Then re-run Cells 0-11")

        elif failure_category == "RUNTIME_ERROR":
            print("\n-> Enable debug in Cell 0:")
            print("   VERBOSE_LOGGING = True")
            print("   DEBUG_DISCOVERY = True")
            print("   Then re-run Cell 11 for detailed traceback")

        elif failure_category == "USER_INTERRUPT":
            print("\n-> Check if checkpoint exists:")
            print(f"   import os")
            print(f"   os.path.exists('{_CHECKPOINT_PATH}')")
            print("   If checkpoint exists, can load it and skip training:")
            print("   -> Run Cell 8 (inference) directly")

        else:
            print("\n-> General recovery steps:")
            print("   1. Enable DEBUG in Cell 0:")
            print("      VERBOSE_LOGGING = True")
            print("      DEBUG_DISCOVERY = True")
            print("   2. Re-run Cells 0-11")
            print("   3. Check GPU availability:")
            print("      torch.cuda.is_available()")
            print("   4. Verify sufficient GPU memory")

        print("\n" + "=" * 80)

    total_duration = time.time() - start_time
    end_utc = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")

    print("\n" + "=" * 80)
    print("EXECUTION SUMMARY")
    print("=" * 80)
    print(f"User: {user_login}")
    print(f"Started: {now_utc}")
    print(f"Finished: {end_utc}")
    print(f"Duration: {_format_duration(total_duration)}")

    if pipeline_success:
        print("Status: SUCCESS ✓")
        if 'checkpoint_valid' in locals() and checkpoint_valid:
            print("Checkpoint: VALID ✓")
        else:
            print("Checkpoint: NEEDS ATTENTION")
            print("  -> Run warmup to populate DSCD prototypes")
    else:
        print(f"Status: FAILED ({failure_category or 'UNKNOWN'})")
        print(f"  -> See [RECOVERY] section above for fix instructions")

    print("=" * 80)

    _safe_cleanup()

print("\n" + "=" * 80)
print("Cell 11: Execution wrapper ready (FINAL) - FIXED")
print("=" * 80 + "\n")
