In [None]:

import os, glob, random, json, math, gc, time, shutil, itertools
import numpy as np
import torch, torchaudio
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from collections import defaultdict
from sklearn.metrics import roc_curve, accuracy_score
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler
from torchaudio import pipelines as TAP

torch.set_float32_matmul_precision('high')

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = True

DATA_ROOT = "/kaggle/input/vietnam-celeb-dataset/full-dataset"
AUDIO_DIR = os.path.join(DATA_ROOT, "data")
TXT_E     = os.path.join(DATA_ROOT, "vietnam-celeb-e.txt")
TXT_H     = os.path.join(DATA_ROOT, "vietnam-celeb-h.txt")
assert os.path.exists(AUDIO_DIR), f"Không thấy thư mục audio: {AUDIO_DIR}"
print("DEVICE:", DEVICE)
print("AUDIO_DIR:", AUDIO_DIR)

FAST_MODE           = False        
PROFILE             = "M"          
USE_WAVLM_LARGE     = False        

CFG = {
  "S":  dict(d_model=256, nhead=4, d_ff=1024, n_layers=6,  emb_dim=192, k=31, bs=14, steps=6000, max_seconds=2.0),
  "M":  dict(d_model=320, nhead=5, d_ff=1280, n_layers=8,  emb_dim=192, k=31, bs=12, steps=6000, max_seconds=2.0),
  "L":  dict(d_model=384, nhead=6, d_ff=1536, n_layers=10, emb_dim=256, k=41, bs=10, steps=5500, max_seconds=2.0),
  "XL": dict(d_model=512, nhead=8, d_ff=2048, n_layers=12, emb_dim=256, k=41, bs=6,  steps=5000, max_seconds=1.6),
}
P = CFG[PROFILE]

TARGET_SR           = 16000
MAX_SECONDS         = P["max_seconds"]
MAX_SECONDS_VERIFY  = 4.0
BATCH_SIZE          = P["bs"]
NUM_WORKERS         = 2
MIN_UTTS_PER_SPK    = 5
SPLIT_RATIO         = (0.80, 0.10, 0.10)


E_LP = 3       
E_A  = 4       
E_B  = 30      
STEPS_PER_EPOCH_TRAIN = P["steps"]


EER_EVAL_EVERY_LP = 1
EER_EVAL_EVERY_FT = 1
MAX_PAIRS_LP      = 2500 if not FAST_MODE else 1500
MAX_PAIRS_FT      = 5000 if not FAST_MODE else 3000
EER_CROP_MODE     = "tta"     
USE_SNORM         = True


COHORT_K_INIT  = 6000 if not FAST_MODE else 2000
COHORT_K_FT    = 10000 if not FAST_MODE else 4000
SNORM_TOPM     = 600
EARLY_STOP_PATIENCE = 2

print("HYPERPARAMS:", dict(TARGET_SR=TARGET_SR, BATCH_SIZE=BATCH_SIZE, FAST_MODE=FAST_MODE))


def _norm_rel(p: str) -> str:
    p = p.strip().replace("\\", "/").lstrip("./")
    if p.startswith("data/"): p = p.split("data/", 1)[1]
    return p

