In [None]:
!pip install -U transformers huggingface_hub httpx

In [None]:
from torchvision import datasets
import torchvision
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torchaudio
import os
import random
from torch.utils.data import Dataset, DataLoader, random_split
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
import copy

root = "/kaggle/input/vietnam-celeb-dataset/full-dataset/data"

In [None]:
import os

def list_all_wavs(root, except_path):
    all_wavs = []
    spk2label = {}
    for spk_folder in sorted(os.listdir(root)):
        fpath = os.path.join(root, spk_folder)
        if not os.path.isdir(fpath):
            continue
        for w in sorted(os.listdir(fpath)):
            if w.endswith(".wav"):
                rel_path = f"{spk_folder}/{w}"
                if rel_path in except_path:
                    continue
                all_wavs.append(rel_path)
    return all_wavs


def read_txt_paths(txt_file, valid_paths):
    data = []
    els_path = []
    num_removed = 0

    with open(txt_file, "r") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            parts = line.split("\t")
            if len(parts) != 3:
                raise ValueError(f"{line}")
            label, path1, path2 = parts

            if path1 not in valid_paths or path2 not in valid_paths:
                num_removed += 1
                continue

            data.append((int(label), path1, path2))
            els_path.append(path1)
            els_path.append(path2)

    return data, num_removed, set(els_path)


a_files = set(list_all_wavs(root, []))
print(len(a_files))

e_files, e_r, e_e = read_txt_paths("/kaggle/input/vietnam-celeb-dataset/full-dataset/vietnam-celeb-e.txt", set(a_files))
print(len(e_files), e_r)

h_files, h_r, h_e = read_txt_paths("/kaggle/input/vietnam-celeb-dataset/full-dataset/vietnam-celeb-h.txt", set(a_files))
print(len(h_files), h_r)

print(len(a_files - (e_e | h_e)))

print(e_files[:3])

In [None]:
def assign_labels_from_ids(file_list):
    data = []
    spk2label = {}

    for path in file_list:
        spk = path.split("/")[0]

        if spk not in spk2label:
            spk2label[spk] = len(spk2label)

        label = spk2label[spk]
        data.append((label, path))

    return data, spk2label

all_files, all_spk = assign_labels_from_ids(list_all_wavs(root, e_e|h_e))
print(len(all_files), len(all_spk))
print(all_files[:5])

In [None]:
class SpeakerClassificationFeatureDataset(Dataset):
    def __init__(self, root, file_list, fixed_len=64000, one_hot=True,
                 augment=False, sample_rate=16000, si=16):
        
        self.root = root
        self.fixed_len = fixed_len
        self.one_hot = one_hot
        self.augment = augment
        self.sample_rate = sample_rate
        self.si = si

        self.data = [(label, os.path.join(root, path)) for label, path in file_list]
        self.num_speakers = max(label for label, _ in file_list) + 1

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        label, path = self.data[idx]

        wav, sr = torchaudio.load(path)
        wav = wav.mean(dim=0)

        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            wav = resampler(wav.unsqueeze(0)).squeeze(0)

        if self.augment:
            if random.random() < 0.5:
                wav = wav * random.uniform(0.9, 1.1)
            if random.random() < 0.5:
                wav = wav + torch.randn_like(wav) * 0.005

        wav = self._fix_length(wav)

        if self.one_hot:
            lbl = torch.zeros(self.num_speakers)
            lbl[label] = 1.0
        else:
            lbl = label

        return wav, lbl

    def _fix_length(self, wav):
        L = wav.size(0)
        if L > self.fixed_len:
            start = (L - self.fixed_len) // 2
            return wav[start:start + self.fixed_len]
        elif L < self.fixed_len:
            return F.pad(wav, (0, self.fixed_len - L))
        return wav


class SiameseSpeakerFeatureDataset(Dataset):
    def __init__(self, root, file_list, fixed_len=64000, sample_rate=16000, si=16):
        self.root = root
        self.file_list = file_list
        self.fixed_len = fixed_len
        self.sample_rate = sample_rate
        self.si = si

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        label, rel1, rel2 = self.file_list[idx]
        p1 = os.path.join(self.root, rel1)
        p2 = os.path.join(self.root, rel2)

        wav1, sr1 = torchaudio.load(p1)
        wav2, sr2 = torchaudio.load(p2)
        wav1 = wav1.mean(dim=0)
        wav2 = wav2.mean(dim=0)

        if sr1 != self.sample_rate:
            wav1 = torchaudio.transforms.Resample(sr1, self.sample_rate)(wav1.unsqueeze(0)).squeeze(0)
        if sr2 != self.sample_rate:
            wav2 = torchaudio.transforms.Resample(sr2, self.sample_rate)(wav2.unsqueeze(0)).squeeze(0)

        wav1 = self._fix_length(wav1)
        wav2 = self._fix_length(wav2)

        return wav1, wav2, torch.tensor(label, dtype=torch.float32)

    def _fix_length(self, wav):
        L = wav.size(0)
        if L > self.fixed_len:
            start = (L - self.fixed_len) // 2
            return wav[start:start + self.fixed_len]
        elif L < self.fixed_len:
            return F.pad(wav, (0, self.fixed_len - L))
        return wav

def siamese_collate_fn(batch):
    anchors, pairs, labels = zip(*batch)
    anchors = torch.stack(anchors)
    pairs = torch.stack(pairs)
    labels = torch.tensor(labels).float()
    return anchors, pairs, labels

