In [None]:
!git clone https://github.com/Res2Net/Res2Net-PretrainedModels.git

import sys
sys.path.append('/kaggle/working/Res2Net-PretrainedModels')

In [None]:
!grep -E "def |class " /kaggle/working/Res2Net-PretrainedModels/res2net_v1b.py


In [None]:
import os
import argparse
import random
from pathlib import Path
from typing import List, Tuple

import numpy as np
import soundfile as sf
import librosa

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
import torchaudio.transforms as T

from scipy.optimize import brentq
from scipy.interpolate import interp1d

from sklearn.metrics import roc_curve

from tqdm import tqdm
import glob

from collections import defaultdict
from collections import Counter

from torch.cuda.amp import autocast, GradScaler

from res2net_v1b import res2net50_v1b_26w_4s

In [None]:
# ------------------------- Config / Hyperparams -------------------------
DEFAULT_SR = 16000        
AUDIO_DURATION = 2.5    
N_MELS = 128             
N_FFT = 1024                
HOP_LENGTH = 256       

EMBEDDING_SIZE = 256       

EPOCHS = 20               
BATCH_SIZE = 64            
LR = 1e-3                  

SPLIT_RATIO = (0.8, 0.2)     
MIN_UTTS_PER_SPK = 2  

TRAIN_ROOT = "/kaggle/input/voxvn-api491/train_small_wav"
TEST_ROOT = "/kaggle/input/voxvietnam"
AUDIO_DIR_TEST = os.path.join(TEST_ROOT, "wav", "wav")

In [None]:
def read_audio(path: str, sr: int = DEFAULT_SR, duration: float = AUDIO_DURATION, random_crop: bool = False):
    wav, orig_sr = sf.read(path)   
    if wav.ndim > 1:
        wav = np.mean(wav, axis=1)
    
    if orig_sr != sr:
        wav = librosa.resample(wav.astype('float32'), orig_sr, sr)
    
    target_len = int(sr * duration)
    
    if len(wav) >= target_len:
        if random_crop:
            start = np.random.randint(0, len(wav) - target_len + 1)
        else:
            start = 0
        wav = wav[start:start + target_len]
    else:
        pad_width = target_len - len(wav)
        wav = np.pad(wav, (0, pad_width), mode='constant')
    
    return wav.astype('float32')

In [None]:
def waveform_to_features(
    wav: torch.Tensor, 
    sr=DEFAULT_SR,
    n_fft=N_FFT, 
    hop_length=HOP_LENGTH,
    n_mels=N_MELS, 
    normalize=True
):

    if isinstance(wav, np.ndarray):
        wav = torch.tensor(wav, dtype=torch.float32)

    mel_spec = T.MelSpectrogram(
        sample_rate=sr,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        power=2.0
    )(wav.unsqueeze(0))

    log_mel = T.AmplitudeToDB()(mel_spec).squeeze(0) 
    if normalize:
        log_mel = (log_mel - log_mel.mean()) / (log_mel.std() + 1e-6)
    features = log_mel.unsqueeze(0)
    return features

In [None]:
def scan_dataset(root, min_utts_per_spk=MIN_UTTS_PER_SPK):
    all_files = glob.glob(os.path.join(root, "*", "*.wav"))
    speakers = sorted(list({os.path.basename(os.path.dirname(f)) for f in all_files}))
    spk2idx = {spk: i for i, spk in enumerate(speakers)}
    spk2files = defaultdict(list)
    for f in all_files:
        spk2files[spk2idx[os.path.basename(os.path.dirname(f))]].append(f)
    # dùng min_utts_per_spk từ config
    spk2files = {k: v for k, v in spk2files.items() if len(v) >= min_utts_per_spk}
    return spk2files


spk2files = scan_dataset(TRAIN_ROOT)
valid_spks = list(spk2files.keys())
random.shuffle(valid_spks)
split_idx = int(len(valid_spks)*0.8)

spk2files_train = {k: spk2files[k] for k in valid_spks[:split_idx]}
spk2files_val = {k: spk2files[k] for k in valid_spks[split_idx:]}

