In [None]:

!pip -q install speechbrain==0.5.15 torchaudio --progress-bar off
!apt -y -qq update && apt -y -qq install ffmpeg


In [None]:

import os, glob, random, json, math, gc, time, errno, shutil
import numpy as np
import torch, torchaudio
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from sklearn.metrics import accuracy_score, roc_curve
from collections import defaultdict
from torch.utils.data import DataLoader, Dataset
from torch.amp import autocast, GradScaler


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = True


DATA_ROOT = "/kaggle/input/vietnam-celeb-dataset/full-dataset"
AUDIO_DIR = os.path.join(DATA_ROOT, "data")                    
TXT_E     = os.path.join(DATA_ROOT, "vietnam-celeb-e.txt")    
TXT_H     = os.path.join(DATA_ROOT, "vietnam-celeb-h.txt")     

assert os.path.exists(AUDIO_DIR), f"Không thấy thư mục audio: {AUDIO_DIR}"
print("DEVICE:", DEVICE)
print("AUDIO_DIR:", AUDIO_DIR)


TARGET_SR         = 16000
MAX_SECONDS       = 2.5          
BATCH_SIZE        = 12        
NUM_WORKERS       = 2
MIN_UTTS_PER_SPK  = 5            
SPLIT_RATIO       = (0.80, 0.10, 0.10)  


EER_EVAL_EVERY_LP = 1
EER_EVAL_EVERY_FT = 1
MAX_PAIRS_LP      = 2000        
MAX_PAIRS_FT      = 5000       

print("HYPERPARAMS:",
      dict(TARGET_SR=TARGET_SR, BATCH_SIZE=BATCH_SIZE, MIN_UTTS_PER_SPK=MIN_UTTS_PER_SPK,
           SPLIT_RATIO=SPLIT_RATIO))


In [None]:
def _norm_rel(p: str) -> str:
    p = p.strip().replace("\\", "/").lstrip("./")
    if p.startswith("data/"):
        p = p.split("data/", 1)[1]
    return p 

