In [None]:
!git clone https://github.com/Res2Net/Res2Net-PretrainedModels.git

import sys
sys.path.append('/kaggle/working/Res2Net-PretrainedModels')


In [None]:
!grep -E "def |class " /kaggle/working/Res2Net-PretrainedModels/res2net_v1b.py


In [None]:
import os
import argparse
import random
from pathlib import Path
from typing import List, Tuple

import numpy as np
import soundfile as sf
import librosa

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T

from scipy.optimize import brentq
from scipy.interpolate import interp1d

from sklearn.metrics import roc_curve

from tqdm import tqdm
import glob

from collections import defaultdict
from collections import Counter

from torch.cuda.amp import autocast, GradScaler

from res2net_v1b import res2net50_v1b_26w_4s


In [None]:
# ------------------------- Config / Hyperparams -------------------------
DEFAULT_SR = 16000
AUDIO_DURATION = 2.5
N_MELS = 64
N_FFT = 1024

HOP_LENGTH = 256
EMBEDDING_SIZE = 256

MIN_UTTS_PER_SPK  = 5      
SPLIT_RATIO = (0.8, 0.2)
MAX_PAIRS_FT = 5000

EPOCHS = 8
BATCH_SIZE = 64
LR = 5e-6

TXT_E = "/kaggle/input/vietnam-celeb-dataset/vietnam-celeb-e.txt"
TXT_H = "/kaggle/input/vietnam-celeb-dataset/vietnam-celeb-h.txt"
AUDIO_DIR = "/kaggle/input/vietnam-celeb-dataset/full-dataset/data"


In [None]:
def read_audio(path: str, sr: int = DEFAULT_SR, duration: float = AUDIO_DURATION, random_crop: bool = False):
    wav, orig_sr = sf.read(path)     # load file wav
    if wav.ndim > 1:
        wav = np.mean(wav, axis=1)   # stereo -> mono
    
    if orig_sr != sr:
        wav = librosa.resample(wav.astype('float32'), orig_sr, sr)
    
    target_len = int(sr * duration)
    
    if len(wav) >= target_len:
        if random_crop:
            # chọn vị trí ngẫu nhiên trong file để crop
            start = np.random.randint(0, len(wav) - target_len + 1)
        else:
            # lấy từ đầu file
            start = 0
        wav = wav[start:start + target_len]
    else:
        pad_width = target_len - len(wav)
        wav = np.pad(wav, (0, pad_width), mode='constant')
    
    return wav.astype('float32')


In [None]:
def waveform_to_features(
    wav: torch.Tensor, 
    sr=DEFAULT_SR,
    n_fft=N_FFT, 
    hop_length=HOP_LENGTH,
    n_mels=N_MELS, 
    normalize=True
):

    if isinstance(wav, np.ndarray):
        wav = torch.tensor(wav, dtype=torch.float32)

    mel_spec = T.MelSpectrogram(
        sample_rate=sr,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        power=2.0
    )(wav.unsqueeze(0))  # [1, n_mels, T]

    log_mel = T.AmplitudeToDB()(mel_spec).squeeze(0)  # [n_mels, T]

    if normalize:
        log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-6)

    features = log_mel.unsqueeze(0)

    return features


In [None]:
def _norm_rel(p: str) -> str:
    p = p.strip().replace("\\", "/").lstrip("./")
    if p.startswith("data/"):
        p = p.split("data/", 1)[1]
    return p

def read_pairs_to_set(txt_path: str) -> set:
    s = set()
    if not os.path.exists(txt_path):
        return s
    with open(txt_path, "r", encoding="utf-8") as f:
        for ln in f:
            ln = ln.strip()
            if not ln or ln.startswith("#"):
                continue
            parts = ln.split()
            if len(parts) >= 3:
                s.add(_norm_rel(parts[1]))
                s.add(_norm_rel(parts[2]))
    return s

E_files = read_pairs_to_set(TXT_E)
H_files = read_pairs_to_set(TXT_H)
ban = E_files | H_files
print("Files in E:", len(E_files), "| Files in H:", len(H_files), "| BAN total:", len(ban))

EXTS = (".wav", ".flac", ".mp3", ".m4a", ".ogg")
all_abs = [p for p in glob.glob(os.path.join(AUDIO_DIR, "**", "*"), recursive=True)
           if os.path.splitext(p)[1].lower() in EXTS]
