In [55]:
import requests, pathlib, os, zipfile, gdown, time
import torch, torchaudio, librosa, soundfile, math, numpy, random
import torch.nn.functional as F

from pathlib import Path
from torch import nn, Tensor
from tqdm import tqdm 

# Create Train/Val

In [56]:
# import os
# import random

# def split_speakers_with_spoof(base_dir, train_list_file, val_list_file, train_ratio=0.8, seed=42):
#     random.seed(seed)

#     # Láº¥y toÃ n bá»™ speakers cÃ³ folder spoof
#     speakers = []
#     for d in os.listdir(base_dir):
#         spk_path = os.path.join(base_dir, d)
#         spoof_path = os.path.join(spk_path, "spoof")
#         if os.path.isdir(spk_path) and os.path.exists(spoof_path) and len(os.listdir(spoof_path)) > 0:
#             speakers.append(d)

#     speakers.sort()

#     # Shuffle vÃ  split
#     random.shuffle(speakers)
#     n_train = int(len(speakers) * train_ratio)
#     train_speakers = speakers[:n_train]
#     val_speakers = speakers[n_train:]

#     print(f"Total speakers with spoof: {len(speakers)} | Train: {len(train_speakers)} | Val: {len(val_speakers)}")

#     def write_list(speaker_subset, output_file):
#         with open(output_file, "w") as f:
#             for speaker in speaker_subset:
#                 spk_path = os.path.join(base_dir, speaker)
#                 for label in ["bonafide", "spoof"]:
#                     label_path = os.path.join(spk_path, label)
#                     if not os.path.exists(label_path):
#                         continue
#                     for wav in os.listdir(label_path):
#                         if wav.endswith(".wav"):
#                             filepath = os.path.join(label_path, wav)
#                             f.write(f"{filepath} {label}\n")

#     # Ghi file train/val
#     write_list(train_speakers, train_list_file)
#     write_list(val_speakers, val_list_file)

# # VÃ­ dá»¥ dÃ¹ng
# split_speakers_with_spoof("/kaggle/input/vlsp-train/home4/vuhl/VSASV-Dataset/vlsp2025/train", "train_list.txt", "val_list.txt", train_ratio=0.8)


# AASIST Model

In [57]:
from typing import Union
import numpy as np

