In [None]:

import os, glob, random, gc, shutil
from collections import defaultdict
import numpy as np
import torch, torchaudio
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from sklearn.metrics import roc_curve

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = True
print(" DEVICE:", DEVICE)


DATA_ROOT = "/kaggle/input/vietnam-celeb-dataset/full-dataset"
AUDIO_DIR = os.path.join(DATA_ROOT, "data")
TXT_E = os.path.join(DATA_ROOT, "vietnam-celeb-e.txt")
TXT_H = os.path.join(DATA_ROOT, "vietnam-celeb-h.txt")

EXTS = (".wav", ".flac", ".mp3", ".m4a", ".ogg")
N_MELS = 80   

def _norm_rel(p: str) -> str:
    p = p.strip().replace("\\", "/").lstrip("./")
    if p.startswith("data/"): p = p.split("data/", 1)[1]
    return p

def read_pairs(txt_path):
    pairs = []
    if not os.path.exists(txt_path): return pairs
    with open(txt_path, "r", encoding="utf-8") as f:
        for ln in f:
            parts = ln.strip().split()
            if len(parts) >= 3:
                try:
                    lab = int(parts[0])
                    a_rel = _norm_rel(parts[1])
                    b_rel = _norm_rel(parts[2])
                    pairs.append((lab, a_rel, b_rel))
                except: continue
    return pairs

E_pairs = read_pairs(TXT_E)
H_pairs = read_pairs(TXT_H)
E_files = {p for _, a, b in E_pairs for p in [a,b]}
H_files = {p for _, a, b in H_pairs for p in [a,b]}
ban = E_files | H_files

all_abs = [p for p in glob.glob(os.path.join(AUDIO_DIR, "**", "*"), recursive=True)
           if os.path.splitext(p)[1].lower() in EXTS]
kept_abs = [p for p in all_abs
            if os.path.relpath(p, AUDIO_DIR).replace("\\", "/") not in ban]
print(f" Total audio: {len(all_abs)} | Train usable: {len(kept_abs)}")


spk2files_all = defaultdict(list)
for p in kept_abs:
    sid = os.path.basename(os.path.dirname(p))
    spk2files_all[sid].append(p)
spk2files = {s: fs for s, fs in spk2files_all.items() if len(fs) >= 10}
speakers = sorted(spk2files.keys())
spk2idx = {s: i for i, s in enumerate(speakers)}
num_classes = len(spk2idx)
print(f" Speakers for training: {num_classes}")


train_pairs = []
for sid, files in spk2files.items():
    random.shuffle(files)
    n = len(files)
    n_tr = int(0.9 * n)
    train_pairs += [(p, spk2idx[sid]) for p in files[:n_tr]]
print(f"Train samples: {len(train_pairs)}")


MEL_TRANSFORM = torchaudio.transforms.MelSpectrogram(
    sample_rate=16000, n_fft=400, hop_length=160, n_mels=N_MELS
)

def is_valid_audio(path):
    try:
        wav, sr = torchaudio.load(path)
        if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
        if sr != 16000: wav = torchaudio.functional.resample(wav, sr, 16000)
        return wav.abs().mean().item() > 1e-4
    except:
        return False

train_pairs = [(p, y) for p, y in train_pairs if is_valid_audio(p)]
print(f" After cleaning: {len(train_pairs)} valid samples")

class VoiceDataset(Dataset):
    def __init__(self, pairs, augment=False, sr=16000, max_len=3.0):
        self.pairs = pairs; self.augment = augment
        self.sr = sr; self.max_len = max_len
    def __len__(self): return len(self.pairs)
    def __getitem__(self, idx):
        path, y = self.pairs[idx]
        wav, sr = torchaudio.load(path)
        if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
        if sr != self.sr: wav = torchaudio.functional.resample(wav, sr, self.sr)
        L = int(self.max_len * self.sr)
        if wav.size(1) > L:
            st = random.randint(0, wav.size(1)-L); wav = wav[:, st:st+L]
        else:
            wav = F.pad(wav, (0, L - wav.size(1)))
        if self.augment:
            if random.random() < 0.5: wav += 0.005 * torch.randn_like(wav)
            if random.random() < 0.5: wav *= random.uniform(0.8, 1.2)
        mel = torch.log(MEL_TRANSFORM(wav) + 1e-6).squeeze(0)
        mel = torch.nan_to_num(mel)
        return mel, y

