In [None]:

import os, glob, random, csv, time
from collections import defaultdict
import numpy as np
import torch, torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from sklearn.metrics import roc_curve


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE, " | Torch:", torch.__version__, "| Torchaudio:", torchaudio.__version__)

SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
torch.backends.cudnn.benchmark = True


AMP_DTYPE = torch.float16
USE_SCALER = True

if DEVICE == "cuda":
    try:
        torch.cuda.manual_seed(SEED)
        torch.cuda.manual_seed_all(SEED)
    except RuntimeError as e:
        print(" CUDA manual_seed_all() failed, continue with CPU seeds:", str(e))


TRAIN_ROOT = "/kaggle/input/voxvn-api491/train_small_wav"
TEST_ROOT  = "/kaggle/input/voxvietnam"
assert os.path.exists(TRAIN_ROOT), "Không thấy train folder!"
assert os.path.exists(TEST_ROOT),  "Không thấy test folder!"


def scan_dataset(root):
    all_files = glob.glob(os.path.join(root, "*", "*.wav"))
    speakers = sorted(list({os.path.basename(os.path.dirname(f)) for f in all_files}))
    spk2idx_raw = {spk: i for i, spk in enumerate(speakers)}
    spk2files_raw = defaultdict(list)
    for f in all_files:
        spk = os.path.basename(os.path.dirname(f))
        spk2files_raw[spk2idx_raw[spk]].append(f)
    spk2files_raw = {k: v for k, v in spk2files_raw.items() if len(v) >= 2}
    return spk2files_raw

spk2files_train_raw = scan_dataset(TRAIN_ROOT)
spk_list = sorted(list(spk2files_train_raw.keys()))
spk_map = {spk: i for i, spk in enumerate(spk_list)}
train_set = [(f, spk_map[spk]) for spk, files in spk2files_train_raw.items() for f in files]
num_classes = len(spk_map)

print(f" Train dataset: {num_classes} speakers | {len(train_set)} utterances")
labels_dbg = [y for _, y in train_set]
print(f" Label range: {min(labels_dbg)} … {max(labels_dbg)} (num_classes={num_classes})")


def load_voxvietnam_pairs(root):
    csv_path = os.path.join(root, "test_list_gt.csv")
    txt_path = os.path.join(root, "test_list.txt")
    path = csv_path if os.path.exists(csv_path) else txt_path
    assert os.path.exists(path), f"Không tìm thấy test list: {path}"

    pairs = []
    with open(path, "r", encoding="utf-8-sig") as f:
        for raw in f:
            line = raw.strip()
            if not line: 
                continue
            line = line.replace('"', '').replace(',', ' ').replace('\t', ' ')
            parts = [p for p in line.split() if p]
            if len(parts) != 3:
                continue
            try:
                label = int(parts[0]); f1, f2 = parts[1], parts[2]
                pairs.append((f1, f2, label))
            except:
                continue
    print(f" Loaded {len(pairs)} pairs from {os.path.basename(path)}")
    if pairs:
        print(" Ví dụ:", pairs[0])
    return pairs

pairs = load_voxvietnam_pairs(TEST_ROOT)

def resolve_test_path(root, rel):
    p1 = os.path.join(root, rel)
    p2 = os.path.join(root, "wav", rel)
    p3 = os.path.join(root, "wav", "wav", os.path.basename(rel))
    if os.path.exists(p1): return p1
    if os.path.exists(p2): return p2
    if os.path.exists(p3): return p3
    return None

class VoiceDataset(Dataset):
    def __init__(self, samples, augment=False, sr=16000, max_len=3.0, n_mels=80):
        self.samples=samples; self.augment=augment
        self.sr=sr; self.max_len=max_len
        self.mel_tf=torchaudio.transforms.MelSpectrogram(
            sample_rate=sr, n_fft=400, hop_length=160, n_mels=n_mels)
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        path, label = self.samples[idx]
        try:
            wav, sr0 = torchaudio.load(path)
            if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
            if sr0 != self.sr:   wav = torchaudio.functional.resample(wav, sr0, self.sr)
            L = int(self.max_len * self.sr)
            if wav.size(1) > L:
                st = random.randint(0, wav.size(1)-L)
                wav = wav[:, st:st+L]
            else:
                wav = F.pad(wav, (0, L - wav.size(1)))
            if self.augment:
                if random.random() < 0.5: wav = wav + 0.005 * torch.randn_like(wav)
                if random.random() < 0.5: wav = wav * random.uniform(0.8, 1.2)
            mel = torch.log(self.mel_tf(wav) + 1e-6).squeeze(0).transpose(0,1)  
            return mel, label
        except Exception as e:
            print("[WARN] lỗi đọc:", path, e)
            return self.__getitem__((idx+1) % len(self.samples))