In [None]:
sample_file = spk2files_train[list(spk2files_train.keys())[0]][0]
wav = read_audio(sample_file)

In [None]:
features = waveform_to_features(wav)
print("Feature shape:", features.shape)
print("Mean / Std:", features.mean().item(), features.std().item())

In [None]:
class TripletSpeakerDataset(Dataset):
    def __init__(self, spk2files, sr=DEFAULT_SR, duration=AUDIO_DURATION,
                 n_mels=N_MELS,
                 augment=True, transform=None, random_crop=True, debug=False):
        self.spk2files = spk2files
        self.speakers = list(spk2files.keys())
        self.sr = sr
        self.duration = duration
        self.n_mels = n_mels
        self.augment = augment
        self.transform = transform
        self.random_crop = random_crop
        self.debug = debug

    def __len__(self):
        return sum(len(v) for v in self.spk2files.values())

    def _load_feature(self, path):
        wav = read_audio(path, sr=self.sr, duration=self.duration, random_crop=self.random_crop)

        if self.augment:
            wav = augment_waveform(wav, self.sr)

        features = waveform_to_features(
            wav, sr=self.sr,
            n_mels=self.n_mels,
            normalize=True
        )

        features = features.repeat(3, 1, 1) 

        features = F.interpolate(
            features.unsqueeze(0),
            size=(224, 224),
            mode='bilinear',
            align_corners=False
        ).squeeze(0) 

        if self.augment and random.random() < 0.3:
            features = spec_augment(features)

        if self.transform:
            features = self.transform(features)

        return features

    def __getitem__(self, idx):

        spk = random.choice(self.speakers)
        pos_files = self.spk2files[spk]

        if len(pos_files) < 2:
            return self.__getitem__(random.randint(0, len(self) - 1))

        anchor_path, pos_path = random.sample(pos_files, 2)

        neg_spk = random.choice([s for s in self.speakers if s != spk])
        neg_path = random.choice(self.spk2files[neg_spk])

        anchor = self._load_feature(anchor_path)
        positive = self._load_feature(pos_path)
        negative = self._load_feature(neg_path)

        return anchor, positive, negative


In [None]:
from torch.utils.data import DataLoader

train_ds = TripletSpeakerDataset(spk2files_train, random_crop=True, augment=False)
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True)
for a, p, n in train_loader:
    break

In [None]:
def augment_waveform(wav, sr):
    if random.random() < 0.5:
        aug_type = random.choice(["noise", "gain"])
        wav_tensor = torch.tensor(wav, dtype=torch.float32)

        if aug_type == "noise":
            noise = torch.randn_like(wav_tensor) * random.uniform(0.001, 0.01)
            wav_tensor += noise
        elif aug_type == "gain":
            gain = random.uniform(-6, 6)
            wav_tensor *= 10 ** (gain / 20)

        wav_tensor = torch.clamp(wav_tensor, -1.0, 1.0)
        wav = wav_tensor.numpy()

    return wav

