In [1]:
import os                       # for working with files

import numpy as np              # for numerical computationss
import pandas as pd             # for working with dataframes
import matplotlib.pyplot as plt # for plotting informations on graph and images using tensors
from PIL import Image
import json

import torch                    # Pytorch module 
import torch.nn as nn           # for creating  neural networks
import torch.optim as optim
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset, DataLoader # for dataloaders 
from torch.utils.data import random_split
import torch.nn.functional as F

import speechbrain as sb
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm


In [2]:
from model import SASVModel
from metrics import compute_eer

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


In [4]:
from dataloader import SASVDataset
sasv_dataset = SASVDataset(data_root="D:/home4/vuhl/VSASV-Dataset/vlsp2025/train")


total_size = len(sasv_dataset)
val_size = int(0.1 * total_size)  # 10% làm validation
train_size = total_size - val_size

train_dataset, val_dataset = random_split(sasv_dataset, [train_size, val_size])
print(len(train_dataset), len(val_dataset))

91231 10136


In [5]:
def custom_collate_fn(batch):
    tst_waveforms = [item['tst_waveform'] for item in batch]
    enr_features = [item['enr_feature'] for item in batch]
    tst_features = [item['tst_feature'] for item in batch]
    labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)

    tst_waveforms = pad_sequence(tst_waveforms, batch_first=True)     # [B, T]
    enr_features = pad_sequence(enr_features, batch_first=True)       # [B, T, F]
    tst_features = pad_sequence(tst_features, batch_first=True)       # [B, T, F]

    return tst_waveforms, enr_features, tst_features, labels

In [6]:
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, collate_fn=custom_collate_fn)
val_loader =DataLoader(val_dataset, batch_size=4, collate_fn= custom_collate_fn)

In [7]:
config_dict = {
    "model_config": {
        "architecture": "AASIST",
        "nb_samp": 64600,
        "first_conv": 128,
        "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
        "gat_dims": [64, 32],
        "pool_ratios": [0.5, 0.7, 0.5, 0.5],
        "temperatures": [2.0, 2.0, 100.0, 100.0]
    },
    }

model = SASVModel(config_dict).to(device)

In [None]:
def train_step(model, batch, criterion, optimizer, device):
    model.train()

    # Giải nén batch
    tst_waveform, enr_feature, tst_feature, label = batch
    tst_waveform = tst_waveform.to(device)         # [B, T]
    enr_feature = enr_feature.to(device)           # [B, T, F]
    tst_feature = tst_feature.to(device)           # [B, T, F]
    label = label.to(device)                       # [B]

    output = model(tst_waveform, enr_feature, tst_feature)  # [B, 2]

    loss = criterion(output, label)

    optimizer.zero_grad()
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
    optimizer.step()

    predicted = torch.argmax(output, dim=1)
    correct = (predicted == label).sum().item()
    accuracy = correct / label.size(0)

    return loss.item(), accuracy

def val_step(model, batch, criterion, device):

    model.eval()

    with torch.no_grad():
        tst_waveform, enr_feature, tst_feature, label = batch
        tst_waveform = tst_waveform.to(device)
        enr_feature = enr_feature.to(device)
        tst_feature = tst_feature.to(device)
        label = label.to(device)

        output = model(tst_waveform, enr_feature, tst_feature)  # [B, 2]

        loss = criterion(output, label)

        predicted = torch.argmax(output, dim=1)
        correct = (predicted == label).sum().item()
        accuracy = correct / label.size(0)

        # Softmax để lấy score class 1 (tức là "target")
        probs = F.softmax(output, dim=1)
        scores = probs[:, 1]  # Lấy score tương ứng với class "target"

        # Chuyển về CPU để tính EER
        eer = compute_eer(scores.cpu().numpy(), label.cpu().numpy())

    return loss.item(), accuracy, eer

In [None]:
def train_model(model, train_loader, val_loader, criterion, optimizer, device, num_epochs):
    for epoch in range(num_epochs):
        print(f"\nEpoch [{epoch + 1}/{num_epochs}]")

        # ========== Training ==========
        model.train()
        total_train_loss = 0.0
        total_train_acc = 0.0

        train_bar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)
        for batch in train_bar:
            loss, acc = train_step(model, batch, criterion, optimizer, device)
            total_train_loss += loss
            total_train_acc += acc
            train_bar.set_postfix(loss=loss)

        avg_train_loss = total_train_loss / len(train_loader)
        avg_train_acc = total_train_acc / len(train_loader)

        # ========== Validation ==========
        model.eval()
        total_val_loss = 0.0
        total_val_acc = 0.0
        all_scores = []
        all_labels = []

        val_bar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs}", leave=False)

        with torch.no_grad():
            for batch in val_bar:
                loss, acc, eer = val_step(model, batch, criterion, device)
                total_val_loss += loss
                total_val_acc += acc

                tst_waveform, enr_feature, tst_feature, label = batch
                output = model(
                    tst_waveform.to(device),
                    enr_feature.to(device),
                    tst_feature.to(device)
                )
                probs = F.softmax(output, dim=1)
                scores = probs[:, 1] 
                all_scores.extend(scores.cpu().numpy())
                all_labels.extend(label.cpu().numpy())

        avg_val_loss = total_val_loss / len(val_loader)
        avg_val_acc = total_val_acc / len(val_loader)
        overall_val_eer = compute_eer(all_scores, all_labels)

        print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {avg_train_acc:.4f}")
        print(f"Val Loss:   {avg_val_loss:.4f} | Val Acc:   {avg_val_acc:.4f} | Val EER: {overall_val_eer:.4f}")

In [12]:
from torch.optim.lr_scheduler import ReduceLROnPlateau

optimizer = torch.optim.AdamW(model.parameters(), lr=0.0001, weight_decay=0.001)

scheduler = ReduceLROnPlateau(
    optimizer, 
    mode='min',          
    factor=0.5,           
    patience=2,      
    verbose=True
)
criterion = nn.CrossEntropyLoss()


train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    criterion=criterion,
    optimizer=optimizer,
    device=device,
    num_epochs=10
)


Epoch [1/10]


                                                                            

KeyboardInterrupt: 