In [None]:

import os, glob, random, math, csv
from collections import defaultdict
import numpy as np
import torch, torchaudio
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.amp import autocast, GradScaler
from sklearn.metrics import roc_curve


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("DEVICE:", DEVICE, "| Torch:", torch.__version__, "| Torchaudio:", torchaudio.__version__)
SEED = 42
random.seed(SEED); np.random.seed(SEED); torch.manual_seed(SEED)
if DEVICE == "cuda": torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.benchmark = True

AMP_DTYPE = torch.bfloat16 if (DEVICE=="cuda" and torch.cuda.is_bf16_supported()) else torch.float16
USE_SCALER = (AMP_DTYPE == torch.float16)

TRAIN_ROOT = "/kaggle/input/voxvn-api491/train_small_wav"   
TEST_ROOT  = "/kaggle/input/voxvietnam"                     
assert os.path.exists(TRAIN_ROOT), f"Không thấy train folder: {TRAIN_ROOT}"
assert os.path.exists(TEST_ROOT),  f"Không thấy test folder: {TEST_ROOT}"


def scan_dataset(root):
    all_files = glob.glob(os.path.join(root, "*", "*.wav"))
    speakers = sorted(list({os.path.basename(os.path.dirname(f)) for f in all_files}))
    spk2idx = {spk: i for i, spk in enumerate(speakers)}
    spk2files = defaultdict(list)
    for f in all_files:
        spk2files[spk2idx[os.path.basename(os.path.dirname(f))]].append(f)
   
    spk2files = {k: v for k, v in spk2files.items() if len(v) >= 2}
    return spk2files

spk2files_train = scan_dataset(TRAIN_ROOT)
print(f" Train dataset: {len(spk2files_train)} speakers")

valid_spks = sorted(list(spk2files_train.keys()))
spk_map = {spk: i for i, spk in enumerate(valid_spks)}
train_set = [(f, spk_map[spk]) for spk, files in spk2files_train.items() for f in files]
num_classes = len(spk_map)
print(f"Train: {len(train_set)} | Classes: {num_classes} | Label range check → "
      f"{min(l for _,l in train_set)} … {max(l for _,l in train_set)}")


def load_voxvietnam_pairs(root):
    csv_path = os.path.join(root, "test_list_gt.csv")
    txt_path = os.path.join(root, "test_list.txt")
    path = csv_path if os.path.exists(csv_path) else txt_path
    assert os.path.exists(path), f"Không tìm thấy test list: {csv_path} | {txt_path}"

    pairs = []
    with open(path, "r", encoding="utf-8-sig") as f:
        for raw in f:
            line = raw.strip()
            if not line:
                continue
            line = line.replace('"', '').replace(',', ' ').replace('\t', ' ')
            parts = [p for p in line.split() if p.strip()]
            if len(parts) != 3:
                parts_col = [line]
                while len(parts_col) < 3:
                    nxt = f.readline()
                    if not nxt:
                        break
                    nxt = nxt.strip().replace('"','').replace(',', ' ').replace('\t',' ')
                    if nxt:
                        parts_col.append(nxt)
                parts_join = [p for p in " ".join(parts_col).split() if p.strip()]
                if len(parts_join) != 3:
                    print(f"[WARN] Bỏ qua dòng lỗi: {raw.strip()}")
                    continue
                parts = parts_join
            label_str, f1, f2 = parts
            try:
                label = int(label_str)
                pairs.append((f1.strip(), f2.strip(), label))
            except ValueError:
                print(f"[WARN] Bỏ qua dòng lỗi (label không hợp lệ): {raw.strip()}")
                continue

    print(f" Loaded {len(pairs)} pairs from {os.path.basename(path)}")
    if pairs:
        print(" Ví dụ:", pairs[0])
    return pairs

pairs = load_voxvietnam_pairs(TEST_ROOT)