def read_pairs_to_set(txt_path: str) -> set:
    s = set()
    if not os.path.exists(txt_path): return s
    with open(txt_path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln or ln.startswith("#"): continue
            parts = ln.split()
            if len(parts) >= 3:
                s.add(_norm_rel(parts[1])); s.add(_norm_rel(parts[2]))
    return s

E_files = read_pairs_to_set(TXT_E)
H_files = read_pairs_to_set(TXT_H)
ban = E_files | H_files
print("Files in E:", len(E_files), "| Files in H:", len(H_files), "| BAN total:", len(ban))

EXTS = (".wav", ".flac", ".mp3", ".m4a", ".ogg")
all_abs = [p for p in glob.glob(os.path.join(AUDIO_DIR, "**", "*"), recursive=True)
           if os.path.splitext(p)[1].lower() in EXTS]
print("Found audio in data/:", len(all_abs))

kept_abs = [p for p in all_abs
            if os.path.relpath(p, AUDIO_DIR).replace("\\", "/") not in ban]
print("Kept for training (data \\ (E ∪ H)):", len(kept_abs))
assert len(kept_abs) > 0

spk2files_all = defaultdict(list)
for p in kept_abs:
    sid = os.path.basename(os.path.dirname(p))
    spk2files_all[sid].append(p)

spk2files = {s: fs for s, fs in spk2files_all.items() if len(fs) >= MIN_UTTS_PER_SPK}
speakers = sorted(spk2files.keys())
spk2idx = {s: i for i, s in enumerate(speakers)}
idx2spk = {i: s for s, i in spk2idx.items()}
num_classes = len(speakers)
print("Speakers kept:", num_classes)
assert num_classes >= 2

SPLIT_TR, SPLIT_VA, _ = SPLIT_RATIO
train_pairs, val_pairs, test_pairs = [], [], []
random.seed(123)
for sid, files in spk2files.items():
    files = files[:] ; random.shuffle(files)
    n = len(files)
    n_tr = max(3, int(SPLIT_TR * n))
    n_va = max(1, int(SPLIT_VA * n))
    tr = files[:n_tr]; va = files[n_tr:n_tr + n_va]; te = files[n_tr + n_va:]
    if len(te) == 0 and len(va) > 0: te = [va.pop()]
    if len(te) == 0 and len(tr) > 1: te = [tr.pop()]
    train_pairs += [(p, spk2idx[sid]) for p in tr]
    val_pairs   += [(p, spk2idx[sid]) for p in va]
    test_pairs  += [(p, spk2idx[sid]) for p in te]

print("Split sizes (train/val/test):", len(train_pairs), len(val_pairs), len(test_pairs))
to_rel = lambda p: os.path.relpath(p, AUDIO_DIR).replace("\\", "/")
train_rel = {to_rel(p) for p, _ in train_pairs}
val_rel   = {to_rel(p) for p, _ in val_pairs}
test_rel  = {to_rel(p) for p, _ in test_pairs}
print("train ∩ E:", len(train_rel & E_files), "| train ∩ H:", len(train_rel & H_files))
print(" val  ∩ E:", len(val_rel   & E_files), "|  val  ∩ H:", len(val_rel   & H_files))
print(" test ∩ E:", len(test_rel  & E_files), "| test  ∩ H:", len(test_rel  & H_files))


MAX_LEN        = int(MAX_SECONDS * TARGET_SR)
MAX_LEN_VERIFY = int(MAX_SECONDS_VERIFY * TARGET_SR)
_resamp = {}

def load_wav(path, target_sr=TARGET_SR, max_len=MAX_LEN, crop_train=True):
    wav, sr = torchaudio.load(path)
    if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
    if sr != target_sr:
        if sr not in _resamp: _resamp[sr] = torchaudio.transforms.Resample(sr, target_sr)
        wav = _resamp[sr](wav)
    if wav.shape[-1] > max_len:
        if crop_train:
            start = torch.randint(0, wav.shape[-1]-max_len+1, (1,)).item()
            wav = wav[:, start:start+max_len]
        else:
            st = (wav.shape[-1]-max_len)//2
            wav = wav[:, st:st+max_len]
    return wav

def aug_wav(wav):
    if torch.rand(1).item() < 0.6:
        g = 10 ** (float(torch.empty(1).uniform_(-8, 6)) / 20)
        wav = (wav * g).clamp(-1, 1)
    if torch.rand(1).item() < 0.6:
        snr_db = float(torch.empty(1).uniform_(5, 20))
        sig_pwr = wav.pow(2).mean().clamp_min(1e-9)
        noise_pwr = sig_pwr / (10 ** (snr_db / 10))
        noise = torch.randn_like(wav)
        noise = noise * (noise_pwr.sqrt() / (noise.pow(2).mean().clamp_min(1e-9).sqrt()))
        wav = (wav + noise).clamp(-1, 1)
    if torch.rand(1).item() < 0.4:
        if torch.rand(1).item() < 0.5:
            cutoff = float(torch.empty(1).uniform_(2500, 4500))
            wav = torchaudio.functional.lowpass_biquad(wav, TARGET_SR, cutoff)
        else:
            cutoff = float(torch.empty(1).uniform_(100, 250))
            wav = torchaudio.functional.highpass_biquad(wav, TARGET_SR, cutoff)
    if torch.rand(1).item() < 0.5:
        max_shift = int(0.08 * TARGET_SR)
        shift = int(torch.randint(-max_shift, max_shift + 1, (1,)).item())
        if shift > 0:
            wav = torch.cat([torch.zeros(1, shift, device=wav.device), wav[..., :-shift]], dim=-1)
        elif shift < 0:
            sh = -shift
            wav = torch.cat([wav[..., sh:], torch.zeros(1, sh, device=wav.device)], dim=-1)
    if wav.shape[-1] > MAX_LEN: wav = wav[..., :MAX_LEN]
    return wav

def maybe_mixup(wav, y, pool_pairs, alpha=(0.2,0.8)):
    if torch.rand(1).item() >= 0.3: return wav, y
    p2, y2 = random.choice(pool_pairs)
    w2 = load_wav(p2, crop_train=True)
    if w2.shape[-1] < wav.shape[-1]:
        pad = torch.zeros(1, wav.shape[-1]-w2.shape[-1])
        w2 = torch.cat([w2, pad], dim=-1)
    elif w2.shape[-1] > wav.shape[-1]:
        w2 = w2[..., :wav.shape[-1]]
    a = float(torch.empty(1).uniform_(*alpha))
    wav = (a*wav + (1-a)*w2).clamp(-1, 1)
    return wav, y

FREQ_MASK_PARAM = 18
TIME_MASK_PARAM = 32
FREQ_MASKS = 2
TIME_MASKS = 2

def spec_augment_(feat):

    x = feat.transpose(1,2) 
    for _ in range(FREQ_MASKS):
        x = torchaudio.transforms.FrequencyMasking(freq_mask_param=FREQ_MASK_PARAM)(x)
    for _ in range(TIME_MASKS):
        x = torchaudio.transforms.TimeMasking(time_mask_param=TIME_MASK_PARAM)(x)
    return x.transpose(1,2)


class VNCelebID(Dataset):
    def __init__(self, pairs, train=True, pool_pairs=None):
        self.pairs = pairs; self.train = train
        self.pool_pairs = pool_pairs if pool_pairs is not None else pairs
    def __len__(self): return len(self.pairs)
    def __getitem__(self, i):
        path, y = self.pairs[i]
        wav = load_wav(path, crop_train=self.train)
        if self.train:
            wav = aug_wav(wav)
            wav, y = maybe_mixup(wav, y, self.pool_pairs)
        return wav, y

def pad_batch(batch):
    wavs, ys = zip(*batch)
    maxlen = max(w.shape[-1] for w in wavs)
    out = []
    for w in wavs:
        if w.shape[-1] < maxlen:
            pad = torch.zeros(1, maxlen - w.shape[-1])
            w = torch.cat([w, pad], dim=-1)
        out.append(w)
    wavs = torch.stack(out, dim=0) 
    ys = torch.tensor(ys, dtype=torch.long)
    return wavs, ys

SAFE_BATCH = BATCH_SIZE
dl_tr = DataLoader(
    VNCelebID(train_pairs, True, pool_pairs=train_pairs),
    batch_size=SAFE_BATCH, shuffle=True,
    num_workers=min(NUM_WORKERS, 2),
    collate_fn=pad_batch,
    pin_memory=False, persistent_workers=False,
    prefetch_factor=2
)
dl_va = DataLoader(VNCelebID(val_pairs, False),  batch_size=SAFE_BATCH, shuffle=False, num_workers=0, collate_fn=pad_batch)
dl_te = DataLoader(VNCelebID(test_pairs, False), batch_size=SAFE_BATCH, shuffle=False, num_workers=0, collate_fn=pad_batch)
print("Dataloaders ready. BATCH_SIZE:", SAFE_BATCH)

def iterate_steps(dloader, steps=None):
    if steps is None:
        for b in dloader: yield b
    else:
        cyc = itertools.cycle(dloader)
        for _ in range(steps): yield next(cyc)


class AAMSoftmax(nn.Module):
    def __init__(self, emb_dim, num_classes, s=30.0, m=0.3):  
        super().__init__()
        self.W = nn.Parameter(torch.randn(emb_dim, num_classes))
        nn.init.xavier_normal_(self.W)
        self.s = s; self.m = m
    def forward(self, emb, y):
        W = F.normalize(self.W, dim=0)
        logits = emb @ W
        y_onehot = F.one_hot(y, num_classes=W.size(1)).float()
        theta = torch.acos(torch.clamp(logits, -1+1e-7, 1-1e-7))
        target_logits = torch.cos(theta + self.m)
        logits = logits * (1 - y_onehot) + target_logits * y_onehot
        logits = logits * self.s
        return logits

class FFN(nn.Module):
    def __init__(self, d_model, d_ff, p=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.LayerNorm(d_model),
            nn.Linear(d_model, d_ff), nn.SiLU(),
            nn.Dropout(p),
            nn.Linear(d_ff, d_model),
            nn.Dropout(p),
        )
        self.scale = 0.5
    def forward(self, x):
        return x + self.scale * self.net(x)

class MHSA(nn.Module):
    def __init__(self, d_model, nhead, p=0.1):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.attn = nn.MultiheadAttention(d_model, nhead, dropout=p, batch_first=True)
        self.dropout = nn.Dropout(p)
    def forward(self, x, key_padding_mask=None):
        q = k = v = self.ln(x)
        out, _ = self.attn(q, k, v, key_padding_mask=key_padding_mask, need_weights=False)
        return x + self.dropout(out)

class ConvModule(nn.Module):
    def __init__(self, d_model, k=31, p=0.1):
        super().__init__()
        self.ln = nn.LayerNorm(d_model)
        self.pw1 = nn.Conv1d(d_model, 2*d_model, kernel_size=1)
        self.dw  = nn.Conv1d(d_model, d_model, kernel_size=k, padding=k//2, groups=d_model)
        self.bn  = nn.BatchNorm1d(d_model)
        self.swish = nn.SiLU()
        self.pw2 = nn.Conv1d(d_model, d_model, kernel_size=1)
        self.dropout = nn.Dropout(p)
    def forward(self, x, mask=None):
        y = self.ln(x).transpose(1,2)
        y = F.glu(self.pw1(y), dim=1)
        y = self.dw(y); y = self.bn(y); y = self.swish(y); y = self.pw2(y)
        y = y.transpose(1,2)
        y = self.dropout(y)
        if mask is not None:
            y = y.masked_fill(mask.unsqueeze(-1), 0.0)
        return x + y

class ConformerBlock(nn.Module):
    def __init__(self, d_model, nhead, d_ff, k=31, p=0.1):
        super().__init__()
        self.ff1 = FFN(d_model, d_ff, p)
        self.mha = MHSA(d_model, nhead, p)
        self.conv= ConvModule(d_model, k, p)
        self.ff2 = FFN(d_model, d_ff, p)
        self.ln  = nn.LayerNorm(d_model)
    def forward(self, x, key_padding_mask=None):
        x = self.ff1(x)
        x = self.mha(x, key_padding_mask)
        x = self.conv(x, key_padding_mask)
        x = self.ff2(x)
        return self.ln(x)

def lengths_to_mask(lengths, max_len=None):
    B = lengths.size(0)
    T = int(max_len) if max_len is not None else int(lengths.max().item())
    idx = torch.arange(T, device=lengths.device).unsqueeze(0).expand(B, T)
    return idx >= lengths.unsqueeze(1)

class MFAConformerEmbedder(nn.Module):

    def __init__(self, bundle, d_model, nhead, d_ff, n_layers, emb_dim, p=0.1, k=31):
        super().__init__()
        self.ssl = bundle.get_model()
        for p_ in self.ssl.parameters(): p_.requires_grad = False
        self.ms_proj = nn.LazyLinear(d_model)
        self.layers  = nn.ModuleList([ConformerBlock(d_model, nhead, d_ff, k=k, p=p) for _ in range(n_layers)])
        hidden_fc = 1024 if (2 * n_layers * d_model >= 8192) else 512
        self.emb_head = nn.Sequential(
            nn.Linear(2 * d_model * n_layers, hidden_fc), nn.ReLU(),
            nn.Linear(hidden_fc, emb_dim)
        )
        self.n_layers = n_layers
        self.d_model  = d_model

    def forward(self, wav, apply_specaug=True):
        if wav.dim() == 3:
            wav = wav.squeeze(1)

        with autocast('cuda', enabled=(DEVICE=='cuda')):
            feats_list, lengths = self.ssl.extract_features(wav)  # list or tensor (B,T,Ci)
            hs = torch.cat(feats_list, dim=-1) if isinstance(feats_list, list) else feats_list
            hs = self.ms_proj(hs)  # (B,T,d_model)

        hs = hs.float()  # AMP safety

        T = hs.size(1)
        if lengths is None:
            lens = torch.full((hs.size(0),), T, device=hs.device, dtype=torch.long)
        else:
            lens = lengths.to(hs.device, dtype=torch.long).clamp_max(T)
        pad_mask = lengths_to_mask(lens, max_len=T)

        # Light SpecAugment on SSL features (train only)
        if self.training and apply_specaug:
            hs = spec_augment_(hs)

        h = hs
        hidden = []
        for layer in self.layers:
            h = layer(h, key_padding_mask=pad_mask)
            hidden.append(h)
        ms = torch.cat(hidden, dim=-1)  # (B,T,L*d_model)

        if pad_mask.any():
            valid = (~pad_mask).float().unsqueeze(-1)
            sumv  = (ms * valid).sum(dim=1)
            cnt   = valid.sum(dim=1).clamp_min(1.0)
            mean  = sumv / cnt
            var   = ((ms - mean.unsqueeze(1))**2 * valid).sum(dim=1) / cnt
            std   = var.clamp_min(1e-5).sqrt()
        else:
            mean = ms.mean(dim=1)
            std  = ms.std(dim=1).clamp_min(1e-5)
        stats = torch.cat([mean, std], dim=-1)

        emb = self.emb_head(stats)
        return F.normalize(emb, p=2, dim=-1)


bundle = TAP.WAVLM_LARGE if USE_WAVLM_LARGE else TAP.WAVLM_BASE
mfa = MFAConformerEmbedder(
    bundle=bundle,
    d_model=P["d_model"], nhead=P["nhead"], d_ff=P["d_ff"],
    n_layers=P["n_layers"], emb_dim=P["emb_dim"], k=P["k"]
).to(DEVICE)
head_ce  = nn.Linear(P["emb_dim"], num_classes).to(DEVICE)
head_aam = AAMSoftmax(emb_dim=P["emb_dim"], num_classes=num_classes, s=30.0, m=0.3).to(DEVICE)

@torch.no_grad()
def logits_aam_infer(emb):
    Wn = F.normalize(head_aam.W, dim=0)
    return emb @ Wn * head_aam.s

def forward_logits_ce(batch_wav):
    x = batch_wav.to(DEVICE, dtype=torch.float32)
    e = mfa(x); return head_ce(e)

def forward_logits_aam(batch_wav, y):
    x = batch_wav.to(DEVICE, dtype=torch.float32)
    e = mfa(x); return head_aam(e, y), e

@torch.no_grad()
def embed_batch(batch_wav):
    x = batch_wav.to(DEVICE, dtype=torch.float32)
    e = mfa(x)
    return e.detach().cpu().numpy()


def resolve_rel(rel):
    rel = rel.strip().replace("\\","/").lstrip("./")
    if rel.startswith("data/"): rel = rel.split("data/", 1)[1]
    return os.path.join(AUDIO_DIR, rel)

EMBED_CACHE = {}
MODEL_REV = 0

@torch.no_grad()
def embed_path(path, n_crops=2):
    key = (MODEL_REV, EER_CROP_MODE, path)
    if key in EMBED_CACHE: return EMBED_CACHE[key]
    wav, sr = torchaudio.load(path)
    if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
    if sr != TARGET_SR:
        wav = torchaudio.transforms.Resample(sr, TARGET_SR)(wav)
    L = wav.shape[-1]; seg = MAX_LEN_VERIFY
    embs = []
    if EER_CROP_MODE == "center":
        if L < seg:
            pad = torch.zeros(1, seg - L); crop = torch.cat([wav, pad], dim=-1)
        else:
            st = (L - seg) // 2; crop = wav[:, st:st+seg]
        x = crop.unsqueeze(0)
        e = embed_batch(x)[0]; embs.append(e)
    else:
        if L <= seg:
            pad = torch.zeros(1, seg - L); base = torch.cat([wav, pad], dim=-1)
            crops = [base, base]
        else:
            st = (L - seg) // 2
            crops = [wav[:, st:st+seg], wav[:, int(torch.randint(0, L - seg + 1, (1,)).item()):][:, :seg]]
        for c in crops:
            x = c.unsqueeze(0); e = embed_batch(x)[0]; embs.append(e)
    e = np.mean(np.stack(embs, 0), axis=0); e = e / (np.linalg.norm(e) + 1e-9)
    EMBED_CACHE[key] = e
    return e

def read_pairs(txt_path):
    pairs = []
    if (txt_path is None) or (not os.path.exists(txt_path)): return pairs
    with open(txt_path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln or ln.startswith("#"): continue
            parts = ln.split()
            if len(parts) < 3: continue
            lab = int(parts[0]); a = parts[1]; b = parts[2]
            pairs.append((lab, a, b))
    return pairs

print("Building cohort (initial) ...")
def build_cohort(paths, K=2000, seed=7):
    rng = np.random.default_rng(seed)
    if not paths: return np.zeros((0, P["emb_dim"]), dtype=np.float32)
    idxs = rng.choice(len(paths), size=min(K, len(paths)), replace=False)
    E = []
    for i in idxs:
        p = paths[i]
        if os.path.exists(p):
            e = embed_path(p); E.append(e)
    if len(E) == 0: return np.zeros((0, P["emb_dim"]), dtype=np.float32)
    return np.stack(E, 0)

COHORT = build_cohort(kept_abs, K=COHORT_K_INIT)
print("Cohort size:", COHORT.shape)

def s_norm_score(e1, e2, cohort=COHORT, topM=SNORM_TOPM):
    if (not USE_SNORM) or (cohort is None) or getattr(cohort, "shape", (0,))[0] == 0:
        return float(e1 @ e2)
    c1 = cohort @ e1; c2 = cohort @ e2
    m1 = np.sort(c1)[-min(topM, c1.shape[0]):].mean()
    m2 = np.sort(c2)[-min(topM, c2.shape[0]):].mean()
    return float(e1 @ e2) - 0.5 * (m1 + m2)

@torch.no_grad()
def compute_eer_for_pairs(txt_path, max_pairs=None):
    pairs = read_pairs(txt_path)
    if len(pairs) == 0: return None, None, 0, 0
    if (max_pairs is not None) and (len(pairs) > max_pairs):
        rng = np.random.default_rng(123)
        idxs = rng.choice(len(pairs), size=max_pairs, replace=False)
        pairs = [pairs[i] for i in idxs]
    scores, labels, skipped = [], [], 0
    for lab, a_rel, b_rel in pairs:
        pa = resolve_rel(a_rel); pb = resolve_rel(b_rel)
        if not (os.path.exists(pa) and os.path.exists(pb)):
            skipped += 1; continue
        ea = embed_path(pa); eb = embed_path(pb)
        scores.append(s_norm_score(ea, eb)); labels.append(lab)
    if len(scores) == 0: return None, None, 0, skipped
    labels = np.array(labels); scores = np.array(scores)
    fpr, tpr, th = roc_curve(labels, scores, pos_label=1)
    fnr = 1 - tpr; i = np.nanargmin(np.abs(fnr - fpr))
    eer = float((fnr[i] + fpr[i]) / 2.0); thr = float(th[i])
    return eer, thr, len(scores), skipped


ce = nn.CrossEntropyLoss()
scaler = GradScaler('cuda') if DEVICE == "cuda" else None

def run_epoch_lp(dloader, opt, steps=None):
    mfa.train(); head_ce.train(); head_aam.eval()
    tot, correct, loss_sum = 0, 0, 0.0
    for it, (wavs, ys) in enumerate(iterate_steps(dloader, steps)):
        ys = ys.to(DEVICE, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        with autocast('cuda', enabled=(DEVICE=='cuda')):
            logits = forward_logits_ce(wavs)
            loss = ce(logits, ys)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(list(mfa.parameters())+list(head_ce.parameters()), 5.0)
        scaler.step(opt); scaler.update()
        preds = logits.argmax(1); bs = ys.size(0)
        correct += (preds == ys).sum().item(); tot += bs
        loss_sum += float(loss.item()) * bs
        if DEVICE == "cuda" and (it % 50 == 0): torch.cuda.empty_cache()
    return loss_sum / max(1, tot), correct / max(1, tot)

def run_epoch_hybrid(dloader, opt, lambda_aam=0.7, steps=None):
    mfa.train(); head_ce.train(); head_aam.train()
    tot, correct, loss_sum = 0, 0, 0.0
    for it, (wavs, ys) in enumerate(iterate_steps(dloader, steps)):
        ys = ys.to(DEVICE, non_blocking=True)
        opt.zero_grad(set_to_none=True)
        x = wavs.to(DEVICE, dtype=torch.float32)
        with autocast('cuda', enabled=(DEVICE=='cuda')):
            e = mfa(x, apply_specaug=True)
            logits_ce  = head_ce(e)
            logits_aam = head_aam(e, ys)
            loss = (1 - lambda_aam) * ce(logits_ce, ys) + lambda_aam * ce(logits_aam, ys)
        scaler.scale(loss).backward()
        torch.nn.utils.clip_grad_norm_(list(mfa.parameters())+list(head_ce.parameters())+list(head_aam.parameters()), 5.0)
        scaler.step(opt); scaler.update()
        preds = logits_ce.argmax(1)
        bs = ys.size(0)
        correct += (preds == ys).sum().item(); tot += bs
        loss_sum += float(loss.item()) * bs
        if DEVICE == "cuda" and (it % 50 == 0): torch.cuda.empty_cache()
    return loss_sum / max(1, tot), correct / max(1, tot)

def run_epoch_aam(dloader, opt=None, train=True, steps=None):
    if train: mfa.train(); head_aam.train(); head_ce.eval()
    else:     mfa.eval();  head_aam.eval();  head_ce.eval()
    tot, correct, loss_sum = 0, 0, 0.0
    for it, (wavs, ys) in enumerate(iterate_steps(dloader, steps if train else None)):
        ys = ys.to(DEVICE, non_blocking=True)
        if train: opt.zero_grad(set_to_none=True)
        x = wavs.to(DEVICE, dtype=torch.float32)
        with autocast('cuda', enabled=(DEVICE=='cuda')):
            e = mfa(x, apply_specaug=train)
            logits_train = head_aam(e, ys)
            loss = ce(logits_train, ys) if train else torch.tensor(0., device=x.device)
        if train:
            scaler.scale(loss).backward()
            torch.nn.utils.clip_grad_norm_(list(mfa.parameters())+list(head_aam.parameters()), 5.0)
            scaler.step(opt); scaler.update()
        with torch.no_grad():
            preds = logits_aam_infer(e).argmax(1)
        bs = ys.size(0)
        correct += (preds == ys).sum().item(); tot += bs
        loss_sum += float(loss.item()) * bs
        if DEVICE == "cuda" and (it % 50 == 0): torch.cuda.empty_cache()
    return loss_sum / max(1, tot), correct / max(1, tot)

def _set_requires(module, flag: bool):
    for p in module.parameters(): p.requires_grad = flag

def _unfreeze_ssl_last_k(ssl_module, k=8):  
    try:
        enc = ssl_module.encoder
        layers = list(enc.transformer.layers)
        L = len(layers)
        for i, layer in enumerate(layers):
            req = (i >= L - k)
            for p in layer.parameters(): p.requires_grad = req
        print(f"Unfroze last-{k} WavLM layers.")
    except Exception:
        _set_requires(ssl_module, True)
        print("Fallback: unfroze ALL SSL layers.")

def set_trainable_lp():
    _set_requires(mfa.ssl, False)
    _set_requires(mfa.ms_proj, True)
    for l in mfa.layers: _set_requires(l, True)
    _set_requires(mfa.emb_head, True)
    _set_requires(head_ce, True)
    _set_requires(head_aam, False)

def set_trainable_ftA():
    _set_requires(mfa.ssl, False)
    _set_requires(mfa.ms_proj, True)
    for l in mfa.layers: _set_requires(l, True)
    _set_requires(mfa.emb_head, True)
    _set_requires(head_ce, True)
    _set_requires(head_aam, True)

def set_trainable_ftB():
    _unfreeze_ssl_last_k(mfa.ssl, k=8)   
    _set_requires(mfa.ms_proj, True)
    for l in mfa.layers: _set_requires(l, True)
    _set_requires(mfa.emb_head, True)
    _set_requires(head_ce, False)
    _set_requires(head_aam, True)


best_va_acc   = -1.0
best_path_val = "/kaggle/working/mfa_conf_best_by_val.pt"
best_path_eer_e = "/kaggle/working/mfa_conf_best_by_eer_E.pt"
best_path_eer_h = "/kaggle/working/mfa_conf_best_by_eer_H.pt"
alias_best      = "/kaggle/working/mfa_conf_best.pt"

best_eer_e, best_thr_e = 1.0, None
best_eer_h, best_thr_h = 1.0, None
GLOBAL_EER_BEST = 1.0

def save_state(path, split=None, eer=None, thr=None):
    obj = {"mfa": mfa.state_dict(), "head_ce": head_ce.state_dict(), "head_aam": head_aam.state_dict(), "spk2idx": spk2idx}
    if split is not None: obj.update({"best_eer": eer, "best_thr": thr, "split": split})
    torch.save(obj, path)

def eval_eer_checkpoint(tag="LP", max_pairs_E=MAX_PAIRS_LP, max_pairs_H=MAX_PAIRS_LP):
    global EMBED_CACHE
    EMBED_CACHE = {}
    eer_e, thr_e, n_e, _ = compute_eer_for_pairs(TXT_E, max_pairs=max_pairs_E) if TXT_E else (None,None,0,0)
    eer_h, thr_h, n_h, _ = compute_eer_for_pairs(TXT_H, max_pairs=max_pairs_H) if TXT_H else (None,None,0,0)
    msg = f"[{tag}]"
    if eer_e is not None: msg += f" EER_E={eer_e*100:.2f}%({n_e})"
    if eer_h is not None: msg += f" | EER_H={eer_h*100:.2f}%({n_h})"
    print(msg)
    return eer_e, thr_e, eer_h, thr_h

def maybe_update_best(eer_e, thr_e, eer_h, thr_h):
    global best_eer_e, best_thr_e, best_eer_h, best_thr_h, GLOBAL_EER_BEST
    updated = False
    if (eer_e is not None) and (eer_e < best_eer_e):
        best_eer_e, best_thr_e = eer_e, thr_e
        save_state(best_path_eer_e, split="E", eer=best_eer_e, thr=best_thr_e)
        updated = True
        if best_eer_e < GLOBAL_EER_BEST:
            shutil.copy(best_path_eer_e, alias_best)
            GLOBAL_EER_BEST = best_eer_e
            print(f" [alias] mfa_conf_best.pt -> E (EER={best_eer_e*100:.2f}%)")
    if (eer_h is not None) and (eer_h < best_eer_h):
        best_eer_h, best_thr_h = eer_h, thr_h
        save_state(best_path_eer_h, split="H", eer=best_eer_h, thr=best_thr_h)
        updated = True
        if best_eer_h < GLOBAL_EER_BEST:
            shutil.copy(best_path_eer_h, alias_best)
            GLOBAL_EER_BEST = best_eer_h
            print(f" [alias] mfa_conf_best.pt -> H (EER={best_eer_h*100:.2f}%)")
    return updated



set_trainable_lp()
opt_lp = optim.AdamW(
    [p for p in list(mfa.parameters()) + list(head_ce.parameters()) if p.requires_grad],
    lr=2e-3, weight_decay=1e-4
)

def build_warmup_cosine(optimizer, warmup_steps, total_steps, min_lr=1e-6, base_lr=2e-3):
    def lr_lambda(step):
        if step < warmup_steps:
            return (step + 1) / max(1, warmup_steps)
        progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
        return max(min_lr / base_lr, 0.5 * (1.0 + math.cos(math.pi * progress)))
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

total_steps_lp = E_LP * STEPS_PER_EPOCH_TRAIN
sched_lp = build_warmup_cosine(opt_lp, warmup_steps=int(0.05*total_steps_lp), total_steps=total_steps_lp, base_lr=2e-3)

MODEL_REV = 0
step_ctr = 0
for ep in range(1, E_LP+1):
    tr_loss, tr_acc = run_epoch_lp(dl_tr, opt_lp, steps=STEPS_PER_EPOCH_TRAIN)
    step_ctr += STEPS_PER_EPOCH_TRAIN
    sched_lp.step()


    mfa.eval(); head_ce.eval()
    with torch.no_grad():
        tot, correct = 0, 0
        for wavs, ys in iterate_steps(dl_va, steps=300 if FAST_MODE else None):
            ys = ys.to(DEVICE, non_blocking=True)
            logits = forward_logits_ce(wavs)
            correct += (logits.argmax(1) == ys).sum().item()
            tot += ys.size(0)
        va_acc = correct / max(1, tot)
    if va_acc > best_va_acc: best_va_acc = va_acc; save_state(best_path_val)
    if ep % EER_EVAL_EVERY_LP == 0:
        eer_e, thr_e, eer_h, thr_h = eval_eer_checkpoint(tag=f"LP-Ep{ep}", max_pairs_E=MAX_PAIRS_LP, max_pairs_H=MAX_PAIRS_LP)
        maybe_update_best(eer_e, thr_e, eer_h, thr_h)
    MODEL_REV += 1
    print(f"[LP] Ep{ep:02d} | tr_acc={tr_acc:.4f} | va_acc={va_acc:.4f}")


print("Rebuilding cohort for S-Norm after LP ...")
EMBED_CACHE = {}; COHORT = build_cohort(kept_abs, K=COHORT_K_FT)
print("Cohort size (rebuild):", COHORT.shape)


set_trainable_ftA()
paramsA = [p for p in list(mfa.parameters())+list(head_ce.parameters())+list(head_aam.parameters()) if p.requires_grad]
optA = optim.AdamW(paramsA, lr=5e-5, weight_decay=1e-5)
total_steps_fta = E_A * STEPS_PER_EPOCH_TRAIN
schedA = torch.optim.lr_scheduler.CosineAnnealingLR(optA, T_max=max(10, E_A))

for ep in range(1, E_A+1):
    tr_loss, tr_acc = run_epoch_hybrid(dl_tr, optA, lambda_aam=0.7, steps=STEPS_PER_EPOCH_TRAIN)
    _, va_acc = run_epoch_aam(dl_va, opt=None, train=False)
    if va_acc > best_va_acc + 1e-4: best_va_acc = va_acc; save_state(best_path_val)
    eer_e, thr_e, eer_h, thr_h = eval_eer_checkpoint(tag=f"FT-A{ep:02d}", max_pairs_E=MAX_PAIRS_FT, max_pairs_H=MAX_PAIRS_FT)
    maybe_update_best(eer_e, thr_e, eer_h, thr_h)
    schedA.step(); MODEL_REV += 1
    print(f"[FT-A] Ep{ep:02d} | tr_acc={tr_acc:.4f} | va_acc={va_acc:.4f}")


set_trainable_ftB()
paramsB = [p for p in list(mfa.parameters())+list(head_aam.parameters()) if p.requires_grad]
optB   = optim.AdamW(paramsB, lr=8e-6, weight_decay=1e-5)  
schedB = optim.lr_scheduler.CosineAnnealingLR(optB, T_max=max(E_B, 20))

no_improve = 0
for ep in range(1, E_B+1):
    tr_loss, tr_acc = run_epoch_aam(dl_tr, optB, train=True, steps=STEPS_PER_EPOCH_TRAIN)
    _, va_acc = run_epoch_aam(dl_va, opt=None, train=False)
    improved = False
    if va_acc > best_va_acc + 2e-4:
        best_va_acc = va_acc
        save_state(best_path_val)
        improved = True

    eer_e, thr_e, eer_h, thr_h = eval_eer_checkpoint(
        tag=f"FT-B{ep:02d}",
        max_pairs_E=MAX_PAIRS_FT,
        max_pairs_H=MAX_PAIRS_FT
    )
    if maybe_update_best(eer_e, thr_e, eer_h, thr_h):
        improved = True

    schedB.step(); MODEL_REV += 1
    no_improve = 0 if improved else (no_improve + 1)
    print(f"[FT-B] Ep{ep:02d} | va_acc={va_acc:.4f} | best_eer={min(best_eer_e, best_eer_h)*100:.2f}% | no_improve={no_improve}")
    if no_improve >= EARLY_STOP_PATIENCE:
        print("Early stop FT-B due to no improvement.")
        break

gc.collect()
if DEVICE == "cuda": torch.cuda.empty_cache()

print("\n== Checkpoint summary (MFA-Conformer) ==")
print("Best by VAL :", best_path_val, "| best_va_acc =", f"{best_va_acc:.4f}")
print("Best by EER-E:", best_path_eer_e, "| best_eer_e =", f"{best_eer_e*100:.2f}%", "| thr =", best_thr_e)
print("Best by EER-H:", best_path_eer_h, "| best_eer_h =", f"{best_eer_h*100:.2f}%", "| thr =", best_thr_h)
print("Alias (ưu tiên EER):", alias_best, "| current_best_eer =", f"{min(best_eer_e, best_eer_h)*100:.2f}%")


@torch.no_grad()
def forward_logits_tta_aam_infer(batch_wav, n_crops=2):
    B, _, T = batch_wav.shape
    seg = min(MAX_LEN_VERIFY, T)
    def crop_center(x):
        if T <= seg: return x
        st = (T - seg)//2
        return x[..., st:st+seg]
    def crop_random(x):
        if T <= seg: return x
        st = int(torch.randint(0, T - seg + 1, (1,)).item())
        return x[..., st:st+seg]
    crops = [crop_center(batch_wav)] + [crop_random(batch_wav)]
    logits_sum = 0
    for c in crops:
        if c.shape[-1] < MAX_LEN_VERIFY:
            pad = torch.zeros(B,1,MAX_LEN_VERIFY - c.shape[-1], device=c.device, dtype=c.dtype)
            c = torch.cat([c,pad], dim=-1)
        x = c.to(DEVICE, dtype=torch.float32)
        e = mfa(x, apply_specaug=False)
        logits_sum = logits_sum + logits_aam_infer(e)
    return logits_sum / float(len(crops))

def eval_loader_aam_tta(dloader, n_crops=2):
    mfa.eval(); head_aam.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for wavs, ys in dloader:
            logits = forward_logits_tta_aam_infer(wavs, n_crops=n_crops)
            preds = logits.argmax(1).cpu().numpy().tolist()
            y_pred += preds; y_true += ys.numpy().tolist()
    return accuracy_score(y_true, y_pred)

def safe_load_meta(path):
    if not os.path.exists(path): return None, math.inf, None
    obj = torch.load(path, map_location=DEVICE)
    return obj, obj.get("best_eer", math.inf), obj.get("best_thr", None)

obj_e, eer_e, thr_e = safe_load_meta(best_path_eer_e)
obj_h, eer_h, thr_h = safe_load_meta(best_path_eer_h)
if (obj_e is not None and eer_e <= eer_h):
    ckpt = obj_e; chosen_path = best_path_eer_e; chosen_split="E"; chosen_eer=eer_e; chosen_thr=thr_e
elif obj_h is not None:
    ckpt = obj_h; chosen_path = best_path_eer_h; chosen_split="H"; chosen_eer=eer_h; chosen_thr=thr_h
else:
    ckpt = torch.load(best_path_val, map_location=DEVICE)
    chosen_path = best_path_val; chosen_split="VAL"; chosen_eer=None; chosen_thr=None

mfa.load_state_dict(ckpt["mfa"]); head_ce.load_state_dict(ckpt["head_ce"]); head_aam.load_state_dict(ckpt["head_aam"])
spk2idx = ckpt["spk2idx"]; idx2spk = {i: s for s, i in spk2idx.items()}

shutil.copy(chosen_path, alias_best)
msg = f"Loaded: {chosen_path} -> aliased to {alias_best}"
if chosen_eer is not None:
    msg += f" | chosen_split={chosen_split} | EER={chosen_eer*100:.2f}% | thr={chosen_thr}"
print(msg)

val_acc  = eval_loader_aam_tta(dl_va,  n_crops=2)
test_acc = eval_loader_aam_tta(dl_te,  n_crops=2)
print(f"VAL acc (AAM+TTA) = {val_acc:.4f} | TEST acc (AAM+TTA) = {test_acc:.4f}")


with open("/kaggle/working/spk_map.json", "w", encoding="utf-8") as f:
    json.dump(idx2spk, f, ensure_ascii=False, indent=2)
meta = {"chosen_path": chosen_path, "split": chosen_split}
if chosen_eer is not None: meta.update({"best_eer": float(chosen_eer), "best_thr": float(chosen_thr)})
with open("/kaggle/working/train_meta_mfa_conf.json", "w", encoding="utf-8") as f:
    json.dump(meta, f, ensure_ascii=False, indent=2)

export_dir = "/kaggle/working/export_id_model_mfa_conf"
os.makedirs(export_dir, exist_ok=True)
for p in [alias_best, best_path_val, best_path_eer_e, best_path_eer_h,
          "/kaggle/working/spk_map.json", "/kaggle/working/train_meta_mfa_conf.json"]:
    if os.path.exists(p): shutil.copy(p, export_dir)

zip_path = "/kaggle/working/mfa_conformer_model_improved"
shutil.make_archive(zip_path, 'zip', export_dir)
print("Saved zip:", zip_path + ".zip")
