In [1]:
!pip uninstall -y transformers tokenizers sentence-transformers
!pip install transformers==4.30.2 --no-deps
!pip install "tokenizers<0.14" sacremoses
!pip install sentence-transformers==2.2.2
!pip install sacrebleu

Found existing installation: transformers 4.57.1
Uninstalling transformers-4.57.1:
  Successfully uninstalled transformers-4.57.1
Found existing installation: tokenizers 0.22.1
Uninstalling tokenizers-0.22.1:
  Successfully uninstalled tokenizers-0.22.1
Found existing installation: sentence-transformers 5.1.1
Uninstalling sentence-transformers-5.1.1:
  Successfully uninstalled sentence-transformers-5.1.1
Collecting transformers==4.30.2
  Downloading transformers-4.30.2-py3-none-any.whl.metadata (113 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m113.6/113.6 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.2/7.2 MB[0m [31m79.5 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25hInstalling collected packages: transformers
Successfully installed transformers-4.30.2
Collecting tokenizers<0.14
  Downloading tokenizers-0.13.3.tar.gz (314 kB)
[2

In [2]:
import transformers
print(transformers.__version__)


4.30.2


In [3]:
# ==============================================================================
# CELL 0: TATN CONFIGURATION (BENGALI → ENGLISH) - FIXED & VERIFIED
# ==============================================================================

import os
import sys
import math
import random
import re
import unicodedata
import time
import threading
from pathlib import Path
from collections import deque, defaultdict
from typing import List, Dict, Tuple, Optional, Union, Set, Any
from types import SimpleNamespace

import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import warnings
import gc

# ------------------------------------------------------------------------------
# Optional dependencies
# ------------------------------------------------------------------------------

try:
    import pandas as pd
    _HAS_PANDAS = True
except ImportError:
    _HAS_PANDAS = False
    print("[WARN] pandas not available; CSV loading will fail")

try:
    from transformers import M2M100TokenizerFast as M2M100Tokenizer
    _HAS_M2M_TOKENIZER = True
except Exception:
    try:
        from transformers import M2M100Tokenizer
        _HAS_M2M_TOKENIZER = True
    except Exception:
        M2M100Tokenizer = None
        _HAS_M2M_TOKENIZER = False
        print("[WARN] M2M100Tokenizer not available")

try:
    from datasets import load_dataset
    _HAS_DATASETS = True
except Exception:
    load_dataset = None
    _HAS_DATASETS = False

warnings.filterwarnings("ignore")

os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"

# ------------------------------------------------------------------------------
# Device / GPU configuration
# ------------------------------------------------------------------------------

NUM_GPUS = torch.cuda.device_count()
USE_MULTI_GPU = NUM_GPUS > 1

if USE_MULTI_GPU:
    print(f"[Cell 0] Multi-GPU Mode: {NUM_GPUS} GPUs available")
    DEVICE = torch.device("cuda:0")
else:
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    mode = "Single GPU Mode" if torch.cuda.is_available() else "CPU Mode"
    print(f"[Cell 0] {mode}")

print(f"[Cell 0] Device: {DEVICE} (visible GPUs: {NUM_GPUS})")

# ------------------------------------------------------------------------------
# Dataset path (NOTE: CSV is en→bn but model trains bn→en; swap happens in Cell 2)
# ------------------------------------------------------------------------------

DATASET_CSV_PATH = os.environ.get(
    "DATASET_PATH",
    "/kaggle/input/bn-homo/bn_homograph_complete_dataset.csv"
)

if not os.path.exists(DATASET_CSV_PATH):
    print(f"[WARN] Dataset CSV not found at: {DATASET_CSV_PATH}")
    print("[WARN] Training will use fallback dataset if file is not accessible")
else:
    print(f"[INFO] Dataset CSV found: {DATASET_CSV_PATH}")
    if _HAS_PANDAS:
        try:
            _test_df = pd.read_csv(DATASET_CSV_PATH, nrows=1)
            # Relaxed check: Accept if 'src' OR 'tgt' exists, logic handled in Cell 2
            cols = _test_df.columns.tolist()
            print(f"[INFO] CSV validation passed (columns: {cols})")
            del _test_df
        except Exception as e:
            print(f"[WARN] Could not validate CSV structure: {e}")

# ------------------------------------------------------------------------------
# Core training hyperparameters
# ------------------------------------------------------------------------------

BATCH_SIZE = 100
NUM_SAMPLES =30000
MAX_LENGTH = 48

LR_NMT = 2e-5
LR_TRG = 1e-5
LR_PHI = 1e-5

EPOCHS = 1
GRAD_CLIP_NORM = 1.0
USE_AMP = True
PRINT_INTERVAL = 300
SEED = 42

ACCUMULATION_STEPS = 16

# ------------------------------------------------------------------------------
# TRG / MC-dropout / data loader / memory
# ------------------------------------------------------------------------------

MC_DROPOUT_PASSES = 5
TRG_EVIDENCE_K = 3
MAX_SILVER_BUFFER = 50

NUM_WORKERS = 2
PIN_MEMORY = True
PREFETCH_FACTOR = 2
GRADIENT_CHECKPOINTING = True

# ------------------------------------------------------------------------------
# Debug flags (keep DSCD internal logs off by default to avoid spam)
# ------------------------------------------------------------------------------

DEBUG_DISCOVERY = False
DEBUG_TIMING = True
DEBUG_VERBOSE = False

# ------------------------------------------------------------------------------
# DSCD configuration (more permissive for multi-sense)
# ------------------------------------------------------------------------------

DSCD_BUFFER_SIZE = 50
DSCD_MAX_PROTOS = 8
DSCD_N_MIN = 2                 # allow minority sense clusters
DSCD_DISPERSION_THRESHOLD = 0.15  # encourage cluster splitting
DSCD_EMBED_DIM = 1024
DSCD_TEMPERATURE = 0.7
DSCD_DROPOUT = 0.1
DSCD_AUGMENT_SCALE = 0.1
DSCD_ENABLE_TRAINING_CLUSTERING = True
DSCD_WARMUP_SAMPLES = 8000

PERIODIC_DISCOVERY_FREQUENCY = 200
_MAX_TOKENS_PER_DISCOVERY = 150

# ------------------------------------------------------------------------------
# ASBN / TRG configuration
# ------------------------------------------------------------------------------

ENABLE_ASBN_TRAINING = True
ENABLE_ASBN_INFERENCE = False

ENABLE_TRG_TRAINING = True
ENABLE_TRG_INFERENCE = True

CLUSTERING_TIMEOUT = 5
MEMORY_CLEANUP_FREQUENCY = 100
VALIDATION_CHECK_INTERVAL = 200
VERBOSE_LOGGING = False

# ------------------------------------------------------------------------------
# Checkpoint configuration
# ------------------------------------------------------------------------------

CHECKPOINT_DIR = "/kaggle/working/"
CHECKPOINT_SAVE_AFTER_TRAINING = True
CHECKPOINT_FILENAME = "tatn_final.pt"
CHECKPOINT_INTERVAL = 99999999
SAVE_REPLAY_BUFFER = False
LOAD_REPLAY_BUFFER = False
REPLAY_BUFFER_SIZE = 25000
RESUME_FROM_CHECKPOINT = False
CHECKPOINT_PATH = ""

# ------------------------------------------------------------------------------
# TRG uncertainty / span thresholds
# ------------------------------------------------------------------------------

TAU_LOW = 0.15
TAU_HIGH = 0.85
TAU_ACCEPT = 0.8

TRG_MAX_GEN_LEN = 16
TRG_GEN_EMBED = 64
TRG_GEN_HID = 64

SPAN_THRESHOLD = 0.20
UNCERTAINTY_THRESHOLD = 0.25
TRG_TEMPERATURE = 1.0

# ------------------------------------------------------------------------------
# ASBN loss weights
# ------------------------------------------------------------------------------

ASBN_HIDDEN_DIM = 64
ASBN_LAMBDA = 0.1
ASBN_DROPOUT = 0.1

LAMBDA_ASBN = 0.05
LAMBDA_DSCD = 0.15

# ------------------------------------------------------------------------------
# Domain labels / GRL schedule
# ------------------------------------------------------------------------------

TRAIN_DOMAIN = 0
TEST_DOMAIN = 1
USE_DOMAIN_LABELS = True

GRL_ALPHA_START = 0.0
GRL_ALPHA_END = 1.0
GRL_ALPHA_SCHEDULE = "linear"
GRL_ALPHA_STEPS = (
    NUM_SAMPLES // (BATCH_SIZE * ACCUMULATION_STEPS) * EPOCHS
    if BATCH_SIZE * ACCUMULATION_STEPS > 0
    else 10000
)

# ------------------------------------------------------------------------------
# Language configuration
# ------------------------------------------------------------------------------

SOURCE_LANGUAGE = "bn"
TARGET_LANGUAGE = "en"

M2M100_BN_TOKEN_ID = 128025
M2M100_EN_TOKEN_ID = 128022

# ------------------------------------------------------------------------------
# Reference homograph list (evaluation only; DSCD unsupervised)
# ------------------------------------------------------------------------------

HOMOGRAPH_REFERENCE_LIST_BN: Set[str] = {
    "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা",
    "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত",
    "দিন", "রাত", "জল", "বাড়ি", "পার্ক", "নদী", "বন", "ফুল", "গাছ",
    "চোখ", "মুখ", "পা", "কান", "গলা", "নাক", "দাঁত", "কোমর",
    "পড়া", "দেখা", "যাওয়া", "আসা", "খেলা", "লেখা", "বলা", "শোনা",
    "চলা", "ধরা", "দেওয়া", "নেওয়া",
    "সময়", "বছর", "মাস", "সাল", "ঘন্টা", "মুহূর্ত",
    "গরম", "শীত", "বাতাস", "আগুন", "পাথর", "মাটি",
    "ভাব", "রং", "আলো", "ছায়া", "শব্দ", "অর্থ",
}

HOMOGRAPH_WATCHLIST_BN: Set[str] = set()
HOMOGRAPH_WATCHLIST: Set[str] = set()
USE_WATCHLIST_PRIORITIZATION = False
WATCHLIST_ONLY_FOR_TRG = False

# ------------------------------------------------------------------------------
# Normalization utilities
# ------------------------------------------------------------------------------

def normalize_bengali(t: str) -> str:
    if not t:
        return ""
    t = unicodedata.normalize("NFKC", t)
    t = t.replace("▁", "").replace("##", "").strip()
    return t

def normalize_english(t: str) -> str:
    if not t:
        return ""
    t = unicodedata.normalize("NFKC", t).lower().strip()
    return t

# ------------------------------------------------------------------------------
# CUDA helpers
# ------------------------------------------------------------------------------

def empty_cuda_cache() -> None:
    gc.collect()
    if torch.cuda.is_available():
        try:
            torch.cuda.empty_cache()
        except Exception:
            pass

def safe_cuda_synchronize() -> None:
    if torch.cuda.is_available():
        try:
            torch.cuda.synchronize()
        except Exception:
            pass

def monitor_gpu_usage() -> None:
    if torch.cuda.is_available():
        visible_gpus = torch.cuda.device_count()
        print(f"\n[GPU MONITOR] Checking {visible_gpus} GPU(s):")
        for i in range(visible_gpus):
            try:
                mem_alloc = torch.cuda.memory_allocated(i) / (1024 ** 3)
                mem_reserved = torch.cuda.memory_reserved(i) / (1024 ** 3)
                print(
                    f"  GPU {i}: {mem_alloc:.2f}GB allocated / {mem_reserved:.2f}GB reserved"
                )
            except Exception:
                print(f"  GPU {i}: memory stats unavailable")
    else:
        print("[GPU MONITOR] No CUDA devices available")

# ------------------------------------------------------------------------------
# Checkpoint helpers
# ------------------------------------------------------------------------------

def get_checkpoint_path() -> str:
    return os.path.join(CHECKPOINT_DIR, CHECKPOINT_FILENAME)

def should_save_checkpoint(global_step: int, epoch: int, is_final: bool = False) -> bool:
    if is_final and CHECKPOINT_SAVE_AFTER_TRAINING:
        return True
    if (
        CHECKPOINT_INTERVAL < 99999999
        and global_step >= CHECKPOINT_INTERVAL
        and global_step % CHECKPOINT_INTERVAL == 0
    ):
        return True
    return False

# ------------------------------------------------------------------------------
# Function timeout utility
# ------------------------------------------------------------------------------

class FunctionTimeoutError(Exception):
    pass

def with_timeout(seconds: int):
    def decorator(func):
        def wrapper(*args, **kwargs):
            result = [FunctionTimeoutError("Function timed out")]

            def target():
                try:
                    result[0] = func(*args, **kwargs)
                except Exception as e:
                    result[0] = e

            thread = threading.Thread(target=target, daemon=True)
            thread.start()
            thread.join(timeout=seconds)
            if thread.is_alive():
                return None
            if isinstance(result[0], Exception):
                if isinstance(result[0], FunctionTimeoutError):
                    return None
                raise result[0]
            return result[0]
        return wrapper
    return decorator

# ------------------------------------------------------------------------------
# Token utilities for DSCD / tokenizer
# ------------------------------------------------------------------------------

def get_special_tokens(tokenizer) -> Set[str]:
    try:
        s = set(getattr(tokenizer, "all_special_tokens", []))
    except Exception:
        s = {"<pad>", "</s>", "<s>", "<unk>"}
    s.update({SOURCE_LANGUAGE, TARGET_LANGUAGE})
    return s

_token_validation_cache: Dict[Tuple[str, str], bool] = {}
_cache_lock = threading.Lock()
_cache_max_size = 10000

def is_valid_token(
    token,
    special_tokens: Optional[Set[str]] = None,
    tokenizer=None,
    language: str = "bn",
) -> bool:
    token = "" if token is None else str(token)
    cache_key = (token, language)

    with _cache_lock:
        if cache_key in _token_validation_cache:
            return _token_validation_cache[cache_key]

    clean = token.replace("▁", "").replace("##", "").strip()
    if special_tokens and token in special_tokens:
        result = False
    else:
        min_len = 2
        if len(clean) < min_len:
            result = False
        else:
            has_bengali_chars = any("\u0980" <= c <= "\u09FF" for c in clean)
            if not has_bengali_chars:
                result = False
            else:
                bengali_count = sum(1 for c in clean if "\u0980" <= c <= "\u09FF")
                alphanum_count = sum(1 for c in clean if c.isalnum())
                if alphanum_count == 0:
                    result = False
                else:
                    bengali_ratio = bengali_count / alphanum_count
                    result = bengali_ratio >= 0.5

    with _cache_lock:
        if len(_token_validation_cache) < _cache_max_size:
            _token_validation_cache[cache_key] = result

    return result

def safe_tokenize_with_offsets(tokenizer, text: str, max_length: int = 512):
    try:
        encoded = tokenizer(
            text,
            return_offsets_mapping=True,
            max_length=max_length,
            truncation=True,
            add_special_tokens=False,
        )
        toks = tokenizer.convert_ids_to_tokens(encoded.get("input_ids", []))
        offsets = encoded.get("offset_mapping", [(0, 0)] * len(toks))
        return toks, offsets
    except Exception:
        return None, None

# ------------------------------------------------------------------------------
# Discovery timing helper used by DSCD
# ------------------------------------------------------------------------------

class DiscoveryTimer:
    def __init__(self):
        self.discovery_times: List[float] = []
        self.discovery_steps: List[int] = []

    def record(self, step: int, duration: float) -> None:
        self.discovery_times.append(duration)
        self.discovery_steps.append(step)

    def get_stats(self) -> Dict[str, float]:
        if not self.discovery_times:
            return {"count": 0, "total": 0.0, "avg": 0.0, "max": 0.0}
        total = sum(self.discovery_times)
        return {
            "count": len(self.discovery_times),
            "total": total,
            "avg": total / len(self.discovery_times),
            "max": max(self.discovery_times),
        }

_discovery_timer = DiscoveryTimer()
discoverytimer = _discovery_timer # Alias for older DSCD code

# ------------------------------------------------------------------------------
# Seeding and CuDNN behaviour
# ------------------------------------------------------------------------------

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

if hasattr(torch, "set_float32_matmul_precision"):
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# ------------------------------------------------------------------------------
# Summary printout
# ------------------------------------------------------------------------------

effective_batch = BATCH_SIZE * ACCUMULATION_STEPS
if USE_MULTI_GPU:
    effective_batch *= NUM_GPUS

print("\n" + "=" * 80)
print("TATN CONFIGURATION (Bengali to English)")
print("=" * 80)
print(f"User: {os.getenv('KAGGLE_USERNAME', os.getenv('USER', 'manas0003'))}")
print(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime())} UTC")
print(f"Multi-GPU: {'ENABLED' if USE_MULTI_GPU else 'DISABLED'} ({NUM_GPUS} GPUs)")
print(f"Dataset: {DATASET_CSV_PATH}")
print(f"Samples: {NUM_SAMPLES:,} | Batch: {BATCH_SIZE} | Accum: {ACCUMULATION_STEPS}")
print(f"Effective batch: {effective_batch}")
print(f"Max length: {MAX_LENGTH} | Epochs: {EPOCHS} | AMP: {USE_AMP}")
print()
print("DSCD Config:")
print(f"  Buffer: {DSCD_BUFFER_SIZE} | n_min: {DSCD_N_MIN} | Max protos: {DSCD_MAX_PROTOS}")
print(f"  Dispersion threshold: {DSCD_DISPERSION_THRESHOLD}")
print(f"  Periodic discovery: Every {PERIODIC_DISCOVERY_FREQUENCY} steps")
print(f"  Max tokens per discovery: {_MAX_TOKENS_PER_DISCOVERY}")
print()
print("TRG & Uncertainty:")
print(f"  MC Dropout passes: {MC_DROPOUT_PASSES} | TAU_LOW: {TAU_LOW}")
print(f"  SPAN_THRESHOLD: {SPAN_THRESHOLD} | TAU_HIGH: {TAU_HIGH}")
print(f"  Temperature: {TRG_TEMPERATURE}")
print()
print("ASBN / Loss:")
print(f"  LAMBDA_ASBN: {LAMBDA_ASBN} | LAMBDA_DSCD: {LAMBDA_DSCD}")
print(f"  Domain labels: {USE_DOMAIN_LABELS} | GRL: {GRL_ALPHA_SCHEDULE}")
print(f"  GRL steps: {GRL_ALPHA_STEPS}")
print()
print("Debug Flags:")
print(f"  Discovery logging: {DEBUG_DISCOVERY}")
print(f"  Timing monitoring: {DEBUG_TIMING}")
print(f"  Verbose mode: {DEBUG_VERBOSE}")
print()
print("Validation:")
print(f"  Check interval: {VALIDATION_CHECK_INTERVAL} steps")
print()
print("Language Tokens:")
print(f"  Bengali (bn): {M2M100_BN_TOKEN_ID}")
print(f"  English (en): {M2M100_EN_TOKEN_ID}")
print()
print("Checkpoint:")
print(f"  Path: {get_checkpoint_path()}")
print(f"  Save strategy: Final only")
print()
print("Discovery Mode:")
print("  PURE UNSUPERVISED (no watchlist bias)")
print(f"  Reference list: {len(HOMOGRAPH_REFERENCE_LIST_BN)} words (evaluation only)")
print("  Watchlist prioritization: DISABLED")
print("=" * 80)

if not _HAS_PANDAS:
    print("[ERROR] pandas not available - CSV loading will fail!")
if not _HAS_M2M_TOKENIZER:
    print("[ERROR] M2M100Tokenizer not available - tokenization will fail!")

try:
    test_file = os.path.join(CHECKPOINT_DIR, ".test_write")
    with open(test_file, "w") as f:
        f.write("test")
    os.remove(test_file)
    print(f"Checkpoint directory writable: {CHECKPOINT_DIR}")
except Exception as e:
    print(f"Checkpoint directory not writable: {e}")

monitor_gpu_usage()

print("\n" + "=" * 80)
print("Cell 0: Configuration loaded successfully")
print("=" * 80)


[Cell 0] Multi-GPU Mode: 2 GPUs available
[Cell 0] Device: cuda:0 (visible GPUs: 2)
[INFO] Dataset CSV found: /kaggle/input/bn-homo/bn_homograph_complete_dataset.csv
[INFO] CSV validation passed (columns: ['idx', 'src', 'tgt', 'word', 'sense'])

TATN CONFIGURATION (Bengali to English)
User: manas0003
Date: 2026-01-07 14:57:43 UTC
Multi-GPU: ENABLED (2 GPUs)
Dataset: /kaggle/input/bn-homo/bn_homograph_complete_dataset.csv
Samples: 30,000 | Batch: 100 | Accum: 16
Effective batch: 3200
Max length: 48 | Epochs: 1 | AMP: True

DSCD Config:
  Buffer: 50 | n_min: 2 | Max protos: 8
  Dispersion threshold: 0.15
  Periodic discovery: Every 200 steps
  Max tokens per discovery: 150

TRG & Uncertainty:
  MC Dropout passes: 5 | TAU_LOW: 0.15
  SPAN_THRESHOLD: 0.2 | TAU_HIGH: 0.85
  Temperature: 1.0

ASBN / Loss:
  LAMBDA_ASBN: 0.05 | LAMBDA_DSCD: 0.15
  Domain labels: True | GRL: linear
  GRL steps: 18

Debug Flags:
  Discovery logging: False
  Timing monitoring: True
  Verbose mode: False

Validat

In [4]:
# ===========================================================================================
# CELL 1: TOKENIZER UTILITIES (BENGALI-FOCUSED)
# ===========================================================================================

import threading
from typing import Tuple, List, Dict, Optional, Set
import numpy as np
import torch

# ------------------------------------------------------------------------------
# Safe defaults based on MAX_LENGTH and global flags from Cell 0
# ------------------------------------------------------------------------------

try:
    if isinstance(MAX_LENGTH, (int, float)) and MAX_LENGTH > 0:
        SAFE_OFFSET_MAX_LEN = int(MAX_LENGTH)
    else:
        SAFE_OFFSET_MAX_LEN = 48
except (NameError, ValueError, TypeError):
    SAFE_OFFSET_MAX_LEN = 48

try:
    _SOURCE_LANG = SOURCE_LANGUAGE
except NameError:
    _SOURCE_LANG = "bn"

try:
    _DEBUG_VERBOSE = DEBUG_VERBOSE
except NameError:
    _DEBUG_VERBOSE = False

try:
    _DEBUG_DISCOVERY = DEBUG_DISCOVERY
except NameError:
    _DEBUG_DISCOVERY = False

_SPECIAL_TOKENS_CACHE: Dict[str, Set[str]] = {}
_SPECIAL_TOKENS_LOCK = threading.Lock()
_LANGUAGE_WARNING_COUNT = 0
_MAX_LANGUAGE_WARNINGS = 3

# ------------------------------------------------------------------------------
# Special-token utilities
# ------------------------------------------------------------------------------

def _special_token_cache_key(tokenizer) -> str:
    name = getattr(tokenizer, "name_or_path", None) or getattr(tokenizer, "name", None)
    if not name:
        name = "unknown_tokenizer"
    vocab = None
    if hasattr(tokenizer, "vocab_size"):
        try:
            vocab = int(getattr(tokenizer, "vocab_size"))
        except Exception:
            vocab = None
    elif hasattr(tokenizer, "get_vocab") and callable(getattr(tokenizer, "get_vocab")):
        try:
            vocab = len(tokenizer.get_vocab())
        except Exception:
            vocab = None
    return f"{name}__vocab={vocab}"

def get_tokenizer_special_tokens(tokenizer) -> Set[str]:
    cache_key = _special_token_cache_key(tokenizer)
    with _SPECIAL_TOKENS_LOCK:
        if cache_key in _SPECIAL_TOKENS_CACHE:
            return _SPECIAL_TOKENS_CACHE[cache_key]

        special_tokens: Set[str] = set()
        try:
            if hasattr(tokenizer, "all_special_tokens"):
                try:
                    result = getattr(tokenizer, "all_special_tokens")
                    if isinstance(result, (list, tuple, set)):
                        special_tokens.update(x for x in result if x)
                except Exception:
                    pass
            if hasattr(tokenizer, "additional_special_tokens"):
                try:
                    result = getattr(tokenizer, "additional_special_tokens")
                    if isinstance(result, (list, tuple, set)):
                        special_tokens.update(x for x in result if x)
                except Exception:
                    pass
            for attr in ("pad_token", "unk_token", "bos_token", "eos_token",
                         "cls_token", "sep_token", "mask_token"):
                if hasattr(tokenizer, attr):
                    try:
                        tok = getattr(tokenizer, attr)
                        if tok:
                            special_tokens.add(tok)
                    except Exception:
                        pass
            try:
                stm = (
                    getattr(tokenizer, "special_tokens_map", None)
                    or getattr(tokenizer, "special_tokens_map_extended", None)
                )
                if isinstance(stm, dict):
                    for v in stm.values():
                        if isinstance(v, str) and v:
                            special_tokens.add(v)
            except Exception:
                pass
        except Exception:
            special_tokens = set()

        special_tokens.update({
            "__bn__", "__en__",
            "</s>", "<pad>", "<s>", "<unk>",
            "[PAD]", "[EOS]", "[UNK]", "[CLS]", "[SEP]", "[MASK]",
        })

        try:
            vocab = tokenizer.get_vocab() if hasattr(tokenizer, "get_vocab") else {}
            special_tokens = {
                tok
                for tok in special_tokens
                if tok in vocab or tok in {"</s>", "<pad>", "<s>", "<unk>"}
            }
        except Exception:
            pass

        _SPECIAL_TOKENS_CACHE[cache_key] = special_tokens
        return special_tokens

# ------------------------------------------------------------------------------
# Offset mapping normalization for fast / batch encodings
# ------------------------------------------------------------------------------

def _normalize_offset_mapping_for_batchencoding(enc):
    try:
        if "offset_mapping" in enc and enc["offset_mapping"] is not None:
            off = enc["offset_mapping"]
            try:
                if hasattr(off, "tolist"):
                    arr = off.tolist()
                    if isinstance(arr, list) and len(arr) > 0 and isinstance(arr[0], list):
                        enc["offset_mapping"] = [
                            (x[0], x[1])
                            if (isinstance(x, (list, tuple)) and len(x) >= 2)
                            else (None, None)
                            for x in arr[0]
                        ]
                        return enc
                if isinstance(off, (list, tuple)):
                    if len(off) > 0 and isinstance(off[0], (list, tuple)):
                        enc["offset_mapping"] = [
                            (x[0], x[1])
                            if (isinstance(x, (list, tuple)) and len(x) >= 2)
                            else (None, None)
                            for x in off[0]
                        ]
                        return enc
            except Exception:
                pass
    except Exception:
        pass

    try:
        data = getattr(enc, "data", None)
        if (
            data
            and isinstance(data, dict)
            and "offset_mapping" in data
            and data["offset_mapping"] is not None
        ):
            om = data["offset_mapping"]
            if isinstance(om, (list, tuple)) and len(om) > 0 and isinstance(om[0], (list, tuple)):
                enc["offset_mapping"] = [
                    (x[0], x[1])
                    if (isinstance(x, (list, tuple)) and len(x) >= 2)
                    else (None, None)
                    for x in om[0]
                ]
                return enc
    except Exception:
        pass

    try:
        seq_len = 0
        if "input_ids" in enc:
            input_ids = enc["input_ids"]
            if hasattr(input_ids, "shape") and len(input_ids.shape) > 0:
                seq_len = int(input_ids.shape[-1])
            elif (
                isinstance(input_ids, (list, tuple))
                and len(input_ids) > 0
                and isinstance(input_ids[0], (list, tuple))
            ):
                seq_len = len(input_ids[0])
        enc["offset_mapping"] = [(None, None)] * seq_len
    except Exception:
        enc["offset_mapping"] = []

    return enc

# ------------------------------------------------------------------------------
# Safe tokenization with offsets (single sentence)
# ------------------------------------------------------------------------------

def safe_offsets_tokenize(
    tokenizer,
    text: str,
    max_length: Optional[int] = None,
    include_special_tokens: bool = False,
) -> dict:
    if max_length is None:
        max_length = SAFE_OFFSET_MAX_LEN
    eff_max = int(max_length)

    try:
        if not isinstance(text, str):
            text = "" if text is None else str(text)
    except Exception:
        if _DEBUG_VERBOSE:
            print("[WARN] Failed to convert input to string, using empty string")
        text = ""

    char_limit = min(eff_max * 30, 8000)
    sample_text = text[:char_limit]

    is_fast = getattr(tokenizer, "is_fast", False)

    if is_fast:
        try:
            enc = tokenizer(
                sample_text,
                return_offsets_mapping=True,
                return_tensors="pt",
                truncation=True,
                padding=False,
                max_length=eff_max,
                add_special_tokens=include_special_tokens,
            )
            enc = _normalize_offset_mapping_for_batchencoding(enc)
            return enc
        except Exception:
            pass

    try:
        enc = tokenizer(
            sample_text,
            return_tensors="pt",
            truncation=True,
            padding=False,
            max_length=eff_max,
            add_special_tokens=include_special_tokens,
        )
    except Exception as e:
        if _DEBUG_VERBOSE:
            print(f"[WARN] Tokenization failed: {e}, returning empty encoding")
        pad_id = getattr(tokenizer, "pad_token_id", 0)
        enc = {
            "input_ids": torch.tensor([[pad_id]], dtype=torch.long),
            "attention_mask": torch.tensor([[1]], dtype=torch.long),
        }
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc

    try:
        input_ids = None
        try:
            input_ids = enc["input_ids"][0].tolist()
        except Exception:
            if hasattr(enc, "data") and "input_ids" in enc.data:
                input_ids = enc.data["input_ids"][0]

        tokens: List[str] = []
        if input_ids is not None:
            try:
                tokens = tokenizer.convert_ids_to_tokens(input_ids)
            except Exception:
                tokens = []

        offsets_list: List[Tuple[Optional[int], Optional[int]]] = []
        src = sample_text
        cur_pos = 0
        for tok in tokens:
            token_text = (tok or "").replace("▁", "").replace("##", "").replace("Ġ", "").strip()
            if not token_text:
                offsets_list.append((None, None))
                continue
            idx = src.find(token_text, cur_pos)
            if idx == -1:
                idx = src.lower().find(token_text.lower(), cur_pos)
            if idx == -1:
                offsets_list.append((None, None))
            else:
                start = int(idx)
                end = int(idx + len(token_text))
                offsets_list.append((start, end))
                cur_pos = end

        enc["offset_mapping"] = offsets_list
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc
    except Exception:
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc

# ------------------------------------------------------------------------------
# Word reconstruction from token offsets (used by DSCD to get surface words)
# ------------------------------------------------------------------------------

def reconstruct_word_spans(
    tokenizer,
    text: str,
    max_length: Optional[int] = None,
) -> Tuple[Dict[int, Optional[str]], List[str]]:
    global _LANGUAGE_WARNING_COUNT

    if max_length is None:
        max_length = SAFE_OFFSET_MAX_LEN
    eff_max = int(max_length)

    if not isinstance(text, str) or len(text.strip()) == 0:
        return {}, []

    has_bengali = any("\u0980" <= c <= "\u09FF" for c in text)
    has_english = any("a" <= c.lower() <= "z" for c in text)

    if _DEBUG_VERBOSE and _DEBUG_DISCOVERY:
        bengali_pct = (
            sum(1 for c in text if "\u0980" <= c <= "\u09FF")
            / max(1, len(text))
            * 100.0
        )
        print(f"[TOKENIZER] Text sample: {text[:50]}")
        print(
            f"[TOKENIZER] Bengali: {has_bengali} ({bengali_pct:.1f}%), "
            f"English: {has_english}"
        )

    if not has_bengali and has_english and _LANGUAGE_WARNING_COUNT < _MAX_LANGUAGE_WARNINGS:
        if _DEBUG_DISCOVERY:
            print("[TOKENIZER WARNING] Text appears to be ENGLISH, not BENGALI")
            print(f"  Sample: {text[:80]}")
        _LANGUAGE_WARNING_COUNT += 1
        if _LANGUAGE_WARNING_COUNT == _MAX_LANGUAGE_WARNINGS:
            print("[TOKENIZER] Suppressing further language warnings")

    char_limit = min(eff_max * 30, 8000)
    text = text[:char_limit]
    text_len = len(text)

    special_tokens = get_tokenizer_special_tokens(tokenizer)

    try:
        current_lang = SOURCE_LANGUAGE
    except NameError:
        current_lang = _SOURCE_LANG

    try:
        encoded = safe_offsets_tokenize(
            tokenizer, text, max_length=eff_max, include_special_tokens=False
        )
    except Exception:
        return {}, []

    offsets = encoded.get("offset_mapping", [])
    try:
        input_ids = encoded["input_ids"][0].tolist()
    except Exception:
        input_ids = []
    try:
        tokens = tokenizer.convert_ids_to_tokens(input_ids) if input_ids else []
    except Exception:
        tokens = []

    if isinstance(offsets, list) and len(offsets) > 0 and all(
        isinstance(x, tuple) for x in offsets
    ):
        offsets_list = offsets
    elif isinstance(offsets, list) and len(offsets) > 0 and isinstance(
        offsets[0], (list, tuple)
    ):
        offsets_list = [
            (x[0], x[1])
            if (isinstance(x, (list, tuple)) and len(x) >= 2)
            else (None, None)
            for x in offsets[0]
        ]
    else:
        offsets_list = [(None, None)] * len(tokens)

    token_word_map: Dict[int, Optional[str]] = {}
    words: List[str] = []

    used_any_offset = any(
        isinstance(o, tuple) and o[0] is not None and o[1] is not None
        for o in offsets_list
    )
    if used_any_offset:
        word_start: Optional[int] = None
        word_end: Optional[int] = None

        for idx, (off, tok) in enumerate(zip(offsets_list, tokens)):
            try:
                off_start = int(off[0]) if off[0] is not None else None
                off_end = int(off[1]) if off[1] is not None else None
            except Exception:
                off_start, off_end = None, None

            if off_start is not None and off_end is not None:
                if off_start < 0 or off_end < 0:
                    if _DEBUG_VERBOSE:
                        print(
                            f"[WARN] Negative offset detected: "
                            f"({off_start}, {off_end}), skipping"
                        )
                    off_start, off_end = None, None
                else:
                    off_start = max(0, min(off_start, text_len))
                    off_end = max(off_start, min(off_end, text_len))

            if off_start is None or off_end is None:
                if word_start is not None and word_end is not None:
                    try:
                        wtext = text[word_start:word_end].strip()
                        if wtext:
                            words.append(wtext)
                    except Exception:
                        pass
                word_start = None
                word_end = None
                token_word_map[idx] = None
                continue

            if tok in special_tokens:
                token_word_map[idx] = None
                continue

            if word_start is None:
                word_start = off_start
                word_end = off_end
            else:
                if off_start > word_end:
                    try:
                        wtext = text[word_start:word_end].strip()
                        if wtext:
                            words.append(wtext)
                    except Exception:
                        pass
                    word_start = off_start
                    word_end = off_end
                else:
                    word_end = max(word_end, off_end)

            try:
                current_word = text[word_start:word_end].strip()
                token_word_map[idx] = current_word if current_word else None
            except Exception:
                token_word_map[idx] = None

        if word_start is not None and word_end is not None:
            try:
                wtext = text[word_start:word_end].strip()
                if wtext:
                    words.append(wtext)
            except Exception:
                pass

        if token_word_map:
            words = [w for w in words if isinstance(w, str) and w.strip()]
            return token_word_map, words

    token_word_map = {}
    assembled: List[str] = []
    current_parts: List[str] = []
    running_word = ""
    max_word_len = 100

    for i, tok in enumerate(tokens):
        if tok in special_tokens:
            token_word_map[i] = None
            continue
        clean = (tok or "").replace("▁", "").replace("Ġ", "").replace("##", "").strip()
        if not clean:
            token_word_map[i] = None
            continue

        if tok.startswith("▁") or tok.startswith("Ġ"):
            if current_parts:
                word = "".join(current_parts)
                if len(word) <= max_word_len:
                    assembled.append(word)
            current_parts = [clean]
            running_word = clean
        else:
            current_parts.append(clean)
            running_word = "".join(current_parts)
            if len(running_word) > max_word_len:
                if current_parts[:-1]:
                    word = "".join(current_parts[:-1])
                    assembled.append(word)
                current_parts = [clean]
                running_word = clean

        token_word_map[i] = running_word if running_word else None

    if current_parts:
        word = "".join(current_parts)
        if len(word) <= max_word_len:
            assembled.append(word)

    if token_word_map:
        words = [w for w in assembled if w and w.strip()]
        return token_word_map, words

    try:
        word_list = [w for w in text.split() if w.strip()]
        token_word_map = {}

        if tokens and word_list:
            word_idx = 0

            for i, tok in enumerate(tokens):
                clean = (tok or "").replace("▁", "").replace("Ġ", "").replace("##", "").strip()
                if not clean or tok in special_tokens:
                    token_word_map[i] = None
                    continue

                if word_idx < len(word_list):
                    current_word = word_list[word_idx]
                    if clean in current_word or current_word.startswith(clean):
                        token_word_map[i] = current_word
                    else:
                        word_idx = min(word_idx + 1, len(word_list) - 1)
                        token_word_map[i] = word_list[word_idx]
                else:
                    token_word_map[i] = word_list[-1] if word_list else None

        return token_word_map, word_list
    except Exception:
        return {}, []

# ------------------------------------------------------------------------------
# Quick self-test helper
# ------------------------------------------------------------------------------

def test_tokenizer_utilities_quick(tokenizer=None) -> bool:
    sample_bn = "কাল আমি বাজারে যাব।"
    sample_en = "Tomorrow I will go to the market."

    print("\n" + "=" * 60)
    print("TOKENIZER UTILITIES TEST")
    print("=" * 60)

    try:
        if tokenizer is None:
            print("No tokenizer provided: skipping test")
            return True

        print("\n[TEST 1] Bengali text processing:")
        print(f"  Input: {sample_bn}")
        enc_bn = safe_offsets_tokenize(
            tokenizer, sample_bn, max_length=32, include_special_tokens=False
        )
        enc_len = (
            int(enc_bn["input_ids"].shape[-1])
            if isinstance(enc_bn, dict) and "input_ids" in enc_bn
            else "N/A"
        )
        print(f"  Encoded length: {enc_len}")
        offsets_bn = enc_bn.get("offset_mapping") or []
        print(f"  Offsets (first 5): {offsets_bn[:5]}")

        token_map_bn, words_bn = reconstruct_word_spans(tokenizer, sample_bn, max_length=32)
        print(f"  Reconstructed words: {words_bn}")
        print(f"  Token map sample: {dict(list(token_map_bn.items())[:3])}")

        has_bengali_words = any(
            any("\u0980" <= c <= "\u09FF" for c in w) for w in words_bn
        )
        print(f"  Contains Bengali words: {has_bengali_words}")

        print("\n[TEST 2] English text processing (should show warning):")
        print(f"  Input: {sample_en}")
        token_map_en, words_en = reconstruct_word_spans(tokenizer, sample_en, max_length=32)
        print(f"  Reconstructed words: {words_en}")

        has_english_words = any(
            any("a" <= c.lower() <= "z" for c in w) for w in words_en
        )
        print(f"  Contains English words: {has_english_words}")

        if has_bengali_words and not any(
            "a" <= c.lower() <= "z" for c in "".join(words_bn)
        ):
            print("\nTest PASSED: Bengali processing works correctly")
            return True
        else:
            print("\nTest WARNING: Check language detection logic")
            return False

    except Exception as e:
        print(f"\nTest FAILED: {repr(e)}")
        import traceback
        traceback.print_exc()
        return False
    finally:
        print("=" * 60 + "\n")

# ------------------------------------------------------------------------------
# Backwards-compatible aliases for other cells (critical for DSCD & data loader)
# ------------------------------------------------------------------------------

safeoffsetstokenize = safe_offsets_tokenize
reconstructwordspans = reconstruct_word_spans
gettokenizerspecialtokens = get_tokenizer_special_tokens
isvalidtoken = is_valid_token = lambda token, special_tokens=None, tokenizer=None, language="bn": \
    is_valid_token(token, special_tokens, tokenizer, language)  # type: ignore[name-defined]

print("Cell 1: Tokenizer utilities loaded")

Cell 1: Tokenizer utilities loaded


In [5]:
# ==============================================================================
# CELL 2: MEMORY-EFFICIENT DATA LOADING (BENGALI → ENGLISH TASK)
# ==============================================================================

from typing import Optional, List, Tuple, Dict, Any
from collections import defaultdict
import os
import time
import random
import traceback
import re

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, get_worker_info
from tqdm import tqdm

try:
    import pandas as pd
    _HAS_PANDAS = True
except ImportError:
    pd = None
    _HAS_PANDAS = False
    print("[CELL2] WARNING: pandas not available; CSV loading will fail!")

try:
    from datasets import load_dataset
    _HAS_DATASETS = True
except Exception:
    load_dataset = None
    _HAS_DATASETS = False

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except NameError:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_VERBOSE = bool(DEBUG_VERBOSE)
except NameError:
    _DEBUG_VERBOSE = False

DEBUG_CELL2 = bool(_VERBOSE_LOGGING) or bool(_DEBUG_VERBOSE)
DEBUG_LIMIT = 10
_cell2_dbg_counts: Dict[str, int] = defaultdict(int)

def cell2_dbg(key: str, msg: str, limit: int = DEBUG_LIMIT) -> None:
    if not DEBUG_CELL2:
        return
    _cell2_dbg_counts[key] += 1
    if _cell2_dbg_counts[key] <= limit:
        print(f"[CELL2-DBG] {msg}")

try:
    _NUM_SAMPLES = int(NUM_SAMPLES)
except Exception:
    _NUM_SAMPLES = 50000
    print("[CELL2] WARNING: NUM_SAMPLES not defined, using default 50000")

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except Exception:
    _MAX_LENGTH = 48
    print("[CELL2] WARNING: MAX_LENGTH not defined, using default 48")

try:
    _SOURCE_LANG = str(SOURCE_LANGUAGE)
    _TARGET_LANG = str(TARGET_LANGUAGE)
except NameError:
    _SOURCE_LANG = "bn"
    _TARGET_LANG = "en"
    print("[CELL2] WARNING: SOURCE_LANGUAGE/TARGET_LANGUAGE not defined, using defaults bn/en")

try:
    _M2M_BN_TOKEN_ID = int(M2M100_BN_TOKEN_ID)
    _M2M_EN_TOKEN_ID = int(M2M100_EN_TOKEN_ID)
except NameError:
    _M2M_BN_TOKEN_ID = 128025
    _M2M_EN_TOKEN_ID = 128022
    print("[CELL2] WARNING: M2M100 token IDs not defined, using defaults")

try:
    _NUM_GPUS = int(NUM_GPUS)
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
except NameError:
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1
    print(f"[CELL2] WARNING: GPU config not defined, detected {_NUM_GPUS} GPUs")

try:
    _NUM_WORKERS = int(NUM_WORKERS)
except NameError:
    _NUM_WORKERS = 0
    print("[CELL2] WARNING: NUM_WORKERS not defined, using 0")

try:
    _PIN_MEMORY = bool(PIN_MEMORY)
except NameError:
    _PIN_MEMORY = False

try:
    _PREFETCH_FACTOR = int(PREFETCH_FACTOR)
except NameError:
    _PREFETCH_FACTOR = 2

try:
    _DATASET_CSV_PATH = str(DATASET_CSV_PATH)
except NameError:
    _DATASET_CSV_PATH = "/kaggle/input/bengali-english-homograph/bengali_homograph_sentences.csv"
    print(f"[CELL2] WARNING: DATASET_CSV_PATH not defined, using default: {_DATASET_CSV_PATH}")

try:
    _TRAIN_DOMAIN = int(TRAIN_DOMAIN)
    _TEST_DOMAIN = int(TEST_DOMAIN)
    _USE_DOMAIN_LABELS = bool(USE_DOMAIN_LABELS)
except NameError:
    _TRAIN_DOMAIN = 0
    _TEST_DOMAIN = 1
    _USE_DOMAIN_LABELS = False
    print("[CELL2] WARNING: Domain label config not found, disabling domain labels")

_has_normalize = ("normalize_bengali" in globals()) and ("normalize_english" in globals())
_has_reconstruct_word_spans = "reconstruct_word_spans" in globals()
_has_safe_offsets_tokenize = "safe_offsets_tokenize" in globals()

if not _has_normalize:
    print("[CELL2] WARNING: normalize_bengali/normalize_english not found; using simple .strip()")

_BENGALI_CHAR_RE = re.compile(r"[\u0980-\u09FF]")

def is_bengali_text(s: str) -> bool:
    if s is None:
        return False
    if not isinstance(s, str) or not s:
        return False
    return bool(_BENGALI_CHAR_RE.search(s))

def _dataloader_worker_init_fn(worker_id: int) -> None:
    worker_info = get_worker_info()
    dataset = worker_info.dataset if worker_info is not None else None
    try:
        if dataset is not None and hasattr(dataset, "_tokenizer_name_or_path") and dataset._tokenizer_name_or_path:
            try:
                from transformers import M2M100Tokenizer
                dataset.tokenizer = M2M100Tokenizer.from_pretrained(dataset._tokenizer_name_or_path)
                dataset.is_fast = getattr(dataset.tokenizer, "is_fast", False)
                if DEBUG_CELL2:
                    print(f"[CELL2-WORKER-{worker_id}] Tokenizer reloaded successfully")
            except Exception as e:
                cell2_dbg("worker_tokenizer_reload", f"Worker {worker_id} tokenizer reload failed: {e}")
                dataset.tokenizer = None
                dataset.is_fast = False
    except Exception:
        if DEBUG_CELL2:
            print(f"[CELL2-WORKER-INIT] Tokenizer rebind failed in worker {worker_id}")

    try:
        base = int(os.environ.get("PYTHONHASHSEED", "0"))
        seed = (base ^ (worker_id + 1) ^ int(time.time())) & 0xFFFFFFFF
        random.seed(seed)
        np.random.seed(seed % (2**31 - 1))
        torch.manual_seed(seed % (2**31 - 1))
    except Exception:
        pass

def load_and_preprocess_optimized(
    num_samples: Optional[int] = None,
    split: str = "train",
) -> List[Tuple[str, str]]:
    if num_samples is None:
        num_samples = _NUM_SAMPLES
    if num_samples <= 0:
        raise ValueError("num_samples must be positive")

    print(f"[CELL2] Loading up to {num_samples} samples from local CSV: {_DATASET_CSV_PATH}")

    if not _HAS_PANDAS:
        print("[CELL2] ERROR: pandas not available; cannot load CSV!")
        print("[CELL2] Using fallback dataset for debugging.")
        return _get_fallback_dataset()

    if not os.path.exists(_DATASET_CSV_PATH):
        print(f"[CELL2] ERROR: CSV file not found at: {_DATASET_CSV_PATH}")
        print("[CELL2] Using fallback dataset for debugging.")
        return _get_fallback_dataset()

    try:
        print("[CELL2] Reading CSV file...")
        df = pd.read_csv(_DATASET_CSV_PATH)
        if df.empty:
            print("[CELL2] ERROR: CSV file is empty")
            return _get_fallback_dataset()

        if "src" not in df.columns or "tgt" not in df.columns:
            print(f"[CELL2] ERROR: CSV missing required columns. Found columns: {list(df.columns)}")
            print("[CELL2] Expected format: src (Bengali), tgt (English) OR src (English), tgt (Bengali)")
            return _get_fallback_dataset()

        sample_src = str(df["src"].iloc[0]) if len(df) > 0 else ""
        sample_tgt = str(df["tgt"].iloc[0]) if len(df) > 0 else ""

        src_is_bengali = bool(_BENGALI_CHAR_RE.search(sample_src))
        tgt_is_bengali = bool(_BENGALI_CHAR_RE.search(sample_tgt))
        src_is_english = bool(re.search(r"[a-zA-Z]", sample_src)) and not src_is_bengali
        tgt_is_english = bool(re.search(r"[a-zA-Z]", sample_tgt)) and not tgt_is_bengali

        if src_is_english and tgt_is_bengali:
            print("[CELL2] Detected src=English, tgt=Bengali: Swapping columns for bn→en task.")
            df = df.rename(columns={"src": "_temp_tgt", "tgt": "_temp_src"})
            df = df.rename(columns={"_temp_src": "src", "_temp_tgt": "tgt"})
            sample_src = str(df["src"].iloc[0]) if len(df) > 0 else ""
            sample_tgt = str(df["tgt"].iloc[0]) if len(df) > 0 else ""
            src_is_bengali = bool(_BENGALI_CHAR_RE.search(sample_src))
            tgt_is_english = bool(re.search(r"[a-zA-Z]", sample_tgt)) and not bool(_BENGALI_CHAR_RE.search(sample_tgt))
            if not src_is_bengali or not tgt_is_english:
                print("[CELL2] ERROR: Swap failed, after swap src is not Bengali or tgt is not English.")
                return _get_fallback_dataset()
            else:
                print("[CELL2] Swap successful: src=Bengali, tgt=English")
        elif not src_is_bengali or not tgt_is_english:
            print("[CELL2] WARNING: After column check, src not Bengali or tgt not English. Proceeding but output may be incorrect.")

        df = df.head(num_samples)
        print(f"[CELL2] Processing {len(df)} rows from CSV...")

        pairs: List[Tuple[str, str]] = []
        skipped = 0

        for row_tuple in tqdm(df.itertuples(index=False), total=len(df), desc="Loading dataset"):
            try:
                src_val = row_tuple.src
                tgt_val = row_tuple.tgt
                if pd.isna(src_val) or pd.isna(tgt_val):
                    skipped += 1
                    cell2_dbg("nan_value", "NaN value detected")
                    continue
                bn = str(src_val).strip()
                en = str(tgt_val).strip()
                if not bn or not en:
                    skipped += 1
                    cell2_dbg("empty_field", "Empty src/tgt field")
                    continue
                if not is_bengali_text(bn):
                    skipped += 1
                    cell2_dbg("not_bengali_src", "src field not Bengali")
                    continue
                if not re.search(r"[a-zA-Z]", en):
                    skipped += 1
                    cell2_dbg("not_english_tgt", "tgt field not English")
                    continue
                max_words = max(20, _MAX_LENGTH // 2)
                if len(bn.split()) > max_words or len(en.split()) > max_words:
                    skipped += 1
                    cell2_dbg("too_long", "Text too long")
                    continue
                if _has_normalize:
                    bn_norm = normalize_bengali(bn)
                    en_norm = normalize_english(en)
                else:
                    bn_norm = bn.strip()
                    en_norm = en.lower().strip()
                if not bn_norm or not en_norm:
                    skipped += 1
                    cell2_dbg("empty_after_norm", "Empty after normalization")
                    continue
                pairs.append((bn_norm, en_norm))
            except Exception as e:
                skipped += 1
                cell2_dbg("row_exception", f"Row load exception: {type(e).__name__}")
                continue

        print(f"[CELL2] Loaded {len(pairs)} pairs from CSV, skipped {skipped} rows")
        if len(pairs) == 0:
            print("[CELL2] ERROR: No valid pairs loaded from CSV!")
            print("[CELL2] Check that src column contains Bengali and tgt column contains English.")
            return _get_fallback_dataset()

        return pairs

    except pd.errors.EmptyDataError:
        print(f"[CELL2] ERROR: CSV file is empty: {_DATASET_CSV_PATH}")
        return _get_fallback_dataset()
    except Exception as e:
        print(f"[CELL2] ERROR loading CSV: {type(e).__name__}: {str(e)}")
        traceback.print_exc()
        print("[CELL2] Using fallback dataset")
        return _get_fallback_dataset()

def _get_fallback_dataset() -> List[Tuple[str, str]]:
    print("[CELL2] Using fallback dataset (50 unique samples)")
    fallback_pairs = [
        ("আমি কল বন্ধ করেছি।", "i turned off the tap."),
        ("সে আমাকে পরে কল করবে।", "he will call me later."),
        ("আমরা প্রতিদিন তাজা ফল খাই।", "we eat fresh fruits every day."),
        ("তার কঠোর পরিশ্রমের ভালো ফল হয়েছে।", "his hard work has brought good results."),
        ("গাছে নতুন পাতাগুলো গজিয়েছে।", "new leaves have sprouted on the tree."),
        ("আমি বইয়ের পাতা উল্টাচ্ছি।", "i am turning the pages of the book."),
        ("কাল আমি বাজারে গিয়েছিলাম।", "yesterday i went to the market."),
        ("কাল আমি তোমার সাথে দেখা করব।", "tomorrow i will meet you."),
        ("তারা আকাশে উজ্জ্বল।", "the stars are bright in the sky."),
        ("তারা বাড়িতে নেই।", "they are not at home."),
        ("ব্যাংক নদীর ধারে ভেঙে গেছে।", "the bank by the river has collapsed."),
        ("আমি ব্যাংকে টাকা জমা দিয়েছি।", "i deposited money in the bank."),
        ("বার বার চেষ্টা করতে হবে।", "you have to try again and again."),
        ("আমি বার খুলে ভিতরে ঢুকলাম।", "i opened the bar and entered."),
        ("তার মাথা ব্যথা করছে।", "his head is hurting."),
        ("আমি মাথা নেড়ে সম্মতি দিলাম।", "i nodded my head in agreement."),
        ("সে হার মেনে নিয়েছে।", "he accepted defeat."),
        ("আমি গলায় সোনার হার পরেছি।", "i am wearing a gold necklace."),
        ("পানি খুব ঠান্ডা।", "the water is very cold."),
        ("আমি পানি খাচ্ছি।", "i am drinking water."),
        ("দল খেলায় জিতেছে।", "the team won the game."),
        ("আমি মাটি দল দিয়ে ফেললাম।", "i trampled the soil."),
        ("বাজার থেকে সবজি কিনলাম।", "i bought vegetables from the market."),
        ("বাজার অনেক ভিড় ছিল।", "the market was very crowded."),
        ("তার নাম আহমেদ।", "his name is ahmed."),
        ("নাম না করে কাজ করো।", "work without making a name."),
        ("কথা বলা বন্ধ করো।", "stop talking."),
        ("তার কথা শুনে ভালো লাগল।", "i felt good hearing his words."),
        ("বই পড়তে ভালো লাগে।", "i like reading books."),
        ("আমি একটি নতুন বই কিনেছি।", "i bought a new book."),
        ("ঘর পরিষ্কার করা হয়েছে।", "the house has been cleaned."),
        ("আমি ঘরে বসে আছি।", "i am sitting at home."),
        ("মন ভালো নেই।", "my mind is not good."),
        ("আমার মন চায় বেড়াতে যেতে।", "my mind wants to go for a walk."),
        ("হাত ধুয়ে নাও।", "wash your hands."),
        ("আমি তার হাত ধরলাম।", "i held his hand."),
        ("দিন কেটে যাচ্ছে।", "the day is passing by."),
        ("আজ কি দিন?", "what day is today?"),
        ("রাত হয়ে এসেছে।", "night has come."),
        ("আমি রাত জেগে পড়েছি।", "i studied staying up at night."),
        ("জল খুব গরম।", "the water is very hot."),
        ("আমি জল দিয়ে গাছ সিঞ্চন করেছি।", "i watered the plants."),
        ("বাড়ি যাচ্ছি।", "i am going home."),
        ("আমার বাড়ি ঢাকায়।", "my house is in dhaka."),
        ("পার্কে অনেক মানুষ।", "there are many people in the park."),
        ("আমি প্রতিদিন পার্কে হাঁটি।", "i walk in the park every day."),
        ("নদী বইছে।", "the river is flowing."),
        ("আমি নদীর ধারে দাঁড়িয়ে আছি।", "i am standing by the river."),
        ("বন খুব সুন্দর।", "the forest is very beautiful."),
        ("আমি বন দেখতে গিয়েছিলাম।", "i went to see the forest."),
    ]
    if _has_normalize:
        return [
            (normalize_bengali(bn), normalize_english(en))
            for bn, en in fallback_pairs
        ]
    else:
        return [(bn.strip(), en.lower().strip()) for bn, en in fallback_pairs]

class MemoryEfficientDataset(Dataset):
    def __init__(
        self,
        pairs: List[Tuple[str, str]],
        tokenizer: Any = None,
        max_length: Optional[int] = None,
        split: str = "train",
    ):
        if max_length is None:
            max_length = _MAX_LENGTH
        self.max_length = int(max_length)
        self.tokenizer = tokenizer
        self.split = split

        try:
            self._tokenizer_name_or_path = getattr(tokenizer, "name_or_path", None)
        except Exception:
            self._tokenizer_name_or_path = None

        try:
            self.is_fast = getattr(self.tokenizer, "is_fast", False)
        except Exception:
            self.is_fast = False

        self.pairs: List[Tuple[str, str]] = []
        invalid = 0

        for i, p in enumerate(pairs):
            try:
                if not isinstance(p, (list, tuple)) or len(p) != 2:
                    invalid += 1
                    cell2_dbg("init_badpair", f"Bad pair structure at idx={i}")
                    continue
                src, tgt = p
                if not isinstance(src, str) or not isinstance(tgt, str):
                    invalid += 1
                    cell2_dbg("init_badtype", f"Non-string src/tgt at idx={i}")
                    continue
                if not src or not tgt:
                    invalid += 1
                    cell2_dbg("init_empty", f"Empty src/tgt at idx={i}")
                    continue
                if len(src) > self.max_length * 20 or len(tgt) > self.max_length * 20:
                    invalid += 1
                    cell2_dbg("init_long", f"Extremely long text at idx={i}")
                    continue
                self.pairs.append((src, tgt))
            except Exception as e:
                invalid += 1
                cell2_dbg("init_exc", f"Init pair exception idx={i}: {type(e).__name__}")

        print(f"[CELL2] Dataset initialized: {len(self.pairs)} valid pairs, {invalid} invalid")

        try:
            if "get_tokenizer_special_tokens" in globals():
                self.special_tokens = get_tokenizer_special_tokens(self.tokenizer)
            else:
                self.special_tokens = set(getattr(self.tokenizer, "all_special_tokens", [])) if self.tokenizer is not None else set()
        except Exception:
            self.special_tokens = {
                f"__{_SOURCE_LANG}__",
                f"__{_TARGET_LANG}__",
                "</s>",
                "<pad>",
                "<s>",
                "<unk>",
            }

    def __getstate__(self):
        state = self.__dict__.copy()
        state["tokenizer"] = None
        state["_tokenizer_name_or_path"] = getattr(self, "_tokenizer_name_or_path", None)
        return state

    def __setstate__(self, state):
        self.__dict__.update(state)
        self.tokenizer = None
        self.is_fast = False

    def __len__(self) -> int:
        return len(self.pairs)

    def _encode_src(self, src_text: str):
        src_text = src_text if isinstance(src_text, str) else str(src_text)
        try:
            if self.tokenizer is None:
                self.tokenizer = globals().get("tokenizer", None)
                self.is_fast = getattr(self.tokenizer, "is_fast", False) if self.tokenizer is not None else False
            if self.tokenizer is None:
                raise RuntimeError("Tokenizer not available")

            if _has_safe_offsets_tokenize:
                enc = safe_offsets_tokenize(self.tokenizer, src_text, max_length=self.max_length)
                try:
                    if isinstance(enc["input_ids"], torch.Tensor):
                        input_ids = enc["input_ids"].squeeze(0)
                    else:
                        input_ids = torch.tensor(enc["input_ids"][0])
                except Exception:
                    input_ids = torch.tensor(enc.get("input_ids", [[1]])[0])
                attention_mask = enc.get("attention_mask", torch.ones_like(input_ids))
                if isinstance(attention_mask, list):
                    attention_mask = torch.tensor(attention_mask[0]) if attention_mask else torch.ones_like(input_ids)
                try:
                    ids_list = input_ids.tolist() if isinstance(input_ids, torch.Tensor) else list(input_ids)
                    tokens = self.tokenizer.convert_ids_to_tokens(ids_list)
                except Exception:
                    tokens = []
            else:
                enc = self.tokenizer(
                    src_text,
                    max_length=self.max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt",
                    add_special_tokens=False,
                )
                input_ids = enc["input_ids"].squeeze(0)
                attention_mask = enc.get("attention_mask", torch.ones_like(input_ids)).squeeze(0)
                try:
                    tokens = self.tokenizer.convert_ids_to_tokens(input_ids.tolist())
                except Exception:
                    tokens = []

            token_word_map: Dict[int, str] = {}
            if _has_reconstruct_word_spans:
                try:
                    wm, words = reconstruct_word_spans(self.tokenizer, src_text, max_length=self.max_length)
                    if isinstance(wm, dict) and wm:
                        token_word_map = wm
                except Exception as e:
                    cell2_dbg("wm_exc", f"reconstruct_word_spans failed: {e}")

            if not token_word_map and tokens:
                try:
                    word_buffer: List[Tuple[int, str]] = []
                    
                    for idx, tok in enumerate(tokens):
                        if not isinstance(tok, str) or tok in self.special_tokens:
                            continue
                        
                        clean = tok.replace("▁", "").replace("Ġ", "").replace("##", "").strip()
                        if not clean:
                            continue
                        
                        if tok.startswith("▁") or tok.startswith("Ġ"):
                            if word_buffer:
                                full_word = "".join([part for _, part in word_buffer])
                                for w_idx, _ in word_buffer:
                                    token_word_map[w_idx] = full_word
                            word_buffer = [(idx, clean)]
                        else:
                            word_buffer.append((idx, clean))
                    
                    if word_buffer:
                        full_word = "".join([part for _, part in word_buffer])
                        for w_idx, _ in word_buffer:
                            token_word_map[w_idx] = full_word
                            
                except Exception as e:
                    cell2_dbg("fallback_wm", f"Fallback word map failed: {e}")

            return input_ids, attention_mask, tokens, token_word_map

        except Exception as e:
            cell2_dbg("encode_src_exc", f"Encoding source failed: {type(e).__name__}")
            pad_id = getattr(self.tokenizer, "pad_token_id", 1) if self.tokenizer is not None else 1
            input_ids = torch.full((self.max_length,), int(pad_id), dtype=torch.long)
            attention_mask = torch.zeros(self.max_length, dtype=torch.long)
            return input_ids, attention_mask, [], {}

    def _encode_tgt(self, tgt_text: str):
        tgt_text = tgt_text if isinstance(tgt_text, str) else str(tgt_text)
        try:
            if self.tokenizer is None:
                self.tokenizer = globals().get("tokenizer", None)
            if self.tokenizer is None:
                raise RuntimeError("Tokenizer not available")

            dec = self.tokenizer(
                tgt_text,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
                add_special_tokens=False,
            )
            labels = dec["input_ids"].squeeze(0)
            pad_id = getattr(self.tokenizer, "pad_token_id", 1) if self.tokenizer is not None else 1
            labels[labels == int(pad_id)] = -100
            return labels
        except Exception as e:
            cell2_dbg("encode_tgt_exc", f"Encoding tgt failed: {type(e).__name__}")
            return torch.full((self.max_length,), -100, dtype=torch.long)

    def _make_safe_sample(self, reason: str = "fallback") -> Dict[str, Any]:
        try:
            src = "আমি"
            tgt = "i"
            input_ids, attention_mask, tokens, token_word_map = self._encode_src(src)
            labels = self._encode_tgt(tgt)
            domain_label = _TRAIN_DOMAIN if self.split == "train" else _TEST_DOMAIN
            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "token_word_map": token_word_map,
                "src_text": src,
                "tokens": tokens,
                "domain_label": domain_label,
            }
        except Exception:
            pad_id = 1
            domain_label = _TRAIN_DOMAIN if self.split == "train" else _TEST_DOMAIN
            return {
                "input_ids": torch.full((self.max_length,), int(pad_id), dtype=torch.long),
                "attention_mask": torch.zeros(self.max_length, dtype=torch.long),
                "labels": torch.full((self.max_length,), -100, dtype=torch.long),
                "token_word_map": {},
                "src_text": "",
                "tokens": [],
                "domain_label": domain_label,
            }

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        try:
            if idx < 0 or idx >= len(self.pairs):
                cell2_dbg("getitem_oob", f"Index out of range idx={idx}")
                return self._make_safe_sample("oob")

            src, tgt = self.pairs[idx]
            if not isinstance(src, str) or not isinstance(tgt, str):
                cell2_dbg("getitem_bad_types", f"Bad types at idx={idx}")
                return self._make_safe_sample("bad_types")

            if DEBUG_CELL2 and idx < 3:
                has_bengali = is_bengali_text(src)
                has_english = any("a" <= c.lower() <= "z" for c in src)
                print(f"[CELL2-GETITEM-{idx}] src sample: {src[:50]}")
                print(f"[CELL2-GETITEM-{idx}] Bengali: {has_bengali}, English: {has_english}")
                if not has_bengali:
                    print(f"[CELL2] WARNING: src_text is NOT Bengali at idx={idx}!")

            input_ids, attention_mask, tokens, token_word_map = self._encode_src(src)
            labels = self._encode_tgt(tgt)
            domain_label = _TRAIN_DOMAIN if self.split == "train" else _TEST_DOMAIN

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "token_word_map": token_word_map,
                "src_text": src,
                "tokens": tokens,
                "domain_label": domain_label,
            }
        except Exception as e:
            cell2_dbg("getitem_exc", f"Unhandled __getitem__ exception idx={idx}: {type(e).__name__}")
            return self._make_safe_sample("unhandled")

def _infer_pad_id_from_sample(sample: Dict[str, Any], default_pad_id: int = 1) -> int:
    try:
        tk = globals().get("tokenizer", None)
        if tk is not None:
            pad = getattr(tk, "pad_token_id", None)
            if pad is not None:
                return int(pad)
    except Exception:
        cell2_dbg("infer_pad_exc", "infer pad id failed")
    return int(default_pad_id)

def _pad_or_truncate_array(tensor: torch.Tensor, length: int, pad_value: int) -> torch.Tensor:
    if tensor is None:
        return torch.full((length,), int(pad_value), dtype=torch.long)
    t = tensor.view(-1).long()
    L = t.size(0)
    if L == length:
        return t
    if L < length:
        pad = torch.full((length - L,), int(pad_value), dtype=t.dtype)
        return torch.cat([t, pad], dim=0)
    return t[:length]

def safe_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    valid = [b for b in batch if isinstance(b, dict) and "input_ids" in b and isinstance(b["input_ids"], torch.Tensor)]
    if not valid:
        pad = _infer_pad_id_from_sample({}, default_pad_id=1)
        return {
            "input_ids": torch.full((1, _MAX_LENGTH), pad, dtype=torch.long),
            "attention_mask": torch.zeros(1, _MAX_LENGTH, dtype=torch.long),
            "labels": torch.full((1, _MAX_LENGTH), -100, dtype=torch.long),
            "token_word_map": [{}],
            "src_text": [""],
            "tokens": [[]],
            "domain_labels": torch.tensor([_TRAIN_DOMAIN], dtype=torch.long),
        }

    pad_id = _infer_pad_id_from_sample(valid[0], default_pad_id=1)
    inputs, masks, labs, twmaps, srcs, toks, domains = [], [], [], [], [], [], []

    for i, s in enumerate(valid):
        try:
            in_ids = s["input_ids"]
            att = s.get("attention_mask", None)
            lab = s["labels"]
            domain = s.get("domain_label", _TRAIN_DOMAIN)

            if att is None:
                att = (in_ids != pad_id).long()
            else:
                try:
                    att = att.view(-1).long()
                except Exception:
                    att = (in_ids != pad_id).long()

            try:
                in_ids = in_ids.view(-1)
            except Exception:
                in_ids = in_ids.flatten()

            try:
                lab = lab.view(-1)
            except Exception:
                lab = lab.flatten()

            in_ids = _pad_or_truncate_array(in_ids, _MAX_LENGTH, pad_id)
            att = _pad_or_truncate_array(att, _MAX_LENGTH, 0)
            lab = _pad_or_truncate_array(lab, _MAX_LENGTH, -100)

            inputs.append(in_ids)
            masks.append(att)
            labs.append(lab)
            twmaps.append(s.get("token_word_map", {}))
            srcs.append(s.get("src_text", ""))
            toks.append(s.get("tokens", []))
            domains.append(domain)
        except Exception as e:
            cell2_dbg("collate_item_exc", f"Collate item exception idx={i}: {type(e).__name__}")
            continue

    if not inputs:
        pad = _infer_pad_id_from_sample({}, default_pad_id=1)
        return {
            "input_ids": torch.full((1, _MAX_LENGTH), pad, dtype=torch.long),
            "attention_mask": torch.zeros(1, _MAX_LENGTH, dtype=torch.long),
            "labels": torch.full((1, _MAX_LENGTH), -100, dtype=torch.long),
            "token_word_map": [{}],
            "src_text": [""],
            "tokens": [[]],
            "domain_labels": torch.tensor([_TRAIN_DOMAIN], dtype=torch.long),
        }

    input_ids = torch.stack(inputs, dim=0)
    attention_mask = torch.stack(masks, dim=0)
    labels = torch.stack(labs, dim=0)
    try:
        domain_labels = torch.tensor(domains, dtype=torch.long)
    except Exception:
        domain_labels = torch.full((len(inputs),), _TRAIN_DOMAIN, dtype=torch.long)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "token_word_map": twmaps,
        "src_text": srcs,
        "tokens": toks,
        "domain_labels": domain_labels,
    }

def create_optimized_dataloader(
    dataset: Dataset,
    batch_size: Optional[int] = None,
    shuffle: bool = True,
    split: str = "train",
) -> DataLoader:
    if batch_size is None:
        try:
            batch_size = int(BATCH_SIZE)
        except NameError:
            batch_size = 8

    batch_size = int(batch_size)
    original_batch_size = batch_size
    adjusted = False

    if _USE_MULTI_GPU and _NUM_GPUS > 0 and batch_size % _NUM_GPUS != 0:
        new_batch_size = (batch_size // _NUM_GPUS) * _NUM_GPUS
        if new_batch_size == 0:
            if DEBUG_CELL2:
                print(f"[CELL2] WARNING: batch_size {batch_size} < num_gpus {_NUM_GPUS}. Keeping original.")
        else:
            batch_size = new_batch_size
            adjusted = batch_size != original_batch_size

    if adjusted:
        print(f"[CELL2] Adjusted batch size {original_batch_size} to {batch_size} (DP-divisible, GPUs={_NUM_GPUS})")

    num_workers = _NUM_WORKERS if isinstance(_NUM_WORKERS, int) and _NUM_WORKERS >= 0 else 0
    try:
        max_possible = max(0, (os.cpu_count() or 1) - 1)
        if num_workers > max_possible:
            num_workers = max_possible
    except Exception:
        pass

    loader_kwargs: Dict[str, Any] = {
        "dataset": dataset,
        "batch_size": batch_size,
        "shuffle": shuffle,
        "num_workers": num_workers,
        "pin_memory": bool(_PIN_MEMORY and torch.cuda.is_available()),
        "collate_fn": safe_collate,
        "drop_last": False,
    }

    if num_workers > 0:
        loader_kwargs["worker_init_fn"] = _dataloader_worker_init_fn
        loader_kwargs["prefetch_factor"] = _PREFETCH_FACTOR
        loader_kwargs["persistent_workers"] = False

    try:
        dataloader = DataLoader(**loader_kwargs)
    except Exception as e:
        print(f"[CELL2] DataLoader init failed with num_workers={num_workers}: {type(e).__name__}")
        print("[CELL2] Retrying with num_workers=0")
        loader_kwargs["num_workers"] = 0
        loader_kwargs.pop("prefetch_factor", None)
        loader_kwargs.pop("persistent_workers", None)
        loader_kwargs.pop("worker_init_fn", None)
        dataloader = DataLoader(**loader_kwargs)

    if _USE_MULTI_GPU and _NUM_GPUS > 0:
        per_gpu = batch_size // _NUM_GPUS if _NUM_GPUS > 0 else batch_size
        print(f"[CELL2] DataLoader created: total_batch={batch_size}, per_gpu={per_gpu}, workers={loader_kwargs.get('num_workers', 0)}")
    else:
        print(f"[CELL2] DataLoader created: batch_size={batch_size}, workers={loader_kwargs.get('num_workers', 0)}")

    return dataloader

def aggregate_subwords_to_words(
    subword_embeds: torch.Tensor,
    tokens: List[str],
    device: torch.device
) -> Tuple[torch.Tensor, List[List[int]], List[str]]:
    word_boundaries = []
    word_strings = []
    current_word_indices = []
    current_word_parts = []
    
    for i, tok in enumerate(tokens):
        if tok in {'__bn__', '__en__', '<s>', '</s>', '<pad>'}:
            continue
        
        if tok.startswith('▁') or tok.startswith('Ġ'):
            if current_word_indices:
                word_boundaries.append(current_word_indices)
                word_str = ''.join(current_word_parts)
                word_strings.append(word_str)
            
            clean_tok = tok.replace('▁', '').replace('Ġ', '').strip()
            current_word_indices = [i]
            current_word_parts = [clean_tok]
        else:
            clean_tok = tok.replace('▁', '').replace('Ġ', '').strip()
            if current_word_indices:
                current_word_indices.append(i)
                current_word_parts.append(clean_tok)
            else:
                current_word_indices = [i]
                current_word_parts = [clean_tok]
    
    if current_word_indices:
        word_boundaries.append(current_word_indices)
        word_str = ''.join(current_word_parts)
        word_strings.append(word_str)
    
    word_embeds_list = []
    for indices in word_boundaries:
        word_embed = subword_embeds[indices].mean(dim=0)
        word_embeds_list.append(word_embed)
    
    if word_embeds_list:
        word_embeds = torch.stack(word_embeds_list, dim=0)
    else:
        word_embeds = torch.zeros(1, subword_embeds.shape[-1], device=device)
        word_boundaries = [[0]]
        word_strings = ['']
    
    return word_embeds, word_boundaries, word_strings

def broadcast_word_to_subword(
    word_outputs: Any,
    word_boundaries: List[List[int]],
    subword_len: int,
    device: torch.device
) -> Any:
    if isinstance(word_outputs, list):
        subword_out = [None] * subword_len
        for word_idx, indices in enumerate(word_boundaries):
            if word_idx < len(word_outputs):
                word_val = word_outputs[word_idx]
                for idx in indices:
                    if idx < subword_len:
                        subword_out[idx] = word_val
        return subword_out
    
    elif isinstance(word_outputs, torch.Tensor):
        subword_out = torch.zeros(subword_len, *word_outputs.shape[1:], device=device)
        for word_idx, indices in enumerate(word_boundaries):
            if word_idx < word_outputs.shape[0]:
                word_val = word_outputs[word_idx]
                for idx in indices:
                    if idx < subword_len:
                        subword_out[idx] = word_val
        return subword_out
    
    return word_outputs

print("Cell 2: Memory-efficient data loading ready")


Cell 2: Memory-efficient data loading ready


In [6]:
# ==============================================================================
# CELL 3: DSCD MODULE (PURE UNSUPERVISED DISCOVERY) - FAST INFERENCE MODE ADDED
# ==============================================================================

import threading
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gc
from collections import deque
import unicodedata
from typing import Optional, Dict, List, Any, Set, Tuple

PRINTINTERVAL = 200

try:
    from scipy.cluster.hierarchy import linkage, fcluster
    from scipy.spatial.distance import pdist
    HASCLUSTERING = True
except Exception:
    HASCLUSTERING = False
    print("[CELL3] WARNING: scipy not available")

try:
    from sklearn.cluster import KMeans
    HASKMEANS = True
except Exception:
    HASKMEANS = False
    print("[CELL3] WARNING: sklearn not available")

try:
    DSCDMAXPROTOS = int(DSCD_MAX_PROTOS)
    DSCDBUFFERSIZE = int(DSCD_BUFFER_SIZE)
    DSCDNMIN = max(3, int(DSCD_N_MIN))
    DSCDDISPERSIONTHRESHOLD = min(0.08, float(DSCD_DISPERSION_THRESHOLD))
    VERBOSELOGGING = bool(VERBOSE_LOGGING)
    DSCDENABLETRAININGCLUSTERING = bool(DSCD_ENABLE_TRAINING_CLUSTERING)
except (NameError, ValueError, TypeError):
    DSCDMAXPROTOS = 8
    DSCDBUFFERSIZE = 50
    DSCDNMIN = 3
    DSCDDISPERSIONTHRESHOLD = 0.08
    VERBOSELOGGING = True
    DSCDENABLETRAININGCLUSTERING = True
    print("[CELL3] WARNING: Using default DSCD config")

try:
    DEBUGDISCOVERY = bool(DEBUG_DISCOVERY)
except NameError:
    DEBUGDISCOVERY = False

try:
    MAXTOKENSPERDISCOVERY = int(globals().get("_MAX_TOKENS_PER_DISCOVERY", 150))
except Exception:
    MAXTOKENSPERDISCOVERY = 150

try:
    DSCDNEWSENSELAMBDA = float(globals().get("DSCD_NEW_SENSE_LAMBDA", 1.5))
except Exception:
    DSCDNEWSENSELAMBDA = 1.5

try:
    HOMOGRAPHREFERENCELISTBN = set(HOMOGRAPH_REFERENCE_LIST_BN)
    print(f"[CELL3] Loaded reference list for evaluation: {len(HOMOGRAPHREFERENCELISTBN)} words")
except (NameError, TypeError):
    HOMOGRAPHREFERENCELISTBN = {"কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা"}
    print("[CELL3] Using default reference list")

DSCDMAXCLUSTERINGPOINTS = 500
PUNCTSET = set(list(".,!?-;:"))

def normalizetokenkey(token: str) -> str:
    token = "" if token is None else str(token)
    token = unicodedata.normalize("NFKC", token)
    return token.replace("▁", "").replace("##", "").replace("Ġ", "").replace(" ", "").strip().lower()

def iswordtoken(token: str, minletters: int = 2, minletterfraction: float = 0.6) -> bool:
    if not token or not isinstance(token, str):
        return False
    token = token.strip()
    if token == "":
        return False
    letters = 0
    total = 0
    for ch in token:
        cat = unicodedata.category(ch)
        if cat.startswith("L"):
            letters += 1
        if not ch.isspace():
            total += 1
    if total == 0:
        return False
    if letters < minletters:
        return False
    if letters / total < minletterfraction:
        return False
    return True

class MemoryEfficientPrototypeStore:
    def __init__(self, embeddim: int, maxprotos: Optional[int] = None):
        if maxprotos is None:
            maxprotos = DSCDMAXPROTOS
        self.embeddim = embeddim
        self.maxprotos = int(maxprotos)
        self.centroids: List[torch.Tensor] = []
        self.counts: List[int] = []
        self.creationtime: List[float] = []
        self.distances: List[float] = []
        self.mu = 0.0
        self.tau = 1e-6
        self.alpha = 0.1
        self.labels: Optional[torch.Tensor] = None

    def addprototype(self, vector: torch.Tensor, currenttime: Optional[float] = None, count: int = 1) -> None:
        if currenttime is None:
            currenttime = time.time()
        v = vector.detach().cpu().clone()
        if len(self.centroids) < self.maxprotos:
            self.centroids.append(v)
            self.counts.append(int(count))
            self.creationtime.append(float(currenttime))
        else:
            minidx = int(np.argmin(self.counts)) if len(self.counts) > 0 else 0
            self.centroids[minidx] = v
            self.counts[minidx] = int(count)
            self.creationtime[minidx] = float(currenttime)

    def updateprototype(self, idx: int, vector: torch.Tensor, eta: float = 0.05, assignmentdistance: Optional[float] = None) -> None:
        if idx < 0 or idx >= len(self.centroids):
            self.addprototype(vector, time.time(), count=1)
            return
        oldcentroid = self.centroids[idx]
        newvector = vector.detach().cpu()
        self.centroids[idx] = (1.0 - eta) * oldcentroid + eta * newvector
        self.counts[idx] = int(self.counts[idx] + 1)
        if assignmentdistance is not None:
            self.updaterollingstats(float(assignmentdistance))

    def updaterollingstats(self, d: float) -> None:
        if not self.distances:
            self.mu = float(d)
            self.tau = 1e-6
            self.distances.append(float(d))
            return
        prevmu = self.mu
        self.mu = (1 - self.alpha) * self.mu + self.alpha * float(d)
        self.tau = (1 - self.alpha) * self.tau + self.alpha * abs(float(d) - prevmu)
        self.distances.append(float(d))
        if len(self.distances) > 50:
            self.distances.pop(0)

    def getadaptivethreshold(self, lam: float = 1.0) -> float:
        return float(self.mu + lam * self.tau)

    def size(self) -> int:
        return len(self.centroids)

    def ensureconsistency(self) -> None:
        n = len(self.centroids)
        if len(self.counts) != n:
            self.counts = self.counts[:n] if len(self.counts) > n else self.counts + [1] * (n - len(self.counts))
        if len(self.creationtime) != n:
            self.creationtime = self.creationtime[:n] if len(self.creationtime) > n else self.creationtime + [time.time()] * (n - len(self.creationtime))

class MemoryEfficientDSCDOnline(nn.Module):
    def __init__(
        self,
        embeddim: int,
        tokenizer=None,
        buffersize: Optional[int] = None,
        maxprotos: Optional[int] = None,
        nmin: Optional[int] = None,
        dispersionthreshold: Optional[float] = None,
        language: str = "bn",
        enabletrainingclustering: Optional[bool] = None,
        maxclusteringpoints: Optional[int] = None,
        maxcandidatesperstep: int = 2,
        dscdminletters: int = 2,
        dscdminletterfraction: float = 0.6,
    ):
        super().__init__()
        if buffersize is None:
            buffersize = DSCDBUFFERSIZE
        if maxprotos is None:
            maxprotos = DSCDMAXPROTOS
        if nmin is None:
            nmin = DSCDNMIN
        if dispersionthreshold is None:
            dispersionthreshold = DSCDDISPERSIONTHRESHOLD
        if maxclusteringpoints is None:
            maxclusteringpoints = DSCDMAXCLUSTERINGPOINTS
        if enabletrainingclustering is None:
            enabletrainingclustering = DSCDENABLETRAININGCLUSTERING

        self.embeddim = int(embeddim)
        self.buffersize = int(buffersize)
        self.maxprotos = int(maxprotos)
        self.nmin = max(3, int(nmin))
        self.dispersionthreshold = min(0.08, float(dispersionthreshold))
        self.language = language
        self.tokenizer = tokenizer
        self.dscdminletters = int(dscdminletters)
        self.dscdminletterfraction = float(dscdminletterfraction)

        try:
            if tokenizer is not None and "get_special_tokens" in globals():
                self.specialtokens = get_special_tokens(tokenizer)
            else:
                self.specialtokens = set(getattr(tokenizer, "all_special_tokens", [])) if tokenizer is not None else set()
        except Exception:
            self.specialtokens = set()

        self.dscdallowedtokens: Set[str] = set()
        self.dscdignoredtokens: Set[str] = set()
        self.dscdcachemaxsize = 10000

        self.prototypestores: Dict[str, MemoryEfficientPrototypeStore] = {}
        self.buffers: Dict[str, deque] = {}
        self.discoveredlog: List[Dict[str, Any]] = []
        self.discoveredhomographs: Set[str] = set()

        self.lastperiodiccheck = 0
        self.cleanupcounter = 0

        self.dispersioncache: Dict[str, float] = {}
        self.dispersionlastupdated: Dict[str, float] = {}
        self.dispersionlock = threading.Lock()

        self.clusteringlock = threading.Lock()
        self.bufferlock = threading.Lock()

        from collections import deque as threaddeque
        self.activethreads = threaddeque(maxlen=100)
        self.threadlock = threading.Lock()

        self.lastclustertime: Dict[str, float] = {}
        self.clustercooldownseconds = 5.0
        self.enabletrainingclustering = bool(enabletrainingclustering)

        self.discoverycount = 0
        self.discoverytimes: List[float] = []
        self.clusteredtokens: Set[str] = set()
        self.clusterstats: Dict[str, Dict[str, Any]] = {}

        self.spanhead = nn.Sequential(
            nn.Linear(self.embeddim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1),
        )
        self.sigmanet = nn.Sequential(
            nn.Linear(self.embeddim, 16),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(16, 1),
        )

        self.gatew = nn.Parameter(torch.tensor(1.0))
        self.gateb = nn.Parameter(torch.tensor(0.4))
        self.gamma = nn.Parameter(torch.tensor(0.3))

        self.maxclusteringpoints = int(maxclusteringpoints)
        self.maxcandidatesperstep = int(maxcandidatesperstep)

    def state_dict(self, destination=None, prefix="", keep_vars=False):
        state = super().state_dict(destination, prefix, keep_vars)
        plainstores = {}
        for token, store in self.prototypestores.items():
            plainstores[token] = {
                "centroids": [c.cpu().numpy() for c in store.centroids] if hasattr(store, "centroids") else [],
                "counts": list(store.counts) if hasattr(store, "counts") else [],
                "creationtime": list(store.creationtime) if hasattr(store, "creationtime") else [],
                "mu": float(store.mu) if hasattr(store, "mu") else 0.0,
                "tau": float(store.tau) if hasattr(store, "tau") else 0.0,
                "size": int(store.size()) if hasattr(store, "size") else 0,
            }
        state[prefix + "prototypestores"] = plainstores
        state[prefix + "discoveredhomographs"] = list(self.discoveredhomographs)
        return state

    def load_state_dict(self, state_dict, strict=True):
        plainstores = state_dict.pop("prototypestores", state_dict.pop("prototypestoresdata", None))
        discovered = state_dict.pop("discoveredhomographs", [])
        super().load_state_dict(state_dict, strict=strict)

        if not plainstores:
            print("[DSCD WARNING] Empty prototypestores in checkpoint")
            return

        self.prototypestores = {}
        self.discoveredhomographs = set(discovered)

        for token, storedict in plainstores.items():
            store = MemoryEfficientPrototypeStore(embeddim=self.embeddim, maxprotos=self.maxprotos)
            centroidsdata = storedict.get("centroids", [])
            for c in centroidsdata:
                if isinstance(c, torch.Tensor):
                    store.centroids.append(c)
                else:
                    store.centroids.append(torch.tensor(c))
            store.counts = storedict.get("counts", [])
            store.creationtime = storedict.get("creationtime", [])
            store.mu = storedict.get("mu", 0.0)
            store.tau = storedict.get("tau", 0.0)
            store.ensureconsistency()
            self.prototypestores[token] = store

        print(f"[DSCD] Loaded {len(self.prototypestores)} tokens, {sum(s.size() for s in self.prototypestores.values())} prototypes")

    @staticmethod
    def clean_token(token):
        token = str(token)
        token = token.replace("▁", "").replace("Ġ", "").replace("##", "")
        for punct in ["।", ".", ",", "!", "?", ":", ";", "-"]:
            token = token.replace(punct, "")
        return token.strip()

    def isvalidmultisense(self, token):
        if token not in self.prototypestores:
            return False
        store = self.prototypestores[token]
        totaloccurrences = sum(store.counts) if hasattr(store, "counts") else 0
        minperproto = min(store.counts) if hasattr(store, "counts") and store.counts else 0
        return store.size() >= 2 and totaloccurrences >= 10 and minperproto >= 2

    def ismultisensestore(self, store: MemoryEfficientPrototypeStore) -> bool:
        k = store.size()
        if k < 2:
            return False
        counts = store.counts if store.counts else [1] * k
        strong = sum(1 for c in counts if c >= max(3, self.nmin))
        if strong < 2:
            return False
        try:
            cents = []
            for c in store.centroids:
                if isinstance(c, torch.Tensor):
                    cents.append(c.cpu().numpy())
                else:
                    cents.append(np.asarray(c, dtype=np.float32))
            if len(cents) < 2:
                return False
            cents = np.stack(cents, axis=0)
            dists = np.linalg.norm(cents[:, None, :] - cents[None, :, :], axis=-1)
            tri = dists[np.triu_indices(len(cents), k=1)]
            if tri.size == 0:
                return False
            mindist = float(tri.min())
            base = max(store.tau, 1e-3)
            return mindist > base * DSCDNEWSENSELAMBDA
        except Exception:
            return True

    def getdispersion(self, tokentype: str) -> float:
        with self.dispersionlock:
            if tokentype in self.dispersioncache:
                try:
                    lastupdate = self.dispersionlastupdated.get(tokentype, 0.0)
                    if time.time() - lastupdate < 3600:
                        return self.dispersioncache[tokentype]
                except Exception:
                    pass

        with self.bufferlock:
            if tokentype not in self.buffers or len(self.buffers[tokentype]) < 2:
                return 0.0

        try:
            embeddings: List[np.ndarray] = []
            for emb in self.buffers[tokentype]:
                try:
                    if isinstance(emb, torch.Tensor):
                        embeddings.append(emb.cpu().numpy())
                    else:
                        embeddings.append(np.asarray(emb, dtype=np.float32))
                except Exception:
                    continue
            if len(embeddings) < 2:
                return 0.0
            embeddingsnp = np.stack(embeddings, axis=0)
            centroid = embeddingsnp.mean(axis=0)
            distances = np.linalg.norm(embeddingsnp - centroid[None, :], axis=1)
            dispersion = float(distances.std())
            with self.dispersionlock:
                self.dispersioncache[tokentype] = dispersion
                self.dispersionlastupdated[tokentype] = time.time()
            return dispersion
        except Exception:
            return 0.0

    def shouldtracktoken(self, tokentext: str) -> bool:
        if not tokentext or not isinstance(tokentext, str):
            return False

        if len(self.dscdallowedtokens) > self.dscdcachemaxsize:
            self.dscdallowedtokens.clear()
        if len(self.dscdignoredtokens) > self.dscdcachemaxsize:
            self.dscdignoredtokens.clear()

        if tokentext in self.dscdallowedtokens:
            return True
        if tokentext in self.dscdignoredtokens:
            return False

        if not getattr(self, "training", False):
            if tokentext in self.prototypestores:
                self.dscdallowedtokens.add(tokentext)
                return True
            clean = tokentext.replace("▁", "").replace("##", "").replace("Ġ", "").replace(" ", "").strip()
            clean = unicodedata.normalize("NFKC", clean)
            if clean and clean in self.prototypestores:
                self.dscdallowedtokens.add(tokentext)
                return True

        if tokentext in self.specialtokens:
            self.dscdignoredtokens.add(tokentext)
            return False

        clean = tokentext.replace("▁", "").replace("##", "").replace("Ġ", "").replace(" ", "").strip()
        clean = unicodedata.normalize("NFKC", clean)

        if not clean:
            self.dscdignoredtokens.add(tokentext)
            return False

        if len(clean) < 2:
            self.dscdignoredtokens.add(tokentext)
            return False

        if not any(c.isalpha() for c in clean):
            self.dscdignoredtokens.add(tokentext)
            return False

        if clean.isdigit():
            self.dscdignoredtokens.add(tokentext)
            return False

        if all(c in PUNCTSET for c in clean):
            self.dscdignoredtokens.add(tokentext)
            return False

        try:
            bengaliblock = any("\u0980" <= c <= "\u09FF" for c in clean)
            if bengaliblock:
                if len(clean) >= 2:
                    self.dscdallowedtokens.add(tokentext)
                    return True
                self.dscdignoredtokens.add(tokentext)
                return False
        except Exception:
            pass

        if iswordtoken(clean, minletters=self.dscdminletters, minletterfraction=self.dscdminletterfraction):
            self.dscdallowedtokens.add(tokentext)
            return True

        self.dscdignoredtokens.add(tokentext)
        return False

    def canonicaltokenkey(self, rawtoken: str, tokenwordmap: Optional[Dict[int, Optional[str]]], idx: int) -> Optional[str]:
        canonical: Optional[str] = None
        try:
            if tokenwordmap and isinstance(tokenwordmap, dict) and idx in tokenwordmap and tokenwordmap[idx]:
                canonical = str(tokenwordmap[idx]).strip()
        except Exception:
            canonical = None

        if canonical:
            canonical = unicodedata.normalize("NFKC", canonical).strip().lower()
            if canonical:
                return canonical

        cleaned = self.clean_token(rawtoken)
        cleaned = unicodedata.normalize("NFKC", cleaned)
        cleaned = cleaned.replace("▁", "").replace("##", "").replace("Ġ", "").replace(" ", "").strip().lower()
        if cleaned:
            return cleaned

        return None

    def cleanupthreads(self) -> None:
        try:
            with self.threadlock:
                alive = [th for th in list(self.activethreads) if th.is_alive()]
                self.activethreads.clear()
                self.activethreads.extend(alive)
        except Exception:
            pass

    def cleanupmemory(self) -> None:
        try:
            for tokentype, buffer in list(self.buffers.items()):
                if len(buffer) > int(self.buffersize * 1.5):
                    while len(buffer) > self.buffersize:
                        buffer.popleft()
            try:
                now = time.time()
                expired = [k for k, v in self.dispersionlastupdated.items() if now - v > 3600]
                for k in expired:
                    self.dispersioncache.pop(k, None)
                    self.dispersionlastupdated.pop(k, None)
            except Exception:
                pass
            if gc.isenabled():
                gc.collect()
        except Exception:
            pass

    def forward(
        self,
        tokenembeddings,
        tokentypes=None,
        trainmode: bool = True,
        tokenwordmap=None,
        hall=None,
        inputids=None,
        attentionmask=None,
        fast_inference: bool = False,
    ):
        if tokenembeddings is None and hall is not None:
            tokenembeddings = hall
        if tokenembeddings is None:
            raise ValueError("MemoryEfficientDSCDOnline.forward requires tokenembeddings or hall")

        batch_size, seq_len, embed_dim = tokenembeddings.shape
        device = tokenembeddings.device

        if fast_inference:
            return {
                "haugmented": tokenembeddings.detach().clone(),
                "protoprobs": [
                    [torch.tensor([1.0], device=device, dtype=torch.float32) for _ in range(seq_len)]
                    for _ in range(batch_size)
                ],
                "uncertainties": [
                    [torch.tensor(0.0, device=device, dtype=torch.float32) for _ in range(seq_len)]
                    for _ in range(batch_size)
                ],
                "gates": [
                    [torch.tensor(0.0, device=device, dtype=torch.float32) for _ in range(seq_len)]
                    for _ in range(batch_size)
                ],
                "spanpreds": [
                    [torch.tensor(0.0, device=device, dtype=torch.float32) for _ in range(seq_len)]
                    for _ in range(batch_size)
                ],
                "protoassignments": [
                    torch.zeros(seq_len, dtype=torch.long, device=device) for _ in range(batch_size)
                ],
            }

        if inputids is not None and tokentypes is None:
            batchsize, seqlen = inputids.shape
            tokentypes = []
            for b in range(batchsize):
                if self.tokenizer is not None:
                    try:
                        tokentypes.append(self.tokenizer.convert_ids_to_tokens(inputids[b].tolist()))
                    except Exception:
                        tokentypes.append([f"tok{i}" for i in range(seqlen)])
                else:
                    tokentypes.append([f"tok{i}" for i in range(seqlen)])

        self.cleanupcounter += 1
        if self.cleanupcounter % 50 == 0:
            self.cleanupcounter = 0
            self.cleanupmemory()
            self.cleanupthreads()

        batchsize = int(tokenembeddings.size(0))
        seqlen = int(tokenembeddings.size(1))

        alloutputs: Dict[str, List[Any]] = {
            "protoassignments": [],
            "protoprobs": [],
            "uncertainties": [],
            "spanpreds": [],
            "gates": [],
            "haugmented": [],
        }

        for b in range(batchsize):
            wordmap = tokenwordmap[b] if tokenwordmap and len(tokenwordmap) > b else None
            batchoutputs = self.processsequence(
                tokenembeddings[b],
                tokentypes[b] if tokentypes and len(tokentypes) > b else [f"tok{i}" for i in range(seqlen)],
                device,
                wordmap=wordmap,
                trainmode=trainmode,
            )
            for k in alloutputs:
                alloutputs[k].append(batchoutputs[k])

        try:
            hauglist: List[torch.Tensor] = []
            maxseqlen = seqlen
            for b in range(batchsize):
                hbatchlist = alloutputs["haugmented"][b]
                if len(hbatchlist) > 0 and isinstance(hbatchlist[0], torch.Tensor):
                    hbatch = torch.stack(hbatchlist, dim=0)
                    if hbatch.size(0) < maxseqlen:
                        pad = maxseqlen - hbatch.size(0)
                        hbatch = F.pad(hbatch, (0, 0, 0, pad), value=0)
                    elif hbatch.size(0) > maxseqlen:
                        hbatch = hbatch[:maxseqlen]
                else:
                    hbatch = torch.zeros(maxseqlen, self.embeddim, device=device)
                hauglist.append(hbatch)
            alloutputs["haugmented"] = torch.stack(hauglist, dim=0)
        except Exception:
            alloutputs["haugmented"] = tokenembeddings

        try:
            protoassigntensor = []
            for row in alloutputs["protoassignments"]:
                try:
                    stacked = torch.stack([x if isinstance(x, torch.Tensor) else torch.tensor(x) for x in row], dim=0)
                    protoassigntensor.append(stacked)
                except Exception:
                    protoassigntensor.append(torch.tensor([int(x) if not isinstance(x, torch.Tensor) else int(x.item()) for x in row], dtype=torch.long))
            alloutputs["protoassignments"] = protoassigntensor
        except Exception:
            pass

        return alloutputs

    def processsequence(
        self,
        tokenembeddings: torch.Tensor,
        tokentypes: List[Any],
        device: torch.device,
        wordmap: Optional[Dict[int, Optional[str]]] = None,
        trainmode: bool = True,
    ) -> Dict[str, List[Any]]:
        seqlen = int(tokenembeddings.size(0))
        outputs: Dict[str, List[Any]] = {
            "protoassignments": [],
            "protoprobs": [],
            "uncertainties": [],
            "spanpreds": [],
            "gates": [],
            "haugmented": [],
        }

        for j in range(seqlen):
            rawtok = tokentypes[j] if j < len(tokentypes) else f"tok{j}"
            if not isinstance(rawtok, str):
                rawtok = str(rawtok) if rawtok is not None else f"tok{j}"

            tokenkey = self.canonicaltokenkey(rawtok, wordmap, j)
            hj = tokenembeddings[j]

            if not tokenkey:
                outputs["protoassignments"].append(torch.tensor(-1))
                outputs["protoprobs"].append([])
                outputs["uncertainties"].append(0.0)
                outputs["spanpreds"].append(0.0)
                outputs["gates"].append(0.0)
                outputs["haugmented"].append(hj)
                continue

            tokenkey = unicodedata.normalize("NFKC", tokenkey).strip().lower()

            if not self.shouldtracktoken(tokenkey):
                outputs["protoassignments"].append(torch.tensor(-1))
                outputs["protoprobs"].append([])
                outputs["uncertainties"].append(0.0)
                outputs["spanpreds"].append(0.0)
                outputs["gates"].append(0.0)
                outputs["haugmented"].append(hj)
                continue

            with self.bufferlock:
                if tokenkey not in self.buffers:
                    self.buffers[tokenkey] = deque(maxlen=self.buffersize)
                    self.prototypestores[tokenkey] = MemoryEfficientPrototypeStore(self.embeddim, self.maxprotos)

                try:
                    self.buffers[tokenkey].append(hj.detach().clone().cpu())
                except Exception:
                    try:
                        self.buffers[tokenkey].append(hj.cpu())
                    except Exception:
                        pass

                bufferlen = len(self.buffers[tokenkey])

            try:
                if self.enabletrainingclustering and bufferlen >= max(self.nmin, 4):
                    now = time.time()
                    lastt = self.lastclustertime.get(tokenkey, 0.0)
                    if now - lastt > self.clustercooldownseconds:
                        self.lastclustertime[tokenkey] = now

                        def bgcluster(tok: str = tokenkey) -> None:
                            try:
                                with self.clusteringlock:
                                    self.clusterbuffertoprototypeshierarchical(tok)
                            except Exception:
                                pass

                        th = threading.Thread(target=bgcluster, daemon=True)
                        th.start()
                        with self.threadlock:
                            self.activethreads.append(th)
            except Exception:
                pass

            store = self.prototypestores[tokenkey]

            centroidssnapshot: Optional[List[torch.Tensor]] = None
            with self.clusteringlock:
                try:
                    if hasattr(store, "centroids") and len(store.centroids) > 0:
                        centroidssnapshot = []
                        for c in store.centroids:
                            try:
                                if isinstance(c, torch.Tensor):
                                    centroidssnapshot.append(c.clone().cpu())
                                else:
                                    centroidssnapshot.append(torch.from_numpy(np.asarray(c, dtype=np.float32)).cpu())
                            except Exception:
                                continue
                        if not centroidssnapshot:
                            centroidssnapshot = None
                except Exception:
                    centroidssnapshot = None

            assignment = -1
            problist: List[float] = []
            uncertainty = 0.0
            spanpred = 0.0
            gateval = 0.0
            haug = hj

            if centroidssnapshot and len(centroidssnapshot) >= 1:
                try:
                    try:
                        hcpu = hj.detach().cpu().numpy()
                    except Exception:
                        hcpu = hj.cpu().numpy()
                    try:
                        centsnp = np.stack([c.numpy() for c in centroidssnapshot], axis=0)
                    except Exception:
                        centsnp = np.stack([np.asarray(c, dtype=np.float32) for c in centroidssnapshot], axis=0)

                    distsnp = np.linalg.norm(centsnp - hcpu[None, :], axis=1)
                    if distsnp.size > 0:
                        mindist = float(distsnp.min())
                        minidx = int(np.argmin(distsnp))
                        maxdist = float(distsnp.max())

                        spanrange = maxdist - mindist
                        ratio = spanrange / (maxdist + 1e-8) if maxdist > 0 else 0.0
                        if ratio > 0.25:
                            spanpred = float(min(1.0, max(0.0, 0.5 * (ratio - 0.25))))
                        else:
                            spanpred = float(min(1.0, max(0.0, ratio)))

                        try:
                            store.updaterollingstats(mindist)
                        except Exception:
                            pass

                        try:
                            disttensor = torch.from_numpy(distsnp).to(device)
                            probstensor = F.softmax(-disttensor, dim=0)
                            problist = probstensor.tolist()
                            entropy = -torch.sum(probstensor * torch.log(probstensor + 1e-10))
                            maxentropy = np.log(len(distsnp))
                            uncertainty = float(entropy.item() / maxentropy) if maxentropy > 0 else 0.0
                        except Exception:
                            exps = np.exp(-distsnp - np.max(-distsnp)) if distsnp.size > 0 else np.array([])
                            if exps.size > 0:
                                probs = exps / (exps.sum() + 1e-12)
                                problist = probs.tolist()
                                entropyval = -np.sum(probs * np.log(probs + 1e-10))
                                maxentropy = np.log(len(distsnp))
                                uncertainty = float(entropyval / maxentropy) if maxentropy > 0 else 0.0

                        try:
                            gateval = float(torch.sigmoid(self.gatew * torch.norm(hj) - self.gateb).item())
                        except Exception:
                            gateval = 0.5

                        if gateval > 0.3:
                            assignment = minidx

                        try:
                            if store.size() < self.maxprotos and mindist > store.getadaptivethreshold(DSCDNEWSENSELAMBDA):
                                store.addprototype(hj, time.time(), count=1)
                                assignment = store.size() - 1
                                centroidssnapshot.append(hj.detach().cpu())
                        except Exception:
                            pass

                        if assignment >= 0 and assignment < len(centroidssnapshot):
                            centroidt = centroidssnapshot[assignment]
                            try:
                                if device != torch.device("cpu"):
                                    centroidt = centroidt.to(device)
                            except Exception:
                                pass
                            haug = hj + 0.1 * (centroidt - hj)
                except Exception as e:
                    if DEBUGDISCOVERY:
                        print(f"[DSCD] Assignment error for {tokenkey}: {str(e)[:200]}")

            outputs["protoassignments"].append(torch.tensor(assignment))
            outputs["protoprobs"].append(problist)
            outputs["uncertainties"].append(uncertainty)
            outputs["spanpreds"].append(spanpred)
            outputs["gates"].append(gateval)
            outputs["haugmented"].append(haug)

        return outputs

    def clusterbuffertoprototypeshierarchical(self, tokentype: str) -> bool:
        try:
            if not self.shouldtracktoken(tokentype):
                if DEBUGDISCOVERY:
                    print(f"[DSCD-CLUSTER] Skipping non-word token {tokentype}")
                return False

            with self.bufferlock:
                if tokentype not in self.buffers:
                    return False
                bufsnapshot = [e.clone() if isinstance(e, torch.Tensor) else e for e in self.buffers[tokentype]]

            if len(bufsnapshot) < self.nmin:
                if DEBUGDISCOVERY:
                    print(f"[DSCD-CLUSTER] {tokentype}: buffer len {len(bufsnapshot)} < nmin {self.nmin}")
                return False

            emblist: List[np.ndarray] = []
            for e in bufsnapshot:
                try:
                    if isinstance(e, torch.Tensor):
                        try:
                            emblist.append(e.numpy())
                        except Exception:
                            emblist.append(e.cpu().numpy())
                    else:
                        emblist.append(np.asarray(e, dtype=np.float32))
                except Exception:
                    continue

            if len(emblist) == 0:
                return False

            if len(emblist) > self.maxclusteringpoints:
                idxs = np.random.choice(len(emblist), size=self.maxclusteringpoints, replace=False)
                newembeddings = np.stack([emblist[i] for i in idxs], axis=0)
            else:
                newembeddings = np.stack(emblist, axis=0)

            if newembeddings.shape[0] < 2:
                return False

            norms = np.linalg.norm(newembeddings, axis=1)
            if np.all(norms < 1e-6):
                if DEBUGDISCOVERY:
                    print(f"[DSCD-CLUSTER] {tokentype}: all zero vectors, skipping")
                return False

            store = self.prototypestores[tokentype]
            protosadded = 0

            if HASCLUSTERING:
                try:
                    condensed = pdist(newembeddings, metric="euclidean")
                    if condensed.size == 0:
                        return False
                    Z = linkage(condensed, method="average")
                    
                    median_dist = float(np.median(condensed)) if condensed.size > 0 else 0.5
                    absolutethreshold = max(self.dispersionthreshold, median_dist * 0.5)
                    
                    clusters = fcluster(Z, t=absolutethreshold, criterion="distance") - 1
                    if clusters.size == 0:
                        return False

                    maxc = int(clusters.max()) if clusters.size > 0 else 0
                    newcentroids: List[torch.Tensor] = []
                    newcounts: List[int] = []
                    newtimes: List[float] = []

                    for cid in range(maxc + 1):
                        mask = clusters == cid
                        clustersize = int(mask.sum())
                        if clustersize >= self.nmin:
                            centroid = newembeddings[mask].mean(axis=0).astype(np.float32)
                            centroidtensor = torch.from_numpy(centroid)
                            newcentroids.append(centroidtensor)
                            newcounts.append(clustersize)
                            newtimes.append(time.time())
                            protosadded += 1

                    if len(newcentroids) > self.maxprotos:
                        sortedindices = np.argsort(newcounts)[::-1][: self.maxprotos]
                        newcentroids = [newcentroids[i] for i in sortedindices]
                        newcounts = [newcounts[i] for i in sortedindices]
                        newtimes = [newtimes[i] for i in sortedindices]

                    if protosadded > 0:
                        store.centroids = newcentroids
                        store.counts = newcounts
                        store.creationtime = newtimes
                        store.labels = torch.tensor(clusters)
                        return store.size() > 0
                except Exception as e:
                    if DEBUGDISCOVERY:
                        print(f"[DSCD-CLUSTER] Hierarchical failed for {tokentype}: {type(e).__name__} {str(e)[:200]}")

            if protosadded == 0 and HASKMEANS:
                try:
                    embeddings = newembeddings
                    lenembeddings = int(embeddings.shape[0])
                    mink = 1
                    maxk = min(self.maxprotos, max(1, lenembeddings // max(self.nmin, 1)))
                    if maxk < mink:
                        maxk = mink

                    if lenembeddings >= 20:
                        kguess = min(maxk, max(2, int(np.sqrt(lenembeddings) / 2)))
                    elif lenembeddings >= 10:
                        kguess = min(maxk, 2)
                    else:
                        kguess = 1

                    kguess = max(mink, min(kguess, lenembeddings))
                    if kguess > 1 and lenembeddings >= kguess:
                        km = KMeans(n_clusters=kguess, random_state=0, n_init=10).fit(embeddings)
                        labels = km.labels_
                        newcentroids = []
                        newcounts = []
                        newtimes = []
                        for c in range(kguess):
                            mask = labels == c
                            clustersize = int(mask.sum())
                            if clustersize >= self.nmin:
                                centroid = embeddings[mask].mean(axis=0).astype(np.float32)
                                newcentroids.append(torch.from_numpy(centroid))
                                newcounts.append(clustersize)
                                newtimes.append(time.time())
                                protosadded += 1
                        if len(newcentroids) > 0:
                            store.centroids = newcentroids
                            store.counts = newcounts
                            store.creationtime = newtimes
                            store.labels = torch.tensor(labels)
                    return store.size() > 0
                except Exception as e:
                    if DEBUGDISCOVERY:
                        print(f"[DSCD-CLUSTER] KMeans failed for {tokentype}: {type(e).__name__} {str(e)[:200]}")

            return store.size() > 0
        except Exception as e:
            if DEBUGDISCOVERY:
                print(f"[DSCD-ERROR] Clustering error for {tokentype}: {type(e).__name__} {str(e)[:200]}")
            return False

    def periodic_discovery_check(self, current_step: int, frequency: int) -> None:
        try:
            if current_step - self.lastperiodiccheck < frequency:
                return
            
            self.lastperiodiccheck = current_step
            
            if DEBUGDISCOVERY:
                print(f"\n[DSCD] Periodic discovery check at step {current_step}")
            
            self.discover_homographs()
            
        except Exception as e:
            if DEBUGDISCOVERY:
                print(f"[DSCD] Periodic check failed: {e}")

    def discover_homographs(self) -> Dict[str, Any]:
        try:
            if DEBUGDISCOVERY:
                print(f"\n[DSCD] Running discovery...")
            
            discovered_count = 0
            total_candidates = 0
            
            with self.clusteringlock:
                for tokentype, store in list(self.prototypestores.items()):
                    total_candidates += 1
                    
                    if self.ismultisensestore(store):
                        clean_token = normalizetokenkey(tokentype)
                        if clean_token and clean_token not in self.discoveredhomographs:
                            self.discoveredhomographs.add(clean_token)
                            discovered_count += 1
                            
                            if DEBUGDISCOVERY:
                                print(f"[DSCD] Discovered: {clean_token} ({store.size()} senses)")
            
            discovery_result = {
                'discovered': discovered_count,
                'candidates': total_candidates,
                'total_homographs': len(self.discoveredhomographs),
                'timestamp': time.time()
            }
            
            self.discoveredlog.append(discovery_result)
            self.discoverycount += 1
            self.discoverytimes.append(time.time())
            
            if DEBUGDISCOVERY:
                print(f"[DSCD] Discovery: {discovered_count}/{total_candidates} homographs discovered")
                print(f"[DSCD] Total homographs: {len(self.discoveredhomographs)}")
            
            return discovery_result
            
        except Exception as e:
            if DEBUGDISCOVERY:
                print(f"[DSCD] discover_homographs failed: {e}")
            return {'discovered': 0, 'candidates': 0, 'total_homographs': 0}

    def get_discovered_homographs(self) -> Set[str]:
        try:
            with self.clusteringlock:
                return self.discoveredhomographs.copy()
        except Exception:
            return set()

    def validate_prototypes(self) -> Dict[str, Any]:
        try:
            total_tokens = 0
            total_prototypes = 0
            multi_sense_tokens = 0
            strong_multi_sense = 0
            
            with self.clusteringlock:
                for tokentype, store in self.prototypestores.items():
                    total_tokens += 1
                    protos = store.size()
                    total_prototypes += protos
                    
                    if protos >= 2:
                        multi_sense_tokens += 1
                        
                        if self.ismultisensestore(store):
                            strong_multi_sense += 1
            
            quality_score = strong_multi_sense / max(1, multi_sense_tokens) if multi_sense_tokens > 0 else 0.0
            
            return {
                'quality_score': quality_score,
                'multi_sense_tokens': multi_sense_tokens,
                'strong_multi_sense_tokens': strong_multi_sense,
                'total_prototypes': total_prototypes,
                'total_tokens': total_tokens,
            }
        except Exception:
            return {
                'quality_score': 0.0,
                'multi_sense_tokens': 0,
                'strong_multi_sense_tokens': 0,
                'total_prototypes': 0,
                'total_tokens': 0,
            }

    def printclusterssummary(self) -> None:
        try:
            items: List[Tuple[str, int, int, float, float, int]] = []
            for token, store in self.prototypestores.items():
                try:
                    protosamplecount = sum(getattr(store, "counts", []) or [])
                except Exception:
                    protosamplecount = 0
                bufferlen = len(self.buffers.get(token, [])) if token in self.buffers else 0
                totalcount = protosamplecount if protosamplecount > 0 else bufferlen
                protos = store.size()
                mu = getattr(store, "mu", 0.0)
                tau = getattr(store, "tau", 0.0)
                items.append((token, totalcount, protos, mu, tau, bufferlen))

            items.sort(key=lambda x: x[1], reverse=True)
            top5 = items[:5]

            if VERBOSELOGGING:
                print("[CLUSTER] Top 5 clusters")
                print("-" * 100)
                print(f"{'Rank':<6} {'Token':<18} {'Count':<12} {'Protos':<8} {'BufLen':<8} {'mu':<15} {'tau':<15}")
                print("-" * 100)
                for rank, (tok, cnt, prot, mu, tau, buflen) in enumerate(top5, 1):
                    tokstr = str(tok)[:18]
                    print(f"{rank:<6} {tokstr:<18} {cnt:<12} {prot:<8} {buflen:<8} {mu:<15.6f} {tau:<15.6f}")
                print("-" * 100)
        except Exception:
            pass

    def getexplanations(self, thresholdspan: float = 0.3) -> List[Dict[str, Any]]:
        expl: List[Dict[str, Any]] = []
        for tokentype, store in self.prototypestores.items():
            if store.size() >= 2:
                expl.append({"token": str(tokentype), "protos": store.size()})
        return expl

print("=" * 80)
print("Cell 3 DSCD Ready: FAST INFERENCE MODE ADDED")
print("=" * 80)
print("Configuration:")
print(f" - Max Protos: {DSCDMAXPROTOS}")
print(f" - Buffer Size: {DSCDBUFFERSIZE}")
print(f" - N Min: {DSCDNMIN} (enforced minimum: 3)")
print(f" - Dispersion Threshold: {DSCDDISPERSIONTHRESHOLD} (enforced maximum: 0.08)")
print(f" - Fast inference mode: ENABLED")
print("=" * 80)


[CELL3] Loaded reference list for evaluation: 65 words
Cell 3 DSCD Ready: FAST INFERENCE MODE ADDED
Configuration:
 - Max Protos: 8
 - Buffer Size: 50
 - N Min: 3 (enforced minimum: 3)
 - Dispersion Threshold: 0.08 (enforced maximum: 0.08)
 - Fast inference mode: ENABLED


In [7]:
# ==============================================================================
# CELL 4: ASBN MODULE - FIXED & ROBUST
# ==============================================================================

import traceback
from typing import Any, List, Tuple, Optional, Dict
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except Exception:
    _MAX_LENGTH = 48

try:
    _ENABLE_ASBN_TRAINING = bool(ENABLE_ASBN_TRAINING)
except Exception:
    _ENABLE_ASBN_TRAINING = True

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except Exception:
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except Exception:
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except Exception:
    _DEBUG_TIMING = False

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
except Exception:
    _SOURCE_LANGUAGE = "bn"

try:
    _GRL_ALPHA_START = float(GRL_ALPHA_START)
    _GRL_ALPHA_END = float(GRL_ALPHA_END)
    _GRL_ALPHA_SCHEDULE = str(GRL_ALPHA_SCHEDULE)
    try:
        _GRL_ALPHA_STEPS = int(GRL_ALPHA_STEPS)
    except Exception:
        _GRL_ALPHA_STEPS = 10000
except Exception:
    _GRL_ALPHA_START = 0.0
    _GRL_ALPHA_END = 1.0
    _GRL_ALPHA_SCHEDULE = "linear"
    _GRL_ALPHA_STEPS = 10000

_has_is_valid_token = "is_valid_token" in globals()
_has_get_tokenizer_special_tokens = "get_tokenizer_special_tokens" in globals()
_has_should_track_token = "should_track_token" in globals()

class GradientReversalFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x, alpha):
        ctx.alpha = float(alpha)
        return x.view_as(x)

    @staticmethod
    def backward(ctx, grad_output):
        return -ctx.alpha * grad_output, None

def gradient_reversal(x, alpha: float = 1.0):
    return GradientReversalFunction.apply(x, alpha)

class LightweightDiscriminator(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.classifier(x)

class DomainDiscriminator(nn.Module):
    def __init__(self, input_dim: int):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 2),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.classifier(x)

class MemoryEfficientASBNModule(nn.Module):
    def __init__(
        self,
        embed_dim: int,
        tokenizer=None,
        language: str = "bn",
        freq_threshold: float = 0.7,
        uncertainty_threshold: float = 0.3,
        gate_threshold: float = 0.5,
        warmup_steps: int = 1000,
        encoder_grl_scale: float = 0.5,
    ):
        super().__init__()
        self.language = language
        self.tokenizer = tokenizer
        self.embed_dim = int(embed_dim)

        self.bn_source = nn.BatchNorm1d(self.embed_dim, track_running_stats=True)
        self.bn_target = nn.BatchNorm1d(self.embed_dim, track_running_stats=True)

        self.d_domain = DomainDiscriminator(self.embed_dim)
        self.d_freq = LightweightDiscriminator(self.embed_dim + 2)
        self.d_ctx = LightweightDiscriminator(self.embed_dim + 2)
        self.d_xl = LightweightDiscriminator(self.embed_dim)
        self.freq_threshold = float(freq_threshold)
        self.uncertainty_threshold = float(uncertainty_threshold)
        self.gate_threshold = float(gate_threshold)
        self.warmup_steps = int(warmup_steps)
        self.current_step = 0
        self.lambda_base = {"freq": 1.0, "ctx": 0.5, "xl": 0.8, "domain": 1.0}
        self.lambda_max = 2.0
        self.encoder_grl_scale = float(encoder_grl_scale)
        self.stats_reset_interval = 1000
        self.stats = {
            "domain_loss": 0.0,
            "domain_accuracy": 0.0,
            "source_accuracy": 0.0,
            "target_accuracy": 0.0,
            "asbn_loss": 0.0,
            "num_updates": 0,
        }
        try:
            if tokenizer is not None:
                if _has_get_tokenizer_special_tokens:
                    self.special_tokens = get_tokenizer_special_tokens(tokenizer)
                else:
                    self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []))
            else:
                self.special_tokens = set()
        except Exception:
            self.special_tokens = set()
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print("[ASBN-INIT] Initialized MemoryEfficientASBNModule:")
            print(f"  - embed_dim: {self.embed_dim}")
            print(f"  - warmup_steps: {self.warmup_steps}")
            print(f"  - encoder_grl_scale: {self.encoder_grl_scale}")
            print(f"  - GRL_ALPHA_STEPS: {_GRL_ALPHA_STEPS}")
            print(f"  - thresholds: freq={self.freq_threshold}, uncert={self.uncertainty_threshold}, gate={self.gate_threshold}")

    def get_grl_alpha(self, global_step: Optional[int] = None) -> float:
        if global_step is None:
            global_step = self.current_step
        step = max(0, int(global_step))
        if _GRL_ALPHA_SCHEDULE == "linear":
            progress = min(1.0, float(step) / float(_GRL_ALPHA_STEPS))
            alpha = _GRL_ALPHA_START + progress * (_GRL_ALPHA_END - _GRL_ALPHA_START)
        elif _GRL_ALPHA_SCHEDULE == "exponential":
            progress = min(1.0, float(step) / float(_GRL_ALPHA_STEPS))
            ratio = _GRL_ALPHA_END / max(1e-8, _GRL_ALPHA_START if _GRL_ALPHA_START > 0 else 1e-3)
            alpha = _GRL_ALPHA_START * (ratio ** progress)
        else:
            alpha = _GRL_ALPHA_END
        return float(alpha)

    def get_asbn_stats(self) -> Dict[str, float]:
        return self.get_detailed_stats()

    def get_detailed_stats(self) -> Dict[str, float]:
        if self.stats["num_updates"] > 0:
            n = float(self.stats["num_updates"])
            return {
                "domain_loss": self.stats["domain_loss"] / n,
                "domain_accuracy": self.stats["domain_accuracy"] / n,
                "source_accuracy": self.stats["source_accuracy"] / n,
                "target_accuracy": self.stats["target_accuracy"] / n,
                "asbn_loss": self.stats["asbn_loss"] / n,
                "num_updates": self.stats["num_updates"],
            }
        return {
            "domain_loss": 0.0,
            "domain_accuracy": 0.0,
            "source_accuracy": 0.0,
            "target_accuracy": 0.0,
            "asbn_loss": 0.0,
            "num_updates": 0,
        }

    def reset_stats(self) -> None:
        self.stats = {
            "domain_loss": 0.0,
            "domain_accuracy": 0.0,
            "source_accuracy": 0.0,
            "target_accuracy": 0.0,
            "asbn_loss": 0.0,
            "num_updates": 0,
        }

    def critic_parameters(self):
        return (
            list(self.d_domain.parameters())
            + list(self.d_freq.parameters())
            + list(self.d_ctx.parameters())
            + list(self.d_xl.parameters())
        )

    def _ensure_discriminators_on_device(self, device: torch.device) -> None:
        try:
            for mod in (
                self.d_domain,
                self.d_freq,
                self.d_ctx,
                self.d_xl,
                self.bn_source,
                self.bn_target,
            ):
                try:
                    p = next(mod.parameters())
                    if p.device != device:
                        mod.to(device)
                except StopIteration:
                    mod.to(device)
                except Exception:
                    pass
        except Exception:
            if _VERBOSE_LOGGING:
                try:
                    print("[ASBN] Device migration failed:", traceback.format_exc().splitlines()[-1])
                except Exception:
                    print("[ASBN] Device migration failed")

    def _parse_proto_probs_matrix(self, proto_probs: Any, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
        pmax = torch.full((batch_size, seq_len), 0.5, dtype=torch.float32, device=device)
        try:
            if proto_probs is None:
                return pmax
            if isinstance(proto_probs, torch.Tensor):
                if proto_probs.dim() == 3:
                    B, T, K = proto_probs.shape
                    p = proto_probs.detach().to(device)
                    b_max = min(batch_size, B)
                    t_max = min(seq_len, T)
                    pmax[:b_max, :t_max] = p[:b_max, :t_max].max(dim=2)[0]
                    return pmax
                if proto_probs.dim() == 2:
                    p = proto_probs.detach().to(device)
                    if batch_size >= 1:
                        t_max = min(seq_len, p.size(0))
                        pmax[0, :t_max] = p[:t_max].max(dim=1)[0]
                        return pmax
            if isinstance(proto_probs, (list, tuple)):
                if len(proto_probs) == batch_size:
                    for b in range(batch_size):
                        row = proto_probs[b]
                        if isinstance(row, torch.Tensor) and row.dim() == 2:
                            t_max = min(seq_len, row.size(0))
                            pmax[b, :t_max] = row[:t_max].max(dim=1)[0].to(device)
                        elif isinstance(row, (list, tuple)):
                            for t in range(min(seq_len, len(row))):
                                try:
                                    val = row[t]
                                    if isinstance(val, torch.Tensor):
                                        pmax[b, t] = float(val.max().item())
                                    else:
                                        arr = np.asarray(val, dtype=np.float32)
                                        pmax[b, t] = float(np.max(arr))
                                except Exception:
                                    pmax[b, t] = 0.5
                else:
                    if batch_size == 1:
                        row = proto_probs
                        for t in range(min(seq_len, len(row))):
                            try:
                                val = row[t]
                                if isinstance(val, torch.Tensor):
                                    pmax[0, t] = float(val.max().item())
                                else:
                                    pmax[0, t] = float(np.max(np.asarray(val, dtype=np.float32)))
                            except Exception:
                                pmax[0, t] = 0.5
        except Exception as e:
            if _VERBOSE_LOGGING:
                print(f"[ASBN] parse_proto_probs exception: {e}")
        return pmax

    def _parse_scalar_matrix(self, mat: Any, batch_size: int, seq_len: int, device: torch.device,
                            default: float = 0.0) -> torch.Tensor:
        out = torch.full((batch_size, seq_len), float(default), dtype=torch.float32, device=device)
        try:
            if mat is None:
                return out
            if isinstance(mat, torch.Tensor):
                if mat.dim() == 3:
                    B, T, D = mat.shape
                    b_max = min(batch_size, B)
                    t_max = min(seq_len, T)
                    if D == 1:
                        out[:b_max, :t_max] = mat[:b_max, :t_max, 0].to(device)
                    else:
                        out[:b_max, :t_max] = mat[:b_max, :t_max, 0].to(device)
                elif mat.dim() == 2:
                    if mat.size(0) == batch_size:
                        t_max = min(seq_len, mat.size(1))
                        out[:, :t_max] = mat[:, :t_max].to(device)
                    elif batch_size == 1:
                        t_max = min(seq_len, mat.size(0))
                        out[0, :t_max] = mat[:t_max].to(device)
                elif mat.dim() == 1 and batch_size == 1:
                    t_max = min(seq_len, mat.size(0))
                    out[0, :t_max] = mat[:t_max].to(device)
            elif isinstance(mat, (list, tuple)):
                if len(mat) == batch_size:
                    for b in range(batch_size):
                        row = mat[b]
                        if isinstance(row, torch.Tensor) and row.dim() >= 1:
                            t_max = min(seq_len, row.size(0))
                            for t in range(t_max):
                                out[b, t] = float(row[t].item())
                        elif isinstance(row, (list, tuple, np.ndarray)):
                            t_max = min(seq_len, len(row))
                            for t in range(t_max):
                                try:
                                    v = row[t]
                                    out[b, t] = (float(v.item()) if isinstance(v, torch.Tensor) else float(v))
                                except Exception:
                                    out[b, t] = float(default)
                elif batch_size == 1:
                    row = mat
                    t_max = min(seq_len, len(row))
                    for t in range(t_max):
                        try:
                            v = row[t]
                            out[0, t] = (float(v.item()) if isinstance(v, torch.Tensor) else float(v))
                        except Exception:
                            out[0, t] = float(default)
        except Exception:
            if _VERBOSE_LOGGING:
                try:
                    print("[ASBN] parse_scalar_matrix exception:", traceback.format_exc().splitlines()[-1])
                except Exception:
                    pass
        return out

    def compute_lambda_scaled_tensor(self, pmax: torch.Tensor, uncertainty: torch.Tensor,
                                    gate: torch.Tensor, lambda_type: str) -> torch.Tensor:
        base = float(self.lambda_base.get(lambda_type, 0.2))
        
        lam = base * (1.0 - pmax + 1e-3) * (uncertainty + 1e-3) * (gate + 1e-3)
        
        lam = torch.clamp(lam, min=1e-4, max=float(self.lambda_max))
        
        lam = torch.where(torch.isfinite(lam), lam, torch.zeros_like(lam))
        return lam

    def forward(self, h: torch.Tensor, domain_labels: Optional[torch.Tensor] = None, 
                device: Optional[torch.device] = None) -> Tuple[torch.Tensor, torch.Tensor]:
        if not isinstance(h, torch.Tensor) or h.dim() != 3:
            dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
            return h, torch.tensor(0.0, device=dev)
        
        B, T, H = h.size()
        if device is None:
            device = h.device
        
        if domain_labels is not None:
            if domain_labels.dim() == 0:
                domain_labels = domain_labels.unsqueeze(0)
            if domain_labels.size(0) == 1 and B > 1:
                domain_labels = domain_labels.expand(B)
            elif domain_labels.size(0) != B:
                if _DEBUG_DISCOVERY:
                    print(f"[ASBN] Domain label size mismatch: {domain_labels.size(0)} vs batch {B}, using first label")
                domain_labels = domain_labels[0].unsqueeze(0).expand(B)
        
        if domain_labels is not None and self.training:
            try:
                self._ensure_discriminators_on_device(device)
                h_flat = h.view(B * T, H)
                domain_expanded = domain_labels.unsqueeze(1).expand(B, T).reshape(-1)
                source_mask = domain_expanded == 0
                target_mask = domain_expanded == 1
                
                h_normalized = h_flat.clone()
                
                source_count = int(source_mask.sum().item())
                target_count = int(target_mask.sum().item())
                
                if source_count >= 2:
                    h_normalized[source_mask] = self.bn_source(h_flat[source_mask])
                elif source_count == 1:
                    h_normalized[source_mask] = h_flat[source_mask]
                
                if target_count >= 2:
                    h_normalized[target_mask] = self.bn_target(h_flat[target_mask])
                elif target_count == 1:
                    h_normalized[target_mask] = h_flat[target_mask]
                
                h_out = h_normalized.view(B, T, H)
                if _DEBUG_DISCOVERY:
                    print(f"[ASBN] Applied BN: {source_count} source, {target_count} target tokens")
                return h_out, torch.tensor(0.0, device=device)
            except Exception as e:
                if _VERBOSE_LOGGING:
                    print(f"[ASBN] BN failed: {e}")
                return h, torch.tensor(0.0, device=device)
        else:
            dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
            return h, torch.tensor(0.0, device=dev)

    def forward_with_grl_simplified(self, h: torch.Tensor, proto_probs: Any, uncertainties: Any, gates: Any,
                                   token_word_map: Optional[List[Dict[int, str]]] = None,
                                   domain_labels: Optional[torch.Tensor] = None,
                                   global_step: Optional[int] = None) \
            -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        if global_step is not None:
            self.current_step = int(global_step)
        dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
        if self.current_step < self.warmup_steps:
            if _DEBUG_DISCOVERY and self.current_step % 100 == 0:
                print(f"[ASBN] Warmup: {self.current_step}/{self.warmup_steps}")
            zero = torch.tensor(0.0, device=dev)
            return zero, zero, zero, zero
        if not self.training or not _ENABLE_ASBN_TRAINING:
            zero = torch.tensor(0.0, device=dev)
            return zero, zero, zero, zero
        if not isinstance(h, torch.Tensor) or h.dim() != 3:
            zero = torch.tensor(0.0, device=dev)
            return zero, zero, zero, zero
        device = h.device
        self._ensure_discriminators_on_device(device)
        self.d_domain.train()
        self.d_freq.train()
        self.d_ctx.train()
        self.d_xl.train()
        B, T, H = h.size()
        pmax_mat = self._parse_proto_probs_matrix(proto_probs, B, T, device)
        U_mat = self._parse_scalar_matrix(uncertainties, B, T, device, default=0.1)
        G_mat = self._parse_scalar_matrix(gates, B, T, device, default=0.0)
        sel_mask = torch.ones((B, T), dtype=torch.bool, device=device)
        batch_indices = torch.arange(B, device=device).unsqueeze(1).expand(B, T)
        if token_word_map and isinstance(token_word_map, (list, tuple)) and len(token_word_map) > 0:
            try:
                for b in range(min(B, len(token_word_map))):
                    wm = token_word_map[b]
                    if not isinstance(wm, dict):
                        continue
                    for t in range(T):
                        if t in wm:
                            try:
                                token_str = wm[t]
                                tracked = True
                                if _has_should_track_token:
                                    tracked = bool(globals()["should_track_token"](token_str))
                                elif _has_is_valid_token:
                                    tracked = bool(is_valid_token(token_str, self.special_tokens, self.tokenizer, language=self.language))
                                if not tracked:
                                    sel_mask[b, t] = False
                            except Exception:
                                pass
            except Exception:
                if _VERBOSE_LOGGING:
                    try:
                        print("[ASBN] Token filtering failed:", traceback.format_exc().splitlines()[-1])
                    except Exception:
                        pass
        sel_idx = sel_mask.view(-1).nonzero(as_tuple=False).squeeze(1)
        batch_idx = batch_indices.view(-1)[sel_idx]
        if sel_idx.numel() == 0:
            if _DEBUG_DISCOVERY:
                print("[ASBN] No valid tokens after filtering")
            zero = torch.tensor(0.0, device=device)
            return zero, zero, zero, zero
        h_flat = h.view(B * T, H)
        sel_emb = h_flat[sel_idx]
        pmax_flat = pmax_mat.view(-1)[sel_idx]
        U_flat = U_mat.view(-1)[sel_idx]
        G_flat = G_mat.view(-1)[sel_idx]
        seq_len_feature = float(T) / max(int(_MAX_LENGTH), 1)
        freq_feature = torch.stack([pmax_flat, U_flat], dim=1).to(device)
        ctx_feature = torch.stack([G_flat, torch.full_like(G_flat, seq_len_feature)], dim=1).to(device)
        xl_input = sel_emb
        grl_alpha = self.get_grl_alpha(global_step)
        freq_input = torch.cat([sel_emb, freq_feature], dim=1)
        ctx_input = torch.cat([sel_emb, ctx_feature], dim=1)
        xl_input_grl = gradient_reversal(xl_input, alpha=grl_alpha)
        freq_input_grl = gradient_reversal(freq_input, alpha=grl_alpha)
        ctx_input_grl = gradient_reversal(ctx_input, alpha=grl_alpha)
        freq_logits = self.d_freq(freq_input_grl)
        ctx_logits = self.d_ctx(ctx_input_grl)
        xl_logits = self.d_xl(xl_input_grl)
        freq_label = (pmax_flat > self.freq_threshold).long().to(device)
        ctx_label = (U_flat < self.uncertainty_threshold).long().to(device)
        xl_label = (G_flat > self.gate_threshold).long().to(device)
        loss_freq = F.cross_entropy(freq_logits, freq_label, reduction="none")
        loss_ctx = F.cross_entropy(ctx_logits, ctx_label, reduction="none")
        loss_xl = F.cross_entropy(xl_logits, xl_label, reduction="none")
        lam_freq = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "freq")
        lam_ctx = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "ctx")
        lam_xl = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "xl")
        weighted = lam_freq * loss_freq + lam_ctx * loss_ctx + lam_xl * loss_xl
        mean_weighted = torch.mean(weighted)
        domain_loss = torch.tensor(0.0, device=device)
        domain_accuracy = torch.tensor(0.0, device=device)
        if domain_labels is not None:
            try:
                if domain_labels.dim() == 0:
                    domain_labels = domain_labels.unsqueeze(0)
                if domain_labels.size(0) == 1 and B > 1:
                    domain_labels = domain_labels.expand(B)
                elif domain_labels.size(0) != B:
                    domain_labels = domain_labels[0].unsqueeze(0).expand(B)
                domain_flat = domain_labels[batch_idx]
                domain_input = gradient_reversal(sel_emb, alpha=grl_alpha)
                domain_logits = self.d_domain(domain_input)
                domain_loss = F.cross_entropy(domain_logits, domain_flat)
                with torch.no_grad():
                    domain_preds = torch.argmax(domain_logits, dim=1)
                    domain_accuracy = (domain_preds == domain_flat).float().mean()
                    source_mask = domain_flat == 0
                    target_mask = domain_flat == 1
                    if source_mask.any():
                        source_acc = ((domain_preds[source_mask] == domain_flat[source_mask]).float().mean())
                        self.stats["source_accuracy"] += float(source_acc.item())
                    if target_mask.any():
                        target_acc = ((domain_preds[target_mask] == domain_flat[target_mask]).float().mean())
                        self.stats["target_accuracy"] += float(target_acc.item())
            except Exception as e:
                if _VERBOSE_LOGGING:
                    print(f"[ASBN] Domain loss failed: {e}")
        encoder_loss = self.encoder_grl_scale * (mean_weighted + domain_loss)
        try:
            with torch.no_grad():
                self.stats["domain_loss"] += float(domain_loss.item())
                self.stats["domain_accuracy"] += float(domain_accuracy.item())
                self.stats["asbn_loss"] += float(encoder_loss.item())
                self.stats["num_updates"] += 1
                if self.stats["num_updates"] >= self.stats_reset_interval:
                    if _DEBUG_DISCOVERY:
                        stats = self.get_detailed_stats()
                        print(f"\n[ASBN-STATS] After {stats['num_updates']} updates:")
                        print(f"  Domain loss: {stats['domain_loss']:.4f}")
                        print(f"  Domain acc: {stats['domain_accuracy']:.2%}")
                        print(f"  Source acc: {stats['source_accuracy']:.2%}")
                        print(f"  Target acc: {stats['target_accuracy']:.2%}")
                        print(f"  ASBN loss: {stats['asbn_loss']:.4f}")
                    self.reset_stats()
        except Exception:
            pass
        if _DEBUG_DISCOVERY and self.current_step % 500 == 0:
            print(f"\n[ASBN-STEP-{self.current_step}]")
            print(f"  GRL alpha: {grl_alpha:.3f}")
            print(f"  Encoder loss: {encoder_loss.item():.4f}")
            print(f"  Domain loss: {domain_loss.item():.4f}")
            print(f"  Domain acc: {domain_accuracy.item():.2%}")
        return encoder_loss, mean_weighted, domain_loss, domain_accuracy

    def state_dict(self, *args, **kwargs):
        state = super().state_dict(*args, **kwargs)
        state['current_step'] = self.current_step
        state['stats'] = self.stats.copy()
        return state

    def load_state_dict(self, state_dict, strict=True):
        self.current_step = state_dict.pop('current_step', 0)
        self.stats = state_dict.pop('stats', {
            "domain_loss": 0.0,
            "domain_accuracy": 0.0,
            "source_accuracy": 0.0,
            "target_accuracy": 0.0,
            "asbn_loss": 0.0,
            "num_updates": 0,
        })
        return super().load_state_dict(state_dict, strict=strict)

    def test_asbn(self, batch_size: int = 2, seq_len: int = 10) -> bool:
        print("\n" + "=" * 60)
        print("[ASBN-TEST] Testing ASBN module")
        print("=" * 60)
        try:
            try:
                device = next(self.parameters()).device
            except StopIteration:
                device = torch.device("cpu")
            h = torch.randn(batch_size, seq_len, self.embed_dim, device=device)
            domain_labels = torch.randint(0, 2, (batch_size,), device=device)
            h_out, _ = self.forward(h, domain_labels, device=device)
            assert h_out.shape == h.shape, "Forward output shape mismatch"
            print("  forward() passed")
            proto_probs = torch.rand(batch_size, seq_len, 3, device=device)
            uncertainties = torch.rand(batch_size, seq_len, device=device)
            gates = torch.rand(batch_size, seq_len, device=device)
            self.train()
            self.current_step = self.warmup_steps + 1
            enc_loss, adv_loss, dom_loss, dom_acc = self.forward_with_grl_simplified(
                h,
                proto_probs,
                uncertainties,
                gates,
                domain_labels=domain_labels,
                global_step=self.current_step,
            )
            assert enc_loss.item() >= 0.0, "Encoder loss negative"
            assert 0.0 <= dom_acc.item() <= 1.0, "Domain accuracy out of range"
            print("  forward_with_grl_simplified() passed")
            alpha = self.get_grl_alpha(self.current_step)
            assert 0.0 <= alpha <= max(_GRL_ALPHA_START, _GRL_ALPHA_END) * 1.1, "GRL alpha out of range"
            print(f"  get_grl_alpha() passed (alpha={alpha:.3f})")
            stats = self.get_detailed_stats()
            assert "domain_loss" in stats, "Missing domain_loss in stats"
            print("  Statistics tracking passed")
            state = self.state_dict()
            assert 'current_step' in state, "Missing current_step in state_dict"
            print("  state_dict() passed")
            print("\nAll ASBN tests passed")
            print("=" * 60 + "\n")
            return True
        except Exception as e:
            print(f"\nASBN test failed: {e}")
            traceback.print_exc()
            print("=" * 60 + "\n")
            return False

print("\n" + "=" * 80)
print("Cell 4: ASBN Ready (dynamic GRL, DSCD-aware) - ALL FIXES APPLIED")
print("=" * 80)



Cell 4: ASBN Ready (dynamic GRL, DSCD-aware) - ALL FIXES APPLIED


In [8]:
# ==============================================================================
# CELL 5: TRG MODULE (TRANSLATION RATIONALE GENERATION) - COMPLETE FIXES
# ==============================================================================
from typing import List, Dict, Tuple, Optional, Set, Any
from collections import deque, defaultdict
import traceback
import numpy as np
import torch
import torch.nn as nn
import threading
import time

try:
    TRG_EVIDENCE_K = int(TRG_EVIDENCE_K)
except (NameError, ValueError, TypeError):
    TRG_EVIDENCE_K = 3

try:
    TRG_GEN_EMBED = int(TRG_GEN_EMBED)
except (NameError, ValueError, TypeError):
    TRG_GEN_EMBED = 64

try:
    MAX_SILVER_BUFFER = int(MAX_SILVER_BUFFER)
except (NameError, ValueError, TypeError):
    MAX_SILVER_BUFFER = 50

try:
    VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except NameError:
    VERBOSE_LOGGING = False

try:
    DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except NameError:
    DEBUG_DISCOVERY = False

try:
    DEBUG_TIMING = bool(DEBUG_TIMING)
except NameError:
    DEBUG_TIMING = False

try:
    ENABLE_TRG_INFERENCE = bool(ENABLE_TRG_INFERENCE)
except NameError:
    ENABLE_TRG_INFERENCE = True

try:
    SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
except (NameError, TypeError):
    SOURCE_LANGUAGE = "bn"

try:
    TRG_UNCERTAINTY_THRESHOLD = float(TAU_LOW)
except (NameError, ValueError, TypeError):
    TRG_UNCERTAINTY_THRESHOLD = 0.15

try:
    TRG_SPAN_THRESHOLD = float(SPAN_THRESHOLD)
except (NameError, ValueError, TypeError):
    TRG_SPAN_THRESHOLD = 0.20

try:
    TAU_HIGH = float(TAU_HIGH)
except (NameError, ValueError, TypeError):
    TAU_HIGH = 0.85

try:
    TAU_LOW = float(TAU_LOW)
except (NameError, ValueError, TypeError):
    TAU_LOW = 0.15

try:
    TAU_ACCEPT = float(TAU_ACCEPT)
except (NameError, ValueError, TypeError):
    TAU_ACCEPT = 0.80

try:
    TRG_TEMPERATURE = float(TRG_TEMPERATURE)
except (NameError, ValueError, TypeError):
    TRG_TEMPERATURE = 1.0

try:
    MAX_EXPLANATIONS_PER_SENTENCE = int(MAX_EXPLANATIONS_PER_SENTENCE) if "MAX_EXPLANATIONS_PER_SENTENCE" in globals() else 10
except Exception:
    MAX_EXPLANATIONS_PER_SENTENCE = 10

_has_is_valid_token = "is_valid_token" in globals()
_has_get_tokenizer_special_tokens = "get_tokenizer_special_tokens" in globals()
_has_get_cached_special_tokens = "get_cached_special_tokens" in globals()

TRG_PUNCT_SET = set(".,!?-;:|")

def fallback_is_valid_token(token: str, special_tokens: set, tokenizer=None, language: str = "bn") -> bool:
    if token is None:
        return False
    if not isinstance(token, str):
        try:
            token = str(token)
        except Exception:
            return False
    token = token.strip()
    if not token:
        return False
    if token in special_tokens:
        return False
    clean = token.replace("▁", "").replace("##", "").replace("Ġ", "").replace(" ", "").strip()
    if len(clean) < 2:
        return False
    if not any(c.isalpha() for c in clean):
        return False
    if all(c in TRG_PUNCT_SET for c in clean):
        return False
    if clean.isdigit():
        return False
    return True

def is_word_start(raw_token: str, token_word_map: Optional[dict], idx: int) -> bool:
    if not isinstance(raw_token, str):
        return False
    try:
        if token_word_map is not None and isinstance(token_word_map, dict):
            if idx in token_word_map:
                w = token_word_map[idx]
                if isinstance(w, str) and w.strip():
                    return True

        if raw_token.startswith(" ") or raw_token.startswith("Ġ") or raw_token.startswith("▁"):
            return True

        clean = raw_token.replace("▁", "").replace("##", "").replace("Ġ", "").replace(" ", "").strip()
        if len(clean) < 2:
            return False
        if all(ch in ".,!?-;:" for ch in clean):
            return False
        if token_word_map is None and any(c.isalpha() for c in clean):
            return True
        return False
    except Exception:
        return False

class ComprehensiveTRGExplanationTemplate:
    def __init__(self):
        self.explanation_templates = {
            "high_confidence": "Chose sense '{sense}' with high confidence ({confidence:.1%}) based on evidence: {evidence}. Pattern matches learned data. {alternatives_text}",
            "medium_confidence": "Selected sense '{sense}' with moderate confidence ({confidence:.1%}). Evidence: {evidence}. Some uncertainty. {alternatives_text}",
            "low_confidence": "Uncertain: chose sense '{sense}' ({confidence:.1%}). Evidence: {evidence}. {alternatives_text} Review recommended.",
            "fallback": "Token '{token}' analyzed. Context: {evidence}.",
        }

    def generate_explanation(self, evidence: Dict) -> str:
        if not evidence or not isinstance(evidence, dict):
            return ""

        token = str(evidence.get("token", "unknown")).replace("▁", "").replace("##", "")

        sense_info = evidence.get("chosen_sense", ("unknown", 0.5))
        if isinstance(sense_info, (tuple, list)) and len(sense_info) >= 2:
            sense_name, confidence = str(sense_info[0]), float(sense_info[1])
        else:
            sense_name, confidence = "unknown", 0.5

        evidence_tokens = evidence.get("evidence_tokens", [])
        evidence_str = ", ".join([str(tok).replace("▁", "").replace("##", "") for tok in evidence_tokens[:TRG_EVIDENCE_K]]) or "limited context"

        alternatives = evidence.get("alternatives", [])
        alternatives_text = ""
        if isinstance(alternatives, list) and len(alternatives) > 0:
            alt_parts = []
            for alt in alternatives[:2]:
                if isinstance(alt, (tuple, list)) and len(alt) >= 2:
                    alt_name, alt_conf = str(alt[0]), float(alt[1])
                    alt_parts.append(f"'{alt_name}' ({alt_conf:.1%})")
            if alt_parts:
                alternatives_text = f"Alternatives: {', '.join(alt_parts)}."

        if confidence >= TAU_ACCEPT:
            template_key = "high_confidence"
        elif confidence <= TRG_UNCERTAINTY_THRESHOLD:
            template_key = "medium_confidence"
        else:
            template_key = "low_confidence"

        template = self.explanation_templates.get(template_key, self.explanation_templates["fallback"])

        try:
            return template.format(
                sense=sense_name,
                confidence=confidence,
                evidence=evidence_str,
                alternatives_text=alternatives_text,
                token=token,
            )
        except Exception:
            return f"Token {token} - {sense_name} ({confidence:.1%})."

class MemoryEfficientTRGExtractor:
    def __init__(self, tokenizer=None, language: str = "bn", dscd_module=None):
        self.tokenizer = tokenizer
        self.language = language
        self.dscd_module = dscd_module
        self.span_clamp_warnings = 0
        self.last_warning_time = 0.0

        if tokenizer is not None:
            try:
                if _has_get_tokenizer_special_tokens:
                    self.special_tokens = get_tokenizer_special_tokens(tokenizer)
                elif _has_get_cached_special_tokens:
                    self.special_tokens = get_cached_special_tokens(tokenizer)
                else:
                    self.special_tokens = set(tokenizer.all_special_tokens)
            except Exception:
                self.special_tokens = set()
        else:
            self.special_tokens = set()

    def extract_evidence_from_target(
        self,
        token_idx: int,
        span_start: int,
        span_end: int,
        tgt_preds: torch.Tensor,
    ) -> Optional[List[str]]:
        if not isinstance(token_idx, int) or token_idx < 0:
            return None
        if not isinstance(span_start, int) or not isinstance(span_end, int):
            return None
        if span_start < 0:
            return None
        if not isinstance(tgt_preds, (torch.Tensor, list)):
            return None

        seq_len = len(tgt_preds) if isinstance(tgt_preds, list) else int(tgt_preds.size(0))
        if span_end > seq_len:
            return None
        if span_start >= span_end:
            return None
        if token_idx < span_start or token_idx >= span_end:
            return None
        if token_idx >= seq_len:
            return None

        try:
            evidence_tokens: List[str] = []
            for i in range(span_start, span_end):
                if i == token_idx:
                    continue
                if isinstance(tgt_preds, list):
                    evidence_tokens.append(str(tgt_preds[i]))
                else:
                    try:
                        evidence_tokens.append(str(int(tgt_preds[i].item())))
                    except Exception:
                        evidence_tokens.append(f"token_{i}")
            return evidence_tokens if evidence_tokens else None
        except Exception:
            return None

    def _get_key(self, dscd_outputs: Dict, *names: str):
        if not isinstance(dscd_outputs, dict):
            return None
        for k in names:
            if k in dscd_outputs and dscd_outputs[k] is not None:
                return dscd_outputs[k]
        return None

    def extract_evidence_efficiently(
        self,
        token_idx: int,
        tokens: List[str],
        dscd_outputs: Dict,
        token_word_map: Optional[dict] = None,
        decoder_attention: Optional[torch.Tensor] = None,
    ) -> Dict:
        if not isinstance(tokens, list):
            return self.create_fallback_evidence(token_idx, [])

        if not isinstance(token_idx, int):
            return self.create_fallback_evidence(0, tokens)

        if token_idx < 0 or token_idx >= len(tokens):
            return self.create_fallback_evidence(max(0, min(token_idx, len(tokens) - 1)), tokens)

        raw_token = tokens[token_idx]

        if token_word_map and token_idx in token_word_map and token_word_map[token_idx]:
            canonical_word = str(token_word_map[token_idx]).strip()
            if _has_is_valid_token:
                try:
                    is_valid = is_valid_token(canonical_word, self.special_tokens, self.tokenizer, language=self.language)
                except Exception:
                    is_valid = fallback_is_valid_token(canonical_word, self.special_tokens, self.tokenizer, self.language)
            else:
                is_valid = fallback_is_valid_token(canonical_word, self.special_tokens, self.tokenizer, self.language)
        else:
            if _has_is_valid_token:
                try:
                    is_valid = is_valid_token(raw_token, self.special_tokens, self.tokenizer, language=self.language)
                except Exception:
                    is_valid = fallback_is_valid_token(raw_token, self.special_tokens, self.tokenizer, self.language)
            else:
                is_valid = fallback_is_valid_token(raw_token, self.special_tokens, self.tokenizer, self.language)

        if not is_valid:
            return self.create_fallback_evidence(token_idx, tokens)

        try:
            proto_probs = self.safe_extract_proto_probs(token_idx, dscd_outputs)
            uncertainty = self.safe_extract_uncertainty(token_idx, dscd_outputs)
            gate = self.safe_extract_gate(token_idx, dscd_outputs)
            span = self.safe_extract_span(token_idx, dscd_outputs)

            evidence_tokens: Optional[List[str]] = None
            if decoder_attention is not None and isinstance(decoder_attention, torch.Tensor):
                try:
                    if decoder_attention.dim() == 4:
                        if decoder_attention.size(0) == 1 and decoder_attention.size(1) == 1:
                            attn_avg = decoder_attention.mean(dim=(0, 1))
                        elif decoder_attention.size(0) == 1:
                            attn_avg = decoder_attention.mean(dim=1).squeeze(0)
                        else:
                            attn_avg = decoder_attention.mean(dim=0)

                        if attn_avg.dim() == 2 and token_idx < attn_avg.size(0):
                            vec = attn_avg[token_idx]
                        else:
                            vec = attn_avg.reshape(-1)

                    elif decoder_attention.dim() == 3:
                        attn_avg = decoder_attention.mean(dim=0)
                        if attn_avg.dim() == 2 and token_idx < attn_avg.size(0):
                            vec = attn_avg[token_idx]
                        else:
                            vec = attn_avg.reshape(-1)

                    elif decoder_attention.dim() == 2:
                        if token_idx < decoder_attention.size(0):
                            vec = decoder_attention[token_idx]
                        else:
                            vec = decoder_attention.reshape(-1)

                    elif decoder_attention.dim() == 1:
                        vec = decoder_attention
                    else:
                        vec = None

                    if vec is not None and vec.numel() > 0:
                        k = min(5, int(vec.size(0)))
                        topk_indices = torch.topk(vec, k).indices.detach().cpu().numpy()
                        evidence_tokens = []
                        for i in topk_indices:
                            ii = int(i)
                            is_same_word = False
                            if token_word_map and token_idx in token_word_map:
                                if ii in token_word_map and token_word_map[ii] == token_word_map[token_idx]:
                                    is_same_word = True
                            elif ii == token_idx:
                                is_same_word = True

                            if ii < len(tokens) and not is_same_word:
                                evidence_tokens.append(tokens[ii])
                except Exception:
                    evidence_tokens = None

            if evidence_tokens is None:
                evidence_tokens = self.extract_context_window(token_idx, tokens, token_word_map)

            seen: Dict[str, bool] = {}
            dedup_evidence: List[str] = []
            for t in evidence_tokens:
                if t not in seen:
                    seen[t] = True
                    dedup_evidence.append(t)
            evidence_tokens = dedup_evidence[:TRG_EVIDENCE_K]

            top_senses = self.compute_sense_alternatives_fast(proto_probs, temperature=TRG_TEMPERATURE)
            chosen_sense = top_senses[0] if len(top_senses) > 0 else ("unknown", 0.5)
            alternatives = top_senses[1:3] if len(top_senses) > 1 else []

            if token_word_map and token_idx in token_word_map and isinstance(token_word_map[token_idx], str) and token_word_map[token_idx].strip():
                token_value = token_word_map[token_idx]
            else:
                token_value = raw_token

            return {
                "token": token_value,
                "token_idx": token_idx,
                "evidence_tokens": evidence_tokens,
                "chosen_sense": chosen_sense,
                "alternatives": alternatives,
                "uncertainty": float(uncertainty),
                "gate": float(gate),
                "span": float(span),
            }
        except Exception as e:
            if VERBOSE_LOGGING or DEBUG_DISCOVERY:
                print(f"[TRG] Evidence error {token_idx}: {e}")
            return self.create_fallback_evidence(token_idx, tokens)

    def extract_context_window(self, token_idx: int, tokens: List[str], token_word_map: Optional[dict]) -> List[str]:
        context_window = 2
        start_idx = max(0, token_idx - context_window)
        end_idx = min(len(tokens), token_idx + context_window + 1)
        evidence_tokens: List[str] = []

        current_word_id = None
        if token_word_map:
            current_word_id = token_word_map.get(token_idx)

        for i in range(start_idx, end_idx):
            if i == token_idx:
                continue

            if current_word_id and token_word_map and token_word_map.get(i) == current_word_id:
                continue

            if i >= len(tokens):
                continue

            r_tok = tokens[i]
            clean_token = str(r_tok).replace("▁", "").replace("##", "").replace("Ġ", "").strip()

            if not is_word_start(r_tok, token_word_map, i):
                if token_word_map is None and len(clean_token) > 2 and any(c.isalpha() for c in clean_token):
                    pass
                else:
                    continue

            if _has_is_valid_token:
                try:
                    ok = is_valid_token(r_tok, self.special_tokens, self.tokenizer, language=self.language)
                except Exception:
                    ok = fallback_is_valid_token(r_tok, self.special_tokens, self.tokenizer, self.language)
            else:
                ok = fallback_is_valid_token(r_tok, self.special_tokens, self.tokenizer, self.language)

            if ok and len(clean_token) > 0:
                if token_word_map and isinstance(token_word_map.get(i, ""), str) and token_word_map[i].strip():
                    evidence_tokens.append(token_word_map[i].strip())
                else:
                    evidence_tokens.append(clean_token)
        return evidence_tokens

    def safe_extract_proto_probs(self, token_idx: int, dscd_outputs: Dict) -> torch.Tensor:
        try:
            if not isinstance(dscd_outputs, dict):
                return torch.tensor([1.0], dtype=torch.float32)

            pp_all = self._get_key(dscd_outputs, "proto_probs", "protoprobs")
            if pp_all is None:
                return torch.tensor([1.0], dtype=torch.float32)

            if isinstance(pp_all, list):
                if len(pp_all) == 0:
                    return torch.tensor([1.0], dtype=torch.float32)

                row0 = pp_all[0]

                if isinstance(row0, list):
                    if token_idx < len(row0):
                        v = row0[token_idx]
                        if isinstance(v, torch.Tensor):
                            vv = v.detach().cpu().float().flatten()
                            return vv if vv.numel() > 0 else torch.tensor([1.0], dtype=torch.float32)
                        if isinstance(v, (list, tuple, np.ndarray)):
                            vv = torch.as_tensor(v, dtype=torch.float32).flatten()
                            return vv if vv.numel() > 0 else torch.tensor([1.0], dtype=torch.float32)
                        return torch.tensor([float(v)], dtype=torch.float32)

                if isinstance(row0, torch.Tensor):
                    t = row0
                    if t.ndim == 2 and token_idx < t.shape[0]:
                        vv = t[token_idx].detach().cpu().float().flatten()
                        return vv if vv.numel() > 0 else torch.tensor([1.0], dtype=torch.float32)
                    vv = t.detach().cpu().float().flatten()
                    return vv if vv.numel() > 0 else torch.tensor([1.0], dtype=torch.float32)

                if isinstance(pp_all, list) and token_idx < len(pp_all):
                    v = pp_all[token_idx]
                    if isinstance(v, torch.Tensor):
                        vv = v.detach().cpu().float().flatten()
                        return vv if vv.numel() > 0 else torch.tensor([1.0], dtype=torch.float32)

            if isinstance(pp_all, torch.Tensor):
                t = pp_all
                if t.ndim == 3 and t.size(0) > 0 and token_idx < t.size(1):
                    vv = t[0, token_idx].detach().cpu().float().flatten()
                    return vv if vv.numel() > 0 else torch.tensor([1.0], dtype=torch.float32)
                if t.ndim == 2 and token_idx < t.size(0):
                    vv = t[token_idx].detach().cpu().float().flatten()
                    return vv if vv.numel() > 0 else torch.tensor([1.0], dtype=torch.float32)
                vv = t.detach().cpu().float().flatten()
                return vv if vv.numel() > 0 else torch.tensor([1.0], dtype=torch.float32)

        except Exception:
            pass
        return torch.tensor([1.0], dtype=torch.float32)

    def safe_extract_uncertainty(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            if not isinstance(dscd_outputs, dict):
                return 0.5
            U_all = self._get_key(dscd_outputs, "uncertainties")
            if U_all and len(U_all) > 0:
                row = U_all[0]
                if isinstance(row, torch.Tensor):
                    if row.ndim == 2 and token_idx < row.shape[0]:
                        return float(row[token_idx].item())
                    if row.ndim == 1 and token_idx < row.shape[0]:
                        return float(row[token_idx].item())
                if isinstance(row, (list, tuple)) and token_idx < len(row):
                    val = row[token_idx]
                    return float(val.item()) if isinstance(val, torch.Tensor) else float(val)
        except Exception:
            pass
        return 0.5

    def safe_extract_gate(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            if not isinstance(dscd_outputs, dict):
                return 0.0
            G_all = self._get_key(dscd_outputs, "gates")
            if G_all and len(G_all) > 0:
                row = G_all[0]
                if isinstance(row, torch.Tensor):
                    if row.ndim == 2 and token_idx < row.shape[0]:
                        return float(row[token_idx].item())
                    if row.ndim == 1 and token_idx < row.shape[0]:
                        return float(row[token_idx].item())
                if isinstance(row, (list, tuple)) and token_idx < len(row):
                    val = row[token_idx]
                    return float(val.item()) if isinstance(val, torch.Tensor) else float(val)
        except Exception:
            pass
        return 0.0

    def safe_extract_span(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            if not isinstance(dscd_outputs, dict):
                return 0.0
            S_all = self._get_key(dscd_outputs, "span_preds", "spanpreds")
            if S_all and len(S_all) > 0:
                row = S_all[0]
                if isinstance(row, torch.Tensor):
                    if row.ndim == 2 and token_idx < row.shape[0]:
                        span_val = float(row[token_idx].item())
                    elif row.ndim == 1 and token_idx < row.shape[0]:
                        span_val = float(row[token_idx].item())
                    else:
                        return 0.0
                elif isinstance(row, (list, tuple)) and token_idx < len(row):
                    val = row[token_idx]
                    span_val = float(val.item()) if isinstance(val, torch.Tensor) else float(val)
                else:
                    return 0.0

                if span_val < 0.0:
                    current_time = time.time()
                    if self.span_clamp_warnings < 10 or current_time - self.last_warning_time > 60.0:
                        if VERBOSE_LOGGING or DEBUG_DISCOVERY:
                            print(f"[TRG] Negative span {span_val:.3f} -> 0.0")
                        self.span_clamp_warnings += 1
                        self.last_warning_time = current_time
                    return 0.0
                if span_val > 1.0:
                    current_time = time.time()
                    if self.span_clamp_warnings < 10 or current_time - self.last_warning_time > 60.0:
                        if VERBOSE_LOGGING or DEBUG_DISCOVERY:
                            print(f"[TRG] Span {span_val:.3f} > 1.0 -> 1.0")
                        self.span_clamp_warnings += 1
                        self.last_warning_time = current_time
                    return 1.0
                return span_val
        except Exception:
            pass
        return 0.0

    def compute_span(self, sense_probs) -> float:
        try:
            if isinstance(sense_probs, dict):
                probs = list(sense_probs.values())
            else:
                probs = sense_probs

            if isinstance(probs, torch.Tensor):
                probs = probs.cpu().numpy().flatten().tolist()
            if isinstance(probs, (np.ndarray, list)):
                probs = list(probs)

            if len(probs) < 2:
                return 0.0

            sorted_probs = sorted([float(p) for p in probs], reverse=True)
            span = float(sorted_probs[0] - float(sorted_probs[1]))
            return max(0.0, min(1.0, span))
        except Exception:
            return 0.0

    def compute_sense_alternatives_fast(self, proto_probs: torch.Tensor, temperature: float = 1.0) -> List[Tuple[str, float]]:
        try:
            if not isinstance(proto_probs, torch.Tensor):
                proto_probs = torch.as_tensor(proto_probs, dtype=torch.float32)

            probs = proto_probs.flatten().float()
            probs = torch.clamp(probs, min=1e-10, max=1.0)

            if temperature != 1.0 and probs.numel() > 1:
                log_probs = torch.log(probs)
                scaled_log_probs = log_probs / float(temperature)
                probs = torch.softmax(scaled_log_probs, dim=0)

            if probs.numel() > 1:
                probs_sorted, indices = torch.sort(probs, descending=True)
                topk = min(3, int(indices.numel()))
                return [(f"sense_{int(indices[i].item())}", float(probs_sorted[i].item())) for i in range(topk)]
            else:
                return [("sense_0", float(probs[0].item()))]
        except Exception:
            return [("unknown", 0.5)]

    def create_fallback_evidence(self, token_idx: int, tokens: List[str]) -> Dict:
        if isinstance(tokens, list) and 0 <= token_idx < len(tokens):
            token = tokens[token_idx]
        else:
            token = "UNK"

        return {
            "token": token,
            "token_idx": token_idx,
            "evidence_tokens": [],
            "chosen_sense": ("unknown", 0.5),
            "alternatives": [],
            "uncertainty": 0.5,
            "gate": 0.0,
            "span": 0.0,
        }

    def get_homograph_tokens_from_dscd(self) -> Set[str]:
        homograph_tokens: Set[str] = set()
        try:
            if self.dscd_module is not None:
                if hasattr(self.dscd_module, "get_discovered_homographs"):
                    homograph_tokens = set(self.dscd_module.get_discovered_homographs())
                elif hasattr(self.dscd_module, "prototypestores"):
                    for token, store in self.dscd_module.prototypestores.items():
                        try:
                            sz = int(store.size()) if hasattr(store, "size") and callable(store.size) else 0
                        except Exception:
                            try:
                                cents = getattr(store, "centroids", None)
                                sz = int(len(cents)) if cents is not None else 0
                            except Exception:
                                sz = 0
                        if sz >= 2:
                            clean = str(token).replace("▁", "").replace("##", "").replace("Ġ", "").strip()
                            homograph_tokens.add(clean)
        except Exception:
            pass
        return homograph_tokens

class CompleteTRGWithExplanations(nn.Module):
    def __init__(self, embed_dim: Optional[int] = None, tokenizer=None, language: str = "bn", dscd_module=None):
        super().__init__()
        self.embed_dim = int(embed_dim) if embed_dim is not None else int(TRG_GEN_EMBED)
        self.tokenizer = tokenizer
        self.language = language
        self.dscd_module = dscd_module

        if dscd_module is None:
            if VERBOSE_LOGGING or DEBUG_DISCOVERY:
                print("[TRG] No DSCD module - homograph detection disabled")

        if tokenizer is not None:
            try:
                if _has_get_tokenizer_special_tokens:
                    self.special_tokens = get_tokenizer_special_tokens(tokenizer)
                elif _has_get_cached_special_tokens:
                    self.special_tokens = get_cached_special_tokens(tokenizer)
                else:
                    self.special_tokens = set(tokenizer.all_special_tokens)
            except Exception:
                self.special_tokens = set()
        else:
            self.special_tokens = set()

        self.template_system = ComprehensiveTRGExplanationTemplate()
        self.evidence_extractor = MemoryEfficientTRGExtractor(tokenizer, language=language, dscd_module=dscd_module)

        self.silver_buffer = deque(maxlen=int(MAX_SILVER_BUFFER))
        self.silver_lock = threading.Lock()

        self.stats_reset_interval = 1000
        self.stats = {
            "explanations_generated": 0,
            "high_confidence_explanations": 0,
            "low_confidence_explanations": 0,
            "empty_evidence_count": 0,
            "total_evidence_tokens": 0,
            "tokens_filtered_wordstart": 0,
            "tokens_filtered_validity": 0,
            "tokens_filtered_ambiguity": 0,
            "dscd_homographs_explained": 0,
        }
        self.stats_lock = threading.Lock()

        if VERBOSE_LOGGING or DEBUG_DISCOVERY:
            print("[TRG] Initialized")
            print(f"  - Uncertainty: {TRG_UNCERTAINTY_THRESHOLD:.2f}")
            print(f"  - Span: {TRG_SPAN_THRESHOLD:.2f}")
            print(f"  - Temperature: {TRG_TEMPERATURE:.2f}")
            print("  - Mode: DATA-DRIVEN")

    def update_stats(self, evidence: Dict, is_dscd_homograph: bool = False) -> None:
        with self.stats_lock:
            self.stats["explanations_generated"] += 1
            if is_dscd_homograph:
                self.stats["dscd_homographs_explained"] += 1

            if not evidence.get("evidence_tokens"):
                self.stats["empty_evidence_count"] += 1
            else:
                self.stats["total_evidence_tokens"] += len(evidence["evidence_tokens"])

            confidence = 0.5
            chosen = evidence.get("chosen_sense")
            if isinstance(chosen, (tuple, list)) and len(chosen) >= 2:
                try:
                    confidence = float(chosen[1])
                except Exception:
                    confidence = 0.5

            if confidence >= TAU_ACCEPT:
                self.stats["high_confidence_explanations"] += 1
            elif confidence <= TRG_UNCERTAINTY_THRESHOLD:
                self.stats["low_confidence_explanations"] += 1

            if self.stats["explanations_generated"] % self.stats_reset_interval == 0:
                if DEBUG_DISCOVERY:
                    current_stats = self.get_statistics()
                    print(f"[TRG-STATS] After {self.stats['explanations_generated']}")
                    print(f"  High conf: {current_stats['high_confidence_rate']:.2%}")
                    print(f"  DSCD: {current_stats['dscd_homograph_rate']:.2%}")
                self.reset_statistics()

    def add_to_silver_buffer(self, evidence: Dict, explanation: str, tokens: List[str]) -> None:
        try:
            conf = 0.5
            chosen = evidence.get("chosen_sense")
            if isinstance(chosen, (tuple, list)) and len(chosen) >= 2:
                conf = float(chosen[1])

            entry = {
                "token": str(evidence.get("token", "UNK"))[:20],
                "explanation": str(explanation)[:150],
                "confidence": conf,
            }
            with self.silver_lock:
                self.silver_buffer.append(entry)
        except Exception:
            pass

    def generate_explanation_for_token(
        self,
        token_idx: int,
        tokens: List[str],
        dscd_outputs: Dict,
        token_word_map: Optional[dict] = None,
        decoder_attention: Optional[torch.Tensor] = None,
        is_dscd_homograph: bool = False,
    ) -> Tuple[str, Dict]:
        if self.training or not ENABLE_TRG_INFERENCE:
            return "", {}

        if not isinstance(tokens, list) or not isinstance(token_idx, int):
            return "", {}

        if token_idx < 0 or token_idx >= len(tokens):
            return "", {}

        raw_token = tokens[token_idx]
        if _has_is_valid_token:
            try:
                is_valid = is_valid_token(raw_token, self.special_tokens, self.tokenizer, language=self.language)
            except Exception:
                is_valid = fallback_is_valid_token(raw_token, self.special_tokens, self.tokenizer, self.language)
        else:
            is_valid = fallback_is_valid_token(raw_token, self.special_tokens, self.tokenizer, self.language)

        if not is_valid and not is_word_start(raw_token, token_word_map, token_idx):
            return "", {}

        try:
            evidence = self.evidence_extractor.extract_evidence_efficiently(
                token_idx, tokens, dscd_outputs,
                token_word_map=token_word_map,
                decoder_attention=decoder_attention
            )

            explanation_text = self.template_system.generate_explanation(evidence)
            self.update_stats(evidence, is_dscd_homograph=is_dscd_homograph)
            self.add_to_silver_buffer(evidence, explanation_text, tokens)

            return explanation_text, evidence
        except Exception:
            return "", {}

    @staticmethod
    def _tolist_helper(x: Any) -> List[float]:
        if x is None:
            return []
        try:
            if isinstance(x, torch.Tensor):
                if x.ndim == 0:
                    return [float(x.item())]
                if x.ndim == 1:
                    return [float(v.item()) for v in x]
                if x.ndim == 2:
                    return [float(v.item()) for v in x[0]]
                return [float(v.item()) for v in x.flatten()]
            if isinstance(x, (list, tuple)):
                out = []
                for v in x:
                    if isinstance(v, torch.Tensor):
                        if v.ndim == 0:
                            out.append(float(v.item()))
                        elif v.numel() == 1:
                            out.append(float(v.flatten()[0].item()))
                        else:
                            out.append(0.0)
                    elif isinstance(v, (int, float, np.number)):
                        out.append(float(v))
                    else:
                        try:
                            out.append(float(v))
                        except Exception:
                            out.append(0.0)
                return out
            if isinstance(x, (int, float, np.number)):
                return [float(x)]
            return [float(x)]
        except Exception:
            return []

    def process_sentence_for_explanations(
        self,
        tokens: List[str],
        dscd_outputs: Dict,
        token_word_map: Optional[dict] = None,
        uncertainty_threshold: Optional[float] = None,
        decoder_attention: Optional[torch.Tensor] = None,
        max_explanations: int = MAX_EXPLANATIONS_PER_SENTENCE,
    ) -> List[Dict]:

        if self.training or not ENABLE_TRG_INFERENCE:
            return []

        if uncertainty_threshold is None:
            uncertainty_threshold = float(TRG_UNCERTAINTY_THRESHOLD)

        strict_uncertainty = max(TRG_UNCERTAINTY_THRESHOLD, uncertainty_threshold)

        explanations: List[Dict] = []
        try:
            if not tokens or not isinstance(tokens, list):
                return explanations

            if not isinstance(dscd_outputs, dict) or not dscd_outputs:
                return explanations

            U_all = dscd_outputs.get("uncertainties")
            S_all = dscd_outputs.get("span_preds", None)
            if S_all is None:
                S_all = dscd_outputs.get("spanpreds", None)

            if not U_all or not U_all[0]:
                return explanations

            U = self._tolist_helper(U_all[0])
            if S_all and S_all[0]:
                S = self._tolist_helper(S_all[0])
            else:
                S = [0.0] * len(U)

            if len(S) < len(U):
                S.extend([0.0] * (len(U) - len(S)))

            if not U:
                return explanations

            dscd_homographs = self.evidence_extractor.get_homograph_tokens_from_dscd()

            word_groups = defaultdict(list)
            for idx, tok in enumerate(tokens):
                if idx >= len(U):
                    break

                if token_word_map and idx in token_word_map and token_word_map[idx]:
                    word_key = token_word_map[idx]
                else:
                    word_key = f"RAW_{idx}_{tok}"

                word_groups[word_key].append({
                    "idx": idx,
                    "tok": tok,
                    "u": float(U[idx]),
                    "s": float(S[idx])
                })

            candidates: List[Tuple[int, float, float, str, int, int]] = []

            for word_key, group in word_groups.items():
                max_u = max(g["u"] for g in group)
                max_s = max(g["s"] for g in group)

                best_subword = max(group, key=lambda x: x["u"] + x["s"])
                idx = best_subword["idx"]

                clean_tok = str(word_key).replace("▁", "").replace("##", "").replace("Ġ", "").strip()

                if _has_is_valid_token:
                    try:
                        valid = is_valid_token(clean_tok, self.special_tokens, self.tokenizer, language=self.language)
                    except Exception:
                        valid = fallback_is_valid_token(clean_tok, self.special_tokens, self.tokenizer, self.language)
                else:
                    valid = fallback_is_valid_token(clean_tok, self.special_tokens, self.tokenizer, self.language)

                if not valid:
                    with self.stats_lock:
                        self.stats["tokens_filtered_validity"] += 1
                    continue

                in_dscd = clean_tok in dscd_homographs

                if in_dscd:
                    priority = 1
                elif max_s > TRG_SPAN_THRESHOLD:
                    priority = 2
                elif max_u > strict_uncertainty:
                    priority = 3
                else:
                    with self.stats_lock:
                        self.stats["tokens_filtered_ambiguity"] += 1
                    continue

                candidates.append((idx, max_u, max_s, clean_tok, priority, idx))

            if not candidates:
                return explanations

            candidates.sort(key=lambda t: (t[4], -t[2], -t[1], t[5]))

            for token_idx, u, s, clean_tok, priority, _ in candidates[:max_explanations]:
                try:
                    explanation_text, evidence = self.generate_explanation_for_token(
                        token_idx, tokens, dscd_outputs,
                        token_word_map=token_word_map,
                        decoder_attention=decoder_attention,
                        is_dscd_homograph=(priority == 1)
                    )

                    if explanation_text and evidence:
                        display_token = clean_tok

                        explanations.append({
                            "token_idx": token_idx,
                            "token": display_token,
                            "explanation": explanation_text,
                            "uncertainty": u,
                            "span": s,
                            "dscd_discovered": (priority == 1),
                            "priority": priority
                        })
                except Exception:
                    continue

        except Exception:
            pass

        return explanations

    def get_statistics(self) -> Dict:
        with self.stats_lock:
            total = max(self.stats["explanations_generated"], 1)
            if self.stats["explanations_generated"] > 0:
                avg_evidence_tokens = self.stats["total_evidence_tokens"] / total
            else:
                avg_evidence_tokens = 0.0

            return {
                **self.stats.copy(),
                "high_confidence_rate": self.stats["high_confidence_explanations"] / total,
                "low_confidence_rate": self.stats["low_confidence_explanations"] / total,
                "empty_evidence_rate": self.stats["empty_evidence_count"] / total,
                "avg_evidence_tokens": avg_evidence_tokens,
                "silver_buffer_size": len(self.silver_buffer),
                "dscd_homograph_rate": self.stats["dscd_homographs_explained"] / total
            }

    def reset_statistics(self) -> None:
        with self.stats_lock:
            self.stats = {
                "explanations_generated": 0,
                "high_confidence_explanations": 0,
                "low_confidence_explanations": 0,
                "empty_evidence_count": 0,
                "total_evidence_tokens": 0,
                "tokens_filtered_wordstart": 0,
                "tokens_filtered_validity": 0,
                "tokens_filtered_ambiguity": 0,
                "dscd_homographs_explained": 0,
            }

    def clear_silver_buffer(self) -> None:
        with self.silver_lock:
            self.silver_buffer.clear()

    def test_trg(self, tokenizer=None) -> bool:
        print("=" * 60)
        print("[TRG-TEST] Testing")
        print("=" * 60)

        if not ENABLE_TRG_INFERENCE:
            print("[TRG] inference disabled, enabling for test...")

        try:
            tokens = ["_", "Kal", "er", "_", "Pa", "ta"]
            dscd_outputs = {
                "proto_probs": [[torch.tensor([0.6, 0.4])] for _ in tokens],
                "uncertainties": [[0.1], [0.5], [0.2], [0.1], [0.05], [0.0]],
                "span_preds": [[0.05], [0.3], [0.1], [0.05], [0.0], [0.0]],
                "gates": [[0.2], [0.8], [0.3], [0.2], [0.0], [0.0]]
            }
            token_word_map = {
                0: "_",
                1: "Kaler",
                2: "Kaler",
                3: "_",
                4: "Pata",
                5: "Pata"
            }

            self.eval()
            explanations = self.process_sentence_for_explanations(
                tokens=tokens,
                dscd_outputs=dscd_outputs,
                token_word_map=token_word_map,
                max_explanations=3
            )

            print(f"  Generated {len(explanations)} explanations")
            if len(explanations) > 0:
                for i, expl in enumerate(explanations, 1):
                    print(f"  {i}. {expl['token']}: {expl['explanation'][:50]}... (u={expl['uncertainty']:.2f})")

            found_kaler = any(e["token"] == "Kaler" for e in explanations)
            if found_kaler:
                print("  [SUCCESS] Found aggregated word 'Kaler'")
            else:
                print("  [WARN] Did not find 'Kaler' (check thresholds)")

            stats = self.get_statistics()
            print(f"  Stats: {stats['explanations_generated']} total")
            self.reset_statistics()
            stats_after = self.get_statistics()
            assert stats_after["explanations_generated"] == 0
            print("  Reset OK")
            print("  All tests passed")
            print("=" * 60)
            return True
        except Exception as e:
            print(f"  [FAIL] failed: {e}")
            try:
                traceback.print_exc()
            except Exception:
                pass
            print("=" * 60)
            return False

print("\n" + "=" * 80)
print("Cell 5: TRG Ready (DATA-DRIVEN, SUBWORD-AWARE) - ALL FIXES APPLIED")
print("=" * 80)



Cell 5: TRG Ready (DATA-DRIVEN, SUBWORD-AWARE) - ALL FIXES APPLIED


In [9]:
# ==============================================================================
# CELL 6: MEMORY-OPTIMIZED TATN MODEL (FAST INFERENCE MODE ADDED)
# ==============================================================================
from typing import List, Dict, Optional, Any, Tuple
import traceback
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import M2M100ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
import threading
import gc
import time
import unicodedata

try:
    SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
    TARGET_LANGUAGE = str(TARGET_LANGUAGE)
except (NameError, TypeError):
    SOURCE_LANGUAGE = "bn"
    TARGET_LANGUAGE = "en"

def _get_int_global(name: str, default: int) -> int:
    try:
        val = globals().get(name)
        if val is not None:
            return int(val)
    except (ValueError, TypeError):
        pass
    return default

def _get_float_global(name: str, default: float) -> float:
    try:
        val = globals().get(name)
        if val is not None:
            return float(val)
    except (ValueError, TypeError):
        pass
    return default

def _get_bool_global(name: str, default: bool) -> bool:
    try:
        val = globals().get(name)
        if val is not None:
            return bool(val)
    except (ValueError, TypeError):
        pass
    return default

DSCD_BUFFER_SIZE = _get_int_global("DSCD_BUFFER_SIZE", 50)
DSCD_MAX_PROTOS = _get_int_global("DSCD_MAX_PROTOS", 8)
DSCD_N_MIN = _get_int_global("DSCD_N_MIN", 2)
DSCD_DISPERSION_THRESHOLD = _get_float_global("DSCD_DISPERSION_THRESHOLD", 0.15)

ENABLE_ASBN_TRAINING = _get_bool_global("ENABLE_ASBN_TRAINING", True)
ENABLE_TRG_INFERENCE = _get_bool_global("ENABLE_TRG_INFERENCE", True)
MEMORY_CLEANUP_FREQUENCY = _get_int_global("MEMORY_CLEANUP_FREQUENCY", 2000)

NUM_GPUS = _get_int_global("NUM_GPUS", torch.cuda.device_count() if torch.cuda.is_available() else 1)
USE_GC = _get_bool_global("GRADIENT_CHECKPOINTING", False)
DSCD_ENABLE_TRAINING_CLUSTERING = _get_bool_global("DSCD_ENABLE_TRAINING_CLUSTERING", True)

LAMBDA_ASBN = _get_float_global("LAMBDA_ASBN", 0.05)
LAMBDA_DSCD = _get_float_global("LAMBDA_DSCD", 0.15)

VERBOSE_LOGGING = _get_bool_global("VERBOSE_LOGGING", False)

try:
    DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except (NameError, TypeError):
    DEBUG_DISCOVERY = False

try:
    DEBUG_TIMING = bool(DEBUG_TIMING)
except (NameError, TypeError):
    DEBUG_TIMING = False

PERIODIC_DISCOVERY_FREQUENCY = _get_int_global("PERIODIC_DISCOVERY_FREQUENCY", 200)
VALIDATION_CHECK_INTERVAL = _get_int_global("VALIDATION_CHECK_INTERVAL", 200)

SPAN_THRESHOLD = _get_float_global("SPAN_THRESHOLD", 0.20)
UNCERTAINTY_THRESHOLD = _get_float_global("UNCERTAINTY_THRESHOLD", 0.15)

TRG_UNCERTAINTY_THRESHOLD = _get_float_global("TRG_UNCERTAINTY_THRESHOLD", _get_float_global("TAU_LOW", 0.15))
TAU_LOW = _get_float_global("TAU_LOW", 0.15)

TRAIN_DOMAIN = _get_int_global("TRAIN_DOMAIN", 0)
TEST_DOMAIN = _get_int_global("TEST_DOMAIN", 1)
USE_DOMAIN_LABELS = _get_bool_global("USE_DOMAIN_LABELS", True)

try:
    M2M100_EN_TOKEN_ID = int(M2M100_EN_TOKEN_ID)
except (NameError, ValueError, TypeError):
    M2M100_EN_TOKEN_ID = 128022

try:
    M2M100_BN_TOKEN_ID = int(M2M100_BN_TOKEN_ID)
except (NameError, ValueError, TypeError):
    M2M100_BN_TOKEN_ID = 128025

_has_reconstruct_word_spans = "reconstruct_word_spans" in globals()


def _safe_get_last_hidden_state(enc_output):
    if enc_output is None:
        return None
    if hasattr(enc_output, "last_hidden_state"):
        return enc_output.last_hidden_state
    if isinstance(enc_output, (list, tuple)) and len(enc_output) > 0:
        return enc_output[0]
    return None


def _normalize_dscd_outputs(
    raw: Dict[str, Any],
    batch_size: int,
    seq_len: int,
    device: torch.device,
    embed_dim: int,
) -> Dict[str, Any]:
    defaults = {
        "h_augmented": torch.zeros(batch_size, seq_len, embed_dim, device=device, dtype=torch.float32),
        "proto_probs": [
            [torch.tensor([1.0], device=device, dtype=torch.float32) for _ in range(seq_len)]
            for _ in range(batch_size)
        ],
        "uncertainties": [
            [torch.tensor(0.0, device=device, dtype=torch.float32) for _ in range(seq_len)]
            for _ in range(batch_size)
        ],
        "gates": [
            [torch.tensor(0.0, device=device, dtype=torch.float32) for _ in range(seq_len)]
            for _ in range(batch_size)
        ],
        "span_preds": [
            [torch.tensor(0.0, device=device, dtype=torch.float32) for _ in range(seq_len)]
            for _ in range(batch_size)
        ],
        "proto_assignments": [
            torch.zeros(seq_len, dtype=torch.long, device=device) for _ in range(batch_size)
        ],
    }
    
    if not isinstance(raw, dict):
        return defaults
    
    out = defaults.copy()
    
    try:
        haug_key = None
        if "haugmented" in raw and raw["haugmented"] is not None:
            haug_key = "haugmented"
        elif "h_augmented" in raw and raw["h_augmented"] is not None:
            haug_key = "h_augmented"
        
        if haug_key:
            h = raw[haug_key]
            if isinstance(h, torch.Tensor) and h.shape == (batch_size, seq_len, embed_dim):
                out["h_augmented"] = h.to(device)
            else:
                try:
                    out["h_augmented"] = h.to(device).reshape(batch_size, seq_len, embed_dim)
                except Exception:
                    pass
    except Exception:
        pass
    
    for list_key in ("proto_probs", "uncertainties", "gates", "span_preds"):
        list_key_alt = list_key.replace("_", "")
        actual_key = list_key if list_key in raw else list_key_alt
        
        if actual_key in raw and raw[actual_key] is not None:
            try:
                val = raw[actual_key]
                if isinstance(val, list) and len(val) == batch_size:
                    safe_batch = []
                    for b_row in val:
                        if isinstance(b_row, list):
                            safe_row = []
                            for t_idx in range(seq_len):
                                try:
                                    if t_idx < len(b_row):
                                        v = b_row[t_idx]
                                        if isinstance(v, torch.Tensor):
                                            vv = v.to(device)
                                            if list_key == "proto_probs":
                                                vv = vv.float().flatten()
                                                if vv.numel() == 0:
                                                    vv = torch.tensor([1.0], device=device, dtype=torch.float32)
                                            else:
                                                if vv.numel() != 1:
                                                    vv = vv.float().mean()
                                                vv = vv.float()
                                            safe_row.append(vv)
                                        else:
                                            if list_key == "proto_probs":
                                                safe_row.append(
                                                    torch.as_tensor(v, device=device, dtype=torch.float32).flatten()
                                                )
                                            else:
                                                safe_row.append(torch.tensor(float(v), device=device, dtype=torch.float32))
                                    else:
                                        if list_key == "proto_probs":
                                            safe_row.append(torch.tensor([1.0], device=device, dtype=torch.float32))
                                        else:
                                            safe_row.append(torch.tensor(0.0, device=device, dtype=torch.float32))
                                except Exception:
                                    if list_key == "proto_probs":
                                        safe_row.append(torch.tensor([1.0], device=device, dtype=torch.float32))
                                    else:
                                        safe_row.append(torch.tensor(0.0, device=device, dtype=torch.float32))
                            safe_batch.append(safe_row)
                        else:
                            if list_key == "proto_probs":
                                safe_batch.append(
                                    [torch.tensor([1.0], device=device, dtype=torch.float32) for _ in range(seq_len)]
                                )
                            else:
                                safe_batch.append(
                                    [torch.tensor(0.0, device=device, dtype=torch.float32) for _ in range(seq_len)]
                                )
                    out[list_key] = safe_batch
            except Exception:
                pass
    
    try:
        pa_key = None
        if "protoassignments" in raw and raw["protoassignments"] is not None:
            pa_key = "protoassignments"
        elif "proto_assignments" in raw and raw["proto_assignments"] is not None:
            pa_key = "proto_assignments"
        
        if pa_key:
            pa = raw[pa_key]
            if isinstance(pa, list) and len(pa) == batch_size:
                safe_pa = []
                for b_row in pa:
                    try:
                        if isinstance(b_row, torch.Tensor):
                            t = b_row.to(device).long()
                            if t.numel() != seq_len:
                                t = torch.zeros(seq_len, dtype=torch.long, device=device)
                            safe_pa.append(t)
                        else:
                            t = torch.tensor(b_row, dtype=torch.long, device=device)
                            if t.numel() != seq_len:
                                t = torch.zeros(seq_len, dtype=torch.long, device=device)
                            safe_pa.append(t)
                    except Exception:
                        safe_pa.append(torch.zeros(seq_len, dtype=torch.long, device=device))
                out["proto_assignments"] = safe_pa
    except Exception:
        pass
    
    return out


def aggregate_subwords_to_words(
    subword_embeddings: torch.Tensor,
    token_to_word_map: Dict[int, int],
    num_words: int,
    device: torch.device,
) -> torch.Tensor:
    embed_dim = subword_embeddings.size(-1)
    word_embeddings = torch.zeros(num_words, embed_dim, device=device, dtype=subword_embeddings.dtype)
    word_counts = torch.zeros(num_words, device=device, dtype=torch.int32)
    
    for token_idx, word_idx in token_to_word_map.items():
        if 0 <= token_idx < subword_embeddings.size(0) and 0 <= word_idx < num_words:
            word_embeddings[word_idx] += subword_embeddings[token_idx]
            word_counts[word_idx] += 1
    
    for w_idx in range(num_words):
        if word_counts[w_idx] > 0:
            word_embeddings[w_idx] /= word_counts[w_idx]
    
    return word_embeddings


def broadcast_word_to_subword(
    word_level_outputs: Dict[str, Any],
    token_to_word_map: Dict[int, int],
    seq_len: int,
    device: torch.device,
) -> Dict[str, Any]:
    subword_outputs = {
        "haugmented": [],
        "protoprobs": [],
        "uncertainties": [],
        "gates": [],
        "spanpreds": [],
        "protoassignments": [],
    }
    
    h_words = word_level_outputs.get("haugmented", [])
    proto_probs_words = word_level_outputs.get("protoprobs", [])
    uncertainties_words = word_level_outputs.get("uncertainties", [])
    gates_words = word_level_outputs.get("gates", [])
    span_preds_words = word_level_outputs.get("spanpreds", [])
    proto_assign_words = word_level_outputs.get("protoassignments", [])
    
    for token_idx in range(seq_len):
        word_idx = token_to_word_map.get(token_idx, -1)
        
        if word_idx >= 0:
            if isinstance(h_words, list) and word_idx < len(h_words):
                subword_outputs["haugmented"].append(h_words[word_idx])
            elif isinstance(h_words, torch.Tensor) and word_idx < h_words.size(0):
                subword_outputs["haugmented"].append(h_words[word_idx])
            else:
                subword_outputs["haugmented"].append(torch.zeros(word_level_outputs.get("embed_dim", 1024), device=device))
            
            if isinstance(proto_probs_words, list) and word_idx < len(proto_probs_words):
                subword_outputs["protoprobs"].append(proto_probs_words[word_idx])
            else:
                subword_outputs["protoprobs"].append(torch.tensor([1.0], device=device, dtype=torch.float32))
            
            if isinstance(uncertainties_words, list) and word_idx < len(uncertainties_words):
                subword_outputs["uncertainties"].append(uncertainties_words[word_idx])
            else:
                subword_outputs["uncertainties"].append(torch.tensor(0.0, device=device, dtype=torch.float32))
            
            if isinstance(gates_words, list) and word_idx < len(gates_words):
                subword_outputs["gates"].append(gates_words[word_idx])
            else:
                subword_outputs["gates"].append(torch.tensor(0.0, device=device, dtype=torch.float32))
            
            if isinstance(span_preds_words, list) and word_idx < len(span_preds_words):
                subword_outputs["spanpreds"].append(span_preds_words[word_idx])
            else:
                subword_outputs["spanpreds"].append(torch.tensor(0.0, device=device, dtype=torch.float32))
            
            if isinstance(proto_assign_words, list) and word_idx < len(proto_assign_words):
                subword_outputs["protoassignments"].append(proto_assign_words[word_idx])
            elif isinstance(proto_assign_words, torch.Tensor) and word_idx < proto_assign_words.size(0):
                subword_outputs["protoassignments"].append(proto_assign_words[word_idx])
            else:
                subword_outputs["protoassignments"].append(torch.tensor(-1))
        else:
            embed_dim = word_level_outputs.get("embed_dim", 1024)
            subword_outputs["haugmented"].append(torch.zeros(embed_dim, device=device))
            subword_outputs["protoprobs"].append(torch.tensor([1.0], device=device, dtype=torch.float32))
            subword_outputs["uncertainties"].append(torch.tensor(0.0, device=device, dtype=torch.float32))
            subword_outputs["gates"].append(torch.tensor(0.0, device=device, dtype=torch.float32))
            subword_outputs["spanpreds"].append(torch.tensor(0.0, device=device, dtype=torch.float32))
            subword_outputs["protoassignments"].append(torch.tensor(-1))
    
    return subword_outputs


class MemoryOptimizedTATNWithExplanations(nn.Module):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        
        self.global_step = 0
        self._step_lock = threading.Lock()
        self.last_discovery_step = 0
        self.last_validation_step = 0
        
        self.mbart = M2M100ForConditionalGeneration.from_pretrained(
            "facebook/m2m100_418M",
            torch_dtype=torch.float32,
            use_cache=False,
        )
        try:
            self.mbart.config.use_cache = False
        except Exception:
            pass
        
        try:
            en_token_id = None
            bn_token_id = None
            
            if hasattr(self.tokenizer, "get_lang_id"):
                for code in (TARGET_LANGUAGE, "en", "en_XX", "eng"):
                    try:
                        v = self.tokenizer.get_lang_id(code)
                        if v is not None:
                            en_token_id = int(v)
                            break
                    except Exception:
                        continue
                for code in (SOURCE_LANGUAGE, "bn", "bn_IN", "ben"):
                    try:
                        v = self.tokenizer.get_lang_id(code)
                        if v is not None:
                            bn_token_id = int(v)
                            break
                    except Exception:
                        continue
            if en_token_id is None and hasattr(self.tokenizer, "lang_code_to_id"):
                try:
                    en_token_id = int(self.tokenizer.lang_code_to_id.get(TARGET_LANGUAGE, M2M100_EN_TOKEN_ID))
                except Exception:
                    en_token_id = None
            if bn_token_id is None and hasattr(self.tokenizer, "lang_code_to_id"):
                try:
                    bn_token_id = int(self.tokenizer.lang_code_to_id.get(SOURCE_LANGUAGE, M2M100_BN_TOKEN_ID))
                except Exception:
                    bn_token_id = None
            
            if en_token_id is None:
                en_token_id = M2M100_EN_TOKEN_ID
            if bn_token_id is None:
                bn_token_id = M2M100_BN_TOKEN_ID
            
            self.mbart.config.forced_bos_token_id = int(en_token_id)
            self.mbart.config.decoder_start_token_id = int(en_token_id)
            self.en_token_id = int(en_token_id)
            self.bn_token_id = int(bn_token_id)
            
            if DEBUG_DISCOVERY:
                print(f"[TATN-INIT] Language tokens: BN={bn_token_id}, EN={en_token_id}")
        
        except Exception as e:
            if DEBUG_DISCOVERY:
                print(f"[TATN-INIT] Failed to set language tokens: {e}")
            self.en_token_id = M2M100_EN_TOKEN_ID
            self.bn_token_id = M2M100_BN_TOKEN_ID
        
        try:
            if USE_GC and hasattr(self.mbart, "gradient_checkpointing_enable"):
                self.mbart.gradient_checkpointing_enable()
        except Exception:
            pass
        
        embed_dim = int(getattr(self.mbart.config, "d_model", 1024))
        
        dscd_cls = globals().get("MemoryEfficientDSCDOnline", None)
        if callable(dscd_cls):
            try:
                self.dscd = dscd_cls(
                    embeddim=embed_dim,
                    tokenizer=tokenizer,
                    buffersize=DSCD_BUFFER_SIZE,
                    maxprotos=DSCD_MAX_PROTOS,
                    nmin=max(3, DSCD_N_MIN),
                    dispersionthreshold=min(0.08, DSCD_DISPERSION_THRESHOLD),
                    language=SOURCE_LANGUAGE,
                    enabletrainingclustering=DSCD_ENABLE_TRAINING_CLUSTERING,
                    maxclusteringpoints=500,
                    maxcandidatesperstep=1,
                )
            except Exception as e:
                raise RuntimeError(f"Failed to instantiate MemoryEfficientDSCDOnline: {e}")
        else:
            raise RuntimeError("MemoryEfficientDSCDOnline not found in globals()")
        
        asbn_cls = globals().get("MemoryEfficientASBNModule", None)
        if callable(asbn_cls):
            try:
                self.asbn = asbn_cls(embed_dim, tokenizer, language=SOURCE_LANGUAGE)
            except Exception:
                class _StubASBN(nn.Module):
                    def forward(self, h, domain_labels=None):
                        dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
                        return h, torch.tensor(0.0, device=dev)
                    
                    def forward_with_grl_simplified(self, h, *args, **kwargs):
                        dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
                        zero = torch.tensor(0.0, device=dev)
                        return zero, zero, zero, zero
                self.asbn = _StubASBN()
        else:
            class _StubASBN(nn.Module):
                def forward(self, h, domain_labels=None):
                    dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
                    return h, torch.tensor(0.0, device=dev)
                
                def forward_with_grl_simplified(self, h, *args, **kwargs):
                    dev = h.device if isinstance(h, torch.Tensor) else torch.device("cpu")
                    zero = torch.tensor(0.0, device=dev)
                    return zero, zero, zero, zero
            self.asbn = _StubASBN()
        
        trg_cls = globals().get("CompleteTRGWithExplanations", None)
        if callable(trg_cls):
            try:
                self.trg_system = trg_cls(
                    embed_dim,
                    tokenizer,
                    language=SOURCE_LANGUAGE,
                    dscd_module=self.dscd,
                )
            except Exception:
                class _StubTRG:
                    def process_sentence_for_explanations(
                        self,
                        tokens,
                        dscd_outputs,
                        token_word_map=None,
                        uncertainty_threshold=0.1,
                        decoder_attention=None,
                    ):
                        return []
                self.trg_system = _StubTRG()
        else:
            class _StubTRG:
                def process_sentence_for_explanations(
                    self,
                    tokens,
                    dscd_outputs,
                    token_word_map=None,
                    uncertainty_threshold=0.1,
                    decoder_attention=None,
                ):
                    return []
            self.trg_system = _StubTRG()
        
        if DEBUG_DISCOVERY or VERBOSE_LOGGING:
            print("=" * 80)
            print("[TATN-INIT] Initialized MemoryOptimizedTATNWithExplanations:")
            print(f"  - Embed dim: {embed_dim}")
            print(f"  - Discovery frequency: {PERIODIC_DISCOVERY_FREQUENCY}")
            print(f"  - Validation interval: {VALIDATION_CHECK_INTERVAL}")
            print(f"  - Lambda ASBN: {LAMBDA_ASBN}")
            print(f"  - Lambda DSCD: {LAMBDA_DSCD}")
            print(f"  - DSCD nmin: {max(3, DSCD_N_MIN)} (enforced)")
            print(f"  - DSCD dispersion: {min(0.08, DSCD_DISPERSION_THRESHOLD)} (enforced)")
            print("=" * 80)
    
    @staticmethod
    def _entropy_reg_from_proto_probs_static(proto_probs_list, gates_list=None, min_gate: float = 0.0) -> torch.Tensor:
        if not proto_probs_list or not isinstance(proto_probs_list, list):
            return torch.tensor(0.0)
        
        dev = None
        for row in proto_probs_list:
            if isinstance(row, list):
                for p in row:
                    if isinstance(p, torch.Tensor):
                        dev = p.device
                        break
            if dev is not None:
                break
        
        if dev is None:
            return torch.tensor(0.0)
        
        total = torch.tensor(0.0, device=dev)
        count = 0
        
        for b, row in enumerate(proto_probs_list):
            if not isinstance(row, list):
                continue
            gl = gates_list[b] if (isinstance(gates_list, list) and b < len(gates_list)) else None
            for j, probs in enumerate(row):
                if not isinstance(probs, torch.Tensor) or probs.numel() == 0:
                    continue
                
                if isinstance(gl, list) and j < len(gl):
                    try:
                        gv = gl[j]
                        if isinstance(gv, torch.Tensor):
                            gv = float(gv.detach().float().mean().item())
                        else:
                            gv = float(gv)
                        if gv < min_gate:
                            continue
                    except Exception:
                        pass
                
                try:
                    p = torch.clamp(probs.to(dev).float().flatten(), 1e-8, 1.0)
                    H = -torch.sum(p * torch.log(p))
                    if torch.isfinite(H):
                        total = total + H
                        count += 1
                except Exception:
                    continue
        
        if count == 0:
            return torch.tensor(0.0, device=dev)
        return total / count
    
    def _reconstruct_word_maps_before_dscd(
        self,
        input_ids: torch.Tensor,
        batch_size: int,
        seq_len: int,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None,
    ) -> Tuple[List[Dict[int, int]], List[List[str]]]:
        token_to_word_maps_batch: List[Dict[int, int]] = []
        word_lists_batch: List[List[str]] = []
        
        if not _has_reconstruct_word_spans:
            if DEBUG_DISCOVERY:
                print("[TATN-WORDMAP] reconstruct_word_spans() not available - using fallback")
            for b in range(batch_size):
                try:
                    ids_b = input_ids[b].detach().cpu().tolist()
                    tokens = self.tokenizer.convert_ids_to_tokens(ids_b)
                    token_to_word_map: Dict[int, int] = {}
                    word_list: List[str] = []
                    current_word = ""
                    word_idx = 0
                    
                    for i, tok in enumerate(tokens):
                        clean = (
                            tok.replace("▁", "")
                            .replace("Ġ", "")
                            .replace("##", "")
                            .replace("@@", "")
                            .strip()
                            .lower()
                        )
                        
                        if tok.startswith("▁") or tok.startswith("Ġ"):
                            if current_word:
                                word_list.append(current_word)
                                word_idx += 1
                            current_word = clean
                            token_to_word_map[i] = word_idx
                        else:
                            current_word += clean
                            token_to_word_map[i] = word_idx
                    
                    if current_word:
                        word_list.append(current_word)
                    
                    if word_list:
                        token_to_word_maps_batch.append(token_to_word_map)
                        word_lists_batch.append(word_list)
                    else:
                        token_to_word_maps_batch.append({i: 0 for i in range(min(5, seq_len))})
                        word_lists_batch.append([f"tok{i}" for i in range(min(5, seq_len))])
                except Exception:
                    token_to_word_maps_batch.append({i: 0 for i in range(min(5, seq_len))})
                    word_lists_batch.append([f"tok{i}" for i in range(min(5, seq_len))])
            return token_to_word_maps_batch, word_lists_batch
        
        if DEBUG_DISCOVERY:
            print(f"[TATN-WORDMAP] Reconstructing word maps for {batch_size} samples...")
        
        for b in range(batch_size):
            try:
                if src_texts and b < len(src_texts) and isinstance(src_texts[b], str) and src_texts[b].strip():
                    orig_text = src_texts[b]
                else:
                    try:
                        orig_text = self.tokenizer.decode(input_ids[b], skip_special_tokens=True)
                    except Exception:
                        orig_text = ""
                
                if not orig_text.strip():
                    token_to_word_maps_batch.append({i: 0 for i in range(min(5, seq_len))})
                    word_lists_batch.append([f"tok{i}" for i in range(min(5, seq_len))])
                    continue
                
                wm, words = reconstruct_word_spans(self.tokenizer, orig_text, max_length=seq_len)
                
                if not isinstance(wm, dict):
                    wm = {}
                if not isinstance(words, list):
                    words = []
                
                token_to_word_map: Dict[int, int] = {}
                word_list: List[str] = []
                
                for word in words:
                    if isinstance(word, str) and word.strip():
                        clean_word = unicodedata.normalize("NFKC", word.strip().lower())
                        if clean_word:
                            word_list.append(clean_word)
                
                for token_idx, word_str in wm.items():
                    if isinstance(word_str, str) and word_str.strip():
                        clean_word = unicodedata.normalize("NFKC", word_str.strip().lower())
                        if clean_word and clean_word in word_list:
                            word_idx = word_list.index(clean_word)
                            token_to_word_map[token_idx] = word_idx
                
                if word_list and token_to_word_map:
                    token_to_word_maps_batch.append(token_to_word_map)
                    word_lists_batch.append(word_list)
                else:
                    token_to_word_maps_batch.append({i: 0 for i in range(min(5, seq_len))})
                    word_lists_batch.append([f"tok{i}" for i in range(min(5, seq_len))])
                
                if DEBUG_DISCOVERY and b == 0:
                    print(f"[TATN-WORDMAP] Sample 0: {len(word_list)} words, {len(token_to_word_map)} mapped tokens")
            
            except Exception as e:
                if DEBUG_DISCOVERY:
                    print(f"[TATN-WORDMAP] Reconstruction failed for sample {b}: {e}")
                token_to_word_maps_batch.append({i: 0 for i in range(min(5, seq_len))})
                word_lists_batch.append([f"tok{i}" for i in range(min(5, seq_len))])
        
        total_words = sum(len(wl) for wl in word_lists_batch)
        if DEBUG_DISCOVERY:
            print(f"[TATN-WORDMAP] Reconstructed {total_words} words across {batch_size} samples")
        
        return token_to_word_maps_batch, word_lists_batch
    
    def _extract_domain_labels(
        self,
        batch_size: int,
        device: torch.device,
        src_texts: Optional[List[str]] = None,
    ) -> Optional[torch.Tensor]:
        if not USE_DOMAIN_LABELS:
            return None
        
        try:
            if self.training:
                return torch.full((batch_size,), TRAIN_DOMAIN, dtype=torch.long, device=device)
            else:
                return torch.full((batch_size,), TEST_DOMAIN, dtype=torch.long, device=device)
        except Exception:
            return None
    
    @staticmethod
    def _safe_take_key_static(
        dscd_struct: Dict[str, Any],
        key: str,
        b_index: int,
        seq_len: int,
        device: torch.device,
    ):
        if key == "proto_probs":
            out = [torch.tensor([1.0], dtype=torch.float32, device=device) for _ in range(seq_len)]
        else:
            out = [torch.tensor(0.0, dtype=torch.float32, device=device) for _ in range(seq_len)]
        
        try:
            val = dscd_struct.get(key, None)
            if val is None:
                return out
            
            if key == "proto_probs":
                if isinstance(val, list) and len(val) > b_index:
                    row = val[b_index]
                    if isinstance(row, list):
                        for t in range(min(seq_len, len(row))):
                            v = row[t]
                            if isinstance(v, torch.Tensor):
                                vv = v.to(device).float().flatten()
                                out[t] = vv if vv.numel() > 0 else torch.tensor([1.0], device=device, dtype=torch.float32)
                            else:
                                try:
                                    vv = torch.as_tensor(v, dtype=torch.float32, device=device).flatten()
                                    out[t] = vv if vv.numel() > 0 else torch.tensor([1.0], device=device, dtype=torch.float32)
                                except Exception:
                                    pass
                return out
            
            if isinstance(val, list) and len(val) > b_index:
                row = val[b_index]
                if isinstance(row, list):
                    for t in range(min(seq_len, len(row))):
                        v = row[t]
                        try:
                            if isinstance(v, torch.Tensor):
                                out[t] = torch.tensor(float(v.detach().float().mean().item()), device=device)
                            else:
                                out[t] = torch.tensor(float(v), device=device)
                        except Exception:
                            pass
                elif isinstance(row, torch.Tensor):
                    if row.dim() == 1:
                        for t in range(min(seq_len, int(row.size(0)))):
                            try:
                                out[t] = torch.tensor(float(row[t].item()), device=device)
                            except Exception:
                                pass
                return out
            
            if isinstance(val, torch.Tensor):
                if val.dim() >= 2 and int(val.size(0)) > b_index:
                    for t in range(min(seq_len, int(val.size(1)))):
                        try:
                            v = val[b_index, t]
                            if isinstance(v, torch.Tensor):
                                if v.numel() == 1:
                                    out[t] = torch.tensor(float(v.item()), device=device)
                                else:
                                    out[t] = v.to(device).float().mean()
                            else:
                                out[t] = torch.tensor(float(v), device=device)
                        except Exception:
                            pass
        except Exception:
            pass
        
        return out
    
    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None,
        labels: Optional[torch.Tensor] = None,
        use_dscd: bool = True,
        use_asbn: bool = True,
        track_stats: bool = False,
        fast_inference: bool = False,
    ):
        with self._step_lock:
            self.global_step += 1
            current_step = self.global_step
        
        if input_ids is None or attention_mask is None:
            raise ValueError("input_ids and attention_mask cannot be None")
        if input_ids.dim() != 2 or attention_mask.dim() != 2:
            raise ValueError(f"Expected 2D tensors, got {input_ids.shape}, {attention_mask.shape}")
        
        batch_size, seq_len = int(input_ids.size(0)), int(input_ids.size(1))
        device = input_ids.device
        
        if (
            torch.cuda.is_available()
            and MEMORY_CLEANUP_FREQUENCY > 0
            and current_step % MEMORY_CLEANUP_FREQUENCY == 0
        ):
            for i in range(min(NUM_GPUS, torch.cuda.device_count())):
                try:
                    with torch.cuda.device(i):
                        torch.cuda.empty_cache()
                except Exception:
                    pass
            if gc.isenabled():
                gc.collect()
        
        if self.training and DSCD_ENABLE_TRAINING_CLUSTERING and use_dscd:
            if (current_step - self.last_discovery_step) >= PERIODIC_DISCOVERY_FREQUENCY:
                try:
                    if DEBUG_DISCOVERY or VERBOSE_LOGGING:
                        print("\n" + "=" * 80)
                        print(f"[TATN] PERIODIC DISCOVERY @ step {current_step}")
                        print("=" * 80)
                    
                    start_time = time.time()
                    if hasattr(self.dscd, 'periodic_discovery_check'):
                        self.dscd.periodic_discovery_check(current_step, PERIODIC_DISCOVERY_FREQUENCY)
                    elapsed = time.time() - start_time
                    self.last_discovery_step = current_step
                    
                    if DEBUG_DISCOVERY or VERBOSE_LOGGING:
                        num_tokens = len(self.dscd.prototypestores) if hasattr(self.dscd, 'prototypestores') else 0
                        total_protos = sum(s.size() for s in self.dscd.prototypestores.values()) if hasattr(self.dscd, 'prototypestores') else 0
                        num_homographs = len(self.dscd.discoveredhomographs) if hasattr(self.dscd, 'discoveredhomographs') else 0
                        
                        print(f"[TATN] Discovery completed in {elapsed:.2f}s")
                        print(f"[TATN]   Tokens: {num_tokens}")
                        print(f"[TATN]   Prototypes: {total_protos}")
                        print(f"[TATN]   Homographs: {num_homographs}")
                        print("=" * 80 + "\n")
                except Exception as e:
                    if DEBUG_DISCOVERY or VERBOSE_LOGGING:
                        print(f"[TATN] Discovery failed: {e}")
                        try:
                            traceback.print_exc()
                        except Exception:
                            pass
        
        if not self.training and VALIDATION_CHECK_INTERVAL > 0:
            if (current_step - self.last_validation_step) >= VALIDATION_CHECK_INTERVAL:
                try:
                    if DEBUG_DISCOVERY:
                        print(f"\n[TATN-VALIDATION] Step {current_step}")
                        num_tokens = len(self.dscd.prototypestores) if hasattr(self.dscd, 'prototypestores') else 0
                        total_protos = sum(s.size() for s in self.dscd.prototypestores.values()) if hasattr(self.dscd, 'prototypestores') else 0
                        num_homographs = len(self.dscd.discoveredhomographs) if hasattr(self.dscd, 'discoveredhomographs') else 0
                        print(f"  - Tokens: {num_tokens}")
                        print(f"  - Prototypes: {total_protos}")
                        print(f"  - Homographs: {num_homographs}")
                    self.last_validation_step = current_step
                except Exception:
                    pass
        
        enc_outputs = None
        try:
            enc_outputs = self.mbart.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
        except Exception:
            try:
                enc_outputs = self.mbart.get_encoder()(input_ids=input_ids, attention_mask=attention_mask)
            except Exception as e:
                if DEBUG_DISCOVERY:
                    print(f"[TATN] Encoder failed: {e}")
                enc_outputs = None
        
        h = _safe_get_last_hidden_state(enc_outputs)
        if h is None:
            try:
                emb = self.mbart.get_input_embeddings()(input_ids).to(device)
                h = emb
            except Exception:
                h = torch.zeros(
                    batch_size,
                    seq_len,
                    int(getattr(self.mbart.config, "d_model", 1024)),
                    device=device,
                )
        
        embed_dim = int(h.size(-1))
        training_mode = labels is not None and self.training
        
        token_to_word_maps_batch, word_lists_batch = self._reconstruct_word_maps_before_dscd(
            input_ids, batch_size, seq_len, src_texts, token_word_map
        )
        
        domain_labels = self._extract_domain_labels(batch_size, device, src_texts)
        
        if use_dscd and not fast_inference:
            try:
                if DEBUG_DISCOVERY:
                    print(f"[TATN-DSCD] Processing {batch_size} samples with word-level aggregation...")
                
                batch_word_embeddings = []
                for b in range(batch_size):
                    token_to_word_map_b = token_to_word_maps_batch[b]
                    num_words = len(word_lists_batch[b])
                    word_emb = aggregate_subwords_to_words(h[b], token_to_word_map_b, num_words, device)
                    batch_word_embeddings.append(word_emb)
                
                max_word_len = max(w.size(0) for w in batch_word_embeddings)
                h_words_padded = torch.zeros(batch_size, max_word_len, embed_dim, device=device)
                for b, w_emb in enumerate(batch_word_embeddings):
                    h_words_padded[b, :w_emb.size(0)] = w_emb
                
                tokentypes_batch = word_lists_batch
                
                raw_dscd = self.dscd.forward(
                    tokenembeddings=h_words_padded,
                    tokentypes=tokentypes_batch,
                    trainmode=self.training,
                    tokenwordmap=None,
                    inputids=None,
                    attentionmask=None,
                    fast_inference=False,
                )
                
                dscd_subword_batch = []
                for b in range(batch_size):
                    word_level_out = {
                        "haugmented": raw_dscd.get("haugmented", [])[b] if isinstance(raw_dscd.get("haugmented"), (list, tuple)) else (raw_dscd.get("haugmented")[b] if isinstance(raw_dscd.get("haugmented"), torch.Tensor) else []),
                        "protoprobs": raw_dscd.get("protoprobs", [])[b] if isinstance(raw_dscd.get("protoprobs"), list) else [],
                        "uncertainties": raw_dscd.get("uncertainties", [])[b] if isinstance(raw_dscd.get("uncertainties"), list) else [],
                        "gates": raw_dscd.get("gates", [])[b] if isinstance(raw_dscd.get("gates"), list) else [],
                        "spanpreds": raw_dscd.get("spanpreds", [])[b] if isinstance(raw_dscd.get("spanpreds"), list) else [],
                        "protoassignments": raw_dscd.get("protoassignments", [])[b] if isinstance(raw_dscd.get("protoassignments"), list) else torch.zeros(0),
                        "embed_dim": embed_dim,
                    }
                    
                    subword_out = broadcast_word_to_subword(
                        word_level_out,
                        token_to_word_maps_batch[b],
                        seq_len,
                        device
                    )
                    dscd_subword_batch.append(subword_out)
                
                combined_raw_dscd = {
                    "haugmented": [],
                    "protoprobs": [],
                    "uncertainties": [],
                    "gates": [],
                    "spanpreds": [],
                    "protoassignments": [],
                }
                
                for b in range(batch_size):
                    if isinstance(dscd_subword_batch[b]["haugmented"], list):
                        combined_raw_dscd["haugmented"].append(torch.stack(dscd_subword_batch[b]["haugmented"]))
                    else:
                        combined_raw_dscd["haugmented"].append(torch.zeros(seq_len, embed_dim, device=device))
                    
                    combined_raw_dscd["protoprobs"].append(dscd_subword_batch[b]["protoprobs"])
                    combined_raw_dscd["uncertainties"].append(dscd_subword_batch[b]["uncertainties"])
                    combined_raw_dscd["gates"].append(dscd_subword_batch[b]["gates"])
                    combined_raw_dscd["spanpreds"].append(dscd_subword_batch[b]["spanpreds"])
                    combined_raw_dscd["protoassignments"].append(dscd_subword_batch[b]["protoassignments"])
                
                try:
                    combined_raw_dscd["haugmented"] = torch.stack(combined_raw_dscd["haugmented"])
                except Exception:
                    combined_raw_dscd["haugmented"] = h.clone()
                
                raw_dscd = combined_raw_dscd
                
                if DEBUG_DISCOVERY:
                    print(f"[TATN-DSCD] Word-level processing complete")
            
            except Exception as e:
                if DEBUG_DISCOVERY:
                    print(f"[TATN] DSCD word-level processing failed: {e}")
                    traceback.print_exc()
                raw_dscd = {
                    "haugmented": h.detach().clone(),
                    "protoprobs": [
                        [torch.tensor([1.0], dtype=torch.float32, device=device) for _ in range(seq_len)]
                        for _ in range(batch_size)
                    ],
                    "uncertainties": [
                        [torch.tensor(0.0, device=device) for _ in range(seq_len)]
                        for _ in range(batch_size)
                    ],
                    "gates": [
                        [torch.tensor(0.0, device=device) for _ in range(seq_len)]
                        for _ in range(batch_size)
                    ],
                    "spanpreds": [
                        [torch.tensor(0.0, device=device) for _ in range(seq_len)]
                        for _ in range(batch_size)
                    ],
                    "protoassignments": [
                        torch.zeros(seq_len, dtype=torch.long, device=device) for _ in range(batch_size)
                    ],
                }
        else:
            raw_dscd = {
                "haugmented": h.detach().clone(),
                "protoprobs": [
                    [torch.tensor([1.0], dtype=torch.float32, device=device) for _ in range(seq_len)]
                    for _ in range(batch_size)
                ],
                "uncertainties": [
                    [torch.tensor(0.0, device=device) for _ in range(seq_len)]
                    for _ in range(batch_size)
                ],
                "gates": [
                    [torch.tensor(0.0, device=device) for _ in range(seq_len)]
                    for _ in range(batch_size)
                ],
                "spanpreds": [
                    [torch.tensor(0.0, device=device) for _ in range(seq_len)]
                    for _ in range(batch_size)
                ],
                "protoassignments": [
                    torch.zeros(seq_len, dtype=torch.long, device=device) for _ in range(batch_size)
                ],
            }
        
        dscd = _normalize_dscd_outputs(raw_dscd, batch_size, seq_len, device, embed_dim)
        
        h_aug = dscd.get("h_augmented", h)
        
        if not isinstance(h_aug, torch.Tensor) or h_aug.shape != h.shape:
            if DEBUG_DISCOVERY:
                print(
                    f"[TATN] h_augmented shape mismatch (expected {h.shape}, got {getattr(h_aug, 'shape', None)})"
                )
            h_aug = h
        
        if use_asbn and domain_labels is not None:
            try:
                h_aug, _ = self.asbn.forward(h_aug, domain_labels=domain_labels)
            except Exception as e:
                if DEBUG_DISCOVERY:
                    print(f"[TATN] ASBN forward (BN) failed: {e}")
        
        try:
            enc_for_decoder = BaseModelOutput(
                last_hidden_state=h_aug,
                hidden_states=(getattr(enc_outputs, "hidden_states", None) if enc_outputs else None),
                attentions=(getattr(enc_outputs, "attentions", None) if enc_outputs else None),
            )
        except Exception:
            enc_for_decoder = (h_aug,)
        
        if training_mode:
            try:
                if labels is not None:
                    decoder_input_ids = labels.clone()
                    decoder_input_ids = torch.where(
                        decoder_input_ids == -100,
                        torch.full_like(decoder_input_ids, self.tokenizer.pad_token_id),
                        decoder_input_ids,
                    )
                    
                    bos_column = torch.full(
                        (batch_size, 1),
                        int(self.mbart.config.decoder_start_token_id),
                        dtype=torch.long,
                        device=device,
                    )
                    decoder_input_ids = torch.cat([bos_column, decoder_input_ids[:, :-1]], dim=1)
                    decoder_attention_mask = (decoder_input_ids != self.tokenizer.pad_token_id).long()
                else:
                    decoder_input_ids = None
                    decoder_attention_mask = None
                
                seq_outputs = self.mbart(
                    input_ids=None,
                    attention_mask=attention_mask,
                    encoder_outputs=enc_for_decoder,
                    decoder_input_ids=decoder_input_ids,
                    decoder_attention_mask=decoder_attention_mask,
                    labels=labels,
                    use_cache=False,
                    return_dict=True,
                )
                translation_loss = getattr(seq_outputs, "loss", None)
                if translation_loss is None:
                    translation_loss = torch.tensor(0.0, device=device)
            except Exception as e:
                if DEBUG_DISCOVERY:
                    print(f"[TATN] Decoder forward failed: {e}")
                    try:
                        traceback.print_exc()
                    except Exception:
                        pass
                translation_loss = torch.tensor(0.0, device=device)
            
            if use_asbn:
                try:
                    asbn_ret = self.asbn.forward_with_grl_simplified(
                        h_aug,
                        dscd.get("proto_probs", None),
                        dscd.get("uncertainties", None),
                        dscd.get("gates", None),
                        token_word_map=token_to_word_maps_batch,
                        domain_labels=domain_labels,
                        global_step=current_step,
                    )
                    if isinstance(asbn_ret, (tuple, list)):
                        asbn_loss = asbn_ret[0]
                    else:
                        asbn_loss = asbn_ret
                    
                    if not isinstance(asbn_loss, torch.Tensor):
                        asbn_loss = torch.tensor(float(asbn_loss), device=device)
                    else:
                        asbn_loss = asbn_loss.to(device)
                    
                    if not torch.isfinite(asbn_loss):
                        asbn_loss = torch.tensor(0.0, device=device)
                    asbn_loss = torch.clamp(asbn_loss, 0.0, 10.0)
                except Exception as e:
                    if DEBUG_DISCOVERY:
                        print(f"[TATN] ASBN forward failed: {e}")
                    asbn_loss = torch.tensor(0.0, device=device)
            else:
                asbn_loss = torch.tensor(0.0, device=device)
            
            try:
                dscd_reg = self._entropy_reg_from_proto_probs_static(
                    dscd.get("proto_probs", []),
                    gates_list=dscd.get("gates", []),
                    min_gate=0.0,
                )
                if not isinstance(dscd_reg, torch.Tensor):
                    dscd_reg = torch.tensor(float(dscd_reg), device=device)
                else:
                    dscd_reg = dscd_reg.to(device)
                if not torch.isfinite(dscd_reg):
                    dscd_reg = torch.tensor(0.0, device=device)
            except Exception as e:
                if DEBUG_DISCOVERY:
                    print(f"[TATN] DSCD reg failed: {e}")
                dscd_reg = torch.tensor(0.0, device=device)
            
            total_loss = translation_loss + LAMBDA_ASBN * asbn_loss + LAMBDA_DSCD * dscd_reg
            if not isinstance(total_loss, torch.Tensor):
                total_loss = torch.tensor(float(total_loss), device=device)
            if total_loss.numel() != 1:
                total_loss = total_loss.mean()
            
            if not torch.isfinite(total_loss):
                if DEBUG_DISCOVERY:
                    print("[TATN] NaN/Inf detected in total_loss - using translation_loss only")
                total_loss = translation_loss if torch.isfinite(translation_loss) else torch.tensor(1.0, device=device)
            
            try:
                del enc_outputs, h, raw_dscd
            except Exception:
                pass
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            return total_loss
        
        explanations_list: List[List[Dict[str, Any]]] = []
        
        if (not self.training) and ENABLE_TRG_INFERENCE:
            if DEBUG_DISCOVERY:
                print(f"\n[TATN-INFERENCE] Starting TRG for {batch_size} samples")
            
            tokens_batch: List[List[str]] = []
            
            for b in range(batch_size):
                try:
                    ids_b = input_ids[b].detach().cpu().tolist()
                    if hasattr(self.tokenizer, "convert_ids_to_tokens"):
                        toks = self.tokenizer.convert_ids_to_tokens(ids_b)
                    else:
                        toks = []
                    if not toks:
                        toks = ["UNK"] * seq_len
                    elif len(toks) < seq_len:
                        toks = toks + [""] * (seq_len - len(toks))
                    elif len(toks) > seq_len:
                        toks = toks[:seq_len]
                except Exception:
                    toks = ["UNK"] * seq_len
                tokens_batch.append(toks)
            
            decoder_attention = None
            
            try:
                total_explanations = 0
                
                for b in range(batch_size):
                    per_sent = {
                        "proto_probs": self._safe_take_key_static(dscd, "proto_probs", b, seq_len, device),
                        "uncertainties": self._safe_take_key_static(dscd, "uncertainties", b, seq_len, device),
                        "gates": self._safe_take_key_static(dscd, "gates", b, seq_len, device),
                        "span_preds": self._safe_take_key_static(dscd, "span_preds", b, seq_len, device),
                    }
                    
                    try:
                        word_map_for_trg = {}
                        for tok_idx, word_idx in token_to_word_maps_batch[b].items():
                            if word_idx < len(word_lists_batch[b]):
                                word_map_for_trg[tok_idx] = word_lists_batch[b][word_idx]
                        
                        exps = self.trg_system.process_sentence_for_explanations(
                            tokens_batch[b],
                            per_sent,
                            token_word_map=word_map_for_trg,
                            uncertainty_threshold=TRG_UNCERTAINTY_THRESHOLD,
                            decoder_attention=decoder_attention,
                        )
                        batch_exps = exps if isinstance(exps, list) else []
                        explanations_list.append(batch_exps)
                        total_explanations += len(batch_exps)
                        
                        if DEBUG_DISCOVERY and b < 2:
                            print(f"[TATN-INFERENCE] Sample {b}: {len(batch_exps)} explanations")
                    
                    except Exception as e:
                        if DEBUG_DISCOVERY:
                            print(f"[TATN-INFERENCE] TRG failed for sample {b}: {e}")
                        explanations_list.append([])
                
                if DEBUG_DISCOVERY:
                    print(f"\n[TATN-INFERENCE] Total explanations: {total_explanations}")
                    if total_explanations == 0:
                        print("[TATN-INFERENCE] NO EXPLANATIONS GENERATED")
            
            except Exception as e:
                if DEBUG_DISCOVERY:
                    print(f"[TATN-INFERENCE] TRG generation failed: {e}")
                    try:
                        traceback.print_exc()
                    except Exception:
                        pass
                explanations_list = [[] for _ in range(batch_size)]
        else:
            explanations_list = [[] for _ in range(batch_size)]
        
        outputs = {
            "encoder_outputs": enc_outputs,
            "dscd_outputs": dscd,
            "sense_augmented_embeddings": h_aug,
            "explanations": explanations_list,
            "asbn_loss": torch.tensor(0.0, device=device),
            "ambiguity_signals": {
                "span": dscd.get("span_preds", []),
                "uncertainty": dscd.get("uncertainties", []),
                "confidence": [
                    [
                        1.0
                        - (
                            float(u)
                            if isinstance(u, (float, int))
                            else (float(u.item()) if isinstance(u, torch.Tensor) else 1.0)
                        )
                        for u in row
                    ]
                    for row in dscd.get("uncertainties", [])
                ],
                "proto_probs": dscd.get("proto_probs", []),
            },
        }
        
        try:
            del h, raw_dscd
        except Exception:
            pass
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        return outputs
    
    def generate(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        max_length: int = 128,
        num_beams: int = 5,
        early_stopping: bool = True,
        **kwargs,
    ) -> torch.Tensor:
        try:
            enc_outputs = self.mbart.model.encoder(input_ids=input_ids, attention_mask=attention_mask)
            
            enc_wrapped = BaseModelOutput(
                last_hidden_state=(enc_outputs.last_hidden_state if hasattr(enc_outputs, "last_hidden_state") else enc_outputs[0]),
                hidden_states=getattr(enc_outputs, "hidden_states", None),
                attentions=getattr(enc_outputs, "attentions", None),
            )
            
            return self.mbart.generate(
                input_ids=None,
                attention_mask=attention_mask,
                encoder_outputs=enc_wrapped,
                max_length=max_length,
                num_beams=num_beams,
                early_stopping=early_stopping,
                forced_bos_token_id=int(self.mbart.config.forced_bos_token_id),
                repetition_penalty=1.2,
                no_repeat_ngram_size=3,
                length_penalty=1.0,
                do_sample=False,
                **kwargs,
            )
        except Exception as e:
            if DEBUG_DISCOVERY:
                print(f"[TATN-GENERATE] Failed: {e}")
            raise
    
    def forward_with_explanations(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None,
    ):
        return self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            src_texts=src_texts,
            token_word_map=token_word_map,
            labels=None,
        )
    
    def get_component_stats(self) -> Dict[str, Any]:
        stats: Dict[str, Any] = {
            "global_step": self.global_step,
            "last_discovery_step": self.last_discovery_step,
            "last_validation_step": self.last_validation_step,
        }
        
        try:
            if hasattr(self.dscd, 'prototypestores'):
                num_tokens = len(self.dscd.prototypestores)
                total_protos = sum(s.size() for s in self.dscd.prototypestores.values())
                num_homographs = len(self.dscd.discoveredhomographs) if hasattr(self.dscd, 'discoveredhomographs') else 0
                
                stats["dscd"] = {
                    "total_tokens": num_tokens,
                    "total_prototypes": total_protos,
                    "num_homographs": num_homographs,
                }
        except Exception:
            pass
        
        try:
            if hasattr(self.asbn, "get_detailed_stats"):
                stats["asbn"] = self.asbn.get_detailed_stats()
        except Exception:
            pass
        
        try:
            if hasattr(self.trg_system, "get_statistics"):
                stats["trg"] = self.trg_system.get_statistics()
        except Exception:
            pass
        
        return stats

print("\n" + "=" * 80)
print("Cell 6: MemoryOptimizedTATNWithExplanations Ready")
print("=" * 80)
print("NEW FEATURES:")
print("  ✓ Word-level DSCD aggregation")
print("  ✓ Subword broadcasting")
print("  ✓ Case-insensitive word keys")
print("  ✓ Enforced nmin=3, dispersion=0.08")
print("=" * 80 + "\n")



Cell 6: MemoryOptimizedTATNWithExplanations Ready
NEW FEATURES:
  ✓ Word-level DSCD aggregation
  ✓ Subword broadcasting
  ✓ Case-insensitive word keys
  ✓ Enforced nmin=3, dispersion=0.08



In [10]:
# ==============================================================================
# CELL 7: TRAINING LOOP (PURE UNSUPERVISED) - FIXED
# ==============================================================================

import os
import time
import math
import gc
import traceback
from datetime import datetime
from pathlib import Path
from collections import defaultdict, deque
from typing import Optional, Dict, Any, List
import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast as cuda_amp_autocast
from tqdm import tqdm
from contextlib import nullcontext
import threading

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except (NameError, TypeError):
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except (NameError, TypeError):
    _DEBUG_DISCOVERY = False

DEBUG_PRINT_INTERVAL = 200
_cell7_dbg_counts = defaultdict(int)

def cell7_dbg(key: str, msg: str, limit: int = 10):
    if not (_VERBOSE_LOGGING or _DEBUG_DISCOVERY):
        return
    _cell7_dbg_counts[key] += 1
    if _cell7_dbg_counts[key] <= limit:
        print(f"[CELL7-DBG] {msg}")

try:
    _DEVICE = DEVICE
except (NameError, TypeError):
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    _EPOCHS = int(EPOCHS)
except (NameError, ValueError, TypeError):
    _EPOCHS = 1

try:
    _BATCH_SIZE = int(BATCH_SIZE)
except (NameError, ValueError, TypeError):
    _BATCH_SIZE = 8

try:
    _ACCUMULATION_STEPS = int(ACCUMULATION_STEPS)
except (NameError, ValueError, TypeError):
    _ACCUMULATION_STEPS = 1

try:
    _GRAD_CLIP_NORM = float(GRAD_CLIP_NORM)
except (NameError, ValueError, TypeError):
    _GRAD_CLIP_NORM = 1.0

try:
    _MEMORY_CLEANUP_FREQUENCY = int(MEMORY_CLEANUP_FREQUENCY)
except (NameError, ValueError, TypeError):
    _MEMORY_CLEANUP_FREQUENCY = 500

try:
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
    _NUM_GPUS = int(NUM_GPUS)
except (NameError, ValueError, TypeError):
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1

try:
    _USE_AMP = bool(USE_AMP)
except (NameError, TypeError):
    _USE_AMP = True

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
    _TARGET_LANGUAGE = str(TARGET_LANGUAGE)
except (NameError, TypeError):
    _SOURCE_LANGUAGE = "bn"
    _TARGET_LANGUAGE = "en"

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except (NameError, ValueError, TypeError):
    _MAX_LENGTH = 48

try:
    _VALIDATION_CHECK_INTERVAL = int(VALIDATION_CHECK_INTERVAL)
except (NameError, ValueError, TypeError):
    _VALIDATION_CHECK_INTERVAL = 500

try:
    _PERIODIC_DISCOVERY_FREQUENCY = int(PERIODIC_DISCOVERY_FREQUENCY)
except (NameError, ValueError, TypeError):
    _PERIODIC_DISCOVERY_FREQUENCY = 3000

try:
    _TRAIN_DOMAIN = int(TRAIN_DOMAIN)
    _TEST_DOMAIN = int(TEST_DOMAIN)
except (NameError, ValueError, TypeError):
    _TRAIN_DOMAIN = 0
    _TEST_DOMAIN = 1

try:
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in HOMOGRAPH_REFERENCE_LIST_BN)
except (NameError, TypeError):
    _HOMOGRAPH_REFERENCE_LIST = {
        "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা",
        "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত"
    }
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in _HOMOGRAPH_REFERENCE_LIST)

def clear_all_gpu_caches():
    gc.collect()
    if not torch.cuda.is_available():
        return
    try:
        for i in range(torch.cuda.device_count()):
            with torch.cuda.device(i):
                try:
                    torch.cuda.empty_cache()
                except Exception:
                    pass
    except Exception:
        pass

def get_amp_ctx():
    if not _USE_AMP or not torch.cuda.is_available():
        return nullcontext()
    try:
        return cuda_amp_autocast()
    except Exception:
        return nullcontext()

_PROTOBUF_COMPAT_ERROR_SHOWN = globals().get("_PROTOBUF_COMPAT_ERROR_SHOWN", False)

def _resolve_dscd_stores(dscd):
    if dscd is None:
        return {}
    for attr_name in ("prototypestores", "prototype_stores"):
        stores = getattr(dscd, attr_name, None)
        if isinstance(stores, dict):
            return stores
        if stores is not None:
            try:
                return dict(stores)
            except Exception:
                pass
    return {}

def _resolve_dscd_lock(dscd):
    if dscd is None:
        return None
    for name in ("bufferlock", "buffer_lock", "clusteringlock", "clustering_lock"):
        lock = getattr(dscd, name, None)
        if lock is not None:
            return lock
    return None

def _get_dscd_homographs(model: torch.nn.Module) -> set:
    try:
        core = model.module if hasattr(model, 'module') else model
        dscd = getattr(core, 'dscd', None)
        if dscd is None:
            return set()

        if hasattr(dscd, 'get_discovered_homographs'):
            return dscd.get_discovered_homographs()

        homographs = set()
        stores = _resolve_dscd_stores(dscd)
        lock = _resolve_dscd_lock(dscd)

        if lock:
            with lock:
                for token, store in stores.items():
                    try:
                        size_val = store.size() if hasattr(store, 'size') else len(getattr(store, 'centroids', []))
                        if size_val >= 2:
                            clean_token = str(token).replace('▁', '').replace('Ġ', '').replace('##', '').strip().lower()
                            homographs.add(clean_token)
                    except Exception:
                        continue
        else:
            for token, store in stores.items():
                try:
                    size_val = store.size() if hasattr(store, 'size') else len(getattr(store, 'centroids', []))
                    if size_val >= 2:
                        clean_token = str(token).replace('▁', '').replace('Ġ', '').replace('##', '').strip().lower()
                        homographs.add(clean_token)
                except Exception:
                    continue

        return homographs
    except Exception:
        return set()

def synchronous_dscd_clustering(model: torch.nn.Module, force: bool = True) -> Dict[str, Any]:
    try:
        core = model.module if hasattr(model, 'module') else model
        dscd = getattr(core, 'dscd', None)
        
        if dscd is None:
            return {'success': False, 'reason': 'no_dscd'}
        
        stores = _resolve_dscd_stores(dscd)
        lock = _resolve_dscd_lock(dscd)
        
        if not stores:
            return {'success': True, 'clustered': 0, 'reason': 'no_stores'}
        
        cluster_count = 0
        failed_count = 0
        
        tokens_to_cluster = list(stores.keys())
        
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print(f"[SYNC-CLUSTER] Clustering {len(tokens_to_cluster)} tokens...")
        
        for token in tokens_to_cluster:
            try:
                if hasattr(dscd, 'clusterbuffertoprototypeshierarchical'):
                    if lock:
                        with lock:
                            success = dscd.clusterbuffertoprototypeshierarchical(token)
                    else:
                        success = dscd.clusterbuffertoprototypeshierarchical(token)
                    
                    if success:
                        cluster_count += 1
                    else:
                        failed_count += 1
                else:
                    failed_count += 1
            except Exception:
                failed_count += 1
                continue
        
        discovery_result = {}
        if hasattr(dscd, 'discover_homographs'):
            try:
                if lock:
                    with lock:
                        discovery_result = dscd.discover_homographs()
                else:
                    discovery_result = dscd.discover_homographs()
            except Exception:
                pass
        
        result = {
            'success': True,
            'clustered': cluster_count,
            'failed': failed_count,
            'total_tokens': len(tokens_to_cluster),
            'discovered': discovery_result.get('discovered', 0),
            'total_homographs': discovery_result.get('total_homographs', 0),
        }
        
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print(f"[SYNC-CLUSTER] Results: {cluster_count} clustered, {failed_count} failed")
            print(f"[SYNC-CLUSTER] Discovery: {result.get('discovered', 0)} new homographs")
            print(f"[SYNC-CLUSTER] Total homographs: {result.get('total_homographs', 0)}")
        
        return result
        
    except Exception as e:
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print(f"[SYNC-CLUSTER] Error: {type(e).__name__}: {str(e)[:200]}")
        return {'success': False, 'reason': str(e)[:200]}

@torch.inference_mode()
def comprehensive_epoch_validation(
    model: torch.nn.Module,
    tokenizer,
    epoch: int,
    global_step: int,
    source_lang: str,
    target_lang: str,
    max_length: int,
    device: torch.device
) -> Dict[str, Any]:
    global _PROTOBUF_COMPAT_ERROR_SHOWN

    print("\n" + "=" * 80)
    print(f"EPOCH {epoch} COMPREHENSIVE VALIDATION (Step {global_step})")
    print("=" * 80)

    core_model = model.module if hasattr(model, "module") else model
    was_training = core_model.training

    if not isinstance(device, torch.device):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    dscd_homographs = _get_dscd_homographs(model)
    if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
        print(f"[VALIDATION] DSCD discovered homographs: {len(dscd_homographs)}")
        if dscd_homographs:
            print(f"[VALIDATION] Sample: {list(dscd_homographs)[:10]}")

    validation_results = {
        'epoch': epoch,
        'step': global_step,
        'translations_success': 0,
        'translations_failed': 0,
        'explanations_generated': 0,
        'dscd_homographs_explained': 0,
        'reference_homographs_explained': 0,
        'avg_explanation_confidence': 0.0,
        'dscd_quality_score': 0.0,
        'dscd_multi_sense_tokens': 0,
        'dscd_total_prototypes': 0,
        'asbn_domain_loss': 0.0,
        'asbn_domain_accuracy': 0.0,
        'asbn_source_accuracy': 0.0,
        'asbn_target_accuracy': 0.0,
        'trg_total_explanations': 0,
        'validation_completed': False,
    }

    try:
        core_model.eval()

        val_sentences = [
            ("আমি কল বন্ধ করেছি।", "I turned off the tap", "কল=tap/call"),
            ("কাল আমি বই কিনব।", "Tomorrow I will buy a book", "কাল=tomorrow/yesterday"),
            ("পাতা ঝরে পড়েছে।", "The leaf has fallen", "পাতা=leaf/page"),
            ("তিনি ব্যাংক গেছেন।", "He went to the bank", "ব্যাংক=bank/embankment"),
            ("আমি ভালো আছি।", "I am fine", "No ambiguity"),
            ("সে খুব মিষ্টি কথা বলে।", "She speaks sweetly", "No ambiguity"),
            ("এটা আমার বই।", "This is my book", "No ambiguity"),
            ("আজ আবহাওয়া ভালো।", "Weather is good today", "No ambiguity"),
            ("ফল খুব সুস্বাদু।", "The fruit is delicious", "ফল=fruit/result"),
            ("মাথা ব্যথা করছে।", "Head is aching", "মাথা=head/top"),
        ]

        print(f"\n[VALIDATION] Testing {len(val_sentences)} samples:")
        print("-" * 80)

        confidences = []
        dscd_homograph_words_detected = set()
        reference_homograph_words_detected = set()

        mbart_obj = None
        try:
            mbart_obj = getattr(core_model, "mbart", None)
        except Exception:
            mbart_obj = None

        try:
            try:
                tokenizer.src_lang = source_lang
            except Exception:
                pass

            forced_id = None
            try:
                if hasattr(tokenizer, "get_lang_id"):
                    for code in (target_lang, "en_XX", "en", "eng"):
                        try:
                            lid = tokenizer.get_lang_id(code)
                            if lid is not None:
                                forced_id = int(lid)
                                break
                        except Exception:
                            continue
                elif hasattr(tokenizer, "lang_code_to_id"):
                    forced_id = tokenizer.lang_code_to_id.get(target_lang, None)
                    if forced_id is not None:
                        forced_id = int(forced_id)
            except Exception:
                forced_id = None

            if forced_id is None:
                try:
                    forced_id = int(globals().get('M2M100_EN_TOKEN_ID', 128022))
                except Exception:
                    forced_id = 128022

            orig_use_cache = None
            try:
                if mbart_obj is not None and hasattr(mbart_obj.config, "use_cache"):
                    orig_use_cache = mbart_obj.config.use_cache
                    mbart_obj.config.use_cache = True
            except Exception:
                orig_use_cache = None

            for idx, (src, expected, note) in enumerate(val_sentences, 1):
                try:
                    enc = tokenizer(
                        src,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=max_length,
                        add_special_tokens=True
                    )
                    enc = {
                        k: (
                            v.to(device, non_blocking=True)
                            if isinstance(v, torch.Tensor)
                            else v
                        )
                        for k, v in enc.items()
                    }

                    if forced_id is not None:
                        try:
                            if mbart_obj is not None:
                                mbart_obj.config.forced_bos_token_id = int(forced_id)
                                mbart_obj.config.decoder_start_token_id = int(forced_id)
                        except Exception:
                            pass

                    out_ids = None
                    try:
                        gen_src = mbart_obj if mbart_obj is not None else core_model
                        if hasattr(gen_src, "generate"):
                            out_ids = gen_src.generate(
                                enc.get("input_ids"),
                                attention_mask=enc.get("attention_mask"),
                                max_length=max_length,
                                num_beams=4,
                                do_sample=False,
                                early_stopping=True,
                                pad_token_id=int(
                                    getattr(tokenizer, "pad_token_id", 1)
                                ),
                                forced_bos_token_id=int(forced_id)
                                if forced_id is not None
                                else None,
                                repetition_penalty=1.2,
                                no_repeat_ngram_size=3,
                                length_penalty=1.0,
                            )
                    except AttributeError:
                        if not _PROTOBUF_COMPAT_ERROR_SHOWN:
                            print(
                                "[VALIDATION] Warning: generation raised AttributeError (protobuf incompatibility)."
                            )
                            _PROTOBUF_COMPAT_ERROR_SHOWN = True
                        out_ids = None
                    except Exception as e:
                        print(
                            f"[VALIDATION] Generation error: {type(e).__name__}: {str(e)[:200]}"
                        )
                        out_ids = None

                    translation = ""
                    if out_ids is not None and (
                        (isinstance(out_ids, torch.Tensor) and out_ids.numel() > 0)
                        or (
                            isinstance(out_ids, (list, tuple))
                            and len(out_ids) > 0
                        )
                    ):
                        try:
                            if isinstance(out_ids, (list, tuple)):
                                translation = tokenizer.batch_decode(
                                    out_ids, skip_special_tokens=True
                                )[0] if out_ids else ""
                            else:
                                translation = (
                                    tokenizer.decode(
                                        out_ids[0], skip_special_tokens=True
                                    )
                                    if out_ids.size(0) > 0
                                    else ""
                                )
                        except AttributeError:
                            if not _PROTOBUF_COMPAT_ERROR_SHOWN:
                                print(
                                    "[VALIDATION] Warning: decode raised AttributeError (protobuf)."
                                )
                                _PROTOBUF_COMPAT_ERROR_SHOWN = True
                            translation = ""
                        except Exception as e:
                            print(
                                f"[VALIDATION] Decode error: {type(e).__name__}: {str(e)[:200]}"
                            )
                            translation = ""
                    else:
                        translation = ""

                    if translation:
                        validation_results['translations_success'] += 1
                    else:
                        validation_results['translations_failed'] += 1
                        print(
                            f"  {idx:2d}. Translation failed: {note[:30]:30s}"
                        )
                        continue

                    explanation_status = ""
                    try:
                        if 'translate_with_explanations' in globals():
                            res = translate_with_explanations(
                                model, tokenizer, src
                            )
                            exps = res.get('explanations', [])
                            validation_results['explanations_generated'] += len(
                                exps
                            )

                            if exps:
                                explanation_status = f"{len(exps)} expl"
                                for exp in exps:
                                    try:
                                        conf = exp.get('confidence', 0.5)
                                        confidences.append(float(conf))

                                        word = exp.get('ambiguous_word', '').strip()
                                        clean_word = (
                                            word.replace('▁', '')
                                            .replace('Ġ', '')
                                            .lower()
                                        )

                                        if clean_word in dscd_homographs:
                                            validation_results[
                                                'dscd_homographs_explained'
                                            ] += 1
                                            dscd_homograph_words_detected.add(
                                                clean_word
                                            )

                                        if clean_word in _HOMOGRAPH_REFERENCE_LIST:
                                            validation_results[
                                                'reference_homographs_explained'
                                            ] += 1
                                            reference_homograph_words_detected.add(
                                                clean_word
                                            )
                                    except Exception:
                                        pass
                            else:
                                explanation_status = "no expl"
                        else:
                            explanation_status = "unavailable"
                    except Exception as e:
                        explanation_status = f"error: {type(e).__name__}"

                    print(
                        f"  {idx:2d}. {explanation_status:15s} "
                        f"{note[:30]:30s} -> {translation[:200]}"
                    )
                    del enc
                    if out_ids is not None:
                        del out_ids

                except Exception as e:
                    validation_results['translations_failed'] += 1
                    print(
                        f"  {idx:2d}. ERROR: {note[:30]:30s} -> {type(e).__name__}"
                    )
                    if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                        try:
                            traceback.print_exc()
                        except Exception:
                            pass

        finally:
            try:
                if mbart_obj is not None and orig_use_cache is not None:
                    mbart_obj.config.use_cache = orig_use_cache
            except Exception:
                pass
            if torch.cuda.is_available():
                try:
                    torch.cuda.synchronize()
                except Exception:
                    pass

            clear_all_gpu_caches()

        print("\n" + "-" * 80)
        print("[VALIDATION] DSCD Prototype Quality Check:")
        try:
            dscd = core_model.dscd if hasattr(core_model, 'dscd') else None
            if dscd and hasattr(dscd, 'validate_prototypes'):
                lock = _resolve_dscd_lock(dscd)

                if lock:
                    with lock:
                        quality_results = dscd.validate_prototypes()
                else:
                    quality_results = dscd.validate_prototypes()

                validation_results['dscd_quality_score'] = quality_results.get(
                    'quality_score', 0.0
                )
                validation_results['dscd_multi_sense_tokens'] = quality_results.get(
                    'multi_sense_tokens', 0
                )
                validation_results['dscd_total_prototypes'] = quality_results.get(
                    'total_prototypes', 0
                )
                print(
                    f"  - Quality Score: {validation_results['dscd_quality_score']:.1%}"
                )
                print(
                    f"  - Multi-sense tokens: {validation_results['dscd_multi_sense_tokens']}"
                )
                print(
                    f"  - Total prototypes: {validation_results['dscd_total_prototypes']}"
                )
            else:
                print("  - Validation not available")
        except Exception as e:
            print(f"  - Validation failed: {type(e).__name__}")

        print("\n" + "-" * 80)
        print("[VALIDATION] ASBN Training Statistics:")
        try:
            asbn = core_model.asbn if hasattr(core_model, 'asbn') else None
            if asbn and hasattr(asbn, 'get_detailed_stats'):
                asbn_stats = asbn.get_detailed_stats()
                validation_results['asbn_domain_loss'] = asbn_stats.get(
                    'domain_loss', 0.0
                )
                validation_results['asbn_domain_accuracy'] = asbn_stats.get(
                    'domain_accuracy', 0.0
                )
                validation_results['asbn_source_accuracy'] = asbn_stats.get(
                    'source_accuracy', 0.0
                )
                validation_results['asbn_target_accuracy'] = asbn_stats.get(
                    'target_accuracy', 0.0
                )
                print(
                    f"  - Domain Loss: {validation_results['asbn_domain_loss']:.4f}"
                )
                print(
                    f"  - Domain Accuracy: {validation_results['asbn_domain_accuracy']:.2%}"
                )
                print(
                    f"  - Source Accuracy: {validation_results['asbn_source_accuracy']:.2%}"
                )
                print(
                    f"  - Target Accuracy: {validation_results['asbn_target_accuracy']:.2%}"
                )
            elif asbn and hasattr(asbn, 'get_asbn_stats'):
                asbn_stats = asbn.get_asbn_stats()
                validation_results['asbn_domain_loss'] = asbn_stats.get(
                    'domain_loss', 0.0
                )
                validation_results['asbn_domain_accuracy'] = asbn_stats.get(
                    'domain_accuracy', 0.0
                )
                print(
                    f"  - Domain Loss: {validation_results['asbn_domain_loss']:.4f}"
                )
                print(
                    f"  - Domain Accuracy: {validation_results['asbn_domain_accuracy']:.2%}"
                )
            else:
                print("  - ASBN statistics not available")
        except Exception as e:
            print(f"  - ASBN stats retrieval failed: {type(e).__name__}")

        print("\n" + "-" * 80)
        print("[VALIDATION] TRG Explanation Statistics:")
        try:
            trg = core_model.trg_system if hasattr(core_model, 'trg_system') else None
            if trg and hasattr(trg, 'get_statistics'):
                trg_stats = trg.get_statistics()
                validation_results['trg_total_explanations'] = trg_stats.get(
                    'explanations_generated', 0
                )
                print(
                    f"  - Total explanations: {validation_results['trg_total_explanations']}"
                )
                print(
                    f"  - High confidence rate: {trg_stats.get('high_confidence_rate', 0):.1%}"
                )
                print(
                    f"  - DSCD homograph rate: {trg_stats.get('dscd_homograph_rate', 0):.1%}"
                )
            else:
                print("  - TRG statistics not available")
        except Exception as e:
            print(f"  - TRG stats retrieval failed: {type(e).__name__}")

        if confidences:
            validation_results['avg_explanation_confidence'] = sum(
                confidences
            ) / len(confidences)

        print("-" * 80)
        print("\n[VALIDATION] Summary:")
        print(
            f"  - Translations: {validation_results['translations_success']}/{len(val_sentences)} successful"
        )
        print(
            f"  - Explanations generated: {validation_results['explanations_generated']}"
        )
        print(
            f"  - Avg explanation confidence: {validation_results['avg_explanation_confidence']:.3f}"
        )
        print(
            f"  - DSCD homographs explained: {validation_results['dscd_homographs_explained']}"
        )
        print(
            f"  - Reference homographs explained: {validation_results['reference_homographs_explained']}"
        )

        if dscd_homograph_words_detected:
            print(
                f"  - DSCD homographs detected: {', '.join(sorted(dscd_homograph_words_detected))}"
            )

        print(
            f"  - DSCD Quality Score: {validation_results['dscd_quality_score']:.1%}"
        )
        print(
            f"  - Multi-sense tokens: {validation_results['dscd_multi_sense_tokens']}"
        )
        print(
            f"  - ASBN Domain Accuracy: {validation_results['asbn_domain_accuracy']:.2%}"
        )

        warnings = []
        if validation_results['translations_failed'] > len(val_sentences) // 2:
            warnings.append("High translation failure rate")
        if validation_results['explanations_generated'] == 0:
            warnings.append("No explanations generated")
        if validation_results['dscd_quality_score'] < 0.3:
            warnings.append("Low DSCD quality score")
        if validation_results['dscd_multi_sense_tokens'] < 10:
            warnings.append("Very few multi-sense tokens")

        if warnings:
            print("\n[VALIDATION] Health Warnings:")
            for w in warnings:
                print(f"  - {w}")
        else:
            print("\n[VALIDATION] All systems healthy")

        validation_results['validation_completed'] = True

    except Exception as e:
        print(
            f"\n[VALIDATION] Critical error: {type(e).__name__}: {str(e)[:200]}"
        )
        if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
            try:
                traceback.print_exc()
            except Exception:
                pass
        validation_results['validation_completed'] = False

    finally:
        if was_training:
            core_model.train()
        clear_all_gpu_caches()

    print("=" * 80 + "\n")
    return validation_results

def _print_gpu_mem(prefix: str = ""):
    if not torch.cuda.is_available():
        return
    try:
        lines = [f"{prefix} GPU mem (GB):"]
        for i in range(torch.cuda.device_count()):
            try:
                alloc = torch.cuda.memory_allocated(i) / (1024**3)
                resv = torch.cuda.memory_reserved(i) / (1024**3)
                lines.append(f"  GPU {i}: alloc={alloc:.2f} resv={resv:.2f}")
            except Exception:
                lines.append(f"  GPU {i}: mem query failed")
        print("\n".join(lines))
    except Exception:
        pass

def _get_cluster_count(model: torch.nn.Module) -> int:
    try:
        core = model
        while hasattr(core, 'module'):
            core = core.module

        dscd = getattr(core, 'dscd', None)
        if dscd is None:
            return 0

        stores = _resolve_dscd_stores(dscd)
        lock = _resolve_dscd_lock(dscd)

        if lock:
            with lock:
                return len(stores)
        else:
            return len(stores)

    except Exception:
        return 0

def _get_dscd_safe(model: torch.nn.Module):
    try:
        core = model
        while hasattr(core, 'module'):
            core = core.module
        return getattr(core, 'dscd', None)
    except Exception:
        return None

def _print_top_clusters(model: torch.nn.Module, top_n: int = 5):
    dscd = _get_dscd_safe(model)
    if dscd is None:
        return

    try:
        dscd_homographs = _get_dscd_homographs(model)
        items = []
        homograph_items = []

        stores = _resolve_dscd_stores(dscd)
        lock = _resolve_dscd_lock(dscd)

        if lock:
            with lock:
                stores_snapshot = list(stores.items())
        else:
            stores_snapshot = list(stores.items())

        for token, store in stores_snapshot:
            try:
                total_count = sum(getattr(store, "counts", []) or [])
                protos = store.size() if hasattr(store, "size") else len(
                    getattr(store, "centroids", [])
                )
                clean_token = (
                    str(token).replace('▁', '').replace('Ġ', '').strip().lower()
                )
                is_homograph = clean_token in dscd_homographs
                buffers = getattr(dscd, 'buffers', {})
                item = (
                    token,
                    total_count,
                    protos,
                    len(buffers.get(token, [])),
                    is_homograph,
                )
                items.append(item)
                if is_homograph:
                    homograph_items.append(item)
            except Exception:
                continue

        items.sort(key=lambda x: x[1], reverse=True)

        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print("[CLUSTER-DBG] Top clusters:")
            for i, (tok, cnt, prot, buflen, is_homo) in enumerate(
                items[:top_n], 1
            ):
                marker = "HOMO" if is_homo else "    "
                print(
                    f"{marker} {i:2d}. {str(tok)[:20]:20s} "
                    f"samples={cnt:4d} protos={prot} buf={buflen}"
                )
            if homograph_items:
                print(
                    f"[CLUSTER-DBG] DSCD-discovered homographs: {len(homograph_items)}"
                )
                for tok, cnt, prot, buflen, _ in homograph_items[:5]:
                    print(
                        f"  HOMO {str(tok)[:20]:20s} samples={cnt:4d} protos={prot}"
                    )
    except Exception as e:
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print(f"[CLUSTER-DBG] _print_top_clusters error: {type(e).__name__}")

def _check_discovery_status(model: torch.nn.Module, global_step: int):
    try:
        core = model
        while hasattr(core, 'module'):
            core = core.module

        dscd = getattr(core, 'dscd', None)
        if dscd is None:
            return

        if hasattr(dscd, 'discovered_log') and dscd.discovered_log:
            total_discovered = len(dscd.discovered_log)

            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                print(
                    f"[DISCOVERY-STATUS] Step {global_step}: {total_discovered} discovery events"
                )

                recent = (
                    dscd.discovered_log[-3:]
                    if len(dscd.discovered_log) >= 3
                    else dscd.discovered_log
                )
                for entry in recent:
                    discovered = entry.get('discovered', 0)
                    candidates = entry.get('candidates', 0)
                    print(
                        f"  - {discovered}/{candidates} homographs discovered"
                    )
        else:
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                print(
                    f"[DISCOVERY-STATUS] No discoveries yet at step {global_step}"
                )
    except Exception as e:
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            print(f"[DISCOVERY-STATUS] Error: {e}")

def train_memory_efficient_tatn(
    model: torch.nn.Module,
    tokenizer,
    train_loader: torch.utils.data.DataLoader,
    optimizer: torch.optim.Optimizer,
    phi_optimizer: Optional[torch.optim.Optimizer] = None,
    epochs: Optional[int] = None,
    accumulation_steps: Optional[int] = None,
    validate_every: Optional[int] = None,
    enable_validation: bool = True,
    device: Optional[torch.device] = None
) -> torch.nn.Module:
    if epochs is None:
        epochs = _EPOCHS
    if accumulation_steps is None:
        accumulation_steps = _ACCUMULATION_STEPS
    if validate_every is None:
        validate_every = _VALIDATION_CHECK_INTERVAL
    if device is None:
        device = _DEVICE

    print(
        f"[TRAIN] Starting training: epochs={epochs}, batch={_BATCH_SIZE}, "
        f"accum_steps={accumulation_steps}"
    )
    print(
        f"[TRAIN] Validation: "
        f"{'enabled' if enable_validation and validate_every > 0 else 'disabled'}"
    )
    print(
        f"[TRAIN] DP enabled: {_USE_MULTI_GPU}, GPUs: {_NUM_GPUS}, Device: {device}"
    )
    print(
        f"[TRAIN] Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY} steps"
    )
    print(
        "[TRAIN] Checkpoint: Will save to /kaggle/working/tatn_final.pt "
        "after all epochs"
    )

    model.train()
    clear_all_gpu_caches()
    scaler = GradScaler(enabled=(_USE_AMP and torch.cuda.is_available()))

    global_step = 0
    accumulated_steps = 0
    pending_validation = False

    training_stats: Dict[str, Any] = {
        "total_loss": [],
        "epoch_losses": [],
        "backward_losses": [],
        "batches_processed": 0,
        "optimizer_updates": 0,
        "skipped_batches": 0,
        "oom_errors": 0,
        "runtime_errors": 0,
        "exceptions": 0,
        "epoch_validations": [],
        "dscd_quality_history": [],
        "multi_sense_ratio_history": [],
        "asbn_domain_accuracy_history": [],
        "trg_explanation_history": [],
        "synchronous_clustering_results": [],
    }

    last_forward_loss = 0.0
    last_backward_loss = 0.0

    for epoch in range(1, epochs + 1):
        epoch_start = time.time()
        epoch_losses: List[float] = []
        skip_reasons = defaultdict(int)

        print(f"\n{'='*80}")
        print(f"EPOCH {epoch}/{epochs} STARTED")
        print(f"{'='*80}")

        try:
            core = model.module if hasattr(model, 'module') else model
            trg = getattr(core, 'trg_system', None)
            if trg and hasattr(trg, 'reset_statistics'):
                try:
                    trg.reset_statistics()
                    print(f"[TRAIN] TRG statistics reset for epoch {epoch}")
                except Exception:
                    pass
        except Exception as e:
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                print(f"[TRAIN] TRG stats reset failed: {e}")

        try:
            core = model.module if hasattr(model, 'module') else model
            asbn = getattr(core, 'asbn', None)
            if asbn and hasattr(asbn, 'reset_stats'):
                try:
                    asbn.reset_stats()
                    print(f"[TRAIN] ASBN statistics reset for epoch {epoch}")
                except Exception:
                    pass
        except Exception as e:
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                print(f"[TRAIN] ASBN stats reset failed: {e}")

        try:
            optimizer.zero_grad(set_to_none=True)
        except Exception:
            pass

        progress = None
        try:
            progress = tqdm(
                train_loader,
                desc=f"Epoch {epoch}/{epochs}",
                dynamic_ncols=True,
            )

            for batch_idx, batch in enumerate(progress):
                global_step += 1
                training_stats["batches_processed"] += 1

                if (
                    _DEBUG_DISCOVERY or _VERBOSE_LOGGING
                ) and global_step % DEBUG_PRINT_INTERVAL == 0:
                    print(
                        f"\n[TRAIN-DEBUG] Epoch {epoch} Batch {batch_idx} "
                        f"GlobalStep {global_step}"
                    )
                    _check_discovery_status(model, global_step)

                if _PERIODIC_DISCOVERY_FREQUENCY and _PERIODIC_DISCOVERY_FREQUENCY > 0:
                    if global_step % _PERIODIC_DISCOVERY_FREQUENCY == 0:
                        try:
                            core = model.module if hasattr(model, 'module') else model
                            dscd = getattr(core, 'dscd', None)
                            if dscd and hasattr(dscd, 'periodic_discovery_check'):
                                print(f"\n[TRAIN] Triggering periodic discovery at step {global_step}...")
                                dscd.periodic_discovery_check(global_step, _PERIODIC_DISCOVERY_FREQUENCY)
                        except Exception as e:
                            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                                print(f"[TRAIN] Periodic discovery failed: {e}")

                if (
                    enable_validation
                    and validate_every
                    and validate_every > 0
                    and (global_step % validate_every == 0)
                ):
                    if accumulated_steps == 0:
                        try:
                            optimizer.zero_grad(set_to_none=True)
                        except Exception:
                            pass

                        val_result = comprehensive_epoch_validation(
                            model,
                            tokenizer,
                            epoch,
                            global_step,
                            _SOURCE_LANGUAGE,
                            _TARGET_LANGUAGE,
                            _MAX_LENGTH,
                            device,
                        )

                        if val_result:
                            training_stats['epoch_validations'].append(
                                val_result
                            )
                    else:
                        pending_validation = True

                if batch is None:
                    training_stats["skipped_batches"] += 1
                    skip_reasons["batch_none"] += 1
                    continue

                try:
                    input_ids = batch["input_ids"]
                    attention_mask = batch["attention_mask"]
                    labels = batch["labels"]

                    domain_labels = batch.get("domain_labels", None)
                    if domain_labels is not None:
                        if not isinstance(domain_labels, torch.Tensor):
                            domain_labels = None
                        elif domain_labels.dim() == 0:
                            domain_labels = domain_labels.unsqueeze(0)

                    if _USE_MULTI_GPU and _NUM_GPUS > 0:
                        bsz = int(input_ids.size(0))
                        keep = (bsz // _NUM_GPUS) * _NUM_GPUS
                        if keep == 0:
                            training_stats["skipped_batches"] += 1
                            skip_reasons["dp_keep_zero"] += 1
                            continue
                        if keep != bsz:
                            input_ids = input_ids[:keep]
                            attention_mask = attention_mask[:keep]
                            labels = labels[:keep]
                            if domain_labels is not None:
                                domain_labels = domain_labels[:keep]

                    input_ids = input_ids.to(device, non_blocking=True)
                    attention_mask = attention_mask.to(
                        device, non_blocking=True
                    )
                    labels = labels.to(device, non_blocking=True)

                    if domain_labels is not None:
                        domain_labels = domain_labels.to(
                            device, non_blocking=True
                        )

                    if input_ids.size(0) == 0:
                        training_stats["skipped_batches"] += 1
                        skip_reasons["empty_batch"] += 1
                        continue

                    forward_kwargs = {
                        "input_ids": input_ids,
                        "attention_mask": attention_mask,
                        "labels": labels,
                        "src_texts": batch.get("src_text", None),
                        "token_word_map": batch.get("token_word_map", None),
                        "track_stats": True,
                    }

                    amp_ctx = get_amp_ctx()
                    with amp_ctx:
                        forward_out = model(**forward_kwargs)

                        if isinstance(forward_out, torch.Tensor):
                            loss_tensor = forward_out
                        elif isinstance(forward_out, dict) and "loss" in forward_out:
                            loss_tensor = forward_out["loss"]
                        else:
                            if isinstance(forward_out, (list, tuple)) and len(
                                forward_out
                            ) > 0 and isinstance(
                                forward_out[0], torch.Tensor
                            ):
                                loss_tensor = forward_out[0]
                            else:
                                raise RuntimeError(
                                    "Model forward did not return a recognizable loss tensor"
                                )

                        if not isinstance(loss_tensor, torch.Tensor):
                            loss_tensor = torch.tensor(
                                float(loss_tensor), device=device
                            )
                        else:
                            loss_tensor = loss_tensor.to(device)

                        if loss_tensor.numel() > 1:
                            loss_val = float(loss_tensor.mean().item())
                            loss_tensor = loss_tensor.mean()
                        else:
                            loss_val = float(loss_tensor.item())

                        last_forward_loss = loss_val
                        epoch_losses.append(loss_val)
                        training_stats["total_loss"].append(loss_val)

                    loss_scaled = loss_tensor / max(1, accumulation_steps)
                    last_backward_loss = float(loss_scaled.item())
                    training_stats["backward_losses"].append(last_backward_loss)

                    if scaler.is_enabled():
                        scaler.scale(loss_scaled).backward()
                    else:
                        loss_scaled.backward()

                    accumulated_steps += 1

                    if accumulated_steps >= accumulation_steps:
                        try:
                            if scaler.is_enabled():
                                scaler.unscale_(optimizer)
                                torch.nn.utils.clip_grad_norm_(
                                    model.parameters(), _GRAD_CLIP_NORM
                                )
                                scaler.step(optimizer)
                                scaler.update()
                            else:
                                torch.nn.utils.clip_grad_norm_(
                                    model.parameters(), _GRAD_CLIP_NORM
                                )
                                optimizer.step()
                            optimizer.zero_grad(set_to_none=True)
                            training_stats["optimizer_updates"] += 1
                        except RuntimeError as e:
                            if "out of memory" in str(e).lower():
                                training_stats["oom_errors"] += 1
                                training_stats["skipped_batches"] += 1
                                skip_reasons["oom"] += 1
                                print(f"[OOM] OOM at step {global_step}")
                                try:
                                    optimizer.zero_grad(set_to_none=True)
                                except Exception:
                                    pass
                                for p in model.parameters():
                                    p.grad = None
                                clear_all_gpu_caches()
                                accumulated_steps = 0
                                continue
                            else:
                                training_stats["runtime_errors"] += 1
                                skip_reasons["opt_runtime"] += 1
                                print(
                                    f"[ERROR] Runtime error during optimizer step: {type(e).__name__}"
                                )
                        except Exception as e:
                            training_stats["exceptions"] += 1
                            skip_reasons["opt_exception"] += 1
                            print(
                                f"[ERROR] Exception during optimizer step: {type(e).__name__}"
                            )
                        finally:
                            accumulated_steps = 0
                            if pending_validation:
                                try:
                                    optimizer.zero_grad(set_to_none=True)
                                except Exception:
                                    pass

                                val_result = comprehensive_epoch_validation(
                                    model,
                                    tokenizer,
                                    epoch,
                                    global_step,
                                    _SOURCE_LANGUAGE,
                                    _TARGET_LANGUAGE,
                                    _MAX_LENGTH,
                                    device,
                                )

                                if val_result:
                                    training_stats['epoch_validations'].append(
                                        val_result
                                    )

                                pending_validation = False

                    if global_step % DEBUG_PRINT_INTERVAL == 0:
                        _print_gpu_mem("[TRAIN-DEBUG]")
                        cluster_count = _get_cluster_count(model)
                        print(
                            f"[TRAIN-DEBUG] step={global_step} "
                            f"loss={last_forward_loss:.4f} clusters={cluster_count}"
                        )
                        _print_top_clusters(model, top_n=5)

                    if global_step % _MEMORY_CLEANUP_FREQUENCY == 0:
                        clear_all_gpu_caches()

                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        training_stats["oom_errors"] += 1
                        training_stats["skipped_batches"] += 1
                        skip_reasons["oom"] += 1
                        print(f"[OOM] Caught OOM at step {global_step}")
                        try:
                            optimizer.zero_grad(set_to_none=True)
                        except Exception:
                            pass
                        for p in model.parameters():
                            p.grad = None
                        clear_all_gpu_caches()
                        accumulated_steps = 0
                        continue
                    else:
                        training_stats["runtime_errors"] += 1
                        training_stats["skipped_batches"] += 1
                        skip_reasons["runtime"] += 1
                        print(
                            f"[RUNTIME] RuntimeError at step {global_step}: {type(e).__name__}"
                        )
                        try:
                            optimizer.zero_grad(set_to_none=True)
                        except Exception:
                            pass
                        accumulated_steps = 0
                        continue
                except Exception as e:
                    training_stats["exceptions"] += 1
                    training_stats["skipped_batches"] += 1
                    skip_reasons["exceptions"] += 1
                    print(
                        f"[EXCEPTION] Exception at step {global_step}: {type(e).__name__}"
                    )
                    if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                        try:
                            traceback.print_exc()
                        except Exception:
                            pass
                    try:
                        optimizer.zero_grad(set_to_none=True)
                    except Exception:
                        pass
                    accumulated_steps = 0
                    continue

                processed_batches = (
                    training_stats["batches_processed"]
                    - training_stats["skipped_batches"]
                )
                expected_updates = max(
                    1,
                    math.floor(
                        processed_batches / max(1, accumulation_steps)
                    ),
                )
                success_rate = (
                    100.0
                    * training_stats["optimizer_updates"]
                    / expected_updates
                    if expected_updates > 0
                    else 0.0
                )
                cluster_count = _get_cluster_count(model)

                next_disc_str = "N/A"
                try:
                    if (
                        _PERIODIC_DISCOVERY_FREQUENCY
                        and _PERIODIC_DISCOVERY_FREQUENCY > 0
                    ):
                        steps_to_next = (
                            _PERIODIC_DISCOVERY_FREQUENCY
                            - (global_step % _PERIODIC_DISCOVERY_FREQUENCY)
                        )
                        next_disc_str = f"next_disc_in={steps_to_next}"
                except Exception:
                    next_disc_str = "next_disc_err"

                progress.set_postfix_str(
                    f"fwd_loss={last_forward_loss:.4f} "
                    f"bwd_loss={last_backward_loss:.4f} "
                    f"rate={success_rate:.1f}% "
                    f"clusters={cluster_count} {next_disc_str}"
                )

        finally:
            if progress is not None:
                try:
                    progress.close()
                except Exception:
                    pass

        if accumulated_steps > 0:
            try:
                if scaler.is_enabled():
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), _GRAD_CLIP_NORM
                    )
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    torch.nn.utils.clip_grad_norm_(
                        model.parameters(), _GRAD_CLIP_NORM
                    )
                    optimizer.step()
                optimizer.zero_grad(set_to_none=True)
                training_stats["optimizer_updates"] += 1
            except Exception as e:
                print(
                    f"[EPOCH-FLUSH] Exception on epoch flush: {type(e).__name__}"
                )
            finally:
                accumulated_steps = 0

        epoch_duration_min = (time.time() - epoch_start) / 60.0
        processed_batches = (
            training_stats["batches_processed"]
            - training_stats["skipped_batches"]
        )
        expected_updates = max(
            1,
            math.floor(processed_batches / max(1, accumulation_steps)),
        )
        success_rate = (
            100.0
            * training_stats["optimizer_updates"]
            / expected_updates
            if expected_updates > 0
            else 0.0
        )
        cluster_count = _get_cluster_count(model)

        avg_epoch_loss = float(np.mean(epoch_losses)) if epoch_losses else 0.0
        training_stats["epoch_losses"].append(avg_epoch_loss)

        print("\n" + "=" * 80)
        print(f"EPOCH {epoch}/{epochs} SUMMARY")
        print("=" * 80)
        print(f"  Duration (min): {epoch_duration_min:.2f}")
        print(f"  Optimizer updates: {training_stats['optimizer_updates']}")
        print(
            f"  Batches: processed={processed_batches}, "
            f"skipped={training_stats['skipped_batches']}"
        )
        print(f"  Success rate: {success_rate:.1f}%")
        print(f"  Clustered tokens: {cluster_count}")
        print(f"  Avg epoch loss: {avg_epoch_loss:.6f}")
        if skip_reasons:
            print("  Skip reasons:")
            for k, v in sorted(skip_reasons.items(), key=lambda x: -x[1]):
                print(f"    - {k}: {v}")
        print("=" * 80)

        try:
            print(f"\n[TRAIN] Running synchronous DSCD clustering after epoch {epoch}...")
            clustering_result = synchronous_dscd_clustering(model, force=True)
            training_stats['synchronous_clustering_results'].append(clustering_result)
            
            if clustering_result.get('success', False):
                print(f"[TRAIN] Clustered {clustering_result.get('clustered', 0)} tokens")
                print(f"[TRAIN] Discovered {clustering_result.get('discovered', 0)} new homographs")
                print(f"[TRAIN] Total homographs: {clustering_result.get('total_homographs', 0)}")
            else:
                print(f"[TRAIN] Clustering incomplete: {clustering_result.get('reason', 'unknown')}")
        
        except Exception as e:
            print(f"[TRAIN] Epoch-end clustering failed: {type(e).__name__}")
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        try:
            print(
                f"\n[TRAIN] Running comprehensive validation after epoch {epoch}..."
            )

            try:
                optimizer.zero_grad(set_to_none=True)
            except Exception:
                pass

            validation_results = comprehensive_epoch_validation(
                model=model,
                tokenizer=tokenizer,
                epoch=epoch,
                global_step=global_step,
                source_lang=_SOURCE_LANGUAGE,
                target_lang=_TARGET_LANGUAGE,
                max_length=_MAX_LENGTH,
                device=device,
            )

            if validation_results and validation_results.get(
                'validation_completed', False
            ):
                training_stats['epoch_validations'].append(
                    validation_results
                )
                training_stats['dscd_quality_history'].append(
                    validation_results.get('dscd_quality_score', 0.0)
                )
                training_stats['asbn_domain_accuracy_history'].append(
                    validation_results.get('asbn_domain_accuracy', 0.0)
                )
                training_stats['trg_explanation_history'].append(
                    validation_results.get('trg_total_explanations', 0)
                )

                try:
                    dscd = (
                        model.module.dscd
                        if hasattr(model, 'module')
                        else getattr(model, 'dscd', None)
                    )

                    if dscd is not None:
                        stores = _resolve_dscd_stores(dscd)
                        lock = _resolve_dscd_lock(dscd)

                        if lock:
                            with lock:
                                total_tokens = len(stores)
                        else:
                            total_tokens = len(stores)

                        multi_sense = validation_results.get(
                            'dscd_multi_sense_tokens', 0
                        )
                        ratio = (
                            multi_sense / total_tokens
                            if total_tokens > 0
                            else 0.0
                        )
                        training_stats['multi_sense_ratio_history'].append(
                            ratio
                        )
                    else:
                        training_stats['multi_sense_ratio_history'].append(
                            0.0
                        )
                except Exception:
                    training_stats['multi_sense_ratio_history'].append(0.0)
            else:
                print("[TRAIN] Validation incomplete")

        except Exception as e:
            print(f"[TRAIN] Epoch validation failed: {type(e).__name__}")
            if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

    print(f"\n{'='*80}")
    print("TRAINING COMPLETE - RUNNING FINAL CLUSTERING")
    print(f"{'='*80}")

    try:
        print("\n[TRAIN] Running final synchronous DSCD clustering...")
        final_clustering_result = synchronous_dscd_clustering(model, force=True)
        training_stats['final_clustering_result'] = final_clustering_result
        
        if final_clustering_result.get('success', False):
            print(f"[TRAIN] Final clustering: {final_clustering_result.get('clustered', 0)} tokens")
            print(f"[TRAIN] Final discovery: {final_clustering_result.get('discovered', 0)} new homographs")
            print(f"[TRAIN] Total homographs: {final_clustering_result.get('total_homographs', 0)}")
        else:
            print(f"[TRAIN] Final clustering incomplete: {final_clustering_result.get('reason', 'unknown')}")
    
    except Exception as e:
        print(f"[TRAIN] Final clustering failed: {type(e).__name__}")
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            try:
                traceback.print_exc()
            except Exception:
                pass

    print(f"\n{'='*80}")
    print("SAVING FINAL CHECKPOINT")
    print(f"{'='*80}")

    try:
        checkpoint_path = Path("/kaggle/working/tatn_final.pt")

        core_model = model.module if hasattr(model, 'module') else model

        dscd_state = {}
        try:
            if hasattr(core_model, 'dscd'):
                try:
                    dscd_state = core_model.dscd.state_dict()
                except Exception:
                    dscd_state = {}
        except Exception:
            dscd_state = {}

        checkpoint_data = {
            'epochs_trained': epochs,
            'global_steps': global_step,
            'final_train_loss': training_stats['epoch_losses'][-1]
            if training_stats['epoch_losses']
            else 0.0,
            'model_state_dict': core_model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scaler_state_dict': scaler.state_dict()
            if scaler is not None
            else None,
            'training_stats': training_stats,
            'dscd_state': dscd_state,
            'config': {
                'SPAN_THRESHOLD': globals().get('SPAN_THRESHOLD', 0.20),
                'TAU_LOW': globals().get('TAU_LOW', 0.15),
                'LAMBDA_ASBN': globals().get('LAMBDA_ASBN', 0.05),
                'LAMBDA_DSCD': globals().get('LAMBDA_DSCD', 0.15),
                'TRG_TEMPERATURE': globals().get('TRG_TEMPERATURE', 1.0),
                'PERIODIC_DISCOVERY_FREQUENCY': _PERIODIC_DISCOVERY_FREQUENCY,
                'NUM_EPOCHS': epochs,
                'BATCH_SIZE': _BATCH_SIZE,
                'LEARNING_RATE': optimizer.param_groups[0]['lr']
                if optimizer and optimizer.param_groups
                else 0.0,
            },
        }

        torch.save(checkpoint_data, checkpoint_path)

        file_size_mb = checkpoint_path.stat().st_size / (1024**2)

        print("\nFINAL CHECKPOINT SAVED")
        print(f"   Path: {checkpoint_path}")
        print(f"   Size: {file_size_mb:.2f} MB")
        print(f"   Epochs trained: {epochs}")
        print(f"   Global steps: {global_step}")
        print(
            f"   Final train loss: "
            f"{training_stats['epoch_losses'][-1] if training_stats['epoch_losses'] else 0.0:.4f}"
        )
        print(f"{'='*80}\n")

    except Exception as e:
        print(f"FINAL CHECKPOINT SAVE FAILED: {type(e).__name__}")
        if _DEBUG_DISCOVERY or _VERBOSE_LOGGING:
            try:
                traceback.print_exc()
            except Exception:
                pass

    print("\n" + "=" * 80)
    print("TRAINING COMPLETED - FINAL SUMMARY")
    print("=" * 80)

    processed_batches = (
        training_stats["batches_processed"]
        - training_stats["skipped_batches"]
    )
    expected_updates = max(
        1,
        math.floor(processed_batches / max(1, accumulation_steps)),
    )
    success_rate = (
        100.0
        * training_stats["optimizer_updates"]
        / expected_updates
        if expected_updates > 0
        else 0.0
    )

    print(f"[TRAIN] Success Rate: {success_rate:.1f}%")
    print(f"[TRAIN] Total Steps: {global_step}")
    print(
        f"[TRAIN] Clustered Token Types: {_get_cluster_count(model)}"
    )

    if training_stats['dscd_quality_history']:
        print("\n[TRAIN] DSCD Quality Score Trend:")
        for i, score in enumerate(training_stats['dscd_quality_history'], 1):
            print(f"  Epoch {i}: {score:.1%}")

    if training_stats['asbn_domain_accuracy_history']:
        print("\n[TRAIN] ASBN Domain Accuracy Trend:")
        for i, acc in enumerate(
            training_stats['asbn_domain_accuracy_history'], 1
        ):
            print(f"  Epoch {i}: {acc:.1%}")

    if training_stats['trg_explanation_history']:
        print("\n[TRAIN] TRG Explanation Count Trend:")
        for i, count in enumerate(
            training_stats['trg_explanation_history'], 1
        ):
            print(f"  Epoch {i}: {count} explanations")

    if training_stats.get('synchronous_clustering_results'):
        print("\n[TRAIN] Synchronous Clustering Results:")
        for i, result in enumerate(training_stats['synchronous_clustering_results'], 1):
            if result.get('success'):
                print(f"  Epoch {i}: {result.get('clustered', 0)} tokens, {result.get('discovered', 0)} homographs")

    print("=" * 80)
    return model

print("\n" + "=" * 80)
print("Cell 7: Training loop ready (PURE UNSUPERVISED) - ALL FIXES APPLIED")
print("=" * 80)
print("NEW FEATURES:")
print("  ✓ Synchronous clustering after each epoch")
print("  ✓ Forced discovery before validation")
print("  ✓ Final clustering before checkpoint save")
print("=" * 80)



Cell 7: Training loop ready (PURE UNSUPERVISED) - ALL FIXES APPLIED
NEW FEATURES:
  ✓ Synchronous clustering after each epoch
  ✓ Forced discovery before validation
  ✓ Final clustering before checkpoint save


In [11]:
# ==============================================================================
# CELL 8: INFERENCE & EVALUATION PIPELINE - COMPLETE FIXED VERSION
# ==============================================================================
import os
import time
import math
import torch
import traceback
import unicodedata
from typing import List, Dict, Any, Tuple, Optional
from collections import defaultdict
import threading
import gc
from transformers.modeling_outputs import BaseModelOutput

try:
    SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
    TARGET_LANGUAGE = str(TARGET_LANGUAGE)
except (NameError, TypeError):
    SOURCE_LANGUAGE = "bn"
    TARGET_LANGUAGE = "en"

try:
    MAX_LENGTH = int(MAX_LENGTH)
except (NameError, ValueError, TypeError):
    MAX_LENGTH = 48

try:
    DEVICE = DEVICE
except (NameError, TypeError):
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except (NameError, TypeError):
    VERBOSE_LOGGING = False

try:
    DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except (NameError, TypeError):
    DEBUG_DISCOVERY = False

try:
    DEBUG_TIMING = bool(DEBUG_TIMING)
except (NameError, TypeError):
    DEBUG_TIMING = False

try:
    USE_MULTI_GPU = bool(USE_MULTI_GPU)
except (NameError, TypeError):
    USE_MULTI_GPU = torch.cuda.is_available() and torch.cuda.device_count() > 1

try:
    REAL_AMB_SPAN_THRESHOLD = float(SPAN_THRESHOLD)
except (NameError, ValueError, TypeError):
    REAL_AMB_SPAN_THRESHOLD = 0.05

try:
    REAL_AMB_UNCERTAINTY_THRESHOLD = float(TAU_LOW)
except (NameError, ValueError, TypeError):
    REAL_AMB_UNCERTAINTY_THRESHOLD = 0.15

try:
    HOMOGRAPH_REFERENCE_LIST = set(w.lower() for w in HOMOGRAPH_REFERENCE_LIST_BN)
except (NameError, TypeError):
    HOMOGRAPH_REFERENCE_LIST = {
        "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা",
        "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত"
    }
    HOMOGRAPH_REFERENCE_LIST = set(w.lower() for w in HOMOGRAPH_REFERENCE_LIST)

try:
    M2M100_EN_TOKEN_ID = int(M2M100_EN_TOKEN_ID)
except (NameError, ValueError, TypeError):
    M2M100_EN_TOKEN_ID = 128022

try:
    M2M100_BN_TOKEN_ID = int(M2M100_BN_TOKEN_ID)
except (NameError, ValueError, TypeError):
    M2M100_BN_TOKEN_ID = 128025

SUBWORD_PUNCT_SET = set(".,!?-;:|")

_has_reconstruct_word_spans = "reconstruct_word_spans" in globals()

def build_token_word_map_with_cell2(
    text: str,
    tokenizer,
    max_length: int = None
) -> Tuple[Dict[int, int], List[str]]:
    if max_length is None:
        max_length = MAX_LENGTH
    
    try:
        if _has_reconstruct_word_spans:
            token_to_word_map, word_list = reconstruct_word_spans(
                tokenizer,
                text,
                max_length=max_length
            )
            
            word_list_lower = [w.lower() for w in word_list]
            
            return token_to_word_map, word_list_lower
        else:
            return build_token_word_map_fallback(text, tokenizer, max_length)
    except Exception as e:
        if DEBUG_DISCOVERY or VERBOSE_LOGGING:
            print(f"[INF] build_token_word_map_with_cell2 error: {e}")
        return build_token_word_map_fallback(text, tokenizer, max_length)

def build_token_word_map_fallback(
    text: str,
    tokenizer,
    max_length: int = None
) -> Tuple[Dict[int, int], List[str]]:
    if max_length is None:
        max_length = MAX_LENGTH
    
    try:
        text = str(text).strip()
        if not text:
            return {}, []
        
        try:
            toks = tokenizer.tokenize(text)
        except Exception:
            return {}, []
        
        if len(toks) > max_length:
            toks = toks[:max_length]
        
        token_to_word_idx: Dict[int, int] = {}
        word_list: List[str] = []
        current_word_parts: List[str] = []
        current_word_idx = 0
        
        for token_idx, tok in enumerate(toks):
            t = str(tok)
            clean = t.replace("▁", "").replace("##", "").replace("Ġ", "").replace("@@", "").strip().lower()
            
            if not clean:
                continue
            
            if t.startswith("▁") or t.startswith("Ġ"):
                if current_word_parts:
                    complete_word = "".join(current_word_parts)
                    if complete_word:
                        word_list.append(complete_word)
                        current_word_idx = len(word_list) - 1
                    current_word_parts = []
                
                current_word_parts = [clean]
                token_to_word_idx[token_idx] = len(word_list)
            else:
                current_word_parts.append(clean)
                token_to_word_idx[token_idx] = len(word_list)
        
        if current_word_parts:
            complete_word = "".join(current_word_parts)
            if complete_word:
                word_list.append(complete_word)
        
        return token_to_word_idx, word_list
    
    except Exception as e:
        if DEBUG_DISCOVERY or VERBOSE_LOGGING:
            print(f"[INF] build_token_word_map_fallback error: {e}")
        return {}, []

def get_dscd_homographs(model: torch.nn.Module) -> set:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return set()

        if hasattr(dscd, "get_discovered_homographs"):
            try:
                return dscd.get_discovered_homographs()
            except Exception:
                pass

        def _store_size(store) -> int:
            try:
                if hasattr(store, "size") and callable(store.size):
                    return int(store.size())
            except Exception:
                pass
            try:
                cents = getattr(store, "centroids", None)
                return int(len(cents)) if cents is not None else 0
            except Exception:
                return 0

        homographs = set()
        lock = None
        if hasattr(dscd, "bufferlock"):
            lock = dscd.bufferlock
        elif hasattr(dscd, "clusteringlock"):
            lock = dscd.clusteringlock

        if lock:
            with lock:
                items = list(getattr(dscd, "prototypestores", {}).items())
        else:
            items = list(getattr(dscd, "prototypestores", {}).items())

        for token, store in items:
            try:
                if _store_size(store) >= 2:
                    clean_token = (
                        str(token)
                        .replace("▁", "")
                        .replace("##", "")
                        .replace("Ġ", "")
                        .strip()
                        .lower()
                    )
                    homographs.add(clean_token)
            except Exception:
                continue

        return homographs
    except Exception:
        return set()

def synchronous_clustering_for_warmup(model: torch.nn.Module) -> Dict[str, Any]:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        
        if dscd is None:
            return {'success': False, 'reason': 'no_dscd'}
        
        lock = None
        if hasattr(dscd, "bufferlock"):
            lock = dscd.bufferlock
        elif hasattr(dscd, "clusteringlock"):
            lock = dscd.clusteringlock
        
        if lock:
            with lock:
                stores_list = list(getattr(dscd, "prototypestores", {}).keys())
        else:
            stores_list = list(getattr(dscd, "prototypestores", {}).keys())
        
        if not stores_list:
            return {'success': True, 'clustered': 0, 'reason': 'no_stores'}
        
        cluster_count = 0
        failed_count = 0
        
        for token in stores_list:
            try:
                if hasattr(dscd, 'clusterbuffertoprototypeshierarchical'):
                    if lock:
                        with lock:
                            success = dscd.clusterbuffertoprototypeshierarchical(token)
                    else:
                        success = dscd.clusterbuffertoprototypeshierarchical(token)
                    
                    if success:
                        cluster_count += 1
                    else:
                        failed_count += 1
                else:
                    failed_count += 1
            except Exception:
                failed_count += 1
                continue
        
        discovery_result = {}
        if hasattr(dscd, 'discover_homographs'):
            try:
                if lock:
                    with lock:
                        discovery_result = dscd.discover_homographs()
                else:
                    discovery_result = dscd.discover_homographs()
            except Exception:
                pass
        
        result = {
            'success': True,
            'clustered': cluster_count,
            'failed': failed_count,
            'total_tokens': len(stores_list),
            'discovered': discovery_result.get('discovered', 0),
            'total_homographs': discovery_result.get('total_homographs', 0),
        }
        
        return result
        
    except Exception as e:
        return {'success': False, 'reason': str(e)[:200]}

class InferenceStatistics:
    def __init__(self):
        self.lock = threading.Lock()
        self.reset()

    def reset(self):
        with self.lock:
            self.total_inferences = 0
            self.successful_translations = 0
            self.failed_translations = 0
            self.total_explanations = 0
            self.high_confidence_explanations = 0
            self.low_confidence_explanations = 0
            self.total_confidence = 0.0
            self.dscd_homographs_explained = set()
            self.reference_homographs_explained = set()
            self.avg_span = 0.0
            self.avg_uncertainty = 0.0
            self.dscd_empty_warnings = 0
            self.token_counts = defaultdict(int)
            self.token_confidences = defaultdict(list)

    def record_inference(self, result: Dict[str, Any], dscd_homographs: Optional[set] = None):
        with self.lock:
            self.total_inferences += 1
            if result.get("translation") and result["translation"] != "ERROR DURING TRANSLATION":
                self.successful_translations += 1
            else:
                self.failed_translations += 1

            explanations = result.get("explanations", [])
            self.total_explanations += len(explanations)

            for exp in explanations:
                try:
                    conf = exp.get("confidence", 0.5)
                    self.total_confidence += float(conf)
                    if conf >= 0.65:
                        self.high_confidence_explanations += 1
                    elif conf <= 0.4:
                        self.low_confidence_explanations += 1

                    word = str(exp.get("ambiguous_word", exp.get("token", ""))).strip()
                    clean_word = word.replace("▁", "").replace("##", "").replace("Ġ", "").lower().strip()
                    if not clean_word:
                        continue

                    self.token_counts[clean_word] += 1
                    self.token_confidences[clean_word].append(float(conf))

                    if dscd_homographs and clean_word in dscd_homographs:
                        self.dscd_homographs_explained.add(clean_word)

                    if clean_word in HOMOGRAPH_REFERENCE_LIST:
                        self.reference_homographs_explained.add(clean_word)

                    self.avg_span += float(exp.get("span", 0.0))
                    self.avg_uncertainty += float(exp.get("uncertainty", 0.0))
                except Exception:
                    pass

    def get_summary(self) -> Dict[str, Any]:
        with self.lock:
            total_exp = max(self.total_explanations, 1)
            unique_tokens = len(self.token_counts)
            diversity_ratio = unique_tokens / total_exp if total_exp > 0 else 0.0

            return {
                "total_inferences": self.total_inferences,
                "successful_translations": self.successful_translations,
                "failed_translations": self.failed_translations,
                "success_rate": self.successful_translations / max(self.total_inferences, 1),
                "total_explanations": self.total_explanations,
                "explanations_per_inference": self.total_explanations / max(self.total_inferences, 1),
                "high_confidence_rate": self.high_confidence_explanations / total_exp,
                "low_confidence_rate": self.low_confidence_explanations / total_exp,
                "avg_confidence": self.total_confidence / total_exp,
                "avg_span": self.avg_span / total_exp,
                "avg_uncertainty": self.avg_uncertainty / total_exp,
                "dscd_homographs_explained": list(self.dscd_homographs_explained),
                "reference_homographs_explained": list(self.reference_homographs_explained),
                "dscd_empty_warnings": self.dscd_empty_warnings,
                "unique_tokens_explained": unique_tokens,
                "diversity_ratio": diversity_ratio,
            }

    def print_summary(self):
        summary = self.get_summary()
        print("=" * 80)
        print("INFERENCE STATISTICS SUMMARY")
        print("=" * 80)
        print(f"Total inferences: {summary['total_inferences']}")
        print(f"Success rate: {summary['success_rate']:.1%}")
        print(f"Total explanations: {summary['total_explanations']}")
        print(f"Explanations per inference: {summary['explanations_per_inference']:.2f}")
        print(f"Unique tokens explained: {summary['unique_tokens_explained']}")
        print(f"Diversity ratio: {summary['diversity_ratio']:.2%}")
        print(f"Avg confidence: {summary['avg_confidence']:.3f}")
        print(f"High confidence rate: {summary['high_confidence_rate']:.1%}")
        print(f"Avg span: {summary['avg_span']:.3f}")
        print(f"Avg uncertainty: {summary['avg_uncertainty']:.3f}")

        if summary["dscd_homographs_explained"]:
            print(f"{len(summary['dscd_homographs_explained'])} DSCD homographs explained")
            print(f"  {', '.join(summary['dscd_homographs_explained'][:10])}...")

        if summary["reference_homographs_explained"]:
            print(f"{len(summary['reference_homographs_explained'])} Reference homographs explained")
            print(f"  {', '.join(summary['reference_homographs_explained'][:10])}...")

        if summary["dscd_empty_warnings"] > 0:
            print(f"DSCD empty warnings: {summary['dscd_empty_warnings']}")
        print("=" * 80)

INFERENCE_STATS = InferenceStatistics()

def to_device_batch(enc: Any, device: torch.device):
    try:
        if hasattr(enc, "to"):
            return enc.to(device)
    except Exception:
        pass

    if isinstance(enc, dict):
        out = {}
        for k, v in enc.items():
            try:
                if isinstance(v, torch.Tensor):
                    out[k] = v.to(device)
                elif isinstance(v, dict):
                    out[k] = to_device_batch(v, device)
                elif isinstance(v, (list, tuple)):
                    out[k] = [t.to(device) if isinstance(t, torch.Tensor) else t for t in v]
                else:
                    out[k] = v
            except Exception:
                out[k] = v
        return out
    return enc

def extract_dscd_outputs(raw_out: Any) -> Dict[str, Any]:
    if raw_out is None:
        return {}

    if isinstance(raw_out, dict):
        if "dscd_outputs" in raw_out and isinstance(raw_out["dscd_outputs"], dict):
            return raw_out["dscd_outputs"]
        if "dscd" in raw_out and isinstance(raw_out["dscd"], dict):
            return raw_out["dscd"]
        if "explanations" in raw_out or "protoprobs" in raw_out:
            return raw_out
        for key in ["dscd_outputs", "dscd", "dscdout"]:
            if key in raw_out and isinstance(raw_out[key], dict):
                return raw_out[key]
        return raw_out

    if isinstance(raw_out, (list, tuple)):
        for item in raw_out:
            if isinstance(item, dict):
                return extract_dscd_outputs(item)
    return {}

def get_explanations_list(dscd: Dict[str, Any]) -> List[List[Dict[str, Any]]]:
    if not dscd:
        return []

    expl = dscd.get("explanations", None)
    if expl is None:
        for alt in ["explanations_per_sentence", "trg_explanations", "exps"]:
            if alt in dscd:
                expl = dscd[alt]
                break

    if expl is None:
        return []

    if isinstance(expl, list):
        if len(expl) > 0 and isinstance(expl[0], dict):
            return [expl]
        if len(expl) > 0 and isinstance(expl[0], list):
            return expl

    return []

def is_subword_token(token: str) -> bool:
    if not token or len(token.strip()) == 0:
        return True
    token = token.strip()
    if token.startswith("##") or token.startswith("▁") or token.startswith("Ġ") or token.startswith("_"):
        return True
    if len(token) < 2:
        return True
    if len(token) == 1 and (token in SUBWORD_PUNCT_SET or token.isdigit()):
        return True
    return False

def should_filter_explanation(expl: Dict[str, Any], span_th: float, u_th: float) -> bool:
    try:
        token = expl.get("ambiguous_word", expl.get("token", ""))
        span = float(expl.get("span", 0.0))
        uncertainty = float(expl.get("uncertainty", 0.0))

        if is_subword_token(str(token)):
            return True

        if span < span_th and uncertainty < u_th:
            return True

        return False
    except Exception:
        return True

def force_english_bos(tokenizer, mbart_model) -> int:
    forced_id = None
    try:
        if hasattr(tokenizer, "get_lang_id"):
            for code in [TARGET_LANGUAGE, "en_XX", "en", "eng"]:
                try:
                    lid = tokenizer.get_lang_id(code)
                    if lid is not None:
                        forced_id = int(lid)
                        break
                except Exception:
                    continue
        
        if forced_id is None and hasattr(tokenizer, "lang_code_to_id"):
            try:
                forced_id = tokenizer.lang_code_to_id.get(TARGET_LANGUAGE, None)
                if forced_id is not None:
                    forced_id = int(forced_id)
            except Exception:
                pass
    except Exception:
        pass

    if forced_id is None:
        forced_id = M2M100_EN_TOKEN_ID

    if forced_id is not None and hasattr(mbart_model, "config"):
        try:
            mbart_model.config.forced_bos_token_id = int(forced_id)
            mbart_model.config.decoder_start_token_id = int(forced_id)
        except Exception:
            if DEBUG_DISCOVERY or VERBOSE_LOGGING:
                print("[INF] Could not set forced BOS on mbart config")

    return forced_id

def safe_generate(mbart, input_ids=None, encoder_outputs=None, attention_mask=None, max_length=64, num_beams=2, **kwargs):
    try:
        if encoder_outputs is not None:
            return mbart.generate(
                input_ids=None,
                encoder_outputs=encoder_outputs,
                attention_mask=attention_mask,
                max_length=max_length,
                num_beams=num_beams,
                do_sample=False,
                early_stopping=True,
                num_return_sequences=1,
                repetition_penalty=1.2,
                no_repeat_ngram_size=3,
                length_penalty=1.0,
                **kwargs
            )
        else:
            return mbart.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_length=max_length,
                num_beams=num_beams,
                do_sample=False,
                early_stopping=True,
                num_return_sequences=1,
                repetition_penalty=1.2,
                no_repeat_ngram_size=3,
                length_penalty=1.0,
                **kwargs
            )
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            if DEBUG_DISCOVERY:
                print("[INF] OOM during generation, reducing beam size...")
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            if encoder_outputs is not None:
                return mbart.generate(
                    input_ids=None,
                    encoder_outputs=encoder_outputs,
                    attention_mask=attention_mask,
                    max_length=min(max_length, 48),
                    num_beams=1,
                    do_sample=False,
                    early_stopping=True,
                    num_return_sequences=1,
                    repetition_penalty=1.2,
                    no_repeat_ngram_size=3,
                    length_penalty=1.0,
                    **kwargs
                )
            else:
                return mbart.generate(
                    input_ids=input_ids,
                    attention_mask=attention_mask,
                    max_length=min(max_length, 48),
                    num_beams=1,
                    do_sample=False,
                    early_stopping=True,
                    num_return_sequences=1,
                    repetition_penalty=1.2,
                    no_repeat_ngram_size=3,
                    length_penalty=1.0,
                    **kwargs
                )
        else:
            raise

def translate_with_explanations(
    model,
    tokenizer,
    input_sentence: str,
    device: Optional[torch.device] = None,
    span_threshold: Optional[float] = None,
    uncertainty_threshold: Optional[float] = None,
    track_stats: bool = True
) -> Dict[str, Any]:

    device = DEVICE if device is None else device
    span_th = REAL_AMB_SPAN_THRESHOLD if span_threshold is None else float(span_threshold)
    u_th = REAL_AMB_UNCERTAINTY_THRESHOLD if uncertainty_threshold is None else float(uncertainty_threshold)

    if DEBUG_DISCOVERY or VERBOSE_LOGGING:
        print(f"[INF] Starting inference")
        print(f"[INF] Input: {input_sentence[:60]}...")
        print(f"[INF] Thresholds: span={span_th:.2f}, uncertainty={u_th:.2f}")

    cleanup_vars = []
    dscd = None
    mbart = None
    encoder_hidden = None
    encoder_hidden_adjusted = None
    dscd_homographs = get_dscd_homographs(model)

    try:
        try:
            tokenizer.src_lang = SOURCE_LANGUAGE
        except Exception:
            pass

        token_to_word_idx_map, word_list = build_token_word_map_with_cell2(
            input_sentence,
            tokenizer,
            max_length=MAX_LENGTH
        )
        
        token_word_map_batch = [token_to_word_idx_map]

        enc = tokenizer(
            input_sentence,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LENGTH,
            add_special_tokens=True
        )
        enc = to_device_batch(enc, device)
        cleanup_vars.append(enc)

        model.eval()
        core = model.module if USE_MULTI_GPU and hasattr(model, "module") else model
        src_texts = [input_sentence]
        dscd_validated = False

        try:
            dscd = core.dscd if hasattr(core, "dscd") else None
            if dscd:
                lock = None
                if hasattr(dscd, "bufferlock"):
                    lock = dscd.bufferlock
                elif hasattr(dscd, "clusteringlock"):
                    lock = dscd.clusteringlock

                if lock:
                    with lock:
                        num_stores = len(dscd.prototypestores)
                        multisense = sum(
                            1 for store in dscd.prototypestores.values()
                            if hasattr(store, "centroids") and len(store.centroids) >= 2
                        )
                else:
                    num_stores = len(dscd.prototypestores)
                    multisense = sum(
                        1 for store in dscd.prototypestores.values()
                        if hasattr(store, "centroids") and len(store.centroids) >= 2
                    )

                if DEBUG_DISCOVERY or VERBOSE_LOGGING:
                    print(f"[INF] DSCD state: {num_stores} tokens, {multisense} multi-sense, {len(dscd_homographs)} discovered")

                if num_stores == 0:
                    print("[INF] WARNING: DSCD prototype stores are EMPTY")
                    if track_stats:
                        INFERENCE_STATS.dscd_empty_warnings += 1
                else:
                    dscd_validated = True
        except Exception as e:
            if DEBUG_DISCOVERY:
                print(f"[INF] DSCD validation failed: {e}")

        with torch.inference_mode():
            raw_dscd_out: Dict[str, Any] = {}

            try:
                if not hasattr(core, "mbart"):
                    raise RuntimeError("Model backend missing .mbart")

                mbart = core.mbart
                
                orig_cache = getattr(mbart.config, "use_cache", None) if hasattr(mbart, "config") else None
                if hasattr(mbart, "config"):
                    try:
                        mbart.config.use_cache = False
                    except Exception:
                        pass
                
                encoder_outputs_raw = mbart.model.encoder(
                    input_ids=enc.get("input_ids"),
                    attention_mask=enc.get("attention_mask")
                )
                cleanup_vars.append(encoder_outputs_raw)

                if hasattr(encoder_outputs_raw, "last_hidden_state"):
                    encoder_hidden = encoder_outputs_raw.last_hidden_state
                elif isinstance(encoder_outputs_raw, tuple):
                    encoder_hidden = encoder_outputs_raw[0]
                else:
                    encoder_hidden = encoder_outputs_raw
                
                cleanup_vars.append(encoder_hidden)

                if not isinstance(encoder_hidden, torch.Tensor) or encoder_hidden.dim() != 3:
                    raise RuntimeError(
                        f"Invalid encoder hidden: type={type(encoder_hidden)}, "
                        f"shape={encoder_hidden.shape if isinstance(encoder_hidden, torch.Tensor) else 'NA'}"
                    )

                if DEBUG_DISCOVERY:
                    print(f"[INF] Encoder hidden: {encoder_hidden.shape}")

                if hasattr(core, "forward_with_explanations"):
                    try:
                        raw_dscd_out = core.forward_with_explanations(
                            input_ids=enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask"),
                            src_texts=src_texts,
                            token_word_map=token_word_map_batch
                        )
                    except TypeError:
                        raw_dscd_out = core.forward_with_explanations(
                            enc.get("input_ids"),
                            enc.get("attention_mask"),
                            src_texts
                        )
                else:
                    if DEBUG_DISCOVERY:
                        print("[INF] forward_with_explanations not found, using forward")
                    out = core.forward(
                        input_ids=enc.get("input_ids"),
                        attention_mask=enc.get("attention_mask"),
                        src_texts=src_texts,
                        labels=None
                    )
                    if isinstance(out, dict):
                        raw_dscd_out = extract_dscd_outputs(out)

                dscd_out = extract_dscd_outputs(raw_dscd_out)
                
                if isinstance(raw_dscd_out, dict) and "sense_augmented_embeddings" in raw_dscd_out:
                    encoder_hidden_adjusted = raw_dscd_out["sense_augmented_embeddings"]
                elif "h_augmented" in dscd_out:
                    encoder_hidden_adjusted = dscd_out["h_augmented"]
                else:
                    encoder_hidden_adjusted = encoder_hidden
                
                cleanup_vars.append(encoder_hidden_adjusted)

                if isinstance(encoder_hidden_adjusted, torch.Tensor):
                    if encoder_hidden_adjusted.shape != encoder_hidden.shape:
                        if DEBUG_DISCOVERY:
                            print("[INF] Shape mismatch, using original")
                        encoder_hidden_adjusted = encoder_hidden
                else:
                    encoder_hidden_adjusted = encoder_hidden

                if DEBUG_DISCOVERY:
                    print("[INF] DSCD forward completed")

            except Exception as e:
                if DEBUG_DISCOVERY or VERBOSE_LOGGING:
                    print(f"[INF] DSCD forward error: {e}")
                raw_dscd_out = {}
                encoder_hidden_adjusted = encoder_hidden if encoder_hidden is not None else None

            if mbart is None:
                raise RuntimeError("mbart is not available for generation")

            forced_id = force_english_bos(tokenizer, mbart)

            if hasattr(mbart, "config"):
                try:
                    mbart.config.use_cache = True
                except Exception:
                    pass

            try:
                if DEBUG_DISCOVERY:
                    print("[INF] Generating translation...")

                if encoder_hidden_adjusted is not None and isinstance(encoder_hidden_adjusted, torch.Tensor):
                    encoder_hidden_adjusted = encoder_hidden_adjusted.to(device)
                    encoder_outputs_for_decoder = BaseModelOutput(
                        last_hidden_state=encoder_hidden_adjusted,
                        hidden_states=getattr(encoder_outputs_raw, "hidden_states", None) if encoder_outputs_raw else None,
                        attentions=getattr(encoder_outputs_raw, "attentions", None) if encoder_outputs_raw else None
                    )

                    generated = safe_generate(
                        mbart,
                        encoder_outputs=encoder_outputs_for_decoder,
                        attention_mask=enc.get("attention_mask"),
                        max_length=min(MAX_LENGTH, 64),
                        num_beams=2,
                        pad_token_id=getattr(tokenizer, "pad_token_id", 1),
                        forced_bos_token_id=forced_id
                    )
                else:
                    generated = safe_generate(
                        mbart,
                        input_ids=enc.get("input_ids"),
                        attention_mask=enc.get("attention_mask"),
                        max_length=min(MAX_LENGTH, 64),
                        num_beams=2,
                        pad_token_id=getattr(tokenizer, "pad_token_id", 1),
                        forced_bos_token_id=forced_id
                    )

                cleanup_vars.append(generated)
                translation = tokenizer.decode(generated[0], skip_special_tokens=True) if generated is not None and len(generated) > 0 else ""

                if DEBUG_DISCOVERY:
                    print(f"[INF] Translation: {translation[:60]}...")

            finally:
                if hasattr(mbart, "config") and orig_cache is not None:
                    try:
                        mbart.config.use_cache = orig_cache
                    except Exception:
                        pass

        if DEBUG_DISCOVERY:
            print("[INF] Extracting explanations...")

        dscd_out = extract_dscd_outputs(raw_dscd_out)
        explanations_list = get_explanations_list(dscd_out)
        sentence_explanations = explanations_list[0] if isinstance(explanations_list, list) and len(explanations_list) > 0 else []

        if DEBUG_DISCOVERY:
            print(f"[INF] Raw explanations: {len(sentence_explanations)}")

        def is_real_ambiguity(e: Dict[str, Any]) -> bool:
            try:
                s = float(e.get("span", 0.0))
                u = float(e.get("uncertainty", 0.0))
                return s >= span_th or u >= u_th
            except Exception:
                return False

        real_amb_count = 0
        out_explanations: List[Dict[str, Any]] = []
        filtered_count = 0

        cleaned_explanations = []
        for ex in sentence_explanations:
            try:
                word = ex.get("ambiguous_word", ex.get("token", ""))
                if isinstance(word, str):
                    clean_word = word.replace("▁", "").replace("##", "").replace("Ġ", "").strip()
                    if clean_word and ex.get("ambiguous_word", None) is not None and clean_word != ex.get("ambiguous_word"):
                        ex["ambiguous_word"] = clean_word
                    elif clean_word and ex.get("ambiguous_word", None) is None:
                        ex["ambiguous_word"] = clean_word
                cleaned_explanations.append(ex)
            except Exception:
                cleaned_explanations.append(ex)

        sentence_explanations = cleaned_explanations

        quality_metrics = {
            "total_raw_explanations": len(sentence_explanations) if isinstance(sentence_explanations, list) else 0,
            "filtered_explanations": 0,
            "high_confidence_count": 0,
            "low_confidence_count": 0,
            "avg_confidence": 0.0,
            "avg_span": 0.0,
            "avg_uncertainty": 0.0
        }

        confidences: List[float] = []
        spans: List[float] = []
        uncertainties: List[float] = []

        if isinstance(sentence_explanations, list):
            for ex in sentence_explanations:
                try:
                    if should_filter_explanation(ex, span_th, u_th):
                        filtered_count += 1
                        continue

                    is_real = is_real_ambiguity(ex)
                    if is_real:
                        real_amb_count += 1

                    confidence = ex.get("confidence", None)
                    if confidence is None:
                        s = float(ex.get("span", 0.0))
                        u = float(ex.get("uncertainty", 0.0))
                        confidence = max(s, u)
                    confidence = float(confidence)

                    confidences.append(confidence)
                    spans.append(float(ex.get("span", 0.0)))
                    uncertainties.append(float(ex.get("uncertainty", 0.0)))

                    if confidence >= 0.65:
                        quality_metrics["high_confidence_count"] += 1
                    elif confidence <= 0.4:
                        quality_metrics["low_confidence_count"] += 1

                    out_explanations.append({
                        "ambiguous_word": ex.get("ambiguous_word", ex.get("token", "NA")),
                        "position": ex.get("position", ex.get("token_idx", "NA")),
                        "explanation": ex.get("explanation", "") or ex.get("explain", "") or "",
                        "uncertainty": float(ex.get("uncertainty", 0.0)),
                        "span": float(ex.get("span", 0.0)),
                        "confidence": confidence,
                        "is_real_amb": bool(is_real)
                    })
                except Exception:
                    continue

        quality_metrics["filtered_explanations"] = filtered_count

        if confidences:
            quality_metrics["avg_confidence"] = sum(confidences) / len(confidences)
            quality_metrics["avg_span"] = sum(spans) / len(spans)
            quality_metrics["avg_uncertainty"] = sum(uncertainties) / len(uncertainties)

        if DEBUG_DISCOVERY:
            print(f"[INF] Final {len(out_explanations)} explanations (filtered {filtered_count})")

        result = {
            "input_sentence": input_sentence,
            "translation": translation,
            "ambiguous_words_detected": int(real_amb_count),
            "explanations": out_explanations,
            "quality_metrics": quality_metrics,
            "dscd_validated": dscd_validated
        }

        if track_stats:
            INFERENCE_STATS.record_inference(result, dscd_homographs=dscd_homographs)

        return result

    except Exception as e:
        if DEBUG_DISCOVERY or VERBOSE_LOGGING:
            print(f"[INF] ERROR: {type(e).__name__} {str(e)[:200]}")
            try:
                traceback.print_exc()
            except Exception:
                pass

        error_result = {
            "input_sentence": input_sentence,
            "translation": "ERROR DURING TRANSLATION",
            "ambiguous_words_detected": 0,
            "explanations": [],
            "quality_metrics": {},
            "dscd_validated": False,
            "error": str(e)[:200]
        }

        if track_stats:
            INFERENCE_STATS.record_inference(error_result, dscd_homographs=dscd_homographs)

        return error_result

    finally:
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception:
            pass
        try:
            if gc.isenabled():
                gc.collect()
        except Exception:
            pass

def demonstrate_system(model, tokenizer, sentences: Optional[List[str]] = None):
    if sentences is None:
        sentences = [
            "আমি কল বন্ধ করেছি",
            "আগামীকাল আমি একটি বই কিনব",
            "পাতা পড়ে গেছে"
        ]

    print("=" * 80)
    print("TATN DEMO: Translation + Explanations")
    print("=" * 80)

    INFERENCE_STATS.reset()

    for s in sentences:
        print(f"\n> {s}")
        res = translate_with_explanations(model, tokenizer, s)
        print(f"Translation: {res.get('translation', '')}")
        print(f"Ambiguous words detected: {res.get('ambiguous_words_detected', 0)}")

        quality = res.get("quality_metrics", {})
        if quality:
            print(
                f"Quality: conf={quality.get('avg_confidence', 0):.3f}, "
                f"high={quality.get('high_confidence_count', 0)}, "
                f"low={quality.get('low_confidence_count', 0)}"
            )

        if res.get("explanations"):
            for idx, ex in enumerate(res["explanations"], 1):
                print(f"  {idx}. {ex['ambiguous_word']} (pos={ex['position']}, conf={ex.get('confidence', 0):.3f})")
                print(f"     {ex.get('explanation', '')[:200]}")
        else:
            print("  No explanations.")

    print("=" * 80)
    INFERENCE_STATS.print_summary()

def dscd_discovery_warmup(
    model,
    tokenizer,
    num_sents: int = 8000,
    batch_size: int = 64,
    max_len: Optional[int] = None
):
    if max_len is None:
        max_len = MAX_LENGTH

    core = model.module if USE_MULTI_GPU and hasattr(model, "module") else model

    try:
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            print("WARMUP: Model has no dscd component")
            return

        print("=" * 80)
        print("WARMUP: Starting DSCD discovery warmup")
        print("=" * 80)

        orig_enable = getattr(dscd, "enable_training_clustering", False)
        orig_nmin = getattr(dscd, "n_min", None)
        orig_buffer = getattr(dscd, "buffer_size", None)

        try:
            if hasattr(dscd, "enable_training_clustering"):
                dscd.enable_training_clustering = True
            if hasattr(dscd, "n_min"):
                dscd.n_min = max(3, int(getattr(dscd, "n_min", 5)))
            if hasattr(dscd, "buffer_size"):
                dscd.buffer_size = max(200, int(getattr(dscd, "buffer_size", 300)))
        except Exception:
            pass

        texts: List[str] = []
        try:
            if "load_and_preprocess_optimized" in globals():
                pairs = load_and_preprocess_optimized(num_sents)
                texts = [bn for bn, _ in pairs[:num_sents]]
            else:
                base = [
                    "আমি কল বন্ধ করেছি",
                    "আগামীকাল আমি একটি বই কিনব",
                    "পাতা পড়ে গেছে",
                    "তিনি ব্যাংকে গিয়েছেন",
                    "নদীর ধারে বসে আছি"
                ]
                while len(texts) < num_sents:
                    texts.extend(base)
                texts = texts[:num_sents]
        except Exception:
            texts = []

        processed = 0
        core.eval()
        print(f"WARMUP: Processing {len(texts)} sentences (batch={batch_size})...")
        start_time = time.time()
        last_print = start_time

        with torch.inference_mode():
            for i in range(0, len(texts), batch_size):
                batch = texts[i:i+batch_size]
                try:
                    enc = tokenizer(
                        batch,
                        return_tensors="pt",
                        padding=True,
                        truncation=True,
                        max_length=max_len,
                        add_special_tokens=False
                    )
                    enc = to_device_batch(enc, DEVICE)

                    if hasattr(core, "forward_with_explanations"):
                        core.forward_with_explanations(
                            input_ids=enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask"),
                            src_texts=batch
                        )
                    else:
                        core.mbart.model.encoder(
                            input_ids=enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask")
                        )

                    processed += len(batch)

                    current_time = time.time()
                    if (i + batch_size) % 10 == 0 or (current_time - last_print) >= 5:
                        elapsed = current_time - start_time
                        rate = processed / elapsed if elapsed > 0 else 0
                        eta = (len(texts) - processed) / rate if rate > 0 else 0
                        print(
                            f"WARMUP: {processed}/{len(texts)} "
                            f"({processed/len(texts)*100:.1%}) @ {rate:.1f} sents/s "
                            f"ETA: {eta:.0f}s"
                        )
                        last_print = current_time

                    del enc
                except Exception as e:
                    print(f"WARMUP: Batch {i}/{batch_size} failed: {str(e)[:100]}")
                    continue

        total_time = time.time() - start_time
        print(f"WARMUP: Completed in {total_time:.1f}s ({processed/total_time:.1f} sents/s)")
        print("-" * 80)

        print("WARMUP: Running synchronous clustering...")
        clustering_result = synchronous_clustering_for_warmup(model)
        
        if clustering_result.get('success'):
            print(f"WARMUP: Clustered {clustering_result.get('clustered', 0)} tokens")
            print(f"WARMUP: Discovered {clustering_result.get('discovered', 0)} new homographs")
            print(f"WARMUP: Total homographs: {clustering_result.get('total_homographs', 0)}")
        else:
            print(f"WARMUP: Clustering incomplete: {clustering_result.get('reason', 'unknown')}")

        print("-" * 80)

        try:
            lock = None
            if hasattr(dscd, "bufferlock"):
                lock = dscd.bufferlock
            elif hasattr(dscd, "clusteringlock"):
                lock = dscd.clusteringlock

            if lock:
                with lock:
                    stores = dict(dscd.prototypestores)
            else:
                stores = dict(dscd.prototypestores)

            num_types = len(stores)
            total_protos = sum(store.size() for store in stores.values()) if stores else 0
            multi = sum(1 for store in stores.values() if store.size() >= 2) if stores else 0

            print("WARMUP: Summary")
            print(f"  - Token types: {num_types}")
            print(f"  - Total prototypes: {total_protos}")
            print(f"  - Multi-sense tokens: {multi}")
            if num_types > 0:
                print(f"  - Multi-sense ratio: {multi/num_types:.1%}")

            dscd_homographs = get_dscd_homographs(model)
            print(f"\nWARMUP: Discovered Homographs: {len(dscd_homographs)}")
            if dscd_homographs:
                print(f"  Sample: {list(dscd_homographs)[:10]}")

            reference_found = dscd_homographs.intersection(HOMOGRAPH_REFERENCE_LIST)
            print(f"\nWARMUP: Reference List Comparison")
            print(f"  - Reference list: {len(HOMOGRAPH_REFERENCE_LIST)} words")
            print(f"  - Found in DSCD: {len(reference_found)}")
            print(f"  - Coverage: {len(reference_found)/len(HOMOGRAPH_REFERENCE_LIST):.1%}")

            if num_types == 0:
                print("WARMUP: CRITICAL - NO PROTOTYPES CREATED")
            elif len(reference_found) / len(HOMOGRAPH_REFERENCE_LIST) < 0.2:
                print("WARMUP: WARNING - <20% reference coverage")
            else:
                print("WARMUP: SUCCESS")

        except Exception as e:
            print(f"WARMUP: Validation failed: {e}")

    finally:
        try:
            if dscd is not None:
                if hasattr(dscd, "enable_training_clustering"):
                    dscd.enable_training_clustering = orig_enable
                if hasattr(dscd, "n_min") and orig_nmin is not None:
                    dscd.n_min = orig_nmin
                if hasattr(dscd, "buffer_size") and orig_buffer is not None:
                    dscd.buffer_size = orig_buffer
        except Exception:
            pass

        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
        except Exception:
            pass

        try:
            if gc.isenabled():
                gc.collect()
        except Exception:
            pass

    print("=" * 80)

def load_checkpoint_for_resume(
    model: torch.nn.Module,
    optimizer,
    checkpoint_path: str
) -> Tuple[bool, int, int, float]:
    if not os.path.exists(checkpoint_path):
        print(f"CHECKPOINT: Not found: {checkpoint_path}")
        return False, 0, 0, 0.0

    try:
        ckpt = torch.load(checkpoint_path, map_location=DEVICE)
    except Exception as e:
        print(f"CHECKPOINT: Load failed: {e}")
        return False, 0, 0, 0.0

    core = model.module if USE_MULTI_GPU and hasattr(model, "module") else model
    state = ckpt.get("model_state_dict", ckpt)

    try:
        core.load_state_dict(state, strict=False)
    except Exception as e:
        print(f"CHECKPOINT: model.load_state_dict failed: {e}")
        try:
            if isinstance(state, dict):
                new_state = {}
                for k, v in state.items():
                    new_key = k.replace("module.", "") if k.startswith("module.") else k
                    new_state[new_key] = v
                core.load_state_dict(new_state, strict=False)
        except Exception:
            pass

    try:
        if optimizer is not None and "optimizer_state_dict" in ckpt:
            optimizer.load_state_dict(ckpt["optimizer_state_dict"])
    except Exception as e:
        print(f"CHECKPOINT: optimizer.load_state_dict failed: {e}")

    try:
        if "dscd_state" in ckpt and ckpt["dscd_state"]:
            dscd_state = ckpt["dscd_state"]
            print("CHECKPOINT: Restoring DSCD...")
            dscd = core.dscd if hasattr(core, "dscd") else None
            if dscd and hasattr(dscd, "load_state_dict"):
                lock = None
                if hasattr(dscd, "bufferlock"):
                    lock = dscd.bufferlock
                elif hasattr(dscd, "clusteringlock"):
                    lock = dscd.clusteringlock

                if lock:
                    with lock:
                        dscd.load_state_dict(dscd_state)
                else:
                    dscd.load_state_dict(dscd_state)

                num_tokens = len(dscd.prototypestores)
                total_protos = sum(store.size() for store in dscd.prototypestores.values())
                multisense = sum(1 for store in dscd.prototypestores.values() if store.size() >= 2)

                print(f"CHECKPOINT: DSCD restored")
                print(f"  - Tokens: {num_tokens}")
                print(f"  - Prototypes: {total_protos}")
                print(f"  - Multi-sense: {multisense}")

                if num_tokens == 0:
                    print("CHECKPOINT: WARNING - DSCD state empty - consider running warmup")
            else:
                print("CHECKPOINT: Model has no dscd.load_state_dict")
        else:
            print("CHECKPOINT: No DSCD state in checkpoint")
    except Exception as e:
        print(f"CHECKPOINT: DSCD restore failed: {e}")

    epoch = int(ckpt.get("epochs_trained", ckpt.get("epoch", 0)))
    step = int(ckpt.get("global_steps", ckpt.get("global_step", ckpt.get("step", 0))))
    avg_loss = float(ckpt.get("final_train_loss", ckpt.get("avg_epoch_loss", ckpt.get("avg_loss", 0.0))))

    print(f"CHECKPOINT: Loaded (epoch={epoch}, step={step}, loss={avg_loss:.6f})")
    return True, epoch, step, avg_loss

print("=" * 80)
print("Cell 8: Inference pipeline ready - COMPLETE FIXED VERSION")
print("=" * 80)
print("KEY FIXES:")
print("  ✓ Word map format aligned with Cell 6 TATN (Dict[int, int] + List[str])")
print("  ✓ Uses Cell 2 reconstruct_word_spans() when available")
print("  ✓ Synchronous clustering added to warmup")
print("  ✓ Proper word-level aggregation support")
print("=" * 80)


Cell 8: Inference pipeline ready - COMPLETE FIXED VERSION
KEY FIXES:
  ✓ Word map format aligned with Cell 6 TATN (Dict[int, int] + List[str])
  ✓ Uses Cell 2 reconstruct_word_spans() when available
  ✓ Synchronous clustering added to warmup
  ✓ Proper word-level aggregation support


In [12]:
# ==============================================================================
# CELL 9: COMPREHENSIVE TESTING & EVALUATION (WITH BLEU/COMET) - COMPLETE FIXED
# ==============================================================================
from typing import Dict, List, Tuple, Optional, Any
import torch
import traceback
import time
import functools
import gc
from collections import defaultdict

try:
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
except (NameError, TypeError):
    _USE_MULTI_GPU = torch.cuda.is_available() and torch.cuda.device_count() > 1

try:
    _SOURCE_LANGUAGE = str(SOURCE_LANGUAGE)
except (NameError, TypeError):
    _SOURCE_LANGUAGE = "bn"

try:
    _TARGET_LANGUAGE = str(TARGET_LANGUAGE)
except (NameError, TypeError):
    _TARGET_LANGUAGE = "en"

try:
    _DEVICE = DEVICE
except (NameError, TypeError):
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except (NameError, TypeError):
    _VERBOSE_LOGGING = False

try:
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
except (NameError, TypeError):
    _DEBUG_DISCOVERY = False

try:
    _DEBUG_TIMING = bool(DEBUG_TIMING)
except (NameError, TypeError):
    _DEBUG_TIMING = False

try:
    _SPAN_THRESHOLD = float(SPAN_THRESHOLD)
except (NameError, ValueError, TypeError):
    _SPAN_THRESHOLD = 0.05

try:
    _UNCERTAINTY_THRESHOLD = float(TAU_LOW)
except (NameError, ValueError, TypeError):
    _UNCERTAINTY_THRESHOLD = 0.15

try:
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in HOMOGRAPH_REFERENCE_LIST_BN)
except (NameError, TypeError):
    _HOMOGRAPH_REFERENCE_LIST = {
        "কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা", "বার", "হার", "তারা",
        "পানি", "দল", "বাজার", "নাম", "কথা", "বই", "ঘর", "মন", "হাত"
    }
    _HOMOGRAPH_REFERENCE_LIST = set(str(w).lower() for w in _HOMOGRAPH_REFERENCE_LIST)

_SACREBLEU_AVAILABLE = False
_COMET_AVAILABLE = False

try:
    import sacrebleu
    _SACREBLEU_AVAILABLE = True
except ImportError:
    pass

try:
    from comet import download_model, load_from_checkpoint
    _COMET_AVAILABLE = True
except ImportError:
    pass

def _resolve_dscd(model: torch.nn.Module):
    core = model.module if hasattr(model, "module") else model
    dscd = getattr(core, "dscd", None)
    return core, dscd

def _resolve_dscd_lock(dscd):
    if dscd is None:
        return None
    for name in ("bufferlock", "buffer_lock", "clusteringlock", "clustering_lock"):
        lock = getattr(dscd, name, None)
        if lock is not None:
            return lock
    return None

def _resolve_prototype_stores(dscd) -> Dict[Any, Any]:
    if dscd is None:
        return {}
    for name in ("prototypestores", "prototype_stores"):
        stores = getattr(dscd, name, None)
        if isinstance(stores, dict):
            return stores
        if stores is not None:
            try:
                return dict(stores)
            except Exception:
                pass
    return {}

def _store_size(store) -> int:
    if store is None:
        return 0
    try:
        size_attr = getattr(store, "size", None)
        if callable(size_attr):
            return int(size_attr())
        if isinstance(size_attr, int):
            return int(size_attr)
    except Exception:
        pass
    try:
        cents = getattr(store, "centroids", None)
        return int(len(cents)) if cents is not None else 0
    except Exception:
        return 0

def _clean_token_for_set(x: Any) -> str:
    return (
        str(x)
        .replace(" ", "")
        .replace("Ġ", "")
        .replace("##", "")
        .replace("▁", "")
        .strip()
        .lower()
    )

def _get_cluster_count(model: torch.nn.Module) -> int:
    try:
        _, dscd = _resolve_dscd(model)
        if dscd is None:
            return 0

        lock = _resolve_dscd_lock(dscd)
        if lock:
            with lock:
                stores = _resolve_prototype_stores(dscd)
                return len(stores)
        else:
            stores = _resolve_prototype_stores(dscd)
            return len(stores)
    except Exception:
        return 0

def _get_dscd_homographs(model: torch.nn.Module) -> set:
    try:
        _, dscd = _resolve_dscd(model)
        if dscd is None:
            return set()

        if hasattr(dscd, "get_discovered_homographs"):
            try:
                return set(_clean_token_for_set(w) for w in dscd.get_discovered_homographs())
            except Exception:
                pass

        homographs = set()
        lock = _resolve_dscd_lock(dscd)

        if lock:
            with lock:
                prototype_stores = _resolve_prototype_stores(dscd)
                items = list(prototype_stores.items())
        else:
            prototype_stores = _resolve_prototype_stores(dscd)
            items = list(prototype_stores.items())

        for token, store in items:
            try:
                if _store_size(store) >= 2:
                    homographs.add(_clean_token_for_set(token))
            except Exception:
                continue

        return homographs
    except Exception:
        return set()

def _print_top_clusters(model: torch.nn.Module, top_n: int = 5):
    try:
        _, dscd = _resolve_dscd(model)
        if dscd is None:
            return

        lock = _resolve_dscd_lock(dscd)
        if lock:
            with lock:
                prototype_stores = dict(_resolve_prototype_stores(dscd))
        else:
            prototype_stores = dict(_resolve_prototype_stores(dscd))

        if not prototype_stores:
            print("[CLUSTER] No clusters found yet")
            return

        cluster_info = []
        for token, store in prototype_stores.items():
            try:
                total_count = sum(getattr(store, "counts", []))
            except Exception:
                total_count = 0
            try:
                n_protos = len(getattr(store, "centroids", []))
            except Exception:
                n_protos = 0
            cluster_info.append({
                "token": token,
                "count": total_count,
                "protos": n_protos,
                "mu": getattr(store, "mu", 0.0),
                "tau": getattr(store, "tau", 0.0)
            })

        cluster_info.sort(key=lambda x: x["count"], reverse=True)

        print(f"\n[CLUSTER] Top {min(top_n, len(cluster_info))} clusters:")
        print("-" * 90)
        print(f"{'Rank':<6}{'Token':<15}{'Count':<12}{'Protos':<10}{'Mu':<15}{'Tau':<12}")
        print("-" * 90)

        for rank, info in enumerate(cluster_info[:top_n], 1):
            token_str = str(info["token"])
            token_display = token_str[:12] if len(token_str) > 12 else token_str
            print(
                f"{rank:<6}{token_display:<15}{info['count']:<12}{info['protos']:<10}"
                f"{info['mu']:<15.6f}{info['tau']:<12.6f}"
            )

        print("-" * 90)

    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[CLUSTER] Error: {str(e)[:100]}")

def _timed(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        if _DEBUG_TIMING:
            start = time.time()
            result = func(*args, **kwargs)
            elapsed = time.time() - start
            print(f"[TIMING] {func.__name__}: {elapsed:.2f}s")
            return result
        else:
            return func(*args, **kwargs)
    return wrapper

def compute_bleu_score(predictions: List[str], references: List[str]) -> float:
    if not _SACREBLEU_AVAILABLE:
        print("[BLEU] sacrebleu not available, install: pip install sacrebleu")
        return 0.0
    
    if not predictions or not references or len(predictions) != len(references):
        return 0.0
    
    try:
        predictions_clean = [p.strip() if p else "" for p in predictions]
        references_clean = [[r.strip() if r else ""] for r in references]
        
        bleu = sacrebleu.corpus_bleu(predictions_clean, references_clean)
        return float(bleu.score)
    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[BLEU] Calculation error: {e}")
        return 0.0

def compute_comet_score(
    sources: List[str],
    predictions: List[str],
    references: List[str],
    model_name: str = "Unbabel/wmt22-comet-da"
) -> float:
    if not _COMET_AVAILABLE:
        print("[COMET] comet-ml not available, install: pip install unbabel-comet")
        return 0.0
    
    if not sources or not predictions or not references:
        return 0.0
    
    if len(sources) != len(predictions) or len(sources) != len(references):
        return 0.0
    
    try:
        model_path = download_model(model_name)
        model = load_from_checkpoint(model_path)
        
        data = [
            {"src": src, "mt": pred, "ref": ref}
            for src, pred, ref in zip(sources, predictions, references)
        ]
        
        output = model.predict(data, batch_size=8, gpus=1 if torch.cuda.is_available() else 0)
        return float(output.system_score)
    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[COMET] Calculation error: {e}")
        return 0.0

def batch_translate_with_explanations(
    model,
    tokenizer,
    test_pairs: List[Tuple[str, str]],
    batch_size: int = 32,
    max_samples: Optional[int] = None,
    compute_bleu: bool = True,
    compute_comet: bool = False
) -> Dict[str, Any]:
    print("\n" + "=" * 80)
    print("BATCH TRANSLATION WITH EXPLANATIONS")
    print("=" * 80)
    
    if max_samples and len(test_pairs) > max_samples:
        test_pairs = test_pairs[:max_samples]
    
    sources = [src for src, _ in test_pairs]
    references = [ref for _, ref in test_pairs]
    
    predictions = []
    all_explanations = []
    total_ambiguous = 0
    
    print(f"[BATCH] Translating {len(sources)} sentences (batch={batch_size})...")
    start_time = time.time()
    
    model.eval()
    
    for i in range(0, len(sources), batch_size):
        batch_sources = sources[i:i+batch_size]
        
        for src in batch_sources:
            try:
                if "translate_with_explanations" in globals():
                    result = translate_with_explanations(
                        model,
                        tokenizer,
                        src,
                        device=_DEVICE,
                        span_threshold=_SPAN_THRESHOLD,
                        uncertainty_threshold=_UNCERTAINTY_THRESHOLD,
                        track_stats=True
                    )
                    
                    predictions.append(result.get("translation", ""))
                    all_explanations.append(result.get("explanations", []))
                    total_ambiguous += result.get("ambiguous_words_detected", 0)
                else:
                    predictions.append("")
                    all_explanations.append([])
            except Exception as e:
                if _DEBUG_DISCOVERY:
                    print(f"[BATCH] Translation error: {e}")
                predictions.append("")
                all_explanations.append([])
        
        if (i + batch_size) % 100 == 0:
            elapsed = time.time() - start_time
            rate = len(predictions) / elapsed if elapsed > 0 else 0
            print(f"[BATCH] {len(predictions)}/{len(sources)} @ {rate:.1f} sents/s")
    
    elapsed = time.time() - start_time
    print(f"[BATCH] Completed in {elapsed:.1f}s ({len(sources)/elapsed:.1f} sents/s)")
    
    bleu_score = 0.0
    if compute_bleu and _SACREBLEU_AVAILABLE:
        print("[BATCH] Computing BLEU...")
        bleu_score = compute_bleu_score(predictions, references)
        print(f"[BATCH] BLEU: {bleu_score:.2f}")
    
    comet_score = 0.0
    if compute_comet and _COMET_AVAILABLE:
        print("[BATCH] Computing COMET...")
        comet_score = compute_comet_score(sources, predictions, references)
        print(f"[BATCH] COMET: {comet_score:.4f}")
    
    total_explanations = sum(len(exps) for exps in all_explanations)
    
    inference_stats_summary = {}
    if "INFERENCE_STATS" in globals():
        try:
            inference_stats_summary = INFERENCE_STATS.get_summary()
        except Exception:
            pass
    
    print("=" * 80)
    
    return {
        "sources": sources,
        "predictions": predictions,
        "references": references,
        "bleu_score": bleu_score,
        "comet_score": comet_score,
        "total_explanations": total_explanations,
        "total_ambiguous": total_ambiguous,
        "inference_stats": inference_stats_summary,
        "elapsed_time": elapsed,
    }

def evaluate_on_test_set(
    model,
    tokenizer,
    test_size: int = 1000,
    batch_size: int = 32,
    compute_bleu: bool = True,
    compute_comet: bool = False
) -> Dict[str, Any]:
    print("\n" + "=" * 80)
    print("EVALUATE ON TEST SET")
    print("=" * 80)
    
    test_pairs = []
    
    try:
        if "load_and_preprocess_optimized" in globals():
            print(f"[EVAL] Loading {test_size} test pairs from dataset...")
            all_pairs = load_and_preprocess_optimized(test_size * 2)
            test_pairs = all_pairs[-test_size:]
            print(f"[EVAL] Loaded {len(test_pairs)} pairs")
        else:
            print("[EVAL] load_and_preprocess_optimized not available")
            return {"error": "dataset_loader_not_available"}
    except Exception as e:
        print(f"[EVAL] Failed to load test set: {e}")
        return {"error": str(e)}
    
    if not test_pairs:
        print("[EVAL] No test pairs loaded")
        return {"error": "no_test_pairs"}
    
    result = batch_translate_with_explanations(
        model,
        tokenizer,
        test_pairs,
        batch_size=batch_size,
        max_samples=test_size,
        compute_bleu=compute_bleu,
        compute_comet=compute_comet
    )
    
    print("\n[EVAL] Test Set Evaluation Complete")
    print(f"  BLEU: {result.get('bleu_score', 0):.2f}")
    if compute_comet:
        print(f"  COMET: {result.get('comet_score', 0):.4f}")
    print(f"  Explanations: {result.get('total_explanations', 0)}")
    print(f"  Ambiguous: {result.get('total_ambiguous', 0)}")
    print("=" * 80)
    
    return result

@torch.inference_mode()
@_timed
def comprehensive_post_training_testing(
    model: torch.nn.Module,
    tokenizer,
    run_warmup: bool = True,
    compare_baseline: bool = False,
    baseline_metrics: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
    print("\n" + "=" * 80)
    print("COMPREHENSIVE POST-TRAINING EVALUATION (Pure Data-Driven)")
    print("=" * 80)

    test_sentences: List[Tuple[str, str, str, List[str]]] = [
        ("আমি কল বন্ধ করেছি।", "I turned off the tap", "কল = tap/call", ["কল"]),
        ("কাল আমি বই কিনব।", "Tomorrow I will buy a book", "কাল = tomorrow/yesterday", ["কাল"]),
        ("পাতা ঝরে পড়েছে।", "The leaf has fallen", "পাতা = leaf/page", ["পাতা"]),
        ("তিনি ব্যাংকে গেছেন।", "He went to the bank", "ব্যাংক = bank/embankment", ["ব্যাংক"]),
        ("ফল খুব সুস্বাদু।", "The fruit is delicious", "ফল = fruit/result", ["ফল"]),
        ("মাথা ব্যথা করছে।", "Head is aching", "মাথা = head/top", ["মাথা"]),
        ("কল থেকে কল এসেছে।", "A call came from the tap", "Multiple কল", ["কল"]),
        ("কালকে কাল মেঘ দেখা গেছে।", "Yesterday black clouds were seen", "Multiple কাল", ["কাল"]),
        ("আজ ভাল আবহাওয়া।", "Weather is good today", "Simple", []),
        ("আমি ভালো আছি।", "I am fine", "Simple", []),
        ("সে খুব মিষ্টি কথা বলে।", "She speaks sweetly", "Simple", []),
        ("এটা আমার বই।", "This is my book", "Simple", []),
        ("তিনি ব্যাংকে কাজ করেন এবং ব্যাংকে বসে থাকেন।",
         "He works at the bank and sits on the embankment",
         "Long with multiple", ["ব্যাংক"]),
    ]

    core_model = model.module if (_USE_MULTI_GPU and hasattr(model, "module")) else model
    core_model.eval()

    quality_metrics = {
        "total_confidence": 0.0,
        "confidence_samples": 0,
        "high_confidence_count": 0,
        "medium_confidence_count": 0,
        "low_confidence_count": 0,
        "confidences": [],
        "spans": [],
        "uncertainties": [],
    }

    homograph_tracking = {
        "test_expected_homographs": set(),
        "dscd_discovered_homographs": set(),
        "explained_homographs": set(),
        "homograph_explanations": defaultdict(list),
    }

    error_tracking = {
        "translation_failures": 0,
        "dscd_failures": 0,
        "trg_failures": 0,
        "timeout_errors": 0,
        "oom_errors": 0,
        "other_errors": 0,
        "error_details": [],
        "per_test_status": [],
    }

    timing_metrics = {
        "total_time": 0.0,
        "per_test_times": [],
        "avg_test_time": 0.0,
    }

    discovery_validated = False
    try:
        _, dscd = _resolve_dscd(core_model)
        disc_log = getattr(dscd, "discoveredlog", None)
        if not disc_log:
            disc_log = getattr(dscd, "discovered_log", None)

        if dscd and isinstance(disc_log, list) and disc_log:
            discovery_validated = True
            last_discovery = disc_log[-1] if isinstance(disc_log[-1], dict) else {}
            discovered = last_discovery.get("discovered", 0)
            candidates = last_discovery.get("candidates", 0)
            if _DEBUG_DISCOVERY:
                print(f"[EVAL] Discovery log: {discovered}/{candidates} homographs")
        else:
            if _DEBUG_DISCOVERY:
                print(f"[EVAL] No discovery log found")
    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[EVAL] Discovery validation failed: {e}")

    asbn_stats: Dict[str, Any] = {}
    try:
        asbn = getattr(core_model, "asbn", None)
        if asbn and hasattr(asbn, "get_detailed_stats"):
            asbn_stats = asbn.get_detailed_stats()
        elif asbn and hasattr(asbn, "get_asbn_stats"):
            asbn_stats = asbn.get_asbn_stats()

        if asbn_stats and _DEBUG_DISCOVERY:
            print(f"[EVAL] ASBN: domain_acc={asbn_stats.get('domain_accuracy', 0):.2%}")
    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[EVAL] ASBN stats failed: {e}")

    trg_stats: Dict[str, Any] = {}
    try:
        trg = getattr(core_model, "trg_system", None)
        if trg and hasattr(trg, "get_statistics"):
            trg_stats = trg.get_statistics()
            if _DEBUG_DISCOVERY:
                print(f"[EVAL] TRG: {trg_stats.get('explanations_generated', 0)} total")
    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[EVAL] TRG stats failed: {e}")

    homograph_tracking["dscd_discovered_homographs"] = _get_dscd_homographs(core_model)
    print(f"[EVAL] DSCD discovered: {len(homograph_tracking['dscd_discovered_homographs'])} homographs")
    if homograph_tracking["dscd_discovered_homographs"] and _DEBUG_DISCOVERY:
        print(f"[EVAL] Sample: {list(homograph_tracking['dscd_discovered_homographs'])[:10]}")

    if run_warmup:
        try:
            _, dscd = _resolve_dscd(core_model)
            if dscd is not None:
                lock = _resolve_dscd_lock(dscd)

                if lock:
                    with lock:
                        stores = _resolve_prototype_stores(dscd)
                        store_count = len(stores) if stores else 0
                else:
                    stores = _resolve_prototype_stores(dscd)
                    store_count = len(stores) if stores else 0

                if store_count == 0 and "dscd_discovery_warmup" in globals():
                    print("[EVAL] Running warmup (num_sents=4000)...")
                    try:
                        dscd_discovery_warmup(model, tokenizer, num_sents=4000, batch_size=64)
                        homograph_tracking["dscd_discovered_homographs"] = _get_dscd_homographs(core_model)
                        
                        if lock:
                            with lock:
                                stores_after = _resolve_prototype_stores(dscd)
                                store_count_after = len(stores_after) if stores_after else 0
                        else:
                            stores_after = _resolve_prototype_stores(dscd)
                            store_count_after = len(stores_after) if stores_after else 0
                        
                        if store_count_after == 0:
                            print("[EVAL] WARNING: Warmup completed but DSCD stores still empty")
                        else:
                            print(f"[EVAL] Warmup completed: {store_count_after} stores created")
                    except Exception as e:
                        print(f"[EVAL] Warmup failed: {e}")
        except Exception:
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

    total_tests = len(test_sentences)
    successful_translations = 0
    total_explanations = 0
    total_high_span = 0
    total_real_ambiguous = 0

    print(f"\n[EVAL] Running {total_tests} tests...")
    print("-" * 80)

    try:
        tokenizer.src_lang = _SOURCE_LANGUAGE
    except Exception:
        pass

    def _is_real_amb(expl: Dict[str, Any]) -> bool:
        try:
            s = float(expl.get("span", 0.0))
            u = float(expl.get("uncertainty", 0.0))
            return (s > _SPAN_THRESHOLD) or (u > _UNCERTAINTY_THRESHOLD)
        except Exception:
            return False

    def _compute_similarity(pred: str, expected: str) -> float:
        try:
            pred_words = set(pred.lower().split())
            exp_words = set(expected.lower().split())
            if not pred_words and not exp_words:
                return 1.0
            if not pred_words or not exp_words:
                return 0.0
            overlap = len(pred_words & exp_words)
            union = len(pred_words | exp_words)
            return overlap / union if union > 0 else 0.0
        except Exception:
            return 0.0

    for _, _, _, expected_homos in test_sentences:
        homograph_tracking["test_expected_homographs"].update([h.lower() for h in expected_homos])

    eval_start = time.time()

    for idx, (src_text, expected_translation, desc, expected_homos) in enumerate(test_sentences, 1):
        test_start = time.time()

        print(f"\nTest {idx}/{total_tests}: {desc}")
        print("=" * 60)

        test_status = {
            "test_id": idx,
            "success": False,
            "translation_ok": False,
            "explanations_count": 0,
            "error": None,
        }

        try:
            if "translate_with_explanations" not in globals():
                print("[EVAL] translate_with_explanations not available")
                error_tracking["other_errors"] += 1
                test_status["error"] = "function_not_available"
                error_tracking["per_test_status"].append(test_status)
                continue

            result = translate_with_explanations(
                core_model if core_model is not None else model,
                tokenizer,
                src_text,
                device=_DEVICE,
                span_threshold=_SPAN_THRESHOLD,
                uncertainty_threshold=_UNCERTAINTY_THRESHOLD,
                track_stats=True
            )

            translation = str(result.get("translation", "") or "")
            amb_count = int(result.get("ambiguous_words_detected", 0))
            explanations = result.get("explanations", []) or []

            similarity = _compute_similarity(translation, expected_translation)

            print(f"Input: {src_text}")
            print(f"Expected: {expected_translation}")
            print(f"Translation: {translation}")
            print(f"Similarity: {similarity:.1%}")
            print(f"Ambiguous: {amb_count}")

            if explanations:
                print("\nExplanations:")
                high_span_local = 0
                real_amb_local = 0

                for j, expl in enumerate(explanations, 1):
                    span_val = float(expl.get("span", 0.0))
                    u_val = float(expl.get("uncertainty", 0.0))
                    conf_val = float(expl.get("confidence", max(span_val, u_val)))

                    marker = f"[S>{_SPAN_THRESHOLD:.2f}]" if span_val > _SPAN_THRESHOLD else "          "

                    word = expl.get("ambiguous_word", expl.get("token", "N/A"))
                    pos = expl.get("position", expl.get("token_idx", "N/A"))

                    print(f"  {j}. {marker} '{word}' @ {pos}")
                    print(f"       conf={conf_val:.3f} | U={u_val:.3f} | S={span_val:.3f}")
                    text = str(expl.get("explanation", ""))
                    if len(text) > 120:
                        text = text[:120] + "..."
                    print(f"       {text}")

                    quality_metrics["confidences"].append(conf_val)
                    quality_metrics["spans"].append(span_val)
                    quality_metrics["uncertainties"].append(u_val)
                    quality_metrics["total_confidence"] = quality_metrics.get("total_confidence", 0.0) + conf_val
                    quality_metrics["confidence_samples"] += 1

                    if conf_val >= 0.65:
                        quality_metrics["high_confidence_count"] += 1
                    elif conf_val >= 0.4:
                        quality_metrics["medium_confidence_count"] += 1
                    else:
                        quality_metrics["low_confidence_count"] += 1

                    if span_val > _SPAN_THRESHOLD:
                        high_span_local += 1
                    if _is_real_amb(expl):
                        real_amb_local += 1

                    clean_word = _clean_token_for_set(word)
                    homograph_tracking["explained_homographs"].add(clean_word)
                    homograph_tracking["homograph_explanations"][clean_word].append({
                        "sentence": src_text,
                        "confidence": conf_val,
                        "span": span_val,
                        "uncertainty": u_val,
                    })

                total_explanations += len(explanations)
                total_high_span += high_span_local
                total_real_ambiguous += real_amb_local
                test_status["explanations_count"] = len(explanations)
            else:
                print("No explanations")

            if translation and translation.strip() and translation not in (
                "Error occurred",
                "Translation generation failed",
                "ERROR DURING TRANSLATION",
            ):
                successful_translations += 1
                test_status["translation_ok"] = True
                test_status["success"] = True
                print("Success")
            else:
                print("Translation failed")
                error_tracking["translation_failures"] += 1
                test_status["error"] = "translation_failed"

        except RuntimeError as e:
            error_str = str(e).lower()
            if "out of memory" in error_str:
                print(f"[EVAL] OOM: {str(e)[:100]}")
                error_tracking["oom_errors"] += 1
                test_status["error"] = "oom"
            elif "timeout" in error_str:
                print(f"[EVAL] Timeout: {str(e)[:100]}")
                error_tracking["timeout_errors"] += 1
                test_status["error"] = "timeout"
            else:
                print(f"[EVAL] Runtime: {type(e).__name__}")
                error_tracking["other_errors"] += 1
                test_status["error"] = "runtime"
            error_tracking["error_details"].append(f"Test {idx}: {type(e).__name__}")
        except Exception as e:
            print(f"[EVAL] Error: {type(e).__name__}")
            error_tracking["other_errors"] += 1
            test_status["error"] = type(e).__name__
            error_tracking["error_details"].append(f"Test {idx}: {type(e).__name__}")
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        error_tracking["per_test_status"].append(test_status)

        test_time = time.time() - test_start
        timing_metrics["per_test_times"].append(test_time)

        print("-" * 60)

    timing_metrics["total_time"] = time.time() - eval_start
    if timing_metrics["per_test_times"]:
        timing_metrics["avg_test_time"] = (
            sum(timing_metrics["per_test_times"]) / len(timing_metrics["per_test_times"])
        )

    if quality_metrics["confidence_samples"] > 0:
        quality_metrics["avg_confidence"] = (
            quality_metrics["total_confidence"] / quality_metrics["confidence_samples"]
        )
        quality_metrics["avg_span"] = (
            sum(quality_metrics["spans"]) / len(quality_metrics["spans"])
            if quality_metrics["spans"]
            else 0.0
        )
        quality_metrics["avg_uncertainty"] = (
            sum(quality_metrics["uncertainties"]) / len(quality_metrics["uncertainties"])
            if quality_metrics["uncertainties"]
            else 0.0
        )

        if quality_metrics["confidences"]:
            sorted_conf = sorted(quality_metrics["confidences"])
            quality_metrics["confidence_p25"] = sorted_conf[len(sorted_conf) // 4]
            quality_metrics["confidence_p50"] = sorted_conf[len(sorted_conf) // 2]
            quality_metrics["confidence_p75"] = sorted_conf[3 * len(sorted_conf) // 4]
    else:
        quality_metrics["avg_confidence"] = 0.0
        quality_metrics["avg_span"] = 0.0
        quality_metrics["avg_uncertainty"] = 0.0

    explained_from_dscd = homograph_tracking["explained_homographs"].intersection(
        homograph_tracking["dscd_discovered_homographs"]
    )

    test_expected_discovered = homograph_tracking["test_expected_homographs"].intersection(
        homograph_tracking["dscd_discovered_homographs"]
    )

    reference_discovered = _HOMOGRAPH_REFERENCE_LIST.intersection(
        homograph_tracking["dscd_discovered_homographs"]
    )

    homograph_tracking["explained_from_dscd_rate"] = (
        len(explained_from_dscd) / len(homograph_tracking["dscd_discovered_homographs"])
        if homograph_tracking["dscd_discovered_homographs"]
        else 0.0
    )
    homograph_tracking["test_expected_discovery_rate"] = (
        len(test_expected_discovered) / len(homograph_tracking["test_expected_homographs"])
        if homograph_tracking["test_expected_homographs"]
        else 0.0
    )
    homograph_tracking["reference_discovery_rate"] = (
        len(reference_discovered) / len(_HOMOGRAPH_REFERENCE_LIST)
        if _HOMOGRAPH_REFERENCE_LIST
        else 0.0
    )

    try:
        dscd_stats = {"total_words": 0, "multi_sense_words": 0, "total_prototypes": 0}
        _, dscd = _resolve_dscd(core_model)
        if dscd is not None:
            lock = _resolve_dscd_lock(dscd)
            if lock:
                with lock:
                    stores = dict(_resolve_prototype_stores(dscd))
            else:
                stores = dict(_resolve_prototype_stores(dscd))

            total_words = 0
            multi = 0
            total_protos = 0
            for _, store in stores.items():
                sz = 0
                try:
                    sz = _store_size(store)
                except Exception:
                    sz = 0
                total_words += 1
                total_protos += sz
                if sz >= 2:
                    multi += 1
            dscd_stats = {
                "total_words": total_words,
                "multi_sense_words": multi,
                "total_prototypes": total_protos,
            }
    except Exception as e:
        if _DEBUG_DISCOVERY:
            print(f"[EVAL] DSCD stats failed: {e}")
        dscd_stats = {"total_words": 0, "multi_sense_words": 0, "total_prototypes": 0}

    print("\n" + "=" * 80)
    print("COMPREHENSIVE EVALUATION SUMMARY")
    print("=" * 80)

    print(f"\n[TRANSLATION QUALITY]")
    print(f"  Total tests: {total_tests}")
    print(f"  Successful: {successful_translations}")
    print(f"  Success rate: {successful_translations / total_tests * 100:.1f}%")

    print(f"\n[AMBIGUITY DETECTION]")
    print(f"  Total explanations: {total_explanations}")
    print(f"  High-span (S>{_SPAN_THRESHOLD}): {total_high_span}")
    print(f"  Real ambiguous: {total_real_ambiguous}")
    if total_tests > 0:
        print(f"  Avg explanations/test: {total_explanations / total_tests:.2f}")

    print(f"\n[EXPLANATION QUALITY]")
    print(f"  Avg confidence: {quality_metrics['avg_confidence']:.3f}")
    print(f"  Avg span: {quality_metrics['avg_span']:.3f}")
    print(f"  Avg uncertainty: {quality_metrics['avg_uncertainty']:.3f}")

    if "confidence_p50" in quality_metrics:
        print(
            f"  Confidence P25/P50/P75: "
            f"{quality_metrics.get('confidence_p25', 0):.3f} / "
            f"{quality_metrics.get('confidence_p50', 0):.3f} / "
            f"{quality_metrics.get('confidence_p75', 0):.3f}"
        )

    print(f"  High (>=0.65): {quality_metrics['high_confidence_count']}")
    print(f"  Medium (0.4-0.65): {quality_metrics['medium_confidence_count']}")
    print(f"  Low (<0.4): {quality_metrics['low_confidence_count']}")

    print(f"\n[HOMOGRAPH DISCOVERY]")
    print(f"  DSCD discovered: {len(homograph_tracking['dscd_discovered_homographs'])}")
    print(f"  Explained: {len(homograph_tracking['explained_homographs'])}")
    print(f"  Explanation rate: {homograph_tracking['explained_from_dscd_rate']:.1%}")
    print(f"  Test discovery rate: {homograph_tracking['test_expected_discovery_rate']:.1%}")

    if homograph_tracking["explained_homographs"]:
        print(f"\n  Explained homographs (top 10):")
        for homo in sorted(homograph_tracking["explained_homographs"])[:10]:
            exps = homograph_tracking["homograph_explanations"].get(homo, [])
            count = len(exps)
            avg_conf = sum(e["confidence"] for e in exps) / len(exps) if exps else 0.0
            in_dscd = "[D]" if homo in homograph_tracking["dscd_discovered_homographs"] else "   "
            in_ref = "[R]" if homo in _HOMOGRAPH_REFERENCE_LIST else "   "
            print(f"    {in_dscd} {in_ref} '{homo}': {count} x conf={avg_conf:.3f}")

    print(f"\n[REFERENCE COMPARISON]")
    print(f"  Reference: {len(_HOMOGRAPH_REFERENCE_LIST)} words")
    print(f"  Discovered: {len(reference_discovered)}/{len(_HOMOGRAPH_REFERENCE_LIST)}")
    print(f"  Coverage: {homograph_tracking['reference_discovery_rate']:.1%}")

    print(f"\n[DSCD PROTOTYPES]")
    print(f"  Word types: {dscd_stats['total_words']}")
    print(f"  Multi-sense: {dscd_stats['multi_sense_words']}")
    print(f"  Total prototypes: {dscd_stats['total_prototypes']}")
    if dscd_stats["total_words"] > 0:
        print(
            f"  Multi-sense ratio: "
            f"{dscd_stats['multi_sense_words'] / dscd_stats['total_words']:.1%}"
        )

    if asbn_stats:
        print(f"\n[ASBN]")
        print(f"  Domain accuracy: {asbn_stats.get('domain_accuracy', 0):.2%}")
        if "source_accuracy" in asbn_stats:
            print(f"  Source accuracy: {asbn_stats['source_accuracy']:.2%}")
            print(f"  Target accuracy: {asbn_stats['target_accuracy']:.2%}")

    if trg_stats:
        print(f"\n[TRG]")
        print(f"  Total explanations: {trg_stats.get('explanations_generated', 0)}")
        print(f"  High confidence: {trg_stats.get('high_confidence_rate', 0):.1%}")

    print(f"\n[PERFORMANCE]")
    print(f"  Total time: {timing_metrics['total_time']:.2f}s")
    print(f"  Avg time/test: {timing_metrics['avg_test_time']:.2f}s")

    total_errors = sum([
        error_tracking["translation_failures"],
        error_tracking["dscd_failures"],
        error_tracking["trg_failures"],
        error_tracking["timeout_errors"],
        error_tracking["oom_errors"],
        error_tracking["other_errors"],
    ])

    if total_errors > 0:
        print(f"\n[ERRORS]")
        print(f"  Total: {total_errors}")
        print(f"  Translation: {error_tracking['translation_failures']}")
        print(f"  OOM: {error_tracking['oom_errors']}")
        print(f"  Other: {error_tracking['other_errors']}")

    if compare_baseline and baseline_metrics and isinstance(baseline_metrics, dict):
        print(f"\n[BASELINE COMPARISON]")
        try:
            baseline_success = float(baseline_metrics.get("success_rate_pct", 0))
            current_success = (successful_translations / total_tests * 100.0) if total_tests > 0 else 0.0
            success_delta = current_success - baseline_success

            baseline_expl = int(baseline_metrics.get("total_explanations", 0))
            expl_delta = total_explanations - baseline_expl

            baseline_quality_dict = baseline_metrics.get("quality_metrics", {})
            baseline_quality = float(baseline_quality_dict.get("avg_confidence", 0)) if isinstance(baseline_quality_dict, dict) else 0.0
            quality_delta = quality_metrics["avg_confidence"] - baseline_quality

            print(f"  Translation: {current_success:.1f}% ({success_delta:+.1f}%)")
            print(f"  Explanations: {total_explanations} ({expl_delta:+d})")
            print(
                f"  Confidence: {quality_metrics['avg_confidence']:.3f} "
                f"({quality_delta:+.3f})"
            )

            if "homograph_tracking" in baseline_metrics and isinstance(baseline_metrics["homograph_tracking"], dict):
                baseline_homo_dict = baseline_metrics["homograph_tracking"]
                baseline_homo_rate = float(baseline_homo_dict.get("explained_from_dscd_rate", 0))

                homo_delta = (homograph_tracking["explained_from_dscd_rate"] - baseline_homo_rate)
                print(
                    f"  Explanation rate: "
                    f"{homograph_tracking['explained_from_dscd_rate']:.1%} "
                    f"({homo_delta:+.1%})"
                )
        except Exception as e:
            print(f"  Comparison failed: {e}")

    warnings = []
    if successful_translations < total_tests * 0.5:
        warnings.append("High translation failure (>50%)")
    if total_explanations == 0:
        warnings.append("No explanations generated")
    if dscd_stats["total_words"] < 100:
        warnings.append("Very few prototypes (<100)")
    if quality_metrics["low_confidence_count"] > quality_metrics["high_confidence_count"]:
        warnings.append("More low than high confidence")
    if homograph_tracking["explained_from_dscd_rate"] < 0.3:
        warnings.append("Low explanation rate (<30%)")
    if not discovery_validated:
        warnings.append("Discovery log missing")
    if asbn_stats and asbn_stats.get("domain_accuracy", 0) < 0.5:
        warnings.append("ASBN domain accuracy <50%")

    if warnings:
        print(f"\n[WARNINGS]")
        for w in warnings:
            print(f"  - {w}")
    else:
        print(f"\n[HEALTH] All systems nominal")

    print("=" * 80)

    try:
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        if gc.isenabled():
            gc.collect()
    except Exception:
        pass

    return {
        "total_tests": total_tests,
        "successful_translations": successful_translations,
        "success_rate_pct": (successful_translations / total_tests * 100.0) if total_tests > 0 else 0.0,
        "total_explanations": total_explanations,
        "total_high_span": total_high_span,
        "total_real_ambiguous": total_real_ambiguous,
        "dscd_stats": dscd_stats,
        "quality_metrics": quality_metrics,
        "homograph_tracking": homograph_tracking,
        "error_tracking": error_tracking,
        "asbn_stats": asbn_stats,
        "trg_stats": trg_stats,
        "discovery_validated": discovery_validated,
        "timing_metrics": timing_metrics,
    }

def test_evaluation_pipeline(model, tokenizer) -> bool:
    print("\n" + "=" * 60)
    print("[TEST] Testing evaluation pipeline")
    print("=" * 60)

    try:
        result = comprehensive_post_training_testing(
            model,
            tokenizer,
            run_warmup=False,
            compare_baseline=False
        )

        assert "total_tests" in result
        assert "quality_metrics" in result
        assert "homograph_tracking" in result

        print("Evaluation pipeline test passed")
        print("=" * 60 + "\n")
        return True

    except Exception as e:
        print(f"Evaluation pipeline test failed: {e}")
        try:
            traceback.print_exc()
        except Exception:
            pass
        print("=" * 60 + "\n")
        return False

print("\n" + "=" * 80)
print("Cell 9: Testing & evaluation ready - COMPLETE FIXED VERSION")
print("=" * 80)
print("KEY FIXES:")
print("  ✓ Added batch_translate_with_explanations() for large test sets")
print("  ✓ Added compute_bleu_score() with sacrebleu")
print("  ✓ Added compute_comet_score() with comet-ml")
print("  ✓ Added evaluate_on_test_set() for dataset evaluation")
print("  ✓ Fixed _compute_similarity() to proper Jaccard similarity")
print("  ✓ Fixed typo in test sentence (বন্ধেছি → বন্ধ করেছি)")
print("  ✓ Changed track_stats=True for Cell 8 integration")
print("  ✓ Import detection for sacrebleu and comet-ml")
print()
print("Available functions:")
print("  - comprehensive_post_training_testing() - 13 curated tests")
print("  - batch_translate_with_explanations() - Batch inference")
print("  - evaluate_on_test_set() - Full test set with BLEU/COMET")
print("  - compute_bleu_score() - BLEU calculation")
print("  - compute_comet_score() - COMET calculation")
print()
print(f"Evaluation metrics:")
print(f"  - BLEU: {'Available' if _SACREBLEU_AVAILABLE else 'Not installed (pip install sacrebleu)'}")
print(f"  - COMET: {'Available' if _COMET_AVAILABLE else 'Not installed (pip install unbabel-comet)'}")
print(f"  - Span threshold: {_SPAN_THRESHOLD}")
print(f"  - Uncertainty threshold: {_UNCERTAINTY_THRESHOLD}")
print(f"  - Reference list: {len(_HOMOGRAPH_REFERENCE_LIST)} words")
print("=" * 80 + "\n")



Cell 9: Testing & evaluation ready - COMPLETE FIXED VERSION
KEY FIXES:
  ✓ Added batch_translate_with_explanations() for large test sets
  ✓ Added compute_bleu_score() with sacrebleu
  ✓ Added compute_comet_score() with comet-ml
  ✓ Added evaluate_on_test_set() for dataset evaluation
  ✓ Fixed _compute_similarity() to proper Jaccard similarity
  ✓ Fixed typo in test sentence (বন্ধেছি → বন্ধ করেছি)
  ✓ Changed track_stats=True for Cell 8 integration
  ✓ Import detection for sacrebleu and comet-ml

Available functions:
  - comprehensive_post_training_testing() - 13 curated tests
  - batch_translate_with_explanations() - Batch inference
  - evaluate_on_test_set() - Full test set with BLEU/COMET
  - compute_bleu_score() - BLEU calculation
  - compute_comet_score() - COMET calculation

Evaluation metrics:
  - BLEU: Available
  - COMET: Not installed (pip install unbabel-comet)
  - Span threshold: 0.2
  - Uncertainty threshold: 0.15
  - Reference list: 65 words



In [13]:
# ==============================================================================
# CELL 10: TATN MAIN PIPELINE - COMPLETE FIXED VERSION
# ==============================================================================

import os
import time
import traceback
from typing import Tuple, Optional, Dict, Any
import gc

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

def g(name, default):
    return globals().get(name, default)

USE_MULTI_GPU = bool(g("USE_MULTI_GPU", False))
NUM_GPUS = int(g("NUM_GPUS", torch.cuda.device_count() if torch.cuda.is_available() else 0))
DEVICE = g("DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
SOURCE_LANGUAGE = str(g("SOURCE_LANGUAGE", "bn"))
TARGET_LANGUAGE = str(g("TARGET_LANGUAGE", "en"))
NUM_SAMPLES = int(g("NUM_SAMPLES", 30000))
MAX_LENGTH = int(g("MAX_LENGTH", 48))
BATCH_SIZE = int(g("BATCH_SIZE", 100))
EPOCHS = int(g("EPOCHS", 1))
ACCUMULATION_STEPS = int(g("ACCUMULATION_STEPS", 16))
LR_NMT = float(g("LR_NMT", 2e-5))
LR_PHI = float(g("LR_PHI", 1e-5))
LR_TRG = float(g("LR_TRG", 1e-5))
ENABLE_ASBN_TRAINING = bool(g("ENABLE_ASBN_TRAINING", True))
ENABLE_TRG_TRAINING = bool(g("ENABLE_TRG_TRAINING", True))
VALIDATION_CHECK_INTERVAL = int(g("VALIDATION_CHECK_INTERVAL", 200))
PERIODIC_DISCOVERY_FREQUENCY = int(g("PERIODIC_DISCOVERY_FREQUENCY", 200))
DSCD_WARMUP_SAMPLES = int(g("DSCD_WARMUP_SAMPLES", 8000))
HOMOGRAPH_REFERENCE_LIST_BN = set(g("HOMOGRAPH_REFERENCE_LIST_BN", {"কল", "কাল", "পাতা"}))
HOMOGRAPH_REFERENCE_LIST = HOMOGRAPH_REFERENCE_LIST_BN
FREEZE_ENCODER = bool(g("FREEZE_ENCODER", False))
DEBUG_TIMING = bool(g("DEBUG_TIMING", True))
VERBOSE_LOGGING = bool(g("VERBOSE_LOGGING", False))
SPAN_THRESHOLD = float(g("SPAN_THRESHOLD", 0.20))
UNCERTAINTY_THRESHOLD = float(g("TAU_LOW", 0.15))
USE_AMP = bool(g("USE_AMP", True))
GRAD_CLIP_NORM = float(g("GRAD_CLIP_NORM", 1.0))
CHECKPOINT_SAVE_AFTER_TRAINING = bool(g("CHECKPOINT_SAVE_AFTER_TRAINING", True))

CHECKPOINT_DIR = "/kaggle/working/"
CHECKPOINT_PATH = os.path.join(CHECKPOINT_DIR, "tatn_final.pt")

def safe_clear_gpu_caches():
    try:
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                try:
                    with torch.cuda.device(i):
                        torch.cuda.empty_cache()
                except Exception:
                    pass
        if gc.isenabled():
            gc.collect()
    except Exception:
        pass

def safe_tokenizer_from_pretrained(model_name: str, local_files_only: bool = False):
    try:
        from transformers import M2M100Tokenizer
        tok = M2M100Tokenizer.from_pretrained(model_name, local_files_only=local_files_only)
        
        required = ["encode", "decode", "convert_ids_to_tokens", "__call__"]
        for method in required:
            if not hasattr(tok, method):
                raise RuntimeError(f"Tokenizer missing {method}")
        
        return tok
    except Exception as e:
        print(f"[TOKENIZER] Load failed: {e}")
        raise

def resolve_dscd_stores(dscd):
    if dscd is None:
        return {}
    
    for attrname in ["prototype_stores", "_prototype_stores"]:
        stores = getattr(dscd, attrname, None)
        if isinstance(stores, dict):
            return stores
        if stores is not None:
            try:
                return dict(stores)
            except Exception:
                pass
    
    return {}

def resolve_dscd_lock(dscd):
    if dscd is None:
        return None
    
    for name in ["buffer_lock", "_buffer_lock", "clustering_lock", "_clustering_lock"]:
        lock = getattr(dscd, name, None)
        if lock is not None:
            return lock
    
    return None

def initialize_environment():
    print("[PIPELINE] Initializing environment...")
    
    if torch.cuda.is_available():
        gcnt = torch.cuda.device_count()
        print(f"[PIPELINE] GPUs: {gcnt}")
        for i in range(gcnt):
            try:
                name = torch.cuda.get_device_name(i)
                mem = torch.cuda.get_device_properties(i).total_memory / (1024**3)
                print(f"  GPU {i}: {name} ({mem:.1f} GB)")
            except Exception:
                print(f"  GPU {i}: Unknown")
        
        safe_clear_gpu_caches()
    else:
        print("[PIPELINE] CPU only")
    
    return True

def warmup_model(
    model,
    tokenizer,
    num_samples: int = 100,
    batch_size: int = 16,
    max_length: int = 48,
) -> bool:
    print("=" * 80)
    print("[WARMUP] Starting fast model warmup...")
    print("=" * 80)
    
    try:
        core = model.module if hasattr(model, "module") else model
        was_training = getattr(core, "training", False)
        core.eval()
        
        start_time = time.time()
        
        with torch.no_grad():
            dummy_input_ids = torch.randint(
                0,
                getattr(tokenizer, "vocab_size", 50000),
                (batch_size, max_length),
                dtype=torch.long,
            ).to(DEVICE)
            
            dummy_attention_mask = torch.ones(
                (batch_size, max_length),
                dtype=torch.long,
            ).to(DEVICE)
            
            print(f"[WARMUP] Running fast inference pass (batch={batch_size}, seqlen={max_length})...")
            
            core.forward(
                input_ids=dummy_input_ids,
                attention_mask=dummy_attention_mask,
                use_dscd=True,
                use_asbn=False,
                fast_inference=True,
            )
        
        elapsed = time.time() - start_time
        print(f"[WARMUP] Completed in {elapsed:.2f}s")
        
        if was_training:
            core.train()
        
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        print("[WARMUP] Model warmup successful")
        print("=" * 80)
        return True
        
    except Exception as e:
        print(f"[WARMUP] Failed: {e}")
        if VERBOSE_LOGGING:
            traceback.print_exc()
        return False

def dscd_discovery_warmup(
    model,
    tokenizer,
    num_sents: int = 4000,
    batch_size: int = 64,
    max_len: int = 48,
    timeout_per_batch: float = 30.0,
) -> bool:
    print("=" * 80)
    print("[WARMUP] Starting DSCD discovery warmup")
    print("=" * 80)
    
    try:
        if "load_and_preprocess_optimized" in globals():
            pairs = globals()["load_and_preprocess_optimized"](num_sents)
        else:
            raise RuntimeError("load_and_preprocess_optimized not found")
        
        dataset = globals()["MemoryEfficientDataset"](pairs, tokenizer, max_length=max_len)
        
        collate_fn = globals().get("safe_collate", None)
        
        dataloader = DataLoader(
            dataset,
            batch_size=batch_size,
            shuffle=False,
            num_workers=0,
            collate_fn=collate_fn,
        )
        
        print(f"[WARMUP] Processing {num_sents} sentences (batch={batch_size})...")
        
        core = model.module if hasattr(model, "module") else model
        was_training = getattr(core, "training", False)
        core.eval()
        
        dscd = getattr(core, "dscd", None)
        original_clustering_flag = None
        
        if dscd and hasattr(dscd, "enable_training_clustering"):
            original_clustering_flag = dscd.enable_training_clustering
            dscd.enable_training_clustering = False
            print("[WARMUP] Clustering DISABLED during warmup")
        
        processed = 0
        skipped = 0
        start_time = time.time()
        last_print_time = start_time
        
        with torch.no_grad():
            for batch_idx, batch in enumerate(dataloader):
                if batch is None:
                    continue
                
                batch_start = time.time()
                
                try:
                    input_ids = batch["input_ids"].to(DEVICE)
                    attention_mask = batch["attention_mask"].to(DEVICE)
                    src_texts = batch.get("src_texts", None)
                    
                    core.forward(
                        input_ids=input_ids,
                        attention_mask=attention_mask,
                        src_texts=src_texts,
                        labels=None,
                        use_dscd=True,
                        use_asbn=False,
                        fast_inference=False,
                    )
                    
                    processed += int(input_ids.size(0))
                    
                    batch_elapsed = time.time() - batch_start
                    if batch_elapsed > timeout_per_batch:
                        print(f"\n[WARMUP] Batch {batch_idx} timeout ({batch_elapsed:.1f}s) - stopping warmup")
                        break
                    
                    current_time = time.time()
                    if (current_time - last_print_time >= 5.0) or (processed % (batch_size * 5) == 0):
                        elapsed = current_time - start_time
                        rate = processed / elapsed if elapsed > 0 else 0
                        eta = (num_sents - processed) / rate if rate > 0 else 0
                        progress = (processed / num_sents) * 100.0
                        print(
                            f"\r[WARMUP] {processed}/{num_sents} ({progress:.1f}%) | "
                            f"{rate:.1f} sent/s | ETA {eta:.0f}s",
                            end="",
                        )
                        last_print_time = current_time
                
                except KeyboardInterrupt:
                    print("\n[WARMUP] Interrupted by user")
                    break
                
                except Exception as e:
                    skipped += batch_size
                    if VERBOSE_LOGGING or skipped > 100:
                        print(f"\n[WARMUP] Batch {batch_idx} failed: {e}")
                    
                    if skipped > 500:
                        print(f"\n[WARMUP] Too many failures ({skipped}) - aborting")
                        break
                    
                    continue
        
        print()
        
        if dscd and original_clustering_flag is not None:
            dscd.enable_training_clustering = original_clustering_flag
            print("[WARMUP] Clustering flag restored")
        
        if was_training:
            core.train()
        
        print("[WARMUP] Skipping discovery check to avoid stalls")
        
        if dscd and hasattr(dscd, "prototype_stores"):
            prototype_stores = resolve_dscd_stores(dscd)
            lock = resolve_dscd_lock(dscd)
            
            if lock:
                with lock:
                    stores = dict(prototype_stores)
            else:
                stores = dict(prototype_stores)
            
            def store_size(s):
                try:
                    if callable(getattr(s, "size", None)):
                        return int(s.size())
                    return int(getattr(s, "size", 0))
                except Exception:
                    return 0
            
            num_tokens = len(stores)
            total_protos = sum(store_size(s) for s in stores.values())
            multi_sense = sum(1 for s in stores.values() if store_size(s) >= 2)
            
            print(f"[WARMUP] Complete - Processed {processed}/{num_sents}")
            print(f"  - Tokens: {num_tokens}")
            print(f"  - Prototypes: {total_protos}")
            print(f"  - Multi-sense: {multi_sense}")
            print(f"  - Skipped: {skipped}")
            
            return num_tokens > 0
        
        print(f"[WARMUP] Completed {processed}/{num_sents} sentences")
        return processed > 0
        
    except Exception as e:
        print(f"[WARMUP] Critical failure: {e}")
        if VERBOSE_LOGGING:
            traceback.print_exc()
        return False
    
    finally:
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            if gc.isenabled():
                gc.collect()
        except Exception:
            pass

def main_pipeline() -> Tuple[object, object]:
    print("=" * 80)
    print("TATN MAIN PIPELINE - COMPLETE INTEGRATION")
    print("=" * 80)
    print(f"Configuration:")
    print(f"  - Span threshold: {SPAN_THRESHOLD}")
    print(f"  - Uncertainty threshold: {UNCERTAINTY_THRESHOLD}")
    print(f"  - Discovery frequency: {PERIODIC_DISCOVERY_FREQUENCY}")
    print(f"  - Epochs: {EPOCHS}")
    print(f"  - Batch size: {BATCH_SIZE}")
    print(f"  - Accumulation steps: {ACCUMULATION_STEPS}")
    print(f"  - Device: {DEVICE}")
    print("=" * 80)
    
    required_functions = ["MemoryEfficientDataset", "MemoryOptimizedTATNWithExplanations", "train_memory_efficient_tatn"]
    missing = [fn for fn in required_functions if fn not in globals()]
    if missing:
        print(f"[PIPELINE] ERROR: Missing critical functions: {missing}")
        print("[PIPELINE] Please run all previous cells (1-9) first")
        raise RuntimeError(f"Missing required functions: {missing}")
    
    optional_functions = ["comprehensive_post_training_testing"]
    missing_optional = [fn for fn in optional_functions if fn not in globals()]
    if missing_optional:
        print(f"[PIPELINE] Warning: Missing optional functions: {missing_optional}")
        print("[PIPELINE] Some evaluation features will be skipped")
    
    pipeline_start = time.time()
    
    if DEBUG_TIMING:
        phase_start = time.time()
    
    initialize_environment()
    
    if DEBUG_TIMING:
        print(f"[TIMING] Initialization: {time.time() - phase_start:.2f}s")
        phase_start = time.time()
    
    print("[PHASE 1] Loading tokenizer...")
    tokenizer = safe_tokenizer_from_pretrained("facebook/m2m100_418M")
    
    try:
        tokenizer.src_lang = SOURCE_LANGUAGE
    except Exception:
        pass
    
    try:
        if not hasattr(tokenizer, "pad_token_id") or tokenizer.pad_token_id is None:
            if hasattr(tokenizer, "add_special_tokens"):
                tokenizer.add_special_tokens({"pad_token": "<pad>"})
    except Exception:
        pass
    
    vocab_size = getattr(tokenizer, "vocab_size", None)
    if vocab_size is None:
        try:
            vocab_size = len(tokenizer)
        except Exception:
            vocab_size = "unknown"
    
    print(f"[PHASE 1] Tokenizer loaded (vocab: {vocab_size})")
    
    if DEBUG_TIMING:
        print(f"[TIMING] Tokenizer: {time.time() - phase_start:.2f}s")
        phase_start = time.time()
    
    print(f"[PHASE 2] Loading data ({NUM_SAMPLES} samples)...")
    
    if "load_and_preprocess_optimized" in globals():
        try:
            pairs = globals()["load_and_preprocess_optimized"](NUM_SAMPLES)
        except Exception as e:
            print(f"[PHASE 2] Data loading failed: {e}")
            pairs = [("আমি কল বন্ধ করেছি।", "I turned off the tap.")]
    else:
        print("[PHASE 2] Using fallback data")
        pairs = [("আমি কল বন্ধ করেছি।", "I turned off the tap.")]
    
    try:
        dataset = globals()["MemoryEfficientDataset"](pairs, tokenizer, max_length=MAX_LENGTH)
    except Exception as e:
        print(f"[PHASE 2] Dataset creation failed: {e}")
        raise RuntimeError(f"Failed to create dataset: {e}")
    
    collate_fn = globals().get("safe_collate", None)
    
    if "create_optimized_dataloader" in globals():
        try:
            train_loader = globals()["create_optimized_dataloader"](dataset, batch_size=BATCH_SIZE, shuffle=True)
        except Exception as e:
            print(f"[PHASE 2] create_optimized_dataloader failed: {e}, using DataLoader")
            dataloader_kwargs = {
                "batch_size": BATCH_SIZE,
                "shuffle": True,
                "num_workers": 0,
                "pin_memory": torch.cuda.is_available(),
            }
            if collate_fn is not None:
                dataloader_kwargs["collate_fn"] = collate_fn
            train_loader = DataLoader(dataset, **dataloader_kwargs)
    else:
        dataloader_kwargs = {
            "batch_size": BATCH_SIZE,
            "shuffle": True,
            "num_workers": 0,
            "pin_memory": torch.cuda.is_available(),
        }
        if collate_fn is not None:
            dataloader_kwargs["collate_fn"] = collate_fn
        else:
            print("[PHASE 2] Warning: safe_collate not found, using default collation")
        train_loader = DataLoader(dataset, **dataloader_kwargs)
    
    try:
        print(f"[PHASE 2] Dataset: {len(dataset)} samples, {len(train_loader)} batches")
    except Exception:
        print("[PHASE 2] Dataset loaded")
    
    del pairs
    safe_clear_gpu_caches()
    
    if DEBUG_TIMING:
        print(f"[TIMING] Data loading: {time.time() - phase_start:.2f}s")
        phase_start = time.time()
    
    print("[PHASE 3] Initializing model...")
    model_core = globals()["MemoryOptimizedTATNWithExplanations"](tokenizer)
    
    if USE_MULTI_GPU and NUM_GPUS > 1:
        device_ids = list(range(NUM_GPUS))
        print(f"[PHASE 3] Using DataParallel on {device_ids}")
        model = nn.DataParallel(model_core, device_ids=device_ids)
    else:
        model = model_core
    
    model = model.to(DEVICE)
    core_model = model.module if hasattr(model, "module") else model
    
    try:
        mbart = getattr(core_model, "mbart", None)
        if mbart and hasattr(mbart, "resize_token_embeddings"):
            try:
                current_size = mbart.get_input_embeddings().num_embeddings
                if isinstance(vocab_size, int):
                    target_size = vocab_size
                else:
                    target_size = current_size
                
                if current_size != target_size:
                    mbart.resize_token_embeddings(target_size)
                    print(f"[PHASE 3] Resized embeddings: {current_size} -> {target_size}")
            except Exception:
                pass
    except Exception:
        pass
    
    if FREEZE_ENCODER:
        try:
            for p in core_model.mbart.model.encoder.parameters():
                p.requires_grad = False
            print("[PHASE 3] Encoder frozen")
        except Exception:
            pass
    
    print(f"[PHASE 3] Model initialized")
    
    if DEBUG_TIMING:
        print(f"[TIMING] Model init: {time.time() - phase_start:.2f}s")
        phase_start = time.time()
    
    print("[PHASE 3.5] Initial model warmup...")
    
    try:
        warmup_success = warmup_model(model, tokenizer, num_samples=100, batch_size=16, max_length=MAX_LENGTH)
        if warmup_success:
            print("[PHASE 3.5] Initial warmup successful")
        else:
            print("[PHASE 3.5] Initial warmup completed with issues")
    except Exception as e:
        print(f"[PHASE 3.5] Initial warmup failed: {e}")
    
    if DEBUG_TIMING:
        print(f"[TIMING] Initial warmup: {time.time() - phase_start:.2f}s")
        phase_start = time.time()
    
    print(f"[PHASE 4] Training for {EPOCHS} epoch(s)...")
    print("[PHASE 4] Creating optimizers...")
    
    try:
        nmt_params = []
        phi_params = []
        
        if hasattr(core_model, "mbart") and core_model.mbart is not None:
            nmt_params.extend([p for p in core_model.mbart.parameters() if p.requires_grad])
        
        if hasattr(core_model, "dscd") and core_model.dscd is not None:
            phi_params.extend([p for p in core_model.dscd.parameters() if p.requires_grad])
        
        if hasattr(core_model, "asbn") and core_model.asbn is not None:
            phi_params.extend([p for p in core_model.asbn.parameters() if p.requires_grad])
        
        if hasattr(core_model, "trg_system") and core_model.trg_system is not None:
            phi_params.extend([p for p in core_model.trg_system.parameters() if p.requires_grad])
        
        if not nmt_params:
            nmt_params = [p for p in model.parameters() if p.requires_grad]
        
        optimizer = torch.optim.AdamW(nmt_params, lr=LR_NMT, weight_decay=0.01)
        print(f"[PHASE 4] Created optimizer with {len(nmt_params)} parameters (lr={LR_NMT})")
        
        phi_optimizer = None
        if phi_params:
            phi_optimizer = torch.optim.AdamW(phi_params, lr=LR_PHI, weight_decay=0.01)
            print(f"[PHASE 4] Created phi_optimizer with {len(phi_params)} parameters (lr={LR_PHI})")
        else:
            print("[PHASE 4] No phi parameters found, phi_optimizer=None")
        
    except Exception as e:
        print(f"[PHASE 4] Optimizer creation failed: {e}")
        if VERBOSE_LOGGING:
            traceback.print_exc()
        raise
    
    try:
        train_fn = globals()["train_memory_efficient_tatn"]
        
        trained_model = train_fn(
            model=model,
            tokenizer=tokenizer,
            train_loader=train_loader,
            optimizer=optimizer,
            phi_optimizer=phi_optimizer,
            epochs=EPOCHS,
            accumulation_steps=ACCUMULATION_STEPS,
            validate_every=VALIDATION_CHECK_INTERVAL,
            enable_validation=True,
            device=DEVICE,
        )
        
        print("[PHASE 4] Training completed successfully")
        
    except Exception as e:
        print(f"[PHASE 4] Training failed: {e}")
        if VERBOSE_LOGGING:
            traceback.print_exc()
        trained_model = model
    
    if DEBUG_TIMING:
        print(f"[TIMING] Training: {time.time() - phase_start:.2f}s")
        phase_start = time.time()
    
    if CHECKPOINT_SAVE_AFTER_TRAINING:
        print("[PHASE 5] Saving checkpoint...")
        
        try:
            core_model_to_save = trained_model.module if hasattr(trained_model, "module") else trained_model
            
            checkpoint = {
                "model_state_dict": core_model_to_save.state_dict(),
                "epoch": EPOCHS,
                "config": {
                    "source_lang": SOURCE_LANGUAGE,
                    "target_lang": TARGET_LANGUAGE,
                    "num_samples": NUM_SAMPLES,
                    "batch_size": BATCH_SIZE,
                    "max_length": MAX_LENGTH,
                }
            }
            
            if hasattr(core_model_to_save, "dscd") and hasattr(core_model_to_save.dscd, "state_dict"):
                try:
                    checkpoint["dscd_state"] = core_model_to_save.dscd.state_dict()
                    print("[PHASE 5] DSCD state included in checkpoint")
                except Exception:
                    pass
            
            torch.save(checkpoint, CHECKPOINT_PATH)
            
            size_mb = os.path.getsize(CHECKPOINT_PATH) / (1024**2)
            print(f"[PHASE 5] Checkpoint saved: {CHECKPOINT_PATH} ({size_mb:.1f} MB)")
            
        except Exception as e:
            print(f"[PHASE 5] Checkpoint save failed: {e}")
        
        if DEBUG_TIMING:
            print(f"[TIMING] Checkpoint save: {time.time() - phase_start:.2f}s")
            phase_start = time.time()
    
    if "comprehensive_post_training_testing" in globals():
        print("[PHASE 6] Running evaluation...")
        
        try:
            eval_fn = globals()["comprehensive_post_training_testing"]
            results = eval_fn(trained_model, tokenizer, run_warmup=False)
            
            print(f"[PHASE 6] Evaluation completed")
            print(f"  - Success rate: {results.get('success_rate_pct', 0):.1f}%")
            print(f"  - Total explanations: {results.get('total_explanations', 0)}")
            
        except Exception as e:
            print(f"[PHASE 6] Evaluation failed: {e}")
            if VERBOSE_LOGGING:
                traceback.print_exc()
        
        if DEBUG_TIMING:
            print(f"[TIMING] Evaluation: {time.time() - phase_start:.2f}s")
    else:
        print("[PHASE 6] Evaluation skipped (function not available)")
    
    pipeline_duration = time.time() - pipeline_start
    
    print("=" * 80)
    print(f"PIPELINE COMPLETED IN {pipeline_duration:.1f}s")
    print("=" * 80)
    
    safe_clear_gpu_caches()
    
    return trained_model, tokenizer

print("=" * 80)
print("Cell 10: Main pipeline ready - COMPLETE VERSION")
print("=" * 80)


Cell 10: Main pipeline ready - COMPLETE VERSION


In [None]:
# ==============================================================================
# CELL 11: MAIN EXECUTION WRAPPER (FINAL) - FIXED
# ==============================================================================
from datetime import datetime, timezone
import os
import traceback
import math
import sys
import time
import torch
import gc

try:
    _NUM_SAMPLES = int(globals().get('NUM_SAMPLES', 30000))
    _EPOCHS = int(globals().get('EPOCHS', 2))
    _BATCH_SIZE = int(globals().get('BATCH_SIZE', 4))
    _ACCUMULATION_STEPS = int(globals().get('ACCUMULATION_STEPS', 16))
    
    # FIX: Robust device loading
    raw_device = globals().get('DEVICE', "cuda" if torch.cuda.is_available() else "cpu")
    if isinstance(raw_device, torch.device):
        _DEVICE = raw_device
    else:
        _DEVICE = torch.device(str(raw_device))

    _ENABLE_ASBN_TRAINING = bool(globals().get('ENABLE_ASBN_TRAINING', True))
    _ENABLE_TRG_INFERENCE = bool(globals().get('ENABLE_TRG_INFERENCE', True))
    _PERIODIC_DISCOVERY_FREQUENCY = int(globals().get('PERIODIC_DISCOVERY_FREQUENCY', 3000))
    _VERBOSE_LOGGING = bool(globals().get('VERBOSE_LOGGING', False))
    _DEBUG_DISCOVERY = bool(globals().get('DEBUG_DISCOVERY', False))
    _DEBUG_TIMING = bool(globals().get('DEBUG_TIMING', False))
    _NUM_GPUS = int(globals().get('NUM_GPUS', torch.cuda.device_count() if torch.cuda.is_available() else 0))
    _USE_MULTI_GPU = bool(globals().get('USE_MULTI_GPU', _NUM_GPUS > 1))
    _SPAN_THRESHOLD = float(globals().get('SPAN_THRESHOLD', 0.20))
    _TAU_LOW = float(globals().get('TAU_LOW', 0.15))
    
    raw_list = globals().get('HOMOGRAPH_REFERENCE_LIST_BN', ["কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা"])
    _HOMOGRAPH_REFERENCE_LIST_BN = set(str(w) for w in raw_list)
    cell0_loaded = 'NUM_SAMPLES' in globals()
    
except (NameError, TypeError, ValueError) as e:
    print(f"[EXEC] Config load error: {e}")
    _NUM_SAMPLES = 30000
    _EPOCHS = 2
    _BATCH_SIZE = 4
    _ACCUMULATION_STEPS = 16
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _ENABLE_ASBN_TRAINING = True
    _ENABLE_TRG_INFERENCE = True
    _PERIODIC_DISCOVERY_FREQUENCY = 3000
    _VERBOSE_LOGGING = False
    _DEBUG_DISCOVERY = False
    _DEBUG_TIMING = False
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = (_NUM_GPUS > 1)
    _SPAN_THRESHOLD = 0.20
    _TAU_LOW = 0.15
    _HOMOGRAPH_REFERENCE_LIST_BN = {"কল", "কাল", "পাতা", "ব্যাংক", "ফল", "মাথা"}
    cell0_loaded = False
    print("[EXEC] Using fallback configuration (Cell 0 not executed)")

_CHECKPOINT_PATH = "/kaggle/working/tatn_final.pt"

def _safe_div_ceil(a: int, b: int) -> int:
    try:
        if isinstance(a, int) and isinstance(b, int) and b > 0:
            return math.ceil(a / b)
    except Exception:
        pass
    return 0

def _format_duration(seconds: float) -> str:
    if seconds < 60:
        return f"{seconds:.1f}s"
    elif seconds < 3600:
        return f"{seconds/60:.1f}min"
    else:
        return f"{seconds/3600:.2f}hr"

def _safe_get(d: dict, *keys, default=None):
    if not isinstance(d, dict):
        return default
    result = d
    for key in keys:
        if not isinstance(result, dict):
            return default
        result = result.get(key, default)
        if result is default:
            return default
    return result

def _get_dscd_homographs(model):
    try:
        core = model.module if hasattr(model, 'module') else model
        dscd = getattr(core, 'dscd', None)

        if dscd and hasattr(dscd, 'get_discovered_homographs'):
            try:
                return dscd.get_discovered_homographs()
            except Exception:
                pass

        if dscd and hasattr(dscd, 'prototype_stores'):
            homographs = set()

            lock = None
            if hasattr(dscd, 'buffer_lock'):
                lock = dscd.buffer_lock
            elif hasattr(dscd, 'clustering_lock'):
                lock = dscd.clustering_lock

            # FIX: Safe locking
            if lock:
                with lock:
                    stores = dict(dscd.prototype_stores)
            else:
                stores = dict(dscd.prototype_stores)

            for token, store in stores.items():
                try:
                    if hasattr(store, 'size') and callable(getattr(store, 'size', None)):
                        size_ok = store.size() >= 2
                    else:
                        size_val = getattr(store, 'size', None)
                        if isinstance(size_val, int):
                            size_ok = size_val >= 2
                        else:
                            size_ok = False
                except Exception:
                    size_ok = False

                if size_ok:
                    clean = str(token).replace(' ', '').replace('Ġ', '').replace('##', '').strip().lower()
                    homographs.add(clean)
            return homographs
    except Exception:
        pass
    return set()

def _safe_cleanup():
    try:
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                try:
                    with torch.cuda.device(i):
                        torch.cuda.empty_cache()
                except Exception:
                    pass
        if gc.isenabled():
            gc.collect()
    except Exception:
        pass

if __name__ == "__main__":
    print("=" * 80)
    print("MEMORY-OPTIMIZED TATN - COMPLETE EXECUTION")
    print("=" * 80)

    user_login = os.getenv("KAGGLE_USERNAME") or os.getenv("USER") or "manas0003"
    start_time = time.time()
    now_utc = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")

    print(f"User: {user_login}")
    print(f"Started: {now_utc}")

    print("\n[CONFIGURATION]")
    print(f"  Cell 0 status: {'Loaded' if cell0_loaded else 'Using fallbacks'}")
    print(f"  Samples: {_NUM_SAMPLES}")
    print(f"  Epochs: {_EPOCHS}")
    print(f"  Batch Size: {_BATCH_SIZE}")
    print(f"  Accumulation: {_ACCUMULATION_STEPS}")
    print(f"  Device: {_DEVICE}")
    print(f"  Multi-GPU: {'ENABLED' if _USE_MULTI_GPU else 'DISABLED'} ({_NUM_GPUS} GPUs)")
    print(f"  Span threshold: {_SPAN_THRESHOLD}")
    print(f"  Uncertainty threshold: {_TAU_LOW}")
    print(f"  Discovery frequency: {_PERIODIC_DISCOVERY_FREQUENCY}")

    if _USE_MULTI_GPU and _NUM_GPUS > 0:
        per_gpu = _safe_div_ceil(_BATCH_SIZE, _NUM_GPUS)
        print(f"  Batch per GPU: {per_gpu}")

    print(f"  ASBN: {'Enabled' if _ENABLE_ASBN_TRAINING else 'Disabled'}")
    print(f"  TRG: {'Enabled' if _ENABLE_TRG_INFERENCE else 'Disabled'}")
    print(f"  Debug: {'Enabled' if _DEBUG_DISCOVERY else 'Disabled'}")
    print("=" * 80)

    trained_model, tokenizer = None, None
    pipeline_success = False
    failure_category = None
    failure_details = ""

    if 'main_pipeline' not in globals():
        print("\nERROR: main_pipeline not found")
        print("   -> Run Cell 10 before executing Cell 11")
        failure_category = "MISSING_DEPENDENCY"
        failure_details = "Cell 10 not executed"
    else:
        try:
            print("\nStarting pipeline...")

            if _DEBUG_TIMING:
                print("   Expected: ~15-45 min (config dependent)")

            pipeline_start = time.time()
            trained_model, tokenizer = main_pipeline()
            pipeline_duration = time.time() - pipeline_start

            print(f"\nPipeline completed: {_format_duration(pipeline_duration)}")
            pipeline_success = True

        except KeyboardInterrupt:
            print("\nInterrupted by user")
            failure_category = "USER_INTERRUPT"
            failure_details = "Manual stop"

        except RuntimeError as e:
            msg = str(e).lower()

            if "tokenizer" in msg or "sentencepiece" in msg:
                print("\nTokenizer error")
                failure_category = "TOKENIZER_ERROR"
                failure_details = str(e)[:200]

                print("\nFix:")
                print("   ! pip install transformers==4.30.2 sentencepiece tokenizers")
                print("   Then RESTART kernel and re-run Cells 0-11")

            elif "out of memory" in msg:
                print("\nOut of Memory")
                failure_category = "OOM_ERROR"
                failure_details = "GPU OOM"

                print("\nFixes:")
                print("   1. Reduce BATCH_SIZE (try 2-4)")
                print("   2. Reduce NUM_SAMPLES (try 10k-20k)")
                print("   3. Increase ACCUMULATION_STEPS (32-64)")

            else:
                print(f"\nRuntime error: {type(e).__name__}")
                print(f"   {str(e)[:400]}")
                failure_category = "RUNTIME_ERROR"
                failure_details = str(e)[:200]

            if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                print("\n[TRACEBACK]")
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        except Exception as e:
            print(f"\nUnexpected error: {type(e).__name__}")
            print(f"   {str(e)[:400]}")
            failure_category = "UNKNOWN_ERROR"
            failure_details = str(e)[:200]

            if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
                print("\n[TRACEBACK]")
                try:
                    traceback.print_exc()
                except Exception:
                    pass

    if pipeline_success and trained_model is not None and tokenizer is not None:
        print("\n" + "=" * 80)
        print("PIPELINE SUCCEEDED")
        print("=" * 80)

        print("\n[CHECKPOINT]")
        checkpoint_valid = False

        try:
            if os.path.exists(_CHECKPOINT_PATH):
                size_mb = os.path.getsize(_CHECKPOINT_PATH) / (1024**2)
                print(f"  File: {_CHECKPOINT_PATH}")
                print(f"  Size: {size_mb:.1f} MB")

                ckpt = torch.load(_CHECKPOINT_PATH, map_location='cpu')

                has_model = 'model_state_dict' in ckpt and len(ckpt['model_state_dict']) > 0
                has_dscd = 'dscd_state' in ckpt and len(ckpt.get('dscd_state', {})) > 0

                print(f"  Model: {'Present' if has_model else 'MISSING'}")
                print(f"  DSCD: {'Present' if has_dscd else 'MISSING'}")

                if has_dscd:
                    num_tokens = len(ckpt['dscd_state'].get('prototype_stores_data', {})) # Changed to match DSCD format
                    print(f"  Tokens: {num_tokens}")

                    if num_tokens > 0:
                        checkpoint_valid = True
                        print("  Status: VALID")
                    else:
                        print("  Status: EMPTY DSCD")
                else:
                    print("  Status: MISSING DSCD")
            else:
                print(f"  NOT FOUND: {_CHECKPOINT_PATH}")

        except Exception as e:
            print(f"  Validation failed: {e}")

        print("\n[COMPONENTS]")

        try:
            core = trained_model.module if hasattr(trained_model, 'module') else trained_model

            dscd = getattr(core, 'dscd', None)
            if dscd and hasattr(dscd, 'get_prototype_summary'):
                try:
                    dscd_stats = dscd.get_prototype_summary()
                    print("  DSCD:")
                    print(f"    - Tokens: {dscd_stats.get('total_tokens', 0)}")
                    print(f"    - Prototypes: {dscd_stats.get('total_prototypes', 0)}")
                    print(f"    - Homographs: {dscd_stats.get('num_homographs', 0)}")
                except Exception:
                    pass

            asbn = getattr(core, 'asbn', None)
            if asbn and hasattr(asbn, 'get_detailed_stats'):
                try:
                    asbn_stats = asbn.get_detailed_stats()
                    print("  ASBN:")
                    print(f"    - Domain accuracy: {asbn_stats.get('domain_accuracy', 0):.2%}")
                    if 'source_accuracy' in asbn_stats:
                        print(f"    - Source: {asbn_stats['source_accuracy']:.2%}")
                        print(f"    - Target: {asbn_stats['target_accuracy']:.2%}")
                except Exception:
                    pass

            trg = getattr(core, 'trg_system', None)
            if trg and hasattr(trg, 'get_statistics'):
                try:
                    trg_stats = trg.get_statistics()
                    print("  TRG:")
                    print(f"    - Explanations: {trg_stats.get('explanations_generated', 0)}")
                    print(f"    - High confidence: {trg_stats.get('high_confidence_rate', 0):.1%}")
                    print(f"    - DSCD homograph rate: {trg_stats.get('dscd_homograph_rate', 0):.1%}")
                except Exception:
                    pass

        except Exception as e:
            print(f"  Stats failed: {e}")

        print("\n[METRICS]")

        try:
            if os.path.exists(_CHECKPOINT_PATH):
                ckpt = torch.load(_CHECKPOINT_PATH, map_location='cpu')

                training_stats = ckpt.get('training_stats', {})
                if training_stats:
                    total_loss = training_stats.get('total_loss', [])
                    updates = training_stats.get('optimizer_updates', 0)

                    print("  Training:")
                    print(f"    - Updates: {updates}")
                    if total_loss:
                        if len(total_loss) >= 100:
                            final = sum(total_loss[-100:]) / len(total_loss[-100:])
                        else:
                            final = sum(total_loss) / len(total_loss)
                        print(f"    - Final loss: {final:.6f}")

                eval_results = ckpt.get('eval_results', {})
                baseline = ckpt.get('baseline_metrics', {})

                if eval_results:
                    final_success = eval_results.get('success_rate_pct', 0)
                    total_expl = eval_results.get('total_explanations', 0)

                    print("  Evaluation:")
                    if baseline:
                        baseline_success = baseline.get('success_rate_pct', 0)
                        improvement = final_success - baseline_success
                        print(f"    - Baseline -> Final: {baseline_success:.1f}% -> {final_success:.1f}%")
                        print(f"    - Improvement: {improvement:+.1f}%")
                    else:
                        print(f"    - Success: {final_success:.1f}%")

                    print(f"    - Explanations: {total_expl}")

                    quality = eval_results.get('quality_metrics', {})
                    if quality:
                        print(f"    - Avg confidence: {quality.get('avg_confidence', 0):.3f}")

        except Exception as e:
            print(f"  Metrics failed: {e}")

        print("\n[INFERENCE VALIDATION]")
        print("Testing disambiguation on ambiguous sentences...")
        print("-" * 80)

        _safe_cleanup()

        inference_success = 0
        inference_failed = 0
        dscd_homographs_detected = set()

        dscd_homographs = _get_dscd_homographs(trained_model)
        print(f"DSCD discovered: {len(dscd_homographs)} homographs")
        if dscd_homographs and _DEBUG_DISCOVERY:
            print(f"  Sample: {list(dscd_homographs)[:10]}")

        test_sentences = [
            ("আমি কল বন্ধ করেছি।", "কল (tap/call)"),
            ("কাল আমি বই কিনব।", "কাল (tomorrow/yesterday)"),
            ("পাতা ঝরে পড়েছে।", "পাতা (leaf/page)"),
        ]

        inference_times = []

        try:
            if 'translate_with_explanations' not in globals():
                print("translate_with_explanations not available")
                print("   -> Run Cell 8 before Cell 11")
            else:
                for idx, (sentence, desc) in enumerate(test_sentences, 1):
                    try:
                        print(f"\n{idx}.  {desc}")
                        print(f"   Input: {sentence}")

                        inf_start = time.time()
                        res = translate_with_explanations(trained_model, tokenizer, sentence)
                        inf_time = time.time() - inf_start
                        inference_times.append(inf_time)

                        if isinstance(res, dict):
                            translation = res.get('translation', 'N/A')
                            amb_count = res.get('ambiguous_words_detected', 0)
                            exs = res.get('explanations', []) or []

                            print(f"   Translation: {translation}")
                            print(f"   Ambiguous: {amb_count}")
                            print(f"   Time: {inf_time:.3f}s")

                            if exs:
                                for exp in exs:
                                    word = exp.get('ambiguous_word', exp.get('token', 'N/A'))
                                    clean = str(word).replace(' ', '').replace('Ġ', '').strip().lower()

                                    if clean in dscd_homographs:
                                        dscd_homographs_detected.add(clean)

                                    try:
                                        conf = float(exp.get('confidence', 0.5))
                                        span = float(exp.get('span', 0.0))
                                        u = float(exp.get('uncertainty', 0.0))
                                        print(f"   -> '{word}': conf={conf:.3f}, s={span:.3f}, u={u:.3f}")
                                    except Exception:
                                        print(f"   -> '{word}': (no metrics)")

                                inference_success += 1
                            else:
                                print("   No explanations")
                                inference_success += 1
                        else:
                            print("   Unexpected format")
                            inference_failed += 1

                        _safe_cleanup()

                    except Exception as e:
                        print(f"   Failed: {type(e).__name__}")
                        inference_failed += 1

                print("\n" + "-" * 80)
                print(f"Results: {inference_success}/{len(test_sentences)} successful")

                if inference_times:
                    avg_time = sum(inference_times) / len(inference_times)
                    print(f"Performance: {avg_time:.3f}s avg per sentence")

                if dscd_homographs_detected:
                    print(f"DSCD homographs detected: {', '.join(sorted(dscd_homographs_detected))}")
                else:
                    print("No DSCD homographs detected")
                    if len(dscd_homographs) == 0:
                        print("   -> DSCD has no discoveries (run warmup)")
                    else:
                        print(f"   -> Check TRG thresholds (span={_SPAN_THRESHOLD}, u={_TAU_LOW})")

        except Exception as e:
            print(f"Validation failed: {e}")
            if _DEBUG_DISCOVERY:
                try:
                    traceback.print_exc()
                except Exception:
                    pass

        print("\n[SYSTEM TEST]")

        try:
            core = trained_model.module if hasattr(trained_model, 'module') else trained_model

            dscd_ok = hasattr(core, 'dscd') and hasattr(core.dscd, 'forward')
            asbn_ok = hasattr(core, 'asbn') and hasattr(core.asbn, 'forward')
            trg_ok = hasattr(core, 'trg_system') and hasattr(core.trg_system, 'process_sentence_for_explanations')
            mbart_ok = hasattr(core, 'mbart') and hasattr(core.mbart, 'generate')

            print("  Component status:")
            print(f"    - DSCD: {'OK' if dscd_ok else 'MISSING'}")
            print(f"    - ASBN: {'OK' if asbn_ok else 'MISSING'}")
            print(f"    - TRG: {'OK' if trg_ok else 'MISSING'}")
            print(f"    - M2M100: {'OK' if mbart_ok else 'MISSING'}")

            all_ok = dscd_ok and asbn_ok and trg_ok and mbart_ok

            if all_ok:
                print("  All components operational")
            else:
                print("  Some components missing")

        except Exception as e:
            print(f"  Test failed: {e}")

        print("\n" + "=" * 80)
        print("NEXT STEPS")
        print("=" * 80)

        print("\n1. Single translation:")
        print("   result = translate_with_explanations(trained_model, tokenizer, 'আমি কল বন্ধ করেছি।')")

        print("\n2. Batch translation:")
        print("   for sent in sentences:")
        print("       res = translate_with_explanations(trained_model, tokenizer, sent)")

        print("\n3. Load checkpoint:")
        print("   ckpt = torch.load('/kaggle/working/tatn_final.pt')")
        print("   model.load_state_dict(ckpt['model_state_dict'])")
        print("   model.dscd.load_state_dict(ckpt['dscd_state'])")

        print("\n4. Full evaluation:")
        print("   results = comprehensive_post_training_testing(trained_model, tokenizer)")

        print("\n5. Demo:")
        print("   demonstrate_system(trained_model, tokenizer)")

        if not checkpoint_valid:
            print("\nCheckpoint needs verification - re-run Cell 10 if needed")

        print("\n" + "=" * 80)

    else:
        print("\n" + "=" * 80)
        print("PIPELINE FAILED")
        print("=" * 80)

        print(f"\nCategory: {failure_category or 'UNKNOWN'}")
        if failure_details:
            print(f"Details: {failure_details[:200]}")

        print("\n[DIAGNOSTICS]")

        components = {
            'Cell 0': 'NUM_SAMPLES' in globals(),
            'Cell 1': 'reconstruct_word_spans' in globals(),
            'Cell 2': 'MemoryEfficientDataset' in globals(),
            'Cell 3': 'MemoryEfficientDSCDOnline' in globals(),
            'Cell 4': 'MemoryEfficientASBNModule' in globals(),
            'Cell 5': 'CompleteTRGWithExplanations' in globals(),
            'Cell 6': 'MemoryOptimizedTATNWithExplanations' in globals(),
            'Cell 7': 'train_memory_efficient_tatn' in globals(),
            'Cell 8': 'translate_with_explanations' in globals(),
            'Cell 9': 'comprehensive_post_training_testing' in globals(),
            'Cell 10': 'main_pipeline' in globals(),
        }

        all_present = True
        for comp, present in components.items():
            status = "OK" if present else "MISSING"
            print(f"  {status} {comp}")
            if not present:
                all_present = False

        print("\n[RECOVERY]")

        if failure_category == "MISSING_DEPENDENCY":
            print("\n-> Run Cells 0-10 in sequence, then re-run Cell 11")

        elif failure_category == "TOKENIZER_ERROR":
            print("\n-> Install dependencies:")
            print("  ! pip install transformers==4.30.2 sentencepiece tokenizers")
            print("  Then RESTART kernel and re-run Cells 0-11")

        elif failure_category == "OOM_ERROR":
            print("\n-> Reduce memory in Cell 0:")
            print("  BATCH_SIZE = 2")
            print("  NUM_SAMPLES = 15000")
            print("  ACCUMULATION_STEPS = 32")
            print("  Then re-run Cells 0-11")

        elif failure_category == "RUNTIME_ERROR":
            print("\n-> Enable debug in Cell 0:")
            print("  VERBOSE_LOGGING = True")
            print("  DEBUG_DISCOVERY = True")
            print("  Then re-run Cell 11 for details")

        elif failure_category == "USER_INTERRUPT":
            print("\n-> Check checkpoint exists:")
            print(f"  os.path.exists('{_CHECKPOINT_PATH}')")
            print("  If yes, can load and skip training")
            print("  If no, re-run Cell 11")

        else:
            print("\n-> General steps:")
            print("  1.  Enable DEBUG in Cell 0")
            print("  2. Re-run Cells 0-11")
            print("  3. Check GPU: torch.cuda.is_available()")
            print("  4.  Verify data loaded")

        print("\n" + "=" * 80)

    total_duration = time.time() - start_time
    end_utc = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")

    print("\n" + "=" * 80)
    print("EXECUTION SUMMARY")
    print("=" * 80)
    print(f"User: {user_login}")
    print(f"Started: {now_utc}")
    print(f"Finished: {end_utc}")
    print(f"Duration: {_format_duration(total_duration)}")

    if pipeline_success:
        print("Status: SUCCESS")
        if 'checkpoint_valid' in locals() and checkpoint_valid:
            print("Checkpoint: VALID")
        else:
            print("Checkpoint: CHECK NEEDED")
    else:
        print(f"Status: FAILED ({failure_category or 'UNKNOWN'})")

    print("=" * 80)

    _safe_cleanup()

print("\n" + "=" * 80)
print("Cell 11: Execution wrapper ready (FINAL) - FIXED")
print("=" * 80 + "\n")


MEMORY-OPTIMIZED TATN - COMPLETE EXECUTION
User: manas0003
Started: 2026-01-07 14:57:46 UTC

[CONFIGURATION]
  Cell 0 status: Loaded
  Samples: 30000
  Epochs: 1
  Batch Size: 100
  Accumulation: 16
  Device: cuda:0
  Multi-GPU: ENABLED (2 GPUs)
  Span threshold: 0.2
  Uncertainty threshold: 0.15
  Discovery frequency: 200
  Batch per GPU: 50
  ASBN: Enabled
  TRG: Enabled
  Debug: Disabled

Starting pipeline...
   Expected: ~15-45 min (config dependent)
TATN MAIN PIPELINE - COMPLETE INTEGRATION
Configuration:
  - Span threshold: 0.2
  - Uncertainty threshold: 0.15
  - Discovery frequency: 200
  - Epochs: 1
  - Batch size: 100
  - Accumulation steps: 16
  - Device: cuda:0
[PIPELINE] Initializing environment...
[PIPELINE] GPUs: 2
  GPU 0: Tesla T4 (14.7 GB)
  GPU 1: Tesla T4 (14.7 GB)
[TIMING] Initialization: 0.35s
[PHASE 1] Loading tokenizer...


vocab.json: 0.00B [00:00, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/298 [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

config.json:   0%|          | 0.00/908 [00:00<?, ?B/s]

[PHASE 1] Tokenizer loaded (vocab: 128104)
[TIMING] Tokenizer: 1.42s
[PHASE 2] Loading data (30000 samples)...
[CELL2] Loading up to 30000 samples from local CSV: /kaggle/input/bn-homo/bn_homograph_complete_dataset.csv
[CELL2] Reading CSV file...
[CELL2] Detected src=English, tgt=Bengali: Swapping columns for bn→en task.
[CELL2] Swap successful: src=Bengali, tgt=English
[CELL2] Processing 30000 rows from CSV...


Loading dataset: 100%|██████████| 30000/30000 [00:00<00:00, 114972.71it/s]
Using cls_token, but it is not set yet.
Using cls_token, but it is not set yet.
Using mask_token, but it is not set yet.
Using mask_token, but it is not set yet.


[CELL2] Loaded 30000 pairs from CSV, skipped 0 rows
[CELL2] Dataset initialized: 30000 valid pairs, 0 invalid
[CELL2] DataLoader created: total_batch=100, per_gpu=50, workers=2
[PHASE 2] Dataset: 30000 samples, 300 batches
[TIMING] Data loading: 0.83s
[PHASE 3] Initializing model...


pytorch_model.bin:   0%|          | 0.00/1.94G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/233 [00:00<?, ?B/s]

[PHASE 3] Using DataParallel on [0, 1]
[PHASE 3] Resized embeddings: 128112 -> 128104
[PHASE 3] Model initialized
[TIMING] Model init: 15.49s
[PHASE 3.5] Initial model warmup...
[WARMUP] Starting fast model warmup...
[WARMUP] Running fast inference pass (batch=16, seqlen=48)...


E0000 00:00:1767797887.416796      55 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767797887.476643      55 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767797887.986171      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767797887.986199      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767797887.986207      55 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767797887.986209      55 computation_placer.cc:177] computation placer already registered. Please check linka

[WARMUP] Completed in 15.51s
[WARMUP] Model warmup successful
[PHASE 3.5] Initial warmup successful
[TIMING] Initial warmup: 15.51s
[PHASE 4] Training for 1 epoch(s)...
[PHASE 4] Creating optimizers...
[PHASE 4] Created optimizer with 509 parameters (lr=2e-05)
[PHASE 4] Created phi_optimizer with 44 parameters (lr=1e-05)
[TRAIN] Starting training: epochs=1, batch=100, accum_steps=16
[TRAIN] Validation: enabled
[TRAIN] DP enabled: True, GPUs: 2, Device: cuda:0
[TRAIN] Discovery frequency: 200 steps
[TRAIN] Checkpoint: Will save to /kaggle/working/tatn_final.pt after all epochs

EPOCH 1/1 STARTED
[TRAIN] TRG statistics reset for epoch 1
[TRAIN] ASBN statistics reset for epoch 1


Epoch 1/1:  66%|██████▋   | 199/300 [45:41<23:32, 13.98s/it, fwd_loss=4.3094 bwd_loss=0.2693 rate=100.0% clusters=6357 next_disc_in=1]   


[TRAIN] Triggering periodic discovery at step 200...
[TRAIN-DEBUG] GPU mem (GB):
  GPU 0: alloc=7.25 resv=12.20
  GPU 1: alloc=1.30 resv=8.30
[TRAIN-DEBUG] step=200 loss=4.1804 clusters=6375

[CLUSTER] Top 5 clusters:
------------------------------------------------------------------------------------------
Rank  Token          Count       Protos    Mu             Tau         
------------------------------------------------------------------------------------------
1     সিদ্ধান্ত      53          6         24.379122      3.906033    
2     লাগে           51          4         23.324217      4.017983    
3     আলোয়          51          4         23.527962      3.982990    
4     পরিবর্তন       51          4         22.262837      4.183382    
5     মান            51          4         24.489047      3.032672    
------------------------------------------------------------------------------------------


Epoch 1/1:  69%|██████▉   | 207/300 [47:32<21:51, 14.10s/it, fwd_loss=4.2356 bwd_loss=0.2647 rate=100.0% clusters=6482 next_disc_in=193]


EPOCH 1 COMPREHENSIVE VALIDATION (Step 208)

[VALIDATION] Testing 10 samples:
--------------------------------------------------------------------------------
   1. no expl         কল=tap/call                    -> I closed the call.
   2. no expl         কাল=tomorrow/yesterday         -> I will buy the book tomorrow..
   3. no expl         পাতা=leaf/page                 -> the page is shred.
   4. no expl         ব্যাংক=bank/embankment         -> he went to the bank.
   5. no expl         No ambiguity                   -> I is good.
   6. no expl         No ambiguity                   -> she speak very sweetly.
   7. no expl         No ambiguity                   -> this is my book.
   8. no expl         No ambiguity                   -> the weather is good today.
   9. no expl         ফল=fruit/result                -> the fruit is very good.
  10. no expl         মাথা=head/top                  -> he he he is pain....

-----------------------------------------------------------------

Epoch 1/1:  69%|██████▉   | 208/300 [47:58<26:39, 17.39s/it, fwd_loss=4.3017 bwd_loss=0.2689 rate=100.0% clusters=6498 next_disc_in=192]




Epoch 1/1: 100%|██████████| 300/300 [1:09:11<00:00, 13.84s/it, fwd_loss=3.7065 bwd_loss=0.2317 rate=100.0% clusters=7675 next_disc_in=100]



EPOCH 1/1 SUMMARY
  Duration (min): 69.19
  Optimizer updates: 19
  Batches: processed=300, skipped=0
  Success rate: 105.6%
  Clustered tokens: 7675
  Avg epoch loss: 5.337965

[TRAIN] Running synchronous DSCD clustering after epoch 1...


# debug cell

In [None]:
# ==============================================================================
# CELL 12: DEEP DIAGNOSTIC DEBUGGER (ROOT CAUSE ANALYSIS) - FIXED
# ==============================================================================
import torch
import torch.nn.functional as F
import numpy as np

def deep_debug_tatn(model, tokenizer, test_sentences):
    print("\n" + "=" * 80)
    print("TATN DEEP DIAGNOSTIC REPORT (FIXED)")
    print("=" * 80)

    # 1. Configuration Check
    print("[1] CONFIGURATION CHECK")
    try:
        if "REAL_AMB_SPAN_THRESHOLD" in globals():
            span_thresh = float(REAL_AMB_SPAN_THRESHOLD)
        elif "_SPAN_THRESHOLD" in globals():
            span_thresh = float(_SPAN_THRESHOLD)
        else:
            span_thresh = 0.05

        if "REAL_AMB_UNCERTAINTY_THRESHOLD" in globals():
            tau_low = float(REAL_AMB_UNCERTAINTY_THRESHOLD)
        elif "_TAU_LOW" in globals():
            tau_low = float(_TAU_LOW)
        else:
            tau_low = 0.15

        print(f"   SPAN_THRESHOLD: {span_thresh}")
        print(f"   TAU_LOW (Uncertainty): {tau_low}")
    except Exception:
        print("   Warning: Could not read global config variables. Using defaults.")
        span_thresh = 0.05
        tau_low = 0.15

    # 2. Component Health Check
    print("\n[2] COMPONENT HEALTH CHECK")
    core = model.module if hasattr(model, "module") else model

    # Check DSCD
    dscd = getattr(core, "dscd", None)
    if dscd:
        try:
            stores = {}
            lock = None
            if hasattr(dscd, "bufferlock"):
                lock = dscd.bufferlock
            elif hasattr(dscd, "clusteringlock"):
                lock = dscd.clusteringlock

            if lock:
                with lock:
                    stores = dict(getattr(dscd, "prototypestores", {}))
            else:
                stores = dict(getattr(dscd, "prototypestores", {}))

            homographs = []
            for k, v in stores.items():
                try:
                    sz = v.size() if hasattr(v, "size") and callable(getattr(v, "size", None)) else 0
                except Exception:
                    sz = 0
                if sz >= 2:
                    homographs.append(k)

            print(f"   DSCD: Alive. Found {len(homographs)} homographs.")
            if len(homographs) > 0:
                print(f"   Example homographs: {homographs[:5]}")
            else:
                print("   CRITICAL WARNING: DSCD found 0 homographs. TRG will never trigger.")
        except Exception as e:
            print(f"   DSCD Error: {e}")
    else:
        print("   CRITICAL ERROR: DSCD module missing.")

    # Check ASBN
    asbn = getattr(core, "asbn", None)
    if asbn:
        print("   ASBN: Module exists.")
        try:
            try:
                device = next(asbn.parameters()).device
            except Exception:
                device = next(core.parameters()).device

            embed_dim = getattr(getattr(core, "mbart", None), "config", None)
            if embed_dim is not None and hasattr(embed_dim, "d_model"):
                embed_dim = int(embed_dim.d_model)
            else:
                embed_dim = 1024

            dummy_h = torch.randn(2, 10, embed_dim, device=device)

            if hasattr(asbn, "d_domain"):
                d_dom_out = asbn.d_domain(dummy_h)
                mean_val = float(d_dom_out.mean().item())
                std_val = float(d_dom_out.std().item())
                print(f"   ASBN Discriminator Output Mean: {mean_val:.4f} (Should not be exactly 0 or +/- inf)")
                if std_val < 1e-6:
                    print("   CRITICAL WARNING: ASBN Discriminator collapsed (zero variance).")
        except Exception as e:
            print(f"   ASBN Probe Failed: {e}")
    else:
        print("   ASBN: Module missing.")

    # 3. Sentence-Level Trace
    print("\n[3] SENTENCE TRACE ANALYSIS")
    core.eval()

    # --------------------------------------------------------------------------
    # Alignment logic for manual trace
    # --------------------------------------------------------------------------
    def _build_alignment(text, tok):
        words = text.split()
        try:
            enc = tok(text, return_offsets_mapping=True, add_special_tokens=False)
            offsets = enc["offset_mapping"]
        except Exception:
            return {}

        m = {}
        current_pos = 0
        w_spans = []
        for w in words:
            start = text.find(w, current_pos)
            if start != -1:
                w_spans.append((start, start + len(w), w))
                current_pos = start + len(w)

        for idx, (s, e) in enumerate(offsets):
            for ws, we, wtext in w_spans:
                if s >= ws and e <= we:
                    m[idx] = wtext
                    break
        return m

    for sent_txt, target_word in test_sentences:
        print(f"\nAnalyzing: '{sent_txt}' (Target: {target_word})")
        print("-" * 60)

        try:
            device = next(core.parameters()).device
        except Exception:
            device = torch.device("cpu")

        inputs = tokenizer(sent_txt, return_tensors="pt", padding=True).to(device)

        token_word_map = _build_alignment(sent_txt, tokenizer)

        with torch.no_grad():
            if hasattr(core, "mbart"):
                try:
                    encoder_out = core.mbart.model.encoder(
                        input_ids=inputs["input_ids"],
                        attention_mask=inputs.get("attention_mask", None),
                    )
                except Exception:
                    encoder_out = core.mbart.get_encoder()(
                        input_ids=inputs["input_ids"],
                        attention_mask=inputs.get("attention_mask", None),
                    )
            else:
                print("   Error: core.mbart not found")
                continue

            hidden_states = encoder_out.last_hidden_state

            # DSCD Trace: call with same signature as main pipeline
            dscd_out = None
            try:
                dscd_out = core.dscd.forward(
                    tokenembeddings=hidden_states,
                    tokentypes=None,
                    trainmode=False,
                    tokenwordmap=[token_word_map],
                    inputids=inputs["input_ids"],
                    attentionmask=inputs.get("attention_mask", None),
                )
            except TypeError:
                try:
                    dscd_out = core.dscd(
                        hidden_states,
                        tokentypes=None,
                        trainmode=False,
                        tokenwordmap=[token_word_map],
                        inputids=inputs["input_ids"],
                        attentionmask=inputs.get("attention_mask", None),
                    )
                except Exception as e:
                    print(f"   DSCD call failed: {e}")
                    continue
            except Exception as e:
                print(f"   DSCD call failed: {e}")
                continue

            try:
                tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
            except Exception:
                tokens = [str(i) for i in range(int(inputs["input_ids"].size(1)))]

            def _get_metric(key_main, key_alt=None, default_val=0.0):
                if not isinstance(dscd_out, dict):
                    return default_val
                if key_main in dscd_out:
                    val = dscd_out[key_main]
                elif key_alt is not None and key_alt in dscd_out:
                    val = dscd_out[key_alt]
                else:
                    return default_val
                if isinstance(val, (list, tuple)):
                    return val[0] if len(val) > 0 else default_val
                return val

            proto_assignments = _get_metric("protoassignments", "proto_assignments", [])
            span_preds = _get_metric("spanpreds", "span_preds", [])
            uncertainties = _get_metric("uncertainties", None, [])
            gates = _get_metric("gates", None, [])

            if isinstance(span_preds, torch.Tensor):
                span_preds = span_preds[0]
            if isinstance(uncertainties, torch.Tensor):
                uncertainties = uncertainties[0]
            if isinstance(gates, torch.Tensor):
                gates = gates[0]
            if isinstance(proto_assignments, torch.Tensor):
                proto_assignments = proto_assignments[0]

            found_target = False

            print(f"   {'Token':<15} {'Word':<15} {'Span':<10} {'Uncert':<10} {'Gate':<10} {'ProtoID':<8} {'Status'}")
            print(f"   {'-'*15} {'-'*15} {'-'*10} {'-'*10} {'-'*10} {'-'*8} {'-'*20}")

            for i, tok in enumerate(tokens):
                mapped_word = token_word_map.get(i, "")

                clean_tok = tok.replace(" ", "").replace("Ġ", "").replace("##", "").strip().lower()
                clean_map = mapped_word.strip().lower()

                is_target = (
                    (target_word in clean_tok)
                    or (clean_tok in target_word)
                    or (target_word in clean_map)
                    or (clean_map in target_word)
                )

                try:
                    s_val = float(span_preds[i]) if i < len(span_preds) else 0.0
                except Exception:
                    s_val = 0.0

                try:
                    u_val = float(uncertainties[i]) if i < len(uncertainties) else 0.0
                except Exception:
                    u_val = 0.0

                try:
                    g_val = float(gates[i]) if i < len(gates) else 0.0
                except Exception:
                    g_val = 0.0

                try:
                    p_val = int(proto_assignments[i]) if i < len(proto_assignments) else -1
                except Exception:
                    p_val = -1

                status = []
                if s_val < span_thresh:
                    status.append(f"Low Span (<{span_thresh})")
                if u_val < tau_low:
                    status.append(f"Low Uncert (<{tau_low})")
                if p_val == -1:
                    status.append("No Proto")

                is_subword = (
                    bool(mapped_word)
                    and not tok.startswith(" ")
                    and not tok.startswith("Ġ")
                    and not tok.startswith("▁")
                    and i > 0
                )
                if is_subword:
                    status.append("Subword Fragment")

                status_str = " | ".join(status) if status else "READY"

                should_explain = (s_val > span_thresh) or (u_val > tau_low)
                if should_explain:
                    status_str += " [TRIGGER]"

                if is_target or should_explain:
                    print(
                        f"   {tok:<15} {mapped_word:<15} {s_val:<10.4f} {u_val:<10.4f} "
                        f"{g_val:<10.4f} {p_val:<8} {status_str}"
                    )
                    found_target = True

            if not found_target:
                print(f"   Note: Target word '{target_word}' not matched or completely filtered.")

    print("\n" + "=" * 80)
    print("ROOT CAUSE CONCLUSION")
    print("=" * 80)
    print("1. If 'Span' is 0.0000 -> DSCD clustering is too tight or Threshold too high.")
    print("2. If 'Uncert' is low -> Model is confident (ASBN might be too strong).")
    print("3. If 'No Proto' -> Word wasn't seen enough in training to form a cluster.")
    print("4. **Subword Fragment**: If you see this, Cell 5 aggregation fix is REQUIRED to combine these metrics.")

# ==============================================================================
# EXECUTE DEBUGGER
# ==============================================================================
if "trained_model" in globals() and "tokenizer" in globals():
    debug_sentences = [
        ("আমি কল বন্ধ করেছি।", "কল"),
        ("কাল আমি বই কিনব।", "কাল"),
        ("পাতা ঝরে পড়েছে।", "পাতা"),
    ]
    deep_debug_tatn(trained_model, tokenizer, debug_sentences)
else:
    print("Error: trained_model and tokenizer not found. Run Cell 11 first.")


In [None]:
# ==============================================================================
# CELL 13: EXTENDED INFERENCE TESTING (FINAL)
# ==============================================================================
import os
import time
import traceback
import json
from typing import Tuple, Any, Dict, List, Optional
from collections import defaultdict
import torch
import gc

try:
    _DEVICE = (
        DEVICE
        if isinstance(DEVICE, torch.device)
        else torch.device(str(DEVICE))
        if isinstance(DEVICE, str)
        else torch.device("cuda" if torch.cuda.is_available() else "cpu")
    )
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
    _NUM_GPUS = int(NUM_GPUS)
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
    _DEBUG_DISCOVERY = bool(DEBUG_DISCOVERY)
    _DEBUG_TIMING = bool(DEBUG_TIMING)
    _SPAN_THRESHOLD = float(SPAN_THRESHOLD)
    _TAU_LOW = float(TAU_LOW)
    _HOMOGRAPH_REFERENCE_LIST_BN = set(str(w) for w in HOMOGRAPH_REFERENCE_LIST_BN)
    cell0_loaded = True
except (NameError, TypeError, ValueError):
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1
    _VERBOSE_LOGGING = False
    _DEBUG_DISCOVERY = False
    _DEBUG_TIMING = False
    _SPAN_THRESHOLD = 0.20
    _TAU_LOW = 0.15
    _HOMOGRAPH_REFERENCE_LIST_BN = {
        "কল",
        "কাল",
        "পাতা",
        "ব্যাংক",
        "ফল",
        "মাথা",
        "বার",
        "হার",
        "তারা",
    }
    cell0_loaded = False
    print("[TEST] Using fallback config (Cell 0 not executed)")

_CHECKPOINT_PATH = "/kaggle/working/tatn_final.pt"


def _safe_print(msg: str):
    try:
        print(msg)
    except Exception:
        pass


def _maybe_traceback(exc: Exception):
    if _VERBOSE_LOGGING or _DEBUG_DISCOVERY:
        try:
            traceback.print_exc()
        except Exception:
            pass


def _compute_similarity(translation: str, expected: str) -> float:
    try:
        trans_words = set(translation.lower().split())
        exp_words = set(expected.lower().split())
        if not exp_words:
            return 0.0
        overlap = len(trans_words & exp_words)
        return overlap / len(exp_words)
    except Exception:
        return 0.0


def _get_dscd_homographs(model) -> set:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            return set()

        if hasattr(dscd, "getdiscoveredhomographs"):
            try:
                return set(dscd.getdiscoveredhomographs())
            except Exception:
                pass

        homographs = set()

        lock = None
        if hasattr(dscd, "bufferlock"):
            lock = dscd.bufferlock
        elif hasattr(dscd, "clusteringlock"):
            lock = dscd.clusteringlock

        if lock:
            with lock:
                stores = dict(getattr(dscd, "prototypestores", {}) or {})
        else:
            stores = dict(getattr(dscd, "prototypestores", {}) or {})

        for token, store in stores.items():
            try:
                size_ok = False
                size_attr = getattr(store, "size", None)
                try:
                    if callable(size_attr):
                        size_ok = size_attr() >= 2
                    elif isinstance(size_attr, int):
                        size_ok = size_attr >= 2
                except Exception:
                    size_ok = False

                if not size_ok:
                    counts = getattr(store, "counts", None)
                    if isinstance(counts, (list, tuple)) and len(counts) >= 2:
                        size_ok = True

                if size_ok:
                    clean = (
                        str(token)
                        .replace("▁", "")
                        .replace("Ġ", "")
                        .replace("##", "")
                        .strip()
                        .lower()
                    )
                    if clean:
                        homographs.add(clean)
            except Exception:
                continue

        return homographs
    except Exception:
        return set()


trained_model_available = "trained_model" in globals() and globals().get("trained_model") is not None
tokenizer_available = "tokenizer" in globals() and globals().get("tokenizer") is not None
translate_available = "translate_with_explanations" in globals()

if not trained_model_available:
    _safe_print("trained_model not found - will try checkpoint")
if not tokenizer_available:
    _safe_print("tokenizer not found - run pipeline first")
if not translate_available:
    _safe_print("translate_with_explanations not found - run Cell 8")


def try_load_checkpoint(checkpoint_path: str, tokenizer) -> Tuple[bool, Any]:
    if not os.path.exists(checkpoint_path):
        return False, f"Not found: {checkpoint_path}"

    if "MemoryOptimizedTATNWithExplanations" not in globals():
        return False, "Model class not available"

    _safe_print(f"[TEST] Loading: {checkpoint_path}")

    try:
        ckpt = torch.load(checkpoint_path, map_location="cpu")
    except Exception as e:
        _safe_print(f"[TEST] Load failed: {type(e).__name__}")
        _maybe_traceback(e)
        return False, e

    state = None
    if isinstance(ckpt, dict):
        for k in ("model_state_dict", "state_dict", "model"):
            if k in ckpt and isinstance(ckpt[k], dict):
                state = ckpt[k]
                break
        if state is None:
            values_sample = list(ckpt.values())[:10]
            if any(torch.is_tensor(v) for v in values_sample):
                state = ckpt
    else:
        state = ckpt

    if state is None:
        return False, "No model state found"

    try:
        _safe_print(f"[TEST] Model state: {len(state)} keys")
    except Exception:
        _safe_print("[TEST] Model state: (unknown size)")

    dscd_state = None
    if isinstance(ckpt, dict) and "dscd_state" in ckpt:
        dscd_state = ckpt["dscd_state"]
        if isinstance(dscd_state, dict):
            # handle plain state_dict or wrapped structure
            proto_obj = dscd_state
            if "prototypestores" in dscd_state and isinstance(dscd_state["prototypestores"], dict):
                proto_obj = dscd_state["prototypestores"]
            num_tokens = len(proto_obj)
            _safe_print(f"[TEST] DSCD state: {num_tokens} tokens")
            if num_tokens == 0:
                _safe_print("[TEST] DSCD empty - warmup needed")
        else:
            _safe_print("[TEST] DSCD state invalid")
    else:
        _safe_print("[TEST] No DSCD state")

    try:
        model_inst = MemoryOptimizedTATNWithExplanations(tokenizer)
    except Exception as e:
        _safe_print(f"[TEST] Instantiation failed: {type(e).__name__}")
        _maybe_traceback(e)
        return False, e

    try:
        mbart = getattr(model_inst, "mbart", None)
        if mbart and hasattr(mbart, "get_input_embeddings"):
            cur = mbart.get_input_embeddings().num_embeddings
            tok_len = getattr(
                tokenizer,
                "vocab_size",
                len(tokenizer) if hasattr(tokenizer, "__len__") else None,
            )
            if tok_len and cur != tok_len:
                try:
                    mbart.resize_token_embeddings(tok_len)
                    _safe_print(f"[TEST] Resized: {cur} -> {tok_len}")
                except Exception:
                    pass
    except Exception:
        pass

    try:
        res = model_inst.load_state_dict(state, strict=False)
        missing = []
        if isinstance(res, dict):
            missing = res.get("missing_keys", []) or res.get("missing", [])
        _safe_print(f"[TEST] State loaded (missing: {len(missing)})")
    except Exception:
        try:
            new_state = {k.replace("module.", "", 1): v for k, v in state.items()}
            model_inst.load_state_dict(new_state, strict=False)
            _safe_print("[TEST] Loaded (stripped prefixes)")
        except Exception as e2:
            _safe_print(f"[TEST] Load failed: {type(e2).__name__}")
            _maybe_traceback(e2)
            return False, e2

    if dscd_state:
        try:
            dscd = getattr(model_inst, "dscd", None)
            if dscd and hasattr(dscd, "load_state_dict"):
                dscd.load_state_dict(dscd_state)
                num_tokens = len(getattr(dscd, "prototypestores", {}) or {}) if hasattr(
                    dscd, "prototypestores"
                ) else 0
                _safe_print(f"[TEST] DSCD loaded: {num_tokens} tokens")
                if num_tokens == 0:
                    _safe_print("[TEST] DSCD has 0 tokens - warmup needed")
            else:
                _safe_print("[TEST] No DSCD load_state_dict")
        except Exception as e:
            _safe_print(f"[TEST] DSCD load failed: {type(e).__name__}")
            _maybe_traceback(e)

    try:
        model_inst.to(_DEVICE)
        model_inst.eval()
    except Exception as e:
        _safe_print(f"[TEST] Device move failed: {type(e).__name__}")
        return False, e

    _safe_print(f"[TEST] Ready on: {_DEVICE}")
    return True, model_inst


if os.path.exists(_CHECKPOINT_PATH) and tokenizer_available:
    succ, model_or_err = try_load_checkpoint(_CHECKPOINT_PATH, globals().get("tokenizer"))
    if succ:
        globals()["trained_model"] = model_or_err
        trained_model_available = True
        _safe_print("[TEST] Checkpoint loaded")
    else:
        _safe_print("[TEST] Checkpoint load failed")


def maybe_run_warmup_if_needed(model, tokenizer, warmup_sents: int = 4000) -> bool:
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)

        if dscd is None:
            _safe_print("[TEST] No DSCD - skip warmup")
            return False

        lock = None
        if hasattr(dscd, "bufferlock"):
            lock = dscd.bufferlock
        elif hasattr(dscd, "clusteringlock"):
            lock = dscd.clusteringlock

        if lock:
            with lock:
                stores = dict(getattr(dscd, "prototypestores", {}) or {})
        else:
            stores = dict(getattr(dscd, "prototypestores", {}) or {})

        initial_count = len(stores)

        if initial_count > 0:
            multi_sense = sum(
                1
                for store in stores.values()
                if hasattr(store, "size")
                and callable(getattr(store, "size", None))
                and store.size() >= 2
            )
            _safe_print(f"[TEST] DSCD has {initial_count} tokens ({multi_sense} multi-sense)")
            return True

        _safe_print("[TEST] DSCD empty - running warmup...")

        warmup_fn = globals().get("dscd_discovery_warmup", None)
        if warmup_fn is None or not callable(warmup_fn):
            _safe_print("[TEST] Warmup function not available")
            return False

        try:
            warmup_start = time.time()
            warmup_fn(
                model,
                tokenizer,
                num_sents=warmup_sents,
                batch_size=64,
                max_len=globals().get("MAX_LENGTH", 48),
            )
            warmup_time = time.time() - warmup_start

            if lock:
                with lock:
                    stores_after = dict(getattr(dscd, "prototypestores", {}) or {})
            else:
                stores_after = dict(getattr(dscd, "prototypestores", {}) or {})

            final_count = len(stores_after)
            multi_sense = sum(
                1
                for store in stores_after.values()
                if hasattr(store, "size")
                and callable(getattr(store, "size", None))
                and store.size() >= 2
            )

            if final_count > 0:
                ratio = multi_sense / final_count if final_count > 0 else 0
                _safe_print(f"[TEST] Warmup success ({warmup_time:.1f}s)")
                _safe_print(
                    f"[TEST]    Tokens: {final_count}, "
                    f"Multi-sense: {multi_sense} ({ratio:.1%})"
                )
                if ratio < 0.1:
                    _safe_print("[TEST] Low multi-sense ratio (<10%)")
                return True
            else:
                _safe_print("[TEST] Warmup complete but NO prototypes")
                return False

        except Exception as e:
            _safe_print(f"[TEST] Warmup failed: {type(e).__name__}")
            _maybe_traceback(e)
            return False

    except Exception as e:
        _safe_print(f"[TEST] Warmup check failed: {type(e).__name__}")
        return False


test_sentences: List[Tuple[str, str, str]] = [
    ("আমি কল বন্ধ করেছি।", "I turned off the tap", "কল = tap/call"),
    ("কাল আমি বই কিনব।", "Tomorrow I will buy a book", "কাল = tomorrow/yesterday"),
    ("পাতা ঝরে পড়েছে।", "The leaf has fallen", "পাতা = leaf/page"),
    ("তিনি ব্যাংক গেছেন।", "He went to the bank", "ব্যাংক = bank/embankment"),
    ("আমি ভালো আছি।", "I am fine", "Simple"),
    ("সে খুব মিষ্টি কথা বলে।", "She speaks sweetly", "Adjective"),
    ("এটা আমার বই।", "This is my book", "Demonstrative"),
    ("তুমি কি আমাকে সাহায্য করতে পারো? ", "Can you help me?", "Question"),
    ("আজ আবহাওয়া ভালো।", "Weather is good", "Simple"),
    ("আমরা বাংলাদেশে বাস করি।", "We live in Bangladesh", "Country"),
    ("সূর্য পূর্ব দিকে ওঠে।", "Sun rises in east", "Directional"),
    ("পাখি আকাশে উড়ে।", "Birds fly in sky", "Simple present"),
    ("সে স্কুলে যাচ্ছে।", "She is going to school", "Continuous"),
]

avg_conf = 0.0
avg_span = 0.0
avg_u = 0.0
avg_time = 0.0

if not (trained_model_available and tokenizer_available and translate_available):
    _safe_print("\nCannot run tests - missing prerequisites")
    _safe_print("   Run Cells 0-11 or load checkpoint")
else:
    warmup_success = False
    try:
        warmup_success = maybe_run_warmup_if_needed(
            globals().get("trained_model"),
            globals().get("tokenizer"),
            warmup_sents=4000,
        )
    except Exception as e:
        _safe_print(f"[TEST] Warmup failed: {type(e).__name__}")
        _maybe_traceback(e)

    dscd_homographs = _get_dscd_homographs(globals().get("trained_model"))
    _safe_print(f"\n[TEST] DSCD discovered: {len(dscd_homographs)} homographs")
    if dscd_homographs and _DEBUG_DISCOVERY:
        _safe_print(f"[TEST] Sample: {list(dscd_homographs)[:10]}")

    _safe_print(f"\n[COMPONENT HEALTH]")
    try:
        core = globals().get("trained_model")
        core = core.module if hasattr(core, "module") else core

        dscd = getattr(core, "dscd", None)
        if dscd and hasattr(dscd, "get_prototype_summary"):
            try:
                dscd_stats = dscd.get_prototype_summary()
                _safe_print(
                    f"  DSCD: {dscd_stats.get('total_tokens', 0)} tokens, "
                    f"{dscd_stats.get('num_homographs', 0)} homographs"
                )
            except Exception:
                pass

        asbn = getattr(core, "asbn", None)
        if asbn and hasattr(asbn, "get_detailed_stats"):
            try:
                asbn_stats = asbn.get_detailed_stats()
                _safe_print(
                    f"  ASBN: {asbn_stats.get('domain_accuracy', 0):.2%} domain accuracy"
                )
            except Exception:
                pass

        trg = getattr(core, "trg_system", None)
        if trg and hasattr(trg, "get_statistics"):
            try:
                trg_stats = trg.get_statistics()
                _safe_print(
                    f"  TRG: {trg_stats.get('explanations_generated', 0)} total explanations"
                )
            except Exception:
                pass
    except Exception:
        pass

    total = len(test_sentences)
    successes = 0
    tests_with_explanations = 0
    total_ambiguous = 0

    quality_metrics = {
        "confidences": [],
        "spans": [],
        "uncertainties": [],
        "similarities": [],
    }

    dscd_homographs_explained = set()
    reference_homographs_explained = set()
    homograph_explanations = defaultdict(list)

    inference_times = []

    _safe_print("\n" + "=" * 80)
    _safe_print("EXTENDED INFERENCE TESTING")
    _safe_print("=" * 80)
    _safe_print("Configuration:")
    _safe_print(f"  Cell 0: {'Loaded' if cell0_loaded else 'Fallback'}")
    _safe_print(f"  Span threshold: {_SPAN_THRESHOLD}")
    _safe_print(f"  Uncertainty threshold: {_TAU_LOW}")
    _safe_print(f"  Tests: {total}")
    _safe_print("=" * 80)

    if not warmup_success:
        _safe_print("\nWARNING: Warmup failed")
        _safe_print("   Homograph detection may not work\n")

    for idx, (sent, expected, note) in enumerate(test_sentences, 1):
        _safe_print("\n" + "-" * 70)
        _safe_print(f"Test {idx}/{total}: {note}")
        _safe_print(f"Input: {sent}")

        try:
            model_for_infer = globals().get("trained_model")
            tokenizer = globals().get("tokenizer")

            if model_for_infer is None or tokenizer is None:
                raise RuntimeError("Model/tokenizer missing")

            inf_start = time.time()
            res = translate_with_explanations(model_for_infer, tokenizer, sent)
            inf_time = time.time() - inf_start
            inference_times.append(inf_time)

            if res is None or not isinstance(res, dict):
                _safe_print("[TEST] Invalid result - skip")
                continue

            translation = str(res.get("translation", ""))
            amb_count = int(res.get("ambiguous_words_detected", 0))
            explanations = res.get("explanations", []) or []

            _safe_print(f"Translation: {translation}")
            _safe_print(f"Time: {inf_time:.3f}s")

            similarity = _compute_similarity(translation, expected)
            quality_metrics["similarities"].append(similarity)
            _safe_print(f"Similarity: {similarity:.1%}")

            _safe_print(f"Ambiguous: {amb_count}")

            if amb_count > 0:
                tests_with_explanations += 1
                total_ambiguous += amb_count
                _safe_print("Explanations:")

                for j, e in enumerate(explanations, 1):
                    try:
                        word = e.get("ambiguous_word", e.get("token", "N/A"))
                        conf = float(e.get("confidence", 0.5))
                        u = float(e.get("uncertainty", 0.0))
                        s = float(e.get("span", 0.0))

                        quality_metrics["confidences"].append(conf)
                        quality_metrics["spans"].append(s)
                        quality_metrics["uncertainties"].append(u)

                        clean = (
                            str(word)
                            .replace("▁", "")
                            .replace("Ġ", "")
                            .strip()
                            .lower()
                        )

                        if clean in dscd_homographs:
                            dscd_homographs_explained.add(clean)
                            homograph_explanations[clean].append(
                                {
                                    "sentence": sent,
                                    "confidence": conf,
                                    "span": s,
                                    "uncertainty": u,
                                }
                            )

                        if clean in _HOMOGRAPH_REFERENCE_LIST_BN:
                            reference_homographs_explained.add(clean)

                        marker = "[HIGH]" if s > _SPAN_THRESHOLD else "      "
                        _safe_print(
                            f"  {j}. {marker} '{word}' conf={conf:.3f} u={u:.3f} s={s:.3f}"
                        )

                    except Exception:
                        if _DEBUG_DISCOVERY:
                            try:
                                traceback.print_exc()
                            except Exception:
                                pass
            else:
                _safe_print("No ambiguity")

            if translation and translation.strip():
                successes += 1
                _safe_print("Success")
            else:
                _safe_print("Failed")

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

        except Exception as e:
            _safe_print(f"Test {idx} failed: {type(e).__name__}")
            if _DEBUG_DISCOVERY:
                _maybe_traceback(e)

    _safe_print("\n" + "=" * 80)
    _safe_print("TEST SUMMARY")
    _safe_print("=" * 80)

    _safe_print("\n[TRANSLATION]")
    _safe_print(f"  Total: {total}")
    if total > 0:
        _safe_print(f"  Success: {successes} ({successes/total*100:.1f}%)")
        _safe_print(f"  Failed: {total - successes} ({(total - successes)/total*100:.1f}%)")

        if quality_metrics["similarities"]:
            avg_sim = sum(quality_metrics["similarities"]) / len(
                quality_metrics["similarities"]
            )
            _safe_print(f"  Avg similarity: {avg_sim:.1%}")

    if inference_times:
        avg_time = sum(inference_times) / len(inference_times)
        _safe_print("\n[PERFORMANCE]")
        _safe_print(f"  Avg time: {avg_time:.3f}s per sentence")
        _safe_print(f"  Throughput: {1 / avg_time:.1f} sentences/sec")

    _safe_print("\n[AMBIGUITY]")
    _safe_print(
        f"  Tests with explanations: {tests_with_explanations}/{total} "
        f"({tests_with_explanations/total*100:.1f}%)"
    )
    _safe_print(f"  Total ambiguous: {total_ambiguous}")
    if total > 0:
        _safe_print(f"  Avg per sentence: {total_ambiguous/total:.2f}")

    if quality_metrics["confidences"]:
        avg_conf = sum(quality_metrics["confidences"]) / len(
            quality_metrics["confidences"]
        )
        avg_span = (
            sum(quality_metrics["spans"]) / len(quality_metrics["spans"])
            if quality_metrics["spans"]
            else 0.0
        )
        avg_u = (
            sum(quality_metrics["uncertainties"])
            / len(quality_metrics["uncertainties"])
            if quality_metrics["uncertainties"]
            else 0.0
        )

        high_conf = sum(1 for c in quality_metrics["confidences"] if c >= 0.65)

        _safe_print("\n[QUALITY]")
        _safe_print(f"  Avg confidence: {avg_conf:.3f}")
        _safe_print(f"  Avg span: {avg_span:.3f}")
        _safe_print(f"  Avg uncertainty: {avg_u:.3f}")
        _safe_print(
            f"  High confidence: {high_conf}/{len(quality_metrics['confidences'])} "
            f"({high_conf/len(quality_metrics['confidences']):.1%})"
        )
    else:
        _safe_print("\n[QUALITY]")
        _safe_print("  NO EXPLANATIONS")
        _safe_print("     Possible causes:")
        _safe_print("     1.  DSCD empty (warmup failed)")
        _safe_print("     2. TRG thresholds too strict")

    _safe_print("\n[HOMOGRAPHS (DATA-DRIVEN)]")
    _safe_print(f"  DSCD discovered: {len(dscd_homographs)}")
    _safe_print(f"  Explained: {len(dscd_homographs_explained)}")
    if dscd_homographs:
        try:
            _safe_print(
                f"  Rate: "
                f"{len(dscd_homographs_explained)/len(dscd_homographs):.1%}"
            )
        except ZeroDivisionError:
            _safe_print("  Rate: 0.0%")

    if dscd_homographs_explained:
        _safe_print("\n  Explained:")
        for homo in sorted(dscd_homographs_explained):
            exps = homograph_explanations[homo]
            avg_conf_local = (
                sum(e["confidence"] for e in exps) / len(exps) if exps else 0.0
            )
            in_ref = "[R]" if homo in _HOMOGRAPH_REFERENCE_LIST_BN else "   "
            _safe_print(f"    {in_ref} '{homo}': {len(exps)}x conf={avg_conf_local:.3f}")

    _safe_print("\n[REFERENCE COMPARISON]")
    _safe_print(f"  Size: {len(_HOMOGRAPH_REFERENCE_LIST_BN)}")
    _safe_print(f"  Explained: {len(reference_homographs_explained)}")
    try:
        coverage = (
            len(reference_homographs_explained) / len(_HOMOGRAPH_REFERENCE_LIST_BN)
            if len(_HOMOGRAPH_REFERENCE_LIST_BN) > 0
            else 0.0
        )
        _safe_print(f"  Coverage: {coverage:.1%}")
    except Exception:
        _safe_print("  Coverage: N/A")

    _safe_print("\n[HEALTH]")
    warnings = []

    if successes < total * 0.7:
        warnings.append("Low success (<70%)")
    if tests_with_explanations == 0:
        warnings.append("NO explanations")
    if quality_metrics["confidences"] and avg_conf < 0.5:
        warnings.append("Low confidence (<0.5)")
    if dscd_homographs and len(dscd_homographs_explained) < len(dscd_homographs) * 0.3:
        warnings.append("Low explanation rate (<30%)")

    if warnings:
        for w in warnings:
            _safe_print(f"  - {w}")
    else:
        _safe_print("  All systems OK")

    try:
        results = {
            "total_tests": total,
            "successes": successes,
            "tests_with_explanations": tests_with_explanations,
            "quality_metrics": {
                "avg_confidence": avg_conf if quality_metrics["confidences"] else 0,
                "avg_span": avg_span if quality_metrics["spans"] else 0,
                "avg_uncertainty": avg_u
                if quality_metrics["uncertainties"]
                else 0,
            },
            "dscd_discovered": len(dscd_homographs),
            "dscd_explained": len(dscd_homographs_explained),
            "reference_explained": len(reference_homographs_explained),
            "avg_inference_time": (sum(inference_times) / len(inference_times))
            if inference_times
            else 0,
        }

        results_path = "/kaggle/working/test_results.json"
        with open(results_path, "w", encoding="utf-8") as f:
            json.dump(results, f, indent=2, ensure_ascii=False)
        _safe_print(f"\nResults saved: {results_path}")
    except Exception:
        pass

    _safe_print("\n" + "=" * 80)
    _safe_print(f"Thresholds: span>{_SPAN_THRESHOLD}, uncertainty>{_TAU_LOW}")
    _safe_print("Testing complete (DATA-DRIVEN)")
    _safe_print("=" * 80)

print("\n" + "=" * 80)
print("Cell 13: Extended testing ready (FINAL)")
print("=" * 80)
print("FIXES APPLIED:")
print(" F1:  Correct checkpoint path (removed space)")
print(" F2:  DEBUG flag integration")
print(" F3:  Correct DSCD dict key (dscd_state)")
print(" F4:  Performance metrics (time/throughput)")
print(" F5:  Memory cleanup between tests")
print(" F6:  Warmup validation (multi-sense ratio)")
print(" F7:  Component health (DSCD/ASBN/TRG)")
print(" F8:  Removed stray spaces in attribute access (e.g., . replace -> .replace)")
print(" F9:  Safe fallbacks before calling .size()/.prototype_stores")
print(" F10: Results export (JSON)")
print("=" * 80 + "\n")
