In [None]:
!pip install -U transformers huggingface_hub httpx

In [None]:
from torchvision import datasets as dt
import torchvision as tv
import matplotlib.pyplot as plt
import torch as pt
import numpy as np
import torchaudio
import os
import random
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as F
from transformers import Wav2Vec2Model

In [None]:
import os
from pathlib import Path

def collect_audio_list(root_fpt, root_edge, root_cut):
    data = []

    for sub in ["false_no", "false_yes", "true"]:
        sub_path = Path(root_fpt) / sub
        if not sub_path.exists():
            continue

        label = 1 if sub == "true" else 0

        for speaker_dir in sub_path.iterdir():
            if not speaker_dir.is_dir():
                continue
            for wav in speaker_dir.glob("*.wav"):
                data.append((str(wav), label))

    edge_root = Path(root_edge)
    for speaker_dir in edge_root.iterdir():
        if not speaker_dir.is_dir():
            continue
        for sub in speaker_dir.iterdir():
            if not sub.is_dir():
                continue
            label = 1 if sub.name.lower() == "true" else 0
            for mp3 in sub.glob("*.mp3"):
                data.append((str(mp3), label))

    cut_root = Path(root_cut)
    for mp3 in cut_root.glob("*.mp3"):
        data.append((str(mp3), 0))

    return data

root_fpt = "/kaggle/input/voice-fpt-aip491/Data_voices/Data_voices/FPT.AI"
root_edge = "/kaggle/input/voice-fpt-aip491/Data_voices/Data_voices/edge_voices_16k"
root_cut = "/kaggle/input/voice-fpt-aip491/Data_voices/Data_voices/cut_sound"

data_train = collect_audio_list(root_fpt, root_edge, root_cut)

random.seed(35)
random.shuffle(data_train)

n_total = len(data_train)
n_valid = int(0.1 * n_total)

data_valid = data_train[:n_valid]
data_train = data_train[n_valid:]

In [None]:
def collect_test_audio_list(root_test):
    data = []
    root_test = Path(root_test)

    for speaker_dir in root_test.iterdir():
        if not speaker_dir.is_dir():
            continue
        for sub_dir in speaker_dir.iterdir():
            if not sub_dir.is_dir():
                continue

            label = 1 if sub_dir.name.lower() == "true" else 0

            for audio_file in sub_dir.glob("*.mp3"):
                data.append((str(audio_file), label))
            for audio_file in sub_dir.glob("*.wav"):
                data.append((str(audio_file), label))

    print(f"{sum(l for _, l in data)} label=1, {len(data)-sum(l for _, l in data)} label=0")
    return data


root_test = "/kaggle/input/test-aip419/Datatest"
data_test = collect_test_audio_list(root_test)


In [None]:
class SpeakerClassificationDataset(Dataset):
    def __init__(self, data, fixed_len=40000, sample_rate=16000, one_hot=False):
        self.data = data
        self.fixed_len = fixed_len
        self.sample_rate = sample_rate
        self.one_hot = one_hot

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        path, spk_id = self.data[idx]
        wav, _ = torchaudio.load(path)
        wav = wav.mean(dim=0)  # mono
        wav = self._fix_length(wav)

        if self.one_hot:
            label = pt.zeros(self.num_speakers, dtype=pt.float32)
            label[spk_id] = 1.0
        else:
            label = spk_id

        return wav, label

    def _fix_length(self, wav):
        L = wav.size(0)
        if L > self.fixed_len:
            wav = wav[:self.fixed_len]
        elif L < self.fixed_len:
            pad_len = self.fixed_len - L
            wav = torch.nn.functional.pad(wav, (0, pad_len))
        return wav

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm
import numpy as np
from sklearn.metrics import roc_curve

class Wav2Vec2ID(nn.Module):
    def __init__(self, hidden_size=150, num_classes=900, freeze_encoder=True):
        super().__init__()
        from transformers import Wav2Vec2Model
        self.encoder = Wav2Vec2Model.from_pretrained("facebook/wav2vec2-base")
        if freeze_encoder:
            for p in self.encoder.parameters():
                p.requires_grad = False
        self.fc_hidden = nn.Sequential(
            nn.Linear(self.encoder.config.hidden_size, hidden_size),
            nn.BatchNorm1d(hidden_size),
            nn.ReLU(),
            nn.Dropout(0.2)
        )
        self.fc_out = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        out = self.encoder(x).last_hidden_state   # [B, T', H]
        pooled = out.mean(dim=1)
        emb = self.fc_hidden(pooled)
        logits = self.fc_out(emb)
        return logits

In [None]:
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score
import time

def train_one_epo(loader, model, optimizer, criterion, device):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []

    for xb, yb in tqdm(loader, desc="Train", leave=False):
        xb, yb = xb.to(device), yb.to(device).float().unsqueeze(1)

        optimizer.zero_grad()
        outputs = model(xb)
        loss = criterion(outputs, yb)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * xb.size(0)

        preds = pt.sigmoid(outputs).round().cpu().detach().numpy()
        all_preds.extend(preds)
        all_labels.extend(yb.cpu().numpy())

    acc = accuracy_score(all_labels, all_preds)
    return total_loss / len(loader.dataset), acc


