In [1]:
!pip uninstall -y transformers sentence-transformers
!pip install transformers==4.30.2 --no-deps --force-reinstall
!pip install sentencepiece tokenizers sacremoses
!pip install scipy scikit-learn
!pip install --upgrade "protobuf==3.20.3"
# optional:
!pip install sentence-transformers==2.2.2
!pip install sacrebleu

Found existing installation: transformers 4.53.3
Uninstalling transformers-4.53.3:
  Successfully uninstalled transformers-4.53.3
Found existing installation: sentence-transformers 4.1.0
Uninstalling sentence-transformers-4.1.0:
  Successfully uninstalled sentence-transformers-4.1.0
Collecting transformers==4.30.2
  Downloading transformers-4.30.2-py3-none-any.whl.metadata (113 kB)
[2K     [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m113.6/113.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading transformers-4.30.2-py3-none-any.whl (7.2 MB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m7.2/7.2 MB[0m [31m71.0 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[?25hInstalling collected packages: transformers
Successfully installed transformers-4.30.2
Collecting sacremoses
  Downloading

In [2]:
import transformers
print(transformers.__version__)


4.30.2


In [3]:
# ==============================================================================
# CELL 0 (FIXED): ‚ö° OPTIMIZED ULTRA-FAST TATN CONFIGURATION (DEBUGGED + FIXED)
# ==============================================================================
# What I changed (summary):
# - Fixed _get_csv_row_count: correct chunk-based counting and robust fallback to line counting.
# - Hardened safe_tokenize_with_offsets to handle BatchEncoding shapes (tensor / list) reliably.
# - Safer reading of CSV line-count sample check (text mode with errors='ignore').
# - Minor defensive guards for pandas/transformers presence in places where they are used.
# - Small clarifying comments and consistent returns for helper functions.
#
# These changes remove the primary logic bugs and make the cell robust in common environments.
# ==============================================================================

import os
import sys
import math
import random
import re
import unicodedata
import time
import threading
from collections import deque, defaultdict
from typing import List, Dict, Tuple, Optional, Union
import numpy as np
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import warnings
import gc

# Add pandas for CSV reading (optional but recommended)
try:
    import pandas as pd
    _HAS_PANDAS = True
except Exception:
    pd = None
    _HAS_PANDAS = False
    print("[WARN] pandas not available; CSV loading/validation will be skipped")

# Prefer fast tokenizer class if available, but be resilient if not
try:
    # Try to import fast variant first (no model download here)
    from transformers import M2M100TokenizerFast as M2M100Tokenizer
except Exception:
    try:
        from transformers import M2M100Tokenizer  # type: ignore
    except Exception:
        M2M100Tokenizer = None

# datasets import is used in other data cells; keep import but avoid heavy ops here
try:
    from datasets import load_dataset
    _HAS_DATASETS = True
except Exception:
    load_dataset = None
    _HAS_DATASETS = False

# Reduce noisy warnings; keep tokenizer workers single-threaded for stability
warnings.filterwarnings("ignore")
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
os.environ.setdefault("TF_CPP_MIN_LOG_LEVEL", "3")

# ==============================================================================
# HARDWARE / DEVICE DETECTION
# ==============================================================================
NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
USE_MULTI_GPU = NUM_GPUS > 1
CUDA_AVAILABLE = torch.cuda.is_available()

# For general code simplicity prefer "cuda" device (lets torch pick device:0)
if CUDA_AVAILABLE:
    DEVICE = torch.device("cuda")
else:
    DEVICE = torch.device("cpu")

if USE_MULTI_GPU and CUDA_AVAILABLE:
    print(f"[Cell 0] Multi-GPU Mode: {NUM_GPUS} GPUs available (using device={DEVICE})")
else:
    mode = "Single GPU Mode" if CUDA_AVAILABLE else "CPU Mode"
    print(f"[Cell 0] {mode} (using device={DEVICE})")

print(f"[Cell 0] Device: {DEVICE} (visible GPUs: {NUM_GPUS})")

# ==============================================================================
# DATASET CONFIGURATION (LOCAL CSV FILE) - update path to your dataset
# ==============================================================================
DATASET_CSV_PATH = "/kaggle/input/homo-bn-dataset/bn_homograph_complete_dataset.csv"  # ‚Üê CHANGE THIS

# Validate dataset path exists (early warning). If not present we keep a small fallback.
if not os.path.exists(DATASET_CSV_PATH):
    print(f"[WARN] Dataset CSV not found at: {DATASET_CSV_PATH}")
    print("[WARN] Training will use a small fallback dataset (to avoid immediate crash).")
    _CSV_AVAILABLE = False
else:
    print(f"[INFO] Dataset CSV found: {DATASET_CSV_PATH}")
    _CSV_AVAILABLE = True

# If pandas is available and CSV exists, try a lightweight validation
def _get_csv_row_count(path: str) -> Optional[int]:
    """
    Return the number of rows in a CSV using pandas chunks (memory-efficient).
    Fallbacks to a safe text-mode line counting if chunked read fails.
    If pandas not available or file not present, returns None.
    """
    if not _HAS_PANDAS or not os.path.exists(path):
        return None
    try:
        # Use chunksize iteration and sum actual chunk lengths
        count = 0
        for chunk in pd.read_csv(path, chunksize=100000, usecols=[0], dtype=str):
            count += len(chunk)
        return int(count)
    except Exception:
        # Fallback: try a robust text-mode line count (handles large files)
        try:
            cnt = 0
            with open(path, "r", encoding="utf-8", errors="ignore") as f:
                for _ in f:
                    cnt += 1
            return int(cnt)
        except Exception:
            return None

if _CSV_AVAILABLE and _HAS_PANDAS:
    try:
        _test_df = pd.read_csv(DATASET_CSV_PATH, nrows=1)
        if 'src' not in _test_df.columns or 'tgt' not in _test_df.columns:
            print(f"[ERROR] CSV missing required columns 'src' and/or 'tgt'. Found: {list(_test_df.columns)}")
        else:
            print(f"[INFO] CSV validation passed (columns: {list(_test_df.columns)})")
        del _test_df
    except Exception as e:
        print(f"[WARN] Could not validate CSV structure: {e}")

# ==============================================================================
# ULTRA-FAST CONFIGURATION (user-tunable)
# ==============================================================================

BATCH_SIZE = 100              # batch size per step
NUM_SAMPLES = 50000           # Maximum samples to load from CSV (cap)
MAX_LENGTH = 48               # Maximum sequence length for tokenization
LR_NMT = 2e-5                 # Learning rate for main NMT model
LR_TRG = 1e-5                 # Learning rate for TRG component
LR_PHI = 1e-5                 # Learning rate for sense disambiguation
EPOCHS = 2                    # Number of training epochs
GRAD_CLIP_NORM = 1.0          # Gradient clipping threshold
USE_AMP = True                # Automatic Mixed Precision (saves memory)
PRINT_INTERVAL = 500          # Print training stats every N steps
SEED = 42                     # Random seed for reproducibility

# ==============================================================================
# MEMORY / PERFORMANCE SETTINGS
# ==============================================================================

# NOTE: the default of 16 accumulation steps is intentional; keep as constant
ACCUMULATION_STEPS = max(1, 16)       # must be >= 1
MC_DROPOUT_PASSES = 0                 # Monte Carlo dropout passes (0 = disabled)
TRG_EVIDENCE_K = 3                    # Top-K evidence for TRG
MAX_SILVER_BUFFER = 50                # Maximum silver label buffer size

NUM_WORKERS = max(0, 2)               # DataLoader workers (0 safe fallback)
# Pin memory only if CUDA is available
PIN_MEMORY = bool(CUDA_AVAILABLE)
PREFETCH_FACTOR = 2                   # Number of batches to prefetch per worker

# ==============================================================================
# DSCD PARAMETERS (balanced defaults; change if you know resource limits)
# ==============================================================================

DSCD_BUFFER_SIZE = 20
DSCD_MAX_PROTOS = 8
DSCD_N_MIN = 5
DSCD_DISPERSION_THRESHOLD = 0.25
DSCD_EMBED_DIM = 1024
DSCD_TEMPERATURE = 0.7
DSCD_DROPOUT = 0.1
DSCD_AUGMENT_SCALE = 0.1
DSCD_ENABLE_TRAINING_CLUSTERING = True
DSCD_WARMUP_SAMPLES = 8000

# ==============================================================================
# CONTROL FLAGS
# ==============================================================================

ENABLE_ASBN_TRAINING = True
ENABLE_ASBN_INFERENCE = True
ENABLE_TRG_TRAINING = False
ENABLE_TRG_INFERENCE = True

CLUSTERING_TIMEOUT = 5
MEMORY_CLEANUP_FREQUENCY = 100
PERIODIC_DISCOVERY_FREQUENCY = 100

VALIDATION_CHECK_INTERVAL = 200  # 0 = disabled

VERBOSE_LOGGING = False

# ==============================================================================
# CHECKPOINT SETTINGS
# ==============================================================================
CHECKPOINT_DIR = "./checkpoints"
os.makedirs(CHECKPOINT_DIR, exist_ok=True)
CHECKPOINT_INTERVAL = 20000
SAVE_REPLAY_BUFFER = False
LOAD_REPLAY_BUFFER = False
REPLAY_BUFFER_SIZE = 25000
RESUME_FROM_CHECKPOINT = False
CHECKPOINT_PATH = ""

# ==============================================================================
# TRG / UNCERTAINTY HYPERPARAMETERS
# ==============================================================================
TAU_HIGH = 0.85
TAU_LOW = 0.4
TAU_ACCEPT = 0.8
TRG_MAX_GEN_LEN = 16
TRG_GEN_EMBED = 64
TRG_GEN_HID = 64
SPAN_THRESHOLD = 0.3

# ==============================================================================
# ASBN PARAMETERS
# ==============================================================================
ASBN_HIDDEN_DIM = 64
ASBN_LAMBDA = 0.1
ASBN_DROPOUT = 0.1

LAMBDA_ASBN = 0.10
LAMBDA_DSCD = 0.05

# ==============================================================================
# LANGUAGE / WATCHLIST
# ==============================================================================
BN_LANG = "bn"
EN_LANG = "en"
SOURCE_LANGUAGE = "bn"

HOMOGRAPH_WATCHLIST_BN = {"‡¶ï‡¶≤", "‡¶ï‡¶æ‡¶≤", "‡¶™‡¶æ‡¶§‡¶æ", "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï", "‡¶´‡¶≤", "‡¶Æ‡¶æ‡¶•‡¶æ"}
WATCHLIST_ONLY_FOR_TRG = False

# ==============================================================================
# MEMORY OPTIMIZATION FLAGS
# ==============================================================================
GRADIENT_CHECKPOINTING = True

# ==============================================================================
# UTILITY FUNCTIONS
# ==============================================================================

def normalize_bengali(t: str) -> str:
    """Normalize Bengali text using NFKC Unicode normalization."""
    if not t:
        return ""
    return unicodedata.normalize("NFKC", t).strip()

def normalize_english(t: str) -> str:
    """Normalize English text: NFKC + lowercase + strip."""
    if not t:
        return ""
    return unicodedata.normalize("NFKC", t).lower().strip()

def empty_cuda_cache():
    """Safely empty CUDA cache and run garbage collection."""
    gc.collect()
    if torch.cuda.is_available():
        try:
            torch.cuda.empty_cache()
        except Exception:
            pass

def safe_cuda_synchronize():
    """Safely synchronize CUDA operations."""
    if torch.cuda.is_available():
        try:
            torch.cuda.synchronize()
        except Exception:
            pass

def monitor_gpu_usage():
    """Print GPU memory usage for all visible GPUs."""
    if torch.cuda.is_available():
        visible_gpus = torch.cuda.device_count()
        for i in range(visible_gpus):
            try:
                mem_alloc = torch.cuda.memory_allocated(i) / (1024**3)
                mem_reserved = torch.cuda.memory_reserved(i) / (1024**3)
                print(f"[GPU] {i}: {mem_alloc:.2f}GB allocated / {mem_reserved:.2f}GB reserved")
            except Exception:
                print(f"[GPU] {i}: memory stats unavailable")
    else:
        print("[GPU] CUDA not available")

# ==============================================================================
# TIMEOUT DECORATOR
# ==============================================================================

class FunctionTimeoutError(Exception):
    """Custom exception for function timeout."""
    pass

def with_timeout(seconds):
    """
    Decorator to enforce timeout on functions.
    Returns None if function exceeds timeout.
    """
    def decorator(func):
        def wrapper(*args, **kwargs):
            result = [FunctionTimeoutError("Function timed out")]
            def target():
                try:
                    result[0] = func(*args, **kwargs)
                except Exception as e:
                    result[0] = e
            thread = threading.Thread(target=target, daemon=True)
            thread.start()
            thread.join(timeout=seconds)
            if thread.is_alive():
                return None  # Timeout occurred
            if isinstance(result[0], Exception):
                if isinstance(result[0], FunctionTimeoutError):
                    return None
                # Re-raise original exception for obvious failures
                raise result[0]
            return result[0]
        return wrapper
    return decorator

# ==============================================================================
# SPECIAL TOKENS & VALIDATION HELPERS
# ==============================================================================

def get_special_tokens(tokenizer) -> set:
    """Extract special tokens from tokenizer safely."""
    try:
        s = getattr(tokenizer, "all_special_tokens", None)
        if s:
            return set(s)
    except Exception:
        pass
    # Conservative fallback
    return {"<pad>", "</s>", "<s>", "<unk>", "[PAD]", "[CLS]", "[SEP]", "[MASK]"}

# Lightweight token validity with thread-safe caching
_token_validation_cache: Dict[Tuple[str, str], bool] = {}
_cache_lock = threading.Lock()
_cache_max_size = 10000

def is_valid_token(token, special_tokens: Optional[set] = None,
                   tokenizer=None, language: str = "bn") -> bool:
    """
    Check if token is valid for homograph disambiguation.
    Uses thread-safe caching for performance.
    """
    token = "" if token is None else str(token)
    cache_key = (token, language)

    # Check cache first
    with _cache_lock:
        if cache_key in _token_validation_cache:
            return _token_validation_cache[cache_key]

    # Clean token (remove common subword markers)
    clean = token.replace("‚ñÅ", "").replace("##", "").strip()

    # Bengali homograph watchlist check (always valid)
    try:
        if language == "bn" and clean in HOMOGRAPH_WATCHLIST_BN:
            result = True
            with _cache_lock:
                if len(_token_validation_cache) < _cache_max_size:
                    _token_validation_cache[cache_key] = result
            return result
    except Exception:
        pass

    # Special token check
    if special_tokens and token in special_tokens:
        result = False
    else:
        # Length check (Bengali needs 2+ chars, English needs 3+)
        min_len = 2 if language == "bn" else 3
        if len(clean) < min_len:
            result = False
        elif not any(c.isalpha() for c in clean):
            # Must contain at least one alphabetic character
            result = False
        else:
            # Must be at least 60% alphabetic
            alpha_count = sum(c.isalpha() for c in clean)
            if alpha_count / max(1, len(clean)) < 0.6:
                result = False
            else:
                result = True

    # Cache result safely
    with _cache_lock:
        if len(_token_validation_cache) < _cache_max_size:
            _token_validation_cache[cache_key] = result
    return result

def safe_tokenize_with_offsets(tokenizer, text: str, max_length: int = 512):
    """
    Safely tokenize text with offset mapping.
    Returns (tokens, offsets) or (None, None) on failure.
    Supports tokenizers that follow the HuggingFace call API.
    """
    try:
        encoded = tokenizer(
            text,
            return_offsets_mapping=True,
            max_length=max_length,
            truncation=True,
            add_special_tokens=False
        )
        # extract input_ids robustly
        input_ids = encoded.get("input_ids", None)
        # input_ids may be tensor or list-of-lists. Normalize to a plain python list for token conversion.
        if input_ids is None:
            # try alternative access
            if hasattr(encoded, "data") and isinstance(encoded.data, dict):
                input_ids = encoded.data.get("input_ids", None)
        # Normalize input_ids to list of ints representing first example
        ids_list = []
        if isinstance(input_ids, list) and input_ids:
            # Could be list-of-lists
            first = input_ids[0]
            if isinstance(first, list):
                ids_list = list(first)
            else:
                ids_list = list(input_ids)
        elif hasattr(input_ids, "tolist"):
            try:
                arr = input_ids.tolist()
                if isinstance(arr, list) and len(arr) > 0 and isinstance(arr[0], list):
                    ids_list = arr[0]
                else:
                    ids_list = arr
            except Exception:
                ids_list = []
        else:
            ids_list = []

        # offsets extraction
        offsets = encoded.get("offset_mapping", None)
        if offsets is None and hasattr(encoded, "data") and isinstance(encoded.data, dict):
            offsets = encoded.data.get("offset_mapping", None)

        # Normalize offsets to list-of-(start,end) for first example
        offsets_list = []
        if offsets is not None:
            if isinstance(offsets, list) and len(offsets) > 0:
                # offsets might be list of lists: offsets[0] is for first example
                first = offsets[0] if isinstance(offsets[0], (list, tuple)) else offsets
                offsets_list = [tuple(o) if isinstance(o, (list, tuple)) and len(o) == 2 else (None, None) for o in first]
            elif hasattr(offsets, "tolist"):
                try:
                    arr = offsets.tolist()
                    if isinstance(arr, list) and len(arr) > 0 and isinstance(arr[0], list):
                        offsets_list = [tuple(o) if isinstance(o, (list, tuple)) and len(o) == 2 else (None, None) for o in arr[0]]
                except Exception:
                    offsets_list = []
        # convert ids_list -> tokens
        toks = []
        if ids_list:
            try:
                if hasattr(tokenizer, "convert_ids_to_tokens"):
                    toks = tokenizer.convert_ids_to_tokens(ids_list)
                else:
                    # best-effort tokenization fallback
                    toks = tokenizer.tokenize(text) if hasattr(tokenizer, "tokenize") else [str(i) for i in ids_list]
            except Exception:
                toks = tokenizer.tokenize(text) if hasattr(tokenizer, "tokenize") else [str(i) for i in ids_list]
        else:
            # fallback: try tokenizer.tokenize on text
            try:
                toks = tokenizer.tokenize(text)
            except Exception:
                toks = []

        return toks, offsets_list
    except Exception:
        return None, None

# ==============================================================================
# RANDOM SEEDS & BACKEND TWEAKS
# ==============================================================================

torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)
if torch.cuda.is_available():
    try:
        torch.cuda.manual_seed_all(SEED)
    except Exception:
        pass

# PyTorch performance optimizations (safe guarded)
if hasattr(torch, "set_float32_matmul_precision"):
    try:
        torch.set_float32_matmul_precision("high")
    except Exception:
        pass

# cuDNN optimizations (benchmark/deterministic balance)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False

# ==============================================================================
# FALLBACK: small synthetic dataset when CSV is missing
# ==============================================================================

FALLBACK_DATASET = [
    {"src": "‡¶§‡¶ø‡¶®‡¶ø ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï ‡¶ó‡ßá‡¶õ‡ßá‡¶®‡•§", "tgt": "He went to the bank."},
    {"src": "‡¶¨‡¶æ‡¶∞‡ßç‡¶• ‡¶™‡ßá‡ßü‡ßá‡¶õ‡¶ø‡¶≤‡¶æ‡¶Æ‡•§", "tgt": "I received a birthday present."},
    {"src": "‡¶∏‡ßá ‡¶è‡¶ï‡¶ü‡¶ø ‡¶ï‡¶≤ ‡¶Ü‡¶Æ‡¶æ‡¶ï‡ßá ‡¶¶‡¶ø‡ßü‡ßá‡¶õ‡ßá‡•§", "tgt": "He gave me a call."},
]

def get_effective_num_samples() -> int:
    """Return the number of samples we will actually attempt to use."""
    if _CSV_AVAILABLE and _HAS_PANDAS:
        try:
            # quick probe: read small head to ensure file is readable, then compute count using chunk-based helper
            _ = pd.read_csv(DATASET_CSV_PATH, nrows=1)
            cnt = _get_csv_row_count(DATASET_CSV_PATH)
            if cnt is None:
                return min(NUM_SAMPLES, len(FALLBACK_DATASET))
            return min(NUM_SAMPLES, int(cnt))
        except Exception:
            return min(NUM_SAMPLES, len(FALLBACK_DATASET))
    else:
        return min(NUM_SAMPLES, len(FALLBACK_DATASET))

EFFECTIVE_NUM_SAMPLES = get_effective_num_samples()

# ==============================================================================
# CONFIGURATION SUMMARY (sanity checks)
# ==============================================================================

print("\n" + "="*80)
print("‚ö° OPTIMIZED ULTRA-FAST TATN CONFIGURATION (Cell 0 - DEBUGGED)")
print("="*80)
print(f"User: {os.getenv('KAGGLE_USERNAME', os.getenv('USER', 'manas0003'))}")
print(f"Date: {time.strftime('%Y-%m-%d %H:%M:%S', time.gmtime())} UTC")
print(f"Multi-GPU: {'ENABLED' if USE_MULTI_GPU else 'DISABLED'} ({NUM_GPUS} GPUs visible)")
print(f"Dataset source: {'LOCAL CSV' if _CSV_AVAILABLE else 'FALLBACK_EMBEDDED_SMALLSET'}")
print(f"Dataset path: {DATASET_CSV_PATH}")
print(f"Dataset samples (cap): {NUM_SAMPLES:,} (effective: {EFFECTIVE_NUM_SAMPLES:,})")
print(f"Batch Size: {BATCH_SIZE} x {ACCUMULATION_STEPS} grad-accum steps")
print(f"Effective batch size: {BATCH_SIZE * ACCUMULATION_STEPS}")
print(f"Max Length: {MAX_LENGTH} tokens")
print(f"Epochs: {EPOCHS}")
print(f"Workers: {NUM_WORKERS}, Prefetch: {PREFETCH_FACTOR}, Pin memory: {PIN_MEMORY}")
print(f"AMP: {'ENABLED' if USE_AMP else 'DISABLED'}")
print(f"Validation interval: {VALIDATION_CHECK_INTERVAL} ({'DISABLED' if VALIDATION_CHECK_INTERVAL == 0 else 'ENABLED'})")
print()
print("DSCD Config:")
print(f"  Buffer size: {DSCD_BUFFER_SIZE}")
print(f"  Max prototypes: {DSCD_MAX_PROTOS}")
print(f"  n_min: {DSCD_N_MIN}")
print(f"  dispersion threshold: {DSCD_DISPERSION_THRESHOLD}")
print(f"  embedding dim: {DSCD_EMBED_DIM}")
print(f"  temperature: {DSCD_TEMPERATURE}")
print(f"  training clustering: {'ENABLED' if DSCD_ENABLE_TRAINING_CLUSTERING else 'DISABLED (warmup only)'}")
print(f"  warmup samples: {DSCD_WARMUP_SAMPLES}")
print()
print("TRG & Uncertainty:")
print(f"  TAU_LOW: {TAU_LOW}, TAU_HIGH: {TAU_HIGH}, TAU_ACCEPT: {TAU_ACCEPT}")
print(f"  span threshold: {SPAN_THRESHOLD}")
print(f"  TRG training: {'ENABLED' if ENABLE_TRG_TRAINING else 'DISABLED'}")
print(f"  TRG inference: {'ENABLED' if ENABLE_TRG_INFERENCE else 'DISABLED'}")
print()
print("ASBN / Loss weights:")
print(f"  ASBN training: {'ENABLED' if ENABLE_ASBN_TRAINING else 'DISABLED'}")
print(f"  ASBN inference: {'ENABLED' if ENABLE_ASBN_INFERENCE else 'DISABLED'}")
print(f"  LAMBDA_ASBN: {LAMBDA_ASBN}")
print(f"  LAMBDA_DSCD: {LAMBDA_DSCD}")
print()
print("Learning Rates:")
print(f"  NMT: {LR_NMT}, TRG: {LR_TRG}, PHI: {LR_PHI}")
print("="*80)
print("üîß MEMORY OPTIMIZATIONS APPLIED:")
print(f"  ‚Ä¢ Batch size: {BATCH_SIZE}")
print(f"  ‚Ä¢ Accumulation steps: {ACCUMULATION_STEPS}")
print(f"  ‚Ä¢ DSCD buffer reduced: {DSCD_BUFFER_SIZE}")
print(f"  ‚Ä¢ Gradient checkpointing: {'ENABLED' if GRADIENT_CHECKPOINTING else 'DISABLED'}")
print("="*80)

# Final sanity checks and small auto-corrections
if not (0.0 <= TAU_LOW <= 1.0):
    print("[WARN] TAU_LOW out of range [0, 1]; resetting to 0.4")
    TAU_LOW = 0.4

if not (0.0 <= TAU_HIGH <= 1.0):
    print("[WARN] TAU_HIGH out of range [0, 1]; resetting to 0.85")
    TAU_HIGH = 0.85

if TAU_LOW >= TAU_HIGH:
    print("[WARN] TAU_LOW >= TAU_HIGH; resetting to TAU_LOW=0.4, TAU_HIGH=0.85")
    TAU_LOW, TAU_HIGH = 0.4, 0.85

if VALIDATION_CHECK_INTERVAL != 0:
    print(f"[INFO] Validation enabled every {VALIDATION_CHECK_INTERVAL} steps")

if not _HAS_PANDAS:
    print("[WARN] pandas not installed. CSV loading will use fallback dataset or will require installing pandas.")

if _CSV_AVAILABLE and _HAS_PANDAS:
    # quick sample size validation (avoid pathological configs)
    try:
        # robust line counting in text mode to estimate file size
        nrows = None
        try:
            nrows = _get_csv_row_count(DATASET_CSV_PATH)
        except Exception:
            nrows = None
        if nrows is not None and nrows < 10 and EFFECTIVE_NUM_SAMPLES > nrows:
            print("[WARN] CSV seems very small relative to NUM_SAMPLES. Adjust NUM_SAMPLES if needed.")
    except Exception:
        pass

print("‚úÖ Cell 0: Configuration loaded and debugged (OOM prevention + DSCD defaults + CSV fallback).")

[Cell 0] Multi-GPU Mode: 2 GPUs available (using device=cuda)
[Cell 0] Device: cuda (visible GPUs: 2)
[INFO] Dataset CSV found: /kaggle/input/homo-bn-dataset/bn_homograph_complete_dataset.csv
[INFO] CSV validation passed (columns: ['idx', 'src', 'tgt', 'word', 'sense'])

‚ö° OPTIMIZED ULTRA-FAST TATN CONFIGURATION (Cell 0 - DEBUGGED)
User: manas0003
Date: 2025-11-22 14:56:05 UTC
Multi-GPU: ENABLED (2 GPUs visible)
Dataset source: LOCAL CSV
Dataset path: /kaggle/input/homo-bn-dataset/bn_homograph_complete_dataset.csv
Dataset samples (cap): 50,000 (effective: 50,000)
Batch Size: 100 x 16 grad-accum steps
Effective batch size: 1600
Max Length: 48 tokens
Epochs: 2
Workers: 2, Prefetch: 2, Pin memory: True
AMP: ENABLED
Validation interval: 200 (ENABLED)

DSCD Config:
  Buffer size: 20
  Max prototypes: 8
  n_min: 5
  dispersion threshold: 0.25
  embedding dim: 1024
  temperature: 0.7
  training clustering: ENABLED
  warmup samples: 8000

TRG & Uncertainty:
  TAU_LOW: 0.4, TAU_HIGH: 0.85, TA

In [4]:
# ===========================================================================================
# CELL 1 - SAFE TOKENIZER UTILITIES (HARDENED)
# - Robust special-token caching
# - Deterministic offset normalization (encoded["offset_mapping"] always present)
# - Fast / slow tokenizer handling improved
# - Word-span reconstruction fallback order: offsets -> SPM markers -> whitespace
# ===========================================================================================

import threading
from typing import Tuple, List, Dict, Optional
import numpy as np
import torch

# Local defaults to avoid hard dependency on other cells
try:
    SAFE_OFFSET_MAX_LEN = int(MAX_LENGTH)
except NameError:
    SAFE_OFFSET_MAX_LEN = 48

try:
    _SOURCE_LANG = SOURCE_LANGUAGE
except NameError:
    _SOURCE_LANG = "bn"  # default to Bengali if not specified

# Thread-safe cache for special tokens
_SPECIAL_TOKENS_CACHE: Dict[str, set] = {}
_SPECIAL_TOKENS_LOCK = threading.Lock()


def _special_token_cache_key(tokenizer) -> str:
    """Build a stable key for caching special token sets for a tokenizer."""
    # tokenizer.name_or_path is preferred; fallback to repr
    name = getattr(tokenizer, "name_or_path", None) or getattr(tokenizer, "name", None) or repr(tokenizer)
    # determine vocab size safely
    vocab = None
    if hasattr(tokenizer, "vocab_size"):
        try:
            vocab = int(getattr(tokenizer, "vocab_size"))
        except Exception:
            vocab = None
    elif hasattr(tokenizer, "get_vocab") and callable(getattr(tokenizer, "get_vocab")):
        try:
            vocab = len(tokenizer.get_vocab())
        except Exception:
            vocab = None
    # final key:
    return f"{name}__vocab={vocab}"


def get_tokenizer_special_tokens(tokenizer) -> set:
    """
    Return a cached set of special tokens for `tokenizer`.
    The result is conservative (includes common placeholders) and avoids
    repeated expensive introspection.
    """
    cache_key = _special_token_cache_key(tokenizer)
    with _SPECIAL_TOKENS_LOCK:
        if cache_key in _SPECIAL_TOKENS_CACHE:
            return _SPECIAL_TOKENS_CACHE[cache_key]

        special_tokens = set()
        try:
            # Try common tokenizer attributes in order of availability
            if hasattr(tokenizer, "all_special_tokens"):
                try:
                    special_tokens.update(x for x in getattr(tokenizer, "all_special_tokens") or [] if x)
                except Exception:
                    pass
            if hasattr(tokenizer, "additional_special_tokens"):
                try:
                    special_tokens.update(x for x in getattr(tokenizer, "additional_special_tokens") or [] if x)
                except Exception:
                    pass
            # single-token attributes
            for attr in ("pad_token", "unk_token", "bos_token", "eos_token", "cls_token", "sep_token", "mask_token"):
                if hasattr(tokenizer, attr):
                    try:
                        tok = getattr(tokenizer, attr)
                        if tok:
                            special_tokens.add(tok)
                    except Exception:
                        pass
            # special_tokens_map or extended map
            try:
                stm = getattr(tokenizer, "special_tokens_map", None) or getattr(tokenizer, "special_tokens_map_extended", None)
                if isinstance(stm, dict):
                    for v in stm.values():
                        if isinstance(v, str) and v:
                            special_tokens.add(v)
            except Exception:
                pass

        except Exception:
            # fallback to safe conservative set
            special_tokens = set()

        # Add conservative language / placeholder tokens likely useful for m2m100 & friends
        special_tokens.update({
            "bn_IN", "en_XX",
            "</s>", "<pad>", "<s>", "<unk>",
            "[PAD]", "[EOS]", "[UNK]", "[CLS]", "[SEP]", "[MASK]"
        })

        _SPECIAL_TOKENS_CACHE[cache_key] = special_tokens
        return special_tokens


def _normalize_offset_mapping_for_batchencoding(enc):
    """
    Normalize a BatchEncoding (from HF tokenizer) so that enc['offset_mapping']
    is set and in Python list-of-(start,end) tuples for the first example in the batch.
    This function mutates enc in-place and returns it.
    """
    # prefer the direct key if present (works for fast tokenizers)
    try:
        if "offset_mapping" in enc and enc["offset_mapping"] is not None:
            off = enc["offset_mapping"]
            # Case: tensor (pt) or list-of-lists
            try:
                # If pt tensor
                if hasattr(off, "tolist"):
                    arr = off.tolist()
                    # arr is typically [[ [s,e], [s,e], ... ]]
                    if isinstance(arr, list) and len(arr) > 0 and isinstance(arr[0], list):
                        enc["offset_mapping"] = [tuple(x) if isinstance(x, list) and len(x) == 2 else (None, None) for x in arr[0]]
                        return enc
                # If already list-like
                if isinstance(off, (list, tuple)):
                    # ensure first-element list -> normalize its elements to tuples
                    if len(off) > 0 and isinstance(off[0], (list, tuple)):
                        enc["offset_mapping"] = [tuple(x) if isinstance(x, (list, tuple)) and len(x) == 2 else (None, None) for x in off[0]]
                        return enc
            except Exception:
                pass
    except Exception:
        pass

    # Last resort: if BatchEncoding exposes .data with offset_mapping, try that
    try:
        data = getattr(enc, "data", None)
        if data and isinstance(data, dict) and "offset_mapping" in data and data["offset_mapping"] is not None:
            om = data["offset_mapping"]
            if isinstance(om, (list, tuple)) and len(om) > 0 and isinstance(om[0], (list, tuple)):
                enc["offset_mapping"] = [tuple(x) if isinstance(x, (list, tuple)) and len(x) == 2 else (None, None) for x in om[0]]
                return enc
    except Exception:
        pass

    # If we reach here, ensure enc["offset_mapping"] exists and is a list for the first example (sequence length placeholder)
    try:
        seq_len = 0
        if "input_ids" in enc:
            input_ids = enc["input_ids"]
            # input_ids may be tensor or list
            if hasattr(input_ids, "shape"):
                seq_len = int(input_ids.shape[-1])
            elif isinstance(input_ids, (list, tuple)) and len(input_ids) > 0 and isinstance(input_ids[0], (list, tuple)):
                seq_len = len(input_ids[0])
        # create placeholder offsets
        enc["offset_mapping"] = [(None, None)] * seq_len
    except Exception:
        enc["offset_mapping"] = []

    return enc


def safe_offsets_tokenize(tokenizer, text: str, max_length: Optional[int] = None,
                          include_special_tokens: bool = False) -> dict:
    """
    Tokenize `text` with tokenizer and *guarantee* that the return value has:
      - 'input_ids' and optionally 'attention_mask' (as returned by HF tokenizer)
      - 'offset_mapping' key present and normalized to a list of (start,end) tuples
        for the first example in the batch (or an empty list if unavailable).

    Parameters:
      tokenizer: HF tokenizer instance (fast or slow)
      text: input string
      max_length: token truncation max (defaults to SAFE_OFFSET_MAX_LEN)
      include_special_tokens: whether to include special tokens in tokenization
    """
    if max_length is None:
        max_length = SAFE_OFFSET_MAX_LEN
    eff_max = int(max_length)

    if not isinstance(text, str):
        text = "" if text is None else str(text)

    # Limit characters to avoid pathological inputs
    char_limit = min(eff_max * 20, 2000)
    sample_text = text[:char_limit]

    is_fast = getattr(tokenizer, "is_fast", False)

    # Prefer the fast path; ensure we ask for offsets and tensor outputs for convenience
    if is_fast:
        try:
            enc = tokenizer(
                sample_text,
                return_offsets_mapping=True,
                return_tensors="pt",
                truncation=True,
                padding=False,
                max_length=eff_max,
                add_special_tokens=include_special_tokens,
            )
            enc = _normalize_offset_mapping_for_batchencoding(enc)
            return enc
        except Exception:
            # fallthrough to slow path
            pass

    # Slow tokenizer path: ask for ids, then build best-effort offsets
    try:
        enc = tokenizer(
            sample_text,
            return_tensors="pt",
            truncation=True,
            padding=False,
            max_length=eff_max,
            add_special_tokens=include_special_tokens,
        )
    except Exception:
        # If the tokenizer call fails completely, produce a minimal encoding
        # that downstream code can still handle.
        enc = {"input_ids": torch.tensor([[tokenizer.pad_token_id if hasattr(tokenizer, "pad_token_id") else 0]]),
               "attention_mask": torch.tensor([[1]])}
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc

    # Try to compute a fallback offset map by aligning decoded token text to source
    try:
        # get sequence of token ids (first example)
        input_ids = None
        try:
            input_ids = enc["input_ids"][0].tolist()
        except Exception:
            # try alternative access
            if hasattr(enc, "data") and "input_ids" in enc.data:
                input_ids = enc.data["input_ids"][0]
        tokens = []
        if input_ids is not None:
            try:
                tokens = tokenizer.convert_ids_to_tokens(input_ids)
            except Exception:
                tokens = []
        # Build offsets by searching token text in source progressively
        offsets_list = []
        src = sample_text
        cur_pos = 0
        for tok in tokens:
            # clean subword markers commonly used by SPM/BPE/fast tokenizers
            token_text = (tok or "").replace("‚ñÅ", "").replace("ƒ†", "").strip()
            if not token_text:
                offsets_list.append((None, None))
                continue
            # naive search from current position
            idx = src.find(token_text, cur_pos)
            if idx == -1:
                idx = src.lower().find(token_text.lower(), cur_pos)
            if idx == -1:
                offsets_list.append((None, None))
            else:
                start = int(idx)
                end = int(idx + len(token_text))
                offsets_list.append((start, end))
                cur_pos = end
        # normalize to same format expected by _normalize_offset_mapping_for_batchencoding
        enc["offset_mapping"] = offsets_list
        # ensure normalized (wrap as first-example list)
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc
    except Exception:
        # fallback: ensure offset_mapping exists
        enc = _normalize_offset_mapping_for_batchencoding(enc)
        return enc


def reconstruct_word_spans(tokenizer, text: str, max_length: Optional[int] = None) -> Tuple[Dict[int, str], List[str]]:
    """
    Return:
      - token_word_map: mapping token_index -> reconstructed word string (best-effort)
      - words: list[str] of words discovered in order

    Strategy:
      1) Use tokenizer offsets when available -> group contiguous character spans into words.
      2) If offsets unavailable or unhelpful, use SPM-style '‚ñÅ' or 'ƒ†' markers to assemble subwords.
      3) Finally fallback to whitespace-splitting.
    """
    if max_length is None:
        max_length = SAFE_OFFSET_MAX_LEN
    eff_max = int(max_length)

    if not isinstance(text, str) or len(text.strip()) == 0:
        return {}, []

    char_limit = min(eff_max * 20, 2000)
    text = text[:char_limit]
    text_len = len(text)

    special_tokens = get_tokenizer_special_tokens(tokenizer)

    try:
        current_lang = SOURCE_LANGUAGE
    except NameError:
        current_lang = _SOURCE_LANG

    # Get normalized encoding (guarantees offset_mapping exists)
    try:
        encoded = safe_offsets_tokenize(tokenizer, text, max_length=eff_max, include_special_tokens=False)
    except Exception:
        return {}, []

    offsets = encoded.get("offset_mapping", [])
    # ensure input_ids and tokens exist
    try:
        input_ids = encoded["input_ids"][0].tolist()
    except Exception:
        input_ids = []
    try:
        tokens = tokenizer.convert_ids_to_tokens(input_ids) if input_ids else []
    except Exception:
        tokens = []

    # Ensure offsets is a list with len(tokens) (if possible)
    if isinstance(offsets, list) and len(offsets) > 0 and all(isinstance(x, tuple) for x in offsets):
        offsets_list = offsets
    elif isinstance(offsets, list) and len(offsets) > 0 and isinstance(offsets[0], (list, tuple)):
        offsets_list = [tuple(x) if isinstance(x, (list, tuple)) and len(x) == 2 else (None, None) for x in offsets[0]]
    else:
        # not usable
        offsets_list = [(None, None)] * len(tokens)

    token_word_map: Dict[int, str] = {}
    words: List[str] = []

    # 1) Use offsets to group contiguous spans into words
    used_any_offset = any((isinstance(o, tuple) and o[0] is not None and o[1] is not None) for o in offsets_list)
    if used_any_offset:
        word_start = None
        word_end = None
        word_accum = ""
        for idx, (off, tok) in enumerate(zip(offsets_list, tokens)):
            try:
                off_start, off_end = (int(off[0]) if off[0] is not None else None, int(off[1]) if off[1] is not None else None)
            except Exception:
                off_start, off_end = None, None
            if off_start is None or off_end is None:
                # token with no offsets: close existing word and skip
                if word_start is not None and word_end is not None:
                    try:
                        wtext = text[word_start:word_end].strip()
                        if wtext:
                            words.append(wtext)
                    except Exception:
                        pass
                word_start = None
                word_end = None
                word_accum = ""
                token_word_map[idx] = "UNK"
                continue

            # optionally skip special tokens
            if tok in special_tokens:
                token_word_map[idx] = ""
                continue

            # Start new word if needed
            if word_start is None:
                word_start = off_start
                word_end = off_end
            else:
                # If this token begins after the previous end -> new word
                if off_start > word_end:
                    # flush previous
                    try:
                        wtext = text[word_start:word_end].strip()
                        if wtext:
                            words.append(wtext)
                    except Exception:
                        pass
                    word_start = off_start
                    word_end = off_end
                else:
                    word_end = max(word_end, off_end)

            # map token to the current word slice (best-effort)
            try:
                current_word = text[word_start:word_end].strip()
                token_word_map[idx] = current_word if current_word else "UNK"
            except Exception:
                token_word_map[idx] = "UNK"

        # flush last
        if word_start is not None and word_end is not None:
            try:
                wtext = text[word_start:word_end].strip()
                if wtext:
                    words.append(wtext)
            except Exception:
                pass

        if token_word_map:
            words = [w for w in words if isinstance(w, str) and w.strip()]
            return token_word_map, words

    # 2) Fallback to SPM/BPE marker assembly (tokens marked with '‚ñÅ' or 'ƒ†')
    token_word_map = {}
    assembled = []
    current = ""
    running_word = ""
    for i, tok in enumerate(tokens):
        # skip special tokens
        if tok in special_tokens:
            token_word_map[i] = ""
            continue
        # normalize token text
        clean = (tok or "").replace("‚ñÅ", "").replace("ƒ†", "").strip()
        if not clean:
            token_word_map[i] = ""
            continue
        if (tok.startswith("‚ñÅ") or tok.startswith("ƒ†")):
            # new word
            if current:
                assembled.append(current)
            current = clean
            running_word = current
        else:
            # continuation subword
            current = current + clean
            running_word = current
        token_word_map[i] = running_word if running_word else "UNK"
    if current:
        assembled.append(current)
    if token_word_map:
        words = [w for w in assembled if w and w.strip()]
        return token_word_map, words

    # 3) Final fallback: whitespace-split the original text and assign tokens approximately
    try:
        word_list = [w for w in text.split() if w.strip()]
        token_word_map = {}
        if tokens and word_list:
            widx = 0
            for i, tok in enumerate(tokens):
                clean = (tok or "").replace("‚ñÅ", "").replace("ƒ†", "").strip()
                if not clean:
                    token_word_map[i] = ""
                    continue
                token_word_map[i] = word_list[min(widx, len(word_list) - 1)]
                # Heuristic: if token looks long or contains punctuation advance
                if len(clean) > len(token_word_map[i]) or clean.endswith((".", ",", ";", "‡•§", "?" , "!" )):
                    widx = min(widx + 1, len(word_list) - 1)
        return token_word_map, word_list
    except Exception:
        return {}, []


# ===========================================================================================
# LIGHTWEIGHT SELF-TEST
# ===========================================================================================
def test_tokenizer_utilities_quick(tokenizer=None):
    """
    If tokenizer is None, this will only sanity-check Python-level logic.
    If tokenizer is provided (HF tokenizer), it will run a quick encode + reconstruct.
    """
    sample = "‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶æ‡¶ú‡¶æ‡¶∞‡ßá ‡¶Ø‡¶æ‡¶¨‡•§"  # Bengali: "Tomorrow I will go to the market."
    print("Running tokenizer-utils quick test...")
    try:
        if tokenizer is None:
            print("No tokenizer provided: basic logic OK.")
            return True
        enc = safe_offsets_tokenize(tokenizer, sample, max_length=32, include_special_tokens=False)
        print("  Encoded input_ids len:", int(enc["input_ids"].shape[-1]) if "input_ids" in enc else "N/A")
        print("  Offset mapping (first 10):", (enc.get("offset_mapping") or [])[:10])
        token_map, words = reconstruct_word_spans(tokenizer, sample, max_length=32)
        print("  Reconstructed words:", words)
        print("  Token->word examples:", {k: token_map[k] for k in list(token_map.keys())[:6]})
        return True
    except Exception as e:
        print("Tokenizer utilities quick test failed:", repr(e))
        return False


# This print is a gentle confirmation that the utilities loaded.
print("‚úÖ Cell 1 (tokenizer utilities) loaded and hardened.")

‚úÖ Cell 1 (tokenizer utilities) loaded and hardened.


In [5]:
# ==============================================================================
# CELL 2: MEMORY-EFFICIENT DATA LOADING (FIXED & HARDENED + CSV SUPPORT)
# ==============================================================================
# ‚úÖ FIXED: Replaced Samanantar with local CSV loading
# ‚úÖ FIXED: Added pandas-based CSV reader with proper column mapping
# ‚úÖ FIXED: Enhanced error handling and validation
# - Robust fallbacks when datasets/tokenizer utilities are missing
# - Safer DP-divisible batching logic (floor to nearest multiple by default)
# - Worker init rebinds tokenizer safely for multiprocessing workers
# - Deterministic per-worker seeding
# - Safe collate that always returns stackable tensors and preserves token_word_map
# - Defensive behaviors and verbose debug prints controlled by VERBOSE_LOGGING
# ==============================================================================
from typing import Optional, List, Tuple, Dict, Any
from collections import defaultdict
import os
import time
import random
import traceback
import re

import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, get_worker_info
from tqdm import tqdm

# Pandas import for CSV reading (required for local dataset)
try:
    import pandas as pd
    _HAS_PANDAS = True
except ImportError:
    pd = None
    _HAS_PANDAS = False
    print("[CELL2] WARNING: pandas not available; CSV loading will fail!")

# Optional import - datasets library (not needed for CSV mode)
try:
    from datasets import load_dataset
    _HAS_DATASETS = True
except Exception:
    load_dataset = None
    _HAS_DATASETS = False

# -------------------------
# Debug control
# -------------------------
try:
    _VERBOSE_LOGGING = bool(VERBOSE_LOGGING)
except NameError:
    _VERBOSE_LOGGING = False

DEBUG_CELL2 = bool(_VERBOSE_LOGGING)
DEBUG_LIMIT = 10
_cell2_dbg_counts: Dict[str, int] = defaultdict(int)


def cell2_dbg(key: str, msg: str, limit: int = DEBUG_LIMIT):
    """Debug print with rate limiting."""
    if not DEBUG_CELL2:
        return
    _cell2_dbg_counts[key] += 1
    if _cell2_dbg_counts[key] <= limit:
        print(f"[CELL2-DBG] {msg}")


# -------------------------
# Local fallbacks for globals (explicit, safe)
# -------------------------
try:
    _NUM_SAMPLES = int(NUM_SAMPLES)
except Exception:
    _NUM_SAMPLES = 50000
    print("[CELL2] WARNING: NUM_SAMPLES not defined, using default 50000")

try:
    _MAX_LENGTH = int(MAX_LENGTH)
except Exception:
    _MAX_LENGTH = 48
    print("[CELL2] WARNING: MAX_LENGTH not defined, using default 48")

try:
    _BN_LANG = BN_LANG
    _EN_LANG = EN_LANG
except NameError:
    _BN_LANG = "bn"
    _EN_LANG = "en"
    print("[CELL2] WARNING: BN_LANG/EN_LANG not defined, using defaults")

try:
    _NUM_GPUS = int(NUM_GPUS)
    _USE_MULTI_GPU = bool(USE_MULTI_GPU)
except NameError:
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1
    print(f"[CELL2] WARNING: GPU config not defined, detected {_NUM_GPUS} GPUs")

try:
    _NUM_WORKERS = int(NUM_WORKERS)
except NameError:
    _NUM_WORKERS = 0
    print("[CELL2] WARNING: NUM_WORKERS not defined, using 0")

try:
    _PIN_MEMORY = bool(PIN_MEMORY)
except NameError:
    _PIN_MEMORY = False

try:
    _PREFETCH_FACTOR = int(PREFETCH_FACTOR)
except NameError:
    _PREFETCH_FACTOR = 2

try:
    _DATASET_CSV_PATH = str(DATASET_CSV_PATH)
except NameError:
    _DATASET_CSV_PATH = "/kaggle/input/bengali-english-homograph/bengali_homograph_sentences.csv"
    print(f"[CELL2] WARNING: DATASET_CSV_PATH not defined, using default: {_DATASET_CSV_PATH}")

# Check availability of utility functions from Cell 0
_has_normalize = ('normalize_bengali' in globals()) and ('normalize_english' in globals())
_has_reconstruct_word_spans = 'reconstruct_word_spans' in globals()
_has_safe_offsets_tokenize = 'safe_offsets_tokenize' in globals()

if not _has_normalize:
    print("[CELL2] WARNING: normalize_bengali/normalize_english not found; using simple .strip()")

# -------------------------
# Utility: detect Bengali text heuristically
# -------------------------
_BENGALI_CHAR_RE = re.compile(r'[\u0980-\u09FF]')

def is_bengali_text(s: str) -> bool:
    """Check if text contains Bengali Unicode characters."""
    if not isinstance(s, str) or not s:
        return False
    # if any Bengali char present, treat as Bengali
    return bool(_BENGALI_CHAR_RE.search(s))


# -------------------------
# Worker init: reattach tokenizer and set per-worker seed
# -------------------------
def _dataloader_worker_init_fn(worker_id: int):
    """Initialize DataLoader worker with tokenizer and deterministic seed."""
    worker_info = get_worker_info()
    dataset = worker_info.dataset if worker_info is not None else None
    
    # Try to rebind tokenizer from the main process globals into the worker dataset
    try:
        if dataset is not None:
            tk = globals().get('tokenizer', None)
            if tk is not None:
                try:
                    # attach tokenizer reference only (avoid copying heavy state)
                    dataset.tokenizer = tk
                    dataset.is_fast = getattr(tk, "is_fast", False)
                except Exception:
                    dataset.tokenizer = tk
                    dataset.is_fast = False
    except Exception:
        if DEBUG_CELL2:
            print(f"[CELL2-WORKER-INIT] tokenizer rebind failed in worker {worker_id}")
            traceback.print_exc()
    
    # Set a deterministic-ish per-worker seed to avoid RNG issues
    try:
        base = int(os.environ.get("PYTHONHASHSEED", "0"))
        # incorporate worker id and time low bits to change per-worker seed
        seed = (base ^ (worker_id + 1) ^ int(time.time())) & 0xFFFFFFFF
        random.seed(seed)
        np.random.seed(seed % (2**31 - 1))
        torch.manual_seed(seed % (2**31 - 1))
    except Exception:
        pass


# -------------------------
# Data loading and preprocessing (CSV-BASED)
# -------------------------
def load_and_preprocess_optimized(num_samples: Optional[int] = None) -> List[Tuple[str, str]]:
    """
    Load parallel bn-en pairs from local CSV file.
    CSV format: idx,src,tgt (where src=English, tgt=Bengali)
    Returns list of (bn, en) pairs.
    Falls back to a small hard-coded set if CSV load fails.
    """
    if num_samples is None:
        num_samples = _NUM_SAMPLES
    if num_samples <= 0:
        raise ValueError("num_samples must be positive")

    print(f"[CELL2] Loading up to {num_samples} samples from local CSV: {_DATASET_CSV_PATH}")
    
    # Validate pandas availability
    if not _HAS_PANDAS:
        print("[CELL2] ERROR: pandas not available; cannot load CSV!")
        print("[CELL2] Install with: !pip install pandas")
        print("[CELL2] Using fallback small dataset for debugging.")
        return _get_fallback_dataset()
    
    # Validate CSV file exists
    if not os.path.exists(_DATASET_CSV_PATH):
        print(f"[CELL2] ERROR: CSV file not found at: {_DATASET_CSV_PATH}")
        print("[CELL2] Using fallback small dataset for debugging.")
        return _get_fallback_dataset()
    
    try:
        # Read CSV file
        print(f"[CELL2] Reading CSV file...")
        df = pd.read_csv(_DATASET_CSV_PATH)
        
        # Validate required columns
        if 'src' not in df.columns:
            print(f"[CELL2] ERROR: CSV missing 'src' column. Found columns: {list(df.columns)}")
            return _get_fallback_dataset()
        
        if 'tgt' not in df.columns:
            print(f"[CELL2] ERROR: CSV missing 'tgt' column. Found columns: {list(df.columns)}")
            return _get_fallback_dataset()
        
        # Limit to num_samples
        df = df.head(num_samples)
        
        print(f"[CELL2] Processing {len(df)} rows from CSV...")
        
        pairs: List[Tuple[str, str]] = []
        skipped = 0
        
        for idx, row in tqdm(df.iterrows(), total=len(df), desc="Loading dataset"):
            try:
                # src = English, tgt = Bengali
                en = str(row['src']).strip()
                bn = str(row['tgt']).strip()
                
                # Basic validation
                if not en or not bn:
                    skipped += 1
                    cell2_dbg("empty_field", f"Empty src/tgt at idx={idx}")
                    continue
                
                # Check for "nan" string from pandas
                if en.lower() == 'nan' or bn.lower() == 'nan':
                    skipped += 1
                    cell2_dbg("nan_value", f"NaN value at idx={idx}")
                    continue
                
                # Length check (avoid extremely long sentences)
                max_words = max(40, _MAX_LENGTH)
                if len(en.split()) > max_words or len(bn.split()) > max_words:
                    skipped += 1
                    cell2_dbg("too_long", f"Too long at idx={idx}: en={len(en.split())} bn={len(bn.split())} words")
                    continue
                
                # Normalize if available
                if _has_normalize:
                    bn_norm = normalize_bengali(bn)
                    en_norm = normalize_english(en)
                else:
                    bn_norm = bn
                    en_norm = en.lower()
                
                # Ensure normalization didn't create empty strings
                if not bn_norm or not en_norm:
                    skipped += 1
                    cell2_dbg("empty_after_norm", f"Empty after normalization at idx={idx}")
                    continue
                
                # Store as (Bengali, English) pair - IMPORTANT ORDER!
                pairs.append((bn_norm, en_norm))
                
            except Exception as e:
                skipped += 1
                cell2_dbg("row_exception", f"Row load exception idx={idx}: {type(e).__name__}: {str(e)[:100]}")
                continue
        
        print(f"[CELL2] Loaded {len(pairs)} pairs from CSV, skipped {skipped} rows")
        
        if len(pairs) == 0:
            print("[CELL2] ERROR: No valid pairs loaded from CSV!")
            return _get_fallback_dataset()
        
        return pairs
        
    except FileNotFoundError:
        print(f"[CELL2] ERROR: CSV file not found at: {_DATASET_CSV_PATH}")
        print("[CELL2] Using fallback small dataset for debugging.")
        return _get_fallback_dataset()
    
    except pd.errors.EmptyDataError:
        print(f"[CELL2] ERROR: CSV file is empty: {_DATASET_CSV_PATH}")
        return _get_fallback_dataset()
    
    except Exception as e:
        print(f"[CELL2] ERROR loading CSV: {type(e).__name__}: {str(e)}")
        print(f"[CELL2] Traceback: {traceback.format_exc().splitlines()[-3:]}")
        print("[CELL2] Using fallback dataset")
        return _get_fallback_dataset()


def _get_fallback_dataset() -> List[Tuple[str, str]]:
    """Return a small fallback dataset for debugging/testing."""
    print("[CELL2] Using fallback small dataset (5 samples)")
    fallback_pairs = [
        ("‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§", "i turned off the tap."),
        ("‡¶∏‡ßá ‡¶Ü‡¶Æ‡¶æ‡¶ï‡ßá ‡¶™‡¶∞‡ßá ‡¶ï‡¶≤ ‡¶ï‡¶∞‡¶¨‡ßá‡•§", "he will call me later."),
        ("‡¶Ü‡¶Æ‡¶∞‡¶æ ‡¶™‡ßç‡¶∞‡¶§‡¶ø‡¶¶‡¶ø‡¶® ‡¶§‡¶æ‡¶ú‡¶æ ‡¶´‡¶≤ ‡¶ñ‡¶æ‡¶á‡•§", "we eat fresh fruits every day."),
        ("‡¶§‡¶æ‡¶∞ ‡¶ï‡¶†‡ßã‡¶∞ ‡¶™‡¶∞‡¶ø‡¶∂‡ßç‡¶∞‡¶Æ‡ßá‡¶∞ ‡¶≠‡¶æ‡¶≤‡ßã ‡¶´‡¶≤ ‡¶π‡¶Ø‡¶º‡ßá‡¶õ‡ßá‡•§", "his hard work has brought good results."),
        ("‡¶ó‡¶æ‡¶õ‡ßá ‡¶®‡¶§‡ßÅ‡¶® ‡¶™‡¶æ‡¶§‡¶æ‡¶ó‡ßÅ‡¶≤‡ßã ‡¶ó‡¶ú‡¶ø‡¶Ø‡¶º‡ßá‡¶õ‡ßá‡•§", "new leaves have sprouted on the tree.")
    ]
    if _has_normalize:
        return [(normalize_bengali(bn), normalize_english(en)) for bn, en in fallback_pairs]
    else:
        return [(bn.strip(), en.lower().strip()) for bn, en in fallback_pairs]


# -------------------------
# Dataset Class
# -------------------------
class MemoryEfficientDataset(Dataset):
    """
    Memory-efficient dataset that returns dicts with:
      - input_ids, attention_mask: torch.LongTensor [L]
      - labels: torch.LongTensor [L] with pad->-100
      - token_word_map: dict token_idx->word
      - src_text: original source string
      - tokens: list of token strings
    The tokenizer attribute is excluded from pickled state so DataLoader workers don't crash.
    """

    def __init__(self, pairs: List[Tuple[str, str]], tokenizer: Any = None, max_length: Optional[int] = None):
        if max_length is None:
            max_length = _MAX_LENGTH
        self.max_length = int(max_length)
        self.tokenizer = tokenizer
        try:
            self._tokenizer_name_or_path = getattr(tokenizer, "name_or_path", None)
        except Exception:
            self._tokenizer_name_or_path = None

        try:
            self.is_fast = getattr(self.tokenizer, "is_fast", False)
        except Exception:
            self.is_fast = False

        self.pairs: List[Tuple[str, str]] = []
        invalid = 0
        
        # Validate and filter pairs
        for i, p in enumerate(pairs):
            try:
                if not isinstance(p, (list, tuple)) or len(p) != 2:
                    invalid += 1
                    cell2_dbg("init_badpair", f"Bad pair structure at idx={i}")
                    continue
                
                src, tgt = p
                
                # Type validation
                if not isinstance(src, str) or not isinstance(tgt, str):
                    invalid += 1
                    cell2_dbg("init_badtype", f"Non-string src/tgt at idx={i}")
                    continue
                
                # Empty check
                if not src or not tgt:
                    invalid += 1
                    cell2_dbg("init_empty", f"Empty src/tgt at idx={i}")
                    continue
                
                # Length sanity check (character level)
                if len(src) > self.max_length * 20 or len(tgt) > self.max_length * 20:
                    invalid += 1
                    cell2_dbg("init_long", f"Extremely long text at idx={i}")
                    continue
                
                self.pairs.append((src, tgt))
                
            except Exception as e:
                invalid += 1
                cell2_dbg("init_exc", f"Init pair exception idx={i}: {type(e).__name__}")
        
        print(f"[CELL2] Dataset initialized: {len(self.pairs)} valid pairs, {invalid} invalid pairs filtered")

        # Get special tokens for filtering
        try:
            if 'get_special_tokens' in globals():
                self.special_tokens = get_special_tokens(self.tokenizer)
            elif 'get_tokenizer_special_tokens' in globals():
                self.special_tokens = get_tokenizer_special_tokens(self.tokenizer)
            else:
                self.special_tokens = set(getattr(self.tokenizer, "all_special_tokens", [])) if self.tokenizer is not None else set()
        except Exception:
            self.special_tokens = {_BN_LANG, _EN_LANG, "</s>", "<pad>", "<s>", "<unk>"}
            cell2_dbg("special_tokens_fallback", "Used explicit fallback special tokens")

    def __getstate__(self):
        """Prepare state for pickling (exclude tokenizer)."""
        state = self.__dict__.copy()
        # avoid serializing tokenizer into worker processes
        state['tokenizer'] = None
        state['_tokenizer_name_or_path'] = getattr(self, "_tokenizer_name_or_path", None)
        return state

    def __setstate__(self, state):
        """Restore state after unpickling (rebind tokenizer)."""
        self.__dict__.update(state)
        try:
            # rebind tokenizer from global if available (set by worker_init_fn)
            self.tokenizer = globals().get('tokenizer', None)
            self.is_fast = getattr(self.tokenizer, "is_fast", False) if self.tokenizer is not None else False
        except Exception:
            self.tokenizer = None
            self.is_fast = False

    def __len__(self) -> int:
        return len(self.pairs)

    def _encode_src(self, src_text: str):
        """Encode source (Bengali) text."""
        src_text = src_text if isinstance(src_text, str) else str(src_text)
        
        try:
            # Ensure tokenizer is available
            if self.tokenizer is None:
                try:
                    self.tokenizer = globals().get('tokenizer', None)
                    self.is_fast = getattr(self.tokenizer, "is_fast", False) if self.tokenizer is not None else False
                except Exception:
                    self.tokenizer = None
                    self.is_fast = False

            if self.tokenizer is None:
                raise RuntimeError("Tokenizer not available")

            # Set source language hints if tokenizer supports it
            try:
                if hasattr(self.tokenizer, "src_lang"):
                    self.tokenizer.src_lang = _BN_LANG
            except Exception:
                pass

            # Prefer safe_offsets_tokenize if available
            if _has_safe_offsets_tokenize:
                enc = safe_offsets_tokenize(self.tokenizer, src_text, max_length=self.max_length)
                try:
                    input_ids = enc["input_ids"].squeeze(0) if isinstance(enc["input_ids"], torch.Tensor) else torch.tensor(enc["input_ids"][0])
                except Exception:
                    input_ids = torch.tensor(enc.get("input_ids", [[1]])[0])
                
                attention_mask = enc.get("attention_mask", torch.ones_like(input_ids))
                if isinstance(attention_mask, list):
                    attention_mask = torch.tensor(attention_mask[0]) if attention_mask else torch.ones_like(input_ids)
                
                try:
                    ids_list = input_ids.tolist() if isinstance(input_ids, torch.Tensor) else list(input_ids)
                    tokens = self.tokenizer.convert_ids_to_tokens(ids_list)
                except Exception:
                    tokens = []
            else:
                # Standard tokenization
                enc = self.tokenizer(
                    src_text,
                    max_length=self.max_length,
                    padding="max_length",
                    truncation=True,
                    return_tensors="pt",
                    add_special_tokens=False
                )
                input_ids = enc["input_ids"].squeeze(0)
                attention_mask = enc.get("attention_mask", torch.ones_like(input_ids)).squeeze(0)
                try:
                    tokens = self.tokenizer.convert_ids_to_tokens(input_ids.tolist())
                except Exception:
                    tokens = []

            # Build token-word mapping
            token_word_map: Dict[int, str] = {}
            if _has_reconstruct_word_spans:
                try:
                    wm, words = reconstruct_word_spans(self.tokenizer, src_text, max_length=self.max_length)
                    if isinstance(wm, dict):
                        token_word_map = wm
                except Exception:
                    cell2_dbg("wm_exc", f"reconstruct_word_spans failed: {traceback.format_exc().splitlines()[-1]}")
                    token_word_map = {}
            else:
                # Fallback: mark tokens starting with SPM markers as word starts
                try:
                    for idx, tok in enumerate(tokens):
                        if isinstance(tok, str) and (tok.startswith("‚ñÅ") or tok.startswith("ƒ†")):
                            token_word_map[idx] = tok.replace("‚ñÅ", "").replace("ƒ†", "").strip()
                except Exception:
                    token_word_map = {}

            return input_ids, attention_mask, tokens, token_word_map
            
        except Exception as e:
            cell2_dbg("encode_src_exc", f"Encoding source failed: {type(e).__name__}: {str(e)[:60]}")
            # Return safe placeholder
            pad_id = getattr(self.tokenizer, "pad_token_id", 1) if self.tokenizer else 1
            input_ids = torch.full((self.max_length,), int(pad_id), dtype=torch.long)
            attention_mask = torch.zeros(self.max_length, dtype=torch.long)
            return input_ids, attention_mask, [], {}

    def _encode_tgt(self, tgt_text: str):
        """Encode target (English) text."""
        tgt_text = tgt_text if isinstance(tgt_text, str) else str(tgt_text)
        
        try:
            if self.tokenizer is None:
                self.tokenizer = globals().get('tokenizer', None)
            
            if self.tokenizer is None:
                raise RuntimeError("Tokenizer not available")
            
            # Set target language hints where supported
            try:
                if hasattr(self.tokenizer, "tgt_lang"):
                    self.tokenizer.tgt_lang = _EN_LANG
            except Exception:
                pass
            
            dec = self.tokenizer(
                tgt_text,
                max_length=self.max_length,
                padding="max_length",
                truncation=True,
                return_tensors="pt",
                add_special_tokens=False
            )
            labels = dec["input_ids"].squeeze(0)
            
            # Replace pad tokens with -100 (ignore index for loss)
            pad_id = getattr(self.tokenizer, "pad_token_id", 1) if self.tokenizer else 1
            labels[labels == int(pad_id)] = -100
            
            return labels
            
        except Exception as e:
            cell2_dbg("encode_tgt_exc", f"Encoding tgt failed: {type(e).__name__}: {str(e)[:60]}")
            return torch.full((self.max_length,), -100, dtype=torch.long)

    def _make_safe_sample(self, reason: str = "fallback"):
        """Create a safe fallback sample."""
        try:
            src = "‡¶Ü‡¶Æ‡¶ø"
            tgt = "i"
            input_ids, attention_mask, tokens, token_word_map = self._encode_src(src)
            labels = self._encode_tgt(tgt)
            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "token_word_map": token_word_map,
                "src_text": src,
                "tokens": tokens
            }
        except Exception:
            pad_id = 1
            return {
                "input_ids": torch.full((self.max_length,), int(pad_id), dtype=torch.long),
                "attention_mask": torch.zeros(self.max_length, dtype=torch.long),
                "labels": torch.full((self.max_length,), -100, dtype=torch.long),
                "token_word_map": {},
                "src_text": "",
                "tokens": []
            }

    def __getitem__(self, idx: int) -> Dict[str, Any]:
        """Get a single sample by index."""
        try:
            if idx < 0 or idx >= len(self.pairs):
                cell2_dbg("getitem_oob", f"Index out of range idx={idx} len={len(self.pairs)}")
                return self._make_safe_sample("oob")
            
            src, tgt = self.pairs[idx]
            
            if not isinstance(src, str) or not isinstance(tgt, str):
                cell2_dbg("getitem_bad_types", f"Bad types at idx={idx}")
                return self._make_safe_sample("bad_types")

            input_ids, attention_mask, tokens, token_word_map = self._encode_src(src)
            labels = self._encode_tgt(tgt)

            return {
                "input_ids": input_ids,
                "attention_mask": attention_mask,
                "labels": labels,
                "token_word_map": token_word_map,
                "src_text": src,
                "tokens": tokens
            }
        except Exception as e:
            cell2_dbg("getitem_exc", f"Unhandled __getitem__ exception idx={idx}: {type(e).__name__}")
            return self._make_safe_sample("unhandled")


# ---------------------------
# Collation and DataLoader helpers
# ---------------------------
def _infer_pad_id_from_sample(sample: Dict[str, Any], default_pad_id: int = 1) -> int:
    """Infer pad token id from tokenizer."""
    try:
        tk = globals().get("tokenizer", None)
        if tk is not None:
            pad = getattr(tk, "pad_token_id", None)
            if pad is not None:
                return int(pad)
    except Exception:
        cell2_dbg("infer_pad_exc", "infer pad id failed")
    return int(default_pad_id)


def _pad_or_truncate_array(tensor: torch.Tensor, length: int, pad_value: int) -> torch.Tensor:
    """Pad or truncate tensor to exact length."""
    if tensor is None:
        return torch.full((length,), int(pad_value), dtype=torch.long)
    
    t = tensor.view(-1).long()
    L = t.size(0)
    
    if L == length:
        return t
    if L < length:
        pad = torch.full((length - L,), int(pad_value), dtype=t.dtype)
        return torch.cat([t, pad], dim=0)
    return t[:length]


def safe_collate(batch: List[Dict[str, Any]]) -> Dict[str, Any]:
    """
    Robust collate: ensures stackable tensors and safe structure.
    Pads/truncates all sequences to _MAX_LENGTH deterministically.
    """
    valid = [b for b in batch if isinstance(b, dict) and "input_ids" in b and isinstance(b["input_ids"], torch.Tensor)]
    
    if not valid:
        pad = _infer_pad_id_from_sample({}, default_pad_id=1)
        return {
            "input_ids": torch.full((1, _MAX_LENGTH), pad, dtype=torch.long),
            "attention_mask": torch.zeros(1, _MAX_LENGTH, dtype=torch.long),
            "labels": torch.full((1, _MAX_LENGTH), -100, dtype=torch.long),
            "token_word_map": [{}],
            "src_text": [""],
            "tokens": [[]]
        }

    pad_id = _infer_pad_id_from_sample(valid[0], default_pad_id=1)

    inputs, masks, labs, twmaps, srcs, toks = [], [], [], [], [], []
    
    for i, s in enumerate(valid):
        try:
            in_ids = s["input_ids"]
            att = s.get("attention_mask", None)
            lab = s["labels"]

            if att is None:
                att = (in_ids != pad_id).long()
            else:
                try:
                    att = att.view(-1).long()
                except Exception:
                    att = (in_ids != pad_id).long()

            try:
                in_ids = in_ids.view(-1)
            except Exception:
                in_ids = in_ids.flatten()
            try:
                lab = lab.view(-1)
            except Exception:
                lab = lab.flatten()

            in_ids = _pad_or_truncate_array(in_ids, _MAX_LENGTH, pad_id)
            att = _pad_or_truncate_array(att, _MAX_LENGTH, 0)
            lab = _pad_or_truncate_array(lab, _MAX_LENGTH, -100)

            inputs.append(in_ids)
            masks.append(att)
            labs.append(lab)
            twmaps.append(s.get("token_word_map", {}))
            srcs.append(s.get("src_text", ""))
            toks.append(s.get("tokens", []))
        except Exception as e:
            cell2_dbg("collate_item_exc", f"Collate item exception idx={i}: {type(e).__name__}")
            continue

    if not inputs:
        pad = _infer_pad_id_from_sample({}, default_pad_id=1)
        return {
            "input_ids": torch.full((1, _MAX_LENGTH), pad, dtype=torch.long),
            "attention_mask": torch.zeros(1, _MAX_LENGTH, dtype=torch.long),
            "labels": torch.full((1, _MAX_LENGTH), -100, dtype=torch.long),
            "token_word_map": [{}],
            "src_text": [""],
            "tokens": [[]]
        }

    input_ids = torch.stack(inputs, dim=0)
    attention_mask = torch.stack(masks, dim=0)
    labels = torch.stack(labs, dim=0)

    # DP-divisible adjustment: trim downward to nearest multiple to avoid OOM
    if _USE_MULTI_GPU and _NUM_GPUS > 0:
        bsz = input_ids.size(0)
        keep = (bsz // _NUM_GPUS) * _NUM_GPUS
        if keep > 0 and keep < bsz:
            cell2_dbg("dp_trunc", f"DP truncate from {bsz} to {keep}")
            input_ids = input_ids[:keep]
            attention_mask = attention_mask[:keep]
            labels = labels[:keep]
            twmaps = twmaps[:keep]
            srcs = srcs[:keep]
            toks = toks[:keep]

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
        "token_word_map": twmaps,
        "src_text": srcs,
        "tokens": toks
    }


def create_optimized_dataloader(dataset: Dataset, batch_size: Optional[int] = None, shuffle: bool = True) -> DataLoader:
    """
    Create a DataLoader with safe defaults and deterministic worker init.
    By default, if _USE_MULTI_GPU the batch_size will be floored to nearest multiple of _NUM_GPUS
    to avoid oversubscribing GPU memory.
    """
    if batch_size is None:
        try:
            batch_size = int(BATCH_SIZE)
        except NameError:
            batch_size = 8
    batch_size = int(batch_size)

    # Floor to nearest multiple for multi-GPU
    adjust_upwards = False  # change to True if you prefer increasing to next multiple

    if _USE_MULTI_GPU and _NUM_GPUS > 0 and batch_size % _NUM_GPUS != 0:
        if adjust_upwards:
            adjusted = ((batch_size + _NUM_GPUS - 1) // _NUM_GPUS) * _NUM_GPUS
            print(f"[CELL2] Adjusting batch size {batch_size} ‚Üí {adjusted} to be DP-divisible (GPUs={_NUM_GPUS})")
            batch_size = adjusted
        else:
            adjusted = (batch_size // _NUM_GPUS) * _NUM_GPUS
            if adjusted == 0:
                print(f"[CELL2] WARNING: batch_size {batch_size} < num_gpus {_NUM_GPUS}. Keeping original batch_size.")
            else:
                print(f"[CELL2] Adjusting batch size {batch_size} ‚Üí {adjusted} (floor to multiple of {_NUM_GPUS}) to avoid OOM.")
                batch_size = adjusted

    # Validate num_workers
    num_workers = _NUM_WORKERS if isinstance(_NUM_WORKERS, int) and _NUM_WORKERS >= 0 else 0
    try:
        max_possible = max(0, (os.cpu_count() or 1) - 1)
        if num_workers > max_possible:
            num_workers = max_possible
    except Exception:
        pass

    loader_kwargs = {
        "dataset": dataset,
        "batch_size": batch_size,
        "shuffle": shuffle,
        "num_workers": num_workers,
        "pin_memory": bool(_PIN_MEMORY and torch.cuda.is_available()),
        "collate_fn": safe_collate,
        "drop_last": False,
    }
    
    # Only set worker_init_fn if using workers
    if num_workers > 0:
        loader_kwargs["worker_init_fn"] = _dataloader_worker_init_fn
        loader_kwargs["prefetch_factor"] = _PREFETCH_FACTOR
        loader_kwargs["persistent_workers"] = False

    try:
        dataloader = DataLoader(**loader_kwargs)
    except Exception as e:
        print(f"[CELL2] DataLoader init failed with num_workers={num_workers}: {type(e).__name__}: {str(e)[:200]}")
        print("[CELL2] Retrying with num_workers=0")
        loader_kwargs["num_workers"] = 0
        loader_kwargs.pop("prefetch_factor", None)
        loader_kwargs.pop("persistent_workers", None)
        loader_kwargs.pop("worker_init_fn", None)
        dataloader = DataLoader(**loader_kwargs)

    if _USE_MULTI_GPU and _NUM_GPUS > 0:
        per_gpu = batch_size // _NUM_GPUS if _NUM_GPUS > 0 else batch_size
        print(f"[CELL2] DataLoader created: total_batch={batch_size}, per_gpu={per_gpu}, workers={loader_kwargs.get('num_workers', 0)}")
    else:
        print(f"[CELL2] DataLoader created: batch_size={batch_size}, workers={loader_kwargs.get('num_workers', 0)}")

    return dataloader


print("‚úÖ Cell 2: Memory-efficient data loading ready (FIXED: CSV support + hardened error handling)")


‚úÖ Cell 2: Memory-efficient data loading ready (FIXED: CSV support + hardened error handling)


In [6]:
# ==============================================================================
# CELL 3 (ENHANCED): DSCD WITH HIERARCHICAL CLUSTERING + KMEANS FALLBACK + WORD-KEYS
# DEBUGGED, HARDENED, and SELF-CONTAINED
# ==============================================================================
# Fixes applied (high level):
# - Robust global config fetching via globals().get with sensible defaults.
# - Added safe helper get_special_tokens(tokenizer) fallback (avoids reliance on external Cell 0 helper).
# - Rewrote numpy KMeans seeding & iterations to be robust and vectorized.
# - Defensive handling of tensor <-> numpy conversions to avoid device/stride errors.
# - Ensure clustering uses CPU numpy arrays only; centroids stored as CPU tensors.
# - Added explicit checks and guarding for empty buffers and shapes.
# - Minor readability / stability improvements and additional logging under VERBOSE_LOGGING.
# ==============================================================================

import threading
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import gc
from collections import deque
import unicodedata
from typing import Optional, List, Tuple

# -------------------------
# Config / feature detection
# -------------------------
PRINT_INTERVAL = int(globals().get("PRINT_INTERVAL", 200))

# SciPy hierarchical clustering (optional)
try:
    from scipy.cluster.hierarchy import linkage, fcluster
    from scipy.spatial.distance import pdist
    HAS_CLUSTERING = True
except Exception:
    HAS_CLUSTERING = False
    print("[CELL3] WARNING: scipy not available - hierarchical clustering disabled")

# sklearn KMeans (optional)
try:
    from sklearn.cluster import KMeans
    HAS_KMEANS = True
except Exception:
    HAS_KMEANS = False
    print("[CELL3] WARNING: sklearn not available - KMeans fallback disabled")

# DSCD config with safe defaults
DSCD_MAX_PROTOS = int(globals().get("DSCD_MAX_PROTOS", 8))
DSCD_BUFFER_SIZE = int(globals().get("DSCD_BUFFER_SIZE", 20))
DSCD_N_MIN = int(globals().get("DSCD_N_MIN", 5))
DSCD_DISPERSION_THRESHOLD = float(globals().get("DSCD_DISPERSION_THRESHOLD", 0.25))
VERBOSE_LOGGING = bool(globals().get("VERBOSE_LOGGING", False))
HOMOGRAPH_WATCHLIST_BN = set(globals().get("HOMOGRAPH_WATCHLIST_BN",
                                         {"‡¶ï‡¶≤", "‡¶ï‡¶æ‡¶≤", "‡¶™‡¶æ‡¶§‡¶æ", "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï", "‡¶´‡¶≤", "‡¶Æ‡¶æ‡¶•‡¶æ"}))
DSCD_MAX_CLUSTERING_POINTS = int(globals().get("DSCD_MAX_CLUSTERING_POINTS", 2000))

# small deny-prefix set for combining/vowel marks (avoid clustering noise)
DSCD_TOKEN_DENY_PREFIXES = set(['‡ßç', '‡¶ø', '‡ßá', '‡¶æ', '‡ßÄ', '‡ßÅ', '‡ßÇ', '‡ßó', '‡ßç‡¶∞', '‡ßé', '‡¶Å'])

# -------------------------
# Helper: safe special tokens extractor (fallbacks)
# -------------------------
def get_special_tokens_safe(tok):
    """
    Return a set of special tokens from tokenizer in a robust way.
    """
    if tok is None:
        return set()
    try:
        s = getattr(tok, "all_special_tokens", None)
        if s:
            return set(s)
    except Exception:
        pass
    try:
        # HF tokenizer mapping
        stm = getattr(tok, "special_tokens_map", None)
        if isinstance(stm, dict):
            vals = [v for v in stm.values() if isinstance(v, str)]
            return set(vals)
    except Exception:
        pass
    try:
        # fallback to keys of special_tokens_map_extended or similar
        stmap = getattr(tok, "additional_special_tokens", None)
        if stmap:
            return set(stmap)
    except Exception:
        pass
    # last resort: look for attributes that may contain token ids/names
    out = set()
    for attr in ("bos_token", "eos_token", "pad_token", "unk_token", "sep_token", "cls_token", "mask_token"):
        try:
            v = getattr(tok, attr, None)
            if isinstance(v, str):
                out.add(v)
        except Exception:
            pass
    return out

# -------------------------
# Token predicate
# -------------------------
def is_word_token(token: str, min_letters: int = 2, min_letter_fraction: float = 0.6) -> bool:
    """
    Unicode-aware test whether a token is likely a real word to track.
    """
    if not token or not isinstance(token, str):
        return False
    token = token.strip()
    if token == "":
        return False

    letters = 0
    total = 0
    for ch in token:
        cat = unicodedata.category(ch)
        if cat.startswith("L"):
            letters += 1
        if not ch.isspace():
            total += 1

    if total == 0:
        return False
    if letters < min_letters:
        return False
    if (letters / total) < min_letter_fraction:
        return False
    return True

# -------------------------
# Robust small numpy KMeans fallback
# -------------------------
def _numpy_kmeans(X: np.ndarray, n_clusters: int, n_iter: int = 10, random_state: int = 0) -> Tuple[np.ndarray, np.ndarray]:
    """
    Simple, robust KMeans implemented with numpy.
    - X: (N, D)
    - returns: labels (N,), centroids (n_clusters, D)
    This implementation uses random initialization with a KMeans++-like heuristic
    (choose first centroid randomly, subsequent centroids by distance weighting).
    """
    X = np.asarray(X, dtype=np.float32)
    N, D = X.shape
    if N == 0:
        return np.zeros((0,), dtype=np.int32), np.zeros((0, D), dtype=np.float32)
    n_clusters = int(max(1, min(n_clusters, N)))
    rng = np.random.RandomState(random_state)

    # KMeans++ style initialization
    centroids = np.empty((n_clusters, D), dtype=np.float32)
    first_idx = rng.randint(0, N)
    centroids[0] = X[first_idx]
    for k in range(1, n_clusters):
        # compute distance to nearest existing centroid
        dists = np.linalg.norm(X[:, None, :] - centroids[None, :k, :], axis=2)  # (N, k)
        nearest = dists.min(axis=1)  # (N,)
        probs = nearest / (nearest.sum() + 1e-12)
        chosen = rng.choice(N, p=probs)
        centroids[k] = X[chosen]

    labels = np.zeros(N, dtype=np.int32)
    for it in range(n_iter):
        # assign
        dists = np.linalg.norm(X[:, None, :] - centroids[None, :, :], axis=2)  # (N, k)
        new_labels = dists.argmin(axis=1)
        # update
        changed = False
        for j in range(n_clusters):
            members = (new_labels == j)
            if members.sum() == 0:
                # reinitialize empty centroid
                centroids[j] = X[rng.randint(0, N)]
                changed = True
            else:
                new_cent = X[members].mean(axis=0)
                if not np.allclose(new_cent, centroids[j], atol=1e-6):
                    centroids[j] = new_cent.astype(np.float32)
                    changed = True
        labels = new_labels
        if not changed:
            break
    return labels, centroids

# -------------------------
# Prototype store (CPU)
# -------------------------
class MemoryEfficientPrototypeStore:
    def __init__(self, embed_dim: int, max_protos: Optional[int] = None):
        self.embed_dim = int(embed_dim)
        self.max_protos = int(max_protos) if max_protos is not None else DSCD_MAX_PROTOS
        self.centroids: List[torch.Tensor] = []   # CPU tensors
        self.counts: List[int] = []
        self.creation_time: List[float] = []
        self.distances: List[float] = []
        self.mu: float = 0.0
        self.tau: float = 1e-6
        self.alpha: float = 0.1

    def add_prototype(self, vector, current_time=None, count=1):
        if current_time is None:
            current_time = time.time()
        try:
            if isinstance(vector, torch.Tensor):
                v = vector.detach().cpu().float().clone()
            else:
                v = torch.from_numpy(np.asarray(vector, dtype=np.float32)).cpu()
        except Exception:
            return
        if len(self.centroids) < self.max_protos:
            self.centroids.append(v)
            self.counts.append(int(count))
            self.creation_time.append(current_time)
            return
        # replace least-supported prototype
        try:
            min_idx = int(np.argmin(self.counts)) if self.counts else 0
        except Exception:
            min_idx = 0
        min_idx = max(0, min_idx)
        if min_idx < len(self.centroids):
            self.centroids[min_idx] = v
            self.counts[min_idx] = int(count)
            self.creation_time[min_idx] = current_time
        else:
            # pad lists
            while len(self.centroids) <= min_idx:
                self.centroids.append(v.clone())
                self.counts.append(int(count))
                self.creation_time.append(current_time)

    def update_prototype(self, idx, vector, eta=0.05, assignment_distance=None):
        try:
            if idx < 0 or idx >= len(self.centroids):
                self.add_prototype(vector, time.time(), count=1)
                return
            old = self.centroids[idx]
            newv = vector.detach().cpu() if isinstance(vector, torch.Tensor) else torch.from_numpy(np.asarray(vector, dtype=np.float32)).cpu()
            try:
                self.centroids[idx] = (1.0 - eta) * old + eta * newv
            except Exception:
                self.centroids[idx] = newv.clone()
            try:
                self.counts[idx] = int(self.counts[idx]) + 1
            except Exception:
                while len(self.counts) < len(self.centroids):
                    self.counts.append(1)
                self.counts[idx] = int(self.counts[idx]) + 1
        except Exception:
            try:
                self.add_prototype(vector, time.time(), count=1)
            except Exception:
                pass
        if assignment_distance is not None:
            try:
                self.update_rolling_stats(float(assignment_distance))
            except Exception:
                pass

    def update_rolling_stats(self, d: float):
        try:
            if not self.distances:
                self.mu = float(d)
                self.tau = 1e-6
                self.distances = [float(d)]
                return
            prev_mu = self.mu
            self.mu = (1 - self.alpha) * self.mu + self.alpha * float(d)
            self.tau = (1 - self.alpha) * self.tau + self.alpha * abs(float(d) - prev_mu)
            self.distances.append(float(d))
            if len(self.distances) > 50:
                self.distances.pop(0)
        except Exception:
            pass

    def get_adaptive_threshold(self, lam=1.0) -> float:
        try:
            return float(self.mu + lam * self.tau)
        except Exception:
            return float(self.mu)

    def get_centroids(self, device=torch.device("cpu")) -> Optional[torch.Tensor]:
        if not self.centroids:
            return None
        try:
            return torch.stack([c.to(device) for c in self.centroids], dim=0)
        except Exception:
            try:
                return torch.stack([c.cpu() for c in self.centroids], dim=0).to(device)
            except Exception:
                return None

    def get_valid_centroids(self, device=torch.device("cpu"), min_count=None):
        if min_count is None:
            min_count = DSCD_N_MIN
        idxs = [i for i, ct in enumerate(self.counts) if ct >= int(min_count)]
        if not idxs:
            return None, None
        cents = [self.centroids[i].to(device) for i in idxs]
        return torch.stack(cents, dim=0), idxs

    def set_centroids_from_arrays(self, array_list, counts=None):
        try:
            self.centroids = [torch.from_numpy(np.asarray(a, dtype=np.float32)).cpu() for a in array_list]
            if counts and len(counts) == len(array_list):
                self.counts = [int(c) for c in counts]
            else:
                self.counts = [1 for _ in array_list]
            self.creation_time = [time.time()] * len(array_list)
        except Exception:
            self.centroids = []
            self.counts = []
            self.creation_time = []

    def size(self) -> int:
        return len(self.centroids)

# -------------------------
# DSCD Online Module
# -------------------------
class MemoryEfficientDSCDOnline(nn.Module):
    def __init__(self, embed_dim, tokenizer=None, buffer_size=None, max_protos=None,
                 n_min=None, dispersion_threshold=None, language='bn',
                 enable_training_clustering=False, max_clustering_points=None,
                 max_candidates_per_step=2, dscd_min_letters: int = 2,
                 dscd_min_letter_fraction: float = 0.6):
        super().__init__()
        self.embed_dim = int(embed_dim)
        self.buffer_size = int(buffer_size) if buffer_size is not None else DSCD_BUFFER_SIZE
        self.max_protos = int(max_protos) if max_protos is not None else DSCD_MAX_PROTOS
        self.n_min = int(n_min) if n_min is not None else DSCD_N_MIN
        self.dispersion_threshold = float(dispersion_threshold) if dispersion_threshold is not None else DSCD_DISPERSION_THRESHOLD
        self.language = language
        self.tokenizer = tokenizer
        self.dscd_min_letters = int(dscd_min_letters)
        self.dscd_min_letter_fraction = float(dscd_min_letter_fraction)

        # special tokens
        try:
            if tokenizer is not None:
                self.special_tokens = get_special_tokens_safe(tokenizer)
            else:
                self.special_tokens = set()
        except Exception:
            self.special_tokens = set()

        self._dscd_allowed_tokens = set()
        self._dscd_ignored_tokens = set()

        self.prototype_stores = {}
        self.buffers = {}
        self.discovery_log = []
        self.last_periodic_check = 0
        self.cleanup_counter = 0
        self.clustering_lock = threading.Lock()

        self.last_cluster_time = {}
        self.cluster_cooldown_seconds = 60
        self.enable_training_clustering = bool(enable_training_clustering)

        # small heads kept for compatibility
        self.span_head = nn.Sequential(
            nn.Linear(self.embed_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 1)
        )
        self.sigma_net = nn.Sequential(
            nn.Linear(self.embed_dim, 16),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(16, 1)
        )
        self.gate_w = nn.Parameter(torch.tensor(1.0))
        self.gate_b = nn.Parameter(torch.tensor(0.4))
        self.gamma = nn.Parameter(torch.tensor(0.3))

        self.max_clustering_points = int(max_clustering_points) if max_clustering_points is not None else DSCD_MAX_CLUSTERING_POINTS
        self.max_candidates_per_step = int(max_candidates_per_step)

        if VERBOSE_LOGGING:
            print(f"[DSCD-INIT] embed_dim={self.embed_dim}, buffer_size={self.buffer_size}, max_protos={self.max_protos}, n_min={self.n_min}")
            print(f"[DSCD-INIT] dispersion_threshold={self.dispersion_threshold}, max_clustering_points={self.max_clustering_points}")

    # ------------------------
    def should_track_token(self, token_text: str) -> bool:
        if not token_text or not isinstance(token_text, str):
            return False
        if token_text in self._dscd_allowed_tokens:
            return True
        if token_text in self._dscd_ignored_tokens:
            return False

        try:
            clean = token_text.replace('‚ñÅ', '').replace('ƒ†', '').replace('##', '').replace('@@', '').strip()
        except Exception:
            clean = token_text.strip()

        try:
            if len(clean) <= 2 and any(clean.startswith(p) for p in DSCD_TOKEN_DENY_PREFIXES):
                self._dscd_ignored_tokens.add(token_text)
                return False
        except Exception:
            pass

        try:
            if clean in HOMOGRAPH_WATCHLIST_BN:
                self._dscd_allowed_tokens.add(token_text)
                if VERBOSE_LOGGING and len(self._dscd_allowed_tokens) <= 20:
                    print(f"[DSCD] ‚úÖ Homograph watchlist token tracked: '{clean}'")
                return True
        except Exception:
            pass

        if token_text in self.special_tokens:
            self._dscd_ignored_tokens.add(token_text)
            return False

        if clean == "":
            self._dscd_ignored_tokens.add(token_text)
            return False

        if len(clean) < 2:
            self._dscd_ignored_tokens.add(token_text)
            return False

        if not any(c.isalpha() for c in clean):
            self._dscd_ignored_tokens.add(token_text)
            return False

        if clean.isdigit():
            self._dscd_ignored_tokens.add(token_text)
            return False

        if all(c in '.,!?;:()[]{}"\'-‚Äî‚Äì/\\' for c in clean):
            self._dscd_ignored_tokens.add(token_text)
            return False

        try:
            bengali_block = any('\u0980' <= c <= '\u09FF' for c in clean)
            if bengali_block and len(clean) >= 2:
                self._dscd_allowed_tokens.add(token_text)
                return True
        except Exception:
            pass

        if is_word_token(clean, min_letters=self.dscd_min_letters, min_letter_fraction=self.dscd_min_letter_fraction):
            self._dscd_allowed_tokens.add(token_text)
            return True

        self._dscd_ignored_tokens.add(token_text)
        return False

    def _canonical_token_key(self, raw_token: str, token_word_map: Optional[dict], idx: int) -> str:
        canonical = None
        try:
            if token_word_map and isinstance(token_word_map, dict) and idx in token_word_map and token_word_map[idx]:
                canonical = str(token_word_map[idx]).strip()
        except Exception:
            canonical = None
        if not canonical:
            try:
                canonical = raw_token.replace('‚ñÅ', '').replace('ƒ†', '').replace('##', '').replace('@@', '').strip()
            except Exception:
                canonical = raw_token
        if not canonical:
            canonical = raw_token
        return canonical

    def forward(self, token_embeddings, token_types=None, train_mode=True,
                token_word_map=None, h_all=None, input_ids=None, attention_mask=None):
        if token_embeddings is None and h_all is not None:
            token_embeddings = h_all
        if token_embeddings is None:
            raise ValueError("MemoryEfficientDSCDOnline.forward requires token_embeddings or h_all")

        # build token_types if missing
        if input_ids is not None and token_types is None:
            try:
                batch_size, seq_len = input_ids.shape
            except Exception:
                batch_size = int(token_embeddings.size(0))
                seq_len = int(token_embeddings.size(1))
            token_types = []
            for b in range(batch_size):
                if self.tokenizer is not None:
                    try:
                        token_types.append(self.tokenizer.convert_ids_to_tokens(input_ids[b].tolist()))
                    except Exception:
                        token_types.append([f'tok_{i}' for i in range(seq_len)])
                else:
                    token_types.append([f'tok_{i}' for i in range(seq_len)])

        self.cleanup_counter += 1
        if self.cleanup_counter % 50 == 0:
            self.cleanup_counter = 0
            self.cleanup_memory()

        device = token_embeddings.device
        batch_size = int(token_embeddings.size(0))
        seq_len = int(token_embeddings.size(1))

        all_outputs = {
            'proto_assignments': [],
            'proto_probs': [],
            'uncertainties': [],
            'span_preds': [],
            'gates': [],
            'h_augmented': []
        }

        for b in range(batch_size):
            word_map = token_word_map[b] if token_word_map and len(token_word_map) > b else None
            tt = token_types[b] if token_types and len(token_types) > b else [f'tok_{i}' for i in range(seq_len)]
            batch_outputs = self.process_sequence(
                token_embeddings[b],
                tt,
                device,
                word_map=word_map,
                train_mode=train_mode
            )
            for k in all_outputs:
                all_outputs[k].append(batch_outputs[k])

        # assemble h_augmented into tensor (batch, seq_len, embed_dim) where possible
        try:
            h_aug_list = []
            max_seq_len = seq_len
            for b in range(batch_size):
                h_batch_list = all_outputs['h_augmented'][b]
                if len(h_batch_list) > 0 and isinstance(h_batch_list[0], torch.Tensor):
                    h_batch = torch.stack(h_batch_list, dim=0)
                    if h_batch.size(0) < max_seq_len:
                        pad = max_seq_len - h_batch.size(0)
                        h_batch = F.pad(h_batch, (0, 0, 0, pad), value=0)
                    elif h_batch.size(0) > max_seq_len:
                        h_batch = h_batch[:max_seq_len]
                else:
                    h_batch = torch.zeros(max_seq_len, self.embed_dim, device=device)
                h_aug_list.append(h_batch)
            all_outputs['h_augmented'] = torch.stack(h_aug_list, dim=0)
        except Exception:
            all_outputs['h_augmented'] = token_embeddings

        return all_outputs

    def process_sequence(self, token_embeddings, token_types, device, word_map=None, train_mode=True):
        seq_len = int(token_embeddings.size(0))
        outputs = {
            'proto_assignments': [],
            'proto_probs': [],
            'uncertainties': [],
            'span_preds': [],
            'gates': [],
            'h_augmented': []
        }

        for j in range(seq_len):
            raw_tok = token_types[j] if j < len(token_types) else f'tok_{j}'
            token_key = self._canonical_token_key(raw_tok, word_map, j)
            h_j = token_embeddings[j]

            if not self.should_track_token(token_key):
                outputs['proto_assignments'].append(torch.tensor(-1))
                outputs['proto_probs'].append([])
                outputs['uncertainties'].append(0.0)
                outputs['span_preds'].append(0.0)
                outputs['gates'].append(0.0)
                outputs['h_augmented'].append(h_j)
                continue

            if token_key not in self.buffers:
                self.buffers[token_key] = deque(maxlen=self.buffer_size)
                self.prototype_stores[token_key] = MemoryEfficientPrototypeStore(self.embed_dim, self.max_protos)

            try:
                self.buffers[token_key].append(h_j.detach().cpu())
            except Exception:
                try:
                    self.buffers[token_key].append(h_j.cpu())
                except Exception:
                    pass

            # background clustering trigger
            try:
                if self.enable_training_clustering and len(self.buffers[token_key]) >= max(self.n_min, 4):
                    now = time.time()
                    last_t = self.last_cluster_time.get(token_key, 0.0)
                    if now - last_t > self.cluster_cooldown_seconds:
                        self.last_cluster_time[token_key] = now
                        def _bg_cluster(tok=token_key):
                            try:
                                with self.clustering_lock:
                                    self._cluster_buffer_to_prototypes_hierarchical(tok)
                            except Exception:
                                if VERBOSE_LOGGING:
                                    import traceback as _tb
                                    print(f"[DSCD] Background clustering error for token '{tok}': {_tb.format_exc().splitlines()[-1]}")
                        th = threading.Thread(target=_bg_cluster, daemon=True)
                        th.start()
            except Exception:
                if VERBOSE_LOGGING:
                    import traceback as _tb
                    print(f"[DSCD] Failed to trigger background clustering for token {token_key}: {_tb.format_exc().splitlines()[-1]}")

            store = self.prototype_stores[token_key]

            # atomic centroid snapshot
            centroids_snapshot = []
            with self.clustering_lock:
                try:
                    for c in getattr(store, "centroids", []):
                        if isinstance(c, torch.Tensor):
                            centroids_snapshot.append(c.clone().cpu())
                        else:
                            centroids_snapshot.append(torch.from_numpy(np.asarray(c)).cpu())
                except Exception:
                    centroids_snapshot = []

            assignment = -1
            prob_list = []
            uncertainty = 0.0
            span_pred = 0.0
            gate_val = 0.0
            h_aug = h_j

            if centroids_snapshot and len(centroids_snapshot) >= 1:
                try:
                    h_cpu = h_j.detach().cpu().numpy()
                    cents_np = np.stack([c.numpy() for c in centroids_snapshot], axis=0).astype(np.float32)  # (K, D)
                    dists_np = np.linalg.norm(cents_np - h_cpu[None, :], axis=1)  # (K,)
                    if dists_np.size > 0:
                        assignment = int(np.argmin(dists_np))
                        min_dist = float(dists_np[assignment])
                        try:
                            store.update_rolling_stats(min_dist)
                        except Exception:
                            pass

                        # softmax over negative distances -> probabilities
                        try:
                            neg = -dists_np
                            exps = np.exp(neg - np.max(neg))
                            probs = exps / (exps.sum() + 1e-12)
                            prob_list = probs.tolist()
                            uncertainty = 1.0 - float(np.max(probs))
                        except Exception:
                            prob_list = []
                            uncertainty = 0.0

                        try:
                            span_pred = float(torch.sigmoid(self.span_head(h_j)).item())
                        except Exception:
                            try:
                                span_pred = float(torch.sigmoid(self.span_head(h_j.cpu())).item())
                            except Exception:
                                span_pred = 0.0

                        try:
                            gate_val = float(torch.sigmoid(self.gate_w * torch.norm(h_j) + self.gate_b).item())
                        except Exception:
                            gate_val = 0.5

                        if gate_val > 0.3 and 0 <= assignment < len(centroids_snapshot):
                            centroid_t = centroids_snapshot[assignment].to(device)
                            try:
                                h_aug = h_j + 0.1 * (centroid_t - h_j)
                            except Exception:
                                h_aug = h_j
                except Exception as e:
                    if VERBOSE_LOGGING:
                        print(f"[DSCD] Assignment error for '{token_key}': {str(e)[:200]}")

            outputs['proto_assignments'].append(torch.tensor(assignment))
            outputs['proto_probs'].append(prob_list)
            outputs['uncertainties'].append(uncertainty)
            outputs['span_preds'].append(span_pred)
            outputs['gates'].append(gate_val)
            outputs['h_augmented'].append(h_aug)

        if not train_mode and len(self.prototype_stores) > 0 and VERBOSE_LOGGING:
            if self.last_periodic_check % PRINT_INTERVAL == 0:
                self._print_clusters_summary()
            self.last_periodic_check += 1

        return outputs

    def _print_clusters_summary(self):
        try:
            items = []
            for token, store in self.prototype_stores.items():
                try:
                    proto_sample_count = sum(getattr(store, 'counts', []) or [])
                except Exception:
                    proto_sample_count = 0
                buffer_len = len(self.buffers.get(token, [])) if token in self.buffers else 0
                total_count = proto_sample_count if proto_sample_count > 0 else buffer_len
                protos = store.size()
                mu = getattr(store, 'mu', 0.0)
                tau = getattr(store, 'tau', 0.0)
                items.append((token, total_count, protos, mu, tau, buffer_len))
            items.sort(key=lambda x: x[1], reverse=True)
            if VERBOSE_LOGGING:
                print("\n[CLUSTER] Top 5 clusters (by sample count or buffer size):")
                print("-" * 100)
                print(f"{'Rank':<6} {'Token':<18} {'Count':<12} {'Protos':<8} {'BufLen':<8} {'Œº (mean)':<15} {'œÑ (dev)':<15}")
                print("-" * 100)
                for rank, (tok, cnt, prot, mu, tau, buflen) in enumerate(items[:5], 1):
                    tok_str = str(tok)[:18]
                    print(f"{rank:<6} {tok_str:<18} {cnt:<12} {prot:<8} {buflen:<8} {mu:<15.6f} {tau:<15.6f}")
                print("-" * 100)
                total_samples = sum(item[1] for item in items)
                total_protos = sum(item[2] for item in items)
                total_buffers = sum(item[5] for item in items)
                print(f"Total clusters: {len(items)} | Total samples: {total_samples} | Total protos: {total_protos} | Sum buffers: {total_buffers}\n")
        except Exception as e:
            if VERBOSE_LOGGING:
                print(f"[CLUSTER] Error printing summary: {str(e)[:200]}")

    def cleanup_memory(self):
        try:
            for token_type, buffer in list(self.buffers.items()):
                if len(buffer) > int(self.buffer_size * 1.5):
                    while len(buffer) > self.buffer_size:
                        buffer.popleft()
            try:
                gc.collect()
            except Exception:
                pass
        except Exception:
            pass

    def _cluster_buffer_to_prototypes_hierarchical(self, token_type):
        try:
            if not self.should_track_token(token_type):
                if VERBOSE_LOGGING:
                    print(f"[DSCD-CLUSTER] Skipping clustering for non-word token '{token_type}'")
                return False
            if token_type not in self.buffers:
                return False
            buf = self.buffers[token_type]
            if len(buf) < self.n_min:
                if VERBOSE_LOGGING:
                    print(f"[DSCD-CLUSTER] '{token_type}' buffer size {len(buf)} < n_min {self.n_min}")
                return False

            emb_list = []
            for e in buf:
                try:
                    if isinstance(e, torch.Tensor):
                        emb_list.append(e.numpy())
                    else:
                        emb_list.append(np.asarray(e))
                except Exception:
                    continue
            if len(emb_list) == 0:
                return False

            if len(emb_list) > self.max_clustering_points:
                idxs = np.random.choice(len(emb_list), size=self.max_clustering_points, replace=False)
                embeddings = np.stack([emb_list[i] for i in idxs], axis=0)
            else:
                embeddings = np.stack(emb_list, axis=0)

            if embeddings.shape[0] < 2:
                return False

            if VERBOSE_LOGGING:
                norms = np.linalg.norm(embeddings, axis=1)
                print(f"[DSCD-CLUSTER] Token '{token_type}' buffer={len(buf)} sampled={embeddings.shape[0]} mean_norm={norms.mean():.4f} std_norm={norms.std():.4f}")

            store = self.prototype_stores[token_type]
            store.centroids = []
            store.counts = []
            store.creation_time = []

            protos_added = 0

            # hierarchical clustering (scipy)
            if HAS_CLUSTERING:
                try:
                    condensed = pdist(embeddings, metric='euclidean')
                    if condensed.size > 0:
                        k_guess = min(self.max_protos, max(2, len(embeddings) // max(1, self.n_min)))
                        k_guess = max(1, int(k_guess))
                        Z = linkage(condensed, method='ward')
                        clusters = fcluster(Z, t=k_guess, criterion='maxclust') - 1
                        if clusters.size > 0:
                            maxc = int(clusters.max())
                            for cid in range(maxc + 1):
                                mask = (clusters == cid)
                                if mask.sum() >= self.n_min:
                                    centroid = torch.from_numpy(embeddings[mask].mean(axis=0).astype(np.float32))
                                    store.add_prototype(centroid, time.time(), count=int(mask.sum()))
                                    protos_added += 1
                    if VERBOSE_LOGGING and protos_added > 0:
                        print(f"[DSCD-CLUSTER] Hierarchical clustering created {protos_added} prototypes for '{token_type}'")
                except Exception as e:
                    if VERBOSE_LOGGING:
                        print(f"[DSCD-CLUSTER] Hierarchical clustering failed for '{token_type}': {type(e).__name__}: {str(e)[:200]}")

            # sklearn KMeans fallback
            if protos_added == 0 and HAS_KMEANS:
                try:
                    k_guess = min(self.max_protos, max(1, len(embeddings) // max(1, self.n_min)))
                    k_guess = int(min(k_guess, len(embeddings)))
                    if k_guess >= 1 and len(embeddings) >= k_guess:
                        km = KMeans(n_clusters=k_guess, random_state=0, n_init=10).fit(embeddings)
                        labels = km.labels_
                        for c in range(k_guess):
                            mask = (labels == c)
                            if mask.sum() >= self.n_min:
                                centroid = torch.from_numpy(embeddings[mask].mean(axis=0).astype(np.float32))
                                store.add_prototype(centroid, time.time(), count=int(mask.sum()))
                                protos_added += 1
                        if VERBOSE_LOGGING and protos_added > 0:
                            print(f"[DSCD-CLUSTER] KMeans fallback created {protos_added} prototypes for '{token_type}'")
                except Exception as e:
                    if VERBOSE_LOGGING:
                        print(f"[DSCD-CLUSTER] KMeans fallback failed for '{token_type}': {type(e).__name__}: {str(e)[:200]}")

            # pure-numpy kmeans fallback
            if protos_added == 0:
                try:
                    k_guess = min(self.max_protos, max(1, len(embeddings) // max(1, self.n_min)))
                    k_guess = int(min(k_guess, len(embeddings)))
                    if k_guess >= 1 and len(embeddings) >= k_guess:
                        labels, cents = _numpy_kmeans(embeddings.astype(np.float32), n_clusters=k_guess, n_iter=10, random_state=0)
                        for c in range(k_guess):
                            mask = (labels == c)
                            if mask.sum() >= self.n_min:
                                centroid = torch.from_numpy(cents[c].astype(np.float32))
                                store.add_prototype(centroid, time.time(), count=int(mask.sum()))
                                protos_added += 1
                        if VERBOSE_LOGGING and protos_added > 0:
                            print(f"[DSCD-CLUSTER] numpy-kmeans created {protos_added} prototypes for '{token_type}'")
                except Exception as e:
                    if VERBOSE_LOGGING:
                        print(f"[DSCD-CLUSTER] numpy-kmeans failed for '{token_type}': {type(e).__name__}: {str(e)[:200]}")

            if VERBOSE_LOGGING:
                print(f"[DSCD-CLUSTER] Token '{token_type}': final_protos={store.size()} counts={store.counts}")

            return store.size() > 0

        except Exception as e:
            if VERBOSE_LOGGING:
                print(f"[DSCD-ERROR] Clustering error for '{token_type}': {type(e).__name__}: {str(e)[:200]}")
            return False

    def get_explanations(self, threshold_span=0.3):
        expl = []
        for token_type, store in self.prototype_stores.items():
            try:
                if store.size() >= 2:
                    expl.append({'token': str(token_type), 'protos': store.size(), 'counts': list(store.counts)})
            except Exception:
                continue
        return expl

# ==============================================================================
# VERIFICATION MESSAGE
# ==============================================================================
print("\n" + "=" * 80)
print("‚úÖ Cell 3 (ENHANCED): DSCD Ready with Homograph Watchlist Integration (Debugged)")
print("=" * 80)
print("Key features and fixes:")
print(" ‚úÖ Robust global config loading & defaults")
print(" ‚úÖ Deny-prefix set for short combining/vowel marks to reduce noise")
print(" ‚úÖ Atomic centroid snapshot under clustering_lock (race fix)")
print(" ‚úÖ Hierarchical clustering (scipy) with sklearn KMeans and numpy-KMeans fallback")
print(" ‚úÖ CPU-only prototype storage and clustering")
print(" ‚úÖ Unicode-aware token filtering (Bengali/Latin aware)")
print(" ‚úÖ Sampling for large buffers to avoid OOMs")
print(" ‚úÖ Safe guards for missing scipy/sklearn with robust logging")
print("=" * 80 + "\n")


‚úÖ Cell 3 (ENHANCED): DSCD Ready with Homograph Watchlist Integration (Debugged)
Key features and fixes:
 ‚úÖ Robust global config loading & defaults
 ‚úÖ Deny-prefix set for short combining/vowel marks to reduce noise
 ‚úÖ Atomic centroid snapshot under clustering_lock (race fix)
 ‚úÖ Hierarchical clustering (scipy) with sklearn KMeans and numpy-KMeans fallback
 ‚úÖ CPU-only prototype storage and clustering
 ‚úÖ Unicode-aware token filtering (Bengali/Latin aware)
 ‚úÖ Sampling for large buffers to avoid OOMs
 ‚úÖ Safe guards for missing scipy/sklearn with robust logging



In [7]:
# Fixed Bangla normalize_bn_word (vowel-aware suffix stripping)
# - Avoids removing final consonant when suffix includes vowel signs,
#   e.g. "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï‡ßá" -> "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï" (not "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç").
import re
import unicodedata
from typing import Optional, List

# Original suffix list (kept as provided, sorted longest-first below)
_BN_COMMON_SUFFIXES = [
    # (same items as in your original list)
    "‡¶ó‡ßÅ‡¶≤‡ßã‡¶∞‡¶á", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶∞‡¶ì", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶§‡ßá‡¶ì", "‡¶ó‡ßÅ‡¶≤‡¶ø‡¶ï‡ßá‡¶ì", "‡¶ó‡ßÅ‡¶≤‡¶ø‡¶ï‡ßá‡¶ì‡¶á", "‡¶ó‡ßÅ‡¶≤‡¶ø‡¶ï‡ßá‡¶ì‡¶ì",
    "‡¶ó‡ßÅ‡¶≤‡¶ø ‡¶•‡ßá‡¶ï‡ßá‡¶ì", "‡¶ó‡ßÅ‡¶≤‡¶ø‡¶§‡ßá ‡¶•‡ßá‡¶ï‡ßá‡¶ì", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶§‡ßá ‡¶•‡ßá‡¶ï‡ßá‡¶ì", "‡¶ó‡ßÅ‡¶≤‡¶ø‡¶§‡ßá", "‡¶ó‡ßÅ‡¶≤‡¶ø‡¶§‡ßá‡¶ì",
    "‡¶ó‡ßÅ‡¶≤‡ßã‡¶§‡ßá‡¶ì", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶ï‡ßá‡¶á", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶ï‡ßá‡¶á‡¶á", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶∞‡ßá", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶ó‡ßÅ‡¶≤‡ßã‡¶ï‡ßá", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶ó‡ßÅ‡¶≤‡ßã‡¶∞",
    "‡¶ó‡ßÅ‡¶≤‡ßã‡¶ü‡¶æ ‡¶•‡ßá‡¶ï‡ßá", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶ü‡¶æ", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶ü‡¶ø", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶ü‡¶æ ‡¶•‡ßá‡¶ï‡ßá", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶∞‡¶á", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶∞‡¶ì",
    "‡¶ó‡ßÅ‡¶≤‡ßã‡¶∞‡ßá", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶ï‡ßá", "‡¶ó‡ßÅ‡¶≤‡¶ø‡¶ï‡ßá‡¶ì", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶∞‡¶æ", "‡¶ó‡ßÅ‡¶≤‡ßã‡¶¶‡ßá‡¶∞", "‡¶ó‡ßÅ‡¶≤‡¶ø‡¶∏‡¶π",
    "‡¶¶‡ßá‡¶∞‡¶ï‡ßá", "‡¶¶‡ßá‡¶∞‡¶á", "‡¶¶‡ßá‡¶∞‡¶ì", "‡¶¶‡ßá‡¶∞‡¶ì‡¶á", "‡¶¶‡ßá‡¶∞‡ßá", "‡¶¶‡ßá‡¶∞‡¶æ", "‡¶¶‡ßá‡¶∞‡¶á‡¶ì",
    "‡¶ü‡¶æ‡¶∞‡¶á", "‡¶ü‡¶æ‡¶∞‡¶ì", "‡¶ü‡¶æ‡¶∞", "‡¶ü‡¶æ‡¶∞‡¶ü‡¶æ", "‡¶ü‡¶ø‡¶∞‡¶á", "‡¶ü‡¶ø‡¶∞‡¶ì", "‡¶ü‡¶ø‡¶∞", "‡¶ü‡¶æ‡¶ì", "‡¶ü‡¶æ‡¶á",
    "‡¶ü‡¶ø‡¶á", "‡¶ü‡¶æ", "‡¶ü‡¶ø", "‡¶ü‡¶æ‡ßü", "‡¶ü‡¶æ‡¶§‡ßá", "‡¶ü‡¶æ‡ßü‡¶ì",
    "‡¶•‡ßá‡¶ï‡ßá", "‡¶•‡ßá‡¶ï‡ßá‡¶ì", "‡¶•‡ßá‡¶ï‡ßá‡¶ì‡¶á", "‡¶•‡ßá‡¶ï‡ßá‡¶§‡ßá", "‡¶¶‡¶ø‡ßü‡ßá", "‡¶¶‡¶ø‡¶Ø‡¶º‡ßá", "‡¶¶‡¶ø‡¶Ø‡¶º‡ßá‡¶ì", "‡¶¶‡¶ø‡¶Ø‡¶º‡ßá‡¶á",
    "‡¶¶‡ßç‡¶¨‡¶æ‡¶∞‡¶æ", "‡¶Æ‡¶ß‡ßç‡¶Ø‡ßá", "‡¶Æ‡¶ß‡ßç‡¶Ø‡ßá‡¶ì", "‡¶™‡¶∞‡ßá", "‡¶™‡¶∞‡ßá ‡¶•‡ßá‡¶ï‡ßá‡¶á", "‡¶ú‡¶®‡ßç‡¶Ø", "‡¶ú‡¶®‡ßç‡¶Ø‡¶á", "‡¶™‡¶ï‡ßç‡¶∑‡ßá",
    "‡¶®‡¶ø‡¶Ø‡¶º‡ßá", "‡¶®‡¶ø‡ßü‡ßá", "‡¶∏‡¶π", "‡¶∏‡¶π‡ßá‡¶á", "‡¶¨‡¶ø‡¶®‡ßç‡¶¶‡ßÅ‡¶§‡ßá", "‡¶∏‡¶Æ‡ßç‡¶™‡¶∞‡ßç‡¶ï‡ßá", "‡¶Ö‡¶®‡ßÅ‡¶Ø‡¶æ‡ßü‡ßÄ", "‡¶Ö‡¶®‡ßÅ‡¶Ø‡¶æ‡¶Ø‡¶º‡ßÄ",
    "‡¶Ö‡¶®‡ßÅ‡¶Ø‡¶æ‡¶Ø‡¶º‡ßÄ‡¶§", "‡¶Ö‡¶®‡ßÅ‡¶∏‡¶æ‡¶∞‡ßá", "‡¶Æ‡¶§‡ßã", "‡¶∏‡¶Æ‡ßç‡¶™‡¶®‡ßç‡¶®", "‡¶®‡¶ø‡¶Æ‡¶ø‡¶§‡ßç‡¶§‡ßá",
    "‡¶è‡¶∞‡¶á", "‡¶è‡¶∞‡¶ì", "‡¶è‡¶∞‡¶ì‡¶á", "‡¶è‡¶∞", "‡¶∞‡¶á", "‡¶∞‡¶æ‡¶ì", "‡¶∞‡¶æ‡¶ì‡¶á", "‡¶∞‡¶á‡¶ì", "‡¶∞‡ßá", "‡¶∞‡ßã", "‡¶∞",
    "‡¶ï‡ßá", "‡¶ï‡ßá‡¶á", "‡¶ï‡ßá‡¶ì", "‡¶ï‡ßá‡¶ì‡¶á", "‡¶ï‡ßá‡¶æ", "‡¶ï‡ßá‡¶æ‡¶ì", "‡¶ï‡ßá‡¶æ‡¶ì‡¶á", "‡¶§‡ßá", "‡¶§‡ßá‡¶ì", "‡¶§‡ßá‡¶á", "‡ßá‡¶§‡ßá",
    "‡¶§‡ßá‡¶æ", "‡¶§‡ßá‡¶á‡¶ì", "‡¶§‡ßá ‡¶•‡ßá‡¶ï‡ßá‡¶á",
    "‡¶õ‡¶ø‡¶≤‡¶æ‡¶Æ", "‡¶õ‡¶ø‡¶≤‡ßá", "‡¶õ‡¶ø‡¶≤‡ßá‡¶®", "‡¶õ‡¶ø‡¶≤", "‡¶õ‡¶ø‡¶≤‡ßã", "‡¶õ‡ßá‡¶®", "‡¶õ‡ßá‡¶®‡¶á", "‡¶õ‡ßá", "‡¶õ‡¶ø", "‡¶¨‡ßã", "‡¶¨‡ßá‡¶®",
    "‡¶¨‡ßá", "‡¶¨", "‡¶§‡ßá‡¶õ‡¶ø", "‡¶§‡ßá‡¶õ‡ßá", "‡¶§‡ßá‡¶õ‡¶ø‡¶≤", "‡¶§‡ßá‡¶õ‡¶ø‡¶≤‡ßá‡¶®", "‡¶§‡ßá‡¶õ‡¶ø‡¶≤‡¶æ‡¶Æ", "‡¶Ü‡¶õ‡ßá", "‡¶Ü‡¶õ‡¶ø‡¶≤",
    "‡¶π‡¶Ø‡¶º‡ßá ‡¶ó‡ßá‡¶õ‡ßá", "‡¶π‡¶Ø‡¶º‡ßá‡¶õ‡ßá", "‡¶π‡ßü‡ßá‡¶õ‡ßá", "‡¶π‡ßü‡ßá‡¶õ‡¶ø‡¶≤", "‡¶π‡¶Ø‡¶º‡ßá‡¶õ‡¶ø‡¶≤", "‡¶Ø‡¶æ‡¶ö‡ßç‡¶õ‡¶ø", "‡¶Ø‡¶æ‡¶ö‡ßç‡¶õ‡ßá",
    "‡¶Ø‡¶æ‡¶ö‡ßç‡¶õ‡ßá‡¶®", "‡¶Ø‡¶æ‡¶¨‡ßá‡¶®", "‡¶Ø‡¶æ‡¶¨‡ßá", "‡¶ó‡ßá‡¶õ‡ßá", "‡¶Ü‡¶∏‡¶õ‡ßá", "‡¶Ü‡¶∏‡¶õ‡ßá‡¶®", "‡¶ï‡¶∞‡ßá‡¶õ‡¶ø", "‡¶ï‡¶∞‡ßá‡¶õ‡ßá",
    "‡¶ï‡¶∞‡ßá‡¶õ‡ßá‡¶®", "‡¶ï‡¶∞‡¶õ‡¶ø‡¶≤", "‡¶ï‡¶∞‡¶õ‡¶ø‡¶≤‡¶æ‡¶Æ", "‡¶ï‡¶∞‡¶õ‡¶ø‡¶≤‡ßá‡¶®", "‡¶ï‡¶∞‡¶¨‡ßá", "‡¶ï‡¶∞‡¶¨‡ßá‡¶®", "‡¶ï‡¶∞‡¶õ‡ßá",
    "‡¶§‡ßá ‡¶™‡¶æ‡¶∞‡ßá", "‡¶§‡ßá ‡¶™‡¶æ‡¶∞‡¶ø", "‡¶§‡ßá ‡¶™‡¶æ‡¶∞‡ßá‡¶®", "‡¶§‡ßá ‡¶™‡¶æ‡¶∞‡¶¨‡ßá", "‡¶§‡ßá ‡¶™‡¶æ‡¶∞‡¶¨‡ßá‡¶æ", "‡¶§‡ßá ‡¶™‡¶æ‡¶∞‡¶§", "‡¶§‡ßá ‡¶™‡¶æ‡¶∞‡¶§‡ßá‡¶®",
    "‡¶§‡¶æ", "‡¶§‡ßç‡¶Ø", "‡¶§‡ßç‡¶¨", "‡ßÄ‡¶§‡ßç‡¶¨", "‡¶ø‡¶§‡ßç‡¶¨", "‡ßÄ‡¶ï", "‡ßÄ‡¶Ø‡¶º", "‡ßÄ‡ßü", "‡ßÄ‡¶§‡¶æ", "‡¶ø‡¶§‡¶æ", "‡¶ï‡¶æ‡¶∞‡ßÄ", "‡¶¨‡¶æ‡¶¶",
    "‡¶¨‡¶æ‡¶ö‡¶ï", "‡¶Æ‡ßü", "‡¶Æ‡¶Ø‡¶º", "‡¶∏‡¶Æ‡ßÇ‡¶π", "‡¶ó‡¶£", "‡¶ú‡¶æ‡¶§", "‡¶Ø‡ßã‡¶ó‡ßç‡¶Ø", "‡¶Ø‡ßã‡¶ó‡ßç‡¶Ø‡¶§‡¶æ", "‡¶™‡ßÇ‡¶∞‡ßç‡¶£", "‡¶™‡ßÇ‡¶∞‡ßç‡¶£‡¶§‡¶æ",
    "‡¶¨‡ßÉ‡¶§‡ßç‡¶§‡¶ø", "‡¶¨‡ßã‡¶ß", "‡¶∏‡ßÅ‡¶≤‡¶≠",
    "‡¶∞‡¶æ", "‡¶ú‡¶®", "‡¶ú‡¶®‡¶∞‡¶æ", "‡¶ú‡¶®‡ßá‡¶∞", "‡¶ú‡¶®‡¶ï‡ßá", "‡¶≤‡ßã‡¶ï", "‡¶≤‡ßã‡¶ï‡ßá‡¶∞‡¶æ", "‡¶≤‡ßã‡¶ï‡¶ú‡¶®", "‡¶ú‡¶®‡¶ó‡¶£",
    "‡¶ú‡ßÄ", "‡¶ú‡¶ø", "‡¶∏‡¶æ‡¶π‡ßá‡¶¨", "‡¶¨‡¶æ‡¶¨‡ßÅ", "‡¶Æ‡¶∂‡¶æ‡¶á", "‡¶¶‡¶æ‡¶¶‡¶æ", "‡¶¶‡¶ø‡¶¶‡¶æ", "‡¶Æ‡¶æ", "‡¶¨‡¶æ‡¶¨‡¶æ", "‡¶Æ‡¶æ‡¶Æ‡¶æ", "‡¶§‡¶æ‡¶á",
    "‡¶∏‡¶∞‡¶ø", "‡¶Æ‡¶∞‡ßç‡¶®",
    "‡¶ì", "‡¶á", "‡¶®", "‡¶®‡¶æ", "‡¶§‡ßã", "‡¶§‡¶æ", "‡¶á‡¶ì", "‡¶á‡¶§‡ßá‡¶á", "‡¶¶‡ßá‡¶®", "‡¶¶‡ßá‡¶á", "‡¶´‡¶≤‡ßá", "‡¶•‡¶æ‡¶ï‡¶≤‡ßá",
    "‡¶™‡¶æ‡¶ì", "‡¶™‡¶æ‡¶á", "‡¶™‡ßá‡¶≤‡ßá", "‡¶™‡ßá‡¶≤‡ßá‡¶æ", "‡¶™‡¶æ‡¶ì‡ßü‡¶æ", "‡¶™‡ßá‡¶Ø‡¶º‡ßá‡¶õ‡ßá", "‡¶™‡ßá‡ßü‡ßá‡¶õ‡ßá",
    "‡¶Æ‡ßü", "‡¶Æ‡¶Ø‡¶º", "‡¶Ø‡ßá", "‡¶Ø‡¶æ‡¶∞", "‡¶Ø‡¶æ‡¶ï‡ßá", "‡¶Ø‡¶æ‡¶§‡ßá", "‡¶Ø‡¶æ‡¶ì", "‡¶Ø‡¶æ‡¶ö‡ßç‡¶õ‡ßá", "‡¶Ø‡¶æ‡¶¨‡ßá", "‡¶™‡¶∞‡ßá‡¶æ", "‡¶™‡ßú‡ßá‡¶õ‡ßá",
    "‡¶™‡¶°‡¶º‡ßá‡¶õ‡ßá", "‡¶¶‡¶ø‡ßü‡ßá‡¶æ", "‡¶¶‡ßá‡¶ì‡ßü‡¶æ", "‡¶¶‡ßá‡¶ì‡¶Ø‡¶º‡¶æ", "‡¶®‡ßá‡¶á", "‡¶•‡¶æ‡¶ï‡¶ø", "‡¶•‡¶æ‡¶ï‡ßá‡¶®", "‡¶•‡¶æ‡¶ï‡¶õ‡ßá",
    "‡¶Ö‡¶™‡¶∞‡¶∞‡ßÇ‡¶™", "‡¶Ö‡¶™‡¶∞‡¶™‡ßç‡¶∞‡¶§", "‡¶Ö‡¶™‡¶∞", "‡¶Ö‡¶§‡¶ø", "‡¶Ö‡¶§‡¶ø‡¶∂‡¶Ø‡¶º", "‡¶Ö‡¶§‡¶ø‡¶∂‡ßü", "‡¶Ö‡¶®‡ßÅ‡¶™‡ßç‡¶∞‡¶¨‡ßá‡¶∂", "‡¶Ö‡¶®‡ßÅ‡¶™‡ßç‡¶∞‡ßá‡¶∞‡¶£‡¶æ",
    "‡¶Ö‡¶®‡ßÅ", "‡¶Ö‡¶®", "‡¶â‡¶™", "‡¶â‡¶™-‡¶™", "‡¶â‡¶™‡ßã‡¶∏", "‡¶™‡ßç‡¶∞‡¶§‡¶ø", "‡¶™‡ßÅ‡¶®‡¶∞", "‡¶™‡ßÅ‡¶®‡¶É", "‡¶™‡ßÅ‡¶®", "‡¶∏‡ßç‡¶¨",
    "‡¶∏‡¶Æ", "‡¶¶‡ßç‡¶¨‡¶ø", "‡¶§‡ßç‡¶∞‡¶ø", "‡¶§‡ßç‡¶∞‡ßà", "‡¶Ö‡¶¨", "‡¶¨‡¶π‡ßÅ", "‡¶â‡¶¶", "‡¶Ö‡¶®‡ßç‡¶§‡¶∞", "‡¶Ö‡¶®‡ßç‡¶§‡¶É", "‡¶™‡¶∞",
    "‡¶¨‡¶ø‡¶∞‡ßÅ", "‡¶¨‡¶ø‡¶®", "‡¶Ü‡¶§‡ßç‡¶Æ", "‡¶Ü‡¶§‡ßç‡¶Æ-", "‡¶®‡¶ø‡¶∞", "‡¶®‡¶ø‡¶π‡¶ø‡¶§", "‡¶Ö‡¶§‡¶ø-‡¶§‡ßÉ‡¶∑‡ßç‡¶£‡¶æ", "‡¶∏‡ßÅ", "‡¶∏‡ßç‡¶¨‡¶∞", "‡¶∏‡ßç‡¶¨‡¶ß‡ßÄ",
    "‡¶™‡¶∂‡ßç‡¶ö‡¶æ‡¶§", "‡¶™‡ßÇ‡¶∞‡ßç‡¶¨", "‡¶™‡ßÇ‡¶∞‡ßç‡¶¨-" , "‡¶Ö‡¶™‡ßç‡¶∞", "‡¶™‡ßç‡¶∞‡¶§‡¶ø-", "‡¶¨‡¶ø‡¶®‡¶æ", "‡¶∏‡¶Ç", "‡¶∏‡¶Ç-",
    "‡¶Ö", "‡¶™‡ßÅ‡¶®‡¶∞", "‡¶™‡ßÅ‡¶®‡¶É", "‡¶Ö‡¶¨", "‡¶â‡¶™", "‡¶™‡ßç‡¶∞", "‡¶®‡¶æ", "‡¶®‡¶ø", "‡¶Ö‡¶§‡¶ø", "‡¶â‡ßé", "‡¶â‡ßé‡¶™", "‡¶â‡¶¶‡ßç",
    "‡¶™‡¶∞‡¶ø", "‡¶∏‡¶Æ‡ßç‡¶¨", "‡¶∏‡¶Æ‡¶∞‡ßç‡¶•", "‡¶∏‡ßç‡¶¨‡¶®", "‡¶∏‡ßÅ-",
    # safety: very short items last (avoid over-stripping)
    "‡ßá‡¶á", "‡¶á‡¶á", "‡¶á", "‡¶ì‡¶á", "‡¶ì", "‡ßç‡¶∞", "‡ßç‡¶∑", "‡ßç‡¶§", "‡ßç‡¶ï",
]

# longest-first sort
_BN_COMMON_SUFFIXES_SORTED: List[str] = sorted(_BN_COMMON_SUFFIXES, key=lambda s: len(s), reverse=True)

# regex to remove punctuation
_RE_PUNCT = re.compile(r"[^\w\u0980-\u09FF\-]+", flags=re.UNICODE)

# Bengali vowel signs (common combining marks)
_VOWEL_SIGNS = {
    "\u09BE",  # ‡¶æ
    "\u09BF",  # ‡¶ø
    "\u09C0",  # ‡ßÄ
    "\u09C1",  # ‡ßÅ
    "\u09C2",  # ‡ßÇ
    "\u09C3",  # ‡ßÉ
    "\u09C7",  # ‡ßá  <-- important for "‡¶ï‡ßá"
    "\u09C8",  # ‡ßà
    "\u09CB",  # ‡ßã
    "\u09CC",  # ‡ßå
    "\u0982",  # ‡¶Ç
    "\u0983",  # ‡¶É
}

# zero-width cleanup
_ZW_RE = re.compile(r"[\u200b\u200c\u200d]+")

def _ends_with_vowel_sign(s: str) -> bool:
    return len(s) > 0 and s[-1] in _VOWEL_SIGNS

def normalize_bn_word(raw: Optional[str]) -> str:
    """
    Normalizes Bengali token and does vowel-aware suffix stripping.
    - NFC unicode normalization
    - removes common token markers (subword, BPE)
    - removes punctuation
    - iterates over sorted suffixes: if suffix matches and suffix ends with vowel sign
      only remove trailing vowel sign(s) from the word (preserve consonant),
      otherwise remove whole suffix (longest-first).
    - requires resulting stem to be of minimal sensible length (>=2)
    """
    if raw is None:
        return ""
    s = str(raw).strip()
    if not s:
        return ""

    # Normalize unicode form (use NFC for stable composed form)
    s = unicodedata.normalize("NFC", s)

    # remove subword markers
    for mk in ("‚ñÅ", "##", "ƒ†", "@@"):
        s = s.replace(mk, "")

    # remove zero-width joiners/markers
    s = _ZW_RE.sub("", s)

    # strip surrounding ascii punctuation
    s = s.strip(" \t\n\r.,;:!?\"'()[]{}‚Äî‚Äì-")

    # remove internal punctuation (non-Bengali letter/number/dash)
    s = _RE_PUNCT.sub("", s)

    # iterate over suffix list (longest-first)
    for suf in _BN_COMMON_SUFFIXES_SORTED:
        try:
            if not suf:
                continue
            if s.endswith(suf) and (len(s) - len(suf) >= 2):
                # If the suffix ends with a vowel sign (e.g., "‡¶ï‡ßá" ends with "‡ßá"),
                # prefer removing only the trailing vowel sign(s) from the word
                if any(ch in _VOWEL_SIGNS for ch in suf):
                    # strip trailing vowel signs from the word (but keep consonant)
                    while _ends_with_vowel_sign(s) and len(s) > 1:
                        s = s[:-1]
                    s = s.strip()
                else:
                    # safe remove whole suffix (as before)
                    s = s[: -len(suf)].strip()
                break
        except Exception:
            continue

    # final normalization
    s = unicodedata.normalize("NFC", s).strip()
    return s

# expose globally if run inside a notebook/script
globals()['normalize_bn_word'] = normalize_bn_word

# quick smoke test
if __name__ == "__main__":
    tests = ["‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï‡ßá", "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï‡ßá‡¶∞", "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï", "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï‡¶ï", "‡¶ï‡¶≤‡¶ó‡ßÅ‡¶≤‡ßã‡¶§‡ßá", "‡¶ï‡¶≤‡¶ó‡ßÅ‡¶≤‡ßã‡¶ï‡ßá"]
    for t in tests:
        print(f"{t} -> {normalize_bn_word(t)}")

‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï‡ßá -> ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï
‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï‡ßá‡¶∞ -> ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï‡ßá
‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï -> ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï
‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï‡¶ï -> ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï‡¶ï
‡¶ï‡¶≤‡¶ó‡ßÅ‡¶≤‡ßã‡¶§‡ßá -> ‡¶ï‡¶≤‡¶ó‡ßÅ‡¶≤‡ßã‡¶§
‡¶ï‡¶≤‡¶ó‡ßÅ‡¶≤‡ßã‡¶ï‡ßá -> ‡¶ï‡¶≤‡¶ó‡ßÅ‡¶≤‡ßã‡¶ï


In [8]:
# ==============================================================================
# Cell 4 (patched): ASBN module ‚Äî device-safe, memory-friendly, defensive
# - Thoroughly hardened and line-by-line defensive fixes applied.
# - Exposes a toggle ASBN_MONITOR_IN_EVAL to allow monitoring in eval mode if desired.
# - Robust parsing of DSCD outputs (many possible shapes).
# - Functional frozen forward for discriminator parameters (GRL-style encoder loss).
# - Safe device movement for discriminator submodules.
# ==============================================================================
import traceback
from typing import Any, List, Tuple, Optional, Dict

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Local safe globals (use globals().get to avoid NameError)
_MAX_LENGTH = int(globals().get("MAX_LENGTH", 48))
_ENABLE_ASBN_TRAINING = bool(globals().get("ENABLE_ASBN_TRAINING", True))
_VERBOSE_LOGGING = bool(globals().get("VERBOSE_LOGGING", False))
_SOURCE_LANGUAGE = str(globals().get("SOURCE_LANGUAGE", "bn"))
_ASBN_MONITOR_IN_EVAL = bool(globals().get("ASBN_MONITOR_IN_EVAL", False))  # New: allow monitoring even when module.eval()

_has_is_valid_token = "is_valid_token" in globals()
_has_get_special_tokens = "get_special_tokens" in globals()

# Utility: safe device selector
def _device_of(x: Any) -> torch.device:
    if isinstance(x, torch.Tensor):
        return x.device
    # default device
    return torch.device("cuda" if torch.cuda.is_available() else "cpu")


class LightweightDiscriminator(nn.Module):
    """Small discriminator head used by ASBN (kept intentionally tiny)."""

    def __init__(self, input_dim: int):
        super().__init__()
        # two-layer MLP
        self.classifier = nn.Sequential(
            nn.Linear(input_dim, 64),
            nn.ReLU(),
            nn.Dropout(0.1),
            nn.Linear(64, 2)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.classifier(x)


class MemoryEfficientASBNModule(nn.Module):
    """
    ASBN module (robust/hardened).
    - forward_discriminators_simplified: monitoring pass (no grad).
    - forward_with_grl_simplified: computes encoder loss with frozen discriminator params.
    """

    def __init__(self, embed_dim: int, tokenizer=None, language: str = "bn"):
        super().__init__()
        self.language = language
        self.tokenizer = tokenizer

        # discriminators
        self.d_freq = LightweightDiscriminator(embed_dim + 2)
        self.d_ctx = LightweightDiscriminator(embed_dim + 2)
        self.d_xl = LightweightDiscriminator(embed_dim)

        # scaling knobs
        self.lambda_base = {"freq": 1.0, "ctx": 0.5, "xl": 0.8}
        self.lambda_max = 2.0
        self.encoder_grl_scale = float(globals().get("ASBN_ENCODER_GRL_SCALE", 0.1))

        # Cache special tokens robustly
        try:
            if tokenizer is not None and _has_get_special_tokens:
                self.special_tokens = globals()["get_special_tokens"](tokenizer)
            elif tokenizer is not None:
                self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []))
            else:
                self.special_tokens = set()
        except Exception:
            self.special_tokens = set()

    def critic_parameters(self):
        return list(self.d_freq.parameters()) + list(self.d_ctx.parameters()) + list(self.d_xl.parameters())

    # -----------------------
    # helpers
    # -----------------------
    def _ensure_discriminators_on_device(self, device: torch.device):
        """
        Best-effort: move discriminator modules to device. Do not raise.
        """
        try:
            for mod in (self.d_freq, self.d_ctx, self.d_xl):
                try:
                    p = next(mod.parameters(), None)
                    if p is not None and p.device != device:
                        mod.to(device)
                except Exception:
                    try:
                        mod.to(device)
                    except Exception:
                        if _VERBOSE_LOGGING:
                            print("[ASBN] warning moving discriminator to device failed")
        except Exception:
            if _VERBOSE_LOGGING:
                print("[ASBN] _ensure_discriminators_on_device failed:", traceback.format_exc().splitlines()[-1])

    def _parse_proto_probs_matrix(self, proto_probs: Any, batch_size: int, seq_len: int, device: torch.device) -> torch.Tensor:
        """
        Normalize proto_probs into pmax tensor [B, T] containing max prototype prob per token.
        Accepts torch.Tensor (B,T,K | T,K | T | K), lists of lists, numpy arrays, etc.
        """
        pmax = torch.full((batch_size, seq_len), 0.5, dtype=torch.float32, device=device)
        try:
            if proto_probs is None:
                return pmax

            # case: torch tensor
            if isinstance(proto_probs, torch.Tensor):
                p = proto_probs.detach().to(device)
                if p.dim() == 3:
                    B, T, K = p.shape
                    vals = p.max(dim=2)[0]
                    pmax[:min(batch_size, B), :min(seq_len, T)] = vals[:batch_size, :seq_len]
                    return pmax
                elif p.dim() == 2:
                    # treat rows as per-token vectors for a single batch
                    if p.size(0) <= seq_len and batch_size >= 1:
                        vals = p.max(dim=1)[0]
                        pmax[0, :min(seq_len, vals.size(0))] = vals[:seq_len]
                        return pmax
                    else:
                        vals = p.max(dim=1)[0]
                        pmax[0, :min(seq_len, vals.size(0))] = vals[:seq_len]
                        return pmax
                elif p.dim() == 1:
                    pmax[0, :min(seq_len, p.size(0))] = p[:seq_len]
                    return pmax

            # list/tuple handling
            if isinstance(proto_probs, (list, tuple)):
                # if matches batch length
                if len(proto_probs) == batch_size:
                    for b in range(batch_size):
                        row = proto_probs[b]
                        if isinstance(row, torch.Tensor):
                            if row.dim() == 2:
                                vals = row.max(dim=1)[0].detach().to(device)
                                pmax[b, :min(seq_len, vals.size(0))] = vals[:seq_len]
                            elif row.dim() == 1:
                                vals = row.detach().to(device)
                                pmax[b, :min(seq_len, vals.size(0))] = vals[:seq_len]
                        elif isinstance(row, (list, tuple, np.ndarray)):
                            for t in range(min(seq_len, len(row))):
                                try:
                                    val = row[t]
                                    if isinstance(val, torch.Tensor):
                                        arr = val.detach().cpu().numpy()
                                        if arr.size:
                                            pmax[b, t] = float(arr.max())
                                    else:
                                        arr = np.asarray(val, dtype=np.float32)
                                        if arr.size:
                                            pmax[b, t] = float(np.max(arr))
                                except Exception:
                                    pmax[b, t] = 0.5
                    return pmax
                else:
                    # maybe single batch list-of-per-token
                    if batch_size == 1:
                        for t in range(min(seq_len, len(proto_probs))):
                            try:
                                val = proto_probs[t]
                                if isinstance(val, torch.Tensor):
                                    arr = val.detach().cpu().numpy()
                                    pmax[0, t] = float(np.max(arr)) if arr.size else 0.5
                                else:
                                    arr = np.asarray(val, dtype=np.float32)
                                    pmax[0, t] = float(np.max(arr)) if arr.size else 0.5
                            except Exception:
                                pmax[0, t] = 0.5
                        return pmax
        except Exception:
            if _VERBOSE_LOGGING:
                print("[ASBN] parse_proto_probs exception:", traceback.format_exc().splitlines()[-1])
        return pmax

    def _parse_scalar_matrix(self, mat: Any, batch_size: int, seq_len: int, device: torch.device, default: float = 0.0) -> torch.Tensor:
        """
        Normalize scalar-like structures into [B, T] tensor (supports torch.Tensor, list, tuple, numpy).
        """
        out = torch.full((batch_size, seq_len), float(default), dtype=torch.float32, device=device)
        try:
            if mat is None:
                return out

            if isinstance(mat, torch.Tensor):
                m = mat.detach().to(device)
                if m.dim() == 3:
                    out[:min(batch_size, m.size(0)), :min(seq_len, m.size(1))] = m[:, :seq_len, 0]
                elif m.dim() == 2:
                    if m.size(0) == batch_size:
                        out[:, :min(seq_len, m.size(1))] = m[:, :seq_len]
                    elif batch_size == 1:
                        out[0, :min(seq_len, m.size(0))] = m[:seq_len]
                elif m.dim() == 1:
                    if batch_size == 1:
                        out[0, :min(seq_len, m.size(0))] = m[:seq_len]
                return out

            if isinstance(mat, (list, tuple, np.ndarray)):
                if len(mat) == batch_size:
                    for b in range(batch_size):
                        row = mat[b]
                        if isinstance(row, torch.Tensor):
                            r = row.detach().to(device)
                            for t in range(min(seq_len, r.size(0))):
                                out[b, t] = float(r[t].item())
                        else:
                            for t in range(min(seq_len, len(row))):
                                try:
                                    v = row[t]
                                    out[b, t] = float(v.item()) if isinstance(v, torch.Tensor) else float(v)
                                except Exception:
                                    out[b, t] = float(default)
                    return out
                else:
                    # single-batch sequence
                    if batch_size == 1:
                        row = mat
                        for t in range(min(seq_len, len(row))):
                            try:
                                v = row[t]
                                out[0, t] = float(v.item()) if isinstance(v, torch.Tensor) else float(v)
                            except Exception:
                                out[0, t] = float(default)
                        return out
        except Exception:
            if _VERBOSE_LOGGING:
                print("[ASBN] parse_scalar_matrix exception:", traceback.format_exc().splitlines()[-1])
        return out

    def compute_lambda_scaled_tensor(self, pmax: torch.Tensor, uncertainty: torch.Tensor, gate: torch.Tensor, lambda_type: str) -> torch.Tensor:
        """
        lam = base * pmax * (1 - uncertainty) * gate, clipped to [0, lambda_max]
        """
        base = float(self.lambda_base.get(lambda_type, 0.2))
        lam = base * pmax * (1.0 - uncertainty) * gate
        lam = torch.clamp(lam, 0.0, float(self.lambda_max))
        lam = torch.where(torch.isfinite(lam), lam, torch.zeros_like(lam))
        return lam

    # -----------------------
    # Monitor (no grad) - safe even if discriminators on CPU
    # -----------------------
    def forward_discriminators_simplified(
        self,
        h: Optional[torch.Tensor],
        proto_probs: Any,
        uncertainties: Any,
        gates: Any,
        token_word_map: Optional[List[Dict[int, str]]] = None
    ) -> torch.Tensor:
        """
        Monitoring pass under torch.no_grad(). Returns scalar Tensor on same device as h.
        Controlled by module training or ASBN_MONITOR_IN_EVAL flag.
        """
        device = _device_of(h)
        zero = torch.tensor(0.0, device=device)

        # Monitor only when training by default unless user forces monitoring in eval
        if (not self.training) and (not _ASBN_MONITOR_IN_EVAL):
            return zero

        if not isinstance(h, torch.Tensor) or h.dim() != 3:
            return zero

        B, T, H = h.size()

        # Ensure discriminators are available on device (best-effort)
        try:
            self._ensure_discriminators_on_device(device)
        except Exception:
            pass

        pmax_mat = self._parse_proto_probs_matrix(proto_probs, B, T, device)
        U_mat = self._parse_scalar_matrix(uncertainties, B, T, device, default=0.1)
        G_mat = self._parse_scalar_matrix(gates, B, T, device, default=0.0)

        sel_mask = torch.ones((B, T), dtype=torch.bool, device=device)

        # Filter using token_word_map and optional is_valid_token helper
        if token_word_map:
            try:
                for b in range(min(B, len(token_word_map))):
                    wm = token_word_map[b] or {}
                    for t in range(T):
                        if t in wm:
                            if _has_is_valid_token:
                                try:
                                    if not is_valid_token(wm[t], self.special_tokens, self.tokenizer, language=self.language):
                                        sel_mask[b, t] = False
                                except Exception:
                                    sel_mask[b, t] = False
                            else:
                                w = str(wm[t])
                                if len(w.strip()) < 2:
                                    sel_mask[b, t] = False
            except Exception:
                if _VERBOSE_LOGGING:
                    print("[ASBN] token_word_map filter failed:", traceback.format_exc().splitlines()[-1])

        sel_idx = sel_mask.view(-1).nonzero(as_tuple=False).squeeze(1)
        if sel_idx.numel() == 0:
            return zero

        # Gather features
        h_flat = h.view(B * T, H)
        sel_emb = h_flat[sel_idx]
        pmax_flat = pmax_mat.view(-1)[sel_idx]
        U_flat = U_mat.view(-1)[sel_idx]
        G_flat = G_mat.view(-1)[sel_idx]

        seq_len_feature = float(T) / max(int(_MAX_LENGTH), 1)
        ctx_feature = torch.stack([G_flat, torch.full_like(G_flat, seq_len_feature)], dim=1)
        freq_feature = torch.stack([pmax_flat, U_flat], dim=1)

        freq_input = torch.cat([sel_emb, freq_feature.to(device)], dim=1)
        ctx_input = torch.cat([sel_emb, ctx_feature.to(device)], dim=1)
        xl_input = sel_emb

        try:
            with torch.no_grad():
                self._ensure_discriminators_on_device(device)
                freq_logits = self.d_freq(freq_input)
                ctx_logits = self.d_ctx(ctx_input)
                xl_logits = self.d_xl(xl_input)

                # pseudo labels
                freq_label = (pmax_flat > 0.7).long().to(device)
                ctx_label = (U_flat < 0.3).long().to(device)
                xl_label = (G_flat > 0.5).long().to(device)

                loss_freq = F.cross_entropy(freq_logits, freq_label, reduction="none")
                loss_ctx = F.cross_entropy(ctx_logits, ctx_label, reduction="none")
                loss_xl = F.cross_entropy(xl_logits, xl_label, reduction="none")

                lam_freq = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "freq")
                lam_ctx = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "ctx")
                lam_xl = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "xl")

                weighted = lam_freq * loss_freq + lam_ctx * loss_ctx + lam_xl * loss_xl
                avg_loss = torch.mean(weighted) if weighted.numel() > 0 else torch.tensor(0.0, device=device)
            return avg_loss
        except Exception:
            if _VERBOSE_LOGGING:
                print("[ASBN] Monitor forward failed (device/param issue):", traceback.format_exc().splitlines()[-1])
            return zero

    # -----------------------
    # Encoder GRL using detached/cloned params (functional forward)
    # -----------------------
    def forward_with_grl_simplified(
        self,
        h: Optional[torch.Tensor],
        proto_probs: Any,
        uncertainties: Any,
        gates: Any,
        token_word_map: Optional[List[Dict[int, str]]] = None
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Returns (encoder_loss, disc_monitor_loss, zero, zero)
        encoder_loss is suitable for backprop into encoder representations h.
        """
        device = _device_of(h)
        zero = torch.tensor(0.0, device=device)

        if (not self.training) or (not _ENABLE_ASBN_TRAINING):
            return zero, zero, zero, zero

        if not isinstance(h, torch.Tensor) or h.dim() != 3:
            return zero, zero, zero, zero

        # monitor (no_grad)
        try:
            with torch.no_grad():
                disc_monitor_loss = self.forward_discriminators_simplified(h, proto_probs, uncertainties, gates, token_word_map)
                if not isinstance(disc_monitor_loss, torch.Tensor):
                    disc_monitor_loss = torch.tensor(float(disc_monitor_loss), device=device)
        except Exception:
            if _VERBOSE_LOGGING:
                print("[ASBN] forward_discriminators_simplified (monitor) failed:", traceback.format_exc().splitlines()[-1])
            disc_monitor_loss = torch.tensor(0.0, device=device)

        # compute encoder loss (functional forward)
        try:
            B, T, H = h.size()
            pmax_mat = self._parse_proto_probs_matrix(proto_probs, B, T, device)
            U_mat = self._parse_scalar_matrix(uncertainties, B, T, device, default=0.1)
            G_mat = self._parse_scalar_matrix(gates, B, T, device, default=0.0)

            sel_mask = torch.ones((B, T), dtype=torch.bool, device=device)
            if token_word_map:
                try:
                    for b in range(min(B, len(token_word_map))):
                        wm = token_word_map[b] or {}
                        for t in range(T):
                            if t in wm:
                                if _has_is_valid_token:
                                    try:
                                        if not is_valid_token(wm[t], self.special_tokens, self.tokenizer, language=self.language):
                                            sel_mask[b, t] = False
                                    except Exception:
                                        sel_mask[b, t] = False
                                else:
                                    w = str(wm[t])
                                    if len(w.strip()) < 2:
                                        sel_mask[b, t] = False
                except Exception:
                    if _VERBOSE_LOGGING:
                        print("[ASBN] token_word_map filter (GRL) failed:", traceback.format_exc().splitlines()[-1])

            sel_idx = sel_mask.view(-1).nonzero(as_tuple=False).squeeze(1)
            if sel_idx.numel() == 0:
                encoder_loss = torch.tensor(0.0, device=device, requires_grad=True)
            else:
                h_flat = h.view(B * T, H)
                sel_emb = h_flat[sel_idx]           # [N, H]
                pmax_flat = pmax_mat.view(-1)[sel_idx]
                U_flat = U_mat.view(-1)[sel_idx]
                G_flat = G_mat.view(-1)[sel_idx]

                max_len = max(int(_MAX_LENGTH), 1)
                seq_len_feature = float(T) / float(max_len)
                freq_feature = torch.stack([pmax_flat, U_flat], dim=1).to(device)
                ctx_feature = torch.stack([G_flat, torch.full_like(G_flat, seq_len_feature)], dim=1).to(device)

                freq_input = torch.cat([sel_emb, freq_feature], dim=1)     # [N, Df]
                ctx_input = torch.cat([sel_emb, ctx_feature], dim=1)       # [N, Dc]
                xl_input = sel_emb                                         # [N, H]

                # extract frozen params (as leaf tensors)
                def get_frozen_params(module: nn.Module, device: torch.device):
                    try:
                        # primary path relies on module.classifier structure used in LightweightDiscriminator
                        l0 = module.classifier[0]
                        l1 = module.classifier[3]
                        w0 = l0.weight.detach().clone().to(device)
                        b0 = l0.bias.detach().clone().to(device) if l0.bias is not None else None
                        w1 = l1.weight.detach().clone().to(device)
                        b1 = l1.bias.detach().clone().to(device) if l1.bias is not None else None
                        for t in (w0, b0, w1, b1):
                            if t is not None:
                                t.requires_grad = False
                        return (w0, b0, w1, b1)
                    except Exception:
                        params = list(module.parameters())
                        if len(params) >= 4:
                            try:
                                w0 = params[0].detach().clone().to(device)
                                b0 = params[1].detach().clone().to(device)
                                w1 = params[2].detach().clone().to(device)
                                b1 = params[3].detach().clone().to(device)
                                for t in (w0, b0, w1, b1):
                                    if t is not None:
                                        t.requires_grad = False
                                return (w0, b0, w1, b1)
                            except Exception:
                                pass
                        raise RuntimeError("Failed to extract frozen params from discriminator module")

                frozen_freq = get_frozen_params(self.d_freq, device)
                frozen_ctx = get_frozen_params(self.d_ctx, device)
                frozen_xl = get_frozen_params(self.d_xl, device)

                def functional_classifier_forward(x: torch.Tensor, frozen_params, dropout_p: float = 0.1):
                    w0, b0, w1, b1 = frozen_params
                    y = F.linear(x, w0, b0)
                    y = F.relu(y)
                    # dropout in functional form (training False here)
                    y = F.dropout(y, p=dropout_p, training=False)
                    y = F.linear(y, w1, b1)
                    return y

                freq_logits = functional_classifier_forward(freq_input, frozen_freq, dropout_p=0.1)
                ctx_logits = functional_classifier_forward(ctx_input, frozen_ctx, dropout_p=0.1)
                xl_logits = functional_classifier_forward(xl_input, frozen_xl, dropout_p=0.1)

                freq_label = (pmax_flat > 0.7).long().to(device)
                ctx_label = (U_flat < 0.3).long().to(device)
                xl_label = (G_flat > 0.5).long().to(device)

                loss_freq = F.cross_entropy(freq_logits, freq_label, reduction="none")
                loss_ctx = F.cross_entropy(ctx_logits, ctx_label, reduction="none")
                loss_xl = F.cross_entropy(xl_logits, xl_label, reduction="none")

                lam_freq = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "freq")
                lam_ctx = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "ctx")
                lam_xl = self.compute_lambda_scaled_tensor(pmax_flat, U_flat, G_flat, "xl")

                weighted = lam_freq * loss_freq + lam_ctx * loss_ctx + lam_xl * loss_xl
                mean_weighted = torch.mean(weighted) if weighted.numel() > 0 else torch.tensor(0.0, device=device)
                encoder_loss = -float(self.encoder_grl_scale) * mean_weighted
                encoder_loss = encoder_loss.to(device)
        except Exception:
            if _VERBOSE_LOGGING:
                print("[ASBN] GRL computation failed:", traceback.format_exc().splitlines()[-1])
            encoder_loss = torch.tensor(0.0, device=device, requires_grad=True)

        return encoder_loss, disc_monitor_loss, torch.tensor(0.0, device=device), torch.tensor(0.0, device=device)


print("‚úÖ Cell 4 (patched final, device-safe): ASBN module ready (functional frozen-forward + discriminator device safety)")

‚úÖ Cell 4 (patched final, device-safe): ASBN module ready (functional frozen-forward + discriminator device safety)


In [9]:
# ==============================================================================
# CELL 5 (patched): TRG EXPLANATION SYSTEM (INFERENCE-ONLY + MULTI-GPU OPTIMIZED)
# - Thorough, line-by-line hardening
# - Aggregates subword pieces using normalized token_word_map (norm) when available
# - Avoids early-return on self.training so eval-mode toggling handled externally
# - Robust proto_probs / scalar parsing, safe paddings and averaging for aggregation
# ==============================================================================
from typing import List, Dict, Tuple, Optional, Any
from collections import deque
import numpy as np
import torch
import torch.nn as nn

# Robust config defaults
_TRG_EVIDENCE_K = int(globals().get("TRG_EVIDENCE_K", 3))
_TRG_GEN_EMBED = int(globals().get("TRG_GEN_EMBED", 64))
_MAX_SILVER_BUFFER = int(globals().get("MAX_SILVER_BUFFER", 50))
_VERBOSE_LOGGING = bool(globals().get("VERBOSE_LOGGING", False))
_ENABLE_TRG_INFERENCE = bool(globals().get("ENABLE_TRG_INFERENCE", True))
_SOURCE_LANGUAGE = str(globals().get("SOURCE_LANGUAGE", "bn"))
_TRG_UNCERTAINTY_THRESHOLD = float(globals().get("TAU_LOW", 0.40))

_has_is_valid_token = "is_valid_token" in globals()
_has_get_tokenizer_special_tokens = "get_tokenizer_special_tokens" in globals()
_has_get_cached_special_tokens = "get_cached_special_tokens" in globals()

# Optional normalizer (may be provided by bn_normalizer cell)
_normalize_fn = globals().get("normalize_bn_word", None)


def _is_word_start(raw_token: str, token_word_map: Optional[dict], idx: int) -> bool:
    """
    Robust word-start detection (SPM/BPE markers or reconstructed word in token_word_map).
    """
    try:
        if token_word_map and isinstance(token_word_map, dict) and idx in token_word_map:
            w = token_word_map[idx]
            return isinstance(w, str) and len(w.strip()) > 0

        if isinstance(raw_token, str):
            if raw_token.startswith("‚ñÅ") or raw_token.startswith("ƒ†"):
                return True
            clean = raw_token.replace("‚ñÅ", "").replace("ƒ†", "").strip()
            if len(clean) >= 2 and not all(ch in '.,;:!?"\'()[]{}-/' for ch in clean):
                return True
    except Exception:
        pass
    return False


class ComprehensiveTRGExplanationTemplate:
    """Templates to render explanation strings."""

    def __init__(self):
        self.explanation_templates = {
            "high_confidence": (
                "Chose '{sense}' with high confidence ({confidence:.1%}) based on contextual evidence: '{evidence}'. "
                "This matches the learned pattern. {alternatives_text}"
            ),
            "medium_confidence": (
                "Selected '{sense}' with moderate confidence ({confidence:.1%}). "
                "Evidence: '{evidence}'. Some uncertainty remains. {alternatives_text}"
            ),
            "low_confidence": (
                "Uncertain between senses; chose '{sense}' ({confidence:.1%}). "
                "Evidence: '{evidence}'. {alternatives_text} Review recommended."
            ),
            "fallback": (
                "Token '{token}' processed with standard analysis. Context: '{evidence}'."
            ),
        }

    def generate_explanation(self, evidence: Dict) -> str:
        token = str(evidence.get("token", "unknown")).replace("‚ñÅ", "")
        sense_info = evidence.get("chosen_sense", ("unknown", 0.5))

        if isinstance(sense_info, (tuple, list)) and len(sense_info) >= 2:
            sense_name, confidence = str(sense_info[0]), float(sense_info[1])
        else:
            sense_name, confidence = "unknown", 0.5

        evidence_tokens = evidence.get("evidence_tokens", [])
        evidence_str = ", ".join([str(tok).replace("‚ñÅ", "") for tok in evidence_tokens[:_TRG_EVIDENCE_K]]) or "limited context"

        alternatives = evidence.get("alternatives", [])
        alternatives_text = ""
        if isinstance(alternatives, list) and len(alternatives) > 0:
            alt_parts = []
            for alt in alternatives[:2]:
                if isinstance(alt, (tuple, list)) and len(alt) >= 2:
                    alt_name, alt_conf = str(alt[0]), float(alt[1])
                    alt_parts.append(f"'{alt_name}' ({alt_conf:.1%})")
            if alt_parts:
                alternatives_text = f"Alternatives: {', '.join(alt_parts)} considered."

        if confidence >= 0.65:
            template_key = "high_confidence"
        elif confidence >= 0.4:
            template_key = "medium_confidence"
        else:
            template_key = "low_confidence"

        template = self.explanation_templates.get(template_key, self.explanation_templates["fallback"])

        try:
            return template.format(
                sense=sense_name,
                confidence=confidence,
                evidence=evidence_str,
                alternatives_text=alternatives_text,
                token=token,
            )
        except Exception:
            return f"Token '{token}' disambiguated as '{sense_name}' ({confidence:.1%})."


class MemoryEfficientTRGExtractor:
    """Extracts evidence around a token for explanation rendering and handles aggregation."""

    def __init__(self, tokenizer=None, language="bn"):
        self.tokenizer = tokenizer
        self.language = language

        # special tokens retrieval robustly
        if tokenizer is not None:
            try:
                if _has_get_tokenizer_special_tokens:
                    self.special_tokens = globals()["get_tokenizer_special_tokens"](tokenizer)
                elif _has_get_cached_special_tokens:
                    self.special_tokens = globals()["get_cached_special_tokens"](tokenizer)
                else:
                    self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []))
            except Exception:
                self.special_tokens = set()
        else:
            self.special_tokens = set()

    # -------------------------
    # HIGH-LEVEL: evidence extraction with aggregation across subword pieces
    # -------------------------
    def extract_evidence_efficiently(
        self,
        token_idx: int,
        tokens: List[str],
        dscd_outputs: Dict,
        token_word_map: Optional[dict] = None,
    ) -> Dict:
        """Extract evidence safely and aggregate subword pieces using normalized token_word_map when available."""
        if not isinstance(tokens, list) or token_idx < 0 or token_idx >= len(tokens):
            return self._create_fallback_evidence(token_idx, tokens or [])

        raw_token = tokens[token_idx]

        # Basic validity check
        if _has_is_valid_token:
            try:
                is_valid = is_valid_token(raw_token, self.special_tokens, self.tokenizer, language=self.language)
            except Exception:
                is_valid = raw_token not in self.special_tokens and len(str(raw_token)) >= 2
        else:
            is_valid = raw_token not in self.special_tokens and len(str(raw_token)) >= 2

        if not is_valid:
            return self._create_fallback_evidence(token_idx, tokens)

        try:
            # Determine indices to aggregate: prefer token_word_map['norm'] equality if available
            agg_indices = [token_idx]
            try:
                if token_word_map and isinstance(token_word_map, dict):
                    # support two common formats:
                    # 1) token_word_map is mapping idx -> reconstructed word
                    # 2) token_word_map is dict with 'orig' and 'norm' maps
                    norm_map = token_word_map.get("norm") if isinstance(token_word_map.get("norm", None), dict) else None
                    orig_map = token_word_map.get("orig") if isinstance(token_word_map.get("orig", None), dict) else token_word_map if isinstance(token_word_map, dict) else None

                    if norm_map and token_idx in norm_map and norm_map[token_idx]:
                        main_norm = norm_map[token_idx]
                        for k, v in norm_map.items():
                            try:
                                if k != token_idx and v == main_norm:
                                    agg_indices.append(k)
                            except Exception:
                                continue
                    elif orig_map and token_idx in orig_map and orig_map[token_idx]:
                        # try normalized comparator if normalization fn available
                        try:
                            if _normalize_fn:
                                main_norm = _normalize_fn(orig_map[token_idx])
                                for k, v in orig_map.items():
                                    try:
                                        if k != token_idx and isinstance(v, str) and _normalize_fn(v) == main_norm:
                                            agg_indices.append(k)
                                    except Exception:
                                        continue
                            else:
                                # fallback: aggregate adjacent subword pieces heuristically
                                if token_idx - 1 >= 0 and not str(tokens[token_idx - 1]).startswith('‚ñÅ'):
                                    agg_indices.insert(0, token_idx - 1)
                                if token_idx + 1 < len(tokens) and not str(tokens[token_idx + 1]).startswith('‚ñÅ'):
                                    agg_indices.append(token_idx + 1)
                        except Exception:
                            pass
                    else:
                        # heuristic: include neighbors that are continuation pieces (do not start with word-start marker)
                        if token_idx - 1 >= 0 and not str(tokens[token_idx - 1]).startswith('‚ñÅ') and not str(tokens[token_idx - 1]).startswith('ƒ†'):
                            agg_indices.insert(0, token_idx - 1)
                        if token_idx + 1 < len(tokens) and not str(tokens[token_idx + 1]).startswith('‚ñÅ') and not str(tokens[token_idx + 1]).startswith('ƒ†'):
                            agg_indices.append(token_idx + 1)
            except Exception:
                # leave agg_indices as [token_idx]
                agg_indices = [token_idx]

            # Collect proto_probs, uncertainties, gates, spans for aggregated indices
            proto_tensors = []
            uncerts = []
            gates = []
            spans = []
            for i in agg_indices:
                p = self._safe_extract_proto_probs(i, dscd_outputs)  # torch tensor
                if not isinstance(p, torch.Tensor):
                    try:
                        p = torch.as_tensor(np.asarray(p, dtype=np.float32), dtype=torch.float32)
                    except Exception:
                        p = torch.tensor([1.0], dtype=torch.float32)
                proto_tensors.append(p.flatten())

                u = self._safe_extract_uncertainty(i, dscd_outputs)
                g = self._safe_extract_gate(i, dscd_outputs)
                s = self._safe_extract_span(i, dscd_outputs)
                try:
                    uncerts.append(float(u))
                except Exception:
                    uncerts.append(0.5)
                try:
                    gates.append(float(g))
                except Exception:
                    gates.append(0.0)
                try:
                    spans.append(float(s))
                except Exception:
                    spans.append(0.0)

            # Pad proto vectors to same length and average
            maxk = max([int(p.numel()) for p in proto_tensors]) if proto_tensors else 1
            padded = []
            for p in proto_tensors:
                if int(p.numel()) < maxk:
                    p2 = torch.zeros(maxk, dtype=torch.float32)
                    p2[:p.numel()] = p
                    padded.append(p2)
                else:
                    padded.append(p[:maxk])
            stacked = torch.stack(padded, dim=0)
            agg_proto = torch.mean(stacked, dim=0)
            # normalize if sum > 0
            try:
                ssum = float(torch.sum(agg_proto).item())
                if ssum > 0:
                    agg_proto = agg_proto / (ssum + 1e-12)
            except Exception:
                pass

            agg_uncert = float(sum(uncerts) / len(uncerts)) if uncerts else 0.5
            agg_gate = float(sum(gates) / len(gates)) if gates else 0.0
            agg_span = float(sum(spans) / len(spans)) if spans else 0.0

            # Build evidence tokens (context window)
            context_window = 2
            start_idx = max(0, token_idx - context_window)
            end_idx = min(len(tokens), token_idx + context_window + 1)
            evidence_tokens = []
            for i in range(start_idx, end_idx):
                if i == token_idx or i >= len(tokens):
                    continue
                rtok = tokens[i]
                clean_token = str(rtok).replace("‚ñÅ", "").replace("</w>", "").strip()
                # require word-start or token_word_map entry
                if not _is_word_start(rtok, token_word_map, i):
                    if token_word_map is None and len(clean_token) >= 2:
                        pass
                    else:
                        continue
                # validity
                if _has_is_valid_token:
                    try:
                        ok = is_valid_token(rtok, self.special_tokens, self.tokenizer, language=self.language)
                    except Exception:
                        ok = rtok not in self.special_tokens and len(clean_token) > 1
                else:
                    ok = rtok not in self.special_tokens and len(clean_token) > 1
                if not ok:
                    continue
                # prefer normalized/orig mapping when available
                chosen = None
                if token_word_map and isinstance(token_word_map, dict):
                    # try norm -> orig
                    try:
                        if isinstance(token_word_map.get("norm", None), dict) and i in token_word_map["norm"]:
                            chosen = token_word_map["norm"][i]
                        elif isinstance(token_word_map.get("orig", None), dict) and i in token_word_map["orig"]:
                            chosen = token_word_map["orig"][i]
                        elif i in token_word_map:
                            chosen = token_word_map[i]
                    except Exception:
                        chosen = None
                if chosen and isinstance(chosen, str) and chosen.strip():
                    evidence_tokens.append(chosen.strip())
                else:
                    evidence_tokens.append(clean_token)

            # dedupe & trim
            seen = set()
            dedup = []
            for t in evidence_tokens:
                if t not in seen:
                    seen.add(t)
                    dedup.append(t)
            evidence_tokens = dedup[:_TRG_EVIDENCE_K]

            # compute top senses
            top_senses = self._compute_sense_alternatives_fast(agg_proto)
            chosen_sense = top_senses[0] if len(top_senses) > 0 else ("unknown", 0.5)
            alternatives = top_senses[1:3] if len(top_senses) > 1 else []

            # token value prefer normalized form if available
            token_value = raw_token
            try:
                if token_word_map and isinstance(token_word_map, dict):
                    if isinstance(token_word_map.get("norm", None), dict) and token_idx in token_word_map["norm"]:
                        token_value = token_word_map["norm"][token_idx]
                    elif isinstance(token_word_map.get("orig", None), dict) and token_idx in token_word_map["orig"]:
                        token_value = token_word_map["orig"][token_idx]
                    elif token_idx in token_word_map:
                        token_value = token_word_map[token_idx]
                # fall back to normalize_fn if available
                if (_normalize_fn is not None) and (isinstance(token_value, str) and token_value.strip()):
                    token_value = _normalize_fn(token_value)
            except Exception:
                pass

            return {
                "token": token_value,
                "token_idx": token_idx,
                "evidence_tokens": evidence_tokens,
                "chosen_sense": chosen_sense,
                "alternatives": alternatives,
                "uncertainty": float(agg_uncert),
                "gate": float(agg_gate),
                "span": float(agg_span),
                "proto_probs": agg_proto,
            }
        except Exception as e:
            if _VERBOSE_LOGGING:
                import traceback as _tb
                print(f"[TRG] evidence extraction error at token {token_idx}: {_tb.format_exc().splitlines()[-1]}")
            return self._create_fallback_evidence(token_idx, tokens)

    # -------------------------
    # SAFE EXTRACTORS (robust shapes)
    # -------------------------
    def _safe_extract_proto_probs(self, token_idx: int, dscd_outputs: Dict) -> torch.Tensor:
        """
        Robust extraction of prototype probabilities for a single token.
        Returns 1D torch.Tensor.
        """
        try:
            pp_all = dscd_outputs.get("proto_probs", None) if isinstance(dscd_outputs, dict) else None
            if pp_all is None:
                return torch.tensor([1.0], dtype=torch.float32)

            # Torch tensor cases
            if isinstance(pp_all, torch.Tensor):
                p = pp_all.detach().cpu()
                if p.dim() == 3:
                    # B, T, K -> choose batch 0 if present
                    B, T, K = p.shape
                    if token_idx < T:
                        return p[0, token_idx, :].float()
                    else:
                        return p[0].max(dim=1)[0].float()
                elif p.dim() == 2:
                    # T,K or B,T (ambiguous). Heuristics:
                    if p.shape[0] > 1 and token_idx < p.size(0):
                        return p[token_idx, :].float()
                    elif token_idx < p.size(0):
                        return p[token_idx, :].float()
                    else:
                        return p.max(dim=1)[0].float()
                elif p.dim() == 1:
                    return p.float()
                else:
                    return torch.tensor([1.0], dtype=torch.float32)

            # numpy
            if isinstance(pp_all, np.ndarray):
                arr = pp_all
                if arr.ndim == 3:
                    if token_idx < arr.shape[1]:
                        return torch.from_numpy(arr[0, token_idx, :].astype(np.float32))
                elif arr.ndim == 2:
                    if token_idx < arr.shape[0]:
                        return torch.from_numpy(arr[token_idx].astype(np.float32))
                elif arr.ndim == 1:
                    return torch.from_numpy(arr.astype(np.float32))
                return torch.tensor([1.0], dtype=torch.float32)

            # list/tuple: many shapes possible
            if isinstance(pp_all, (list, tuple)):
                # If it's a batch list
                if len(pp_all) > 0 and isinstance(pp_all[0], (list, tuple, np.ndarray, torch.Tensor)):
                    first = pp_all[0]
                    if isinstance(first, torch.Tensor):
                        row = first.detach().cpu()
                        if row.dim() == 2 and token_idx < row.size(0):
                            return row[token_idx, :].float()
                        elif row.dim() == 1:
                            return row.float()
                    elif isinstance(first, np.ndarray):
                        if first.ndim >= 1 and token_idx < first.shape[0]:
                            return torch.from_numpy(first[token_idx].astype(np.float32))
                    elif isinstance(first, (list, tuple)):
                        # assume per-token arrays in first
                        if token_idx < len(first):
                            val = first[token_idx]
                            return torch.as_tensor(np.asarray(val, dtype=np.float32), dtype=torch.float32)
                    # fallback: if outer list length matches token count, use that
                    if token_idx < len(pp_all):
                        val = pp_all[token_idx]
                        if isinstance(val, torch.Tensor):
                            return val.detach().cpu().float()
                        else:
                            return torch.as_tensor(np.asarray(val, dtype=np.float32), dtype=torch.float32)
                # otherwise try flattening first element
                if token_idx < len(pp_all):
                    val = pp_all[token_idx]
                    if isinstance(val, torch.Tensor):
                        return val.detach().cpu().float()
                    else:
                        return torch.as_tensor(np.asarray(val, dtype=np.float32), dtype=torch.float32)

            # unknown -> fallback
            return torch.tensor([1.0], dtype=torch.float32)
        except Exception:
            if _VERBOSE_LOGGING:
                import traceback as _tb
                print("[TRG] _safe_extract_proto_probs failed:", _tb.format_exc().splitlines()[-1])
            return torch.tensor([1.0], dtype=torch.float32)

    def _safe_extract_uncertainty(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            U_all = dscd_outputs.get("uncertainties", None) if isinstance(dscd_outputs, dict) else None
            if U_all is None:
                return 0.5
            # tensor
            if isinstance(U_all, torch.Tensor):
                u = U_all.detach().cpu()
                if u.dim() == 2:
                    if token_idx < u.size(1):
                        return float(u[0, token_idx].item())
                    elif token_idx < u.size(0):
                        return float(u[token_idx].item())
                elif u.dim() == 1:
                    if token_idx < u.size(0):
                        return float(u[token_idx].item())
                return 0.5
            if isinstance(U_all, np.ndarray):
                if U_all.ndim >= 1 and token_idx < U_all.shape[0]:
                    return float(U_all[token_idx])
                return 0.5
            if isinstance(U_all, (list, tuple)):
                # try batch->row
                first = U_all[0] if len(U_all) > 0 else None
                if isinstance(first, (list, tuple, np.ndarray, torch.Tensor)):
                    row = first
                    if isinstance(row, torch.Tensor):
                        if row.dim() >= 1 and token_idx < row.size(0):
                            return float(row[token_idx].item())
                    elif isinstance(row, (list, tuple, np.ndarray)):
                        if token_idx < len(row):
                            return float(row[token_idx])
                if token_idx < len(U_all):
                    v = U_all[token_idx]
                    return float(v.item()) if isinstance(v, torch.Tensor) else float(v)
            return 0.5
        except Exception:
            if _VERBOSE_LOGGING:
                import traceback as _tb
                print("[TRG] _safe_extract_uncertainty failed:", _tb.format_exc().splitlines()[-1])
            return 0.5

    def _safe_extract_gate(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            G_all = dscd_outputs.get("gates", None) if isinstance(dscd_outputs, dict) else None
            if G_all is None:
                return 0.0
            if isinstance(G_all, torch.Tensor):
                g = G_all.detach().cpu()
                if g.dim() == 2:
                    if token_idx < g.size(1):
                        return float(g[0, token_idx].item())
                    elif token_idx < g.size(0):
                        return float(g[token_idx].item())
                elif g.dim() == 1:
                    if token_idx < g.size(0):
                        return float(g[token_idx].item())
                return 0.0
            if isinstance(G_all, np.ndarray):
                if G_all.ndim >= 1 and token_idx < G_all.shape[0]:
                    return float(G_all[token_idx])
                return 0.0
            if isinstance(G_all, (list, tuple)):
                first = G_all[0] if len(G_all) > 0 else None
                if isinstance(first, (list, tuple, np.ndarray, torch.Tensor)):
                    row = first
                    if isinstance(row, torch.Tensor):
                        if row.dim() >= 1 and token_idx < row.size(0):
                            return float(row[token_idx].item())
                    elif isinstance(row, (list, tuple, np.ndarray)):
                        if token_idx < len(row):
                            return float(row[token_idx])
                if token_idx < len(G_all):
                    v = G_all[token_idx]
                    return float(v.item()) if isinstance(v, torch.Tensor) else float(v)
            return 0.0
        except Exception:
            if _VERBOSE_LOGGING:
                import traceback as _tb
                print("[TRG] _safe_extract_gate failed:", _tb.format_exc().splitlines()[-1])
            return 0.0

    def _safe_extract_span(self, token_idx: int, dscd_outputs: Dict) -> float:
        try:
            S_all = dscd_outputs.get("span_preds", None) if isinstance(dscd_outputs, dict) else None
            if S_all is None:
                return 0.0
            if isinstance(S_all, torch.Tensor):
                s = S_all.detach().cpu()
                if s.dim() == 2:
                    if token_idx < s.size(1):
                        return float(s[0, token_idx].item())
                    elif token_idx < s.size(0):
                        return float(s[token_idx].item())
                elif s.dim() == 1 and token_idx < s.size(0):
                    return float(s[token_idx].item())
                return 0.0
            if isinstance(S_all, np.ndarray):
                if S_all.ndim >= 1 and token_idx < S_all.shape[0]:
                    return float(S_all[token_idx])
                return 0.0
            if isinstance(S_all, (list, tuple)):
                first = S_all[0] if len(S_all) > 0 else None
                if isinstance(first, torch.Tensor):
                    row = first
                    if row.dim() >= 1 and token_idx < row.size(0):
                        return float(row[token_idx].item())
                elif isinstance(first, (list, tuple, np.ndarray)):
                    row = first
                    if token_idx < len(row):
                        return float(row[token_idx])
                if token_idx < len(S_all):
                    v = S_all[token_idx]
                    return float(v.item()) if isinstance(v, torch.Tensor) else float(v)
            return 0.0
        except Exception:
            if _VERBOSE_LOGGING:
                import traceback as _tb
                print("[TRG] _safe_extract_span failed:", _tb.format_exc().splitlines()[-1])
            return 0.0

    # -------------------------
    # SENSE / UTIL
    # -------------------------
    def compute_span(self, sense_probs: Any) -> float:
        try:
            if isinstance(sense_probs, dict):
                probs = list(sense_probs.values())
            else:
                probs = sense_probs

            if isinstance(probs, torch.Tensor):
                probs = probs.detach().cpu().numpy().flatten().tolist()
            elif isinstance(probs, np.ndarray):
                probs = probs.flatten().tolist()
            elif isinstance(probs, (list, tuple)):
                probs = list(probs)
            else:
                return 0.0

            if len(probs) < 2:
                return 0.0
            sorted_probs = sorted([float(x) for x in probs], reverse=True)
            span = float(sorted_probs[0]) - float(sorted_probs[1])
            return max(0.0, span)
        except Exception:
            if _VERBOSE_LOGGING:
                print("[TRG] compute_span error")
            return 0.0

    def _compute_sense_alternatives_fast(self, proto_probs: torch.Tensor) -> List[Tuple[str, float]]:
        try:
            if not isinstance(proto_probs, torch.Tensor):
                proto_probs = torch.as_tensor(np.asarray(proto_probs, dtype=np.float32))
            probs = proto_probs.flatten()
            if probs.numel() > 1:
                probs_sorted, indices = torch.sort(probs, descending=True)
                top_k = min(3, int(indices.numel()))
                return [(f"sense_{int(indices[i].item())}", float(probs_sorted[i].item())) for i in range(top_k)]
            else:
                return [("sense_0", float(probs[0].item()))]
        except Exception:
            return [("unknown", 0.5)]

    def _create_fallback_evidence(self, token_idx: int, tokens: List[str]) -> Dict:
        token = tokens[token_idx] if isinstance(tokens, list) and 0 <= token_idx < len(tokens) else "UNK"
        return {
            "token": token,
            "token_idx": token_idx,
            "evidence_tokens": [],
            "chosen_sense": ("unknown", 0.5),
            "alternatives": [],
            "uncertainty": 0.5,
            "gate": 0.0,
            "span": 0.0,
            "proto_probs": torch.tensor([1.0], dtype=torch.float32),
        }


class CompleteTRGWithExplanations(nn.Module):
    """
    Inference-only disambiguation and explanation component.
    """

    def __init__(self, embed_dim: Optional[int] = None, tokenizer=None, language: str = "bn"):
        super().__init__()
        self.embed_dim = int(embed_dim) if embed_dim is not None else int(_TRG_GEN_EMBED)
        self.tokenizer = tokenizer
        self.language = language

        # Cache special tokens robustly
        if tokenizer is not None:
            try:
                if _has_get_tokenizer_special_tokens:
                    self.special_tokens = globals()["get_tokenizer_special_tokens"](tokenizer)
                elif _has_get_cached_special_tokens:
                    self.special_tokens = globals()["get_cached_special_tokens"](tokenizer)
                else:
                    self.special_tokens = set(getattr(tokenizer, "all_special_tokens", []))
            except Exception:
                self.special_tokens = set()
        else:
            self.special_tokens = set()

        self.template_system = ComprehensiveTRGExplanationTemplate()
        self.evidence_extractor = MemoryEfficientTRGExtractor(tokenizer, language=language)

        self.silver_buffer = deque(maxlen=int(_MAX_SILVER_BUFFER))
        self.stats = {
            "explanations_generated": 0,
            "high_confidence_explanations": 0,
            "low_confidence_explanations": 0,
        }

        if _VERBOSE_LOGGING:
            print("[TRG] system initialized (inference-only, multi-GPU compatible)")

    def generate_explanation_for_token(
        self,
        token_idx: int,
        tokens: List[str],
        dscd_outputs: Dict,
        token_word_map: Optional[dict] = None,
    ) -> Tuple[str, Dict]:
        """Generate an explanation string and its evidence for a single token."""
        # Feature flag must be enabled
        if not _ENABLE_TRG_INFERENCE:
            return "", {}

        if not isinstance(tokens, list) or token_idx < 0 or token_idx >= len(tokens):
            return "", {}

        raw_token = tokens[token_idx]
        if _has_is_valid_token:
            try:
                is_valid = is_valid_token(raw_token, self.special_tokens, self.tokenizer, language=self.language)
            except Exception:
                is_valid = raw_token not in self.special_tokens and len(str(raw_token)) >= 2
        else:
            is_valid = raw_token not in self.special_tokens and len(str(raw_token)) >= 2

        if not is_valid:
            return "", {}

        try:
            evidence = self.evidence_extractor.extract_evidence_efficiently(token_idx, tokens, dscd_outputs, token_word_map=token_word_map)
            if not evidence:
                return "", {}
            explanation_text = self.template_system.generate_explanation(evidence)
            self._update_stats(evidence)
            self._add_to_silver_buffer(evidence, explanation_text, tokens)
            return explanation_text, evidence
        except Exception as e:
            if _VERBOSE_LOGGING:
                import traceback as _tb
                print(f"[TRG] generate_explanation error at token {token_idx}: {_tb.format_exc().splitlines()[-1]}")
            return "", {}

    def process_sentence_for_explanations(
        self,
        tokens: List[str],
        dscd_outputs: Dict,
        token_word_map: Optional[dict] = None,
        uncertainty_threshold: Optional[float] = None,
        top_k: int = 3,
    ) -> List[Dict]:
        """Select up to top_k tokens and generate explanations for them."""
        if not _ENABLE_TRG_INFERENCE:
            return []

        if uncertainty_threshold is None:
            uncertainty_threshold = float(_TRG_UNCERTAINTY_THRESHOLD)

        strict_uncertainty = max(0.40, uncertainty_threshold)

        explanations: List[Dict] = []
        try:
            if not tokens or not isinstance(dscd_outputs, dict):
                return explanations

            U_all = dscd_outputs.get("uncertainties", [])
            S_all = dscd_outputs.get("span_preds", [])

            # Normalize U and S to python lists safely
            def _to_list_safe(x):
                if isinstance(x, torch.Tensor):
                    x = x.detach().cpu()
                    if x.dim() == 2:
                        return [float(v) for v in x[0].tolist()]
                    elif x.dim() == 1:
                        return [float(v) for v in x.tolist()]
                    else:
                        return []
                if isinstance(x, (list, tuple, np.ndarray)):
                    out = []
                    for v in x:
                        if isinstance(v, torch.Tensor):
                            out.append(float(v.item()))
                        else:
                            try:
                                out.append(float(v))
                            except Exception:
                                out.append(0.0)
                    return out
                return []

            U = _to_list_safe(U_all[0]) if (isinstance(U_all, (list, tuple)) and len(U_all) > 0 and (isinstance(U_all[0], (list, tuple, torch.Tensor, np.ndarray)))) else _to_list_safe(U_all)
            S = _to_list_safe(S_all[0]) if (isinstance(S_all, (list, tuple)) and len(S_all) > 0 and (isinstance(S_all[0], (list, tuple, torch.Tensor, np.ndarray)))) else _to_list_safe(S_all)
            if not U:
                return explanations

            # Collect candidates
            candidates: List[Tuple[int, float, float]] = []
            for idx in range(min(len(tokens), len(U))):
                tok = tokens[idx]

                if not _is_word_start(tok, token_word_map, idx):
                    if token_word_map is None:
                        if not isinstance(tok, str) or len(tok.replace("‚ñÅ", "").replace("ƒ†", "")) < 2:
                            continue
                    else:
                        continue

                if _has_is_valid_token:
                    try:
                        valid = is_valid_token(tok, self.special_tokens, self.tokenizer, language=self.language)
                    except Exception:
                        valid = tok not in self.special_tokens and len(str(tok)) >= 2
                else:
                    valid = tok not in self.special_tokens and len(str(tok)) >= 2
                if not valid:
                    continue

                u = float(U[idx]) if idx < len(U) else 0.5
                s = float(S[idx]) if idx < len(S) else 0.0

                probs = self.evidence_extractor._safe_extract_proto_probs(idx, dscd_outputs)
                has_multi_sense = isinstance(probs, torch.Tensor) and probs.numel() >= 2

                if not (has_multi_sense or (s > 0.3) or (u > strict_uncertainty)):
                    continue

                candidates.append((idx, u, s))

            if not candidates:
                return explanations

            # Prioritize and select
            span_first = [c for c in candidates if c[2] > 0.3]
            span_first.sort(key=lambda t: (t[2], t[1]), reverse=True)

            uncertain = [c for c in candidates if c[1] > strict_uncertainty]
            uncertain.sort(key=lambda t: t[1], reverse=True)

            selected = []
            selected.extend(span_first)
            for t in uncertain:
                if t not in selected:
                    selected.append(t)
                if len(selected) >= top_k:
                    break

            if not selected and candidates:
                candidates.sort(key=lambda t: (t[2], t[1]), reverse=True)
                selected = candidates[:max(1, top_k)]

            for (token_idx, u, s) in selected[:top_k]:
                try:
                    explanation_text, evidence = self.generate_explanation_for_token(token_idx, tokens, dscd_outputs, token_word_map=token_word_map)
                    if explanation_text and evidence:
                        token_label = token_word_map[token_idx] if (token_word_map and isinstance(token_word_map, dict) and token_idx in token_word_map) else tokens[token_idx].replace("‚ñÅ", "")
                        explanations.append({
                            "token_idx": token_idx,
                            "token": token_label,
                            "explanation": explanation_text,
                            "uncertainty": u,
                            "span": s,
                        })
                except Exception as e:
                    if _VERBOSE_LOGGING:
                        import traceback as _tb
                        print(f"[TRG] explanation generation failure @ idx {token_idx}: {_tb.format_exc().splitlines()[-1]}")
                    continue

        except Exception as e:
            if _VERBOSE_LOGGING:
                import traceback as _tb
                print(f"[TRG] sentence processing error: {_tb.format_exc().splitlines()[-1]}")

        return explanations

    def _update_stats(self, evidence: Dict):
        try:
            self.stats["explanations_generated"] += 1
            confidence = 0.5
            chosen = evidence.get("chosen_sense")
            if isinstance(chosen, (tuple, list)) and len(chosen) >= 2:
                try:
                    confidence = float(chosen[1])
                except Exception:
                    confidence = 0.5

            if confidence >= 0.65:
                self.stats["high_confidence_explanations"] += 1
            elif confidence < 0.4:
                self.stats["low_confidence_explanations"] += 1
        except Exception:
            pass

    def _add_to_silver_buffer(self, evidence: Dict, explanation: str, tokens: List[str]):
        try:
            conf = 0.5
            chosen = evidence.get("chosen_sense")
            if isinstance(chosen, (tuple, list)) and len(chosen) >= 2:
                conf = float(chosen[1])
            entry = {
                "token": str(evidence.get("token", "UNK"))[:20],
                "explanation": str(explanation)[:150],
                "confidence": conf,
            }
            self.silver_buffer.append(entry)
        except Exception:
            pass

    def get_statistics(self) -> Dict:
        total = max(self.stats.get("explanations_generated", 0), 1)
        return {
            **self.stats,
            "high_confidence_rate": self.stats.get("high_confidence_explanations", 0) / total,
            "low_confidence_rate": self.stats.get("low_confidence_explanations", 0) / total,
            "silver_buffer_size": len(self.silver_buffer),
        }


print("‚úÖ Cell 5 (patched): TRG explanation system ready (robust aggregation & safe proto handling)")

‚úÖ Cell 5 (patched): TRG explanation system ready (robust aggregation & safe proto handling)


In [10]:
# ==============================================================================
# CELL 6 (fixed): ‚ö° OPTIMIZED TATN MODEL WITH GRADIENT CHECKPOINTING (M2M100 418M)
# - Thorough line-by-line hardening and practical fixes
# - Integrates bn_normalizer usage for normalized token_word_map
# - Ensures TRG instance is placed into eval() mode by default for inference
# - Makes DSCD/ASBN/TRG instantiation robust and non-fatal
# - Device-safe handling of all intermediate tensors and shapes
# ==============================================================================
from typing import List, Dict, Optional, Any
import traceback
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

# Attempt to import transformers model class (best-effort).
try:
    from transformers import M2M100ForConditionalGeneration
    from transformers.modeling_outputs import BaseModelOutput
    _HAS_TRANSFORMERS = True
except Exception:
    M2M100ForConditionalGeneration = None
    BaseModelOutput = None
    _HAS_TRANSFORMERS = False

# -----------------------------------------------------------------------------
# Defensive global fallback helpers
# -----------------------------------------------------------------------------
def _get_int_global(name: str, default: int) -> int:
    try:
        v = globals().get(name, default)
        return int(v) if v is not None else default
    except Exception:
        return default

def _get_float_global(name: str, default: float) -> float:
    try:
        v = globals().get(name, default)
        return float(v) if v is not None else default
    except Exception:
        return default

def _get_bool_global(name: str, default: bool) -> bool:
    try:
        v = globals().get(name, default)
        return bool(v)
    except Exception:
        return default

# Read globals safely (Cell 0 may not have run)
_DSCD_BUFFER_SIZE = _get_int_global('DSCD_BUFFER_SIZE', 20)
_DSCD_MAX_PROTOS = _get_int_global('DSCD_MAX_PROTOS', 8)
_DSCD_N_MIN = _get_int_global('DSCD_N_MIN', 5)
_DSCD_DISPERSION_THRESHOLD = _get_float_global('DSCD_DISPERSION_THRESHOLD', 0.25)
_SOURCE_LANGUAGE = globals().get('SOURCE_LANGUAGE', 'bn')
_ENABLE_ASBN_TRAINING = _get_bool_global('ENABLE_ASBN_TRAINING', True)
_ENABLE_TRG_INFERENCE = _get_bool_global('ENABLE_TRG_INFERENCE', True)
_MEMORY_CLEANUP_FREQUENCY = _get_int_global('MEMORY_CLEANUP_FREQUENCY', 100)
_NUM_GPUS = _get_int_global('NUM_GPUS', torch.cuda.device_count() if torch.cuda.is_available() else 0)
_USE_GC = _get_bool_global('GRADIENT_CHECKPOINTING', False)
_DSCD_ENABLE_TRAINING_CLUSTERING = _get_bool_global('DSCD_ENABLE_TRAINING_CLUSTERING', False)
_LAMBDA_ASBN = _get_float_global('LAMBDA_ASBN', 0.10)
_LAMBDA_DSCD = _get_float_global('LAMBDA_DSCD', 0.05)
_VERBOSE_LOGGING = _get_bool_global('VERBOSE_LOGGING', False)

_has_reconstruct_word_spans = 'reconstruct_word_spans' in globals()
_normalize_fn = globals().get("normalize_bn_word", None)  # bn_normalizer cell (may be present)

# -----------------------------------------------------------------------------
# Safe helper to obtain last hidden state from various HF encoder outputs
# -----------------------------------------------------------------------------
def _safe_get_last_hidden_state(enc_output: Any) -> Optional[torch.Tensor]:
    try:
        if enc_output is None:
            return None
        # HuggingFace BaseModelOutput
        if hasattr(enc_output, 'last_hidden_state'):
            return enc_output.last_hidden_state
        # tuple/list like (last_hidden_state, ...)
        if isinstance(enc_output, (list, tuple)) and len(enc_output) > 0:
            cand = enc_output[0]
            if isinstance(cand, torch.Tensor):
                return cand
        # dict-like
        if isinstance(enc_output, dict) and 'last_hidden_state' in enc_output:
            return enc_output['last_hidden_state']
    except Exception:
        if _VERBOSE_LOGGING:
            print("[TATN] _safe_get_last_hidden_state error:", traceback.format_exc().splitlines()[-1])
    return None

# -----------------------------------------------------------------------------
# Normalize DSCD outputs into canonical, CPU/device-consistent structures
# -----------------------------------------------------------------------------
def _normalize_dscd_outputs(raw: Dict[str, Any],
                            batch_size: int,
                            seq_len: int,
                            device: torch.device,
                            embed_dim: int) -> Dict[str, Any]:
    """
    Defensive normalization of DSCD raw outputs into canonical forms:
      - proto_probs: List[List[Tensor]] indexed [B][T] (each entry 1D tensor)
      - uncertainties/gates/span_preds: List[List[Tensor]] [B][T] (scalars as 0-d/1-d tensors)
      - proto_assignments: List[Tensor] length B each [T] (long)
      - h_augmented: Tensor [B, T, H] or zeros fallback
    This function never raises; logs only when VERBOSE_LOGGING=True.
    """
    def _log(msg: str):
        if _VERBOSE_LOGGING:
            print("[DSCD-NORM]", msg)

    # defaults
    proto_probs = [[torch.tensor([1.0], dtype=torch.float32, device=device) for _ in range(seq_len)] for _ in range(batch_size)]
    uncertainties = [[torch.tensor(0.0, dtype=torch.float32, device=device) for _ in range(seq_len)] for _ in range(batch_size)]
    gates = [[torch.tensor(0.0, dtype=torch.float32, device=device) for _ in range(seq_len)] for _ in range(batch_size)]
    span_preds = [[torch.tensor(0.0, dtype=torch.float32, device=device) for _ in range(seq_len)] for _ in range(batch_size)]
    proto_assignments = [torch.zeros(seq_len, dtype=torch.long, device=device) for _ in range(batch_size)]
    h_aug = None

    try:
        if not isinstance(raw, dict):
            _log("raw DSCD output not a dict; using defaults")
            raw = {} if raw is None else dict(raw)

        # --- h_augmented
        h_raw = raw.get('h_augmented', None)
        if isinstance(h_raw, torch.Tensor):
            try:
                if h_raw.dim() == 3 and int(h_raw.size(0)) == batch_size and int(h_raw.size(1)) == seq_len:
                    h_aug = h_raw.to(device)
                else:
                    # coerce as much as possible
                    tmp = torch.zeros(batch_size, seq_len, embed_dim, device=device, dtype=h_raw.dtype)
                    max_b = min(batch_size, int(h_raw.size(0)))
                    for b in range(max_b):
                        row = h_raw[b]
                        if isinstance(row, torch.Tensor) and row.dim() >= 2:
                            L = min(seq_len, int(row.size(0)))
                            tmp[b, :L] = row[:L].to(device)
                    h_aug = tmp
            except Exception:
                _log("h_aug coercion from tensor failed; fallback to None")
                h_aug = None
        elif isinstance(h_raw, (list, tuple, np.ndarray)):
            try:
                stacked = []
                for b in range(min(batch_size, len(h_raw))):
                    row = h_raw[b]
                    if isinstance(row, torch.Tensor):
                        stacked.append(row.to(device))
                    else:
                        stacked.append(torch.as_tensor(row, device=device))
                if stacked:
                    tensor = torch.stack(stacked, dim=0)
                    if tensor.dim() == 3:
                        tmp = torch.zeros(batch_size, seq_len, embed_dim, device=device, dtype=tensor.dtype)
                        for b in range(min(batch_size, tensor.size(0))):
                            L = min(seq_len, int(tensor.size(1)))
                            tmp[b, :L] = tensor[b, :L]
                        h_aug = tmp
            except Exception:
                _log("h_aug list coercion failed; fallback to None")
                h_aug = None

        # --- proto_probs
        try:
            pp = raw.get('proto_probs', None)
            if pp is not None:
                def _to_tensor(v):
                    try:
                        if isinstance(v, torch.Tensor):
                            return v.detach().cpu().float()
                        else:
                            a = np.asarray(v, dtype=np.float32)
                            return torch.from_numpy(a).cpu().float()
                    except Exception:
                        return torch.tensor([1.0], dtype=torch.float32)

                if isinstance(pp, torch.Tensor):
                    p = pp.detach().cpu()
                    if p.dim() == 3:
                        B, T, K = p.shape
                        for b in range(min(batch_size, int(B))):
                            for t in range(min(seq_len, int(T))):
                                proto_probs[b][t] = _to_tensor(p[b, t])
                    elif p.dim() == 2:
                        # either [B, T] or [T, K]
                        if int(p.size(0)) == batch_size:
                            for b in range(batch_size):
                                for t in range(min(seq_len, int(p.size(1)))):
                                    proto_probs[b][t] = _to_tensor(p[b, t])
                        elif batch_size == 1:
                            for t in range(min(seq_len, int(p.size(0)))):
                                proto_probs[0][t] = _to_tensor(p[t])
                    elif p.dim() == 1:
                        for t in range(min(seq_len, int(p.size(0)))):
                            proto_probs[0][t] = _to_tensor(p[t])
                elif isinstance(pp, (list, tuple)):
                    if len(pp) == batch_size:
                        for b in range(batch_size):
                            row = pp[b]
                            if isinstance(row, torch.Tensor):
                                r = row.detach().cpu()
                                if r.dim() == 2:
                                    for t in range(min(seq_len, int(r.size(0)))):
                                        proto_probs[b][t] = _to_tensor(r[t])
                                elif r.dim() == 1:
                                    for t in range(min(seq_len, int(r.size(0)))):
                                        proto_probs[b][t] = _to_tensor(r[t])
                            else:
                                for t in range(min(seq_len, len(row))):
                                    proto_probs[b][t] = _to_tensor(row[t])
                    elif batch_size == 1:
                        for t in range(min(seq_len, len(pp))):
                            proto_probs[0][t] = _to_tensor(pp[t])
                    else:
                        for t in range(min(seq_len, len(pp))):
                            proto_probs[0][t] = _to_tensor(pp[t])
        except Exception as e:
            _log(f"proto_probs parsing failed: {e}")

        # --- uncertainties/gates/span_preds normalization helper
        def _normalize_scalar_matrix(key: str, target):
            try:
                val = raw.get(key, None)
                if val is None:
                    return
                if isinstance(val, torch.Tensor):
                    m = val.detach().cpu()
                    if m.dim() == 3 and int(m.size(0)) == batch_size:
                        for b in range(batch_size):
                            for t in range(min(seq_len, int(m.size(1)))):
                                target[b][t] = torch.tensor(float(m[b, t].item()), device=device)
                    elif m.dim() == 2:
                        if int(m.size(0)) == batch_size:
                            for b in range(batch_size):
                                for t in range(min(seq_len, int(m.size(1)))):
                                    target[b][t] = torch.tensor(float(m[b, t].item()), device=device)
                        elif batch_size == 1:
                            for t in range(min(seq_len, int(m.size(0)))):
                                target[0][t] = torch.tensor(float(m[t].item()), device=device)
                    elif m.dim() == 1 and batch_size == 1:
                        for t in range(min(seq_len, int(m.size(0)))):
                            target[0][t] = torch.tensor(float(m[t].item()), device=device)
                elif isinstance(val, (list, tuple, np.ndarray)):
                    if len(val) == batch_size:
                        for b in range(batch_size):
                            row = val[b]
                            if isinstance(row, torch.Tensor):
                                r = row.detach().cpu()
                                for t in range(min(seq_len, int(r.size(0)))):
                                    target[b][t] = torch.tensor(float(r[t].item()), device=device)
                            else:
                                for t in range(min(seq_len, len(row))):
                                    try:
                                        target[b][t] = torch.tensor(float(row[t]), device=device)
                                    except Exception:
                                        pass
                    elif batch_size == 1:
                        row = val
                        for t in range(min(seq_len, len(row))):
                            try:
                                target[0][t] = torch.tensor(float(row[t]), device=device)
                            except Exception:
                                pass
            except Exception as e:
                _log(f"{key} normalization failed: {e}")

        _normalize_scalar_matrix('uncertainties', uncertainties)
        _normalize_scalar_matrix('gates', gates)
        _normalize_scalar_matrix('span_preds', span_preds)

        # --- proto_assignments normalization
        try:
            pa = raw.get('proto_assignments', None)
            if pa is not None:
                if isinstance(pa, list) and len(pa) == batch_size:
                    for b in range(batch_size):
                        row = pa[b]
                        try:
                            if isinstance(row, torch.Tensor):
                                arr = row.detach().cpu().long().view(-1)
                                if arr.numel() < seq_len:
                                    pad = torch.zeros(seq_len - arr.numel(), dtype=torch.long, device=device)
                                    proto_assignments[b] = torch.cat([arr.to(device), pad], dim=0)
                                else:
                                    proto_assignments[b] = arr[:seq_len].to(device)
                            else:
                                arr = torch.as_tensor(row, dtype=torch.long, device=device).view(-1)
                                if arr.numel() < seq_len:
                                    pad = torch.zeros(seq_len - arr.numel(), dtype=torch.long, device=device)
                                    proto_assignments[b] = torch.cat([arr, pad], dim=0)
                                else:
                                    proto_assignments[b] = arr[:seq_len]
                        except Exception:
                            proto_assignments[b] = torch.zeros(seq_len, dtype=torch.long, device=device)
                elif isinstance(pa, torch.Tensor):
                    p = pa.detach().cpu().long()
                    if p.dim() == 2 and int(p.size(0)) == batch_size:
                        for b in range(batch_size):
                            arr = p[b].view(-1)
                            if arr.numel() < seq_len:
                                pad = torch.zeros(seq_len - arr.numel(), dtype=torch.long, device=device)
                                proto_assignments[b] = torch.cat([arr.to(device), pad], dim=0)
                            else:
                                proto_assignments[b] = arr[:seq_len].to(device)
                    elif p.dim() == 1 and batch_size == 1:
                        arr = p.view(-1)
                        if arr.numel() < seq_len:
                            pad = torch.zeros(seq_len - arr.numel(), dtype=torch.long, device=device)
                            proto_assignments[0] = torch.cat([arr.to(device), pad], dim=0)
                        else:
                            proto_assignments[0] = arr[:seq_len].to(device)
        except Exception as e:
            _log(f"proto_assignments parse failed: {e}")

    except Exception as outer:
        _log(f"overall normalization failure: {outer}")

    # final fallback for h_aug
    if h_aug is None:
        h_aug = torch.zeros(batch_size, seq_len, embed_dim, device=device, dtype=torch.float32)

    return {
        'proto_probs': proto_probs,
        'uncertainties': uncertainties,
        'gates': gates,
        'span_preds': span_preds,
        'proto_assignments': proto_assignments,
        'h_augmented': h_aug
    }

# -----------------------------------------------------------------------------
# Main model wrapper
# -----------------------------------------------------------------------------
class MemoryOptimizedTATNWithExplanations(nn.Module):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer
        self.global_step = 0

        # Attempt to load backbone model only if transformers available
        self.mbart = None
        if _HAS_TRANSFORMERS and M2M100ForConditionalGeneration is not None:
            try:
                if os.environ.get("SKIP_MODEL_LOAD", "0") != "1":
                    self.mbart = M2M100ForConditionalGeneration.from_pretrained(
                        "facebook/m2m100_418M",
                        torch_dtype=torch.float32,
                        use_cache=False
                    )
                    try:
                        self.mbart.config.use_cache = False
                    except Exception:
                        pass
                    try:
                        if _USE_GC and hasattr(self.mbart, "gradient_checkpointing_enable"):
                            self.mbart.gradient_checkpointing_enable()
                    except Exception:
                        pass
            except Exception:
                if _VERBOSE_LOGGING:
                    print("[TATN] Could not load M2M100 model:", traceback.format_exc().splitlines()[-1])
                self.mbart = None
        else:
            if _VERBOSE_LOGGING:
                print("[TATN] transformers or model class missing; running without backbone")

        # embed_dim fallback
        embed_dim = 512
        try:
            if self.mbart is not None:
                embed_dim = int(getattr(self.mbart.config, "d_model", embed_dim))
        except Exception:
            pass

        # Initialize DSCD (class must be defined in Cell 3)
        DSC_CLASS = globals().get('MemoryEfficientDSCDOnline', None)
        if callable(DSC_CLASS):
            try:
                self.dscd = DSC_CLASS(
                    embed_dim=embed_dim,
                    tokenizer=tokenizer,
                    buffer_size=_DSCD_BUFFER_SIZE,
                    max_protos=_DSCD_MAX_PROTOS,
                    n_min=_DSCD_N_MIN,
                    language=_SOURCE_LANGUAGE,
                    dispersion_threshold=_DSCD_DISPERSION_THRESHOLD,
                    enable_training_clustering=_DSCD_ENABLE_TRAINING_CLUSTERING,
                    max_clustering_points=500,
                    max_candidates_per_step=1
                )
            except Exception:
                if _VERBOSE_LOGGING:
                    print("[TATN] DSCD instantiation failed:", traceback.format_exc().splitlines()[-1])
                self.dscd = None
        else:
            self.dscd = None

        # ASBN: instantiate if available (Cell 4)
        ASBN_CLASS = globals().get('MemoryEfficientASBNModule', None)
        if callable(ASBN_CLASS):
            try:
                self.asbn = ASBN_CLASS(embed_dim, tokenizer, language=_SOURCE_LANGUAGE)
            except Exception:
                if _VERBOSE_LOGGING:
                    print("[TATN] ASBN instantiation failed:", traceback.format_exc().splitlines()[-1])
                self.asbn = None
        else:
            self.asbn = None

        # TRG system: instantiate if available (Cell 5)
        TRG_CLASS = globals().get('CompleteTRGWithExplanations', None)
        if callable(TRG_CLASS):
            try:
                self.trg_system = TRG_CLASS(embed_dim, tokenizer, language=_SOURCE_LANGUAGE)
                # ENSURE TRG is in eval mode for inference so it doesn't early-return
                try:
                    self.trg_system.eval()
                except Exception:
                    try:
                        self.trg_system.training = False
                    except Exception:
                        pass
            except Exception:
                if _VERBOSE_LOGGING:
                    print("[TATN] TRG instantiation failed:", traceback.format_exc().splitlines()[-1])
                self.trg_system = None
        else:
            self.trg_system = None

    # entropy regularizer helper
    @staticmethod
    def _entropy_reg_from_proto_probs_static(proto_probs_list, gates_list=None, min_gate=0.0):
        dev = None
        try:
            if isinstance(proto_probs_list, list):
                for row in proto_probs_list:
                    if isinstance(row, list):
                        for p in row:
                            if isinstance(p, torch.Tensor):
                                dev = p.device
                                break
                    if dev is not None:
                        break
        except Exception:
            pass
        if dev is None:
            dev = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        total = torch.tensor(0.0, device=dev)
        count = 0
        try:
            for b, row in enumerate(proto_probs_list or []):
                if not isinstance(row, list):
                    continue
                gl = gates_list[b] if (gates_list and b < len(gates_list)) else None
                for j, probs in enumerate(row):
                    try:
                        if not isinstance(probs, torch.Tensor) or probs.numel() == 0:
                            continue
                        if gl and j < len(gl):
                            if float(gl[j]) < min_gate:
                                continue
                        p = torch.clamp(probs.to(dev), 1e-8, 1.0)
                        H = -torch.sum(p * torch.log(p))
                        total = total + H
                        count += 1
                    except Exception:
                        continue
        except Exception:
            pass
        if count == 0:
            return torch.tensor(0.0, device=dev)
        return total / count

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None,
        labels: Optional[torch.Tensor] = None,
    ):
        self.global_step += 1

        if input_ids is None or attention_mask is None:
            raise ValueError("input_ids and attention_mask cannot be None")
        if input_ids.dim() != 2 or attention_mask.dim() != 2:
            raise ValueError(f"Expected 2D tensors for input_ids/attention_mask, got {input_ids.shape}, {attention_mask.shape}")

        batch_size, seq_len = int(input_ids.size(0)), int(input_ids.size(1))
        device = input_ids.device

        # periodic GPU cleanup
        if torch.cuda.is_available() and (self.global_step % max(1, _MEMORY_CLEANUP_FREQUENCY) == 0):
            try:
                torch.cuda.empty_cache()
            except Exception:
                pass

        # Encoder forward
        enc_outputs = None
        try:
            if self.mbart is not None:
                # Prefer calling model.encoder or get_encoder if available
                try:
                    if hasattr(self.mbart, "get_encoder"):
                        enc = self.mbart.get_encoder()
                        enc_outputs = enc(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
                    elif hasattr(self.mbart, "model") and hasattr(self.mbart.model, "encoder"):
                        enc_outputs = self.mbart.model.encoder(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
                    else:
                        # as a last resort, run the full model and take its encoder output (costly)
                        full = self.mbart(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
                        enc_outputs = getattr(full, "encoder_last_hidden_state", None) or _safe_get_last_hidden_state(full)
                except Exception:
                    enc_outputs = None
            else:
                enc_outputs = None
        except Exception:
            enc_outputs = None
            if _VERBOSE_LOGGING:
                print("[TATN] Encoder forward failed:", traceback.format_exc().splitlines()[-1])

        h = _safe_get_last_hidden_state(enc_outputs)
        if h is None:
            try:
                if self.mbart is not None and hasattr(self.mbart, "get_input_embeddings"):
                    emb = self.mbart.get_input_embeddings()(input_ids).to(device)
                    h = emb
                else:
                    h = torch.zeros(batch_size, seq_len, 512, device=device)
            except Exception:
                h = torch.zeros(batch_size, seq_len, 512, device=device)

        embed_dim = int(h.size(-1))

        training_mode = (labels is not None and self.training)

        if token_word_map is None:
            token_word_map = [{} for _ in range(batch_size)]

        # DSCD forward
        raw_dscd = {}
        try:
            if self.dscd is not None:
                raw_dscd = self.dscd.forward(h, token_types=None, train_mode=self.training,
                                             input_ids=input_ids, attention_mask=attention_mask,
                                             token_word_map=token_word_map)
            else:
                raw_dscd = {}
        except Exception:
            if _VERBOSE_LOGGING:
                print("[TATN] DSCD forward failed; using fallback:", traceback.format_exc().splitlines()[-1])
            raw_dscd = {}

        # Normalize DSCD outputs into canonical structure
        dscd = _normalize_dscd_outputs(raw_dscd, batch_size, seq_len, device, embed_dim)
        h_aug = dscd.get('h_augmented', h)
        if not isinstance(h_aug, torch.Tensor) or h_aug.shape != h.shape:
            h_aug = h

        # embedding-based fallback for spans if DSCD did not provide meaningful spans
        try:
            span_missing = True
            for b in range(batch_size):
                row = dscd['span_preds'][b]
                if any(float(x) > 1e-6 for x in row):
                    span_missing = False
                    break
            if span_missing:
                norms = torch.norm(h_aug, dim=-1)
                for b in range(batch_size):
                    n = norms[b]
                    if n.numel() == 0 or torch.all(n == 0):
                        continue
                    mn = float(n.min().item())
                    mx = float(n.max().item())
                    rng = mx - mn + 1e-8
                    scaled = (n - mn) / rng
                    for t in range(min(seq_len, scaled.size(0))):
                        try:
                            dscd['span_preds'][b][t] = torch.tensor(float(scaled[t].item()), device=device)
                        except Exception:
                            pass
                if _VERBOSE_LOGGING:
                    print("[TATN] Applied embedding-norm fallback for span_preds.")
        except Exception:
            if _VERBOSE_LOGGING:
                print("[TATN] Span fallback error:", traceback.format_exc().splitlines()[-1])

        # TRAINING path: produce scalar loss
        if training_mode:
            try:
                enc_for_decoder = BaseModelOutput(last_hidden_state=h_aug) if BaseModelOutput is not None else (h_aug,)
            except Exception:
                enc_for_decoder = (h_aug,)

            translation_loss = torch.tensor(0.0, device=device)
            try:
                if self.mbart is not None:
                    seq_outputs = self.mbart(encoder_outputs=enc_for_decoder,
                                             attention_mask=attention_mask,
                                             labels=labels,
                                             use_cache=False,
                                             return_dict=True)
                    translation_loss = getattr(seq_outputs, 'loss', torch.tensor(0.0, device=device))
                else:
                    translation_loss = torch.tensor(0.0, device=device)
            except Exception:
                if _VERBOSE_LOGGING:
                    print("[TATN] Decoder forward failed during training:", traceback.format_exc().splitlines()[-1])
                translation_loss = torch.tensor(0.0, device=device)

            try:
                if self.asbn is not None:
                    asbn_ret = self.asbn.forward_with_grl_simplified(h_aug, dscd.get('proto_probs', None),
                                                                    dscd.get('uncertainties', None),
                                                                    dscd.get('gates', None),
                                                                    token_word_map=token_word_map)
                    asbn_loss = asbn_ret[0] if isinstance(asbn_ret, (tuple, list)) else asbn_ret
                    if not isinstance(asbn_loss, torch.Tensor):
                        asbn_loss = torch.tensor(float(asbn_loss), device=device)
                    else:
                        asbn_loss = asbn_loss.to(device)
                    if not torch.isfinite(asbn_loss):
                        asbn_loss = torch.tensor(0.0, device=device)
                else:
                    asbn_loss = torch.tensor(0.0, device=device)
            except Exception:
                if _VERBOSE_LOGGING:
                    print("[TATN] ASBN forward failed:", traceback.format_exc().splitlines()[-1])
                asbn_loss = torch.tensor(0.0, device=device)

            try:
                dscd_reg = self._entropy_reg_from_proto_probs_static(dscd.get('proto_probs', []),
                                                                     gates_list=dscd.get('gates', []),
                                                                     min_gate=0.0)
                if not isinstance(dscd_reg, torch.Tensor):
                    dscd_reg = torch.tensor(float(dscd_reg), device=device)
                else:
                    dscd_reg = dscd_reg.to(device)
                if not torch.isfinite(dscd_reg):
                    dscd_reg = torch.tensor(0.0, device=device)
            except Exception:
                if _VERBOSE_LOGGING:
                    print("[TATN] DSCD reg computation failed:", traceback.format_exc().splitlines()[-1])
                dscd_reg = torch.tensor(0.0, device=device)

            total_loss = translation_loss + _LAMBDA_ASBN * asbn_loss + _LAMBDA_DSCD * dscd_reg
            if not isinstance(total_loss, torch.Tensor):
                total_loss = torch.tensor(float(total_loss), device=device)
            try:
                if total_loss.numel() != 1:
                    total_loss = total_loss.mean()
            except Exception:
                total_loss = torch.tensor(float(total_loss), device=device)
            return total_loss

        # INFERENCE path: produce explanations (no loss)
        explanations = {i: [] for i in range(batch_size)}
        if (not self.training) and _ENABLE_TRG_INFERENCE and self.trg_system is not None:
            tokens_batch: List[List[str]] = []
            word_maps_batch: List[dict] = []

            # build tokens and word maps
            for b in range(batch_size):
                try:
                    ids_b = input_ids[b].detach().cpu().tolist()
                    if hasattr(self.tokenizer, 'convert_ids_to_tokens'):
                        toks = self.tokenizer.convert_ids_to_tokens(ids_b)
                    else:
                        try:
                            decoded = self.tokenizer.decode(ids_b, skip_special_tokens=True)
                            toks = decoded.split()[:seq_len]
                        except Exception:
                            toks = [''] * seq_len
                    if len(toks) < seq_len:
                        toks = toks + [''] * (seq_len - len(toks))
                    else:
                        toks = toks[:seq_len]
                except Exception:
                    toks = [''] * seq_len
                tokens_batch.append(toks)

                # reconstruct word spans and create normalized map if available
                if _has_reconstruct_word_spans:
                    try:
                        if src_texts and b < len(src_texts) and isinstance(src_texts[b], str) and src_texts[b].strip():
                            orig_text = src_texts[b]
                        else:
                            try:
                                orig_text = self.tokenizer.decode(input_ids[b], skip_special_tokens=True)
                            except Exception:
                                orig_text = ""
                        wm, _ = reconstruct_word_spans(self.tokenizer, orig_text, max_length=seq_len)
                        if not isinstance(wm, dict):
                            wm = {}
                    except Exception:
                        if _VERBOSE_LOGGING:
                            print("[TATN] reconstruct_word_spans failed:", traceback.format_exc().splitlines()[-1])
                        wm = {}
                else:
                    wm = {}

                # build normalized word map using normalize_bn_word if available
                norm_map = {}
                try:
                    if isinstance(wm, dict) and _normalize_fn:
                        for k, v in wm.items():
                            try:
                                norm_map[k] = _normalize_fn(v) if isinstance(v, str) and v else v
                            except Exception:
                                norm_map[k] = v
                    else:
                        # if wm not dict or no normalize fn, try making a trivial map
                        if isinstance(wm, dict):
                            norm_map = wm.copy()
                except Exception:
                    norm_map = wm.copy() if isinstance(wm, dict) else {}

                word_maps_batch.append({"orig": wm, "norm": norm_map})

            # helper: safe extractor that returns per-token lists (lists of tensors/scalars)
            def _safe_take_key(dscd_struct, key, b_index):
                out = []
                try:
                    val = dscd_struct.get(key, None)
                    if val is None:
                        # default: scalar zeros or proto tensor fallback
                        if key == 'proto_probs':
                            return [torch.tensor([1.0], device=device) for _ in range(seq_len)]
                        else:
                            return [torch.tensor(0.0, device=device) for _ in range(seq_len)]

                    # handle list-of-batches
                    if isinstance(val, list) and len(val) == batch_size:
                        row = val[b_index]
                        if isinstance(row, list):
                            for t in range(min(seq_len, len(row))):
                                v = row[t]
                                if isinstance(v, torch.Tensor):
                                    out.append(v.to(device))
                                else:
                                    try:
                                        out.append(torch.tensor(float(v), device=device))
                                    except Exception:
                                        out.append(torch.tensor(0.0, device=device))
                            # pad
                            while len(out) < seq_len:
                                if key == 'proto_probs':
                                    out.append(torch.tensor([1.0], device=device))
                                else:
                                    out.append(torch.tensor(0.0, device=device))
                            return out
                        elif isinstance(row, torch.Tensor):
                            r = row.detach().cpu()
                            if r.dim() == 1:
                                for t in range(min(seq_len, int(r.size(0)))):
                                    out.append(torch.tensor(float(r[t].item()), device=device))
                            elif r.dim() == 2:
                                for t in range(min(seq_len, int(r.size(0)))):
                                    out.append(r[t].to(device))
                            while len(out) < seq_len:
                                if key == 'proto_probs':
                                    out.append(torch.tensor([1.0], device=device))
                                else:
                                    out.append(torch.tensor(0.0, device=device))
                            return out

                    # handle tensor layout [B, T, ...]
                    if isinstance(val, torch.Tensor):
                        v = val.detach().cpu()
                        if v.dim() >= 2 and int(v.size(0)) == batch_size:
                            for t in range(min(seq_len, int(v.size(1)))):
                                if v.dim() == 3:
                                    out.append(v[b_index, t].to(device))
                                else:
                                    out.append(torch.tensor(float(v[b_index, t].item()), device=device))
                            while len(out) < seq_len:
                                if key == 'proto_probs':
                                    out.append(torch.tensor([1.0], device=device))
                                else:
                                    out.append(torch.tensor(0.0, device=device))
                            return out
                        elif v.dim() == 1 and batch_size == 1:
                            for t in range(min(seq_len, int(v.size(0)))):
                                out.append(torch.tensor(float(v[t].item()), device=device))
                            while len(out) < seq_len:
                                out.append(torch.tensor(0.0, device=device))
                            return out

                    # numpy arrays or list-of-per-token
                    if isinstance(val, (list, tuple, np.ndarray)):
                        if isinstance(val, np.ndarray) and val.ndim >= 2 and val.shape[0] == batch_size:
                            for t in range(min(seq_len, int(val.shape[1]))):
                                try:
                                    out.append(torch.tensor(float(val[b_index, t]), device=device))
                                except Exception:
                                    out.append(torch.tensor(0.0, device=device))
                            while len(out) < seq_len:
                                out.append(torch.tensor(0.0, device=device))
                            return out
                        # treat as per-token sequence for single batch
                        if len(val) >= seq_len:
                            for t in range(min(seq_len, len(val))):
                                vt = val[t]
                                if isinstance(vt, torch.Tensor):
                                    out.append(vt.detach().to(device))
                                else:
                                    try:
                                        out.append(torch.tensor(float(vt), device=device))
                                    except Exception:
                                        out.append(torch.tensor(0.0, device=device))
                            return out

                except Exception:
                    if _VERBOSE_LOGGING:
                        print("[TATN] _safe_take_key error:", traceback.format_exc().splitlines()[-1])
                # fallback fill
                if key == 'proto_probs':
                    return [torch.tensor([1.0], device=device) for _ in range(seq_len)]
                return [torch.tensor(0.0, device=device) for _ in range(seq_len)]

            # call TRG system per sentence safely
            try:
                for b in range(batch_size):
                    per_sent = {
                        'proto_probs': _safe_take_key(dscd, 'proto_probs', b),
                        'uncertainties': _safe_take_key(dscd, 'uncertainties', b),
                        'gates': _safe_take_key(dscd, 'gates', b),
                        'span_preds': _safe_take_key(dscd, 'span_preds', b),
                    }
                    try:
                        # TRG expects dscd_outputs-like dict for a single sentence; pass token_word_map[b] (with 'norm' available)
                        exps = self.trg_system.process_sentence_for_explanations(
                            tokens_batch[b],
                            per_sent,
                            token_word_map=word_maps_batch[b],
                            uncertainty_threshold=float(globals().get('TAU_LOW', 0.4)),
                            top_k=3
                        )
                        explanations[b] = exps if isinstance(exps, list) else []
                    except Exception:
                        if _VERBOSE_LOGGING:
                            print(f"[TATN] TRG explanation generation failed for idx={b}:", traceback.format_exc().splitlines()[-1])
                        explanations[b] = []
            except Exception:
                if _VERBOSE_LOGGING:
                    print("[TATN] TRG generation failed overall:", traceback.format_exc().splitlines()[-1])
                explanations = {i: [] for i in range(batch_size)}

        outputs = {
            'encoder_outputs': enc_outputs,
            'dscd_outputs': dscd,
            'sense_augmented_embeddings': h_aug,
            'explanations': [explanations.get(i, []) for i in range(batch_size)],
            'asbn_loss': torch.tensor(0.0, device=device),
        }
        return outputs

    def forward_with_explanations(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        src_texts: Optional[List[str]] = None,
        token_word_map: Optional[List[dict]] = None,
    ):
        return self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            src_texts=src_texts,
            token_word_map=token_word_map,
            labels=None,
        )

# -----------------------------------------------------------------------------
# Verification prints (concise)
# -----------------------------------------------------------------------------
print("=" * 80)
print("‚úÖ Cell 6: TATN model wrapper ready (DEBUGGED & HARDENED)")
print("=" * 80)
print(f"‚úì transformers available: {_HAS_TRANSFORMERS}")
print(f"‚úì Gradient checkpointing enabled (config): {_USE_GC}")
print(f"‚úì DSCD training clustering: {'ENABLED' if _DSCD_ENABLE_TRAINING_CLUSTERING else 'DISABLED'}")
print(f"‚úì DSCD buffer: {_DSCD_BUFFER_SIZE}, n_min: {_DSCD_N_MIN}, disp_th: {_DSCD_DISPERSION_THRESHOLD}")
print("=" * 80)

‚úÖ Cell 6: TATN model wrapper ready (DEBUGGED & HARDENED)
‚úì transformers available: True
‚úì Gradient checkpointing enabled (config): True
‚úì DSCD training clustering: ENABLED
‚úì DSCD buffer: 20, n_min: 5, disp_th: 0.25


In [11]:
# ==============================================================================
# CELL 7 (FIXED): TRAINING LOOP (DP + AMP + AccUM + Progress + DEBUG + CLUSTER TRACKING)
# Fully debugged, hardened, and compatible with multiple torch versions.
# ==============================================================================
# Fixes applied (summary):
#  - Removed use of private GradScaler API scaler._maybe_opt_step(None) and handled optimizer=None safely.
#  - Added robust helper scaler_enabled(...) to support older/newer torch versions.
#  - Added support for ModelOutput-like objects (access .loss attribute) in forward outputs.
#  - Replaced fragile calls to scaler.is_enabled() with scaler_enabled(scaler) wrapper.
#  - Graceful epoch-flush when optimizer is None (no exception raised).
#  - Wrapped scaler.unscale_/step/update calls in try/except for compatibility and safety.
#  - Ensured we do not call scaler-specific APIs when scaler is disabled.
#  - Minor logging and defensive guards (avoid uninitialized vars, better OOM handling).
# ==============================================================================
import os
import time
import math
import gc
import traceback
from datetime import datetime
from collections import defaultdict, deque
from typing import Optional, Dict, Any, List, Tuple, Union

import numpy as np
import torch
from torch.cuda.amp import GradScaler, autocast as cuda_amp_autocast
from tqdm import tqdm
from contextlib import nullcontext

# ---------------- Debug control ----------------
_VERBOSE_LOGGING = bool(globals().get("VERBOSE_LOGGING", False))

DEBUG_PRINT_INTERVAL = int(globals().get("DEBUG_PRINT_INTERVAL", 200))
_cell7_dbg_counts = defaultdict(int)


def cell7_dbg(key: str, msg: str, limit: int = 10):
    if not _VERBOSE_LOGGING:
        return
    _cell7_dbg_counts[key] += 1
    if _cell7_dbg_counts[key] <= limit:
        print(f"[CELL7-DBG] {msg}")


# ---------------- Fallback globals ----------------
_DEVICE = globals().get("DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
_EPOCHS = int(globals().get("EPOCHS", 1))
_BATCH_SIZE = int(globals().get("BATCH_SIZE", 8))
_ACCUMULATION_STEPS = int(globals().get("ACCUMULATION_STEPS", 1))
_GRAD_CLIP_NORM = float(globals().get("GRAD_CLIP_NORM", 1.0))
_MEMORY_CLEANUP_FREQUENCY = int(globals().get("MEMORY_CLEANUP_FREQUENCY", 100))
_USE_MULTI_GPU = bool(globals().get("USE_MULTI_GPU", torch.cuda.device_count() > 1))
_NUM_GPUS = int(globals().get("NUM_GPUS", torch.cuda.device_count() if torch.cuda.is_available() else 0))
_USE_AMP = bool(globals().get("USE_AMP", True))
_BN_LANG = str(globals().get("BN_LANG", "bn"))
_EN_LANG = str(globals().get("EN_LANG", "en"))
_MAX_LENGTH = int(globals().get("MAX_LENGTH", 48))
VALIDATION_CHECK_INTERVAL = int(globals().get("VALIDATION_CHECK_INTERVAL", 0))

# ---------------- Helpers ----------------
def clear_all_gpu_caches():
    gc.collect()
    if not torch.cuda.is_available():
        return
    try:
        for i in range(torch.cuda.device_count()):
            with torch.cuda.device(i):
                try:
                    torch.cuda.empty_cache()
                except Exception:
                    pass
    except Exception:
        pass


def get_amp_ctx():
    """
    Return a context manager for mixed-precision if enabled and available.
    Otherwise return a nullcontext.
    """
    if not _USE_AMP or not torch.cuda.is_available():
        return nullcontext()
    try:
        return cuda_amp_autocast()
    except Exception:
        return nullcontext()


def save_checkpoint(model: torch.nn.Module, optimizer: Optional[torch.optim.Optimizer], training_stats: Dict[str, Any],
                    epoch: int, global_step: int, epoch_losses: List[float], ckpt_dir: str = "checkpoints"):
    os.makedirs(ckpt_dir, exist_ok=True)
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    fname = f"tatn_e{epoch}_s{global_step}_{timestamp}.pt"
    path = os.path.join(ckpt_dir, fname)
    core_model = model.module if hasattr(model, "module") else model
    ckpt = {
        "epoch": epoch,
        "global_step": global_step,
        "model_state_dict": core_model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict() if optimizer is not None else None,
        "training_stats": training_stats,
        "avg_epoch_loss": float(np.mean(epoch_losses)) if epoch_losses else 0.0,
    }
    try:
        torch.save(ckpt, path)
        print(f"[CHECKPOINT] Saved {fname} avg_loss={ckpt['avg_epoch_loss']:.6f}")
    except Exception as e:
        print(f"[CHECKPOINT] Save failed: {type(e).__name__}: {str(e)[:200]}")


# ---------------- Validation (hardened) ----------------
_PROTOBUF_COMPAT_ERROR_SHOWN = globals().get("_PROTOBUF_COMPAT_ERROR_SHOWN", False)


@torch.inference_mode()
def quick_validation_check(model: torch.nn.Module, tokenizer, step: int, bn_lang: str, en_lang: str, max_length: int, device: torch.device):
    """
    Run a few simple translations to sanity-check the model.
    Robust to protobuf/getprototype errors.
    """
    global _PROTOBUF_COMPAT_ERROR_SHOWN
    core_model = model.module if hasattr(model, "module") else model
    gen_target = getattr(core_model, "mbart", core_model)
    was_training = core_model.training
    core_model.eval()

    samples = [
        "‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§",
        "‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§",
        "‡¶™‡¶æ‡¶§‡¶æ ‡¶ù‡¶∞‡ßá ‡¶™‡¶°‡¶º‡ßá‡¶õ‡ßá‡•§",
        "‡¶Ü‡¶Æ‡¶ø ‡¶≠‡¶æ‡¶≤‡ßã ‡¶Ü‡¶õ‡¶ø‡•§",
        "‡¶Ü‡¶ú ‡¶Ü‡¶¨‡¶π‡¶æ‡¶ì‡¶Ø‡¶º‡¶æ ‡¶≠‡¶æ‡¶≤‡ßã‡•§",
    ]
    print("\n" + "=" * 70)
    print(f"[VALIDATION] Quick validation at step {step}")
    print("=" * 70)
    try:
        try:
            tokenizer.src_lang = bn_lang
        except Exception:
            pass

        forced_id = None
        try:
            if hasattr(tokenizer, "get_lang_id"):
                for code in (en_lang, "en_XX", "en", "eng"):
                    try:
                        lid = tokenizer.get_lang_id(code)
                        if lid is not None:
                            forced_id = lid
                            break
                    except Exception:
                        continue
            elif hasattr(tokenizer, "lang_code_to_id"):
                forced_id = tokenizer.lang_code_to_id.get(en_lang, None)
        except Exception:
            forced_id = None

        mbart_obj = getattr(core_model, "mbart", None)
        orig_use_cache = None
        try:
            if mbart_obj is not None and hasattr(mbart_obj.config, "use_cache"):
                orig_use_cache = mbart_obj.config.use_cache
                mbart_obj.config.use_cache = True
        except Exception:
            orig_use_cache = None

        for i, src in enumerate(samples, 1):
            try:
                enc = tokenizer(src, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
                enc = {k: (v.to(device) if isinstance(v, torch.Tensor) else v) for k, v in enc.items()}
                if forced_id is not None:
                    try:
                        if mbart_obj is not None:
                            mbart_obj.config.forced_bos_token_id = int(forced_id)
                            mbart_obj.config.decoder_start_token_id = int(forced_id)
                    except Exception:
                        pass
                out_ids = None
                try:
                    gen_src = getattr(core_model, "mbart", None) or core_model
                    if hasattr(gen_src, "generate"):
                        out_ids = gen_src.generate(
                            enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask"),
                            max_length=max_length,
                            num_beams=2,
                            do_sample=False,
                            early_stopping=True,
                            pad_token_id=int(getattr(tokenizer, "pad_token_id", 1)),
                            forced_bos_token_id=int(forced_id) if forced_id is not None else None
                        )
                except AttributeError as ae:
                    if not _PROTOBUF_COMPAT_ERROR_SHOWN:
                        print("[VALIDATION] Warning: generation raised AttributeError (often protobuf incompatibility).")
                        print("  Suggestion: pip install 'protobuf==3.20.3' and restart the kernel.")
                        _PROTOBUF_COMPAT_ERROR_SHOWN = True
                    out_ids = None
                except Exception as e:
                    print(f"[VALIDATION] Generation error: {type(e).__name__}: {str(e)[:200]}")
                    out_ids = None

                if out_ids is not None:
                    try:
                        if isinstance(out_ids, (list, tuple)):
                            pred = tokenizer.batch_decode(out_ids, skip_special_tokens=True)[0]
                        else:
                            # ensure out_ids is tensor
                            if isinstance(out_ids, torch.Tensor):
                                pred = tokenizer.decode(out_ids[0], skip_special_tokens=True)
                            else:
                                pred = str(out_ids)
                    except AttributeError:
                        if not _PROTOBUF_COMPAT_ERROR_SHOWN:
                            print("[VALIDATION] Warning: decode raised AttributeError (protobuf). Pin protobuf and restart.")
                            _PROTOBUF_COMPAT_ERROR_SHOWN = True
                        pred = ""
                    except Exception as e:
                        print(f"[VALIDATION] Decode error: {type(e).__name__}: {str(e)[:200]}")
                        pred = ""
                else:
                    pred = ""
                print(f"{i}. {src} -> {pred}")
            except Exception as e:
                print(f"{i}. Validation error: {type(e).__name__}: {str(e)[:200]}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
    finally:
        try:
            if mbart_obj is not None and orig_use_cache is not None:
                mbart_obj.config.use_cache = orig_use_cache
        except Exception:
            pass
        if torch.cuda.is_available():
            try:
                torch.cuda.synchronize()
            except Exception:
                pass
        clear_all_gpu_caches()
        if was_training:
            core_model.train()
    print("=" * 70)


def _print_gpu_mem(prefix: str = ""):
    if not torch.cuda.is_available():
        return
    try:
        lines = [f"{prefix} GPU mem (GB):"]
        for i in range(torch.cuda.device_count()):
            try:
                alloc = torch.cuda.memory_allocated(i) / (1024**3)
                resv = torch.cuda.memory_reserved(i) / (1024**3)
                lines.append(f"  GPU {i}: alloc={alloc:.2f} resv={resv:.2f}")
            except Exception:
                lines.append(f"  GPU {i}: mem query failed")
        print("\n".join(lines))
    except Exception:
        pass


def _get_cluster_count(model: torch.nn.Module) -> int:
    try:
        dscd = model.module.dscd if hasattr(model, "module") else model.dscd
        return len(getattr(dscd, "prototype_stores", {}) or {})
    except Exception:
        return 0


def _get_dscd_safe(model: torch.nn.Module):
    try:
        core = model.module if hasattr(model, "module") else model
        return getattr(core, "dscd", None)
    except Exception:
        return None


def _print_top_clusters(model: torch.nn.Module, top_n: int = 5):
    dscd = _get_dscd_safe(model)
    if dscd is None:
        if _VERBOSE_LOGGING:
            print("[CLUSTER-DBG] No DSCD instance attached to model.")
        return
    try:
        items = []
        for token, store in dscd.prototype_stores.items():
            total_count = sum(getattr(store, "counts", []) or [])
            protos = store.size() if hasattr(store, "size") else len(getattr(store, "centroids", []))
            items.append((token, total_count, protos, len(dscd.buffers.get(token, []))))
        items.sort(key=lambda x: x[1], reverse=True)
        if _VERBOSE_LOGGING:
            print("[CLUSTER-DBG] Top clusters:")
            for i, (tok, cnt, prot, buflen) in enumerate(items[:top_n], 1):
                print(f"  {i:2d}. {str(tok)[:20]:20s} samples={cnt:4d} protos={prot} buf={buflen}")
    except Exception as e:
        if _VERBOSE_LOGGING:
            print(f"[CLUSTER-DBG] _print_top_clusters error: {type(e).__name__}: {str(e)[:200]}")


def _print_cluster_stats(model: torch.nn.Module):
    dscd = _get_dscd_safe(model)
    if dscd is None:
        return
    try:
        total_tokens = len(dscd.prototype_stores)
        total_protos = 0
        total_samples = 0
        total_buffers = 0
        for token, store in dscd.prototype_stores.items():
            total_protos += store.size() if hasattr(store, "size") else len(getattr(store, "centroids", []))
            total_samples += sum(getattr(store, "counts", []) or [])
            total_buffers += len(dscd.buffers.get(token, []))
        if _VERBOSE_LOGGING:
            print(f"[CLUSTER-DBG] tokens_with_stores={total_tokens} total_prototypes={total_protos} total_samples={total_samples} total_buffered_embeddings={total_buffers}")
    except Exception as e:
        if _VERBOSE_LOGGING:
            print(f"[CLUSTER-DBG] _print_cluster_stats error: {type(e).__name__}: {str(e)[:200]}")


# ---------------- batch unpacking helper ----------------
def _unpack_batch(batch: Any) -> Dict[str, Any]:
    """
    Accept common batch formats:
      - dict with keys
      - tuple/list: (input_ids, attention_mask, labels, src_texts?, token_word_map?)
    Return dict with keys 'input_ids','attention_mask','labels','src_text','token_word_map'
    """
    if batch is None:
        return {}
    if isinstance(batch, dict):
        return dict(batch)
    if isinstance(batch, (list, tuple)):
        out = {}
        # heuristics: first two are input_ids, attention_mask
        try:
            if len(batch) >= 2:
                out['input_ids'] = batch[0]
                out['attention_mask'] = batch[1]
            if len(batch) >= 3:
                out['labels'] = batch[2]
            if len(batch) >= 4:
                out['src_text'] = batch[3]
            if len(batch) >= 5:
                out['token_word_map'] = batch[4]
        except Exception:
            pass
        return out
    # fallback: can't unpack
    return {}


# Helper to check scaler availability robustly across torch versions
def scaler_enabled(scaler: Optional[GradScaler]) -> bool:
    if scaler is None:
        return False
    try:
        # GradScaler in newer torch versions has is_enabled()
        return bool(getattr(scaler, "is_enabled", lambda: False)())
    except Exception:
        # Backwards compat: use enabled attribute or assume True if instance
        return getattr(scaler, "enabled", False) if hasattr(scaler, "enabled") else True


# ---------------- Main training loop ----------------
def train_memory_efficient_tatn(
    model: torch.nn.Module,
    tokenizer,
    train_loader: torch.utils.data.DataLoader,
    optimizer: Optional[torch.optim.Optimizer],
    phi_optimizer: Optional[torch.optim.Optimizer] = None,
    epochs: Optional[int] = None,
    accumulation_steps: Optional[int] = None,
    validate_every: Optional[int] = None,
    enable_validation: bool = True
) -> torch.nn.Module:
    if epochs is None:
        epochs = _EPOCHS
    if accumulation_steps is None:
        accumulation_steps = max(1, _ACCUMULATION_STEPS)
    if validate_every is None:
        validate_every = VALIDATION_CHECK_INTERVAL

    print(f"[TRAIN] Starting training: epochs={epochs}, batch={_BATCH_SIZE}, accum_steps={accumulation_steps}")
    print(f"[TRAIN] Validation: {'enabled' if enable_validation and validate_every > 0 else 'disabled'}")
    print(f"[TRAIN] DP enabled: {_USE_MULTI_GPU}, GPUs: {_NUM_GPUS}, Device: {_DEVICE}")

    model.train()
    clear_all_gpu_caches()

    # GradScaler enabled only if AMP requested and CUDA present
    scaler = GradScaler(enabled=(_USE_AMP and torch.cuda.is_available()))

    global_step = 0
    accumulated_steps = 0
    pending_validation = False

    training_stats: Dict[str, Any] = {
        "total_loss": [],
        "batches_processed": 0,
        "optimizer_updates": 0,
        "skipped_batches": 0,
        "oom_errors": 0,
        "runtime_errors": 0,
        "exceptions": 0,
    }

    skip_reasons = defaultdict(int)
    last_forward_loss = 0.0
    last_backward_loss = 0.0

    for epoch in range(1, epochs + 1):
        epoch_start = time.time()
        epoch_losses: List[float] = []
        try:
            if optimizer is not None:
                try:
                    optimizer.zero_grad(set_to_none=True)
                except Exception:
                    pass
        except Exception:
            pass

        progress = tqdm(train_loader, desc=f"Epoch {epoch}/{epochs}", ncols=180, dynamic_ncols=False)

        for batch_idx, batch in enumerate(progress):
            global_step += 1
            training_stats["batches_processed"] += 1

            if _VERBOSE_LOGGING and global_step % DEBUG_PRINT_INTERVAL == 0:
                print(f"[TRAIN-DEBUG] Epoch {epoch} Batch {batch_idx} GlobalStep {global_step}")

            # Validation scheduling
            if enable_validation and validate_every and validate_every > 0 and (global_step % validate_every == 0):
                if accumulated_steps == 0:
                    try:
                        quick_validation_check(model, tokenizer, global_step, _BN_LANG, _EN_LANG, _MAX_LENGTH, _DEVICE)
                    except Exception:
                        if _VERBOSE_LOGGING:
                            print("[TRAIN] quick_validation_check failed:", traceback.format_exc().splitlines()[-1])
                else:
                    pending_validation = True

            # Validate batch
            if batch is None:
                training_stats["skipped_batches"] += 1
                skip_reasons["batch_none"] += 1
                cell7_dbg("batch_none", f"Batch is None at idx={batch_idx}")
                continue

            try:
                # Unpack batch robustly
                bdict = _unpack_batch(batch)
                input_ids = bdict.get("input_ids", None)
                attention_mask = bdict.get("attention_mask", None)
                labels = bdict.get("labels", None)

                # If key tensors missing, skip
                if input_ids is None or attention_mask is None:
                    training_stats["skipped_batches"] += 1
                    skip_reasons["missing_tensors"] += 1
                    cell7_dbg("missing_tensors", f"Missing tensors in batch idx={batch_idx}")
                    continue

                # ensure tensors are on correct dtype/device
                try:
                    if isinstance(input_ids, torch.Tensor):
                        input_ids = input_ids.to(_DEVICE, non_blocking=True)
                        if input_ids.dtype not in (torch.long, torch.int64):
                            input_ids = input_ids.long()
                    if isinstance(attention_mask, torch.Tensor):
                        attention_mask = attention_mask.to(_DEVICE, non_blocking=True)
                    if labels is not None and isinstance(labels, torch.Tensor):
                        labels = labels.to(_DEVICE, non_blocking=True)
                except Exception:
                    # fallback: move using .to with try/except
                    try:
                        input_ids = input_ids.to(_DEVICE)
                    except Exception:
                        pass
                    try:
                        attention_mask = attention_mask.to(_DEVICE)
                    except Exception:
                        pass
                    try:
                        if labels is not None and isinstance(labels, torch.Tensor):
                            labels = labels.to(_DEVICE)
                    except Exception:
                        pass

                # DP-divisible truncation safety (should already be handled by collate)
                if _USE_MULTI_GPU and _NUM_GPUS > 0:
                    try:
                        bsz = int(input_ids.size(0))
                        keep = (bsz // _NUM_GPUS) * _NUM_GPUS
                        if keep == 0:
                            training_stats["skipped_batches"] += 1
                            skip_reasons["dp_keep_zero"] += 1
                            cell7_dbg("dp_keep_zero", f"DP keep==0 bsz={bsz}, gpus={_NUM_GPUS}")
                            continue
                        if keep != bsz:
                            input_ids = input_ids[:keep]
                            attention_mask = attention_mask[:keep]
                            if labels is not None:
                                labels = labels[:keep]
                    except Exception:
                        # If we can't determine size, skip to be safe
                        training_stats["skipped_batches"] += 1
                        skip_reasons["dp_size_error"] += 1
                        continue

                if isinstance(input_ids, torch.Tensor) and input_ids.size(0) == 0:
                    training_stats["skipped_batches"] += 1
                    skip_reasons["empty_batch"] += 1
                    continue

                # Optional debugging: token_word_map presence in the batch (non-tensor)
                if _VERBOSE_LOGGING and 'token_word_map' in bdict:
                    try:
                        sample_map = bdict['token_word_map'][:2]
                        cell7_dbg("tokmap_sample", f"token_word_map sample lens: {[len(x) if x else 0 for x in sample_map]}", limit=3)
                    except Exception:
                        pass

                forward_kwargs = {
                    "input_ids": input_ids,
                    "attention_mask": attention_mask,
                    "labels": labels,
                    "src_texts": bdict.get("src_text", None),
                    "token_word_map": bdict.get("token_word_map", None),
                }

                amp_ctx = get_amp_ctx()
                with amp_ctx:
                    forward_out = model(**forward_kwargs)

                    # Determine loss tensor
                    loss_tensor = None
                    # handle HF ModelOutput-like (with .loss attribute)
                    try:
                        if hasattr(forward_out, "loss"):
                            loss_tensor = getattr(forward_out, "loss")
                    except Exception:
                        pass

                    if loss_tensor is None:
                        if isinstance(forward_out, torch.Tensor):
                            loss_tensor = forward_out
                        elif isinstance(forward_out, dict):
                            # common keys to check
                            possible_loss_keys = ["loss", "total_loss", "translation_loss"]
                            for k in possible_loss_keys:
                                if k in forward_out:
                                    loss_tensor = forward_out[k]
                                    break
                            # if no explicit loss, model may have returned scalar in fieldless dict
                            if loss_tensor is None:
                                # try to find any tensor value that is scalar-like
                                for v in forward_out.values():
                                    if isinstance(v, torch.Tensor) and v.numel() == 1:
                                        loss_tensor = v
                                        break
                        elif isinstance(forward_out, (list, tuple)) and len(forward_out) > 0:
                            if isinstance(forward_out[0], torch.Tensor):
                                loss_tensor = forward_out[0]

                    if loss_tensor is None:
                        # As a last resort try converting numeric outputs
                        try:
                            if isinstance(forward_out, (int, float, np.floating, np.integer)):
                                loss_tensor = torch.tensor(float(forward_out), device=_DEVICE)
                        except Exception:
                            pass

                    if loss_tensor is None:
                        raise RuntimeError("Model forward did not return a recognizable loss tensor")

                    # Ensure scalar and on device
                    if not isinstance(loss_tensor, torch.Tensor):
                        loss_tensor = torch.tensor(float(loss_tensor), device=_DEVICE)
                    else:
                        try:
                            loss_tensor = loss_tensor.to(_DEVICE)
                        except Exception:
                            pass

                    if loss_tensor.numel() > 1:
                        loss_val = float(loss_tensor.mean().item())
                        loss_tensor = loss_tensor.mean()
                    else:
                        loss_val = float(loss_tensor.item())

                    last_forward_loss = loss_val
                    epoch_losses.append(loss_val)
                    training_stats["total_loss"].append(loss_val)

                # backward + accumulation
                loss_scaled = loss_tensor / max(1, accumulation_steps)
                try:
                    last_backward_loss = float(loss_scaled.item())
                except Exception:
                    try:
                        last_backward_loss = float(loss_scaled.detach().cpu().item()) if isinstance(loss_scaled, torch.Tensor) else float(loss_scaled)
                    except Exception:
                        last_backward_loss = 0.0

                # Backward: use scaler only if enabled
                try:
                    if scaler_enabled(scaler):
                        scaler.scale(loss_scaled).backward()
                    else:
                        loss_scaled.backward()
                except RuntimeError as e:
                    # immediate OOM during backward
                    if "out of memory" in str(e).lower():
                        training_stats["oom_errors"] += 1
                        training_stats["skipped_batches"] += 1
                        skip_reasons["oom_backward"] += 1
                        print(f"[OOM] OOM during backward at step {global_step}: {str(e)[:200]}")
                        try:
                            if optimizer is not None:
                                optimizer.zero_grad(set_to_none=True)
                        except Exception:
                            pass
                        for p in model.parameters():
                            if p is not None:
                                p.grad = None
                        clear_all_gpu_caches()
                        accumulated_steps = 0
                        continue
                    else:
                        raise

                accumulated_steps += 1

                # optimizer step
                if accumulated_steps >= accumulation_steps:
                    try:
                        if optimizer is None:
                            # Nothing to step; just zero grads and move on
                            training_stats["skipped_batches"] += 1
                            skip_reasons["no_optimizer"] += 1
                            try:
                                model.zero_grad(set_to_none=True)
                            except Exception:
                                for p in model.parameters():
                                    if p.grad is not None:
                                        p.grad = None
                        else:
                            # Unscale first if using scaler, then clip
                            if scaler_enabled(scaler):
                                try:
                                    scaler.unscale_(optimizer)
                                except Exception:
                                    # unscale_ might not exist in some versions; ignore and proceed
                                    pass
                            # gradient clip (guard against generator/params that might be empty)
                            try:
                                torch.nn.utils.clip_grad_norm_(model.parameters(), _GRAD_CLIP_NORM)
                            except Exception:
                                pass
                            # step
                            if scaler_enabled(scaler):
                                try:
                                    scaler.step(optimizer)
                                    scaler.update()
                                except Exception as e:
                                    # If scaler.step failed, try a plain step (best-effort)
                                    try:
                                        optimizer.step()
                                    except Exception:
                                        raise
                            else:
                                optimizer.step()
                            # zero grads
                            try:
                                optimizer.zero_grad(set_to_none=True)
                            except Exception:
                                for p in model.parameters():
                                    if p.grad is not None:
                                        p.grad.detach_()
                                        p.grad.zero_()
                            training_stats["optimizer_updates"] += 1
                    except RuntimeError as e:
                        if "out of memory" in str(e).lower():
                            training_stats["oom_errors"] += 1
                            training_stats["skipped_batches"] += 1
                            skip_reasons["oom"] += 1
                            print(f"[OOM] OOM at step {global_step}: {str(e)[:200]}")
                            try:
                                if optimizer is not None:
                                    optimizer.zero_grad(set_to_none=True)
                            except Exception:
                                pass
                            for p in model.parameters():
                                p.grad = None
                            clear_all_gpu_caches()
                            accumulated_steps = 0
                            continue
                        else:
                            training_stats["runtime_errors"] += 1
                            skip_reasons["opt_runtime"] += 1
                            print(f"[ERROR] Runtime error during optimizer step: {type(e).__name__}: {str(e)[:200]}")
                    except Exception as e:
                        training_stats["exceptions"] += 1
                        skip_reasons["opt_exception"] += 1
                        print(f"[ERROR] Exception during optimizer step: {type(e).__name__}: {str(e)[:200]}")
                    finally:
                        accumulated_steps = 0
                        if pending_validation:
                            try:
                                quick_validation_check(model, tokenizer, global_step, _BN_LANG, _EN_LANG, _MAX_LENGTH, _DEVICE)
                            except Exception:
                                if _VERBOSE_LOGGING:
                                    print("[TRAIN] deferred quick_validation_check failed:", traceback.format_exc().splitlines()[-1])
                            pending_validation = False

                # periodic housekeeping & logs
                if global_step % DEBUG_PRINT_INTERVAL == 0:
                    _print_gpu_mem("[TRAIN-DEBUG]")
                    try:
                        cluster_count = _get_cluster_count(model)
                    except Exception:
                        cluster_count = 0
                    print(f"[TRAIN-DEBUG] step={global_step} loss={last_forward_loss:.4f} opt_updates={training_stats['optimizer_updates']} clusters={cluster_count}")
                    _print_top_clusters(model, top_n=5)
                    _print_cluster_stats(model)

                if global_step % _MEMORY_CLEANUP_FREQUENCY == 0:
                    clear_all_gpu_caches()

            except RuntimeError as e:
                if "out of memory" in str(e).lower():
                    training_stats["oom_errors"] += 1
                    training_stats["skipped_batches"] += 1
                    skip_reasons["oom"] += 1
                    print(f"[OOM] Caught OOM at step {global_step}: {str(e)[:200]}")
                    try:
                        if optimizer is not None:
                            optimizer.zero_grad(set_to_none=True)
                    except Exception:
                        pass
                    for p in model.parameters():
                        p.grad = None
                    clear_all_gpu_caches()
                    accumulated_steps = 0
                    continue
                else:
                    training_stats["runtime_errors"] += 1
                    training_stats["skipped_batches"] += 1
                    skip_reasons["runtime"] += 1
                    print(f"[RUNTIME] RuntimeError at step {global_step}: {type(e).__name__}: {str(e)[:200]}")
                    if _VERBOSE_LOGGING:
                        traceback.print_exc()
                    try:
                        if optimizer is not None:
                            optimizer.zero_grad(set_to_none=True)
                    except Exception:
                        pass
                    accumulated_steps = 0
                    continue
            except Exception as e:
                training_stats["exceptions"] += 1
                training_stats["skipped_batches"] += 1
                skip_reasons["exceptions"] += 1
                print(f"[EXCEPTION] Exception at step {global_step}: {type(e).__name__}: {str(e)[:200]}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
                try:
                    if optimizer is not None:
                        optimizer.zero_grad(set_to_none=True)
                except Exception:
                    pass
                accumulated_steps = 0
                continue

            # update progress bar postfix
            processed_batches = training_stats["batches_processed"] - training_stats["skipped_batches"]
            expected_updates = max(1, math.floor(processed_batches / max(1, accumulation_steps)))
            success_rate = 100.0 * training_stats["optimizer_updates"] / expected_updates if expected_updates > 0 else 0.0
            cluster_count = _get_cluster_count(model)
            try:
                progress.set_postfix_str(
                    f"fwd_loss={last_forward_loss:.4f} bwd_loss={last_backward_loss:.6f} rate={success_rate:.1f}% proc={processed_batches} skip={training_stats['skipped_batches']} clusters={cluster_count}"
                )
            except Exception:
                # ignore progress bar update errors
                pass

        # end epoch: flush remaining grads if any
        if accumulated_steps > 0:
            try:
                if optimizer is None:
                    # Cannot flush without optimizer: just zero gradients and log
                    try:
                        model.zero_grad(set_to_none=True)
                    except Exception:
                        for p in model.parameters():
                            if p.grad is not None:
                                p.grad = None
                    print("[EPOCH-FLUSH] Skipped flush because optimizer is None.")
                else:
                    if scaler_enabled(scaler):
                        try:
                            scaler.unscale_(optimizer)
                        except Exception:
                            pass
                        try:
                            torch.nn.utils.clip_grad_norm_(model.parameters(), _GRAD_CLIP_NORM)
                        except Exception:
                            pass
                        try:
                            scaler.step(optimizer)
                            scaler.update()
                        except Exception:
                            try:
                                optimizer.step()
                            except Exception:
                                raise
                    else:
                        try:
                            torch.nn.utils.clip_grad_norm_(model.parameters(), _GRAD_CLIP_NORM)
                        except Exception:
                            pass
                        optimizer.step()
                    try:
                        optimizer.zero_grad(set_to_none=True)
                    except Exception:
                        for p in model.parameters():
                            if p.grad is not None:
                                p.grad.detach_()
                                p.grad.zero_()
                    training_stats["optimizer_updates"] += 1
            except Exception as e:
                print(f"[EPOCH-FLUSH] Exception on epoch flush: {type(e).__name__}: {str(e)[:200]}")
            finally:
                accumulated_steps = 0

        epoch_duration_min = (time.time() - epoch_start) / 60.0
        processed_batches = training_stats["batches_processed"] - training_stats["skipped_batches"]
        expected_updates = max(1, math.floor(processed_batches / max(1, accumulation_steps)))
        success_rate = 100.0 * training_stats["optimizer_updates"] / expected_updates if expected_updates > 0 else 0.0
        cluster_count = _get_cluster_count(model)

        print("\n" + "=" * 80)
        print(f"Epoch {epoch} summary:")
        print(f"  duration (min): {epoch_duration_min:.2f}")
        print(f"  optimizer updates: {training_stats['optimizer_updates']}")
        print(f"  batches processed: {training_stats['batches_processed']} (processed={processed_batches}, skipped={training_stats['skipped_batches']})")
        print(f"  success rate (updates/expected): {success_rate:.1f}%")
        print(f"  clustered token types: {cluster_count}")
        if training_stats["total_loss"]:
            print(f"  avg forward loss: {float(np.mean(training_stats['total_loss'])):.6f}")
        if skip_reasons:
            print("  skip reasons:")
            for k, v in sorted(skip_reasons.items(), key=lambda x: -x[1]):
                print(f"    - {k}: {v}")
        print("=" * 80)

        # save checkpoint at epoch end
        try:
            save_checkpoint(model, optimizer, training_stats, epoch, global_step, epoch_losses)
        except Exception as e:
            print(f"[CHECKPOINT] Save at epoch end failed: {type(e).__name__}: {str(e)[:200]}")

    print("\n[TRAIN] Training completed")
    processed_batches = training_stats["batches_processed"] - training_stats["skipped_batches"]
    expected_updates = max(1, math.floor(processed_batches / max(1, accumulation_steps)))
    success_rate = 100.0 * training_stats["optimizer_updates"] / expected_updates if expected_updates > 0 else 0.0
    print(f"[TRAIN] Success Rate (updates/expected): {success_rate:.1f}%")
    print(f"[TRAIN] Batches processed={processed_batches} skipped={training_stats['skipped_batches']}")
    print(f"[TRAIN] Clustered Token Types: {_get_cluster_count(model)}")
    return model


print("\n‚úÖ Cell 7: Training loop ready (patched & hardened)")


‚úÖ Cell 7: Training loop ready (patched & hardened)


In [12]:
# ==============================================================================
# CELL 8 (patched): INFERENCE + PIPELINE (HARDENED + FIXED)
# ==============================================================================
# Fixes applied (line-by-line hardening & behavioral fixes):
#  - Robust global fallback reads via globals().get(...) to avoid NameError
#  - BatchEncoding.to(device) guarded and per-tensor fallback to ensure tensors moved safely
#  - Corrected subword detection: tokens that start with '‚ñÅ' are WORD STARTS (not subwords)
#    (previous logic incorrectly treated '‚ñÅ' as subword prefix)
#  - Stronger filtering logic: remove tokens that are punctuation/short fragments,
#    strip SentencePiece markers before length tests, and optionally normalize via normalize_bn_word
#  - Safer mbart.generate invocation with progressive fallbacks:
#      * reduced max_length and beams on OOM
#      * per-sentence smaller-enc fallback
#      * reliable restore of mbart.config.use_cache
#  - Decoding handles tensor shapes (1D/2D) and list/tuple results
#  - _extract_dscd_outputs made more permissive and defensive: looks for nested dicts and list-contained dicts
#  - _get_explanations_list normalizes many shapes into list-of-lists
#  - All exception paths respect VERBOSE_LOGGING; stack traces only when verbose
#  - Prefer normalize_bn_word (if provided by bn_normalizer cell) to canonicalize tokens used for filtering/display
#  - Ensure model.eval() is set for generation; keep TRG in eval externally (model wrapper should have set it)
#  - Return structure consistent and stable even on internal errors
# ==============================================================================
import os
import time
import math
import traceback
from typing import List, Dict, Any, Tuple, Optional

import numpy as np
import torch

# Local fallbacks (read from Cell 0 if available)
_BN_LANG = globals().get("BN_LANG", "bn")
_EN_LANG = globals().get("EN_LANG", "en")
_MAX_LENGTH = int(globals().get("MAX_LENGTH", 48))
_DEVICE = globals().get("DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
_VERBOSE_LOGGING = bool(globals().get("VERBOSE_LOGGING", False))
_USE_MULTI_GPU = bool(globals().get("USE_MULTI_GPU", torch.cuda.is_available() and torch.cuda.device_count() > 1))

# Real ambiguity thresholds (defaults safe)
_REAL_AMB_SPAN_THRESHOLD = float(globals().get("SPAN_THRESHOLD", 0.3))
_REAL_AMB_UNCERTAINTY_THRESHOLD = float(globals().get("TAU_LOW", 0.4))

# Optional canonicalizer from bn_normalizer
_normalize_fn = globals().get("normalize_bn_word", None)

# ------------------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------------------
def _to_device_batch(enc: Any, device: torch.device):
    """
    Move tokenizer output to device. Prefer BatchEncoding.to(device) if present.
    Otherwise, move any tensor values in the dict to device.
    Returns a dict-like object with tensor values on the requested device.
    """
    try:
        # HF BatchEncoding has .to(device)
        if hasattr(enc, "to") and callable(getattr(enc, "to")):
            try:
                return enc.to(device)
            except Exception:
                if _VERBOSE_LOGGING:
                    print("[CELL8] BatchEncoding.to() raised; falling back to per-tensor move")
    except Exception:
        pass

    # fallback: assume mapping of key -> tensor
    out = {}
    try:
        for k, v in dict(enc).items():
            try:
                if isinstance(v, torch.Tensor):
                    out[k] = v.to(device)
                else:
                    out[k] = v
            except Exception:
                out[k] = v
        return out
    except Exception:
        # last-resort: return input unchanged
        if _VERBOSE_LOGGING:
            print("[CELL8] _to_device_batch fallback failed; returning original enc")
        return enc


def _extract_dscd_outputs(raw_out: Any) -> Dict[str, Any]:
    """
    Accept many possible model forward outputs and return a dict that contains DSCD/TRG outputs.
    Heuristics:
      - If dict and contains common DSCD keys -> return (or nested dict value)
      - If list/tuple, search for a dict element that looks like DSCD outputs
    """
    if raw_out is None:
        return {}

    # If it's already a dict with DSCD-like keys, prefer that
    if isinstance(raw_out, dict):
        # common nested keys
        for key in ("dscd_outputs", "dscd", "dscd_out", "dscd_outputs_cpu"):
            v = raw_out.get(key, None)
            if isinstance(v, dict):
                return v
        # if the dict itself contains proto_probs/explanations
        if any(k in raw_out for k in ("proto_probs", "explanations", "span_preds", "uncertainties", "trg_explanations")):
            return raw_out
        # sometimes inside 'outputs' / 'result' fields
        for key in ("outputs", "result", "result_dict"):
            v = raw_out.get(key, None)
            if isinstance(v, dict) and any(k in v for k in ("proto_probs", "explanations", "span_preds", "uncertainties")):
                return v
        # best-effort: return raw_out (caller will handle missing keys)
        return raw_out

    # If list/tuple, search for a dict inside
    if isinstance(raw_out, (list, tuple)):
        for item in raw_out:
            if isinstance(item, dict):
                sub = _extract_dscd_outputs(item)
                if sub:
                    return sub
    # otherwise unknown shape
    return {}


def _get_explanations_list(dscd: Dict[str, Any]) -> List[List[Dict[str, Any]]]:
    """
    Normalize various 'explanations' layouts into a list-of-lists where each outer entry
    corresponds to a sentence.
    Accepts:
      - explanations: [ {..}, {..} ]  -> wrapped -> [ [..] ]
      - explanations: [ [ {..}, ... ], [ ... ] ] -> returned as-is
      - explanations: dict keyed by sentence idx -> converted to ordered list if possible
    """
    if not dscd:
        return []
    expl = None
    for k in ("explanations", "trg_explanations", "explanations_per_sentence", "exps", "explanations_list"):
        if k in dscd:
            expl = dscd[k]
            break
    if expl is None:
        return []

    # list-of-lists -> pass through
    if isinstance(expl, list):
        if len(expl) == 0:
            return []
        if isinstance(expl[0], list):
            return expl
        # list-of-dicts -> treat as single sentence
        if isinstance(expl[0], dict):
            return [expl]
    # if dict keyed by sentence index
    if isinstance(expl, dict):
        try:
            # try numeric keys first
            numeric_keys = sorted((int(k) for k in expl.keys() if str(k).isdigit()))
            if numeric_keys:
                out = []
                for nk in numeric_keys:
                    v = expl.get(str(nk), expl.get(nk))
                    if isinstance(v, list):
                        out.append(v)
                    elif isinstance(v, dict):
                        out.append([v])
                if out:
                    return out
        except Exception:
            pass
    return []


def _is_subword_token(token: Optional[str]) -> bool:
    """
    Heuristic for detecting subword tokens/fragments to filter.
    Important: SentencePiece uses '‚ñÅ' to mark word-start. Tokens that START with '‚ñÅ'
    are word-beginnings and should NOT be treated as subword fragments.
    Treat '##' and '@@' as continuation markers (subword fragments).
    """
    if token is None:
        return True
    t = str(token).strip()
    if t == "":
        return True
    # continuation markers (BPE style)
    if t.startswith("##") or t.startswith("@@"):
        return True
    # SentencePiece word start marker -> NOT subword
    if t.startswith("‚ñÅ"):
        return False
    # short fragments (after stripping leading markers)
    clean = t.lstrip("‚ñÅ").lstrip("ƒ†").replace("</w>", "").strip()
    if len(clean) < 2:
        return True
    # punctuation-only or digit-only
    if all(ch in '.,!?;:()[]{}"\'-‚Äî‚Äì/\\' for ch in clean):
        return True
    if clean.isdigit():
        return True
    return False


def _should_filter_explanation(expl: Dict[str, Any], span_th: float, u_th: float) -> bool:
    """
    Return True if an explanation should be filtered out because it is low-quality.
    Filter if:
      - token is subword/empty/punctuation
      - BOTH span <= span_th and uncertainty <= u_th (i.e., not enough signal)
    """
    try:
        token_raw = expl.get("token", "") or expl.get("ambiguous_word", "") or expl.get("token_value", "")
        # prefer token field that exists; sanitize
        token = str(token_raw)
        # remove SPM markers for length check
        token_clean = token.lstrip("‚ñÅ").lstrip("ƒ†").replace("</w>", "").strip()
        # canonicalize if normalizer available (helps group inflected forms)
        if _normalize_fn and token_clean:
            try:
                token_clean = _normalize_fn(token_clean)
            except Exception:
                pass
        # filter tiny/punct tokens
        if not token_clean or len(token_clean) < 2 or all(ch in '.,!?;:()[]{}"\'-‚Äî‚Äì/\\' for ch in token_clean):
            return True

        span = float(expl.get("span", 0.0) or 0.0)
        uncertainty = float(expl.get("uncertainty", 0.0) or 0.0)

        # If both metrics are below thresholds, filter out
        if span <= span_th and uncertainty <= u_th:
            return True
        return False
    except Exception:
        if _VERBOSE_LOGGING:
            traceback.print_exc()
        return True


def _force_english_bos(tokenizer, mbart_model) -> Optional[int]:
    """
    Try to determine English forced BOS id for tokenizer and set it in mbart_model.config.
    Return the forced_id or None.
    """
    forced_id = None
    try:
        if hasattr(tokenizer, "get_lang_id"):
            forced_id = tokenizer.get_lang_id(_EN_LANG)
        elif hasattr(tokenizer, "lang_code_to_id"):
            forced_id = tokenizer.lang_code_to_id.get(_EN_LANG, None)
    except Exception:
        forced_id = None

    if forced_id is not None and hasattr(mbart_model, "config"):
        try:
            mbart_model.config.forced_bos_token_id = int(forced_id)
            mbart_model.config.decoder_start_token_id = int(forced_id)
        except Exception:
            if _VERBOSE_LOGGING:
                print("[CELL8] Could not set forced_bos_token_id on mbart config")
    return forced_id


# ------------------------------------------------------------------------------
# translate_with_explanations
# ------------------------------------------------------------------------------
def translate_with_explanations(
    model,
    tokenizer,
    input_sentence: str,
    device: Optional[torch.device] = None,
    span_threshold: Optional[float] = None,
    uncertainty_threshold: Optional[float] = None,
) -> Dict[str, Any]:
    device = _DEVICE if device is None else device
    span_th = _REAL_AMB_SPAN_THRESHOLD if span_threshold is None else float(span_threshold)
    u_th = _REAL_AMB_UNCERTAINTY_THRESHOLD if uncertainty_threshold is None else float(uncertainty_threshold)

    try:
        # prepare encoding
        try:
            # Some tokenizers use src_lang attribute (M2M100)
            if hasattr(tokenizer, "src_lang"):
                try:
                    setattr(tokenizer, "src_lang", _BN_LANG)
                except Exception:
                    pass
        except Exception:
            pass

        enc = tokenizer(
            input_sentence,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=_MAX_LENGTH
        )
        enc = _to_device_batch(enc, device)

        # ensure model in eval
        model.eval()
        core = model.module if (_USE_MULTI_GPU and hasattr(model, "module")) else model

        # Attempt to get DSCD/TRG outputs via forward_with_explanations or forward
        raw_dscd_out = {}
        try:
            with torch.inference_mode():
                if hasattr(core, "forward_with_explanations"):
                    try:
                        raw_dscd_out = core.forward_with_explanations(
                            input_ids=enc.get("input_ids"),
                            attention_mask=enc.get("attention_mask"),
                            src_texts=[input_sentence]
                        )
                    except TypeError:
                        # fallback positional argument order
                        raw_dscd_out = core.forward_with_explanations(enc.get("input_ids"), enc.get("attention_mask"), [input_sentence])
                else:
                    # try generic forward and extract DSCD outputs
                    try:
                        out = core.forward(input_ids=enc.get("input_ids"), attention_mask=enc.get("attention_mask"), src_texts=[input_sentence], labels=None)
                    except TypeError:
                        out = core.forward(enc.get("input_ids"), enc.get("attention_mask"), [input_sentence], None)
                    raw_dscd_out = _extract_dscd_outputs(out)
        except Exception as e:
            if _VERBOSE_LOGGING:
                print("[CELL8] DSCD/TRG forward error:", str(e))
                traceback.print_exc()
            raw_dscd_out = {}

        # Prepare mbart.generate (if available)
        translation = ""
        mbart_obj = getattr(core, "mbart", None)
        if mbart_obj is None:
            if _VERBOSE_LOGGING:
                print("[CELL8] core.mb is missing .mbart -> skipping generation")
            translation = ""
        else:
            forced_id = _force_english_bos(tokenizer, mbart_obj)
            orig_use_cache = None
            try:
                if hasattr(mbart_obj, "config"):
                    orig_use_cache = getattr(mbart_obj.config, "use_cache", None)
                    mbart_obj.config.use_cache = True
            except Exception:
                orig_use_cache = None

            generated = None
            try:
                try:
                    # primary generation call
                    pad_id = getattr(tokenizer, "pad_token_id", None) or getattr(tokenizer, "eos_token_id", None) or 1
                    generated = mbart_obj.generate(
                        enc.get("input_ids"),
                        attention_mask=enc.get("attention_mask"),
                        max_length=min(_MAX_LENGTH, 64),
                        num_beams=2,
                        early_stopping=True,
                        pad_token_id=int(pad_id),
                        forced_bos_token_id=forced_id if forced_id is not None else getattr(mbart_obj.config, "forced_bos_token_id", None),
                    )
                except RuntimeError as gen_err:
                    # handle OOM by trying more conservative generation options
                    if "out of memory" in str(gen_err).lower():
                        if torch.cuda.is_available():
                            try:
                                torch.cuda.empty_cache()
                            except Exception:
                                pass
                        try:
                            small_enc = tokenizer(input_sentence, return_tensors="pt", padding=True, truncation=True, max_length=min(_MAX_LENGTH, 48))
                            small_enc = _to_device_batch(small_enc, device)
                            pad_id = getattr(tokenizer, "pad_token_id", None) or getattr(tokenizer, "eos_token_id", None) or 1
                            generated = mbart_obj.generate(
                                small_enc.get("input_ids"),
                                attention_mask=small_enc.get("attention_mask"),
                                max_length=min(_MAX_LENGTH, 48),
                                num_beams=1,
                                early_stopping=True,
                                pad_token_id=int(pad_id),
                                forced_bos_token_id=forced_id if forced_id is not None else getattr(mbart_obj.config, "forced_bos_token_id", None),
                            )
                        except Exception as e2:
                            if _VERBOSE_LOGGING:
                                print("[CELL8] fallback generation also failed:", str(e2))
                                traceback.print_exc()
                            generated = None
                    else:
                        # other runtime error -> re-raise to outer handler
                        raise
            finally:
                # restore original cache setting
                try:
                    if hasattr(mbart_obj, "config") and orig_use_cache is not None:
                        mbart_obj.config.use_cache = orig_use_cache
                except Exception:
                    pass

            # decode translation safely
            if generated is not None:
                try:
                    if isinstance(generated, (list, tuple)):
                        translation = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
                    elif isinstance(generated, torch.Tensor):
                        # generated may be shape (1, L) or (N, L)
                        if generated.dim() == 2:
                            translation = tokenizer.decode(generated[0], skip_special_tokens=True)
                        else:
                            translation = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
                    else:
                        translation = str(generated)
                except Exception:
                    try:
                        translation = tokenizer.batch_decode(generated, skip_special_tokens=True)[0]
                    except Exception:
                        if _VERBOSE_LOGGING:
                            print("[CELL8] decode failed for generated; returning empty translation")
                        translation = ""
            else:
                translation = ""

        # Process DSCD/TRG explanations
        dscd_out = _extract_dscd_outputs(raw_dscd_out)
        explanations_list = _get_explanations_list(dscd_out)
        sentence_explanations = explanations_list[0] if (isinstance(explanations_list, list) and len(explanations_list) > 0) else []

        real_amb_count = 0
        out_explanations: List[Dict[str, Any]] = []
        if isinstance(sentence_explanations, list):
            for ex in sentence_explanations:
                try:
                    if _should_filter_explanation(ex, span_th, u_th):
                        continue
                    s_val = float(ex.get("span", 0.0) or 0.0)
                    u_val = float(ex.get("uncertainty", 0.0) or 0.0)
                    is_real = (s_val > span_th) or (u_val > u_th)
                    if is_real:
                        real_amb_count += 1
                    # canonical ambiguous token for output: try several keys and clean markers
                    raw_tok = ex.get("token") or ex.get("ambiguous_word") or ex.get("token_value") or ""
                    tok_str = str(raw_tok)
                    tok_clean = tok_str.lstrip("‚ñÅ").lstrip("ƒ†").replace("</w>", "").strip()
                    if _normalize_fn and tok_clean:
                        try:
                            tok_clean = _normalize_fn(tok_clean)
                        except Exception:
                            pass

                    out_explanations.append({
                        "ambiguous_word": tok_clean,
                        "position": ex.get("token_idx", ex.get("position", "N/A")),
                        "explanation": ex.get("explanation", "") or ex.get("explain", "") or ex.get("text", "") or "",
                        "uncertainty": float(u_val),
                        "span": float(s_val),
                        "is_real_amb": bool(is_real),
                    })
                except Exception:
                    if _VERBOSE_LOGGING:
                        traceback.print_exc()
                    continue

        result = {
            "input_sentence": input_sentence,
            "translation": translation,
            "ambiguous_words_detected": int(real_amb_count),
            "explanations": out_explanations,
        }
        return result

    except Exception as e:
        if _VERBOSE_LOGGING:
            traceback.print_exc()
        return {
            "input_sentence": input_sentence,
            "translation": "",
            "ambiguous_words_detected": 0,
            "explanations": [],
            "error": str(e)[:200],
        }


# ------------------------------------------------------------------------------
# demonstrate_system: small runner that prints nicely
# ------------------------------------------------------------------------------
def demonstrate_system(model, tokenizer, sentences: Optional[List[str]] = None):
    if sentences is None:
        sentences = [
            "‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§",
            "‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§",
            "‡¶™‡¶æ‡¶§‡¶æ ‡¶ù‡¶∞‡ßá ‡¶™‡¶°‡¶º‡ßá‡¶õ‡ßá‡•§",
            "‡¶§‡¶ø‡¶®‡¶ø ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï ‡¶ó‡ßá‡¶õ‡ßá‡¶®‡•§",
            "‡¶Ü‡¶ú ‡¶≠‡¶æ‡¶≤ ‡¶Ü‡¶¨‡¶π‡¶æ‡¶ì‡¶Ø‡¶º‡¶æ‡•§",
        ]
    print("=" * 80)
    print("TATN DEMO: translating and listing DSCD/TRG explanations")
    print("=" * 80)
    for s in sentences:
        print(f"\nInput: {s}")
        res = translate_with_explanations(model, tokenizer, s)
        print("Translation:", res.get("translation", ""))
        print("Ambiguous words detected (real):", res.get("ambiguous_words_detected", 0))
        if res.get("explanations"):
            for idx, ex in enumerate(res["explanations"], 1):
                print(f"  {idx}. word='{ex['ambiguous_word']}' pos={ex['position']} span={ex['span']:.3f} U={ex['uncertainty']:.3f} real={ex['is_real_amb']}")
                print("     ", (ex.get("explanation") or "")[:200])
        else:
            print("  No explanations")
    print("=" * 80)


# ------------------------------------------------------------------------------
# dscd_discovery_warmup: warm-up helper (kept for convenience)
# ------------------------------------------------------------------------------
def dscd_discovery_warmup(model, tokenizer, num_sents: int = 8000, batch_size: int = 64, max_len: Optional[int] = None):
    if max_len is None:
        max_len = _MAX_LENGTH

    core = model.module if (_USE_MULTI_GPU and hasattr(model, "module")) else model
    dscd = getattr(core, "dscd", None)
    if dscd is None:
        print("[WARMUP] No DSCD attached to model; skipping.")
        return

    print("[WARMUP] Starting DSCD discovery warmup...")
    orig_enable = getattr(dscd, "enable_training_clustering", False)
    orig_n_min = getattr(dscd, "n_min", None)
    orig_buffer = getattr(dscd, "buffer_size", None)

    try:
        dscd.enable_training_clustering = True
        dscd.n_min = max(3, int(getattr(dscd, "n_min", 5)))
        dscd.buffer_size = max(200, int(getattr(dscd, "buffer_size", 300)))
    except Exception:
        if _VERBOSE_LOGGING:
            traceback.print_exc()

    texts = []
    try:
        if "load_and_preprocess_optimized" in globals():
            pairs = load_and_preprocess_optimized(num_sents)
            texts = [bn for (bn, _) in pairs][:num_sents]
        else:
            base = ["‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§", "‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§", "‡¶™‡¶æ‡¶§‡¶æ ‡¶ù‡¶∞‡ßá ‡¶™‡¶°‡¶º‡ßá‡¶õ‡ßá‡•§", "‡¶§‡¶ø‡¶®‡¶ø ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï ‡¶ó‡ßá‡¶õ‡ßá‡¶®‡•§"]
            while len(texts) < num_sents:
                texts.extend(base)
            texts = texts[:num_sents]
    except Exception:
        if _VERBOSE_LOGGING:
            traceback.print_exc()
        texts = ["‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§"] * num_sents

    processed = 0
    core.eval()
    with torch.inference_mode():
        for i in range(0, len(texts), batch_size):
            batch = texts[i:i + batch_size]
            try:
                enc = tokenizer(batch, return_tensors="pt", padding=True, truncation=True, max_length=max_len)
                enc = _to_device_batch(enc, _DEVICE)
                if hasattr(core, "forward_with_explanations"):
                    try:
                        core.forward_with_explanations(input_ids=enc.get("input_ids"), attention_mask=enc.get("attention_mask"), src_texts=batch)
                    except TypeError:
                        core.forward_with_explanations(enc.get("input_ids"), enc.get("attention_mask"), batch)
                else:
                    try:
                        if hasattr(core, "mbart") and hasattr(core.mbart.model, "encoder"):
                            core.mbart.model.encoder(input_ids=enc.get("input_ids"), attention_mask=enc.get("attention_mask"))
                    except Exception:
                        pass
                processed += len(batch)
                if _VERBOSE_LOGGING and ((i // batch_size) % 10 == 0):
                    print(f"[WARMUP] processed {processed}/{len(texts)} ({processed/len(texts)*100:.1f}%)")
            except Exception as e:
                if _VERBOSE_LOGGING:
                    print("[WARMUP] batch failed:", str(e))
                    traceback.print_exc()
                continue

    try:
        stores = getattr(dscd, "prototype_stores", {}) or {}
        num_types = len(stores)
        total_protos = sum(store.size() for store in stores.values()) if stores else 0
        multi = sum(1 for store in stores.values() if store.size() >= 2) if stores else 0
        print(f"[WARMUP] Prototype discovery: word_types={num_types}, total_protos={total_protos}, multi_sense={multi}")
    except Exception:
        if _VERBOSE_LOGGING:
            traceback.print_exc()
    finally:
        try:
            dscd.enable_training_clustering = orig_enable
            if orig_n_min is not None:
                dscd.n_min = orig_n_min
            if orig_buffer is not None:
                dscd.buffer_size = orig_buffer
            print("[WARMUP] Restored DSCD configuration")
        except Exception:
            if _VERBOSE_LOGGING:
                traceback.print_exc()


# End of Cell 8
print("‚úÖ Cell 8: Inference pipeline & warmup helpers loaded (patched and hardened)")

‚úÖ Cell 8: Inference pipeline & warmup helpers loaded (patched and hardened)


In [13]:
# ==============================================================================
# CELL 9 (patched): COMPREHENSIVE TESTING & EVALUATION (MULTI-GPU OPTIMIZED)
# DEBUGGED, HARDENED, and MORE DEFENSIVE
# ==============================================================================
# Fixes applied (line-by-line highlights):
#  - Robust global lookups via globals().get(...) with safe defaults.
#  - Defensive handling when model is DataParallel / wrapped / None.
#  - Guarded access to DSCD internals (prototype_stores) and safe numeric conversions.
#  - Resilient use of translate_with_explanations (handles missing function gracefully).
#  - Added normalization (normalize_bn_word) for printed tokens when available.
#  - Clearer cluster-stat printing and protections for empty/no-prototype cases.
#  - Wrapped all per-test calls with try/except so one failing test doesn't abort evaluation.
#  - Ensured all numeric computations guard division-by-zero and invalid types.
#  - Useful verbose logging controlled by VERBOSE_LOGGING; debug traces only when enabled.
# ==============================================================================
from typing import Dict, List, Tuple, Optional, Any
import torch
import traceback
import math

# Robust reads from globals (Cell 0)
_USE_MULTI_GPU = bool(globals().get("USE_MULTI_GPU", torch.cuda.is_available() and torch.cuda.device_count() > 1))
_BN_LANG = str(globals().get("BN_LANG", "bn"))
_VERBOSE_LOGGING = bool(globals().get("VERBOSE_LOGGING", False))

# thresholds fallback consistent with earlier cells
_SPAN_THRESHOLD = float(globals().get("SPAN_THRESHOLD", 0.3))
_UNCERTAINTY_THRESHOLD = float(globals().get("TAU_LOW", 0.4))

# optional normalizer
_normalize_fn = globals().get("normalize_bn_word", None)


# ---------
# Cluster analysis helpers (defensive)
# ---------
def _get_cluster_count(model: torch.nn.Module) -> int:
    try:
        dscd = model.module.dscd if hasattr(model, "module") else getattr(model, "dscd", None)
        stores = getattr(dscd, "prototype_stores", None) if dscd is not None else None
        if not stores:
            return 0
        return len(stores)
    except Exception:
        if _VERBOSE_LOGGING:
            traceback.print_exc()
        return 0


def _print_top_clusters(model: torch.nn.Module, top_n: int = 5):
    """
    Print top N clusters by sample count (homographs discovered by DSCD).
    Defensive against missing attributes and types.
    """
    try:
        dscd = model.module.dscd if hasattr(model, "module") else getattr(model, "dscd", None)
        prototype_stores = getattr(dscd, "prototype_stores", None) or {}
        if not prototype_stores:
            print("[CLUSTER] No clusters found yet")
            return

        cluster_info = []
        for token, store in prototype_stores.items():
            try:
                total_count = int(sum(getattr(store, "counts", []) or []))
            except Exception:
                total_count = 0
            try:
                n_protos = int(store.size()) if hasattr(store, "size") else len(getattr(store, "centroids", []) or [])
            except Exception:
                n_protos = 0
            mu = float(getattr(store, "mu", 0.0) or 0.0)
            tau = float(getattr(store, "tau", 0.0) or 0.0)
            cluster_info.append({
                "token": token,
                "count": total_count,
                "protos": n_protos,
                "mu": mu,
                "tau": tau
            })

        cluster_info.sort(key=lambda x: x["count"], reverse=True)

        display_n = min(top_n, len(cluster_info))
        print(f"\n[CLUSTER] Top {display_n} clusters (by sample count):")
        print("-" * 90)
        print(f"{'Rank':<6}{'Token':<18}{'Count':<12}{'Protos':<10}{'Œº (mean)':<15}{'œÑ (dev)':<12}")
        print("-" * 90)
        for rank, info in enumerate(cluster_info[:display_n], 1):
            tstr = str(info["token"])
            token_display = (tstr[:15] + "..") if len(tstr) > 17 else tstr
            print(f"{rank:<6}{token_display:<18}{info['count']:<12}{info['protos']:<10}{info['mu']:<15.6f}{info['tau']:<12.6f}")
        print("-" * 90)
        total_samples = sum(c["count"] for c in cluster_info)
        print(f"Total clusters: {len(cluster_info)} | Total samples in clusters: {total_samples}")
    except Exception as e:
        print(f"[CLUSTER] Error: {type(e).__name__}: {str(e)[:200]}")
        if _VERBOSE_LOGGING:
            traceback.print_exc()


def _print_cluster_stats(model: torch.nn.Module):
    """
    Aggregate cluster statistics: totals and simple distribution values.
    """
    try:
        dscd = model.module.dscd if hasattr(model, "module") else getattr(model, "dscd", None)
        prototype_stores = getattr(dscd, "prototype_stores", None) or {}
        if not prototype_stores:
            if _VERBOSE_LOGGING:
                print("[CLUSTER-STATS] No prototype stores.")
            return

        total_clusters = len(prototype_stores)
        total_samples = 0
        total_protos = 0
        cluster_counts = []
        for token, store in prototype_stores.items():
            try:
                cnt = int(sum(getattr(store, "counts", []) or []))
            except Exception:
                cnt = 0
            protos = int(store.size()) if hasattr(store, "size") else len(getattr(store, "centroids", []) or [])
            total_samples += cnt
            total_protos += protos
            cluster_counts.append(cnt)

        avg_samples = (total_samples / total_clusters) if total_clusters > 0 else 0.0
        avg_protos = (total_protos / total_clusters) if total_clusters > 0 else 0.0
        max_samples = max(cluster_counts) if cluster_counts else 0
        min_samples = min(cluster_counts) if cluster_counts else 0

        print("\n[CLUSTER-STATS] Cluster Statistics:")
        print(f"  ‚Ä¢ Total clusters: {total_clusters}")
        print(f"  ‚Ä¢ Total samples: {total_samples}")
        print(f"  ‚Ä¢ Total prototypes: {total_protos}")
        print(f"  ‚Ä¢ Avg samples/cluster: {avg_samples:.1f}")
        print(f"  ‚Ä¢ Avg protos/cluster: {avg_protos:.1f}")
        print(f"  ‚Ä¢ Max samples/cluster: {max_samples}")
        print(f"  ‚Ä¢ Min samples/cluster: {min_samples}")
    except Exception as e:
        print(f"[CLUSTER-STATS] Error: {type(e).__name__}: {str(e)[:200]}")
        if _VERBOSE_LOGGING:
            traceback.print_exc()


# ----------------------------
# Evaluation routine
# ----------------------------
@torch.inference_mode()
def comprehensive_post_training_testing(model: torch.nn.Module, tokenizer) -> Dict[str, Any]:
    """
    Compact comprehensive evaluation:
      - Translate curated Bengali sentences
      - Count detected ambiguous tokens (real ambiguity: span>_SPAN_THRESHOLD or uncertainty>_UNCERTAINTY_THRESHOLD)
      - Print explanations and DSCD prototype stats
      - Optionally run small DSCD warmup if no prototypes and helper exists
    Returns aggregated metrics dict.
    """
    print("\n" + "=" * 80)
    print("COMPREHENSIVE POST-TRAINING EVALUATION (Cell 9)")
    print("=" * 80)

    test_sentences: List[Tuple[str, str]] = [
        ("‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§", "‡¶ï‡¶≤ = tap / call"),
        ("‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§", "‡¶ï‡¶æ‡¶≤ = tomorrow / yesterday"),
        ("‡¶™‡¶æ‡¶§‡¶æ ‡¶ù‡¶∞‡ßá ‡¶™‡¶°‡¶º‡ßá‡¶õ‡ßá‡•§", "‡¶™‡¶æ‡¶§‡¶æ = leaf / page"),
        ("‡¶§‡¶ø‡¶®‡¶ø ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï ‡¶ó‡ßá‡¶õ‡ßá‡¶®‡•§", "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï = bank / embankment"),
        ("‡¶Ü‡¶ú ‡¶≠‡¶æ‡¶≤ ‡¶Ü‡¶¨‡¶π‡¶æ‡¶ì‡ßü‡¶æ‡•§", "Simple sentence (no ambiguity expected)"),
    ]

    # prefer underlying core if DataParallel wrapping was used
    core_model = model.module if (_USE_MULTI_GPU and hasattr(model, "module")) else model
    try:
        core_model.eval()
    except Exception:
        pass

    # If DSCD has no prototypes and warmup helper exists, run a shorter warmup (best-effort)
    try:
        dscd = getattr(core_model, "dscd", None)
        stores = getattr(dscd, "prototype_stores", None) if dscd is not None else None
        # only run warmup if no prototypes at all and warmup helper available
        if (not stores or len(stores) == 0) and "dscd_discovery_warmup" in globals():
            try:
                print("[EVAL] No DSCD prototypes found. Running moderate warmup (num_sents=2000)...")
                # run a modest warmup to seed prototypes (user can skip if heavy)
                dscd_discovery_warmup(core_model, tokenizer, num_sents=2000, batch_size=64)
            except Exception as e:
                print(f"[EVAL] DSCD warmup failed/skipped: {type(e).__name__}: {str(e)[:200]}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
    except Exception:
        if _VERBOSE_LOGGING:
            traceback.print_exc()

    # Metrics
    total_tests = len(test_sentences)
    successful_translations = 0
    total_explanations = 0
    total_high_span = 0
    total_real_ambiguous = 0

    print(f"\n[EVAL] Running {total_tests} tests...")
    print("-" * 80)

    # Configure tokenizer if it supports src_lang attribute
    try:
        tokenizer.src_lang = _BN_LANG
    except Exception:
        pass

    def _is_real_amb(expl: Dict[str, Any]) -> bool:
        try:
            s = float(expl.get("span", 0.0) or 0.0)
            u = float(expl.get("uncertainty", 0.0) or 0.0)
            return (s > _SPAN_THRESHOLD) or (u > _UNCERTAINTY_THRESHOLD)
        except Exception:
            return False

    # iterate tests
    for idx, (src_text, desc) in enumerate(test_sentences, 1):
        print(f"\nTest {idx}/{total_tests}: {desc}")
        print("=" * 60)
        try:
            if "translate_with_explanations" not in globals():
                print("[EVAL] translate_with_explanations not available; skipping this test.")
                continue

            try:
                # call the inference wrapper - pass core_model (best-effort)
                result = translate_with_explanations(core_model if core_model is not None else model, tokenizer, src_text)
            except Exception as e:
                print(f"[EVAL] translate_with_explanations raised: {type(e).__name__}: {str(e)[:200]}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
                result = {"translation": "", "ambiguous_words_detected": 0, "explanations": []}

            translation = str(result.get("translation", "") or "")
            try:
                amb_count = int(result.get("ambiguous_words_detected", 0) or 0)
            except Exception:
                amb_count = 0
            explanations = result.get("explanations", []) or []

            print(f"Input: {src_text}")
            print(f"Translation: {translation}")
            print(f"Ambiguous Words (real, counted): {amb_count}")

            if explanations:
                print("\nExplanations:")
                high_span_local = 0
                real_amb_local = 0
                for j, expl in enumerate(explanations, 1):
                    try:
                        span_val = float(expl.get("span", 0.0) or 0.0)
                        u_val = float(expl.get("uncertainty", 0.0) or 0.0)
                        marker = "[SPAN>0.3]" if span_val > _SPAN_THRESHOLD else "           "
                        raw_word = expl.get("ambiguous_word", expl.get("token", "N/A"))
                        word = str(raw_word or "N/A")
                        # normalize for display if possible
                        try:
                            if _normalize_fn and isinstance(word, str) and word.strip():
                                word = _normalize_fn(word)
                        except Exception:
                            pass
                        pos = expl.get("position", expl.get("token_idx", "N/A"))
                        print(f"  {j}. {marker} '{word}' @ pos {pos}")
                        print(f"       U={u_val:.3f} | S={span_val:.3f}")
                        text = str(expl.get("explanation", "") or "")
                        if len(text) > 120:
                            text = text[:120] + "..."
                        print(f"       {text}")
                        if span_val > _SPAN_THRESHOLD:
                            high_span_local += 1
                        if _is_real_amb(expl):
                            real_amb_local += 1
                    except Exception:
                        if _VERBOSE_LOGGING:
                            traceback.print_exc()
                        continue

                total_explanations += len(explanations)
                total_high_span += high_span_local
                total_real_ambiguous += real_amb_local
            else:
                print("No explanations produced (likely high-confidence translation)")

            # Consider translation successful if non-empty and not an error sentinel
            try:
                bad_sentinels = {"", "Error occurred", "Translation generation failed", "ERROR DURING TRANSLATION"}
                if translation and translation.strip() and translation not in bad_sentinels:
                    successful_translations += 1
                    print("Translation successful")
                else:
                    print("Translation failed or empty")
            except Exception:
                print("Translation check encountered an error; counted as failure")

        except Exception as e:
            print(f"[EVAL] Test {idx} failed: {type(e).__name__}: {str(e)[:200]}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
            continue

        print("-" * 60)

    # DSCD statistics (best-effort)
    try:
        dscd_stats = {"total_words": 0, "multi_sense_words": 0, "total_prototypes": 0}
        dscd = getattr(core_model, "dscd", None)
        if dscd is not None and hasattr(dscd, "prototype_stores"):
            stores = getattr(dscd, "prototype_stores") or {}
            total_words = 0
            multi = 0
            total_protos = 0
            for key, store in stores.items():
                try:
                    sz = int(store.size()) if hasattr(store, "size") else len(getattr(store, "centroids", []) or [])
                except Exception:
                    sz = 0
                total_words += 1
                total_protos += sz
                if sz >= 2:
                    multi += 1
            dscd_stats = {"total_words": total_words, "multi_sense_words": multi, "total_prototypes": total_protos}
        else:
            dscd_stats = {"total_words": 0, "multi_sense_words": 0, "total_prototypes": 0}
    except Exception as e:
        print(f"[EVAL] Could not retrieve DSCD stats: {type(e).__name__}: {str(e)[:200]}")
        if _VERBOSE_LOGGING:
            traceback.print_exc()
        dscd_stats = {"total_words": 0, "multi_sense_words": 0, "total_prototypes": 0}

    # Summary
    print("\n" + "=" * 80)
    print("EVALUATION SUMMARY")
    print("=" * 80)
    print(f"Total tests: {total_tests}")
    print(f"Successful translations: {successful_translations}")
    success_rate = (successful_translations / total_tests * 100.0) if total_tests > 0 else 0.0
    print(f"Success rate: {success_rate:.1f}%")
    print("")
    print("Ambiguity detection:")
    print(f"  - Total explanations produced: {total_explanations}")
    print(f"  - High-span (S>{_SPAN_THRESHOLD}): {total_high_span}")
    print(f"  - Real ambiguous (S>{_SPAN_THRESHOLD} or U>{_UNCERTAINTY_THRESHOLD}): {total_real_ambiguous}")
    if total_tests > 0:
        print(f"  - Avg explanations/test: {total_explanations / total_tests:.2f}")
        print(f"  - Avg real ambiguous/test: {total_real_ambiguous / total_tests:.2f}")
    print("")
    print("DSCD Prototype Discovery:")
    print(f"  - Word types tracked: {dscd_stats.get('total_words', 0)}")
    print(f"  - Multi-sense words (>=2 protos): {dscd_stats.get('multi_sense_words', 0)}")
    print(f"  - Total prototypes: {dscd_stats.get('total_prototypes', 0)}")
    if dscd_stats.get("total_words", 0) > 0:
        avg_protos_word = dscd_stats.get("total_prototypes", 0) / max(1, dscd_stats.get("total_words", 1))
        print(f"  - Avg prototypes/word: {avg_protos_word:.2f}")
    print("=" * 80)

    return {
        "total_tests": total_tests,
        "successful_translations": successful_translations,
        "success_rate_pct": success_rate,
        "total_explanations": total_explanations,
        "total_high_span": total_high_span,
        "total_real_ambiguous": total_real_ambiguous,
        "dscd_stats": dscd_stats,
    }


print("‚úÖ Cell 9: Comprehensive testing & evaluation ready (debugged + hardened).")

‚úÖ Cell 9: Comprehensive testing & evaluation ready (debugged + hardened).


In [14]:
# ==============================================================================
# CELL 10 (patched): TATN MAIN PIPELINE (DISCOVERY FIXES + HOMOGRAPH VERIFICATION)
# Debugged, hardened, and self-contained replacement of original Cell 10.
# ==============================================================================
# Key behavior:
#  - Robust, defensive global reads and fallbacks
#  - Safe tokenizer loading with whitespace fallback if transformers missing
#  - Minimal dataset fallback when data loading utilities are absent
#  - Robust DataLoader construction and safe collate fallback
#  - Conservative, non-forcing clustering limits and clear diagnostics
#  - Normalized homograph verification (canonical forms) and helpful warnings
#  - Safe saving of model state
# ==============================================================================

import os
import time
import traceback
from typing import Tuple, Optional, Iterable, List

import gc
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

import unicodedata

# -------------------------
# Safe defaults (if Cell 0 not executed)
# -------------------------
FREEZE_ENCODER = False

def _g(name, default):
    """Defensive global getter."""
    return globals().get(name, default)

# Pull globals defensively (fall back to sane defaults)
try:
    _USE_MULTI_GPU = bool(_g("USE_MULTI_GPU", False))
    _NUM_GPUS = int(_g("NUM_GPUS", torch.cuda.device_count() if torch.cuda.is_available() else 0))
    _DEVICE = _g("DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    _BN_LANG = _g("BN_LANG", "bn")
    _EN_LANG = _g("EN_LANG", "en")
    _NUM_SAMPLES = int(_g("NUM_SAMPLES", 30000))
    _MAX_LENGTH = int(_g("MAX_LENGTH", 48))
    _BATCH_SIZE = int(_g("BATCH_SIZE", 8))
    _EPOCHS = int(_g("EPOCHS", 1))
    _ACCUMULATION_STEPS = int(_g("ACCUMULATION_STEPS", 1))
    _LR_NMT = float(_g("LR_NMT", 2e-5))
    _LR_PHI = float(_g("LR_PHI", 1e-5))
    _ENABLE_ASBN_TRAINING = bool(_g("ENABLE_ASBN_TRAINING", False))
    _VALIDATION_CHECK_INTERVAL = int(_g("VALIDATION_CHECK_INTERVAL", 0))
    _DSCD_WARMUP_SAMPLES = int(_g("DSCD_WARMUP_SAMPLES", 8000))
    _VERBOSE_LOGGING = bool(_g("VERBOSE_LOGGING", False))
    _HOMOGRAPH_WATCHLIST_BN = set(_g("HOMOGRAPH_WATCHLIST_BN", {"‡¶ï‡¶≤", "‡¶ï‡¶æ‡¶≤", "‡¶™‡¶æ‡¶§‡¶æ", "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï"}))
except Exception:
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _BN_LANG = "bn"
    _EN_LANG = "en"
    _NUM_SAMPLES = 30000
    _MAX_LENGTH = 48
    _BATCH_SIZE = 8
    _EPOCHS = 1
    _ACCUMULATION_STEPS = 1
    _LR_NMT = 2e-5
    _LR_PHI = 1e-5
    _ENABLE_ASBN_TRAINING = False
    _VALIDATION_CHECK_INTERVAL = 0
    _DSCD_WARMUP_SAMPLES = 8000
    _VERBOSE_LOGGING = False
    _HOMOGRAPH_WATCHLIST_BN = {"‡¶ï‡¶≤", "‡¶ï‡¶æ‡¶≤", "‡¶™‡¶æ‡¶§‡¶æ", "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï"}

# DSCD clustering thresholds (defensive)
DSCD_N_MIN = int(globals().get("DSCD_N_MIN", 5))
DEFAULT_CLUSTER_MIN_SAMPLES = 20
_CLUSTER_MIN_SAMPLES = int(globals().get("DSCD_MIN_CLUSTER_SAMPLES", max(DEFAULT_CLUSTER_MIN_SAMPLES, DSCD_N_MIN * 2)))

# -------------------------
# Helper: Clear GPU caches safely
# -------------------------
def _safe_clear_gpu_caches():
    try:
        if "clear_all_gpu_caches" in globals():
            try:
                clear_all_gpu_caches()
            except Exception:
                pass
            return
        if torch.cuda.is_available():
            for i in range(torch.cuda.device_count()):
                try:
                    with torch.cuda.device(i):
                        torch.cuda.empty_cache()
                except Exception:
                    pass
    except Exception:
        pass

# -------------------------
# Small normalization helpers used for homograph matching
# -------------------------
def _norm_clean_token(tok: Optional[str]) -> str:
    if tok is None:
        return ""
    s = str(tok)
    # remove common subword markers and normalize to NFKC
    for marker in ('‚ñÅ', '##', 'ƒ†', '@@'):
        s = s.replace(marker, '')
    s = s.strip()
    s = unicodedata.normalize('NFKC', s)
    return s

def _token_matches_homograph(token_key: str, homograph: str) -> bool:
    clean_tok = _norm_clean_token(token_key)
    clean_h = _norm_clean_token(homograph)
    if not clean_tok or not clean_h:
        return False
    # exact or substring match both ways are considered (conservative)
    if clean_tok == clean_h:
        return True
    if clean_h in clean_tok:
        return True
    if clean_tok in clean_h:
        return True
    return False

# -------------------------
# Robust tokenizer loader (lazy imports + helpful errors + fallback)
# -------------------------
def _safe_tokenizer_from_pretrained(model_name: str, local_files_only: bool = False, prefer_fast: bool = True):
    """
    Robustly load a tokenizer. If transformers missing or unavailable, return a whitespace fallback
    that implements the key methods used downstream (decode, convert_ids_to_tokens, __len__, vocab_size).
    """
    try:
        import transformers as _tf
        from transformers import AutoTokenizer
    except Exception as e_tf:
        # Transformers not importable: return a richer whitespace fallback
        class _WhitespaceFallback:
            def __init__(self):
                self.pad_token = "<pad>"
                self.pad_token_id = 0
                self.vocab_size = 0
            def __len__(self):
                return int(self.vocab_size)
            def encode(self, text, add_special_tokens=True):
                if text is None:
                    return []
                return text.split()
            def convert_ids_to_tokens(self, ids):
                if ids is None:
                    return []
                out = []
                for x in ids:
                    if isinstance(x, str):
                        out.append(x)
                    else:
                        out.append(str(x))
                return out
            def decode(self, ids, skip_special_tokens=True, **kwargs):
                if ids is None:
                    return ""
                if isinstance(ids, (list, tuple)):
                    # join as strings
                    return " ".join([str(t) for t in ids])
                return str(ids)
            def __call__(self, texts, padding=False, truncation=False, return_tensors=None, max_length=None, add_special_tokens=True):
                if isinstance(texts, str):
                    texts = [texts]
                input_ids = []
                attention_mask = []
                for t in texts:
                    toks = (t or "").split()
                    input_ids.append(toks)
                    attention_mask.append([1] * len(toks))
                if return_tensors == "pt":
                    maxlen = max((len(x) for x in input_ids), default=0)
                    import torch as _torch
                    ids_t = _torch.zeros((len(input_ids), maxlen), dtype=_torch.long)
                    mask_t = _torch.zeros((len(input_ids), maxlen), dtype=_torch.long)
                    for i, row in enumerate(input_ids):
                        for j, tok in enumerate(row):
                            ids_t[i, j] = 0
                            mask_t[i, j] = 1
                    return {"input_ids": ids_t, "attention_mask": mask_t}
                return {"input_ids": input_ids, "attention_mask": attention_mask}
        if _VERBOSE_LOGGING:
            print("WARNING: 'transformers' import failed in _safe_tokenizer_from_pretrained(). Using whitespace fallback.")
            print(f"         Original error: {type(e_tf).__name__}: {e_tf}")
        return _WhitespaceFallback()

    tried = []
    try:
        from transformers import M2M100TokenizerFast as _M2MFast
    except Exception:
        _M2MFast = None

    if _M2MFast is not None:
        try:
            return _M2MFast.from_pretrained(model_name, local_files_only=local_files_only)
        except Exception as e:
            tried.append(("M2M100TokenizerFast", e))

    try:
        tok = AutoTokenizer.from_pretrained(model_name, use_fast=prefer_fast, local_files_only=local_files_only)
        return tok
    except Exception as e_auto:
        tried.append(("AutoTokenizer(use_fast=%s)" % prefer_fast, e_auto))
        msg = str(e_auto).lower()
        if "sentencepiece" in msg or "tokenizers" in msg or "sacremoses" in msg:
            raise RuntimeError(
                f"Failed to instantiate tokenizer for '{model_name}'. This often happens because optional deps like 'sentencepiece' or 'tokenizers' are missing.\n"
                "Please run: pip install transformers sentencepiece tokenizers\n"
                "Then RESTART the kernel and re-run cells 0‚Üí10.\n"
                f"Original tokenizer error: {e_auto}"
            ) from e_auto
        # try slow tokenizer as fallback
        try:
            tok = AutoTokenizer.from_pretrained(model_name, use_fast=False, local_files_only=local_files_only)
            return tok
        except Exception as e_slow:
            tried.append(("AutoTokenizer(use_fast=False)", e_slow))
            summary = "; ".join([f"{name}:{type(exc).__name__}" for name, exc in tried])
            raise RuntimeError(
                f"No usable tokenizer class available for '{model_name}'. Tried: {summary}.\n"
                "Make sure you have a compatible 'transformers' installed and the optional dependencies (sentencepiece, tokenizers) for the model.\n"
                "Suggested command:\n"
                "  pip install transformers sentencepiece tokenizers\n"
                "Then RESTART the kernel and re-run the notebook.\n"
                f"Last error: {e_slow}"
            ) from e_slow

# -------------------------
# Minimal fallback dataset if MemoryEfficientDataset is missing
# -------------------------
class _SimpleDataset(Dataset):
    """Minimal dataset used as a safe fallback. Tokenizes on the fly."""
    def __init__(self, pairs: Iterable[Tuple[str, str]], tokenizer, max_length: int = 48):
        self.pairs = list(pairs)
        self.tokenizer = tokenizer
        self.max_length = int(max_length)
    def __len__(self):
        return len(self.pairs)
    def __getitem__(self, idx):
        src, tgt = self.pairs[idx]
        try:
            enc = self.tokenizer(src, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)
            tgt_enc = self.tokenizer(tgt, return_tensors="pt", padding="max_length", truncation=True, max_length=self.max_length)
            input_ids = enc["input_ids"].squeeze(0)
            attention_mask = enc["attention_mask"].squeeze(0)
            labels = tgt_enc["input_ids"].squeeze(0)
        except Exception:
            # tokenizer fallback (whitespace)
            toks = (src or "").split()
            L = min(len(toks), self.max_length)
            import torch as _torch
            input_ids = _torch.zeros(self.max_length, dtype=_torch.long)
            attention_mask = _torch.zeros(self.max_length, dtype=_torch.long)
            for i in range(L):
                input_ids[i] = 0
                attention_mask[i] = 1
            labels = input_ids.clone()
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
            "src_text": src,
            "token_word_map": {}
        }

# -------------------------
# Main pipeline
# -------------------------
def initialize_environment():
    print("[CELL10] Initializing environment...")
    if torch.cuda.is_available():
        gcnt = torch.cuda.device_count()
        print(f"[CELL10] GPUs available: {gcnt}")
        for i in range(gcnt):
            try:
                name = torch.cuda.get_device_name(i)
            except Exception:
                name = "Unknown GPU"
            try:
                mem = torch.cuda.get_device_properties(i).total_memory / 1024 ** 3
                print(f"  - GPU {i}: {name} ({mem:.1f} GB)")
            except Exception:
                print(f"  - GPU {i}: {name} (mem unknown)")
        _safe_clear_gpu_caches()
        if gcnt > 1:
            print("[CELL10] Multi-GPU detected")
    else:
        print("[CELL10] No GPU detected - running on CPU")
    return True

def main_pipeline() -> Tuple[object, object]:
    """
    End-to-end orchestration. Returns (trained_model, tokenizer).
    """
    print("=" * 80)
    print("CELL10: TATN MAIN PIPELINE (patched) - Discovery + Homograph verification")
    print("=" * 80)

    initialize_environment()

    # -----------------------
    # Tokenizer
    # -----------------------
    print("[CELL10] Loading tokenizer...")
    tokenizer = _safe_tokenizer_from_pretrained("facebook/m2m100_418M")
    try:
        tokenizer.src_lang = _BN_LANG
    except Exception:
        pass

    # Ensure pad token exists (best-effort)
    try:
        pad_id = getattr(tokenizer, "pad_token_id", None)
        if pad_id is None and hasattr(tokenizer, "add_special_tokens"):
            try:
                tokenizer.add_special_tokens({"pad_token": "<pad>"})
            except Exception:
                pass
    except Exception:
        pass

    # compute a useful vocab_info for logging
    vocab_info = "unknown"
    try:
        if hasattr(tokenizer, "vocab_size") and getattr(tokenizer, "vocab_size") is not None:
            vocab_info = int(getattr(tokenizer, "vocab_size"))
        elif hasattr(tokenizer, "__len__"):
            try:
                vocab_info = int(len(tokenizer))
            except Exception:
                vocab_info = "unknown"
        else:
            vocab_info = "unknown"
    except Exception:
        vocab_info = "unknown"
    print(f"[CELL10] Tokenizer loaded (vocab size approx {vocab_info})")

    # -----------------------
    # Data loading (fallbacks)
    # -----------------------
    print(f"[CELL10] Loading/preprocessing up to {_NUM_SAMPLES} samples...")
    if "load_and_preprocess_optimized" in globals():
        try:
            pairs = load_and_preprocess_optimized(_NUM_SAMPLES)
        except Exception:
            print("[CELL10] load_and_preprocess_optimized failed; using fallback single example")
            pairs = [("‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§", "i turned off the tap.")]
    else:
        if _VERBOSE_LOGGING:
            print("[CELL10] Warning: load_and_preprocess_optimized not found; using small fallback dataset")
        pairs = [("‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§", "i turned off the tap.")]

    # Dataset: prefer existing MemoryEfficientDataset, else fallback to simple dataset
    if "MemoryEfficientDataset" in globals():
        DatasetClass = globals()["MemoryEfficientDataset"]
        try:
            dataset = DatasetClass(pairs, tokenizer, max_length=_MAX_LENGTH)
        except Exception:
            if _VERBOSE_LOGGING:
                print("[CELL10] MemoryEfficientDataset constructor failed; using fallback _SimpleDataset")
            dataset = _SimpleDataset(pairs, tokenizer, max_length=_MAX_LENGTH)
    else:
        if _VERBOSE_LOGGING:
            print("[CELL10] MemoryEfficientDataset not present - using fallback _SimpleDataset")
        dataset = _SimpleDataset(pairs, tokenizer, max_length=_MAX_LENGTH)

    batch_size = int(_BATCH_SIZE)
    active_device_ids = list(range(_NUM_GPUS)) if (_USE_MULTI_GPU and _NUM_GPUS > 1) else []
    if active_device_ids and batch_size < len(active_device_ids):
        usable = max(1, batch_size)
        active_device_ids = active_device_ids[:usable]
        print(f"[CELL10] Adjusting DataParallel devices to {len(active_device_ids)} due to small batch_size")

    # synchronize global BATCH_SIZE for compatibility with other cells
    try:
        global BATCH_SIZE
        BATCH_SIZE = batch_size
    except Exception:
        pass

    # collate function if provided
    collate_fn = globals().get("safe_collate", None)
    collate_fn = collate_fn if callable(collate_fn) else None

    # Prefer an optimized dataloader if available, else fallback to vanilla DataLoader
    try:
        if "create_optimized_dataloader" in globals():
            train_loader = create_optimized_dataloader(dataset, batch_size=batch_size, shuffle=True)
        else:
            train_loader = DataLoader(
                dataset,
                batch_size=batch_size,
                shuffle=True,
                num_workers=0,
                pin_memory=torch.cuda.is_available(),
                collate_fn=collate_fn,
                drop_last=False
            )
    except Exception:
        if _VERBOSE_LOGGING:
            print("[CELL10] DataLoader construction failed; attempting simple fallback")
        train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0, collate_fn=collate_fn)

    try:
        dataset_len = len(dataset)
    except Exception:
        dataset_len = "unknown"
    try:
        batches_count = len(train_loader)
    except Exception:
        batches_count = "unknown"
    print(f"[CELL10] Dataset: {dataset_len} examples, {batches_count} batches (batch_size={batch_size})")

    # -----------------------
    # Model init
    # -----------------------
    print("[CELL10] Initializing model...")
    if "MemoryOptimizedTATNWithExplanations" not in globals():
        # Do not raise ‚Äî give a clear message and return a safe no-op
        print("[CELL10] ERROR: Model class MemoryOptimizedTATNWithExplanations not found (Cell 6). Aborting pipeline initialization.")
        return None, tokenizer
    model_core = MemoryOptimizedTATNWithExplanations(tokenizer)

    # Wrap into DataParallel if multiple device ids chosen
    if active_device_ids and len(active_device_ids) > 1:
        print(f"[CELL10] Wrapping model in DataParallel on devices {active_device_ids}")
        model = nn.DataParallel(model_core, device_ids=active_device_ids)
    else:
        model = model_core
        if _VERBOSE_LOGGING:
            print("[CELL10] Single-GPU / CPU mode (no DataParallel)")

    # Move to device carefully (avoid .to on DataParallel in some setups)
    try:
        model = model.to(_DEVICE)
    except Exception:
        try:
            core = model.module if hasattr(model, "module") else model
            core.to(_DEVICE)
        except Exception:
            pass

    core_model = model.module if hasattr(model, "module") else model

    # Resize embeddings if tokenizer vocabulary differs from model embedding size
    try:
        mb = getattr(core_model, "mbart", None)
        if mb is not None and hasattr(mb, "get_input_embeddings"):
            emb = mb.get_input_embeddings()
            current_emb = None
            try:
                current_emb = getattr(emb, "num_embeddings", None) or (emb.weight.shape[0] if hasattr(emb, "weight") else None)
            except Exception:
                current_emb = None
            new_size = None
            try:
                if hasattr(tokenizer, "vocab_size") and getattr(tokenizer, "vocab_size") is not None:
                    new_size = int(getattr(tokenizer, "vocab_size"))
                elif hasattr(tokenizer, "__len__"):
                    new_size = int(len(tokenizer))
            except Exception:
                new_size = None
            if new_size and current_emb and int(current_emb) != int(new_size):
                try:
                    mb.resize_token_embeddings(new_size)
                    print(f"[CELL10] Resized token embeddings: {current_emb} -> {new_size}")
                except Exception:
                    if _VERBOSE_LOGGING:
                        print("[CELL10] Warning: resize_token_embeddings failed; continuing")
    except Exception:
        pass

    # Optional encoder freeze
    if FREEZE_ENCODER:
        try:
            if hasattr(core_model, "mbart") and hasattr(core_model.mbart, "model"):
                for p in core_model.mbart.model.encoder.parameters():
                    p.requires_grad = False
                print("[CELL10] Encoder frozen for faster training")
        except Exception:
            if _VERBOSE_LOGGING:
                print("[CELL10] Encoder freeze failed; continuing")

    # -----------------------
    # Optimizers
    # -----------------------
    print("[CELL10] Preparing optimizers...")
    try:
        critic_params = list(core_model.asbn.critic_parameters()) if hasattr(core_model, "asbn") and hasattr(core_model.asbn, "critic_parameters") else []
    except Exception:
        critic_params = []
    critic_ids = {id(p) for p in critic_params}
    base_params = [p for p in core_model.parameters() if p.requires_grad and id(p) not in critic_ids]

    optimizer = torch.optim.AdamW(base_params, lr=_LR_NMT)
    phi_optimizer = None
    if critic_params and _ENABLE_ASBN_TRAINING:
        try:
            phi_optimizer = torch.optim.AdamW([p for p in critic_params if p.requires_grad], lr=_LR_PHI)
            print(f"[CELL10] ASBN critic optimizer created (params: {len([p for p in critic_params if p.requires_grad])})")
        except Exception:
            phi_optimizer = None
            print("[CELL10] ASBN critic optimizer creation failed; continuing without it")
    else:
        if _VERBOSE_LOGGING:
            print("[CELL10] ASBN critic optimizer disabled")

    # -----------------------
    # Training
    # -----------------------
    print("[CELL10] Starting training phase...")
    trained_model = model
    if "train_memory_efficient_tatn" in globals():
        try:
            trained_model = train_memory_efficient_tatn(
                model,
                tokenizer,
                train_loader,
                optimizer,
                phi_optimizer=phi_optimizer,
                epochs=_EPOCHS,
                accumulation_steps=_ACCUMULATION_STEPS,
                validate_every=_VALIDATION_CHECK_INTERVAL,
                enable_validation=bool(_VALIDATION_CHECK_INTERVAL > 0)
            )
        except Exception as e:
            print(f"[CELL10] Training failed: {type(e).__name__}: {str(e)[:200]}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
            trained_model = model
    else:
        if _VERBOSE_LOGGING:
            print("[CELL10] Training function not found (Cell 7). Skipping training.")
        trained_model = model

    # ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
    # Discovery phase: cluster buffered embeddings + verify homographs
    # ‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê
    print("\n" + "=" * 80)
    print("DISCOVERY PHASE: Clustering DSCD buffers to create prototypes...")
    print("=" * 80)

    _safe_clear_gpu_caches()

    try:
        core_for_discovery = trained_model.module if hasattr(trained_model, 'module') else trained_model

        if not hasattr(core_for_discovery, "dscd"):
            raise RuntimeError("Trained model does not have a .dscd attribute (DSCD instance)")

        dscd = core_for_discovery.dscd

        # Collect clusterable tokens using a conservative threshold (use .buffers safely)
        buffers_iter = getattr(dscd, "buffers", {}) or {}
        clusterable_tokens: List[Tuple[str, int]] = []
        for token_type, buffer in buffers_iter.items():
            try:
                buf_len = len(buffer)
            except Exception:
                buf_len = 0
            if buf_len >= _CLUSTER_MIN_SAMPLES:
                clusterable_tokens.append((token_type, buf_len))

        # relax threshold if nothing meets strict threshold
        if len(clusterable_tokens) == 0:
            relaxed = []
            for token_type, buffer in buffers_iter.items():
                try:
                    buf_len = len(buffer)
                except Exception:
                    buf_len = 0
                if buf_len >= DSCD_N_MIN:
                    relaxed.append((token_type, buf_len))
            if relaxed:
                print(f"[DISCOVERY] No tokens >= {_CLUSTER_MIN_SAMPLES}. Relaxing threshold to DSCD_N_MIN={DSCD_N_MIN} (found {len(relaxed)})")
                clusterable_tokens = relaxed

        # Sort by buffer size (descending) and limit to top K (do not force minimum)
        clusterable_tokens.sort(key=lambda x: x[1], reverse=True)
        MAX_TO_CLUSTER = min(500, max(1, len(clusterable_tokens)))
        clusterable_tokens = clusterable_tokens[:MAX_TO_CLUSTER]

        print(f"[DISCOVERY] Found {len(clusterable_tokens)} tokens meeting threshold for clustering (threshold={_CLUSTER_MIN_SAMPLES})")

        if len(clusterable_tokens) == 0:
            print("[DISCOVERY] WARNING: No tokens with sufficient samples! DSCD will not work reliably.")
        else:
            clustered_count = 0
            failed_count = 0
            start_time = time.time()

            for idx, (token_type, buffer_size) in enumerate(clusterable_tokens):
                try:
                    success = False
                    if hasattr(dscd, "_cluster_buffer_to_prototypes_hierarchical"):
                        try:
                            success = dscd._cluster_buffer_to_prototypes_hierarchical(token_type)
                        except Exception as e:
                            if _VERBOSE_LOGGING:
                                print(f"  [WARN] Clustering call raised for token '{token_type}': {type(e).__name__}: {str(e)[:200]}")
                            success = False
                    else:
                        if _VERBOSE_LOGGING:
                            print("  [WARN] DSCD instance has no _cluster_buffer_to_prototypes_hierarchical method; skipping clustering.")
                        success = False

                    if success:
                        clustered_count += 1
                    else:
                        failed_count += 1

                    if (idx + 1) % 50 == 0:
                        elapsed = time.time() - start_time
                        print(f"  Progress: {idx + 1}/{len(clusterable_tokens)} tokens processed "
                              f"({clustered_count} successful, {failed_count} failed) "
                              f"[{elapsed:.1f}s elapsed]")

                except Exception as e:
                    failed_count += 1
                    if failed_count <= 10:
                        token_str = str(token_type)[:40]
                        print(f"  [WARN] Clustering failed for token '{token_str}': {type(e).__name__}: {str(e)[:200]}")
                    if _VERBOSE_LOGGING:
                        traceback.print_exc()
                    continue

            # Final stats: defensive access to prototype_stores
            prototype_stores = getattr(dscd, "prototype_stores", {}) or {}
            try:
                total_prototypes = 0
                for store in prototype_stores.values():
                    try:
                        if hasattr(store, "size") and callable(store.size):
                            total_prototypes += int(store.size())
                        elif hasattr(store, "__len__"):
                            total_prototypes += int(len(store))
                        else:
                            total_prototypes += int(getattr(store, "n_prototypes", 0))
                    except Exception:
                        try:
                            total_prototypes += int(getattr(store, "n_prototypes", 0) or 0)
                        except Exception:
                            pass
            except Exception:
                total_prototypes = 0

            try:
                multi_sense_words = sum(1 for store in prototype_stores.values() if ((store.size() if hasattr(store, "size") and callable(store.size) else (len(store) if hasattr(store, "__len__") else 0)) >= 2))
            except Exception:
                multi_sense_words = 0

            elapsed_total = time.time() - start_time

            print("=" * 80)
            print("‚úì DISCOVERY PHASE COMPLETE")
            print("=" * 80)
            print(f"  ‚Ä¢ Tokens processed: {len(clusterable_tokens)}")
            print(f"  ‚Ä¢ Successfully clustered: {clustered_count}")
            print(f"  ‚Ä¢ Failed: {failed_count}")
            print(f"  ‚Ä¢ Total prototypes created: {total_prototypes}")
            print(f"  ‚Ä¢ Multi-sense words (‚â•2 prototypes): {multi_sense_words}")
            print(f"  ‚Ä¢ Time elapsed: {elapsed_total:.2f}s ({elapsed_total/60:.2f} min)")
            print("=" * 80)

            # Homograph verification (normalized matching)
            print("\n[DISCOVERY] ‚úÖ Verifying homograph words were clustered:")
            print("-" * 80)
            homographs_found = 0
            homographs_missing = 0

            # Build a normalized map from proto store keys -> (orig_key, store)
            proto_map = {}
            for token_key, store in prototype_stores.items():
                try:
                    nk = _norm_clean_token(token_key)
                except Exception:
                    nk = str(token_key)
                if nk not in proto_map:
                    proto_map[nk] = (token_key, store)

            for homograph in (list(_HOMOGRAPH_WATCHLIST_BN) if _HOMOGRAPH_WATCHLIST_BN else []):
                matched_store = None
                matched_key = None
                nh = _norm_clean_token(homograph)
                if nh and nh in proto_map:
                    matched_key, matched_store = proto_map[nh]
                else:
                    for nk, (orig_k, store) in proto_map.items():
                        try:
                            if _token_matches_homograph(orig_k, homograph):
                                matched_key, matched_store = orig_k, store
                                break
                        except Exception:
                            continue

                try:
                    store_size = 0
                    if matched_store is not None:
                        if hasattr(matched_store, "size") and callable(matched_store.size):
                            store_size = int(matched_store.size())
                        elif hasattr(matched_store, "__len__"):
                            store_size = int(len(matched_store))
                        else:
                            store_size = int(getattr(matched_store, "n_prototypes", 0) or 0)
                    if matched_store is not None and store_size >= 2:
                        counts = getattr(matched_store, "counts", None)
                        print(f"  ‚úì '{homograph}' ‚Üí {store_size} prototypes (key='{matched_key}') counts={counts}")
                        homographs_found += 1
                    else:
                        print(f"  ‚úó WARNING: '{homograph}' has NO multi-sense prototypes")
                        print(f"            This word will NOT be disambiguated in inference!")
                        homographs_missing += 1
                except Exception:
                    print(f"  ‚úó WARNING: '{homograph}' verification encountered an error")
                    homographs_missing += 1

            print("-" * 80)
            print(f"Homograph verification: {homographs_found}/{len(list(_HOMOGRAPH_WATCHLIST_BN))} detected")

            if homographs_missing > 0:
                print(f"\n‚ö†Ô∏è WARNING: {homographs_missing} known homographs were NOT properly clustered!")
                print("Possible causes:")
                print("  1. Not enough training samples containing these words")
                print("  2. Words were filtered out by should_track_token() (Cell 3)")
                print("  3. Buffer/cluster thresholds too strict")
                print("  4. Clustering backend unavailable or failed (SciPy/sklearn)")
            else:
                print("\n‚úÖ All homographs successfully clustered! Disambiguation ready.")

            # Clear buffers only if prototypes were actually created
            if total_prototypes > 0:
                if _VERBOSE_LOGGING:
                    print("[DISCOVERY] Clearing DSCD buffers to save memory (prototypes present).")
                try:
                    if hasattr(dscd, "buffers") and hasattr(dscd.buffers, "clear"):
                        dscd.buffers.clear()
                    else:
                        dscd.buffers = {}
                except Exception:
                    try:
                        dscd.buffers = {}
                    except Exception:
                        pass
                _safe_clear_gpu_caches()
            else:
                if _VERBOSE_LOGGING:
                    print("[DISCOVERY] Not clearing DSCD buffers (no prototypes created) to preserve data for debugging/warmup.")

    except Exception as e:
        print(f"[DISCOVERY] CRITICAL ERROR: Discovery phase failed!")
        print(f"  Error type: {type(e).__name__}")
        print(f"  Error message: {str(e)[:300]}")
        if _VERBOSE_LOGGING:
            print("\n[DISCOVERY] Full traceback:")
            traceback.print_exc()
        print("\n[DISCOVERY] WARNING: DSCD homograph detection will NOT work!")
        print("  The model will function but only at baseline M2M100 quality.")

    # Optional: Run additional warmup inference (if dscd_discovery_warmup exists)
    if "dscd_discovery_warmup" in globals():
        try:
            print("\n[CELL10] Running additional DSCD inference warmup...")
            warmup_samples = min(1000, int(_DSCD_WARMUP_SAMPLES))
            dscd_discovery_warmup(trained_model, tokenizer, num_sents=warmup_samples, max_len=_MAX_LENGTH)
            print(f"[CELL10] Inference warmup complete")
        except Exception as e:
            print(f"[CELL10] Inference warmup failed: {str(e)[:200]}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()

    # -----------------------
    # Step 8: Post-training evaluation
    # -----------------------
    print("\n[CELL10] Step 8: Evaluation")
    _safe_clear_gpu_caches()
    if "comprehensive_post_training_testing" in globals():
        try:
            eval_results = comprehensive_post_training_testing(trained_model, tokenizer)
        except Exception as e:
            print(f"[CELL10] Evaluation failed: {type(e).__name__}: {str(e)[:200]}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
            eval_results = {}
    else:
        if _VERBOSE_LOGGING:
            print("[CELL10] comprehensive_post_training_testing not found")
        eval_results = {}

    # -----------------------
    # Save model (core state dict)
    # -----------------------
    print("[CELL10] Saving model...")
    try:
        core_for_save = trained_model.module if hasattr(trained_model, "module") else trained_model
        save_path = "tatn_kaggle_final.pt"
        # Ensure directory exists
        sdir = os.path.dirname(save_path)
        if sdir and not os.path.exists(sdir):
            try:
                os.makedirs(sdir, exist_ok=True)
            except Exception:
                pass
        torch.save(core_for_save.state_dict(), save_path)
        print(f"[CELL10] Model state saved to {save_path}")
    except Exception as e:
        print(f"[CELL10] Save failed: {type(e).__name__}: {str(e)[:200]}")
        if _VERBOSE_LOGGING:
            traceback.print_exc()

    # Final report
    print("\n[CELL10] Final Report Summary:")
    if eval_results:
        try:
            sr = eval_results.get('success_rate_pct', eval_results.get('success_rate', 0.0))
            print(f"  Success Rate: {float(sr):.2f}%")
        except Exception:
            print(f"  Success Rate: {eval_results.get('success_rate_pct', eval_results.get('success_rate', 'N/A'))}")
        print(f"  DSCD prototype stats: {eval_results.get('dscd_stats', {})}")
    else:
        print("  No evaluation metrics available")

    # Clear caches and return
    _safe_clear_gpu_caches()
    return trained_model, tokenizer

# When this cell is executed, the user can call main_pipeline() to execute.
print("‚úÖ Cell 10 (patched): Discovery phase + homograph verification ready. Call main_pipeline() to execute.")

‚úÖ Cell 10 (patched): Discovery phase + homograph verification ready. Call main_pipeline() to execute.


In [15]:
# ==============================================================================
# CELL 11 (patched): MAIN EXECUTION WRAPPER (MULTI-GPU OPTIMIZED - DEBUGGED)
# ==============================================================================
# - Robust globals via globals().get(...) with sensible defaults
# - Safer invocation of main_pipeline() and tolerant unpacking of return values
# - Controlled verbose tracebacks via VERBOSE_LOGGING flag
# - Guarded quick inference check that tolerates many return shapes and call signatures
# - Improved ceil division helper that accepts numeric-like inputs
# ==============================================================================
from datetime import datetime, timezone
import os
import traceback
import math
import sys
import torch
from typing import Any

def _safe_get(name: str, default: Any):
    try:
        return globals().get(name, default)
    except Exception:
        return default

def _safe_div_ceil(a: Any, b: Any) -> int:
    """Return ceil(a/b) for numeric-like inputs, otherwise 0."""
    try:
        a_f = float(a)
        b_f = float(b)
        if b_f == 0:
            return 0
        return int(math.ceil(a_f / b_f))
    except Exception:
        return 0

def _is_model_like(obj: Any) -> bool:
    """Heuristic: object that looks like a model (has forward or predict)."""
    try:
        return hasattr(obj, "forward") or hasattr(obj, "generate") or hasattr(obj, "state_dict") or hasattr(obj, "dscd")
    except Exception:
        return False

def _is_tokenizer_like(obj: Any) -> bool:
    """Heuristic: object that looks like a tokenizer (has decode or convert_ids_to_tokens)."""
    try:
        return hasattr(obj, "decode") or hasattr(obj, "convert_ids_to_tokens") or callable(getattr(obj, "__call__", None))
    except Exception:
        return False

# Entry point guard for script invocation
if __name__ == "__main__":
    print("=" * 80)
    print("MEMORY-OPTIMIZED TATN FOR KAGGLE T4√ó2 (Cell 11 - RUNNER)")
    print("=" * 80)

    # Read configuration safely from globals (Cell 0 may not have run)
    _NUM_SAMPLES = _safe_get("NUM_SAMPLES", 30000)
    _EPOCHS = _safe_get("EPOCHS", 2)
    _BATCH_SIZE = _safe_get("BATCH_SIZE", 4)
    _ACCUMULATION_STEPS = _safe_get("ACCUMULATION_STEPS", 16)
    _DEVICE = _safe_get("DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
    _ENABLE_ASBN_TRAINING = _safe_get("ENABLE_ASBN_TRAINING", True)
    _ENABLE_TRG_INFERENCE = _safe_get("ENABLE_TRG_INFERENCE", True)
    _PERIODIC_DISCOVERY_FREQUENCY = _safe_get("PERIODIC_DISCOVERY_FREQUENCY", 5000)
    _VERBOSE_LOGGING = _safe_get("VERBOSE_LOGGING", False)
    _USE_MULTI_GPU = _safe_get("USE_MULTI_GPU", torch.cuda.is_available() and torch.cuda.device_count() > 1)
    _NUM_GPUS = _safe_get("NUM_GPUS", torch.cuda.device_count() if torch.cuda.is_available() else 0)

    # user and timestamp
    user_login = os.getenv("KAGGLE_USERNAME") or os.getenv("USER") or _safe_get("CURRENT_USER", "manas0003")
    now_utc = datetime.now(timezone.utc).strftime("%Y-%m-%d %H:%M:%S UTC")
    print(f"User: {user_login}")
    print(f"Started: {now_utc}")

    # Configuration summary
    print("\nConfiguration:")
    print(f"   ‚Ä¢ Samples: {_NUM_SAMPLES}")
    print(f"   ‚Ä¢ Epochs: {_EPOCHS}")
    print(f"   ‚Ä¢ Batch Size: {_BATCH_SIZE}")
    print(f"   ‚Ä¢ Accumulation: {_ACCUMULATION_STEPS}")
    print(f"   ‚Ä¢ Device: {_DEVICE}")
    print(f"   ‚Ä¢ Multi-GPU: {'ENABLED' if _USE_MULTI_GPU else 'DISABLED'} ({_NUM_GPUS} GPU(s))")
    if _USE_MULTI_GPU and _NUM_GPUS > 0:
        per_gpu = _safe_div_ceil(int(_BATCH_SIZE), int(max(1, _NUM_GPUS)))
        print(f"   ‚Ä¢ Batch per GPU: {per_gpu}")
    print(f"   ‚Ä¢ ASBN Training: {'Enabled' if _ENABLE_ASBN_TRAINING else 'Disabled'}")
    print(f"   ‚Ä¢ TRG Inference: {'Enabled' if _ENABLE_TRG_INFERENCE else 'Disabled'}")
    print(f"   ‚Ä¢ Periodic Discovery: Every {_PERIODIC_DISCOVERY_FREQUENCY} steps")
    print("=" * 80)

    trained_model, tokenizer = None, None

    # Ensure main_pipeline exists and is callable
    mp = globals().get("main_pipeline", None)
    if mp is None or not callable(mp):
        print("\nERROR: main_pipeline not found or not callable - please run Cell 10 before executing this cell.")
    else:
        try:
            print("\nStarting full pipeline (this may take a while)...")
            # call main_pipeline; accept several possible return patterns
            ret = mp()

            # Normalize and unpack the return robustly
            if isinstance(ret, tuple):
                # common case: (model, tokenizer) or (model,)
                if len(ret) >= 2:
                    trained_model, tokenizer = ret[0], ret[1]
                elif len(ret) == 1:
                    trained_model = ret[0]
                    # attempt to find tokenizer in ret[0] attributes or globals
                    if hasattr(trained_model, "tokenizer") and _is_tokenizer_like(getattr(trained_model, "tokenizer")):
                        tokenizer = trained_model.tokenizer
                    else:
                        tokenizer = globals().get("tokenizer", None)
            elif isinstance(ret, dict):
                # returned a dict of artifacts
                trained_model = ret.get("model") or ret.get("trained_model") or ret.get("core_model") or ret.get("tatn")
                tokenizer = ret.get("tokenizer") or ret.get("tok") or globals().get("tokenizer", None)
                # if values are not model/tokenizer, try heuristics
                if trained_model is None:
                    for v in ret.values():
                        if _is_model_like(v):
                            trained_model = v
                            break
                if tokenizer is None:
                    for v in ret.values():
                        if _is_tokenizer_like(v):
                            tokenizer = v
                            break
            else:
                # single-object return - try to infer
                if _is_model_like(ret):
                    trained_model = ret
                    tokenizer = globals().get("tokenizer", None)
                elif _is_tokenizer_like(ret):
                    tokenizer = ret
                    trained_model = globals().get("trained_model", None) or globals().get("model", None)
                else:
                    # fallback: look into globals for likely objects if pipeline stored them
                    trained_model = globals().get("trained_model", None) or globals().get("model", None)
                    tokenizer = globals().get("tokenizer", None)

        except KeyboardInterrupt:
            print("\nExecution interrupted by user (KeyboardInterrupt).")
        except Exception as e:
            msg = str(e).lower()
            if isinstance(e, RuntimeError) and (
                "no usable tokenizer class available" in msg
                or "failed to instantiate tokenizer" in msg
                or "sentencepiece" in msg
                or "tokenizers" in msg
            ):
                print(f"\nPipeline execution failed: {type(e).__name__}: {str(e)[:400]}")
                print("\nThis error indicates the tokenizer could not be instantiated. Common causes and fixes:")
                print("  ‚Ä¢ Missing or incompatible 'transformers' package.")
                print("  ‚Ä¢ Missing optional tokenizer dependencies (sentencepiece, tokenizers, sacremoses).")
                print("\nSuggested actions (pick one):")
                print("  1) Install the recommended packages (in a notebook cell or terminal):")
                print("       !pip install transformers==4.30.2 sentencepiece tokenizers sacremoses --quiet")
                print("     Then RESTART the kernel and re-run Cells 0‚Üí11 in order.")
                print("")
                print("  2) If you are offline but have a cached tokenizer folder, set local_files_only=True in the tokenizer loader or")
                print("     provide MODEL_LOCAL_TOKENIZER_PATH in your config and re-run.")
                print("")
                print("  3) If you want to continue debugging without real tokenization, ensure Cell 10's _safe_tokenizer_from_pretrained")
                print("     returns the whitespace fallback (it will allow the pipeline to continue but translations will be incorrect).")
                if _VERBOSE_LOGGING:
                    print("\nFull traceback (VERBOSE):")
                    traceback.print_exc()
                else:
                    print("\nSet VERBOSE_LOGGING = True in Cell 0 to see the full traceback.")
            else:
                print(f"\nPipeline execution failed: {type(e).__name__}: {str(e)[:400]}")
                if _VERBOSE_LOGGING:
                    traceback.print_exc()
                else:
                    print("Set VERBOSE_LOGGING = True in Cell 0 to see full traceback.")

    # Post-run summary and quick inference check
    if trained_model is not None and tokenizer is not None:
        print("\n" + "=" * 80)
        print("SYSTEM INITIALIZATION SUCCEEDED")
        print("=" * 80)
        print("\nCapabilities:")
        print("   ‚Ä¢ Bengali ‚Üí English translation")
        print("   ‚Ä¢ Automatic homograph disambiguation (DSCD + TRG)")
        print("   ‚Ä¢ Dynamic prototype discovery (hierarchical clustering)")
        if _USE_MULTI_GPU:
            print(f"   ‚Ä¢ Multi-GPU acceleration ({_NUM_GPUS} GPUs)")
        print("=" * 80)

        # Quick inference validation (best-effort; guarded)
        print("\nQuick Inference Validation (single sample):")
        try:
            tw = globals().get("translate_with_explanations", None)
            if callable(tw):
                sample = "‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§"
                print(f"  Testing sentence: {sample}")

                # Try several plausible call signatures until one works
                res = None
                call_attempts = []
                # define candidate argument permutations
                arg_permutations = [
                    (trained_model, tokenizer, sample),
                    (trained_model, tokenizer, [sample]),
                    (trained_model, sample, tokenizer),
                    (tokenizer, trained_model, sample),
                    (sample, trained_model, tokenizer),
                    (sample,),  # some wrappers accept only sentence and read model/tokenizer from globals
                    (trained_model, sample),
                    (tokenizer, sample),
                ]

                for args in arg_permutations:
                    try:
                        candidate = tw(*args)
                        res = candidate
                        call_attempts.append(("ok", args))
                        break
                    except TypeError as te:
                        call_attempts.append(("type_error", args, str(te)))
                        continue
                    except Exception as e:
                        # record and continue; verbose if requested
                        call_attempts.append(("error", args, f"{type(e).__name__}: {e}"))
                        if _VERBOSE_LOGGING:
                            traceback.print_exc()
                        continue

                if res is None:
                    # try calling with named args as a last resort
                    try:
                        res = tw(model=trained_model, tokenizer=tokenizer, input_sentence=sample)
                    except Exception:
                        res = None

                # Report results defensively
                if isinstance(res, dict):
                    print(f"  Translation: {res.get('translation', 'N/A')}")
                    print(f"  Ambiguous Words Detected: {res.get('ambiguous_words_detected', 0)}")
                    exs = res.get('explanations', []) or []
                    if exs:
                        e0 = exs[0]
                        print("  Example explanation (first):")
                        print(f"    Word: {e0.get('ambiguous_word', e0.get('token', 'N/A'))}")
                        try:
                            u = float(e0.get('uncertainty', 0.0))
                            s = float(e0.get('span', 0.0))
                            print(f"    Uncertainty: {u:.3f}")
                            print(f"    Span: {s:.3f}")
                        except Exception:
                            print(f"    Uncertainty/Span: {e0.get('uncertainty','N/A')} / {e0.get('span','N/A')}")
                    else:
                        print("  No explanations returned (high-confidence translation)")
                elif res is None:
                    print("  Quick inference returned None (check translate_with_explanations signature or pipeline logs)")
                    if _VERBOSE_LOGGING:
                        print("  Call attempts summary:")
                        for rec in call_attempts:
                            print("   ", rec)
                else:
                    # non-dict result: print repr for debugging
                    print("  translate_with_explanations returned non-dict result; here's its repr:")
                    print("  ", repr(res)[:1000])
            else:
                print("  translate_with_explanations not available - ensure Cell 8 is run")
        except Exception as e:
            print(f"  Quick inference failed: {type(e).__name__}: {str(e)[:200]}")
            if _VERBOSE_LOGGING:
                traceback.print_exc()
    else:
        print("\n" + "=" * 80)
        print("SYSTEM INITIALIZATION FAILED")
        print("=" * 80)
        print("Troubleshooting tips:")
        print("  1) Run Cells 0‚Üí10 in order to ensure dependencies are loaded.")
        print("  2) Set VERBOSE_LOGGING = True in Cell 0 to see detailed tracebacks.")
        print("  3) Ensure GPUs are available and CUDA visible to the process.")
        print("  4) If warmup/prototype building missed some words, run dscd_discovery_warmup(...) manually.")
        print("")
        print("If the failure was tokenizer-related, run the following and then RESTART the kernel:")
        print("  pip install transformers==4.30.2 sentencepiece tokenizers sacremoses")
        print("=" * 80)

    print("\nCELL 11: Execution wrapper finished.")

MEMORY-OPTIMIZED TATN FOR KAGGLE T4√ó2 (Cell 11 - RUNNER)
User: manas0003
Started: 2025-11-22 14:56:08 UTC

Configuration:
   ‚Ä¢ Samples: 50000
   ‚Ä¢ Epochs: 2
   ‚Ä¢ Batch Size: 100
   ‚Ä¢ Accumulation: 16
   ‚Ä¢ Device: cuda
   ‚Ä¢ Multi-GPU: ENABLED (2 GPU(s))
   ‚Ä¢ Batch per GPU: 50
   ‚Ä¢ ASBN Training: Enabled
   ‚Ä¢ TRG Inference: Enabled
   ‚Ä¢ Periodic Discovery: Every 100 steps

Starting full pipeline (this may take a while)...
CELL10: TATN MAIN PIPELINE (patched) - Discovery + Homograph verification
[CELL10] Initializing environment...
[CELL10] GPUs available: 2
  - GPU 0: Tesla T4 (14.7 GB)
  - GPU 1: Tesla T4 (14.7 GB)
[CELL10] Multi-GPU detected
[CELL10] Loading tokenizer...


tokenizer_config.json:   0%|          | 0.00/298 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/908 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

special_tokens_map.json: 0.00B [00:00, ?B/s]

[CELL10] Tokenizer loaded (vocab size approx 128104)
[CELL10] Loading/preprocessing up to 50000 samples...
[CELL2] Loading up to 50000 samples from local CSV: /kaggle/input/homo-bn-dataset/bn_homograph_complete_dataset.csv
[CELL2] Reading CSV file...
[CELL2] Processing 50000 rows from CSV...


Loading dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 50000/50000 [00:02<00:00, 23195.56it/s]


[CELL2] Loaded 50000 pairs from CSV, skipped 0 rows
[CELL2] Dataset initialized: 50000 valid pairs, 0 invalid pairs filtered
[CELL2] DataLoader created: total_batch=100, per_gpu=50, workers=2
[CELL10] Dataset: 50000 examples, 500 batches (batch_size=100)
[CELL10] Initializing model...


pytorch_model.bin:   0%|          | 0.00/1.94G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/233 [00:00<?, ?B/s]

Using cls_token, but it is not set yet.
Using cls_token, but it is not set yet.
Using mask_token, but it is not set yet.
Using mask_token, but it is not set yet.


[CELL10] Wrapping model in DataParallel on devices [0, 1]
[CELL10] Resized token embeddings: 128112 -> 128104
[CELL10] Preparing optimizers...
[CELL10] ASBN critic optimizer created (params: 12)
[CELL10] Starting training phase...
[TRAIN] Starting training: epochs=2, batch=100, accum_steps=16
[TRAIN] Validation: enabled
[TRAIN] DP enabled: True, GPUs: 2, Device: cuda


Epoch 1/2:  40%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ                              | 199/500 [44:22<1:07:42, 13.50s/it, fwd_loss=2.0640 bwd_loss=0.129002 rate=100.0% proc=199 skip=0 clusters=12710]

[TRAIN-DEBUG] GPU mem (GB):
  GPU 0: alloc=7.73 resv=12.90
  GPU 1: alloc=1.30 resv=8.53
[TRAIN-DEBUG] step=200 loss=2.2224 opt_updates=12 clusters=12741

[CLUSTER] Top 5 clusters (by sample count):
------------------------------------------------------------------------------------------
Rank  Token             Count       Protos    Œº (mean)       œÑ (dev)     
------------------------------------------------------------------------------------------
1     ‡¶∂‡¶ï‡ßç‡¶§              20          4         23.279076      4.302098    
2     ‡¶ï‡¶†‡ßã‡¶∞              20          4         23.584515      3.812131    
3     ‡¶∏‡ßã‡¶®‡¶æ              19          3         26.127859      4.105221    
4     ‡¶ú‡ßç‡¶û‡¶æ‡¶®             19          3         23.873760      4.940936    
5     ‡¶∏‡ßÅ‡¶∞               19          3         24.755248      4.159158    
------------------------------------------------------------------------------------------
Total clusters: 12741 | Total samples in c

Epoch 1/2:  41%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñã                             | 207/500 [46:10<1:05:35, 13.43s/it, fwd_loss=2.0851 bwd_loss=0.130319 rate=100.0% proc=207 skip=0 clusters=12942]


[VALIDATION] Quick validation at step 208


2025-11-22 15:42:59.652126: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763826179.867045      48 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763826179.921018      48 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


1. ‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§ -> ..
2. ‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§ -> to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to to
3. ‡¶™‡¶æ‡¶§‡¶æ ‡¶ù‡¶∞‡ßá ‡¶™‡¶°‡¶º‡ßá‡¶õ‡ßá‡•§ -> ..
4. ‡¶Ü‡¶Æ‡¶ø ‡¶≠‡¶æ‡¶≤‡ßã ‡¶Ü‡¶õ‡¶ø‡•§ -> ..
5. ‡¶Ü‡¶ú ‡¶Ü‡¶¨‡¶π‡¶æ‡¶ì‡¶Ø‡¶º‡¶æ ‡¶≠‡¶æ‡¶≤‡ßã‡•§ -> ..


Epoch 1/2:  42%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñä                             | 208/500 [46:39<1:28:39, 18.22s/it, fwd_loss=1.8908 bwd_loss=0.118173 rate=100.0% proc=208 skip=0 clusters=12981]



Epoch 1/2:  80%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ          | 399/500 [1:29:33<22:48, 13.55s/it, fwd_loss=2.0988 bwd_loss=0.131177 rate=100.0% proc=399 skip=0 clusters=17515]


[VALIDATION] Quick validation at step 400
1. ‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§ -> i closed the call.
2. ‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§ -> i will buy this tomorrow.
3. ‡¶™‡¶æ‡¶§‡¶æ ‡¶ù‡¶∞‡ßá ‡¶™‡¶°‡¶º‡ßá‡¶õ‡ßá‡•§ -> the page fell.
4. ‡¶Ü‡¶Æ‡¶ø ‡¶≠‡¶æ‡¶≤‡ßã ‡¶Ü‡¶õ‡¶ø‡•§ -> i am good.
5. ‡¶Ü‡¶ú ‡¶Ü‡¶¨‡¶π‡¶æ‡¶ì‡¶Ø‡¶º‡¶æ ‡¶≠‡¶æ‡¶≤‡ßã‡•§ -> today is good.
[TRAIN-DEBUG] GPU mem (GB):
  GPU 0: alloc=5.93 resv=7.14
  GPU 1: alloc=1.25 resv=3.09
[TRAIN-DEBUG] step=400 loss=2.3060 opt_updates=25 clusters=17543

[CLUSTER] Top 5 clusters (by sample count):
------------------------------------------------------------------------------------------
Rank  Token             Count       Protos    Œº (mean)       œÑ (dev)     
------------------------------------------------------------------------------------------
1     ‡¶ó‡¶§                20          4         20.176730      10.448194   
2     ‡¶™‡ßç‡¶∞‡¶∂‡¶∏‡ßç‡¶§           20          4         23.658534      6.460830   

Epoch 1/2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [1:52:11<00:00, 13.46s/it, fwd_loss=1.5028 bwd_loss=0.093923 rate=100.0% proc=500 skip=0 clusters=19278]



Epoch 1 summary:
  duration (min): 112.19
  optimizer updates: 32
  batches processed: 500 (processed=500, skipped=0)
  success rate (updates/expected): 103.2%
  clustered token types: 19278
  avg forward loss: 2.808499
[CHECKPOINT] Saved tatn_e1_s500_20251122_164843.pt avg_loss=2.808499


Epoch 2/2:  20%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà                                         | 99/500 [22:27<1:29:59, 13.46s/it, fwd_loss=1.4471 bwd_loss=0.090441 rate=102.7% proc=599 skip=0 clusters=19975]

[TRAIN-DEBUG] GPU mem (GB):
  GPU 0: alloc=7.74 resv=13.75
  GPU 1: alloc=1.33 resv=8.44
[TRAIN-DEBUG] step=600 loss=1.2463 opt_updates=38 clusters=19978

[CLUSTER] Top 5 clusters (by sample count):
------------------------------------------------------------------------------------------
Rank  Token             Count       Protos    Œº (mean)       œÑ (dev)     
------------------------------------------------------------------------------------------
1     ‡¶ß‡¶æ‡¶®               20          4         21.095281      8.081010    
2     ‡¶ï‡ßå                20          4         21.031949      5.118068    
3     ‡¶¶‡ßÇ‡¶∞               20          4         25.049261      3.580823    
4     ‡¶®‡¶æ‡¶ó‡¶∞‡¶ø‡¶ï            20          4         22.141394      4.999902    
5     ‡ßÉ‡¶§‡¶ø               20          4         17.211136      3.584237    
------------------------------------------------------------------------------------------
Total clusters: 19978 | Total samples in cluster

Epoch 2/2:  22%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà                                       | 111/500 [25:09<1:26:15, 13.30s/it, fwd_loss=1.4946 bwd_loss=0.093412 rate=100.0% proc=611 skip=0 clusters=20054]


[VALIDATION] Quick validation at step 612
1. ‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§ -> i closed the call.
2. ‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§ -> i will buy it tomorrow.
3. ‡¶™‡¶æ‡¶§‡¶æ ‡¶ù‡¶∞‡ßá ‡¶™‡¶°‡¶º‡ßá‡¶õ‡ßá‡•§ -> the page has fallen.
4. ‡¶Ü‡¶Æ‡¶ø ‡¶≠‡¶æ‡¶≤‡ßã ‡¶Ü‡¶õ‡¶ø‡•§ -> i am well.


Epoch 2/2:  22%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñè                                      | 112/500 [25:24<1:28:48, 13.73s/it, fwd_loss=1.5706 bwd_loss=0.098166 rate=102.6% proc=612 skip=0 clusters=20059]

5. ‡¶Ü‡¶ú ‡¶Ü‡¶¨‡¶π‡¶æ‡¶ì‡¶Ø‡¶º‡¶æ ‡¶≠‡¶æ‡¶≤‡ßã‡•§ -> the weather is good today.


Epoch 2/2:  60%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ                    | 299/500 [1:07:34<44:51, 13.39s/it, fwd_loss=1.1700 bwd_loss=0.073125 rate=102.0% proc=799 skip=0 clusters=21332]

[TRAIN-DEBUG] GPU mem (GB):
  GPU 0: alloc=7.74 resv=12.90
  GPU 1: alloc=1.28 resv=8.46
[TRAIN-DEBUG] step=800 loss=1.3455 opt_updates=50 clusters=21342

[CLUSTER] Top 5 clusters (by sample count):
------------------------------------------------------------------------------------------
Rank  Token             Count       Protos    Œº (mean)       œÑ (dev)     
------------------------------------------------------------------------------------------
1     ‡¶∂‡¶ø‡¶ï‡ßç‡¶∑             20          4         21.354459      6.465549    
2     ‡¶®‡¶ø‡¶∑‡ßç‡¶†             20          4         23.850510      5.149038    
3     ‡¶¶‡¶æ‡¶Ø‡¶º‡¶ø‡¶§‡ßç‡¶¨          19          3         22.813807      5.406342    
4     ‡¶®‡¶ø‡¶Ø‡¶º‡ßá             19          3         23.675864      5.383530    
5     ‡¶ñ‡ßá                19          3         21.088390      4.672610    
------------------------------------------------------------------------------------------
Total clusters: 21342 | Total sa

Epoch 2/2:  61%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñé                   | 303/500 [1:08:27<43:41, 13.31s/it, fwd_loss=1.3216 bwd_loss=0.082600 rate=100.0% proc=803 skip=0 clusters=21362]


[VALIDATION] Quick validation at step 804
1. ‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§ -> i stopped the call.
2. ‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§ -> i will buy it tomorrow.
3. ‡¶™‡¶æ‡¶§‡¶æ ‡¶ù‡¶∞‡ßá ‡¶™‡¶°‡¶º‡ßá‡¶õ‡ßá‡•§ -> the page has fallen.
4. ‡¶Ü‡¶Æ‡¶ø ‡¶≠‡¶æ‡¶≤‡ßã ‡¶Ü‡¶õ‡¶ø‡•§ -> i am well.


Epoch 2/2:  61%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñç                   | 304/500 [1:08:42<44:48, 13.72s/it, fwd_loss=1.2416 bwd_loss=0.077603 rate=102.0% proc=804 skip=0 clusters=21363]

5. ‡¶Ü‡¶ú ‡¶Ü‡¶¨‡¶π‡¶æ‡¶ì‡¶Ø‡¶º‡¶æ ‡¶≠‡¶æ‡¶≤‡ßã‡•§ -> the weather is good today.


Epoch 2/2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñâ| 499/500 [1:52:34<00:13, 13.52s/it, fwd_loss=1.2068 bwd_loss=0.075424 rate=101.6% proc=999 skip=0 clusters=22648]

[TRAIN-DEBUG] GPU mem (GB):
  GPU 0: alloc=7.74 resv=13.88
  GPU 1: alloc=1.28 resv=8.34
[TRAIN-DEBUG] step=1000 loss=1.1186 opt_updates=63 clusters=22652

[CLUSTER] Top 5 clusters (by sample count):
------------------------------------------------------------------------------------------
Rank  Token             Count       Protos    Œº (mean)       œÑ (dev)     
------------------------------------------------------------------------------------------
1     ‡¶∏‡¶æ‡¶ï‡ßç‡¶∑‡¶æ‡ßé           20          4         23.171738      4.528618    
2     ‡¶§‡¶§‡ßç‡¶§‡ßç‡¶¨            20          4         21.703555      4.709562    
3     ‡¶∏‡ßÅ‡¶ñ               20          4         21.752443      6.131968    
4     ‡¶ú‡¶∞‡ßÅ‡¶∞              20          4         20.615450      5.506017    
5     ‡¶¨‡¶õ‡¶∞               20          4         22.549204      5.459606    
------------------------------------------------------------------------------------------
Total clusters: 22652 | Total sampl

Epoch 2/2: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 500/500 [1:52:49<00:00, 13.54s/it, fwd_loss=1.1186 bwd_loss=0.069913 rate=101.6% proc=1000 skip=0 clusters=22652]



Epoch 2 summary:
  duration (min): 112.83
  optimizer updates: 64
  batches processed: 1000 (processed=1000, skipped=0)
  success rate (updates/expected): 103.2%
  clustered token types: 22652
  avg forward loss: 2.060247
[CHECKPOINT] Saved tatn_e2_s1000_20251122_184143.pt avg_loss=1.311995

[TRAIN] Training completed
[TRAIN] Success Rate (updates/expected): 103.2%
[TRAIN] Batches processed=1000 skipped=0
[TRAIN] Clustered Token Types: 22652

DISCOVERY PHASE: Clustering DSCD buffers to create prototypes...
[DISCOVERY] Found 500 tokens meeting threshold for clustering (threshold=20)
  Progress: 50/500 tokens processed (50 successful, 0 failed) [0.0s elapsed]
  Progress: 100/500 tokens processed (100 successful, 0 failed) [0.1s elapsed]
  Progress: 150/500 tokens processed (150 successful, 0 failed) [0.1s elapsed]
  Progress: 200/500 tokens processed (200 successful, 0 failed) [0.1s elapsed]
  Progress: 250/500 tokens processed (250 successful, 0 failed) [0.2s elapsed]
  Progress: 300/5

Loading dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [00:00<00:00, 20586.85it/s]


[CELL2] Loaded 1000 pairs from CSV, skipped 0 rows
[WARMUP] Prototype discovery: word_types=22655, total_protos=18685, multi_sense=6222
[WARMUP] Restored DSCD configuration
[CELL10] Inference warmup complete

[CELL10] Step 8: Evaluation

COMPREHENSIVE POST-TRAINING EVALUATION (Cell 9)

[EVAL] Running 5 tests...
--------------------------------------------------------------------------------

Test 1/5: ‡¶ï‡¶≤ = tap / call
Input: ‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§
Translation: i closed the call.
Ambiguous Words (real, counted): 0
No explanations produced (likely high-confidence translation)
Translation successful
------------------------------------------------------------

Test 2/5: ‡¶ï‡¶æ‡¶≤ = tomorrow / yesterday
Input: ‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§
Translation: i will buy it tomorrow.
Ambiguous Words (real, counted): 0
No explanations produced (likely high-confidence translation)
Translation successful
-------------------------------------------------------

In [16]:
# ==============================================================================
# CELL 12 (fixed): EXTENDED INFERENCE TESTING & ROBUST CHECKPOINT LOADER
# - Robust load_state_dict handling across PyTorch versions (namedtuple or tuple)
# - Clearer error messages and tracebacks controlled by VERBOSE_LOGGING
# - Safe device mapping and embedding resize before state load
# - Optional DSCD warm-up when prototypes are empty
# - Defensive guards for missing globals / helpers
# ==============================================================================
import os
import time
import traceback
from typing import Tuple, Any, Dict, List, Optional, Union

import torch

# -------------------------
# Local fallbacks for globals (safe)
# -------------------------
try:
    _DEVICE = DEVICE  # may be a torch.device or str in the user's globals
    _USE_MULTI_GPU = USE_MULTI_GPU
    _NUM_GPUS = NUM_GPUS
    _VERBOSE_LOGGING = VERBOSE_LOGGING
except Exception:
    _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    _NUM_GPUS = torch.cuda.device_count() if torch.cuda.is_available() else 0
    _USE_MULTI_GPU = _NUM_GPUS > 1
    _VERBOSE_LOGGING = False
    print("[CELL12] Warning: using fallback device/settings")

# Determine real-ambiguity uncertainty threshold consistently (use TAU_LOW if present)
_REAL_AMB_UNCERTAINTY = float(globals().get("TAU_LOW", 0.4))

# Helpers -----------------------------------------------------------------------
def _safe_print(msg: str):
    try:
        print(msg)
    except Exception:
        pass


def _maybe_traceback(exc: Exception):
    if _VERBOSE_LOGGING:
        traceback.print_exc()
    else:
        _safe_print("   (set VERBOSE_LOGGING = True for full traceback)")


def _looks_like_state_dict(d: Any) -> bool:
    """
    Heuristic to detect whether `d` is a PyTorch state_dict-like mapping:
      - mapping and keys are strings that contain '.' (module.weight style)
      - or many values are tensors
    """
    try:
        if not isinstance(d, dict):
            return False
        if not d:
            return False
        # If many keys contain dots, likely state dict
        dot_keys = sum(1 for k in d.keys() if isinstance(k, str) and "." in k)
        if dot_keys >= max(1, len(d) // 5):
            return True
        # If many values are tensors -> also likely state dict
        sample_vals = list(d.values())[:min(20, len(d))]
        tensor_count = sum(1 for v in sample_vals if torch.is_tensor(v))
        if tensor_count >= max(1, len(sample_vals) // 3):
            return True
    except Exception:
        pass
    return False

# ------------------------------------------------------------------------------
# try_load_checkpoint: robust loader
# ------------------------------------------------------------------------------
def try_load_checkpoint(checkpoint_path: str, tokenizer: Any) -> Tuple[bool, Union[Any, str]]:
    """
    Try to load a checkpoint file into a freshly instantiated model.
    Returns (success, model_instance_or_error_message_or_exception).

    Robust behaviors:
      - Accepts common checkpoint layouts (dict with 'model_state_dict'|'state_dict', nested 'model':{'state_dict':...}, or direct state-dict)
      - Attempts embedding resize before loading when tokenizer differs
      - Retries with stripped 'module.' prefixes if needed
      - Tries key decoding (bytes -> str) and simple key normalization
      - Loads to CPU first then moves model to _DEVICE
      - Sets model.eval() and returns it on success
    """
    if not os.path.exists(checkpoint_path):
        return False, f"Checkpoint path not found: {checkpoint_path}"

    if "MemoryOptimizedTATNWithExplanations" not in globals():
        return False, "Model class MemoryOptimizedTATNWithExplanations not available in current session."

    _safe_print(f"[CELL12] Loading checkpoint from: {checkpoint_path}")
    try:
        ckpt = torch.load(checkpoint_path, map_location="cpu")
    except Exception as e:
        _safe_print(f"[CELL12] Failed to load checkpoint file: {type(e).__name__}: {str(e)[:200]}")
        _maybe_traceback(e)
        return False, e

    # Locate a plausible state dict inside ckpt
    state = None
    try:
        if isinstance(ckpt, dict):
            # 1) conventional top-level keys
            cand_keys = [
                "model_state_dict", "state_dict", "model", "model_state", "state", "net", "model_state_dicts",
                "module_state_dict", "module"
            ]
            for k in cand_keys:
                v = ckpt.get(k, None)
                if v is None:
                    continue
                if isinstance(v, dict) and _looks_like_state_dict(v):
                    state = v
                    break
                # nested mapping containing state_dict
                if isinstance(v, dict) and "state_dict" in v and isinstance(v["state_dict"], dict):
                    state = v["state_dict"]
                    break
            # 2) If not found, check if the top-level mapping itself looks like a state_dict
            if state is None and _looks_like_state_dict(ckpt):
                state = ckpt
            # 3) try shallow search: any value that is a dict and looks like state_dict
            if state is None:
                for v in ckpt.values():
                    try:
                        if isinstance(v, dict) and _looks_like_state_dict(v):
                            state = v
                            break
                    except Exception:
                        continue
        else:
            # ckpt itself might be a state-dict-like mapping
            if isinstance(ckpt, dict) and _looks_like_state_dict(ckpt):
                state = ckpt
    except Exception as e:
        _safe_print(f"[CELL12] Error while inspecting checkpoint structure: {type(e).__name__}: {str(e)[:200]}")
        _maybe_traceback(e)
        return False, e

    if state is None:
        return False, "Could not find a model state-dict inside the checkpoint."

    # Instantiate a fresh model
    try:
        model_inst = MemoryOptimizedTATNWithExplanations(tokenizer)
    except Exception as e:
        _safe_print(f"[CELL12] Failed to instantiate model class: {type(e).__name__}: {str(e)[:200]}")
        _maybe_traceback(e)
        return False, e

    # Try to resize embeddings BEFORE loading to reduce mismatch issues
    try:
        mbart = getattr(model_inst, "mbart", None)
        if mbart is not None and hasattr(mbart, "get_input_embeddings"):
            emb = mbart.get_input_embeddings()
            cur = getattr(emb, "num_embeddings", None)
            tok_len = None
            try:
                if tokenizer is None:
                    tok_len = None
                elif hasattr(tokenizer, "vocab_size") and getattr(tokenizer, "vocab_size") is not None:
                    tok_len = int(getattr(tokenizer, "vocab_size"))
                elif hasattr(tokenizer, "__len__"):
                    tok_len = int(len(tokenizer))
                else:
                    tok_len = None
            except Exception:
                tok_len = getattr(tokenizer, "vocab_size", None) if tokenizer is not None else None

            if cur is not None and tok_len is not None and int(cur) != int(tok_len) and int(tok_len) > 0:
                _safe_print(f"[CELL12] Resizing embeddings: {cur} -> {tok_len}")
                try:
                    mbart.resize_token_embeddings(tok_len)
                except Exception as ex:
                    _safe_print(f"[CELL12] Embedding resize attempt failed: {type(ex).__name__}: {str(ex)[:200]}")
                    _maybe_traceback(ex)
    except Exception as e:
        _safe_print(f"[CELL12] Embedding resize warning: {type(e).__name__}: {str(e)[:200]}")
        _maybe_traceback(e)

    # Helper: attempt to load and return missing/unexpected lists (handles both tuple and NamedTuple results)
    def _load_and_report(state_dict: Dict[str, Any]) -> Tuple[bool, List[str], List[str], Optional[Exception]]:
        try:
            res = model_inst.load_state_dict(state_dict, strict=False)
            missing: List[str] = []
            unexpected: List[str] = []
            # new-style IncompatibleKeys
            if hasattr(res, "missing_keys") or hasattr(res, "unexpected_keys"):
                missing = list(getattr(res, "missing_keys", []) or [])
                unexpected = list(getattr(res, "unexpected_keys", []) or [])
            else:
                # old-style tuple (missing, unexpected)
                try:
                    if isinstance(res, (tuple, list)) and len(res) == 2:
                        missing = list(res[0]) or []
                        unexpected = list(res[1]) or []
                except Exception:
                    missing, unexpected = [], []
            return True, missing, unexpected, None
        except Exception as ex:
            return False, [str(ex)], [], ex

    # First attempt: direct load
    try:
        ok, missing, unexpected, exc = _load_and_report(state)
        if not ok:
            raise RuntimeError(f"Primary load_state_dict failed: {missing[:3]}")
        _safe_print(f"[CELL12] Loaded checkpoint (strict=False). Missing keys: {len(missing)} Unexpected keys: {len(unexpected)}")
        if _VERBOSE_LOGGING:
            if missing:
                _safe_print(f"  Missing keys (sample up to 20): {missing[:20]}")
            if unexpected:
                _safe_print(f"  Unexpected keys (sample up to 20): {unexpected[:20]}")
    except Exception as e:
        _safe_print(f"[CELL12] load_state_dict(strict=False) raised: {type(e).__name__}: {str(e)[:200]}")
        _maybe_traceback(e)
        # Retry 1: strip 'module.' prefixes (DataParallel artifact)
        try:
            if isinstance(state, dict):
                new_state = {}
                for k, v in state.items():
                    new_key = k
                    if isinstance(k, str) and k.startswith("module."):
                        new_key = k.replace("module.", "", 1)
                    new_state[new_key] = v
                ok, missing, unexpected, exc = _load_and_report(new_state)
                if ok:
                    _safe_print("[CELL12] Retried loading after stripping 'module.' prefixes.")
                    if _VERBOSE_LOGGING:
                        _safe_print(f"  Missing: {missing[:20]} Unexpected: {unexpected[:20]}")
                else:
                    raise RuntimeError(f"Retry after strip failed: {missing[:3]}")
            else:
                raise RuntimeError("State-dict is not a dict; cannot strip prefixes")
        except Exception as e2:
            _safe_print(f"[CELL12] Retry after stripping prefixes also failed: {type(e2).__name__}: {str(e2)[:200]}")
            _maybe_traceback(e2)
            # Retry 2: try converting byte keys to str if necessary
            try:
                if isinstance(state, dict):
                    conv_state = {}
                    changed = False
                    for k, v in state.items():
                        if isinstance(k, bytes):
                            try:
                                nk = k.decode("utf-8")
                                conv_state[nk] = v
                                changed = True
                            except Exception:
                                conv_state[k] = v
                        else:
                            conv_state[k] = v
                    if changed:
                        ok, missing, unexpected, exc = _load_and_report(conv_state)
                        if ok:
                            _safe_print("[CELL12] Retried loading after decoding byte keys to str.")
                        else:
                            raise RuntimeError(f"Retry after decode keys failed: {missing[:3]}")
                    else:
                        raise RuntimeError("No byte-key conversion possible; load failed previously.")
                else:
                    raise RuntimeError("State-dict is not a dict for byte-key conversion")
            except Exception as e3:
                _safe_print(f"[CELL12] All retry attempts failed: {type(e3).__name__}: {str(e3)[:200]}")
                _maybe_traceback(e3)
                return False, e3

    # Move model to target device and set eval()
    try:
        model_inst.to(_DEVICE)
        model_inst.eval()
    except Exception as e:
        try:
            core = model_inst.module if hasattr(model_inst, "module") else model_inst
            core.to(_DEVICE)
            core.eval()
            model_inst = core
        except Exception:
            _safe_print(f"[CELL12] Failed to move model to device: {type(e).__name__}: {str(e)[:200]}")
            _maybe_traceback(e)
            return False, e

    _safe_print("[CELL12] Checkpoint successfully loaded and model prepared on device.")
    return True, model_inst


# ------------------------------------------------------------------------------
# If a checkpoint exists, prefer loading it (but fall back to trained_model)
if os.path.exists("tatn_kaggle_final.pt") and globals().get("tokenizer", None) is not None:
    succ, model_or_err = try_load_checkpoint("tatn_kaggle_final.pt", globals().get("tokenizer"))
    if succ:
        globals()['trained_model'] = model_or_err
        _safe_print("[CELL12] Checkpoint loaded and will be used for inference testing.")
    else:
        _safe_print("[CELL12] Checkpoint load failed; falling back to trained_model from runtime (if available).")
        if _VERBOSE_LOGGING:
            _maybe_traceback(model_or_err)

# Warmup helper (useful if prototypes empty)
def maybe_run_warmup_if_needed(model, tokenizer, warmup_sents: int = 4000):
    """
    If DSCD prototype stores are empty, optionally run a short discovery warmup to
    populate DSCD buffers and allow prototype clustering to run. Uses dscd_discovery_warmup if present.
    """
    try:
        core = model.module if hasattr(model, "module") else model
        dscd = getattr(core, "dscd", None)
        if dscd is None:
            _safe_print("[CELL12] No DSCD component on the model; skipping warmup.")
            return
        proto_stores = getattr(dscd, "prototype_stores", None)
        if not proto_stores or len(proto_stores) == 0:
            if 'dscd_discovery_warmup' in globals():
                _safe_print("[CELL12] No DSCD prototypes detected - running short warmup to build prototypes...")
                try:
                    dscd_discovery_warmup(model, tokenizer, num_sents=warmup_sents, max_len=globals().get("MAX_LENGTH", 48))
                    _safe_print("[CELL12] Warmup complete.")
                except Exception as e:
                    _safe_print(f"[CELL12] Warmup failed/skipped: {type(e).__name__}: {str(e)[:200]}")
                    _maybe_traceback(e)
            else:
                _safe_print("[CELL12] Warmup helper not available - skipping prototype building.")
        else:
            _safe_print(f"[CELL12] DSCD prototype stores already contain {len(proto_stores)} types; warmup not needed.")
    except Exception as e:
        _safe_print(f"[CELL12] Warmup probe failed: {type(e).__name__}: {str(e)[:200]}")
        _maybe_traceback(e)


# Prepare test sentences -------------------------------------------------------
test_sentences: List[Tuple[str, str, str]] = [
    ("‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§", "I turned off the tap", "‡¶ï‡¶≤ = tap/call"),
    ("‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§", "Tomorrow I will buy a book", "‡¶ï‡¶æ‡¶≤ = tomorrow/yesterday"),
    ("‡¶™‡¶æ‡¶§‡¶æ ‡¶ù‡¶∞‡ßá ‡¶™‡¶°‡¶º‡ßá‡¶õ‡ßá‡•§", "The leaf has fallen", "‡¶™‡¶æ‡¶§‡¶æ = leaf/page"),
    ("‡¶§‡¶ø‡¶®‡¶ø ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï ‡¶ó‡ßá‡¶õ‡ßá‡¶®‡•§", "He went to the bank", "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï = bank/embankment"),
    ("‡¶Ü‡¶Æ‡¶ø ‡¶≠‡¶æ‡¶≤‡ßã ‡¶Ü‡¶õ‡¶ø‡•§", "I am fine", "Simple (no ambiguity)"),
    ("‡¶∏‡ßá ‡¶ñ‡ßÅ‡¶¨ ‡¶Æ‡¶ø‡¶∑‡ßç‡¶ü‡¶ø ‡¶ï‡¶•‡¶æ ‡¶¨‡¶≤‡ßá‡•§", "She speaks sweetly", "Adjective usage"),
    ("‡¶è‡¶ü‡¶æ ‡¶Ü‡¶Æ‡¶æ‡¶∞ ‡¶¨‡¶á‡•§", "This is my book", "Demonstrative pronoun"),
    ("‡¶§‡ßÅ‡¶Æ‡¶ø ‡¶ï‡¶ø ‡¶Ü‡¶Æ‡¶æ‡¶ï‡ßá ‡¶∏‡¶æ‡¶π‡¶æ‡¶Ø‡ßç‡¶Ø ‡¶ï‡¶∞‡¶§‡ßá ‡¶™‡¶æ‡¶∞‡ßã?", "Can you help me?", "Question form"),
    ("‡¶Ü‡¶ú ‡¶Ü‡¶¨‡¶π‡¶æ‡¶ì‡¶Ø‡¶º‡¶æ ‡¶≠‡¶æ‡¶≤‡ßã‡•§", "Weather description", "Simple"),
    ("‡¶Ü‡¶Æ‡¶∞‡¶æ ‡¶¨‡¶æ‡¶Ç‡¶≤‡¶æ‡¶¶‡ßá‡¶∂‡ßá ‡¶¨‡¶æ‡¶∏ ‡¶ï‡¶∞‡¶ø‡•§", "We live in Bangladesh", "Country name"),
    ("‡¶∏‡ßÇ‡¶∞‡ßç‡¶Ø ‡¶™‡ßÇ‡¶∞‡ßç‡¶¨ ‡¶¶‡¶ø‡¶ï‡ßá ‡¶ì‡¶†‡ßá‡•§", "The sun rises in the east", "Directional"),
    ("‡¶™‡¶æ‡¶ñ‡¶ø ‡¶Ü‡¶ï‡¶æ‡¶∂‡ßá ‡¶â‡¶°‡¶º‡ßá‡•§", "Birds fly in the sky", "Simple present"),
    ("‡¶∏‡ßá ‡¶∏‡ßç‡¶ï‡ßÅ‡¶≤‡ßá ‡¶Ø‡¶æ‡¶ö‡ßç‡¶õ‡ßá‡•§", "She is going to school", "Present continuous"),
]

# Verify prerequisites ---------------------------------------------------------
trained_model_present = ('trained_model' in globals() and globals().get('trained_model') is not None)
tokenizer_available = ('tokenizer' in globals() and globals().get('tokenizer') is not None)
translate_available = ('translate_with_explanations' in globals() and callable(globals().get('translate_with_explanations')))

if not (trained_model_present and tokenizer_available and translate_available):
    _safe_print("\n‚ùå Cannot run extended inference tests. Missing one or more of: trained_model, tokenizer, translate_with_explanations.")
    _safe_print("   Please run the full pipeline (Cells 0-11) or load a model checkpoint and tokenizer.")
else:
    # Ensure prototypes warmup if needed
    try:
        maybe_run_warmup_if_needed(globals().get('trained_model'), globals().get("tokenizer"), warmup_sents=4000)
    except Exception as e:
        _safe_print(f"[CELL12] Warmup invocation failed: {type(e).__name__}: {str(e)[:200]}")
        _maybe_traceback(e)

    # Run tests
    total = len(test_sentences)
    successes = 0
    tests_with_explanations = 0
    total_ambiguous_detected = 0

    _safe_print("\n" + "=" * 80)
    _safe_print("CELL 12: EXTENDED INFERENCE TESTING - START")
    _safe_print("=" * 80)

    for idx, (sent, expected, note) in enumerate(test_sentences, 1):
        _safe_print("\n" + "-" * 70)
        _safe_print(f"Test {idx}/{total}: {note}")
        _safe_print(f"Input: {sent}")
        _safe_print(f"Expected (informal): {expected}")
        try:
            model_for_infer = globals().get('trained_model')
            tok = globals().get('tokenizer')
            if model_for_infer is None or tok is None:
                raise RuntimeError("trained_model or tokenizer missing at inference time")

            # translate_with_explanations can be DP-wrapped; call whatever is present
            try:
                res = translate_with_explanations(model_for_infer, tok, sent)
            except Exception as e:
                _safe_print(f"[CELL12] translate_with_explanations raised: {type(e).__name__}: {str(e)[:200]}")
                _maybe_traceback(e)
                res = None

            if res is None:
                _safe_print("[CELL12] translate_with_explanations returned None or raised; skipping this test.")
                continue

            if not isinstance(res, dict):
                _safe_print(f"[CELL12] Warning: translate_with_explanations returned non-dict: {type(res)}; coercing to dict")
                res = {"translation": str(res)}

            translation = str(res.get("translation", "") or "")
            amb_count = int(res.get("ambiguous_words_detected", 0) or 0)
            explanations = res.get("explanations", []) or []

            _safe_print(f"Translation: {translation}")
            _safe_print(f"Ambiguous words detected (real): {amb_count}")

            if amb_count > 0:
                tests_with_explanations += 1
                total_ambiguous_detected += amb_count
                _safe_print("Explanations (ambiguous tokens):")
                for j, e in enumerate(explanations, 1):
                    try:
                        word = e.get("ambiguous_word", e.get("token", "N/A"))
                        u = float(e.get("uncertainty", 0.0) or 0.0)
                        s = float(e.get("span", 0.0) or 0.0)
                        marker = "üî•" if s > 0.3 else "  "
                        _safe_print(f"  {j}. {marker} '{word}'  U={u:.3f}  S={s:.3f}")
                        _safe_print(f"       {e.get('explanation', '')}")
                    except Exception:
                        if _VERBOSE_LOGGING:
                            traceback.print_exc()
                        continue
            else:
                _safe_print("No real ambiguity detected")

            if translation and translation.strip():
                successes += 1
                _safe_print("Translation produced (non-empty) ‚Üí counted as successful")
            else:
                _safe_print("Translation empty or failed ‚Üí counted as unsuccessful")

        except Exception as e:
            _safe_print(f"Test {idx} failed with exception: {type(e).__name__}: {str(e)[:200]}")
            _maybe_traceback(e)

    # Summary
    _safe_print("\n" + "=" * 80)
    _safe_print("CELL 12: EXTENDED INFERENCE TEST SUMMARY")
    _safe_print("=" * 80)
    _safe_print(f"Total tests: {total}")
    if total > 0:
        _safe_print(f"Successful translations: {successes} ({successes/total*100:.1f}%)")
        _safe_print(f"Tests with explanations: {tests_with_explanations} ({tests_with_explanations/total*100:.1f}%)")
        _safe_print(f"Total ambiguous words detected (real): {total_ambiguous_detected}")
        _safe_print(f"Avg ambiguous words per sentence: {total_ambiguous_detected/total:.2f}")
    else:
        _safe_print("No tests were executed")
    _safe_print("=" * 80)
    _safe_print(f"Real-ambiguity thresholds used: span > 0.3 OR uncertainty > {_REAL_AMB_UNCERTAINTY:.2f}")
    _safe_print("Cell 12 testing complete.")

[CELL12] Loading checkpoint from: tatn_kaggle_final.pt
[CELL12] Resizing embeddings: 128112 -> 128104
[CELL12] Loaded checkpoint (strict=False). Missing keys: 0 Unexpected keys: 0
[CELL12] Checkpoint successfully loaded and model prepared on device.
[CELL12] Checkpoint loaded and will be used for inference testing.
[CELL12] No DSCD prototypes detected - running short warmup to build prototypes...
[WARMUP] Starting DSCD discovery warmup...
[CELL2] Loading up to 4000 samples from local CSV: /kaggle/input/homo-bn-dataset/bn_homograph_complete_dataset.csv
[CELL2] Reading CSV file...
[CELL2] Processing 4000 rows from CSV...


Loading dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 4000/4000 [00:00<00:00, 21458.41it/s]


[CELL2] Loaded 4000 pairs from CSV, skipped 0 rows
[WARMUP] Prototype discovery: word_types=750, total_protos=563, multi_sense=0
[WARMUP] Restored DSCD configuration
[CELL12] Warmup complete.

CELL 12: EXTENDED INFERENCE TESTING - START

----------------------------------------------------------------------
Test 1/13: ‡¶ï‡¶≤ = tap/call
Input: ‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§
Expected (informal): I turned off the tap
Translation: i closed the call.
Ambiguous words detected (real): 0
No real ambiguity detected
Translation produced (non-empty) ‚Üí counted as successful

----------------------------------------------------------------------
Test 2/13: ‡¶ï‡¶æ‡¶≤ = tomorrow/yesterday
Input: ‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§
Expected (informal): Tomorrow I will buy a book
Translation: i will buy it tomorrow.
Ambiguous words detected (real): 0
No real ambiguity detected
Translation produced (non-empty) ‚Üí counted as successful

----------------------------------------

In [17]:
# ==============================================================================
# CELL 13 (patched): LARGE-SCALE EVALUATION (2000+ SAMPLES) - OPTIMIZED & HARDENED
# ==============================================================================
# - Batched generation (faster + VRAM-friendly)
# - Safe handling of DataParallel / wrapper models
# - Defensive tokenizer/lang-id handling for forced_bos_token_id
# - Robust metrics imports and fallbacks
# - CSV export, progress reporting, and clear error handling
# - Hardened decoding and many defensive fallbacks for real-world model/tokenizer shapes
# ==============================================================================

import os
import sys
import warnings
import time
import csv
import traceback
from typing import List, Dict, Tuple, Optional, Any

import numpy as np
import torch
from tqdm import tqdm

warnings.filterwarnings("ignore")

# ------------------------------
# Metrics availability detection
# ------------------------------
HAS_COMET = False
HAS_BLEU = False
HAS_CHRF = False

# Attempt COMET imports (optional)
try:
    # many environments won't have comet; guard carefully
    from comet import download_model, load_from_checkpoint  # type: ignore
    HAS_COMET = True
except Exception:
    HAS_COMET = False

# SacreBLEU (BLEU + CHRF)
try:
    import sacrebleu  # type: ignore
    # Validate presence of expected API functions
    if hasattr(sacrebleu, "corpus_bleu"):
        HAS_BLEU = True
    if hasattr(sacrebleu, "corpus_chrf"):
        HAS_CHRF = True
except Exception:
    HAS_BLEU = False
    HAS_CHRF = False

# ------------------------------
# Local safe global fallbacks
# ------------------------------
_DEVICE = globals().get("DEVICE", torch.device("cuda" if torch.cuda.is_available() else "cpu"))
_VERBOSE_LOGGING = bool(globals().get("VERBOSE_LOGGING", False))

# ------------------------------
# Utility helpers
# ------------------------------
def _safe_print(msg: str):
    try:
        print(msg)
    except Exception:
        pass

def _maybe_traceback(exc: Exception):
    if _VERBOSE_LOGGING:
        traceback.print_exc()
    else:
        print("   (set VERBOSE_LOGGING = True in Cell 0 for full traceback)")

def _unwrap_model(model: torch.nn.Module) -> torch.nn.Module:
    """Return core model (unwrap DataParallel/DistributedDataParallel if needed)."""
    return model.module if hasattr(model, "module") else model

def _get_forced_bos_id(tokenizer, core_mbart) -> Optional[int]:
    """Try several tokenizer/model attributes to find an English forced BOS id."""
    forced_id = None
    try:
        if hasattr(tokenizer, "get_lang_id"):
            for code in ("en", "en_XX", "en-XX", "eng"):
                try:
                    lid = tokenizer.get_lang_id(code)
                    if lid is not None:
                        forced_id = lid
                        break
                except Exception:
                    continue
        elif hasattr(tokenizer, "lang_code_to_id"):
            for code in ("en", "en_XX", "en-XX", "eng"):
                try:
                    candidate = tokenizer.lang_code_to_id.get(code, None)
                    if candidate is not None:
                        forced_id = candidate
                        break
                except Exception:
                    continue
    except Exception:
        forced_id = None
    # fallback to mbart config if available
    try:
        if forced_id is None and core_mbart is not None and hasattr(core_mbart, "config"):
            forced_id = getattr(core_mbart.config, "forced_bos_token_id", None)
            if forced_id is None:
                forced_id = getattr(core_mbart.config, "decoder_start_token_id", None)
    except Exception:
        forced_id = None
    return forced_id

# ------------------------------
# Large scale metrics class
# ------------------------------
class LargeScaleEvaluationMetrics:
    """Compute metrics on many samples efficiently, with fallbacks."""
    def __init__(self, device: Optional[torch.device] = None, batch_size: int = 32):
        self.device = device or _DEVICE
        self.batch_size = int(batch_size)
        self.comet_model = None
        self.metrics_available = {"comet": HAS_COMET, "bleu": HAS_BLEU, "chrf": HAS_CHRF}
        _safe_print("\n" + "=" * 80)
        _safe_print("INITIALIZING LARGE-SCALE EVALUATION METRICS")
        _safe_print("=" * 80)
        _safe_print(f"Device: {self.device}")
        _safe_print(f"Batch Size: {self.batch_size}")
        _safe_print(f"Metrics Available: BLEU={HAS_BLEU}, CHRF={HAS_CHRF}, COMET={HAS_COMET}")
        _safe_print("=" * 80 + "\n")

        if HAS_COMET:
            try:
                _safe_print("[EVAL] Loading COMET model (may take time)...")
                try:
                    model_path = download_model("Unbabel/wmt22-comet-da", saving_directory=".comet_cache")
                    self.comet_model = load_from_checkpoint(model_path)
                    _safe_print("[EVAL] ‚úì COMET model loaded")
                except Exception:
                    _safe_print("[EVAL] COMET automatic load failed; disabling COMET for this run.")
                    self.metrics_available["comet"] = False
                    self.comet_model = None
            except Exception:
                self.metrics_available["comet"] = False
                self.comet_model = None

    def compute_bleu_large(self, references: List[str], hypotheses: List[str]) -> Dict[str, Any]:
        if not self.metrics_available["bleu"] or not references or not hypotheses:
            return {"bleu": None, "error": "BLEU unavailable or empty inputs", "num_samples": len(hypotheses)}
        try:
            _safe_print(f"\n[BLEU] Computing BLEU on {len(hypotheses)} samples...")
            start_time = time.time()
            import sacrebleu  # type: ignore
            # sacrebleu expects list-of-reference-lists
            score = sacrebleu.corpus_bleu(hypotheses, [references])
            elapsed = time.time() - start_time
            result = {"bleu": float(score.score), "num_samples": len(hypotheses), "computation_time_sec": elapsed}
            _safe_print(f"[BLEU] ‚úì {score.score:.2f}/100 computed in {elapsed:.2f}s")
            return result
        except Exception as e:
            _safe_print(f"[BLEU] Error computing BLEU: {type(e).__name__}: {str(e)[:200]}")
            _maybe_traceback(e)
            return {"bleu": None, "error": str(e)[:200], "num_samples": len(hypotheses)}

    def compute_chrf_large(self, references: List[str], hypotheses: List[str]) -> Dict[str, Any]:
        if not self.metrics_available["chrf"] or not references or not hypotheses:
            return {"chrf": None, "error": "CHRF unavailable or empty inputs", "num_samples": len(hypotheses)}
        try:
            _safe_print(f"\n[CHRF++] Computing CHRF++ on {len(hypotheses)} samples...")
            start_time = time.time()
            import sacrebleu  # type: ignore
            score = sacrebleu.corpus_chrf(hypotheses, [references], beta=3.0)
            elapsed = time.time() - start_time
            result = {"chrf": float(score.score), "num_samples": len(hypotheses), "computation_time_sec": elapsed}
            _safe_print(f"[CHRF++] ‚úì {score.score:.2f}/100 computed in {elapsed:.2f}s")
            return result
        except Exception as e:
            _safe_print(f"[CHRF++] Error computing CHRF++: {type(e).__name__}: {str(e)[:200]}")
            _maybe_traceback(e)
            return {"chrf": None, "error": str(e)[:200], "num_samples": len(hypotheses)}

    def compute_comet_large(self, source_texts: List[str], references: List[str], hypotheses: List[str]) -> Dict[str, Any]:
        if not self.metrics_available.get("comet") or self.comet_model is None:
            return {"comet": None, "error": "COMET unavailable", "num_samples": len(hypotheses)}
        if not source_texts or not references or not hypotheses:
            return {"comet": None, "error": "Empty inputs", "num_samples": len(hypotheses)}
        try:
            _safe_print(f"\n[COMET] Computing COMET on {len(hypotheses)} samples...")
            start_time = time.time()
            data = [{"src": s, "ref": r, "mt": h} for s, r, h in zip(source_texts, references, hypotheses)]
            try:
                if torch.cuda.is_available():
                    self.comet_model.to(self.device)
            except Exception:
                pass
            with torch.no_grad():
                if hasattr(self.comet_model, "predict"):
                    output = self.comet_model.predict(data, batch_size=self.batch_size, gpus=1 if torch.cuda.is_available() else 0)
                    scores = np.asarray(getattr(output, "scores", []) or [], dtype=np.float32)
                    system_score = getattr(output, "system_score", None)
                else:
                    scores = []
                    for i in range(0, len(data), self.batch_size):
                        batch = data[i : i + self.batch_size]
                        try:
                            out = self.comet_model.predict(batch)
                            scores.extend(getattr(out, "scores", []) or [])
                        except Exception:
                            break
                    scores = np.asarray(scores, dtype=np.float32) if scores else np.array([])
                    system_score = float(np.mean(scores)) if scores.size else None
            elapsed = time.time() - start_time
            result = {
                "comet": float(system_score) if system_score is not None else None,
                "comet_mean": float(np.mean(scores)) if scores.size else None,
                "comet_median": float(np.median(scores)) if scores.size else None,
                "comet_std": float(np.std(scores)) if scores.size else None,
                "comet_min": float(np.min(scores)) if scores.size else None,
                "comet_max": float(np.max(scores)) if scores.size else None,
                "comet_scores": scores.tolist() if scores.size else [],
                "num_samples": len(hypotheses),
                "computation_time_sec": elapsed,
            }
            _safe_print(f"[COMET] ‚úì Computed in {elapsed:.2f}s")
            return result
        except Exception as e:
            _safe_print(f"[COMET] Error computing COMET: {type(e).__name__}: {str(e)[:200]}")
            _maybe_traceback(e)
            return {"comet": None, "error": str(e)[:200], "num_samples": len(hypotheses)}

    def compute_all_metrics_large(self, source_texts: List[str], references: List[str], hypotheses: List[str]) -> Dict[str, Any]:
        results: Dict[str, Any] = {"num_samples": len(hypotheses), "metrics": {}, "timestamp": time.strftime("%Y-%m-%d %H:%M:%S")}
        if self.metrics_available.get("bleu"):
            results["metrics"]["bleu"] = self.compute_bleu_large(references, hypotheses)
        if self.metrics_available.get("chrf"):
            results["metrics"]["chrf"] = self.compute_chrf_large(references, hypotheses)
        if self.metrics_available.get("comet"):
            results["metrics"]["comet"] = self.compute_comet_large(source_texts, references, hypotheses)
        return results

# ------------------------------
# Main evaluation function
# ------------------------------
def evaluate_on_large_dataset(
    model: torch.nn.Module,
    tokenizer,
    dataset: Optional[List[Tuple[str, str]]] = None,
    num_samples: int = 2000,
    batch_size: int = 32,
    save_results: bool = True,
    max_length: int = 512,
) -> Dict[str, Any]:
    """
    Evaluate model on a dataset (default 2000 samples) and compute metrics.
    Uses batched tokenization + generation for efficiency and stability.
    """
    _safe_print("\n" + "=" * 80)
    _safe_print("LARGE-SCALE EVALUATION ON SAMPLES")
    _safe_print("=" * 80 + "\n")

    try:
        # Prepare dataset
        _safe_print(f"[PREP] Preparing dataset (requested {num_samples} samples)...")
        if not dataset:
            if "load_and_preprocess_optimized" in globals():
                try:
                    pairs = load_and_preprocess_optimized(num_samples)
                except Exception as e:
                    _safe_print(f"[PREP] load_and_preprocess_optimized failed: {type(e).__name__}: {str(e)[:200]}")
                    _maybe_traceback(e)
                    sample_pairs = [("‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§", "I stopped the call."), ("‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§", "I will buy a book tomorrow.")]
                    pairs = (sample_pairs * ((num_samples // len(sample_pairs)) + 1))[:num_samples]
            else:
                _safe_print("[PREP] No data loader found; using dummy data")
                sample_pairs = [("‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§", "I stopped the call."), ("‡¶ï‡¶æ‡¶≤ ‡¶Ü‡¶Æ‡¶ø ‡¶¨‡¶á ‡¶ï‡¶ø‡¶®‡¶¨‡•§", "I will buy a book tomorrow.")]
                pairs = (sample_pairs * ((num_samples // len(sample_pairs)) + 1))[:num_samples]
        else:
            pairs = dataset

        pairs = pairs[:num_samples]
        _safe_print(f"[PREP] ‚úì Loaded {len(pairs)} samples\n")

        source_texts = [s for s, _ in pairs]
        references = [r for _, r in pairs]
        hypotheses: List[str] = []

        # Unwrap core model and prepare generation function
        core = _unwrap_model(model)
        core.eval()
        try:
            core.to(_DEVICE)
        except Exception:
            pass

        # Determine generation callable
        gen_callable = None
        mbart = getattr(core, "mbart", None)
        if mbart is not None and hasattr(mbart, "generate"):
            gen_callable = mbart.generate
            generation_backend = mbart
        elif hasattr(core, "generate"):
            gen_callable = core.generate
            generation_backend = core
        else:
            raise RuntimeError("No generate() found on model or model.mbart")

        forced_bos = _get_forced_bos_id(tokenizer, mbart)

        # Batch generation
        _safe_print(f"[GEN] Generating predictions in batches (batch_size={batch_size}) ...")
        n = len(source_texts)
        batch_size_gen = max(1, int(batch_size))
        with torch.no_grad():
            for start in tqdm(range(0, n, batch_size_gen), desc="[GEN] Batches", unit="batch"):
                batch_srcs = source_texts[start : start + batch_size_gen]
                try:
                    # Tokenize batch
                    try:
                        try:
                            # set source language token if tokenizer supports this
                            if hasattr(tokenizer, "src_lang"):
                                setattr(tokenizer, "src_lang", "bn")
                        except Exception:
                            pass
                        enc = tokenizer(batch_srcs, return_tensors="pt", padding=True, truncation=True, max_length=max_length)
                        enc = {k: v.to(_DEVICE) for k, v in enc.items() if isinstance(v, torch.Tensor)}
                    except Exception as e:
                        _safe_print(f"[GEN] Batch tokenization failed: {type(e).__name__}: {str(e)[:200]}")
                        _maybe_traceback(e)
                        # fallback per-sentence generation
                        for src in batch_srcs:
                            try:
                                try:
                                    if hasattr(tokenizer, "src_lang"):
                                        setattr(tokenizer, "src_lang", "bn")
                                except Exception:
                                    pass
                                enc1 = tokenizer(src, return_tensors="pt", truncation=True, max_length=max_length)
                                enc1 = {k: v.to(_DEVICE) for k, v in enc1.items() if isinstance(v, torch.Tensor)}
                                gen_kwargs1 = {"max_length": 128, "num_beams": 1, "early_stopping": True}
                                if forced_bos is not None:
                                    gen_kwargs1["forced_bos_token_id"] = int(forced_bos)
                                gen_ids = gen_callable(**enc1, **gen_kwargs1)
                                # decode
                                if isinstance(gen_ids, torch.Tensor):
                                    seqs = gen_ids.cpu().tolist()
                                    # take first sequence
                                    seq = seqs[0] if isinstance(seqs, list) and len(seqs) > 0 else seqs
                                    try:
                                        hyp = tokenizer.batch_decode([seq], skip_special_tokens=True)[0]
                                    except Exception:
                                        hyp = tokenizer.decode(seq, skip_special_tokens=True) if hasattr(tokenizer, "decode") else ""
                                else:
                                    try:
                                        hyp = tokenizer.batch_decode(gen_ids, skip_special_tokens=True)[0]
                                    except Exception:
                                        try:
                                            hyp = tokenizer.decode(gen_ids[0], skip_special_tokens=True)
                                        except Exception:
                                            hyp = ""
                                hypotheses.append(hyp)
                            except Exception:
                                hypotheses.append("")
                        continue

                    # call generate
                    gen_kwargs: Dict[str, Any] = {"max_length": 256, "num_beams": 5, "early_stopping": True}
                    if forced_bos is not None:
                        gen_kwargs["forced_bos_token_id"] = int(forced_bos)

                    # Some generate callables are bound methods; ensure proper device for backend if possible
                    try:
                        if torch.cuda.is_available() and hasattr(generation_backend, "to"):
                            generation_backend.to(_DEVICE)
                    except Exception:
                        pass

                    generated_ids = gen_callable(**enc, **gen_kwargs)

                    # normalize to list-of-seqs for decoding
                    gen_ids_tensor = None
                    if isinstance(generated_ids, torch.Tensor):
                        gen_ids_tensor = generated_ids
                    elif isinstance(generated_ids, (list, tuple)):
                        # sometimes returns list of tensors or list of lists
                        if len(generated_ids) == 0:
                            gen_ids_tensor = torch.empty((0, 0), dtype=torch.long)
                        else:
                            first = generated_ids[0]
                            if isinstance(first, torch.Tensor):
                                try:
                                    gen_ids_tensor = torch.stack(generated_ids, dim=0)
                                except Exception:
                                    # try to convert each to list then pad
                                    seqs = [g.cpu().tolist() if isinstance(g, torch.Tensor) else list(g) for g in generated_ids]
                                    # convert to list-of-lists for tokenizer.batch_decode
                                    gen_ids_tensor = seqs
                            else:
                                # list-of-lists
                                gen_ids_tensor = generated_ids
                    else:
                        # unknown type: try to coerce to list
                        try:
                            gen_ids_tensor = torch.tensor(generated_ids)
                        except Exception:
                            gen_ids_tensor = generated_ids

                    # decode safely
                    try:
                        # tokenizer.batch_decode supports list-of-ids or tensor (batch, seq)
                        if isinstance(gen_ids_tensor, (list, tuple)):
                            # list-of-lists or list-of-tensors: normalize to list-of-lists of ints
                            seqs = []
                            for item in gen_ids_tensor:
                                if isinstance(item, torch.Tensor):
                                    seqs.append(item.cpu().tolist())
                                else:
                                    seqs.append(list(item))
                            batch_hyps = tokenizer.batch_decode(seqs, skip_special_tokens=True)
                        elif isinstance(gen_ids_tensor, torch.Tensor):
                            batch_hyps = tokenizer.batch_decode(gen_ids_tensor.cpu(), skip_special_tokens=True)
                        else:
                            # fallback: try decoding element-wise
                            batch_hyps = []
                            try:
                                for item in gen_ids_tensor:
                                    try:
                                        batch_hyps.append(tokenizer.decode(item, skip_special_tokens=True))
                                    except Exception:
                                        batch_hyps.append("")
                            except Exception:
                                batch_hyps = ["" for _ in range(len(batch_srcs))]
                    except Exception as e:
                        _safe_print(f"[GEN] Decoding failed: {type(e).__name__}: {str(e)[:200]}")
                        _maybe_traceback(e)
                        # fallback per-sequence decode
                        batch_hyps = []
                        if isinstance(gen_ids_tensor, torch.Tensor):
                            seqs = gen_ids_tensor.cpu().tolist()
                        elif isinstance(gen_ids_tensor, (list, tuple)):
                            seqs = [g.cpu().tolist() if isinstance(g, torch.Tensor) else list(g) for g in gen_ids_tensor]
                        else:
                            seqs = []
                        for seq in seqs:
                            try:
                                batch_hyps.append(tokenizer.decode(seq, skip_special_tokens=True))
                            except Exception:
                                batch_hyps.append("")

                    hypotheses.extend(batch_hyps)

                except Exception as e:
                    _safe_print(f"[GEN] Batch generation failed at start={start}: {type(e).__name__}: {str(e)[:200]}")
                    _maybe_traceback(e)
                    # fallback per-sentence
                    for src in batch_srcs:
                        try:
                            try:
                                if hasattr(tokenizer, "src_lang"):
                                    setattr(tokenizer, "src_lang", "bn")
                            except Exception:
                                pass
                            enc1 = tokenizer(src, return_tensors="pt", truncation=True, max_length=max_length)
                            enc1 = {k: v.to(_DEVICE) for k, v in enc1.items() if isinstance(v, torch.Tensor)}
                            gen_kwargs1 = {"max_length": 128, "num_beams": 1, "early_stopping": True}
                            if forced_bos is not None:
                                gen_kwargs1["forced_bos_token_id"] = int(forced_bos)
                            gen_ids = gen_callable(**enc1, **gen_kwargs1)
                            seq = gen_ids[0] if isinstance(gen_ids, (list, tuple)) else gen_ids
                            try:
                                hyp = tokenizer.decode(seq, skip_special_tokens=True)
                            except Exception:
                                try:
                                    hyp = tokenizer.batch_decode([seq], skip_special_tokens=True)[0]
                                except Exception:
                                    hyp = ""
                            hypotheses.append(hyp)
                        except Exception:
                            hypotheses.append("")

        # length alignment
        if len(hypotheses) < len(source_texts):
            hypotheses.extend([""] * (len(source_texts) - len(hypotheses)))

        _safe_print(f"[GEN] ‚úì Generated {len(hypotheses)} predictions\n")

        # Compute metrics
        metrics_computer = LargeScaleEvaluationMetrics(device=_DEVICE, batch_size=batch_size)
        metrics_results = metrics_computer.compute_all_metrics_large(source_texts, references, hypotheses)

        # Summary report
        _safe_print("\n" + "=" * 80)
        _safe_print("FINAL EVALUATION REPORT")
        _safe_print("=" * 80 + "\n")

        _safe_print(f"Dataset: {len(hypotheses)} samples")
        _safe_print(f"Timestamp: {metrics_results.get('timestamp', '')}\n")

        _safe_print("Metric Scores:")
        _safe_print("-" * 80)
        if "bleu" in metrics_results["metrics"]:
            bleu_data = metrics_results["metrics"]["bleu"]
            if bleu_data.get("bleu") is not None:
                _safe_print(f"  BLEU:  {bleu_data['bleu']:>7.2f}/100 (computed in {bleu_data.get('computation_time_sec', 0.0):.1f}s)")
            else:
                _safe_print(f"  BLEU:  ERROR - {bleu_data.get('error', 'Unknown')}")
        if "chrf" in metrics_results["metrics"]:
            chrf_data = metrics_results["metrics"]["chrf"]
            if chrf_data.get("chrf") is not None:
                _safe_print(f"  CHRF++: {chrf_data['chrf']:>7.2f}/100 (computed in {chrf_data.get('computation_time_sec', 0.0):.1f}s)")
            else:
                _safe_print(f"  CHRF++: ERROR - {chrf_data.get('error', 'Unknown')}")
        if "comet" in metrics_results["metrics"]:
            comet_data = metrics_results["metrics"]["comet"]
            if comet_data.get("comet") is not None:
                _safe_print(f"  COMET:  {comet_data['comet']:>7.4f}/1.0 (computed in {comet_data.get('computation_time_sec', 0.0):.1f}s)")
                if comet_data.get("comet_mean") is not None:
                    _safe_print(f"         Mean={comet_data['comet_mean']:.4f}, Median={comet_data['comet_median']:.4f}, Std={comet_data['comet_std']:.4f}")
            else:
                _safe_print(f"  COMET:  ERROR - {comet_data.get('error', 'Unknown')}")
        _safe_print("-" * 80)

        # Save results to CSV
        csv_path = None
        if save_results:
            csv_path = "evaluation_results_2000.csv"
            _safe_print(f"\n[SAVE] Saving results to {csv_path}...")
            try:
                with open(csv_path, "w", newline="", encoding="utf-8") as f:
                    writer = csv.writer(f)
                    writer.writerow(["Index", "Source", "Reference", "Hypothesis"])
                    for idx, (s, r, h) in enumerate(zip(source_texts, references, hypotheses), 1):
                        writer.writerow([idx, s, r, h])
                _safe_print(f"[SAVE] ‚úì Saved {len(hypotheses)} predictions to {csv_path}")
            except Exception as e:
                _safe_print(f"[SAVE] Error saving CSV: {type(e).__name__}: {str(e)[:200]}")
                _maybe_traceback(e)

        # Sample outputs
        _safe_print("\n" + "=" * 80)
        _safe_print("SAMPLE TRANSLATIONS (first 10)")
        _safe_print("=" * 80)
        for i, (s, r, h) in enumerate(zip(source_texts[:10], references[:10], hypotheses[:10]), 1):
            _safe_print(f"\nSample {i}:")
            _safe_print(f"  Source:      {s}")
            _safe_print(f"  Reference:   {r}")
            _safe_print(f"  Hypothesis:  {h}")
        _safe_print("\n" + "=" * 80)

        return {"metrics": metrics_results["metrics"], "num_samples": len(hypotheses), "predictions": list(zip(source_texts, references, hypotheses)), "csv_file": csv_path}

    except Exception as e:
        _safe_print(f"\n[ERROR] Evaluation failed: {type(e).__name__}: {str(e)[:200]}")
        _maybe_traceback(e)
        return {"error": str(e), "metrics": {}}


# Example usage (script mode)
if __name__ == "__main__":
    _safe_print(
        """
    ‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
    ‚ïë          LARGE-SCALE EVALUATION (2000+ SAMPLES) - HOW TO USE           ‚ïë
    ‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù
    """
    )
    eval_results = evaluate_on_large_dataset(model=globals().get("trained_model"), tokenizer=globals().get("tokenizer"), num_samples=2000, batch_size=32, save_results=True)
    _safe_print(eval_results)
    _safe_print("‚úÖ Cell 13: Large-scale evaluation (2000+ samples) - ready to run")


    ‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
    ‚ïë          LARGE-SCALE EVALUATION (2000+ SAMPLES) - HOW TO USE           ‚ïë
    ‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù
    

LARGE-SCALE EVALUATION ON SAMPLES

[PREP] Preparing dataset (requested 2000 samples)...
[CELL2] Loading up to 2000 samples from local CSV: /kaggle/input/homo-bn-dataset/bn_homograph_complete_dataset.csv
[CELL2] Reading CSV file...
[CELL2] Processing 2000 rows from CSV...


Loading dataset: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 2000/2000 [00:00<00:00, 21323.63it/s]


[CELL2] Loaded 2000 pairs from CSV, skipped 0 rows
[PREP] ‚úì Loaded 2000 samples

[GEN] Generating predictions in batches (batch_size=32) ...


[GEN] Batches: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 63/63 [06:07<00:00,  5.83s/batch]


[GEN] ‚úì Generated 2000 predictions


INITIALIZING LARGE-SCALE EVALUATION METRICS
Device: cuda
Batch Size: 32
Metrics Available: BLEU=True, CHRF=True, COMET=False


[BLEU] Computing BLEU on 2000 samples...
[BLEU] ‚úì 26.00/100 computed in 0.17s

[CHRF++] Computing CHRF++ on 2000 samples...
[CHRF++] ‚úì 46.54/100 computed in 0.24s

FINAL EVALUATION REPORT

Dataset: 2000 samples
Timestamp: 2025-11-22 18:49:18

Metric Scores:
--------------------------------------------------------------------------------
  BLEU:    26.00/100 (computed in 0.2s)
  CHRF++:   46.54/100 (computed in 0.2s)
--------------------------------------------------------------------------------

[SAVE] Saving results to evaluation_results_2000.csv...
[SAVE] ‚úì Saved 2000 predictions to evaluation_results_2000.csv

SAMPLE TRANSLATIONS (first 10)

Sample 1:
  Source:      ‡¶Ü‡¶Æ‡¶ø ‡¶ï‡¶≤ ‡¶¨‡¶®‡ßç‡¶ß ‡¶ï‡¶∞‡ßá‡¶õ‡¶ø‡•§
  Reference:   i have turned off the tap.
  Hypothesis:  i closed the call.

Sample 2:
  Source:    

In [18]:
# Debug & fallback tokenizer with offsets
import unicodedata, re

def safe_tokenize_with_offsets(tokenizer, text):
    """
    Return (input_ids_tensor, tokens_list, offsets_list).
    - If tokenizer supports return_offsets_mapping, use it (fast tokenizers).
    - Else, fall back to slow-tokenizer path: get tokens and reconstruct character offsets.
    """
    # Try fast tokenizer API first
    try:
        enc = tokenizer(text, return_offsets_mapping=True, return_tensors='pt', add_special_tokens=False)
        ids = enc['input_ids']
        tokens = tokenizer.convert_ids_to_tokens(ids[0])
        offsets = [(int(s), int(e)) for (s, e) in enc['offset_mapping'][0].tolist()]
        return ids, tokens, offsets
    except NotImplementedError:
        pass
    except Exception as e:
        # Other exceptions: fall back to slow approach
        pass

    # Slow fallback: get token strings then greedily align with the original text.
    # This tries to be robust for SentencePiece (‚ñÅ) and BPE (##) style markers.
    tokens = tokenizer.tokenize(text, add_special_tokens=False)
    # Normalize text for matching
    norm_text = unicodedata.normalize('NFC', text)
    offsets = []
    pos = 0

    # Helper to try matching a token string to a substring at or after pos
    def find_token_span(tok, pos):
        # Clean token markers to a comparable surface form
        surface = tok
        # SentencePiece style: leading '‚ñÅ' means a word boundary and stands for a space
        if surface.startswith('‚ñÅ'):
            surface = surface.replace('‚ñÅ', '')
            # we allow matching at current pos or after a whitespace
            # attempt to find after the most recent whitespace or at pos
        # BERT style: '##' is a continuation subword marker
        if surface.startswith('##'):
            surface = surface[2:]
        # Try direct find starting at pos
        idx = norm_text.find(surface, pos)
        if idx != -1:
            return idx, idx + len(surface)
        # If not found, try skipping whitespace characters forward up to some limit
        m = re.search(re.escape(surface), norm_text[pos:])
        if m:
            start = pos + m.start()
            return start, start + len(surface)
        return None

    for tok in tokens:
        # For empty tokens (rare), skip
        if len(tok.strip()) == 0:
            offsets.append((pos, pos))
            continue
        res = find_token_span(tok, pos)
        if res is None:
            # As a last resort try searching from zero (global)
            res = find_token_span(tok, 0)
        if res is None:
            # Give up for this token ‚Äî mark a zero-length span to avoid crashes
            offsets.append((0, 0))
        else:
            start, end = res
            offsets.append((start, end))
            pos = end  # advance scanning position

    # Build input_ids tensor (without offsets API): use tokenizer.encode to get ids
    ids_list = tokenizer.encode(text, add_special_tokens=False)
    import torch
    ids_tensor = torch.tensor([ids_list], dtype=torch.long)
    return ids_tensor, tokens, offsets

In [19]:
# Quick DSCD/TRG debug cell - paste into your notebook and run
# Edit `SENT` if you want to test a different failing sentence.
SENT = "‡¶§‡¶ø‡¶®‡¶ø ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï ‡¶ó‡ßá‡¶õ‡ßá‡¶®‡•§"  # change to a failing sentence you saw

print("SENTENCE:", SENT)

# 1) Tokenizer tokens + offsets
try:
    enc = tokenizer(SENT, return_offsets_mapping=True, return_tensors="pt", truncation=True)
    toks = tokenizer.convert_ids_to_tokens(enc["input_ids"][0])
    offs = enc["offset_mapping"][0].tolist()
    print("\nTOKENIZER (fast) tokens and offsets:")
    print("tokens =", toks)
    print("offsets =", offs)
except Exception as e:
    print("\nTokenizer fast path failed:", repr(e))
    try:
        # fallback: encode_plus older API
        enc2 = tokenizer.encode_plus(SENT, return_offsets_mapping=True, return_tensors="pt", truncation=True)
        toks = tokenizer.convert_ids_to_tokens(enc2["input_ids"][0])
        offs = enc2["offset_mapping"][0].tolist()
        print("\nTOKENIZER (fallback) tokens and offsets:")
        print("tokens =", toks)
        print("offsets =", offs)
    except Exception as e2:
        print("Tokenizer fallback also failed:", repr(e2))
        print("Please tell me the tokenizer variable name and whether it's a fast tokenizer (use_fast=True).")

# 2) Check reconstruct_word_spans / safe helpers if present
for helper in ("reconstruct_word_spans", "safe_offsets_tokenize", "safe_tokenize_with_offsets"):
    if helper in globals():
        try:
            print(f"\nCalling {helper}(...):")
            res = globals()[helper](tokenizer, SENT)
            print(helper, "->", type(res), repr(res)[:1000])
            break
        except Exception as e:
            print(helper, "exists but raised:", repr(e))
    else:
        # not in globals, try to import from likely module if present
        pass

# 3) Inspect dscd object on model (supports model or model.module)
dscd = None
try:
    if hasattr(model, "module"):
        dscd = getattr(model.module, "dscd", None)
    else:
        dscd = getattr(model, "dscd", None)
    print("\nFound dscd on model?:", dscd is not None)
except Exception as e:
    print("Error checking model.dscd:", repr(e))

# 4) Ask DSCD if it would track the canonical token for the word(s)
# We will attempt a couple of likely canonical forms from the token list above.
cands = []
try:
    # gather candidate tokens from tokenizer output
    cands = [t for t in toks if t and t not in tokenizer.all_special_tokens]
except Exception:
    pass
# include an explicit candidate with the homograph word if not present
if "‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï" not in cands:
    cands.append("‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï")

print("\nCandidate tokens to test should_track_token (sample):", cands[:10])

if dscd is None:
    print("\nNo dscd found on the model instance. If DSCD is a separate object, please provide its variable name.")
else:
    for w in cands[:10]:
        try:
            st = getattr(dscd, "should_track_token", None)
            if st is None:
                print("dscd.should_track_token not found on dscd object.")
                break
            ok = st(w)
            print(f"should_track_token('{w}') -> {ok}")
        except Exception as e:
            print("should_track_token call failed for", w, "->", repr(e))

# 5) Print buffer and prototype store status for the canonical homograph '‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï'
if dscd is not None:
    try:
        buf_keys = list(getattr(dscd, "buffers", {}).keys())[:50]
        proto_keys = list(getattr(dscd, "prototype_stores", {}).keys())[:50]
        print("\nDSCD buffer keys (sample, up to 50):", buf_keys)
        print("DSCD prototype_stores keys (sample, up to 50):", proto_keys)
        print("len buffer for '‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï':", len(dscd.buffers.get("‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï", [])))
        ps = dscd.prototype_stores.get("‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï")
        if ps is None:
            print("No prototype store for '‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï' yet.")
        else:
            try:
                print("prototype store for '‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï' .size():", ps.size())
                print("prototype store counts (if present):", getattr(ps, "counts", None))
            except Exception as e:
                print("Cannot inspect prototype store object:", repr(e))
    except Exception as e:
        print("Error inspecting dscd buffers/prototypes:", repr(e))

# 6) Try a forward with explanations (best-effort). Skip if model or device problems.
try:
    model.eval()
    device = next(model.parameters()).device
    ids = enc["input_ids"].to(device)
    attn = enc["attention_mask"].to(device)
    out = None
    # try a few common forward APIs
    if hasattr(model, "forward_with_explanations"):
        out = model.forward_with_explanations(input_ids=ids, attention_mask=attn)
    elif hasattr(model, "forward") and callable(getattr(model, "forward")):
        out = model(input_ids=ids, attention_mask=attn)
    else:
        out = model(ids, attention_mask=attn)
    print("\nModel forward success. Top-level keys in output (repr):")
    try:
        print(list(out.keys()))
    except Exception:
        print(repr(out)[:1000])

    # attempt to extract dscd outputs or proto_probs
    dscd_out = None
    if isinstance(out, dict):
        for k in ("dscd_outputs", "dscd_out", "explanations", "extra"):
            if k in out:
                print("Found key in model output:", k)
                dscd_out = out[k]
                break
    # also try attribute on model if forward didn't return
    if dscd_out is None:
        dscd_out = getattr(model, "last_dscd_outputs", None)
    print("dscd_out found?:", dscd_out is not None)
    if dscd_out:
        # try to print proto_probs & uncertainties shapes or repr
        for name in ("proto_probs", "uncertainties", "gates", "token_word_map"):
            if isinstance(dscd_out, dict) and name in dscd_out:
                v = dscd_out[name]
                try:
                    print(f"{name}: type={type(v)}, repr slice={repr(v)[:400]}")
                except Exception:
                    print(f"{name}: type={type(v)} (couldn't repr)")
except Exception as e:
    print("\nModel forward/extraction failed (this is OK if your environment can't run the model here):", repr(e))

print("\n--- END DEBUG CELL ---\n")
print("Copy the full printed output and paste it here. I will give the exact one-line fix next.")

SENTENCE: ‡¶§‡¶ø‡¶®‡¶ø ‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï ‡¶ó‡ßá‡¶õ‡ßá‡¶®‡•§

Tokenizer fast path failed: NotImplementedError('return_offset_mapping is not available when using Python tokenizers. To use this feature, change your tokenizer to one deriving from transformers.PreTrainedTokenizerFast. More information on available tokenizers at https://github.com/huggingface/transformers/pull/2674')
Tokenizer fallback also failed: NotImplementedError('return_offset_mapping is not available when using Python tokenizers. To use this feature, change your tokenizer to one deriving from transformers.PreTrainedTokenizerFast. More information on available tokenizers at https://github.com/huggingface/transformers/pull/2674')
Please tell me the tokenizer variable name and whether it's a fast tokenizer (use_fast=True).

Calling reconstruct_word_spans(...):
reconstruct_word_spans -> <class 'tuple'> ({0: '‡¶§‡¶ø‡¶®‡¶ø', 1: '‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç', 2: '‡¶¨‡ßç‡¶Ø‡¶æ‡¶Ç‡¶ï', 3: '‡¶ó', 4: '‡¶ó‡ßá‡¶õ‡ßá‡¶®', 5: '‡¶ó‡ßá‡¶õ‡ßá‡¶®‡