In [None]:
import torch
from torch import nn, Tensor
import math, torch, torchaudio
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from typing import Union
import numpy as np

# Model

## ResNet48 Architecture

In [3]:
class BasicBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.downsample = None
        if stride != 1 or in_channels != out_channels:
            self.downsample = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        identity = x
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        if self.downsample:
            identity = self.downsample(x)
        out += identity
        return self.relu(out)

class StatsPooling(nn.Module):
    def forward(self, x):
        # x: [B, C, F, T]  (sau stage4: [B, 256, 10, T/8])
        B, C, F, T = x.size()
        x = x.view(B, C * F, T)         # [B, 2560, T/8]
        mean = x.mean(dim=2)            # [B, 2560]
        std = x.std(dim=2)              # [B, 2560]
        return torch.cat([mean, std], dim=1)  # [B, 5120]

class ResNet48_ASV(nn.Module):
    def __init__(self, embedding_dim=256, num_speakers=None):
        super().__init__()
        # Conv0
        self.conv1 = nn.Conv2d(1, 96, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(96)
        self.relu = nn.ReLU(inplace=True)

        # Residual stages
        self.layer1 = self._make_layer(96, 96, 6, stride=1)    # ResBlock-1
        self.layer2 = self._make_layer(96, 128, 8, stride=2)   # ResBlock-2
        self.layer3 = self._make_layer(128, 160, 6, stride=2)  # ResBlock-3
        self.layer4 = self._make_layer(160, 256, 3, stride=2)  # ResBlock-4

        # Pooling + FC
        self.pooling = StatsPooling()
        self.fc = nn.Linear(5120, embedding_dim)

        # Classifier (optional)
        if num_speakers:
            self.classifier = nn.Linear(embedding_dim, num_speakers)
        else:
            self.classifier = None

    def _make_layer(self, in_c, out_c, num_blocks, stride):
        layers = [BasicBlock(in_c, out_c, stride)]
        for _ in range(1, num_blocks):
            layers.append(BasicBlock(out_c, out_c))
        return nn.Sequential(*layers)

    def forward(self, input_values, labels=None):
        # Backbone
        x = self.relu(self.bn1(self.conv1(input_values)))
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        # StatsPooling
        x = self.pooling(x)

        # Dense → embedding
        embeddings = F.normalize(self.fc(x), dim=1)  # [B, 256]
        return embeddings

## AASIST Architecture

In [None]:
class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_weight = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x):
        '''
        x   :(#bs, #node, #dim)
        '''
        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)
        return x

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map(self, x):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)
        att_map = torch.matmul(att_map, self.att_weight)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class HtrgGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        self.proj_type1 = nn.Linear(in_dim, in_dim)
        self.proj_type2 = nn.Linear(in_dim, in_dim)

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_projM = nn.Linear(in_dim, out_dim)

        self.att_weight11 = self._init_new_params(out_dim, 1)
        self.att_weight22 = self._init_new_params(out_dim, 1)
        self.att_weight12 = self._init_new_params(out_dim, 1)
        self.att_weightM = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        self.proj_with_attM = nn.Linear(in_dim, out_dim)
        self.proj_without_attM = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x1, x2, master=None):
        '''
        x1  :(#bs, #node, #dim)
        x2  :(#bs, #node, #dim)
        '''
        num_type1 = x1.size(1)
        num_type2 = x2.size(1)

        x1 = self.proj_type1(x1)
        x2 = self.proj_type2(x2)

        x = torch.cat([x1, x2], dim=1)

        if master is None:
            master = torch.mean(x, dim=1, keepdim=True)

        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x, num_type1, num_type2)

        # directional edge for master node
        master = self._update_master(x, master)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)

        x1 = x.narrow(1, 0, num_type1)
        x2 = x.narrow(1, num_type1, num_type2)

        return x1, x2, master

    def _update_master(self, x, master):

        att_map = self._derive_att_map_master(x, master)
        master = self._project_master(x, master, att_map)

        return master

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map_master(self, x, master):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = x * master
        att_map = torch.tanh(self.att_projM(att_map))

        att_map = torch.matmul(att_map, self.att_weightM)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _derive_att_map(self, x, num_type1, num_type2):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)

        att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)

        att_board[:, :num_type1, :num_type1, :] = torch.matmul(
            att_map[:, :num_type1, :num_type1, :], self.att_weight11)
        att_board[:, num_type1:, num_type1:, :] = torch.matmul(
            att_map[:, num_type1:, num_type1:, :], self.att_weight22)
        att_board[:, :num_type1, num_type1:, :] = torch.matmul(
            att_map[:, :num_type1, num_type1:, :], self.att_weight12)
        att_board[:, num_type1:, :num_type1, :] = torch.matmul(
            att_map[:, num_type1:, :num_type1, :], self.att_weight12)

        att_map = att_board

        # att_map = torch.matmul(att_map, self.att_weight12)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _project_master(self, x, master, att_map):

        x1 = self.proj_with_attM(torch.matmul(
            att_map.squeeze(-1).unsqueeze(1), x))
        x2 = self.proj_without_attM(master)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class GraphPool(nn.Module):
    def __init__(self, k: float, in_dim: int, p: Union[float, int]):
        super().__init__()
        self.k = k
        self.sigmoid = nn.Sigmoid()
        self.proj = nn.Linear(in_dim, 1)
        self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
        self.in_dim = in_dim

    def forward(self, h):
        Z = self.drop(h)
        weights = self.proj(Z)
        scores = self.sigmoid(weights)
        new_h = self.top_k_graph(scores, h, self.k)

        return new_h

    def top_k_graph(self, scores, h, k):
        """
        args
        =====
        scores: attention-based weights (#bs, #node, 1)
        h: graph data (#bs, #node, #dim)
        k: ratio of remaining nodes, (float)

        returns
        =====
        h: graph pool applied data (#bs, #node', #dim)
        """
        _, n_nodes, n_feat = h.size()
        n_nodes = max(int(n_nodes * k), 1)
        _, idx = torch.topk(scores, n_nodes, dim=1)
        idx = idx.expand(-1, -1, n_feat)

        h = h * scores
        h = torch.gather(h, 1, idx)

        return h