def valid_at_epo(loader, model, criterion, device, threshold= 0.5):
    model.eval()
    total_loss = 0
    all_preds, all_labels = [], []
    all_probs = []
    batch_times = []

    with pt.no_grad():
        for xb, yb in tqdm(loader, desc="Valid", leave=False):
            xb, yb = xb.to(device), yb.to(device).float().unsqueeze(1)

            start = time.time()
            outputs = model(xb)
            end = time.time()
            batch_times.append(end - start)

            loss = criterion(outputs, yb)
            total_loss += loss.item() * xb.size(0)

            probs = pt.sigmoid(outputs).cpu().numpy()
            preds = (probs >= threshold).astype(int)

            all_probs.extend(probs)
            all_preds.extend(preds)
            all_labels.extend(yb.cpu().numpy())


    acc = accuracy_score(all_labels, all_preds)
    f1 = f1_score(all_labels, all_preds)
    try:
        auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        auc = float('nan')
    avg_forward_time = sum(batch_times) / len(batch_times) if batch_times else 0.0
    return total_loss / len(loader.dataset), acc, auc, f1, avg_forward_time

def valid_at_epo_t(loader, model, criterion, device, step=0.01):
    model.eval()
    total_loss = 0
    all_labels, all_probs = [], []
    batch_times = []

    with pt.no_grad():
        for xb, yb in tqdm(loader, desc="Test", leave=False):
            xb, yb = xb.to(device), yb.to(device).float().unsqueeze(1)

            start = time.time()
            outputs = model(xb)
            end = time.time()
            batch_times.append(end - start)

            loss = criterion(outputs, yb)
            total_loss += loss.item() * xb.size(0)

            probs = pt.sigmoid(outputs).cpu().numpy()
            all_probs.extend(probs)
            all_labels.extend(yb.cpu().numpy())

    thresholds = [i * step for i in range(int(1 / step) + 1)]
    best_s, best_f1, best_thr, best_acc = 0, 0, 0.5, 0

    for thr in thresholds:
        preds = (pt.tensor(all_probs) >= thr).int().numpy()
        f1 = f1_score(all_labels, preds)
        acc = accuracy_score(all_labels, preds)
        if f1*acc > best_s:
            best_f1, best_thr, best_acc = f1, thr, acc
            best_s = f1*acc

    try:
        auc = roc_auc_score(all_labels, all_probs)
    except ValueError:
        auc = float("nan")

    avg_forward_time = sum(batch_times) / len(batch_times) if batch_times else 0.0

    return total_loss / len(loader.dataset), best_acc, auc, best_f1, best_thr, avg_forward_time

In [None]:
batch_s = 110

tra_ds = SpeakerClassificationDataset(data_train, fixed_len=40000, sample_rate=16000)
tra_dl = DataLoader(tra_ds, batch_size=batch_s, shuffle=True)

val_ds = SpeakerClassificationDataset(data_valid, fixed_len=40000, sample_rate=16000)
val_dl = DataLoader(val_ds, batch_size=batch_s, shuffle=False)

tes_ds = SpeakerClassificationDataset(data_test, fixed_len=40000, sample_rate=16000)
tes_dl = DataLoader(tes_ds, batch_size=batch_s, shuffle=False)

In [None]:
device = pt.device("cuda" if pt.cuda.is_available() else "cpu")
model = Wav2Vec2ID(num_classes=1, freeze_encoder=True).to(device)
optimizer = pt.optim.Adam(model.parameters(), lr=2e-4)
criterion = pt.nn.BCEWithLogitsLoss()

list_name = []
list_tr_lo = list()
list_tr_ac = list()
list_va_lo = list()
list_va_ac = list()
list_va_au = list()
list_va_f1 = list()
list_va_ti = list()
list_te_ac = list()
list_te_au = list()
list_te_f1 = list()
list_te_b_ac = list()
list_te_b_au = list()
list_te_b_f1 = list()
list_te_b_th = list()

In [None]:
for tn in range(25):

    print(tn)

    if tn == 1:
        for p in model.encoder.parameters():
            p.requires_grad = True
        
        optimizer = torch.optim.Adam(model.parameters(), lr=4e-5)
        
    train_loss, train_acc = train_one_epo(tra_dl, model, optimizer, criterion, device)
    valid_loss, valid_acc, valid_auc, valid_f1, valid_time = valid_at_epo(val_dl, model, criterion, device, 0.5)
    _, test_acc, test_auc, test_f1, _ = valid_at_epo(tes_dl, model, criterion, device, 0.5)
    _, test_b_acc, test_b_auc, test_b_f1, test_b_thr, _ = valid_at_epo_t(tes_dl, model, criterion, device)
    

    list_tr_lo.append(train_loss)
    list_tr_ac.append(train_acc)
    
    list_va_lo.append(valid_loss)
    list_va_ac.append(valid_acc)
    list_va_au.append(valid_auc)
    list_va_f1.append(valid_f1)
    list_va_ti.append(valid_time)

    list_te_ac.append(test_acc)
    list_te_au.append(test_auc)
    list_te_f1.append(test_f1)
    
    list_te_b_ac.append(test_b_acc)
    list_te_b_au.append(test_b_auc)
    list_te_b_f1.append(test_b_f1)
    list_te_b_th.append(test_b_thr)

    print(f"  Train -> loss: {train_loss:.4f}, acc: {train_acc:.4f}")
    print(f"  Valid -> loss: {valid_loss:.4f}, acc: {valid_acc:.4f}, auc: {valid_auc:.4f}, f1: {valid_f1:.4f}, time: {valid_time:.4f}s")
    print(f"  Test -> acc: {test_acc:.4f}, auc: {test_auc:.4f}, f1: {test_f1:.4f}")
    print(f"  Test best  -> acc: {test_b_acc:.4f}, auc: {test_b_auc:.4f}, f1: {test_b_f1:.4f}, th: {test_b_thr:.2f}")
    os.makedirs("w2v", exist_ok=True)
    pt.save(model.state_dict(), "w2v/" + str(tn) + ".pt")