class SEBlock(nn.Module):
    def __init__(self, c, r=8):
        super().__init__()
        self.pool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Conv1d(c, c//r, 1), nn.ReLU(), nn.Conv1d(c//r, c, 1), nn.Sigmoid())
    def forward(self, x):
        s = self.pool(x)
        return x * self.fc(s)

class Res2Block(nn.Module):
    def __init__(self, c=512, scale=8, k=3, d=2):
        super().__init__()
        w = c // scale
        self.scale = scale
        self.convs = nn.ModuleList([
            nn.Conv1d(w, w, k, padding=d, dilation=d, bias=False) for _ in range(scale-1)
        ])
        self.bn = nn.BatchNorm1d(c)
        self.relu = nn.ReLU(inplace=True)
    def forward(self, x):
        splits = torch.chunk(x, self.scale, 1)
        out = []
        for i in range(self.scale):
            if i == 0: out.append(splits[i])
            else:
                s = splits[i] + out[i-1]
                s = self.convs[i-1](s)
                out.append(s)
        out = torch.cat(out, 1)
        return self.bn(self.relu(out))

class ECAPA_TDNN(nn.Module):
    def __init__(self, n_mels=80, c=512, emb=192):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Conv1d(n_mels, c, kernel_size=5, padding=2, bias=False),
            nn.ReLU(), nn.BatchNorm1d(c))
        self.blocks, self.proj, self.se = nn.ModuleList(), nn.ModuleList(), nn.ModuleList()
        for d in [2,3,4]:
            self.blocks.append(Res2Block(c, scale=8, d=d))
            self.proj.append(nn.Conv1d(c, c, 1, bias=False))
            self.se.append(SEBlock(c, 8))
        self.attn = nn.Sequential(
            nn.Conv1d(c*3, 128, 1), nn.ReLU(), nn.BatchNorm1d(128),
            nn.Conv1d(128, c*3, 1), nn.Softmax(dim=2))
        self.fc = nn.Linear((c*3)*2, emb)
    def forward(self, x):          
        x = x.transpose(1,2)      
        x = self.layer1(x)
        feats = []; h = x
        for blk, prj, se in zip(self.blocks, self.proj, self.se):
            y = blk(h); y = se(prj(y)) + h
            feats.append(y); h = y
        cat = torch.cat(feats, 1) 
        w = self.attn(cat)
        mean = torch.sum(cat*w, 2)
        std  = torch.sqrt(torch.sum((cat**2)*w, 2) - mean**2 + 1e-6)
        out = torch.cat([mean, std], 1)
        emb = self.fc(out)
        return F.normalize(emb, p=2, dim=1)

class AAMSoftmaxHead(nn.Module):
    def __init__(self, emb, n_class, s=30.0, m=0.2):
        super().__init__()
        self.W = nn.Parameter(torch.randn(n_class, emb))
        nn.init.xavier_uniform_(self.W)
        self.s, self.m = s, m
    def forward(self, x, y=None):
        x = F.normalize(x, p=2, dim=1)
        W = F.normalize(self.W, p=2, dim=1)
        logits = F.linear(x, W)
        if y is not None:
            idx = torch.arange(x.size(0), device=x.device)
            cos_y = logits[idx, y].clamp(-1,1)
            theta = torch.acos(cos_y)
            logits[idx, y] = torch.cos(theta + self.m).to(logits.dtype)
        return logits * self.s


def compute_eer(scores, labels):
    fpr, tpr, _ = roc_curve(labels, scores)
    fnr = 1 - tpr
    idx = np.nanargmin(np.abs(fnr - fpr))
    return float(fpr[idx] * 100)