class CONV(nn.Module):
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10**(mel / 2595) - 1)

    def __init__(self,
                 out_channels,
                 kernel_size,
                 sample_rate=16000,
                 in_channels=1,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=False,
                 groups=1,
                 mask=False):
        super().__init__()
        if in_channels != 1:

            msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
                in_channels)
            raise ValueError(msg)
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate

        # Forcing the filters to be odd (i.e, perfectly symmetrics)
        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.mask = mask
        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')

        NFFT = 512
        f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        fmel = self.to_mel(f)
        fmelmax = np.max(fmel)
        fmelmin = np.min(fmel)
        filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
        filbandwidthsf = self.to_hz(filbandwidthsmel)

        self.mel = filbandwidthsf
        self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
                                  (self.kernel_size - 1) / 2 + 1)
        self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
        for i in range(len(self.mel) - 1):
            fmin = self.mel[i]
            fmax = self.mel[i + 1]
            hHigh = (2*fmax/self.sample_rate) * \
                np.sinc(2*fmax*self.hsupp/self.sample_rate)
            hLow = (2*fmin/self.sample_rate) * \
                np.sinc(2*fmin*self.hsupp/self.sample_rate)
            hideal = hHigh - hLow

            self.band_pass[i, :] = Tensor(np.hamming(
                self.kernel_size)) * Tensor(hideal)

    def forward(self, x, mask=False):
        band_pass_filter = self.band_pass.clone().to(x.device)
        if mask:
            A = np.random.uniform(0, 20)
            A = int(A)
            A0 = random.randint(0, band_pass_filter.shape[0] - A)
            band_pass_filter[A0:A0 + A, :] = 0
        else:
            band_pass_filter = band_pass_filter

        self.filters = (band_pass_filter).view(self.out_channels, 1,
                                               self.kernel_size)

        return F.conv1d(x,
                        self.filters,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        bias=None,
                        groups=1)


class Residual_block(nn.Module):
    def __init__(self, nb_filts, first=False):
        super().__init__()
        self.first = first

        if not self.first:
            self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
        self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(1, 1),
                               stride=1)
        self.selu = nn.SELU(inplace=True)

        self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
        self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(0, 1),
                               stride=1)

        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
                                             out_channels=nb_filts[1],
                                             padding=(0, 1),
                                             kernel_size=(1, 3),
                                             stride=1)

        else:
            self.downsample = False
        self.mp = nn.MaxPool2d((1, 3))  # self.mp = nn.MaxPool2d((1,4))

    def forward(self, x):
        identity = x
        if not self.first:
            out = self.bn1(x)
            out = self.selu(out)
        else:
            out = x
        out = self.conv1(x)

        # print('out',out.shape)
        out = self.bn2(out)
        out = self.selu(out)
        # print('out',out.shape)
        out = self.conv2(out)
        #print('conv2 out',out.shape)
        if self.downsample:
            identity = self.conv_downsample(identity)

        out += identity
        out = self.mp(out)
        return out