class VoiceDataset(Dataset):
    def __init__(self, samples, augment=False, max_len=3.0, sr=16000, n_mels=80):
        self.samples = samples
        self.augment = augment
        self.sr = sr
        self.max_len = max_len
        self.mel_tf = torchaudio.transforms.MelSpectrogram(
            sample_rate=sr, n_fft=400, hop_length=160, n_mels=n_mels
        )
        
        self.freq_mask = torchaudio.transforms.FrequencyMasking(freq_mask_param=15)
        self.time_mask = torchaudio.transforms.TimeMasking(time_mask_param=35)

    def __len__(self):
        return len(self.samples)

    def _augment_wave(self, wav):
        T = wav.size(1)

        if random.random() < 0.5:
            gain = random.uniform(0.8, 1.2)
            wav = wav * gain

        if random.random() < 0.5:
            snr_db = random.uniform(5, 20)
            sig_power = wav.pow(2).mean().clamp(min=1e-9)
            noise_power = sig_power / (10 ** (snr_db / 10))
            noise = torch.randn_like(wav)
            noise = noise * (noise_power.sqrt() / noise.pow(2).mean().clamp(min=1e-9).sqrt())
            wav = wav + noise

        if random.random() < 0.5:
            max_shift = int(0.08 * self.sr)
            shift = random.randint(-max_shift, max_shift)
            if shift > 0:
                wav = torch.cat([torch.zeros(1, shift), wav[:, :-shift]], dim=1)
            elif shift < 0:
                s = -shift
                wav = torch.cat([wav[:, s:], torch.zeros(1, s)], dim=1)

        if random.random() < 0.4:
            if random.random() < 0.5:
                cutoff = random.uniform(200.0, 4000.0)
                wav = torchaudio.functional.highpass_biquad(wav, self.sr, cutoff)
            else:
                cutoff = random.uniform(2000.0, 6000.0)
                wav = torchaudio.functional.lowpass_biquad(wav, self.sr, cutoff)

        return wav

    def __getitem__(self, idx):
        path, label = self.samples[idx]
        try:
            wav, sr = torchaudio.load(path)
            if wav.shape[0] > 1:
                wav = torch.mean(wav, dim=0, keepdim=True)
            if sr != self.sr:
                wav = torchaudio.functional.resample(wav, sr, self.sr)

            L = int(self.max_len * self.sr)
            if wav.size(1) > L:
                st = random.randint(0, wav.size(1) - L)
                wav = wav[:, st:st + L]
            else:
                wav = F.pad(wav, (0, L - wav.size(1)))

            if self.augment:
                wav = self._augment_wave(wav)

            mel = torch.log(self.mel_tf(wav) + 1e-6).squeeze(0) 

            if self.augment:
                mel = self.freq_mask(mel)
                mel = self.time_mask(mel)

            return mel, label
        except Exception:
            return self.__getitem__((idx + 1) % len(self.samples))


class Swish(nn.Module):
    def forward(self, x):
        return x * torch.sigmoid(x)