BATCH = 16
dl_tr = DataLoader(VoiceDataset(train_pairs, True), batch_size=BATCH, shuffle=True, num_workers=2)

class Swish(nn.Module):
    def forward(self, x): return x * torch.sigmoid(x)

class ConvModule(nn.Module):
    def __init__(self, dim, k=15):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        self.pw1 = nn.Conv1d(dim, 2*dim, 1)
        self.dw  = nn.Conv1d(dim, dim, k, padding=k//2, groups=dim)
        self.bn  = nn.BatchNorm1d(dim)
        self.act = Swish()
        self.pw2 = nn.Conv1d(dim, dim, 1)
    def forward(self, x):
        y = self.ln(x).transpose(1,2)
        y = F.glu(self.pw1(y), dim=1)
        y = self.pw2(self.act(self.bn(self.dw(y)))).transpose(1,2)
        return x + y

class FeedForwardModule(nn.Module):
    def __init__(self, dim, exp=8, drop=0.2):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, exp*dim), Swish(), nn.Dropout(drop),
            nn.Linear(exp*dim, dim), nn.Dropout(drop)
        )
    def forward(self, x): return x + 0.5 * self.ff(self.ln(x))

class MHSA(nn.Module):
    def __init__(self, dim, h=4, drop=0.2):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, h, dropout=drop, batch_first=True)
        self.do = nn.Dropout(drop)
    def forward(self, x):
        y, _ = self.attn(self.ln(x), self.ln(x), self.ln(x))
        return x + self.do(y)

class ConformerBlock(nn.Module):
    def __init__(self, dim, h=4, k=15, exp=8, drop=0.2):
        super().__init__()
        self.ff1 = FeedForwardModule(dim, exp, drop)
        self.mhsa = MHSA(dim, h, drop)
        self.conv = ConvModule(dim, k)
        self.ff2 = FeedForwardModule(dim, exp, drop)
    def forward(self, x): return self.ff2(self.conv(self.mhsa(self.ff1(x))))

class MFAConformer(nn.Module):
    def __init__(self, n_mels=N_MELS, dim=256, L=6, h=4, k=15, exp=8, drop=0.2, emb_dim=192):
        super().__init__()
        self.proj = nn.Linear(n_mels, dim)
        self.blocks = nn.ModuleList([ConformerBlock(dim, h, k, exp, drop) for _ in range(L)])
        self.ln_mfa = nn.LayerNorm(dim * L)
        self.post = nn.Sequential(nn.Linear(dim * L * 2, emb_dim), nn.BatchNorm1d(emb_dim))
    def forward(self, x):
        x = self.proj(x.transpose(1, 2))
        feats = []
        for b in self.blocks:
            x = b(x); feats.append(x)
        H = self.ln_mfa(torch.cat(feats, dim=-1))
        mean, std = H.mean(1), H.std(1).clamp(min=1e-6)
        emb = self.post(torch.cat([mean, std], 1))
        emb = torch.nan_to_num(emb)
        return F.normalize(emb, p=2, dim=1)

def compute_eer(scores, labels):
    fpr, tpr, _ = roc_curve(labels, scores)
    fnr = 1 - tpr
    idx = np.nanargmin(np.abs(fnr - fpr))
    return float((fpr[idx] + fnr[idx]) / 2 * 100)