def collate_fn_classification(batch):
    wavs = torch.stack([b[0] for b in batch], dim=0)
    labels = torch.tensor([b[1] for b in batch], dtype=torch.long)
    return wavs, labels

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from sklearn.metrics import roc_curve

class Wav2Vec2ID(nn.Module):
    def __init__(self, hidden_size=1024, num_classes=900, freeze_encoder=True):
        super().__init__()
        from transformers import Wav2Vec2Model
        self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        if freeze_encoder:
            for p in self.encoder.parameters():
                p.requires_grad = False
        self.fc_hidden = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Dropout(p=0.2)
        )
        self.fc_out = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.encoder(x).last_hidden_state
        pooled = out.mean(dim=1)
        emb = self.fc_hidden(pooled)
        logits = self.fc_out(emb)
        return emb, logits


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AAMSoftmaxLoss(nn.Module):
    def __init__(self, num_classes, emb_dim=256, margin=0.3, scale=30):
        super().__init__()
        self.num_classes = num_classes
        self.margin = margin
        self.scale = scale
        self.weight = nn.Parameter(torch.randn(num_classes, emb_dim))
        nn.init.xavier_normal_(self.weight)

    def forward(self, emb, labels):
        emb = F.normalize(emb)
        W = F.normalize(self.weight)

        cosine = F.linear(emb, W)  # [B, num_classes]

        theta = torch.acos(torch.clamp(cosine, -1.0 + 1e-7, 1.0 - 1e-7))
        target_logit = torch.cos(theta + self.margin)

        one_hot = F.one_hot(labels, num_classes=self.num_classes).float()
        logits = cosine * (1 - one_hot) + target_logit * one_hot

        logits *= self.scale

        loss = F.cross_entropy(logits, labels)
        return loss, logits

def train_one_epoch_am(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0.0
    correct, total = 0, 0

    for x, labels in tqdm(loader, desc="Train", leave=False):
        x = x.to(device)
        labels = labels.to(device)

        optimizer.zero_grad()
        
        emb, _ = model(x)
        loss, logits = criterion(emb, labels)

        loss.backward()
        optimizer.step()

        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += x.size(0)

    avg_loss = total_loss / total
    acc = correct / total
    return avg_loss, acc

def validate_one_epoch_am(model, loader, device, normalize_emb=True):
    model.eval()
    all_scores, all_labels = [], []

    with torch.no_grad():
        for anchor, pair, label in tqdm(loader, desc="Valid", leave=False):
            anchor = anchor.to(device)
            pair = pair.to(device)
            label = label.to(device)

            emb1, _ = model(anchor)
            emb2, _ = model(pair)

            if normalize_emb:
                emb1 = F.normalize(emb1, dim=1)
                emb2 = F.normalize(emb2, dim=1)

            scores = F.cosine_similarity(emb1, emb2)
            all_scores.append(scores.cpu())
            all_labels.append(label.cpu())

    all_scores = torch.cat(all_scores).numpy()
    all_labels = torch.cat(all_labels).numpy()

    fpr, tpr, thresholds = roc_curve(all_labels, all_scores)
    fnr = 1 - tpr
    eer_idx = np.nanargmin(np.abs(fnr - fpr))
    eer = (fpr[eer_idx] + fnr[eer_idx]) / 2.0
    return eer


In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

batch_s = 64

tra_ds = SpeakerClassificationFeatureDataset(root, all_files, fixed_len=40000, one_hot=False, augment=True)
tra_dl = DataLoader(tra_ds, batch_size=batch_s, shuffle=True, collate_fn=collate_fn_classification)
et_ds = SiameseSpeakerFeatureDataset(root, e_files, fixed_len=40000)
et_dl = DataLoader(et_ds, batch_size=batch_s, shuffle=True, collate_fn=siamese_collate_fn)
ht_ds = SiameseSpeakerFeatureDataset(root, h_files, fixed_len=40000)
ht_dl = DataLoader(ht_ds, batch_size=batch_s, shuffle=True, collate_fn=siamese_collate_fn)

start_epoch = 0

In [None]:
model = Wav2Vec2ID(num_classes=905, hidden_size=512, freeze_encoder=True).to(device)
criterion = AAMSoftmaxLoss(num_classes=905, emb_dim=512, margin=0.025, scale=2).to(device)
optimizer = torch.optim.Adam(list(model.parameters()) + list(criterion.parameters()), lr=2e-4)

In [None]:
for epoch in range(start_epoch, 25):

    criterion.margin = (epoch + 1)*0.025
    criterion.scale = (epoch + 1)*2

    tr_loss, tr_acc = train_one_epoch_am(model, tra_dl, criterion, optimizer, device)
    
    if epoch == 0 or (epoch + 1)%2 == 0:
        e_eer = validate_one_epoch_am(model, et_dl, device)
        h_eer = validate_one_epoch_am(model, ht_dl, device)
    else:
        e_eer = 0
        h_eer = 0
    
    print(f"Epoch {epoch+1}: Train Loss={tr_loss:.4f}, Train Acc={tr_acc:.4f}, E EER={e_eer*100:.2f}%, H EER={h_eer*100:.2f}%")

    optimizer_info = {
        'param_groups': [
            {k: v for k, v in group.items() if k in ['lr', 'betas', 'weight_decay']}
            for group in optimizer.param_groups
        ]
    }
    ckpt_path = f"wav2vec2_vnc_{epoch}.pt"
    torch.save({
        "epoch": epoch,
        "model_state": model.state_dict(),
        "criterion_state": criterion.state_dict(),
        "optimizer_info": optimizer_info
    }, ckpt_path)