@torch.no_grad()
def extract_embeddings(model, files, sr=16000, batch_size=4):
    model.eval()
    mel_tf = torchaudio.transforms.MelSpectrogram(sr, n_fft=400, hop_length=160, n_mels=80)
    emb_map = {}
    for i in range(0, len(files), batch_size):
        batch = files[i:i+batch_size]
        mels = []
        for f in batch:
            try:
                wav, sr0 = torchaudio.load(f)
                if wav.shape[0] > 1: wav = wav.mean(0, keepdim=True)
                if sr0 != sr: wav = torchaudio.functional.resample(wav, sr0, sr)
                wav = F.pad(wav, (0, max(0, int(3*sr) - wav.size(1))))[:, :int(3*sr)]
                mel = torch.log(mel_tf(wav) + 1e-6).squeeze(0).transpose(0,1) 
                mels.append(mel)
            except Exception as e:
                print("[WARN] extract:", f, e)
        if not mels: 
            continue
        max_T = max(m.shape[0] for m in mels)
        mels = [F.pad(m, (0,0, 0, max_T - m.shape[0])) for m in mels]  
        mel_batch = torch.stack(mels).to(DEVICE)
        emb_batch = model(mel_batch).cpu()
        for fpath, emb in zip(batch, emb_batch):
            emb_map[os.path.basename(fpath)] = emb
        torch.cuda.empty_cache()
    return emb_map


BATCH, EPOCHS = 32, 40
train_loader = DataLoader(
    VoiceDataset(train_set, augment=True),
    batch_size=BATCH, shuffle=True, num_workers=4, pin_memory=True
)

model = ECAPA_TDNN().to(DEVICE)
head  = AAMSoftmaxHead(192, num_classes).to(DEVICE)
opt   = optim.AdamW(list(model.parameters())+list(head.parameters()), lr=3e-4, weight_decay=1e-4)
sch   = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=EPOCHS, eta_min=1e-5)
crit  = nn.CrossEntropyLoss()
scaler = GradScaler(enabled=USE_SCALER)

best_eer = 999.0
with open("train_log.csv", "w", newline="") as f:
    writer = csv.writer(f); writer.writerow(["epoch","loss","acc","eer"])
    for ep in range(1, EPOCHS+1):
        model.train(); head.train()
        tot_loss=0.0; correct=0; total=0
        t0 = time.time()
        for mel, lbl in train_loader:
            mel, lbl = mel.to(DEVICE), lbl.to(DEVICE)
            with autocast(device_type="cuda" if DEVICE=="cuda" else "cpu", dtype=AMP_DTYPE):
                emb = model(mel)
                logits = head(emb, lbl)
                loss = crit(logits, lbl)
            if torch.isnan(loss) or torch.isinf(loss):
                print(" NaN/Inf loss — skip batch")
                continue
            opt.zero_grad(set_to_none=True)
            if USE_SCALER:
                scaler.scale(loss).backward()
                scaler.unscale_(opt)
                torch.nn.utils.clip_grad_norm_(list(model.parameters())+list(head.parameters()), 5.0)
                scaler.step(opt); scaler.update()
            else:
                loss.backward(); opt.step()
            tot_loss += loss.item() * lbl.size(0)
            correct  += (logits.argmax(1) == lbl).sum().item()
            total    += lbl.size(0)
        sch.step()
        tr_loss, tr_acc = tot_loss/total, correct/total

        all_files = sorted({resolve_test_path(TEST_ROOT, p)
                           for f1, f2, _ in pairs for p in [f1, f2]
                           if resolve_test_path(TEST_ROOT, p)})
        emb_map = extract_embeddings(model, all_files)
        scores, labels = [], []
        for f1, f2, lab in pairs:
            p1 = resolve_test_path(TEST_ROOT, f1)
            p2 = resolve_test_path(TEST_ROOT, f2)
            if not p1 or not p2: continue
            n1, n2 = os.path.basename(p1), os.path.basename(p2)
            if n1 not in emb_map or n2 not in emb_map: continue
            sim = F.cosine_similarity(emb_map[n1].unsqueeze(0), emb_map[n2].unsqueeze(0)).item()
            scores.append(sim); labels.append(lab)
        eer = compute_eer(scores, labels) if scores else 100.0

        t1 = time.time()
        print(f"Epoch {ep:02d} | loss={tr_loss:.4f} acc={tr_acc:.4f} | test_EER={eer:.2f}% | {t1-t0:.1f}s")
        writer.writerow([ep, tr_loss, tr_acc, eer]); f.flush()

        if eer < best_eer:
            best_eer = eer
            torch.save({"model": model.state_dict(), "head": head.state_dict()}, "best_model.pt")
            print(" Saved new best model")

print(f" Done | Best EER={best_eer:.2f}%")