class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_weight = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x):
        '''
        x   :(#bs, #node, #dim)
        '''
        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)
        return x

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map(self, x):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)
        att_map = torch.matmul(att_map, self.att_weight)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class HtrgGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        self.proj_type1 = nn.Linear(in_dim, in_dim)
        self.proj_type2 = nn.Linear(in_dim, in_dim)

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_projM = nn.Linear(in_dim, out_dim)

        self.att_weight11 = self._init_new_params(out_dim, 1)
        self.att_weight22 = self._init_new_params(out_dim, 1)
        self.att_weight12 = self._init_new_params(out_dim, 1)
        self.att_weightM = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        self.proj_with_attM = nn.Linear(in_dim, out_dim)
        self.proj_without_attM = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x1, x2, master=None):
        '''
        x1  :(#bs, #node, #dim)
        x2  :(#bs, #node, #dim)
        '''
        num_type1 = x1.size(1)
        num_type2 = x2.size(1)

        x1 = self.proj_type1(x1)
        x2 = self.proj_type2(x2)

        x = torch.cat([x1, x2], dim=1)

        if master is None:
            master = torch.mean(x, dim=1, keepdim=True)

        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x, num_type1, num_type2)

        # directional edge for master node
        master = self._update_master(x, master)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)

        x1 = x.narrow(1, 0, num_type1)
        x2 = x.narrow(1, num_type1, num_type2)

        return x1, x2, master

    def _update_master(self, x, master):

        att_map = self._derive_att_map_master(x, master)
        master = self._project_master(x, master, att_map)

        return master

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map_master(self, x, master):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = x * master
        att_map = torch.tanh(self.att_projM(att_map))

        att_map = torch.matmul(att_map, self.att_weightM)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _derive_att_map(self, x, num_type1, num_type2):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)

        att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)

        att_board[:, :num_type1, :num_type1, :] = torch.matmul(
            att_map[:, :num_type1, :num_type1, :], self.att_weight11)
        att_board[:, num_type1:, num_type1:, :] = torch.matmul(
            att_map[:, num_type1:, num_type1:, :], self.att_weight22)
        att_board[:, :num_type1, num_type1:, :] = torch.matmul(
            att_map[:, :num_type1, num_type1:, :], self.att_weight12)
        att_board[:, num_type1:, :num_type1, :] = torch.matmul(
            att_map[:, num_type1:, :num_type1, :], self.att_weight12)

        att_map = att_board

        # att_map = torch.matmul(att_map, self.att_weight12)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _project_master(self, x, master, att_map):

        x1 = self.proj_with_attM(torch.matmul(
            att_map.squeeze(-1).unsqueeze(1), x))
        x2 = self.proj_without_attM(master)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class GraphPool(nn.Module):
    def __init__(self, k: float, in_dim: int, p: Union[float, int]):
        super().__init__()
        self.k = k
        self.sigmoid = nn.Sigmoid()
        self.proj = nn.Linear(in_dim, 1)
        self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
        self.in_dim = in_dim

    def forward(self, h):
        Z = self.drop(h)
        weights = self.proj(Z)
        scores = self.sigmoid(weights)
        new_h = self.top_k_graph(scores, h, self.k)

        return new_h

    def top_k_graph(self, scores, h, k):
        """
        args
        =====
        scores: attention-based weights (#bs, #node, 1)
        h: graph data (#bs, #node, #dim)
        k: ratio of remaining nodes, (float)

        returns
        =====
        h: graph pool applied data (#bs, #node', #dim)
        """
        _, n_nodes, n_feat = h.size()
        n_nodes = max(int(n_nodes * k), 1)
        _, idx = torch.topk(scores, n_nodes, dim=1)
        idx = idx.expand(-1, -1, n_feat)

        h = h * scores
        h = torch.gather(h, 1, idx)

        return h


class CONV(nn.Module):
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10**(mel / 2595) - 1)

    def __init__(self,
                 out_channels,
                 kernel_size,
                 sample_rate=16000,
                 in_channels=1,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=False,
                 groups=1,
                 mask=False):
        super().__init__()
        if in_channels != 1:

            msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
                in_channels)
            raise ValueError(msg)
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate

        # Forcing the filters to be odd (i.e, perfectly symmetrics)
        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.mask = mask
        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')

        NFFT = 512
        f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        fmel = self.to_mel(f)
        fmelmax = np.max(fmel)
        fmelmin = np.min(fmel)
        filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
        filbandwidthsf = self.to_hz(filbandwidthsmel)

        self.mel = filbandwidthsf
        self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
                                  (self.kernel_size - 1) / 2 + 1)
        self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
        for i in range(len(self.mel) - 1):
            fmin = self.mel[i]
            fmax = self.mel[i + 1]
            hHigh = (2*fmax/self.sample_rate) * \
                np.sinc(2*fmax*self.hsupp/self.sample_rate)
            hLow = (2*fmin/self.sample_rate) * \
                np.sinc(2*fmin*self.hsupp/self.sample_rate)
            hideal = hHigh - hLow

            self.band_pass[i, :] = Tensor(np.hamming(
                self.kernel_size)) * Tensor(hideal)

    def forward(self, x, mask=False):
        band_pass_filter = self.band_pass.clone().to(x.device)
        if mask:
            A = np.random.uniform(0, 20)
            A = int(A)
            A0 = random.randint(0, band_pass_filter.shape[0] - A)
            band_pass_filter[A0:A0 + A, :] = 0
        else:
            band_pass_filter = band_pass_filter

        self.filters = (band_pass_filter).view(self.out_channels, 1,
                                               self.kernel_size)

        return F.conv1d(x,
                        self.filters,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        bias=None,
                        groups=1)