def safe_cosine(a, b):
    a = F.normalize(torch.nan_to_num(a), p=2, dim=0)
    b = F.normalize(torch.nan_to_num(b), p=2, dim=0)
    s = F.cosine_similarity(a, b, dim=0).item()
    return None if not np.isfinite(s) else float(s)

@torch.no_grad()
def embed_file(p_abs, model):
    try:
        wav, sr = torchaudio.load(p_abs)
        if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
        if sr != 16000: wav = torchaudio.functional.resample(wav, sr, 16000)
        wav = F.pad(wav, (0, max(0, 48000 - wav.size(1))))[:, :48000]
        mel = torch.log(MEL_TRANSFORM(wav)+1e-6).squeeze(0)
        mel = torch.nan_to_num(mel)
        with torch.amp.autocast("cuda", enabled=False):
            emb = model(mel.unsqueeze(0).to(DEVICE, dtype=torch.float32)).cpu().squeeze(0)
        return torch.nan_to_num(emb)
    except: return None

def eval_eer(pairs, model):
    needed = sorted({os.path.join(AUDIO_DIR, a) for _, a, _ in pairs} |
                    {os.path.join(AUDIO_DIR, b) for _, _, b in pairs})
    emb_map = {}
    for p in needed:
        e = embed_file(p, model)
        if e is not None: emb_map[os.path.relpath(p, AUDIO_DIR)] = e
    scores, labels = [], []
    for lab, a, b in pairs:
        if a not in emb_map or b not in emb_map: continue
        s = safe_cosine(emb_map[a], emb_map[b])
        if s is not None:
            scores.append(s); labels.append(lab)
    if len(scores) < 5: return None
    return compute_eer(scores, labels)

model = MFAConformer(n_mels=N_MELS).to(DEVICE)
head = nn.Linear(192, num_classes).to(DEVICE)
opt = torch.optim.Adam(list(model.parameters()) + list(head.parameters()), lr=1e-4, weight_decay=1e-5)
crit = nn.CrossEntropyLoss(label_smoothing=0.1)
scaler = GradScaler()

EPOCHS = 30
steps = len(dl_tr)
sched = torch.optim.lr_scheduler.OneCycleLR(opt, max_lr=1e-4, epochs=EPOCHS, steps_per_epoch=steps)
best_eer = 999.0

os.makedirs("/kaggle/working/mfa_conf_paper", exist_ok=True)

for ep in range(1, EPOCHS+1):
    model.train(); head.train()
    total, correct = 0, 0
    for mel, y in dl_tr:
        mel, y = mel.to(DEVICE), y.to(DEVICE)
        with autocast("cuda", dtype=torch.float16):
            emb = model(mel)
            logits = head(emb)
            loss = crit(logits, y)
        opt.zero_grad(set_to_none=True)
        scaler.scale(loss).backward()
        scaler.step(opt); scaler.update()
        sched.step()
        total += y.size(0)
        correct += (logits.argmax(1) == y).sum().item()
    acc = correct / total
    print(f"[Ep{ep:02d}] train_acc={acc:.4f}")

    model.eval()
    eer_e = eval_eer(E_pairs, model)
    eer_h = eval_eer(H_pairs, model)
    print(f"  ↳ Eval EER_E={eer_e if eer_e else 'N/A'} | EER_H={eer_h if eer_h else 'N/A'}")
    if eer_e and eer_e < best_eer:
        best_eer = eer_e
        torch.save({"model": model.state_dict(), "head": head.state_dict(),
                    "EER_E": eer_e, "EER_H": eer_h},
                   "/kaggle/working/mfa_conf_paper/best.pt")
        print(" Saved new best model")

gc.collect()
if DEVICE == "cuda": torch.cuda.empty_cache()
shutil.make_archive("/kaggle/working/mfa_conf_paper_export", "zip", "/kaggle/working/mfa_conf_paper")
print(f" Done. Best EER_E ≈ {best_eer:.2f}%")