class AASIST_Model(nn.Module):
    def __init__(self):
        super().__init__()
        d_args = {
        "architecture": "AASIST",
        "nb_samp": 64600,
        "first_conv": 128,
        "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
        "gat_dims": [64, 32],
        "pool_ratios": [0.5, 0.7, 0.5, 0.5],
        "temperatures": [2.0, 2.0, 100.0, 100.0]
        }

        self.d_args = d_args
        filts = d_args["filts"]
        gat_dims = d_args["gat_dims"]
        pool_ratios = d_args["pool_ratios"]
        temperatures = d_args["temperatures"]

        self.conv_time = CONV(out_channels=filts[0],
                              kernel_size=d_args["first_conv"],
                              in_channels=1)
        self.first_bn = nn.BatchNorm2d(num_features=1)

        self.drop = nn.Dropout(0.5, inplace=True)
        self.drop_way = nn.Dropout(0.2, inplace=True)
        self.selu = nn.SELU(inplace=True)

        self.encoder = nn.Sequential(
            nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
            nn.Sequential(Residual_block(nb_filts=filts[2])),
            nn.Sequential(Residual_block(nb_filts=filts[3])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])))

        self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
        self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))

        self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[0])
        self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[1])

        self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
        self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
        self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.out_layer = nn.Linear(5 * gat_dims[1], 2)

    def forward(self, x, Freq_aug=False):

        x = x.unsqueeze(1)
        x = self.conv_time(x, mask=Freq_aug)
        x = x.unsqueeze(dim=1)
        x = F.max_pool2d(torch.abs(x), (3, 3))
        x = self.first_bn(x)
        x = self.selu(x)

        # get embeddings using encoder
        # (#bs, #filt, #spec, #seq)
        e = self.encoder(x)

        # spectral GAT (GAT-S)
        e_S, _ = torch.max(torch.abs(e), dim=3)  # max along time
        e_S = e_S.transpose(1, 2) + self.pos_S

        gat_S = self.GAT_layer_S(e_S)
        out_S = self.pool_S(gat_S)  # (#bs, #node, #dim)

        # temporal GAT (GAT-T)
        e_T, _ = torch.max(torch.abs(e), dim=2)  # max along freq
        e_T = e_T.transpose(1, 2)

        gat_T = self.GAT_layer_T(e_T)
        out_T = self.pool_T(gat_T)

        # learnable master node
        master1 = self.master1.expand(x.size(0), -1, -1)
        master2 = self.master2.expand(x.size(0), -1, -1)

        # inference 1
        out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
            out_T, out_S, master=self.master1)

        out_S1 = self.pool_hS1(out_S1)
        out_T1 = self.pool_hT1(out_T1)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
            out_T1, out_S1, master=master1)
        out_T1 = out_T1 + out_T_aug
        out_S1 = out_S1 + out_S_aug
        master1 = master1 + master_aug

        # inference 2
        out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
            out_T, out_S, master=self.master2)
        out_S2 = self.pool_hS2(out_S2)
        out_T2 = self.pool_hT2(out_T2)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
            out_T2, out_S2, master=master2)
        out_T2 = out_T2 + out_T_aug
        out_S2 = out_S2 + out_S_aug
        master2 = master2 + master_aug

        out_T1 = self.drop_way(out_T1)
        out_T2 = self.drop_way(out_T2)
        out_S1 = self.drop_way(out_S1)
        out_S2 = self.drop_way(out_S2)
        master1 = self.drop_way(master1)
        master2 = self.drop_way(master2)

        out_T = torch.max(out_T1, out_T2)
        out_S = torch.max(out_S1, out_S2)
        master = torch.max(master1, master2)

        T_max, _ = torch.max(torch.abs(out_T), dim=1)
        T_avg = torch.mean(out_T, dim=1)

        S_max, _ = torch.max(torch.abs(out_S), dim=1)
        S_avg = torch.mean(out_S, dim=1)

        last_hidden = torch.cat(
            [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)

        last_hidden = self.drop(last_hidden)
        output = self.out_layer(last_hidden)

        return last_hidden, output

    def pad(self, x, max_len=64600): # 64600 samples = 4 giây
      x_len = x.shape[0]
      if x_len >= max_len:
          return x[:max_len]
      # need to pad
      num_repeats = int(max_len / x_len) + 1
      padded_x = np.tile(x, (1, num_repeats))[:, :max_len][0]
      return padded_x

    def getScore(self, path_file_test: Path | str):
      x, _ = soundfile.read(path_file_test)
      x_pad = self.pad(x)
      x_inp = Tensor(x_pad)
      x_inp = torch.unsqueeze(x_inp, 0).to(device)
      self.eval()
      with torch.no_grad():
            last_hidden, output = self.forward(x_inp)
      scores = F.softmax(output) ### có 2 score ??? lấy thế nào --> lấy cái điểm đánh giá x không phải spoof, là real

#       print(scores[0])
      return scores[0][1].detach().cpu().numpy()

# Load Model

In [None]:
from safetensors.torch import load_file
import torch

device = "cuda" if torch.cuda.is_available() else "cpu"

In [13]:
r48 = ResNet48_ASV().to(device)
state_dict = load_file("/kaggle/input/model-final/final_resnet48asv_50k.safetensors")
r48.load_state_dict(state_dict)
r48.eval()

ResNet48_ASV(
  (conv1): Conv2d(1, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(96, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(96, 96, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bia

In [None]:
aasist = AASIST_Model().to(device)
aasist.load_state_dict(torch.load("/kaggle/input/model-final/aasist.pth"))
aasist.eval()

# Data Processing

In [14]:
!pip install webrtcvad

Collecting webrtcvad
  Downloading webrtcvad-2.0.10.tar.gz (66 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.2/66.2 kB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: webrtcvad
  Building wheel for webrtcvad (setup.py) ... [?25l[?25hdone
  Created wheel for webrtcvad: filename=webrtcvad-2.0.10-cp311-cp311-linux_x86_64.whl size=73499 sha256=728961e2df65f88d731e61183ed649dbdf171bfcebeaecb2afffe9371f0635e9
  Stored in directory: /root/.cache/pip/wheels/94/65/3f/292d0b656be33d1c801831201c74b5f68f41a2ae465ff2ee2f
Successfully built webrtcvad
Installing collected packages: webrtcvad
Successfully installed webrtcvad-2.0.10


## Remove Silence


In [15]:
import webrtcvad
def remove_silence(waveform, sample_rate=16000, frame_duration_ms=30):
    vad = webrtcvad.Vad(2)  # Moderate aggressiveness (0-3)
    waveform_np = waveform.squeeze().numpy()
    waveform_int16 = (waveform_np * 32767).astype(np.int16)
    frame_length = int(sample_rate * frame_duration_ms / 1000)
    frames = [waveform_int16[i:i+frame_length] for i in range(0, len(waveform_int16), frame_length)]
    
    voiced_frames = []
    for frame in frames:
        if len(frame) == frame_length and vad.is_speech(frame.tobytes(), sample_rate):
            voiced_frames.append(frame)
    
    if voiced_frames:
        voiced_waveform = np.concatenate(voiced_frames).astype(np.float32) / 32767
        return torch.tensor(voiced_waveform, dtype=torch.float32).unsqueeze(0)
    return waveform

In [16]:
def extract_mfbe(waveform, sample_rate=16000, n_mels=80):
    # Tính Mel spectrogram
    mel_spec = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=400,
        win_length=400,
        hop_length=160,
        n_mels=n_mels,
        power=2.0
    )(waveform)

    # Convert sang log-mel-filterbank energy
    mfbe = torch.log(mel_spec + 1e-6)
    return mfbe

# Get Embedding

## ASV Embed Function


In [22]:
import torch
import torch.nn.functional as F
import torchaudio
import numpy as np

def get_asv_score(
    wav_path, model, device="cuda",
    crop_duration=2, sr=16000, hop_ratio=0.5, full_dur=5,
    sliding_weight=0.7
):
    # 1) load -> resample -> mono
    waveform, sr_orig = torchaudio.load(wav_path)            # [C, T]
    if sr_orig != sr:
        waveform = torchaudio.functional.resample(waveform, sr_orig, sr)
    if waveform.shape[0] > 1:
        waveform = waveform.mean(dim=0, keepdim=True)        # [1, T]
    waveform = remove_silence(waveform, sample_rate=sr)
    if waveform.dim() == 1:
        waveform = waveform.unsqueeze(0)                     # [1, T]
    waveform = waveform.float()

    # 2) tạo crops 2s
    nb_samp = int(crop_duration * sr)
    hop = max(1, int(nb_samp * hop_ratio))
    T = waveform.shape[1]
    crops = []
    if T <= nb_samp:
        pad = nb_samp - T
        crops.append(F.pad(waveform, (0, pad)))              # [1, nb_samp]
    else:
        for s in range(0, T - nb_samp + 1, hop):
            crops.append(waveform[:, s:s+nb_samp])           # [1, nb_samp]
        if (T - nb_samp) % hop != 0:
            crops.append(waveform[:, -nb_samp:])

    # 3) MFBE cho TẤT CẢ crop -> batch [S,1,F,T]
    feats = []
    for c in crops:
        f = extract_mfbe(c, sample_rate=sr)                  # [F, Tm] hoặc [1,F,Tm]
        if f.dim() == 2:
            f = f.unsqueeze(0)                               # -> [1, F, Tm]
        elif f.dim() == 3 and f.size(0) == 1:
            pass                                             # đã [1,F,Tm]
        else:
            raise RuntimeError(f"Unexpected feature shape: {tuple(f.shape)}")
        feats.append(f)

    # cắt về cùng Tm để stack
    Tm = min(fe.shape[2] for fe in feats)                    # chú ý: feats là [1,F,T]
    feats = [fe[:, :, :Tm] for fe in feats]                  # giữ C=1
    batch = torch.stack(feats, dim=0).to(device)             # [S, 1, F, Tm]  ✅ đúng trật tự
    # KHÔNG unsqueeze(0) nữa! (sẽ thành [1,S,F,T])

    # 4) forward 1 lần + L2 -> mean -> L2
    model.eval()
    with torch.inference_mode():
        try:
            out = model(batch, aug=False)                    # nếu model có 'aug'
        except TypeError:
            out = model(batch)
        emb_batch = out[0] if isinstance(out, (tuple, list)) else out  # [S, D]

    emb = F.normalize(emb_batch, dim=-1).mean(dim=0)         # L2 từng crop rồi mean
    emb = F.normalize(emb, dim=-1)
    return emb.detach().cpu().numpy().astype(np.float32)


## CM Embed Function

In [None]:
import torch
import torch.nn.functional as F
import torchaudio
import numpy as np

def get_cm_score(wav_path, model, device="cuda",
                 nb_samp=64600, use_overlay=True, hop_ratio=0.5):
    wav, sr = torchaudio.load(wav_path)             # [C, T]
    if sr != 16000:
        wav = torchaudio.functional.resample(wav, sr, 16000)
    if wav.shape[0] > 1:
        wav = wav.mean(dim=0, keepdim=True)         # [1, T]

    T = wav.shape[1]
    crops = []

    if not use_overlay:
        # CENTER-CROP (pad cân 2 đầu nếu ngắn)
        if T < nb_samp:
            pad = nb_samp - T
            left = pad // 2; right = pad - left
            wav = F.pad(wav, (left, right))
        else:
            start = (T - nb_samp) // 2
            wav = wav[:, start:start+nb_samp]
        crops = [wav]                                # 1 crop
    else:
        # OVERLAY 4s + mean
        hop = max(1, int(nb_samp * hop_ratio))
        if T <= nb_samp:
            pad = nb_samp - T
            left = pad // 2; right = pad - left
            crops = [F.pad(wav, (left, right))]
        else:
            for s in range(0, T - nb_samp + 1, hop):
                crops.append(wav[:, s:s+nb_samp])
            if (T - nb_samp) % hop != 0:
                crops.append(wav[:, -nb_samp:])      # phủ đuôi

    # AASIST thường nhận [B, T]; nếu model bạn cần [B,1,T], thêm .unsqueeze(1)
    batch = torch.stack([c.squeeze(0) for c in crops], dim=0).to(device)  # [S, T]
    model.eval()
    with torch.inference_mode():
        out = model(batch)                           # -> (emb, logits) hoặc emb
        if isinstance(out, (tuple, list)):
            emb_batch, logits = out[0], out[1]      # [S, D], [S, C] (hoặc [S])
        else:
            emb_batch, logits = out, None

    # ===== CM score: mean LOGIT rồi sigmoid (ổn định hơn mean prob) =====
    cm_score = None
    if logits is not None:
        if logits.dim() == 2 and logits.size(1) >= 2:
            spoof_logit = logits[:, 1].mean()
            cm_score = torch.sigmoid(spoof_logit).item()
        else:
            # 1D logit
            cm_score = torch.sigmoid(logits.mean()).item()

    # ===== Embedding pooling: L2 -> mean -> L2 =====
    emb = F.normalize(emb_batch, dim=-1).mean(dim=0)
    emb = F.normalize(emb, dim=-1).detach().cpu().numpy().astype(np.float32)

    return emb, cm_score if cm_score is not None else 0.0

## Train Embed

In [5]:
import os
import numpy as np
from tqdm import tqdm

# 1. Load danh sách speaker đã train
train_speakers = set()
with open("/kaggle/input/log-training-eca-aasist/train.txt", "r") as f:
    for line in f:
        spk = line.strip().split()[0]
        train_speakers.add(spk)

In [6]:
root_dir = "/kaggle/input/vsasv-train/vlsp_train/home4/vuhl/VSASV-Dataset/vlsp2025/train"
out_root_asv = "r48_embeddings_train"
out_root_cm = "aasist_embeddings_train"

In [None]:
os.makedirs(out_root_asv, exist_ok=True)
os.makedirs(out_root_cm, exist_ok=True)

In [18]:
from tqdm import tqdm
import os

all_wavs = []
for dirpath, dirnames, filenames in tqdm(os.walk(root_dir),
                                        desc="Scanning folders",
                                        unit="dir"):
    for fname in filenames:
        if not fname.lower().endswith((".wav", ".flac")):
            continue
        rel_path = os.path.relpath(dirpath, root_dir)  # e.g. "id00016/bonafide"
        speaker_id = rel_path.split(os.sep)[0]
        if speaker_id in train_speakers:
            continue
        rel_file = os.path.join(rel_path, fname)
        all_wavs.append(rel_file)


Scanning folders: 2245dir [00:40, 55.25dir/s]


In [8]:
from tqdm import tqdm
import os


# 3. Duyệt thư mục đệ quy với tqdm
all_wavs = []
for dirpath, dirnames, filenames in tqdm(os.walk(root_dir),
                                        desc="Scanning directories",
                                        unit="dir"):
    for fname in filenames:
        if not fname.lower().endswith((".wav", ".flac")):
            continue
        rel_path = os.path.relpath(dirpath, root_dir)  # e.g. "id00016/bonafide"
        speaker_id = rel_path.split(os.sep)[0]
        if speaker_id in train_speakers:
            continue
        rel_file = os.path.join(rel_path, fname)
        all_wavs.append(rel_file)


Scanning directories: 2245dir [00:40, 55.09dir/s] 


In [23]:
for wav_rel in tqdm(all_wavs, desc="Extract ASV Score"):
    full_path = os.path.join(root_dir, wav_rel)
    if not os.path.exists(full_path):
        print(f"[WARN] File not found: {full_path}")
        continue

    embedding = get_asv_score(full_path, model=r48, device="cuda")

    # 4.1 Xác định đường dẫn output dựa trên rel_path
    rel_no_ext, _ = os.path.splitext(wav_rel)
    out_path = os.path.join(out_root_asv, rel_no_ext + ".npy")

    # 4.2 Tạo các thư mục con nếu cần
    os.makedirs(os.path.dirname(out_path), exist_ok=True)

    # 4.3 Lưu embedding
    np.save(out_path, embedding)


print("Done! Đã lưu embedding theo rel_path trong thư mục r48_embeddings_train/")


Extract ASV Score: 100%|██████████| 18717/18717 [21:35<00:00, 14.44it/s]

Done! Đã lưu embedding theo rel_path trong thư mục r48_embeddings_train/





In [None]:
import os, json
import numpy as np
from tqdm import tqdm

cm_scores = {}  # uid (hoặc rel path không đuôi) -> score

for wav_rel in tqdm(all_wavs, desc="Extract CM Score"):
    full_path = os.path.join(root_dir, wav_rel)
    if not os.path.exists(full_path):
        print(f"[WARN] File not found: {full_path}")
        continue

    # get_cm_score trả (embedding, score)
    emb, score = get_cm_score(full_path, model=aasist, device="cuda")

    # 4.1 Xác định đường dẫn output dựa trên rel_path
    rel_no_ext, _ = os.path.splitext(wav_rel)
    out_path = os.path.join(out_root_cm, rel_no_ext + ".npy")

    # 4.2 Tạo các thư mục con nếu cần
    os.makedirs(os.path.dirname(out_root_cm), exist_ok=True)

    # 4.3 Lưu embedding (đảm bảo float32 1D)
    emb = np.asarray(emb, dtype=np.float32).reshape(-1)
    np.save(out_root_cm, emb)

    # 4.4 Lưu score vào dict (dùng rel path không đuôi làm key)
    cm_scores[rel_no_ext] = float(score)

# 5) Dump JSON
json_path = os.path.join("aasist_scores_train.json")
with open(json_path, "w", encoding="utf-8") as f:
    json.dump(cm_scores, f, ensure_ascii=False, indent=2, sort_keys=True)

print(f"Done! Saved embeddings under '{out_root}' and CM scores to '{json_path}'.")

In [24]:
import zipfile

OUTPUT_DIR = "/kaggle/working/r48_embeddings_train"
ZIP_NAME = f"/kaggle/working/r48_embeddings_train.zip"

with zipfile.ZipFile(ZIP_NAME, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(OUTPUT_DIR):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, OUTPUT_DIR)

            # Thêm file vào zip
            zipf.write(file_path, arcname)

            # Xoá file sau khi đã nén xong
            os.remove(file_path)

In [None]:
import zipfile

OUTPUT_DIR = "/kaggle/working/aasist_embeddings_train"
ZIP_NAME = f"/kaggle/working/aasist_embeddings_train.zip"

with zipfile.ZipFile(ZIP_NAME, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(OUTPUT_DIR):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, OUTPUT_DIR)

            # Thêm file vào zip
            zipf.write(file_path, arcname)

            # Xoá file sau khi đã nén xong
            os.remove(file_path)

## Public Test Embed

In [25]:
root_dir = "/kaggle/input/public-test-vsasv/public_test/home4/vuhl/VSASV-Dataset/vlsp2025/"
list_path = "/kaggle/input/public-test-vsasv/public_test_vlsp.txt"

In [26]:
trials = []
with open(list_path, "r") as f:
    for line in f:
        enroll, test, _ = line.strip().split()
        trials.append((enroll, test))

In [27]:
all_wavs = set()
for path1, path2 in trials:
    all_wavs.add(path1)
    all_wavs.add(path2)

In [None]:
from tqdm import tqdm
import os
os.makedirs("aasist_embedding_pub", exist_ok=True)

cm_score = {}
for wav_path in tqdm(all_wavs, desc="Extract CM Score"):
    full_path = os.path.join(root_dir, wav_path)
    if not os.path.exists(full_path):
        print(f"File not found {full_path}")
        continue

    embedding, score = get_cm_score(full_path, aasist)
    # Lưu embedding
    uid = os.path.splitext(os.path.basename(wav_path))[0]
    np.save(f"aasist_embedding_pub/{uid}.npy", embedding)

    # Lưu score
    cm_score[wav_path] = score

In [None]:
import json
with open("aasist_score_pub.json", "w") as f:
    json.dump(cm_score, f, indent=2)

In [None]:
import zipfile

OUTPUT_DIR = "/kaggle/working/aasist_embedding_pub"
ZIP_NAME = f"/kaggle/working/aasist_embedding_pub.zip"

with zipfile.ZipFile(ZIP_NAME, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(OUTPUT_DIR):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, OUTPUT_DIR)

            # Thêm file vào zip
            zipf.write(file_path, arcname)

            # Xoá file sau khi đã nén xong
            os.remove(file_path)

In [29]:
from tqdm import tqdm
import os
os.makedirs("r48_embedding_pub", exist_ok=True)


for wav_path in tqdm(all_wavs, desc="Extract ASV Score"):
    full_path = os.path.join(root_dir, wav_path)
    if not os.path.exists(full_path):
        print(f"File not found {full_path}")
        continue

    embedding = get_asv_score(full_path, r48)
    # Lưu embedding
    uid = os.path.splitext(os.path.basename(wav_path))[0]
    np.save(f"r48_embedding_pub/{uid}.npy", embedding)

Extract CM Score: 100%|██████████| 73614/73614 [53:52<00:00, 22.77it/s]  


In [30]:
import zipfile

OUTPUT_DIR = "/kaggle/working/r48_embedding_pub"
ZIP_NAME = f"/kaggle/working/r48_embedding_pub.zip"

with zipfile.ZipFile(ZIP_NAME, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(OUTPUT_DIR):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, OUTPUT_DIR)

            # Thêm file vào zip
            zipf.write(file_path, arcname)

            # Xoá file sau khi đã nén xong
            os.remove(file_path)

## Private Test Embed

In [31]:
root_dir = "/kaggle/input/private-test-vlsp/private_test/home4/vuhl/VSASV-Dataset/vlsp2025/"
list_path = "/kaggle/input/private-test-vlsp/private_test_vlsp.txt"

In [33]:
trials = []
with open(list_path, "r") as f:
    for line in f:
        enroll, test= line.strip().split()
        trials.append((enroll, test))

In [34]:
all_wavs = set()
for path1, path2 in trials:
    all_wavs.add(path1)
    all_wavs.add(path2)

In [None]:
from tqdm import tqdm
import os
os.makedirs("aasist_embedding_private", exist_ok=True)

cm_score = {}
for wav_path in tqdm(all_wavs, desc="Extract CM Score"):
    full_path = os.path.join(root_dir, wav_path)
    if not os.path.exists(full_path):
        print(f"File not found {full_path}")
        continue

    embedding, score = get_cm_score(full_path, aasist)
    # Lưu embedding
    uid = os.path.splitext(os.path.basename(wav_path))[0]
    np.save(f"aasist_embedding_private/{uid}.npy", embedding)

    # Lưu score
    cm_score[wav_path] = score

In [None]:
import json
with open("aasist_score_private.json", "w") as f:
    json.dump(cm_score, f, indent=2)

In [None]:
import zipfile

OUTPUT_DIR = "/kaggle/working/aasist_embedding_private"
ZIP_NAME = f"/kaggle/working/aasist_embedding_private.zip"

with zipfile.ZipFile(ZIP_NAME, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(OUTPUT_DIR):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, OUTPUT_DIR)

            # Thêm file vào zip
            zipf.write(file_path, arcname)

            # Xoá file sau khi đã nén xong
            os.remove(file_path)

In [35]:
from tqdm import tqdm
import os
os.makedirs("r48_embedding_private", exist_ok=True)


for wav_path in tqdm(all_wavs, desc="Extract ASV Score"):
    full_path = os.path.join(root_dir, wav_path)
    if not os.path.exists(full_path):
        print(f"File not found {full_path}")
        continue

    embedding = get_asv_score(full_path, r48)
    # Lưu embedding
    uid = os.path.splitext(os.path.basename(wav_path))[0]
    np.save(f"r48_embedding_private/{uid}.npy", embedding)

Extract CM Score: 100%|██████████| 103417/103417 [1:20:02<00:00, 21.54it/s]


In [None]:
from tqdm import tqdm
import os
os.makedirs("aasist_embedding_private", exist_ok=True)

cm_score = {}
for wav_path in tqdm(all_wavs, desc="Extract CM Score"):
    full_path = os.path.join(root_dir, wav_path)
    if not os.path.exists(full_path):
        print(f"File not found {full_path}")
        continue

    embedding, score = get_cm_score(full_path, aasist)
    # Lưu embedding
    uid = os.path.splitext(os.path.basename(wav_path))[0]
    np.save(f"aasist_embedding_private/{uid}.npy", embedding)

    # Lưu score
    cm_score[wav_path] = score

In [36]:
import zipfile

OUTPUT_DIR = "/kaggle/working/r48_embedding_private"
ZIP_NAME = f"/kaggle/working/r48_embedding_private.zip"

with zipfile.ZipFile(ZIP_NAME, 'w', zipfile.ZIP_DEFLATED) as zipf:
    for root, _, files in os.walk(OUTPUT_DIR):
        for file in files:
            file_path = os.path.join(root, file)
            arcname = os.path.relpath(file_path, OUTPUT_DIR)

            # Thêm file vào zip
            zipf.write(file_path, arcname)

            # Xoá file sau khi đã nén xong
            os.remove(file_path)