class Residual_block(nn.Module):
    def __init__(self, nb_filts, first=False):
        super().__init__()
        self.first = first

        if not self.first:
            self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
        self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(1, 1),
                               stride=1)
        self.selu = nn.SELU(inplace=True)

        self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
        self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(0, 1),
                               stride=1)

        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
                                             out_channels=nb_filts[1],
                                             padding=(0, 1),
                                             kernel_size=(1, 3),
                                             stride=1)

        else:
            self.downsample = False
        self.mp = nn.MaxPool2d((1, 3))  # self.mp = nn.MaxPool2d((1,4))

    def forward(self, x):
        identity = x
        if not self.first:
            out = self.bn1(x)
            out = self.selu(out)
        else:
            out = x
        out = self.conv1(x)

        # print('out',out.shape)
        out = self.bn2(out)
        out = self.selu(out)
        # print('out',out.shape)
        out = self.conv2(out)
        #print('conv2 out',out.shape)
        if self.downsample:
            identity = self.conv_downsample(identity)

        out += identity
        out = self.mp(out)
        return out


class AASIST_Model(nn.Module):
    def __init__(self):
        super().__init__()
        d_args = {
        "architecture": "AASIST",
        "nb_samp": 64600,
        "first_conv": 128,
        "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
        "gat_dims": [64, 32],
        "pool_ratios": [0.5, 0.7, 0.5, 0.5],
        "temperatures": [2.0, 2.0, 100.0, 100.0]
        }

        self.d_args = d_args
        filts = d_args["filts"]
        gat_dims = d_args["gat_dims"]
        pool_ratios = d_args["pool_ratios"]
        temperatures = d_args["temperatures"]

        self.conv_time = CONV(out_channels=filts[0],
                              kernel_size=d_args["first_conv"],
                              in_channels=1)
        self.first_bn = nn.BatchNorm2d(num_features=1)

        self.drop = nn.Dropout(0.5, inplace=True)
        self.drop_way = nn.Dropout(0.2, inplace=True)
        self.selu = nn.SELU(inplace=True)

        self.encoder = nn.Sequential(
            nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
            nn.Sequential(Residual_block(nb_filts=filts[2])),
            nn.Sequential(Residual_block(nb_filts=filts[3])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])))

        self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
        self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))

        self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[0])
        self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[1])

        self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
        self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
        self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.out_layer = nn.Linear(5 * gat_dims[1], 2)

    def forward(self, x, Freq_aug=False):

        x = x.unsqueeze(1)
        x = self.conv_time(x, mask=Freq_aug)
        x = x.unsqueeze(dim=1)
        x = F.max_pool2d(torch.abs(x), (3, 3))
        x = self.first_bn(x)
        x = self.selu(x)

        # get embeddings using encoder
        # (#bs, #filt, #spec, #seq)
        e = self.encoder(x)

        # spectral GAT (GAT-S)
        e_S, _ = torch.max(torch.abs(e), dim=3)  # max along time
        e_S = e_S.transpose(1, 2) + self.pos_S

        gat_S = self.GAT_layer_S(e_S)
        out_S = self.pool_S(gat_S)  # (#bs, #node, #dim)

        # temporal GAT (GAT-T)
        e_T, _ = torch.max(torch.abs(e), dim=2)  # max along freq
        e_T = e_T.transpose(1, 2)

        gat_T = self.GAT_layer_T(e_T)
        out_T = self.pool_T(gat_T)

        # learnable master node
        master1 = self.master1.expand(x.size(0), -1, -1)
        master2 = self.master2.expand(x.size(0), -1, -1)

        # inference 1
        out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
            out_T, out_S, master=self.master1)

        out_S1 = self.pool_hS1(out_S1)
        out_T1 = self.pool_hT1(out_T1)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
            out_T1, out_S1, master=master1)
        out_T1 = out_T1 + out_T_aug
        out_S1 = out_S1 + out_S_aug
        master1 = master1 + master_aug

        # inference 2
        out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
            out_T, out_S, master=self.master2)
        out_S2 = self.pool_hS2(out_S2)
        out_T2 = self.pool_hT2(out_T2)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
            out_T2, out_S2, master=master2)
        out_T2 = out_T2 + out_T_aug
        out_S2 = out_S2 + out_S_aug
        master2 = master2 + master_aug

        out_T1 = self.drop_way(out_T1)
        out_T2 = self.drop_way(out_T2)
        out_S1 = self.drop_way(out_S1)
        out_S2 = self.drop_way(out_S2)
        master1 = self.drop_way(master1)
        master2 = self.drop_way(master2)

        out_T = torch.max(out_T1, out_T2)
        out_S = torch.max(out_S1, out_S2)
        master = torch.max(master1, master2)

        T_max, _ = torch.max(torch.abs(out_T), dim=1)
        T_avg = torch.mean(out_T, dim=1)

        S_max, _ = torch.max(torch.abs(out_S), dim=1)
        S_avg = torch.mean(out_S, dim=1)

        last_hidden = torch.cat(
            [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)

        last_hidden = self.drop(last_hidden)
        output = self.out_layer(last_hidden)

        return last_hidden, output

    def pad(self, x, max_len=64600): # 64600 samples = 4 giÃ¢y
      x_len = x.shape[0]
      if x_len >= max_len:
          return x[:max_len]
      # need to pad
      num_repeats = int(max_len / x_len) + 1
      padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
      return padded_x

    def getScore(self, path_file_test: Path | str):
      x, _ = soundfile.read(path_file_test)
      x_pad = self.pad(x)
      x_inp = Tensor(x_pad)
      x_inp = torch.unsqueeze(x_inp, 0).to(device)
      self.eval()
      with torch.no_grad():
            last_hidden, output = self.forward(x_inp)
      scores = F.softmax(output) ### cÃ³ 2 score ??? láº¥y tháº¿ nÃ o --> láº¥y cÃ¡i Ä‘iá»ƒm Ä‘Ã¡nh giÃ¡ x khÃ´ng pháº£i spoof, lÃ  real

#       print(scores[0])
      return scores[0][1].detach().cpu().numpy()

# Compute eer

In [58]:
from sklearn.metrics import roc_curve
import numpy as np

def compute_eer(y_true, y_score):
    fpr, tpr, thresholds = roc_curve(y_true, y_score, pos_label=1)
    fnr = 1 - tpr
    eer_threshold = thresholds[np.nanargmin(np.absolute((fnr - fpr)))]
    eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
    return eer, eer_threshold


# Dataset & DataLoader

In [59]:
import torch
from torch.utils.data import Dataset
import soundfile as sf
import os
import numpy as np
from collections import defaultdict
import random

class AASISTDataset(Dataset):
    def __init__(self, root_dir, list_file, max_len=64600, eval_mode=False):
        """
        root_dir: thÆ° má»¥c gá»‘c
        list_file: file txt chá»©a list wav + nhÃ£n
        max_len: sá»‘ sample (4s @16kHz = 64000, mÃ¬nh Ä‘á»ƒ 64600 cho cháº¯c)
        eval_mode: náº¿u True thÃ¬ crop á»Ÿ giá»¯a (deterministic), dÃ¹ng cho validation/test
        """
        self.root_dir = root_dir
        self.max_len = max_len
        self.eval_mode = eval_mode
        self.data = []
        self.speaker_to_indices = defaultdict(lambda: {"bonafide": [], "spoof": []})

        with open(list_file, "r") as f:
            for idx, line in enumerate(f):
                path, label_str = line.strip().split()
                label = 0 if label_str == "bonafide" else 1
                self.data.append((path, label))

                # láº¥y speaker id
                speaker_id = path.split("/")[0]  
                if label == 0:
                    self.speaker_to_indices[speaker_id]["bonafide"].append(idx)
                else:
                    self.speaker_to_indices[speaker_id]["spoof"].append(idx)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        rel_path, label = self.data[idx]
        full_path = os.path.join(self.root_dir, rel_path)
        waveform, _ = sf.read(full_path)

        L = len(waveform)

        if L > self.max_len:
            if self.eval_mode:
                # crop á»Ÿ giá»¯a cho validation/test
                start = (L - self.max_len) // 2
            else:
                # random crop cho training
                start = random.randint(0, L - self.max_len)
            waveform = waveform[start:start + self.max_len]

        elif L < self.max_len:
            # pad cho Ä‘á»§
            pad = self.max_len - L
            waveform = np.pad(waveform, (0, pad))

        waveform_tensor = torch.FloatTensor(waveform)
        return waveform_tensor, label


In [60]:
class AASISTValDataset(Dataset):
    def __init__(self, root_dir, list_file, max_len=64600):
        """
        Dataset cho VALIDATION/TEST
        - Center crop 4s (deterministic)
        """
        self.root_dir = root_dir
        self.max_len = max_len
        self.data = []

        with open(list_file, "r") as f:
            for line in f:
                path, label_str = line.strip().split()
                label = 0 if label_str == "bonafide" else 1
                self.data.append((path, label))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        rel_path, label = self.data[idx]
        full_path = os.path.join(self.root_dir, rel_path)

        waveform, _ = sf.read(full_path)
        L = len(waveform)

        if L > self.max_len:
            # center crop
            start = (L - self.max_len) // 2
            waveform = waveform[start:start + self.max_len]
        elif L < self.max_len:
            # pad
            pad = self.max_len - L
            waveform = np.pad(waveform, (0, pad))

        waveform_tensor = torch.FloatTensor(waveform)
        return waveform_tensor, label


In [61]:
root_dir = "/kaggle/input/vsasv-augment/vlsp2025/vlsp2025/train"
train_path = "/kaggle/input/cm-training/train_list.txt"
val_path = "/kaggle/input/cm-training/val_list.txt"

# Batch Sampler

In [62]:
import random
from torch.utils.data import Sampler

class SpeakerPairBatchSampler(Sampler):
    def __init__(self, dataset, speakers_per_batch=16, seed=42, infinite=True):
        self.dataset = dataset
        self.speakers = list(dataset.speaker_to_indices.keys())
        self.speakers_per_batch = speakers_per_batch
        self.seed = seed
        self.infinite = infinite
        random.seed(seed)

    def __iter__(self):
        while True:
            selected_speakers = random.sample(self.speakers, self.speakers_per_batch)
            batch_indices = []
            for spk in selected_speakers:
                bonafide_list = self.dataset.speaker_to_indices[spk]["bonafide"]
                spoof_list = self.dataset.speaker_to_indices[spk]["spoof"]

                if len(bonafide_list) == 0 or len(spoof_list) == 0:
                    continue  # skip náº¿u speaker thiáº¿u loáº¡i nÃ o

                b_idx = random.choice(bonafide_list)
                s_idx = random.choice(spoof_list)
                batch_indices.extend([b_idx, s_idx])

            yield batch_indices
            if not self.infinite:
                break

    def __len__(self):
        # KhÃ´ng xÃ¡c Ä‘á»‹nh náº¿u infinite
        return len(self.speakers) // self.speakers_per_batch


In [63]:
from torch.utils.data import DataLoader

# Train dataset: random crop (eval_mode=False)
train_dataset = AASISTDataset(
    root_dir=root_dir,
    list_file=train_path,
    eval_mode=False
)

# Val dataset: center crop (eval_mode=True)
val_dataset = AASISTValDataset(
    root_dir=root_dir,
    list_file=val_path,
)

# Sampler cho train (16 speaker â†’ 32 file má»—i batch)
train_sampler = SpeakerPairBatchSampler(
    dataset=train_dataset,
    speakers_per_batch=16,
    infinite=True  # cho train thÃ¬ láº·p vÃ´ háº¡n
)


In [64]:
train_loader = DataLoader(
    train_dataset,
    batch_sampler=train_sampler,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    num_workers=4,
    pin_memory=True
)

# Train function

In [69]:
import itertools
import torch
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm

def train(model,
          train_loader: DataLoader,
          val_loader: DataLoader,
          num_epochs: int = 30,
          steps_per_epoch: int = 2000,   # ðŸ‘ˆ sá»‘ step má»—i epoch
          lr: float = 1e-4,
          save_path: str = "./best_model.pth",
          global_step = 0,
          device: str = "cuda" if torch.cuda.is_available() else "cpu"):

    # Khá»Ÿi táº¡o wandb
    wandb.init(project="aasist", config={
        "epochs": num_epochs,
        "steps_per_epoch": steps_per_epoch,
        "learning_rate": lr,
        "optimizer": "AdamW",
        "scheduler": "StepLR",
        "architecture": "AASIST",
    })

    model = model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    criterion = torch.nn.CrossEntropyLoss()

    best_eer = 1.0

    for epoch in range(1, num_epochs + 1):
        model.train()
        train_loss = 0.0
        correct = 0
        total = 0

        # ðŸ‘‡ láº¥y iterator vÃ´ háº¡n tá»« train_loader
        infinite_loader = itertools.cycle(train_loader)
        loop = tqdm(range(steps_per_epoch), desc=f"[Epoch {epoch}/{num_epochs}] Training", leave=False)

        for _ in loop:
            x, y = next(infinite_loader)
            x, y = x.to(device), y.to(device)

            optimizer.zero_grad()
            _, output = model(x)

            loss = criterion(output, y)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * x.size(0)
            preds = torch.argmax(output, dim=1)
            correct += (preds == y).sum().item()
            total += x.size(0)

            loop.set_postfix(loss=loss.item())
            if global_step % 500 == 0 and global_step != 0:
                wandb.log({
                    "train/loss_step": loss.item(),
                    "train/lr": scheduler.get_last_lr()[0],
                    "step": global_step
                })

            global_step += 1

        train_acc = correct / total
        avg_train_loss = train_loss / total

        # Validation
        model.eval()
        all_labels, all_scores = [], []

        val_loop = tqdm(val_loader, desc=f"[Epoch {epoch}/{num_epochs}] Validation", leave=False)

        with torch.no_grad():
            for x, y in val_loop:
                x = x.to(device)
                _, output = model(x)
                scores = F.softmax(output, dim=1)[:, 1]  # bonafide score
                all_labels.extend(y.cpu().numpy())
                all_scores.extend(scores.cpu().numpy())

        eer, eer_threshold = compute_eer(np.array(all_labels), np.array(all_scores))

        print(f"[Epoch {epoch}] Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.4f} | Val EER: {eer:.4f}")

        wandb.log({
            "epoch": epoch,
            "train_loss": avg_train_loss,
            "train_acc": train_acc,
            "val_eer": eer,
            "lr": scheduler.get_last_lr()[0]
        })

        if eer < best_eer:
            best_eer = eer
            torch.save(model.state_dict(), save_path)
            print(f"âœ… New best model saved at epoch {epoch} with EER={eer:.4f}")

        scheduler.step()

    print("ðŸŽ‰ Finished Training.")
    wandb.finish()


In [66]:
!pip install wandb



In [67]:
import os, wandb

os.environ["WANDB_KEY"] = "4b8af864ea6d5ec9af172b42a4c40e4444e20cf7"
wandb.login(key=os.getenv("WANDB_KEY"))




True

In [70]:
model = AASIST_Model()
train(model, train_loader, val_loader, num_epochs=50)


                                                                                    

KeyboardInterrupt: 