def spec_augment(mel):
    mel = mel.clone()
    freq_mask = random.randint(1, max(1, mel.size(0)//8))
    time_mask = random.randint(1, max(1, mel.size(1)//8))
    f0 = random.randint(0, mel.size(0)-freq_mask)
    t0 = random.randint(0, mel.size(1)-time_mask)
    mel[f0:f0+freq_mask, :] = 0
    mel[:, t0:t0+time_mask] = 0
    return mel

In [None]:
class SEModule(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEModule, self).__init__()
        self.avg_pool = nn.AdaptiveAvgPool2d(1)
        self.fc1 = nn.Conv2d(channels, channels // reduction, 1)
        self.relu = nn.ReLU(inplace=True)
        self.fc2 = nn.Conv2d(channels // reduction, channels, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        scale = self.avg_pool(x)
        scale = self.fc1(scale)
        scale = self.relu(scale)
        scale = self.fc2(scale)
        scale = self.sigmoid(scale)
        return x * scale

class SERes2NetEmbedding(nn.Module):
    def __init__(self, embedding_size=EMBEDDING_SIZE, pretrained=True):
        super().__init__()

        self.backbone = res2net50_v1b_26w_4s(pretrained=pretrained)
        self.backbone.fc = nn.Identity()

        self.se_block = SEModule(2048, reduction=16)

        self.embedding = nn.Sequential(
            nn.Linear(2048, embedding_size),
            nn.BatchNorm1d(embedding_size),
            nn.ReLU(inplace=True),
            nn.Dropout(0.3)
        )

    def forward(self, x, return_emb=False, normalize_emb=False):
        if x.size(1) == 1:
            x = x.repeat(1, 3, 1, 1)

        feat = self.backbone.forward(x)
        feat = feat.view(feat.size(0), 2048, 1, 1)
        feat = self.se_block(feat)
        feat = feat.view(feat.size(0), -1)

        emb = self.embedding(feat)
        if normalize_emb:
            emb = F.normalize(emb, p=2, dim=1)

        return emb

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SERes2NetEmbedding(embedding_size=256, pretrained=True).to(device)

In [None]:
def load_voxvietnam_pairs(root):
    csv_path = os.path.join(root, "test_list_gt.csv")
    txt_path = os.path.join(root, "test_list.txt")
    path = csv_path if os.path.exists(csv_path) else txt_path
    assert os.path.exists(path), f"Không tìm thấy test list: {csv_path} | {txt_path}"

    pairs = []
    with open(path, "r", encoding="utf-8-sig") as f:
        for raw in f:
            line = raw.strip()
            if not line:
                continue
            line = line.replace('"', '').replace(',', ' ').replace('\t', ' ')
            parts = [p for p in line.split() if p.strip()]
            if len(parts) != 3:
                parts_col = [line]
                while len(parts_col) < 3:
                    nxt = f.readline()
                    if not nxt:
                        break
                    nxt = nxt.strip().replace('"','').replace(',', ' ').replace('\t',' ')
                    if nxt:
                        parts_col.append(nxt)
                parts_join = [p for p in " ".join(parts_col).split() if p.strip()]
                if len(parts_join) != 3:
                    print(f"Bỏ qua dòng lỗi: {raw.strip()}")
                    continue
                parts = parts_join

            label_str, f1, f2 = parts
            try:
                label = int(label_str)
                pairs.append((f1.strip(), f2.strip(), label))
            except ValueError:
                print(f"Bỏ qua dòng lỗi (label không hợp lệ): {raw.strip()}")
                continue

    return pairs

In [None]:
def compute_eer_test(model, pairs, audio_dir=AUDIO_DIR_TEST, device=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.eval()
    embeddings_cache = {}

    def get_embedding(path):
        if path in embeddings_cache:
            return embeddings_cache[path]

        full_path = os.path.join(audio_dir, path)
        wav = read_audio(full_path)  # đọc waveform
        feat = waveform_to_features(wav, normalize=True) 
        if feat.size(0) == 1:
            feat = feat.repeat(3, 1, 1)
        feat = F.interpolate(feat.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False).squeeze(0)
        feat = feat.unsqueeze(0).to(device)

        with torch.no_grad():
            emb = model(feat, return_emb=True, normalize_emb=True)
        emb = emb.cpu().numpy().flatten()
        embeddings_cache[path] = emb
        return emb

    scores, targets = [], []

    for f1, f2, label in pairs:
        try:
            emb1 = get_embedding(f1)
            emb2 = get_embedding(f2)
            sim = np.dot(emb1, emb2) / (np.linalg.norm(emb1) * np.linalg.norm(emb2) + 1e-6)
            scores.append(sim)
            targets.append(label)
        except Exception as e:
            print(f"Lỗi khi xử lý {f1}, {f2}: {e}")
            continue

    scores = np.array(scores)
    targets = np.array(targets)

    fpr, tpr, thresholds = roc_curve(targets, scores)
    fnr = 1 - tpr
    eer_idx = np.nanargmin(np.abs(fnr - fpr))
    eer = (fpr[eer_idx] + fnr[eer_idx]) / 2
    thr = thresholds[eer_idx]

    print(f"EER: {eer:.4f}, Threshold: {thr:.4f}")
    return eer, thr


In [None]:
def train_triplet_with_eer(model, spk2files_tr, spk2files_va,
                            pairs=None, test_root=TEST_ROOT,
                            epochs=EPOCHS, batch_size=BATCH_SIZE, lr=LR,
                            weight_decay=1e-4, margin=0.3,
                            device=None):

    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    train_ds = TripletSpeakerDataset(spk2files_tr, random_crop=True, augment=True)
    val_ds   = TripletSpeakerDataset(spk2files_va, random_crop=False, augment=False)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
                              num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
                            num_workers=2, pin_memory=True)

    optimizer = torch.optim.Adam(
        model.parameters(),
        lr=lr,
        weight_decay=weight_decay
    )
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=epochs, eta_min=1e-6
    )

    criterion = nn.TripletMarginLoss(margin=margin, p=2)
    scaler = GradScaler()

    if pairs is None:
        pairs = load_voxvietnam_pairs(test_root)

    for epoch in range(1, epochs + 1):
        model.train()
        total_loss = 0.0
        num_batches = 0

        for anchor, positive, negative in train_loader:
            anchor, positive, negative = (
                anchor.to(device, non_blocking=True),
                positive.to(device, non_blocking=True),
                negative.to(device, non_blocking=True)
            )
            optimizer.zero_grad()

            with autocast():
                emb_a = model(anchor, return_emb=True, normalize_emb=True)
                emb_p = model(positive, return_emb=True, normalize_emb=True)
                emb_n = model(negative, return_emb=True, normalize_emb=True)

                loss = criterion(emb_a, emb_p, emb_n)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()
            num_batches += 1

        train_loss = total_loss / max(1, num_batches)

        model.eval()
        val_loss = 0.0
        val_batches = 0
        with torch.no_grad():
            for anchor, positive, negative in val_loader:
                anchor, positive, negative = (
                    anchor.to(device, non_blocking=True),
                    positive.to(device, non_blocking=True),
                    negative.to(device, non_blocking=True)
                )
                with autocast():
                    emb_a = model(anchor, return_emb=True, normalize_emb=True)
                    emb_p = model(positive, return_emb=True, normalize_emb=True)
                    emb_n = model(negative, return_emb=True, normalize_emb=True)
                    loss = criterion(emb_a, emb_p, emb_n)
                val_loss += loss.item()
                val_batches += 1

        val_loss /= max(1, val_batches)
        scheduler.step()
        current_lr = scheduler.get_last_lr()[0]

        print(f"Epoch {epoch}/{epochs} | "
              f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | LR: {current_lr:.6f}")

        try:
            eer, thr = compute_eer_test(model, pairs, audio_dir=AUDIO_DIR_TEST, device=device)
            print(f"Epoch {epoch}/{epochs} | EER-Test: {eer:.4f} | Threshold: {thr:.4f}")
        except Exception as e:
            print(f"Skipped EER-Test computation due to error: {e}")

        save_path = f"triplet_se_res2net_epoch{epoch}.pth"
        torch.save({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "scaler_state": scaler.state_dict(),
        }, save_path)
        print(f"Saved checkpoint: {save_path}")


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = SERes2NetEmbedding(embedding_size=EMBEDDING_SIZE, pretrained=False).to(device)

pairs = load_voxvietnam_pairs(TEST_ROOT)
pairs = [(f1.replace("wav/", ""), f2.replace("wav/", ""), label) for f1, f2, label in pairs]

train_triplet_with_eer(
    model,
    spk2files_tr=spk2files_train,
    spk2files_va=spk2files_val,
    epochs=EPOCHS,
    batch_size=BATCH_SIZE,
    lr=LR,
    weight_decay=1e-4,
    margin=0.3,
    device=device,
    pairs=pairs
)