print("Found audio in data/:", len(all_abs))

kept_abs = []
for p in all_abs:
    rel = os.path.relpath(p, AUDIO_DIR).replace("\\", "/")
    if rel in ban:
        continue
    kept_abs.append(p)
print("Kept files after ban filter:", len(kept_abs))

spk2files_all = defaultdict(list)
for p in kept_abs:
    sid = os.path.basename(os.path.dirname(p))
    spk2files_all[sid].append(p)
print(f"Total speakers found: {len(spk2files_all)}")

dropped_speakers = {s: fs for s, fs in spk2files_all.items() if len(fs) < MIN_UTTS_PER_SPK}
spk2files = {s: fs for s, fs in spk2files_all.items() if len(fs) >= MIN_UTTS_PER_SPK}
print(f"Dropped speakers (<{MIN_UTTS_PER_SPK} utts): {len(dropped_speakers)}")
print(f"Kept speakers: {len(spk2files)}")

In [None]:
SPLIT_TR, SPLIT_VA = SPLIT_RATIO
random.seed(123)

spk2files_tr, spk2files_va = {}, {}
for sid, files in spk2files.items():
    files = files[:]
    random.shuffle(files)
    n = len(files)
    n_tr = max(1, int(round(SPLIT_TR * n)))
    tr_files = files[:n_tr]
    va_files = files[n_tr:] if n - n_tr > 0 else files[n_tr-1:]
    spk2files_tr[sid] = tr_files
    spk2files_va[sid] = va_files

# Map speaker id -> integer label
spk2id = {sid: idx for idx, sid in enumerate(sorted(spk2files.keys()))}

# Tạo dataset: list of (file, label)
train_set = [(f, spk2id[sid]) for sid, fs in spk2files_tr.items() for f in fs]
val_set   = [(f, spk2id[sid]) for sid, fs in spk2files_va.items() for f in fs]

print(f"Train samples: {len(train_set)}, Val samples: {len(val_set)}")
print(f"Num speakers (classes): {len(spk2id)})")


In [None]:
from collections import Counter
import os