class SEBlock(nn.Module):
    def __init__(self, channels, r=8):
        super().__init__()
        self.se = nn.Sequential(
            nn.AdaptiveAvgPool1d(1),
            nn.Conv1d(channels, channels // r, kernel_size=1),
            nn.ReLU(inplace=True),
            nn.Conv1d(channels // r, channels, kernel_size=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        w = self.se(x)          
        return x * w           

class ConvModule(nn.Module):
    def __init__(self, dim, k=15):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        self.pw1 = nn.Conv1d(dim, 2 * dim, 1)
        self.dw  = nn.Conv1d(dim, dim, k, padding=k // 2, groups=dim)
        self.bn  = nn.BatchNorm1d(dim)
        self.act = Swish()
        self.pw2 = nn.Conv1d(dim, dim, 1)
        self.se  = SEBlock(dim)  

    def forward(self, x):
        y = self.ln(x).transpose(1, 2)      
        y = F.glu(self.pw1(y), dim=1)      
        y = self.dw(y)
        y = self.bn(y)
        y = self.act(y)
        y = self.pw2(y)
        y = self.se(y)                      
        y = y.transpose(1, 2)              
        return x + y

class FeedForwardModule(nn.Module):
    def __init__(self, dim, exp=4, drop=0.1):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        self.ff = nn.Sequential(
            nn.Linear(dim, exp * dim), Swish(), nn.Dropout(drop),
            nn.Linear(exp * dim, dim), nn.Dropout(drop),
        )

    def forward(self, x):
        return x + 0.5 * self.ff(self.ln(x))

class MHSA(nn.Module):
    def __init__(self, dim, h=4, drop=0.1):
        super().__init__()
        self.ln = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, h, dropout=drop, batch_first=True)
        self.do = nn.Dropout(drop)

    def forward(self, x):
        y, _ = self.attn(self.ln(x), self.ln(x), self.ln(x))
        return x + self.do(y)

class ConformerBlock(nn.Module):
    def __init__(self, dim, h=4, k=15, exp=4, drop=0.1):
        super().__init__()
        self.ff1 = FeedForwardModule(dim, exp, drop)
        self.mhsa = MHSA(dim, h, drop)
        self.conv = ConvModule(dim, k)
        self.ff2 = FeedForwardModule(dim, exp, drop)

    def forward(self, x):
        return self.ff2(self.conv(self.mhsa(self.ff1(x))))

class MFAConformer(nn.Module):
    def __init__(self, n_mels=80, dim=224, L=6, h=4, k=15, exp=4,
                 drop=0.1, emb_dim=192):
        super().__init__()

        self.proj = nn.Linear(n_mels, dim)

        self.blocks = nn.ModuleList([
            ConformerBlock(dim, h=h, k=k, exp=exp, drop=drop)
            for _ in range(L)
        ])

        self.ln_mfa = nn.LayerNorm(dim * L)

        self.post = nn.Linear(dim * L * 2, emb_dim)
        self.bn = nn.BatchNorm1d(emb_dim)

    def forward(self, x):
        x = x.transpose(1, 2)       
        x = self.proj(x)          

        feats = []
        for b in self.blocks:
            x = b(x)
            feats.append(x)

        H = torch.cat(feats, dim=-1)  
        H = self.ln_mfa(H)

        mean = H.mean(dim=1)
        std = H.std(dim=1).clamp(min=1e-6)
        pooled = torch.cat([mean, std], dim=1)

        emb = self.bn(self.post(pooled))
        emb = torch.nan_to_num(emb)
        return F.normalize(emb, p=2, dim=1)


class AAMSoftmax(nn.Module):
    def __init__(self, emb_dim, num_classes, s=30.0, m=0.2):
        super().__init__()
        self.W = nn.Parameter(torch.randn(emb_dim, num_classes))
        nn.init.xavier_normal_(self.W)
        self.s = s
        self.m = m

    def forward(self, emb, y):

        W = F.normalize(self.W, dim=0)          
        x = F.normalize(emb, p=2, dim=1)        
        logits = x @ W                         

        y_onehot = F.one_hot(y, num_classes=W.size(1)).float()
        theta = torch.acos(torch.clamp(logits, -1 + 1e-7, 1 - 1e-7))
        target_logits = torch.cos(theta + self.m)
        logits = logits * (1 - y_onehot) + target_logits * y_onehot
        logits = logits * self.s
        return logits


def compute_eer(scores, labels):
    fpr, tpr, _ = roc_curve(labels, scores)
    fnr = 1 - tpr
    idx = np.nanargmin(np.abs(fnr - fpr))
    return float(fpr[idx] * 100)

def resolve_test_path(root, rel):
    p = os.path.join(root, rel)
    if os.path.exists(p):
        return p
    p2 = os.path.join(root, "wav", rel)
    if os.path.exists(p2):
        return p2
    p3 = os.path.join(root, "wav", "wav", os.path.basename(rel))
    return p3

@torch.no_grad()
def extract_embeddings(model, file_list, sr=16000, n_mels=80, batch_size=4):
    model.eval()
    mel_tf = torchaudio.transforms.MelSpectrogram(
        sample_rate=sr, n_fft=400, hop_length=160, n_mels=n_mels
    )
    emb_map = {}
    for i in range(0, len(file_list), batch_size):
        batch = file_list[i:i+batch_size]
        mels = []
        valid_paths = []
        for path in batch:
            try:
                wav, sr0 = torchaudio.load(path)
                if wav.shape[0] > 1:
                    wav = torch.mean(wav, 0, keepdim=True)
                if sr0 != sr:
                    wav = torchaudio.functional.resample(wav, sr0, sr)
                wav = F.pad(wav, (0, max(0, int(3.0 * sr) - wav.size(1))))
                wav = wav[:, :int(3.0 * sr)]  
                mel = torch.log(mel_tf(wav) + 1e-6).squeeze(0)
                mels.append(mel)
                valid_paths.append(path)
            except Exception as e:
                print(f"[WARN] Lỗi khi đọc {path}: {e}")
                continue

        if not mels:
            continue
        max_T = max(m.shape[1] for m in mels)
        mels = [F.pad(m, (0, max_T - m.shape[1])) for m in mels]

        mel_batch = torch.stack(mels).to(DEVICE)
        emb_batch = model(mel_batch).cpu()
        for path, emb in zip(valid_paths, emb_batch):
            emb_map[os.path.basename(path)] = emb
        if DEVICE == "cuda":
            torch.cuda.empty_cache()
    return emb_map


BATCH, WORKERS, EPOCHS = 32, 4, 30
train_loader = DataLoader(
    VoiceDataset(train_set, augment=True),
    batch_size=BATCH, shuffle=True,
    num_workers=WORKERS, pin_memory=True
)

model = MFAConformer().to(DEVICE)
head  = AAMSoftmax(emb_dim=192, num_classes=num_classes, s=30.0, m=0.2).to(DEVICE)
opt   = optim.Adam(list(model.parameters()) + list(head.parameters()), lr=2e-3)
steps = len(train_loader)
sched = optim.lr_scheduler.OneCycleLR(opt, max_lr=2e-3, epochs=EPOCHS, steps_per_epoch=steps)
crit  = nn.CrossEntropyLoss()
scaler = GradScaler() if USE_SCALER else None

best_eer = 999.0
with open("train_log.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["epoch","train_loss","train_acc","test_eer"])

    for ep in range(1, EPOCHS+1):
        model.train(); head.train()
        total_loss = 0.0; correct = 0; total = 0

        for mel, lbl in train_loader:
            mel, lbl = mel.to(DEVICE), lbl.to(DEVICE)
            with autocast(device_type="cuda" if DEVICE=="cuda" else "cpu", dtype=AMP_DTYPE):
                emb = model(mel)
                logits = head(emb, lbl)
                loss = crit(logits, lbl)

            opt.zero_grad(set_to_none=True)
            if USE_SCALER:
                scaler.scale(loss).backward()
                scaler.step(opt)
                scaler.update()
            else:
                loss.backward()
                opt.step()
            sched.step()

            total_loss += loss.item() * lbl.size(0)
            with torch.no_grad():
                preds = logits.argmax(1)
            correct    += (preds == lbl).sum().item()
            total      += lbl.size(0)

        train_loss = total_loss / total
        train_acc  = correct / total


        all_files = sorted({resolve_test_path(TEST_ROOT, p) for f1, f2, _ in pairs for p in [f1, f2]})
        emb_map = extract_embeddings(model, all_files, batch_size=4)
        scores, labels = [], []
        for f1, f2, label in pairs:
            p1, p2 = resolve_test_path(TEST_ROOT, f1), resolve_test_path(TEST_ROOT, f2)
            n1, n2 = os.path.basename(p1), os.path.basename(p2)
            if n1 not in emb_map or n2 not in emb_map:
                print(f"[WARN] Missing embedding for {n1} or {n2}")
                continue
            sim = F.cosine_similarity(
                emb_map[n1].unsqueeze(0),
                emb_map[n2].unsqueeze(0)
            ).item()
            scores.append(sim); labels.append(label)
        test_eer = compute_eer(scores, labels) if scores else 100.0

        print(f"Epoch {ep:02d} | loss={train_loss:.4f} acc={train_acc:.4f} | test_EER={test_eer:.2f}%")
        writer.writerow([ep, train_loss, train_acc, test_eer]); f.flush()

        if test_eer < best_eer:
            best_eer = test_eer
            torch.save({"model": model.state_dict(), "head": head.state_dict()}, "best_model.pt")
            print(" Saved new best model (by test EER)")