def read_pairs_to_set(txt_path: str) -> set:

    s = set()
    if not os.path.exists(txt_path):
        return s
    with open(txt_path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln or ln.startswith("#"):
                continue
            parts = ln.split()
            if len(parts) >= 3:
                s.add(_norm_rel(parts[1]))
                s.add(_norm_rel(parts[2]))
    return s


E_files = read_pairs_to_set(TXT_E)
H_files = read_pairs_to_set(TXT_H)
ban = E_files | H_files
print("Files in E:", len(E_files), "| Files in H:", len(H_files), "| BAN total:", len(ban))


EXTS = (".wav", ".flac", ".mp3", ".m4a", ".ogg")
all_abs = [p for p in glob.glob(os.path.join(AUDIO_DIR, "**", "*"), recursive=True)
           if os.path.splitext(p)[1].lower() in EXTS]
print("Found audio in data/:", len(all_abs))

kept_abs = []
for p in all_abs:
    rel = os.path.relpath(p, AUDIO_DIR).replace("\\", "/") 
    if rel in ban:
        continue
    kept_abs.append(p)

print("Kept for training (data \\ (E ∪ H)):", len(kept_abs))
assert len(kept_abs) > 0, "Không còn file nào để train sau khi loại E/H."


spk2files_all = defaultdict(list)
for p in kept_abs:
    sid = os.path.basename(os.path.dirname(p))  
    spk2files_all[sid].append(p)

spk2files = {s: fs for s, fs in spk2files_all.items() if len(fs) >= MIN_UTTS_PER_SPK}
speakers = sorted(spk2files.keys())
spk2idx = {s: i for i, s in enumerate(speakers)}
idx2spk = {i: s for s, i in spk2idx.items()}
num_classes = len(speakers)

print(f"Speakers kept (>= {MIN_UTTS_PER_SPK} utts):", num_classes)
assert num_classes >= 2, "Quá ít speaker sau khi lọc — hãy hạ MIN_UTTS_PER_SPK."


SPLIT_TR, SPLIT_VA, _ = SPLIT_RATIO
train_pairs, val_pairs, test_pairs = [], [], []
random.seed(123)

for sid, files in spk2files.items():
    files = files[:] ; random.shuffle(files)
    n = len(files)
    n_tr = max(3, int(SPLIT_TR * n))
    n_va = max(1, int(SPLIT_VA * n))
    tr = files[:n_tr]
    va = files[n_tr:n_tr + n_va]
    te = files[n_tr + n_va:]
    if len(te) == 0 and len(va) > 0: te = [va.pop()]
    if len(te) == 0 and len(tr) > 1: te = [tr.pop()]

    train_pairs += [(p, spk2idx[sid]) for p in tr]
    val_pairs   += [(p, spk2idx[sid]) for p in va]
    test_pairs  += [(p, spk2idx[sid]) for p in te]

print("Split sizes (train/val/test):", len(train_pairs), len(val_pairs), len(test_pairs))


to_rel = lambda p: os.path.relpath(p, AUDIO_DIR).replace("\\", "/")
train_rel = {to_rel(p) for p, _ in train_pairs}
val_rel   = {to_rel(p) for p, _ in val_pairs}
test_rel  = {to_rel(p) for p, _ in test_pairs}

print("train ∩ E:", len(train_rel & E_files), "| train ∩ H:", len(train_rel & H_files))
print(" val  ∩ E:", len(val_rel   & E_files), "|  val  ∩ H:", len(val_rel   & H_files))
print(" test ∩ E:", len(test_rel  & E_files), "| test  ∩ H:", len(test_rel  & H_files))


In [None]:
from torchaudio import sox_effects
import torch.fft as fft  


MAX_LEN = int(MAX_SECONDS * TARGET_SR)
_resamp = {}

def load_wav(path, target_sr=TARGET_SR, max_len=MAX_LEN, crop_train=True):
    wav, sr = torchaudio.load(path)          
    if wav.shape[0] > 1:
        wav = wav.mean(0, keepdim=True)    
    if sr != target_sr:
        if sr not in _resamp:
            _resamp[sr] = torchaudio.transforms.Resample(sr, target_sr)
        wav = _resamp[sr](wav)
    if wav.shape[-1] > max_len:
        if crop_train:
            start = torch.randint(0, wav.shape[-1] - max_len + 1, (1,)).item()
            wav = wav[:, start:start+max_len]
        else:
            st = (wav.shape[-1] - max_len) // 2
            wav = wav[:, st:st+max_len]
    return wav  

def _apply_sox(wav, sr, effects):
    try:
        wav2, _ = sox_effects.apply_effects_tensor(wav, sr, effects)
        if wav2.shape[-1] > MAX_LEN:
            wav2 = wav2[..., :MAX_LEN]
        return wav2
    except Exception:
        return wav

def _spectral_gate(wav, floor=0.30):
    W = torch.hann_window(400, device=wav.device, dtype=wav.dtype)
    spec = torch.stft(wav, 400, 160, window=W, return_complex=True)
    mag = spec.abs()
    noise = mag.median(dim=-1, keepdim=True).values
    mask = (mag > noise).float() * (1 - floor) + floor
    spec_d = spec * mask
    rec = torch.istft(spec_d, 400, 160, window=W, length=wav.shape[-1])
    return rec

def maybe_denoise(wav):
    if torch.rand(1).item() < 0.30:
        try:
            wav = _spectral_gate(wav, floor=0.30)
        except Exception:
            pass
    return wav

def aug_wav(wav):
    wav = maybe_denoise(wav)

    if torch.rand(1).item() < 0.50:
        g = 10 ** (float(torch.empty(1).uniform_(-6, 6)) / 20)
        wav = (wav * g).clamp(-1, 1)

    if torch.rand(1).item() < 0.50:
        snr_db = float(torch.empty(1).uniform_(5, 20)) 
        sig_pwr = wav.pow(2).mean().clamp_min(1e-9)
        noise_pwr = sig_pwr / (10 ** (snr_db / 10))
        if torch.rand(1).item() < 0.5:
            noise = torch.randn_like(wav)
        else:
            noise = torch.randn_like(wav)
            noise = torchaudio.functional.lowpass_biquad(noise, TARGET_SR, 3000)
        noise = noise * (noise_pwr.sqrt() / (noise.pow(2).mean().clamp_min(1e-9).sqrt()))
        wav = (wav + noise).clamp(-1, 1)

    if torch.rand(1).item() < 0.35:
        tempo = float(torch.empty(1).uniform_(0.90, 1.10))
        wav = _apply_sox(wav, TARGET_SR, [["tempo", f"{tempo}"]])

    if torch.rand(1).item() < 0.35:
        rev = str(int(torch.empty(1).uniform_(10, 30).item()))
        room = str(int(torch.empty(1).uniform_(10, 30).item()))
        wav = _apply_sox(wav, TARGET_SR, [["reverb", rev, "50", room]])

    if torch.rand(1).item() < 0.40:
        if torch.rand(1).item() < 0.5:
            cutoff = float(torch.empty(1).uniform_(3000, 4000))
            wav = torchaudio.functional.lowpass_biquad(wav, TARGET_SR, cutoff)
        else:
            cutoff = float(torch.empty(1).uniform_(120, 200))
            wav = torchaudio.functional.highpass_biquad(wav, TARGET_SR, cutoff)

    if torch.rand(1).item() < 0.50:
        max_shift = int(0.08 * TARGET_SR)
        shift = int(torch.randint(-max_shift, max_shift + 1, (1,)).item())
        if shift > 0:
            wav = torch.cat([torch.zeros(1, shift, device=wav.device), wav[..., :-shift]], dim=-1)
        elif shift < 0:
            sh = -shift
            wav = torch.cat([wav[..., sh:], torch.zeros(1, sh, device=wav.device)], dim=-1)

    if torch.rand(1).item() < 0.15:
        thr = float(torch.empty(1).uniform_(0.7, 0.95))
        wav = wav.clamp(-thr, thr) / thr

    if wav.shape[-1] > MAX_LEN:
        wav = wav[..., :MAX_LEN]
    return wav


In [None]:
from speechbrain.pretrained import EncoderClassifier

ecapa = EncoderClassifier.from_hparams(
    source="speechbrain/spkrec-ecapa-voxceleb",
    savedir="/kaggle/working/ecapa_vox"
).to(DEVICE)

head = nn.Linear(192, len(speakers)).to(DEVICE)

def forward_logits(batch_wav):
    x = batch_wav.squeeze(1).contiguous()
    model_device = next(ecapa.mods.embedding_model.parameters()).device
    x = x.to(model_device, dtype=torch.float32)
    wav_lens = torch.ones(x.size(0), device=model_device)  

    with autocast('cuda', enabled=False): 
        feats = ecapa.mods.compute_features(x)
        if feats.device != model_device:
            feats = feats.to(model_device)
        feats = ecapa.mods.mean_var_norm(feats, wav_lens)
        emb = ecapa.mods.embedding_model(feats, wav_lens)   
        if emb.dim() == 3 and emb.size(1) == 1:
            emb = emb.squeeze(1)                      

    logits = head(emb)
    return logits
def forward_logits_tta(batch_wav, n_crops=3):
    B, _, T = batch_wav.shape
    seg = min(MAX_LEN_VERIFY, T)

    def crop_center(x):
        if T <= seg: return x
        st = (T - seg) // 2
        return x[..., st:st+seg]

    def crop_random(x):
        if T <= seg: return x
        st = int(torch.randint(0, T - seg + 1, (1,)).item())
        return x[..., st:st+seg]

    logits_sum = 0.0
    crops = [crop_center(batch_wav)] + [crop_random(batch_wav) for _ in range(max(0, n_crops - 1))]
    for c in crops:
        if c.shape[-1] < batch_wav.shape[-1] and c.shape[-1] < MAX_LEN_VERIFY:
            pad = torch.zeros(B, 1, MAX_LEN_VERIFY - c.shape[-1], device=c.device, dtype=c.dtype)
            c = torch.cat([c, pad], dim=-1)
        logits_sum = logits_sum + forward_logits(c)
    return logits_sum / float(len(crops))


In [None]:

import os, gc, time, numpy as np, torch, torchaudio, shutil
import torch.nn.functional as F
from sklearn.metrics import roc_curve
from torch.amp import autocast, GradScaler
from torch import optim

torch.backends.cudnn.benchmark = True
ce = nn.CrossEntropyLoss()

if 'TARGET_SR' not in globals(): TARGET_SR = 16000
if 'TXT_E' not in globals(): TXT_E = None
if 'TXT_H' not in globals(): TXT_H = None
if 'MAX_PAIRS_LP' not in globals(): MAX_PAIRS_LP = 2000
if 'MAX_PAIRS_FT' not in globals(): MAX_PAIRS_FT = 5000
if 'AUDIO_DIR' not in globals(): raise RuntimeError("Thiếu AUDIO_DIR từ cell trước.")

MAX_SECONDS_VERIFY = 4.0
MAX_LEN_VERIFY     = int(MAX_SECONDS_VERIFY * TARGET_SR)

def resolve_rel(rel):
    rel = rel.strip().replace("\\","/").lstrip("./")
    if rel.startswith("data/"):
        rel = rel.split("data/", 1)[1]
    return os.path.join(AUDIO_DIR, rel)

@torch.no_grad()
def load_wav_verify(path, target_sr=TARGET_SR, max_len=MAX_LEN_VERIFY):
    wav, sr = torchaudio.load(path)
    if wav.shape[0] > 1:
        wav = wav.mean(0, keepdim=True)
    if sr != target_sr:
        wav = torchaudio.transforms.Resample(sr, target_sr)(wav)
    if wav.shape[-1] > max_len:
        st = (wav.shape[-1] - max_len)//2
        wav = wav[:, st:st+max_len]
    return wav  # [1,T]

@torch.no_grad()
def embed_path(path, n_crops=3):
    model_device = next(ecapa.mods.embedding_model.parameters()).device
    wav, sr = torchaudio.load(path)
    if wav.shape[0] > 1:
        wav = wav.mean(0, keepdim=True)
    if sr != TARGET_SR:
        wav = torchaudio.transforms.Resample(sr, TARGET_SR)(wav)

    L = wav.shape[-1]
    seg = MAX_LEN_VERIFY
    crops = []
    if L <= seg:
        pad = torch.zeros(1, seg - L)
        base = torch.cat([wav, pad], dim=-1)
        crops = [base for _ in range(n_crops)]
    else:
        st = (L - seg) // 2
        crops.append(wav[:, st:st+seg])
        for _ in range(n_crops - 1):
            s = torch.randint(0, L - seg + 1, (1,)).item()
            crops.append(wav[:, s:s+seg])

    embs = []
    for c in crops:
        x = c.squeeze(0).unsqueeze(0).to(model_device, dtype=torch.float32)
        wav_lens = torch.ones(1, device=model_device)
        with autocast('cuda', enabled=False):
            feats = ecapa.mods.compute_features(x)
            feats = ecapa.mods.mean_var_norm(feats, wav_lens)
            emb = ecapa.mods.embedding_model(feats, wav_lens)
            if emb.dim() == 3 and emb.size(1) == 1:
                emb = emb.squeeze(1)
            emb = F.normalize(emb, p=2, dim=-1)
        embs.append(emb.squeeze(0).detach().cpu().numpy())

    e = np.mean(np.stack(embs, 0), axis=0)
    e = e / (np.linalg.norm(e) + 1e-9)
    return e
@torch.no_grad()
def build_cohort(paths, K=3000, seed=7):
    rng = np.random.default_rng(seed)
    if paths is None or len(paths) == 0:
        return np.zeros((0, 192), dtype=np.float32)
    idxs = rng.choice(len(paths), size=min(K, len(paths)), replace=False)
    E = []
    for i in idxs:
        p = paths[i]
        if os.path.exists(p):
            e = embed_path(p) 
            E.append(e)
    if len(E) == 0:
        return np.zeros((0, 192), dtype=np.float32)
    return np.stack(E, 0)

if 'COHORT' not in globals() or COHORT is None:
    print("Building cohort for S-Norm ...")
    COHORT = build_cohort(kept_abs, K=3000)
    print("Cohort size:", COHORT.shape)

def s_norm_score(e1, e2, cohort=COHORT, topM=100):
    if cohort is None or getattr(cohort, "shape", (0,))[0] == 0:
        return float(e1 @ e2)
    c1 = cohort @ e1
    c2 = cohort @ e2
    m1 = np.sort(c1)[-min(topM, c1.shape[0]):].mean()
    m2 = np.sort(c2)[-min(topM, c2.shape[0]):].mean()
    raw = float(e1 @ e2)
    return raw - 0.5 * (m1 + m2)
    
def read_pairs(txt_path):
    pairs = []
    if (txt_path is None) or (not os.path.exists(txt_path)):
        return pairs
    with open(txt_path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln or ln.startswith("#"): 
                continue
            parts = ln.split()
            if len(parts) < 3: 
                continue
            lab = int(parts[0]); a = parts[1]; b = parts[2]
            pairs.append((lab, a, b))
    return pairs


@torch.no_grad()
def compute_eer_for_pairs(txt_path, max_pairs=None):
    pairs = read_pairs(txt_path)
    if len(pairs) == 0:
        return None, None, 0, 0
    if (max_pairs is not None) and (len(pairs) > max_pairs):
        rng = np.random.default_rng(123)
        idxs = rng.choice(len(pairs), size=max_pairs, replace=False)
        pairs = [pairs[i] for i in idxs]

    scores, labels, skipped = [], [], 0
    for lab, a_rel, b_rel in pairs:
        pa = resolve_rel(a_rel); pb = resolve_rel(b_rel)
        if not (os.path.exists(pa) and os.path.exists(pb)):
            skipped += 1
            continue
        ea = embed_path(pa); eb = embed_path(pb)
        s = s_norm_score(ea, eb)
        scores.append(s); labels.append(lab)

    if len(scores) == 0:
        return None, None, 0, skipped

    labels = np.array(labels); scores = np.array(scores)
    fpr, tpr, th = roc_curve(labels, scores, pos_label=1)
    fnr = 1 - tpr
    i = np.nanargmin(np.abs(fnr - fpr))
    eer = float((fnr[i] + fpr[i]) / 2.0)
    thr = float(th[i])
    return eer, thr, len(scores), skipped

def run_epoch(dloader, train=True, opt=None, scaler=None, use_amp=True, empty_every=3):
    if train:
        ecapa.train(); head.train()
    else:
        ecapa.eval();  head.eval()

    tot, correct, loss_sum = 0, 0, 0.0
    for it, (wavs, ys) in enumerate(dloader):
        ys = ys.to(DEVICE, non_blocking=True)
        if train and opt is not None:
            opt.zero_grad(set_to_none=True)

        if use_amp and DEVICE == "cuda":
            with autocast('cuda'):
                logits = forward_logits(wavs)
                loss = ce(logits, ys)
        else:
            logits = forward_logits(wavs)
            loss = ce(logits, ys)

        if train and opt is not None:
            if use_amp and scaler is not None:
                scaler.scale(loss).backward()
                params = [p for p in ecapa.parameters() if p.requires_grad] + list(head.parameters())
                torch.nn.utils.clip_grad_norm_(params, 5.0)
                scaler.step(opt); scaler.update()
            else:
                loss.backward()
                params = [p for p in ecapa.parameters() if p.requires_grad] + list(head.parameters())
                torch.nn.utils.clip_grad_norm_(params, 5.0)
                opt.step()

        preds = logits.argmax(1)
        bs = ys.size(0)
        correct += (preds == ys).sum().item()
        tot += bs
        loss_sum += float(loss.item()) * bs

        del wavs, ys, logits, preds, loss
        if DEVICE == "cuda" and (it % empty_every == 0):
            torch.cuda.empty_cache()
        gc.collect()

    return loss_sum / max(1, tot), correct / max(1, tot)

best_va_acc   = -1.0
best_path_val = "/kaggle/working/ecapa_id_best_by_val.pt"

best_path_eer_e = "/kaggle/working/ecapa_id_best_by_eer_E.pt"
best_path_eer_h = "/kaggle/working/ecapa_id_best_by_eer_H.pt"
alias_best      = "/kaggle/working/ecapa_id_best.pt"  

best_eer_e, best_thr_e = 1.0, None
best_eer_h, best_thr_h = 1.0, None
GLOBAL_EER_BEST = 1.0  

scaler = GradScaler('cuda') if DEVICE == "cuda" else None
if 'dl_tr' not in globals() or 'dl_va' not in globals() or 'dl_te' not in globals():
    class VNCelebID(Dataset):
        def __init__(self, pairs, train=True):
            self.pairs = pairs
            self.train = train
        def __len__(self): 
            return len(self.pairs)
        def __getitem__(self, i):
            path, y = self.pairs[i]
            wav = load_wav(path, crop_train=self.train)
            if self.train:
                wav = aug_wav(wav)  
            return wav, y

    def pad_batch(batch):
        wavs, ys = zip(*batch)
        maxlen = max(w.shape[-1] for w in wavs)
        out = []
        for w in wavs:
            if w.shape[-1] < maxlen:
                pad = torch.zeros(1, maxlen - w.shape[-1])
                w = torch.cat([w, pad], dim=-1)
            out.append(w)
        wavs = torch.stack(out, dim=0) 
        ys = torch.tensor(ys, dtype=torch.long)
        return wavs, ys

    SAFE_BATCH = min(BATCH_SIZE, 12)
    dl_tr = DataLoader(
        VNCelebID(train_pairs, True),
        batch_size=SAFE_BATCH, shuffle=True,
        num_workers=min(NUM_WORKERS, 2),
        collate_fn=pad_batch,
        pin_memory=False,
        persistent_workers=False,
        prefetch_factor=2
    )
    dl_va = DataLoader(
        VNCelebID(val_pairs, False),
        batch_size=SAFE_BATCH, shuffle=False,
        num_workers=0,
        collate_fn=pad_batch,
        pin_memory=False
    )
    dl_te = DataLoader(
        VNCelebID(test_pairs, False),
        batch_size=SAFE_BATCH, shuffle=False,
        num_workers=0,
        collate_fn=pad_batch,
        pin_memory=False
    )
    print("Dataloaders (bootstrap) ready. BATCH_SIZE:", SAFE_BATCH)

for p in ecapa.parameters(): p.requires_grad = False
opt1 = optim.AdamW(head.parameters(), lr=2e-3, weight_decay=1e-4)
E1 = 3 

for ep in range(1, E1+1):
    tr_loss, tr_acc = run_epoch(dl_tr, True, opt1, scaler, use_amp=True)
    va_loss, va_acc = run_epoch(dl_va, False, None, None, use_amp=False)

    eer_e, thr_e, n_e, _ = compute_eer_for_pairs(TXT_E, max_pairs=MAX_PAIRS_LP) if TXT_E else (None, None, 0, 0)
    eer_h, thr_h, n_h, _ = compute_eer_for_pairs(TXT_H, max_pairs=MAX_PAIRS_LP) if TXT_H else (None, None, 0, 0)

    if va_acc > best_va_acc:
        best_va_acc = va_acc
        torch.save({"ecapa": ecapa.state_dict(),
                    "head": head.state_dict(),
                    "spk2idx": spk2idx}, best_path_val)

    if (eer_e is not None) and (eer_e < best_eer_e):
        best_eer_e, best_thr_e = eer_e, thr_e
        torch.save({"ecapa": ecapa.state_dict(),
                    "head": head.state_dict(),
                    "spk2idx": spk2idx,
                    "best_eer": best_eer_e,
                    "best_thr": best_thr_e,
                    "subset_pairs": n_e,
                    "split": "E"}, best_path_eer_e)
        if best_eer_e < GLOBAL_EER_BEST:
            shutil.copy(best_path_eer_e, alias_best)
            GLOBAL_EER_BEST = best_eer_e
            print(f" [alias] ecapa_id_best.pt -> E (EER={best_eer_e*100:.2f}%)")

    if (eer_h is not None) and (eer_h < best_eer_h):
        best_eer_h, best_thr_h = eer_h, thr_h
        torch.save({"ecapa": ecapa.state_dict(),
                    "head": head.state_dict(),
                    "spk2idx": spk2idx,
                    "best_eer": best_eer_h,
                    "best_thr": best_thr_h,
                    "subset_pairs": n_h,
                    "split": "H"}, best_path_eer_h)
        if best_eer_h < GLOBAL_EER_BEST:
            shutil.copy(best_path_eer_h, alias_best)
            GLOBAL_EER_BEST = best_eer_h
            print(f" [alias] ecapa_id_best.pt -> H (EER={best_eer_h*100:.2f}%)")

    msg = f"[LP] Ep{ep:02d} | tr_acc={tr_acc:.4f} | va_acc={va_acc:.4f}"
    if eer_e is not None: msg += f" | EER_E={eer_e*100:.2f}%({n_e})"
    if eer_h is not None: msg += f" | EER_H={eer_h*100:.2f}%({n_h})"
    print(msg)

SAVE_EVERY = 3
SNAP_DIR = "/kaggle/working/snaps"
os.makedirs(SNAP_DIR, exist_ok=True)

def save_snapshot(tag):
    path = os.path.join(SNAP_DIR, f"ecapa_id_snap_{tag}.pt")
    torch.save({
        "ecapa": ecapa.state_dict(),
        "head": head.state_dict(),
        "spk2idx": spk2idx,
        "tag": tag,
        "best_va_acc": best_va_acc,
        "best_eer_e": best_eer_e,
        "best_eer_h": best_eer_h
    }, path)
    print(" [snapshot] saved ->", path)

def set_trainable(stage="last_only"):
    """
    stage='last_only' : chỉ mở block cuối + lin (+asp/attention nếu có) và head
    stage='all'       : mở toàn bộ backbone + head
    """
    try:
        last_idx = len(ecapa.mods.embedding_model.blocks) - 1
    except Exception:
        last_idx = 4 

    allow_subs = [
        f"mods.embedding_model.blocks.{last_idx}.",
        "mods.embedding_model.lin",
        "mods.embedding_model.asp",
        "mods.embedding_model.attention",
    ]
    for n, p in ecapa.named_parameters():
        if stage == "all":
            p.requires_grad = True
        else:
            p.requires_grad = any(s in n for s in allow_subs)
    for p in head.parameters():
        p.requires_grad = True

    t, T = 0, 0
    for p in list(ecapa.parameters()) + list(head.parameters()):
        T += p.numel()
        if p.requires_grad:
            t += p.numel()
    print(f"[set_trainable:{stage}] trainable params = {t:,}/{T:,}")

def get_param_groups(lr_backbone, lr_head, wd_backbone=1e-5, wd_head=1e-4):
    g_backbone, g_head = [], []
    for p in ecapa.parameters():
        if p.requires_grad:
            g_backbone.append(p)
    for p in head.parameters():
        if p.requires_grad:
            g_head.append(p)
    groups = []
    if g_backbone: groups.append({"params": g_backbone, "lr": lr_backbone, "weight_decay": wd_backbone})
    if g_head:     groups.append({"params": g_head,     "lr": lr_head,     "weight_decay": wd_head})
    return groups

def run_ft_epoch(opt, use_amp=True, tag="FT"):
    global best_va_acc, best_eer_e, best_thr_e, best_eer_h, best_thr_h, GLOBAL_EER_BEST

    tr_loss, tr_acc = run_epoch(dl_tr, True, opt, scaler, use_amp=use_amp)
    va_loss, va_acc = run_epoch(dl_va, False, None, None, use_amp=False)

    eer_e, thr_e, n_e, _ = compute_eer_for_pairs(TXT_E, max_pairs=MAX_PAIRS_FT) if TXT_E else (None,None,0,0)
    eer_h, thr_h, n_h, _ = compute_eer_for_pairs(TXT_H, max_pairs=MAX_PAIRS_FT) if TXT_H else (None,None,0,0)

    if va_acc > best_va_acc + 1e-4:
        best_va_acc = va_acc
        torch.save({"ecapa": ecapa.state_dict(),
                    "head": head.state_dict(),
                    "spk2idx": spk2idx}, best_path_val)

    if (eer_e is not None) and (eer_e < best_eer_e):
        best_eer_e, best_thr_e = eer_e, thr_e
        torch.save({"ecapa": ecapa.state_dict(),
                    "head": head.state_dict(),
                    "spk2idx": spk2idx,
                    "best_eer": best_eer_e,
                    "best_thr": best_thr_e,
                    "split": "E"}, best_path_eer_e)
        if best_eer_e < GLOBAL_EER_BEST:
            shutil.copy(best_path_eer_e, alias_best)
            GLOBAL_EER_BEST = best_eer_e
            print(f" [alias] ecapa_id_best.pt -> E (EER={best_eer_e*100:.2f}%)")

    if (eer_h is not None) and (eer_h < best_eer_h):
        best_eer_h, best_thr_h = eer_h, thr_h
        torch.save({"ecapa": ecapa.state_dict(),
                    "head": head.state_dict(),
                    "spk2idx": spk2idx,
                    "best_eer": best_eer_h,
                    "best_thr": best_thr_h,
                    "split": "H"}, best_path_eer_h)
        if best_eer_h < GLOBAL_EER_BEST:
            shutil.copy(best_path_eer_h, alias_best)
            GLOBAL_EER_BEST = best_eer_h
            print(f" [alias] ecapa_id_best.pt -> H (EER={best_eer_h*100:.2f}%)")

    msg = f"[{tag}] tr_acc={tr_acc:.4f} | va_acc={va_acc:.4f}"
    if eer_e is not None: msg += f" | EER_E={eer_e*100:.2f}%({n_e})"
    if eer_h is not None: msg += f" | EER_H={eer_h*100:.2f}%({n_h})"
    print(msg)
    return va_acc

set_trainable("last_only")
optA = optim.AdamW(get_param_groups(lr_backbone=5e-5, lr_head=8e-4))
E_A = 2
for ep in range(1, E_A+1):
    print(f"\n[Stage A] Epoch {ep}/{E_A}")
    _ = run_ft_epoch(optA, use_amp=True, tag=f"FT-A{ep:02d}")
    if ep % SAVE_EVERY == 0:
        save_snapshot(f"FT-A{ep:02d}")

set_trainable("all")
optB   = optim.AdamW(get_param_groups(lr_backbone=1e-5, lr_head=3e-4))
schedB = optim.lr_scheduler.CosineAnnealingLR(optB, T_max=10)

E_B = 8 
for ep in range(1, E_B+1):
    print(f"\n[Stage B] Epoch {ep}/{E_B}")
    _ = run_ft_epoch(optB, use_amp=True, tag=f"FT-B{ep:02d}")
    schedB.step()
    if ep % SAVE_EVERY == 0:
        save_snapshot(f"FT-B{ep:02d}")

gc.collect()
if DEVICE == "cuda":
    torch.cuda.empty_cache()

print("\n== Checkpoint summary (sau FT 2-stage) ==")
print("Best by VAL :", best_path_val, "| best_va_acc =", f"{best_va_acc:.4f}")
print("Best by EER-E:", best_path_eer_e, "| best_eer_e =", f"{best_eer_e*100:.2f}%", "| thr =", best_thr_e)
print("Best by EER-H:", best_path_eer_h, "| best_eer_h =", f"{best_eer_h*100:.2f}%", "| thr =", best_thr_h)
print("Alias (ưu tiên EER):", alias_best, "| current_best_eer =", f"{min(best_eer_e, best_eer_h)*100:.2f}%")


In [None]:
import os, shutil, json, math, torch
from sklearn.metrics import accuracy_score

ckpt_e = "/kaggle/working/ecapa_id_best_by_eer_E.pt"
ckpt_h = "/kaggle/working/ecapa_id_best_by_eer_H.pt"
ckpt_v = "/kaggle/working/ecapa_id_best_by_val.pt"

def safe_load_meta(path):
    if not os.path.exists(path): 
        return None, math.inf, None
    obj = torch.load(path, map_location=DEVICE)
    eer = obj.get("best_eer", math.inf)
    thr = obj.get("best_thr", None)
    return obj, eer, thr

obj_e, eer_e, thr_e = safe_load_meta(ckpt_e)
obj_h, eer_h, thr_h = safe_load_meta(ckpt_h)

chosen_path, ckpt, chosen_eer, chosen_thr, chosen_split = None, None, None, None, None
if (obj_e is not None) or (obj_h is not None):
    if eer_e <= eer_h:
        chosen_path, ckpt, chosen_eer, chosen_thr, chosen_split = ckpt_e, obj_e, eer_e, thr_e, "E"
    else:
        chosen_path, ckpt, chosen_eer, chosen_thr, chosen_split = ckpt_h, obj_h, eer_h, thr_h, "H"
else:
    assert os.path.exists(ckpt_v), "Không tìm thấy bất kỳ checkpoint EER hay VAL nào."
    chosen_path = ckpt_v
    ckpt = torch.load(ckpt_v, map_location=DEVICE)
    chosen_eer = None
    chosen_thr = None
    chosen_split = "VAL"

ecapa.load_state_dict(ckpt["ecapa"])
head.load_state_dict(ckpt["head"])
spk2idx = ckpt["spk2idx"]
idx2spk = {i: s for s, i in spk2idx.items()}

alias_path = "/kaggle/working/ecapa_id_best.pt"
shutil.copy(chosen_path, alias_path)

msg = f"Loaded: {chosen_path} -> aliased to {alias_path}"
if chosen_eer is not None:
    msg += f" | chosen_split={chosen_split} | EER={chosen_eer*100:.2f}% | thr={chosen_thr}"
print(msg)

def eval_loader(dloader):
    ecapa.eval(); head.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for wavs, ys in dloader:
            logits = forward_logits(wavs)
            preds = logits.argmax(1).cpu().numpy().tolist()
            y_pred += preds
            y_true += ys.numpy().tolist()
    return accuracy_score(y_true, y_pred)

def eval_loader_tta(dloader, n_crops=3):
    ecapa.eval(); head.eval()
    y_true, y_pred = [], []
    with torch.no_grad():
        for wavs, ys in dloader:
            logits = forward_logits_tta(wavs, n_crops=n_crops)
            preds = logits.argmax(1).cpu().numpy().tolist()
            y_pred += preds
            y_true += ys.numpy().tolist()
    return accuracy_score(y_true, y_pred)
val_acc  = eval_loader_tta(dl_va,  n_crops=3)
test_acc = eval_loader_tta(dl_te,  n_crops=3)

print(f"VAL acc = {val_acc:.4f} | TEST acc = {test_acc:.4f}")

with open("/kaggle/working/spk_map.json", "w", encoding="utf-8") as f:
    json.dump(idx2spk, f, ensure_ascii=False, indent=2)

meta = {"chosen_path": chosen_path, "split": chosen_split}
if chosen_eer is not None:
    meta.update({"best_eer": float(chosen_eer), "best_thr": float(chosen_thr)})
with open("/kaggle/working/train_meta.json", "w", encoding="utf-8") as f:
    json.dump(meta, f, ensure_ascii=False, indent=2)

export_dir = "/kaggle/working/export_id_model"
os.makedirs(export_dir, exist_ok=True)
for p in [
    "/kaggle/working/ecapa_id_best.pt",
    "/kaggle/working/ecapa_id_best_by_val.pt",
    "/kaggle/working/ecapa_id_best_by_eer_E.pt",
    "/kaggle/working/ecapa_id_best_by_eer_H.pt",
    "/kaggle/working/spk_map.json",
    "/kaggle/working/train_meta.json",
]:
    if os.path.exists(p):
        shutil.copy(p, export_dir)

zip_path = "/kaggle/working/ecapa_id_model"
shutil.make_archive(zip_path, 'zip', export_dir)
print("Saved zip:", zip_path + ".zip")