def stats(dataset, name):
    # dataset: list of (file, label)
    labels = [label for _, label in dataset]
    c_label = Counter(labels)

    vals = list(c_label.values())
    if len(vals) > 0:
        vals_sorted = sorted(vals)
        mid = vals_sorted[len(vals_sorted) // 2]
        print(f"{name} - Total samples: {len(dataset)} | Speakers (classes): {len(c_label)}")
        print(f"   Utterances per speaker -> Min: {min(vals)}, Median: {mid}, Max: {max(vals)}")
    else:
        print(f"{name} - Empty dataset")

stats(train_set, "Train")
stats(val_set, "Val")


In [None]:
def augment_waveform(wav, sr):
    if random.random() < 0.5:
        aug_type = random.choice(["noise", "gain"])

        wav_tensor = torch.tensor(wav, dtype=torch.float32)

        if aug_type == "noise":
            noise = torch.randn_like(wav_tensor) * random.uniform(0.001, 0.01)
            wav_tensor = wav_tensor + noise
        elif aug_type == "gain":
            gain = random.uniform(-6, 6)
            wav_tensor = wav_tensor * (10 ** (gain / 20))

        wav = wav_tensor.numpy()

    return wav

def spec_augment(mel):
    mel = mel.clone()
    freq_mask = random.randint(0, mel.size(0)//8)
    time_mask = random.randint(0, mel.size(1)//8)
    f0 = random.randint(0, mel.size(0)-freq_mask)
    t0 = random.randint(0, mel.size(1)-time_mask)
    mel[f0:f0+freq_mask, :] = 0
    mel[:, t0:t0+time_mask] = 0
    return mel

In [None]:
class TripletSpeakerDataset(Dataset):
    def __init__(self, spk2files, sr=DEFAULT_SR, duration=AUDIO_DURATION,
                 n_mels=N_MELS,
                 augment=True, transform=None, random_crop=True, debug=False):
        self.spk2files = spk2files
        self.speakers = list(spk2files.keys())
        self.sr = sr
        self.duration = duration
        self.n_mels = n_mels
        self.augment = augment
        self.transform = transform
        self.random_crop = random_crop
        self.debug = debug

    def __len__(self):
        return sum(len(v) for v in self.spk2files.values())

    def _load_feature(self, path):
        wav = read_audio(path, sr=self.sr, duration=self.duration, random_crop=self.random_crop)

        if self.augment:
            wav = augment_waveform(wav, self.sr)

        features = waveform_to_features(
            wav, sr=self.sr,
            n_mels=self.n_mels,
            normalize=True
        )

        features = F.interpolate(
            features.unsqueeze(0),
            size=(224, 224),
            mode='bilinear',
            align_corners=False
        ).squeeze(0)

        if self.augment and random.random() < 0.3:
            features = spec_augment(features)

        if self.transform:
            features = self.transform(features)

        return features

    def __getitem__(self, idx):
        spk = random.choice(self.speakers)
        pos_files = self.spk2files[spk]

        if len(pos_files) < 2:
            return self.__getitem__(random.randint(0, len(self) - 1))

        anchor_path, pos_path = random.sample(pos_files, 2)

        neg_spk = random.choice([s for s in self.speakers if s != spk])
        neg_path = random.choice(self.spk2files[neg_spk])

        anchor = self._load_feature(anchor_path)
        positive = self._load_feature(pos_path)
        negative = self._load_feature(neg_path)

        return anchor, positive, negative


In [None]:
train_set = [(f, sid) for sid, files in spk2files_tr.items() for f in files]
val_set   = [(f, sid) for sid, files in spk2files_va.items() for f in files]

print("Train samples:", len(train_set))
print("Val samples:", len(val_set))


In [None]:
class SEModule(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels // reduction, 1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels // reduction, channels, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        scale = self.avg_pool(x)
        scale = self.fc1(scale)
        scale = self.relu(scale)
        scale = self.fc2(scale)
        scale = self.sigmoid(scale)
        return x * scale

class SERes2NetEmbedding(nn.Module):
    def __init__(self, embedding_size=256, pretrained=True):
        super().__init__()

        self.backbone = res2net50_v1b_26w_4s(pretrained=pretrained)
        self.backbone.fc = nn.Identity()

        self.se_block = SEModule(2048, reduction=16)

        self.embedding = nn.Sequential(
            nn.Linear(2048, embedding_size),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )

    def forward(self, x, return_emb=False, normalize_emb=False):
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)

        feat = self.backbone.forward(x)
        feat = feat.view(feat.size(0), 2048, 1, 1)
        feat = self.se_block(feat)
        feat = feat.view(feat.size(0), -1)

        emb = self.embedding(feat)
        if normalize_emb:
            emb = F.normalize(emb, p=2, dim=1)

        return emb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SERes2NetEmbedding(embedding_size=256, pretrained=True).to(device)


In [None]:
def compute_eer(model, txt_path, audio_dir, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()

    pairs = []
    skipped_missing = 0
    with open(txt_path, "r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line or line.startswith("#"):
                continue
            parts = line.split()
            if len(parts) < 3:
                continue
            label = 1 if parts[0] == "1" else 0
            f1, f2 = parts[1], parts[2]
            pairs.append((f1, f2, label))

    total_pairs = len(pairs)
    scores, labels = [], []

    for f1, f2, label in pairs:
        path1 = os.path.join(audio_dir, f1)
        path2 = os.path.join(audio_dir, f2)

        if not (os.path.exists(path1) and os.path.exists(path2)):
            skipped_missing += 1
            continue

        try:
            wav1 = read_audio(path1)
            wav2 = read_audio(path2)

            feat1 = waveform_to_features(wav1, normalize=True)  # [1, n_mels, T]
            feat2 = waveform_to_features(wav2, normalize=True)  # [1, n_mels, T]

            feat1 = F.interpolate(
                feat1.unsqueeze(0),
                size=(224, 224),
                mode="bilinear",
                align_corners=False
            ).squeeze(0)
            feat2 = F.interpolate(
                feat2.unsqueeze(0),
                size=(224, 224),
                mode="bilinear",
                align_corners=False
            ).squeeze(0)

            if feat1.size(0) == 1:
                feat1 = feat1.repeat(3, 1, 1)
                feat2 = feat2.repeat(3, 1, 1)

            feat1 = feat1.unsqueeze(0).to(device)
            feat2 = feat2.unsqueeze(0).to(device)

            with torch.no_grad():
                emb1 = model(feat1, return_emb=True, normalize_emb=True)
                emb2 = model(feat2, return_emb=True, normalize_emb=True)
                score = torch.cosine_similarity(emb1, emb2).item()

            scores.append(score)
            labels.append(label)

        except Exception as e:
            skipped_missing += 1
            continue

    if len(scores) == 0:
        return 0.0, 0.0

    scores = np.array(scores)
    labels = np.array(labels)
    fpr, tpr, thresholds = roc_curve(labels, scores)
    fnr = 1 - tpr
    eer_idx = np.nanargmin(np.abs(fnr - fpr))
    eer = (fpr[eer_idx] + fnr[eer_idx]) / 2
    threshold = thresholds[eer_idx]

    return eer, threshold


In [None]:
def train_triplet_with_eer(model, spk2files_tr, spk2files_va,
                           epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR,
                           weight_decay=1e-4, margin=0.3,
                           device=None,
                           TXT_E=TXT_E, TXT_H=TXT_H, AUDIO_DIR=AUDIO_DIR):

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_ds = TripletSpeakerDataset(spk2files_tr, random_crop=True, augment=True)
    val_ds   = TripletSpeakerDataset(spk2files_va, random_crop=False, augment=False)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=2, pin_memory=True)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )

    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs, eta_min=1e-6
    )

    criterion = nn.TripletMarginLoss(margin=margin, p=2)
    scaler = GradScaler()

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        num_batches = 0

        for anchor, positive, negative in train_loader:
            anchor, positive, negative = (
                anchor.to(device, non_blocking=True),
                positive.to(device, non_blocking=True),
                negative.to(device, non_blocking=True)
            )
            optimizer.zero_grad()

            with autocast():
                emb_a = model(anchor, return_emb=True, normalize_emb=True)
                emb_p = model(positive, return_emb=True, normalize_emb=True)
                emb_n = model(negative, return_emb=True, normalize_emb=True)

                loss = criterion(emb_a, emb_p, emb_n)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            num_batches += 1

        train_loss = total_loss / max(1, num_batches)

        model.eval()
        val_loss = 0.0
        val_batches = 0
        with torch.no_grad():
            for anchor, positive, negative in val_loader:
                anchor, positive, negative = (
                    anchor.to(device, non_blocking=True),
                    positive.to(device, non_blocking=True),
                    negative.to(device, non_blocking=True)
                )
                with autocast():
                    emb_a = model(anchor, return_emb=True, normalize_emb=True)
                    emb_p = model(positive, return_emb=True, normalize_emb=True)
                    emb_n = model(negative, return_emb=True, normalize_emb=True)
                    loss = criterion(emb_a, emb_p, emb_n)
                val_loss += loss.item()
                val_batches += 1

        val_loss /= max(1, val_batches)
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]

        print(f"Epoch {epoch}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}")

        model.eval()
        try:
            eer_e, thr_e = compute_eer(model, TXT_E, audio_dir=AUDIO_DIR, device=device)
            eer_h, thr_h = compute_eer(model, TXT_H, audio_dir=AUDIO_DIR, device=device)
            print(f"Epoch {epoch} | "
                  f"EER_E: {eer_e:.4f}, Thr_E: {thr_e:.4f} | "
                  f"EER_H: {eer_h:.4f}, Thr_H: {thr_h:.4f}")
        except Exception as e:
            print(f"[WARN] Skipped EER computation due to error: {e}")

        save_path = f"triplet_se_res2net_epoch{epoch}.pth"
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scaler_state": scaler.state_dict(),
        }, save_path)
        print(f"Saved checkpoint: {save_path}")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SERes2NetEmbedding(embedding_size=256, pretrained=False).to(device)

ckpt_path = "/kaggle/input/triplet/pytorch/default/1/triplet_se_res2net_epoch7.pth"
ckpt = torch.load(ckpt_path, map_location=device)
model.load_state_dict(ckpt["model_state"])
print(f"Loaded checkpoint from {ckpt_path}")


train_triplet_with_eer(
    model,
    spk2files_tr=spk2files_tr,
    spk2files_va=spk2files_va,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
    weight_decay=1e-4,
    margin=0.3,
    device=device
)
