# Mount Drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

# Installations

In [None]:
%%capture
!pip install datasets==3.6.0
!pip install torchcontrib
!pip install accelerate
!pip install torchmetrics

In [None]:
import numpy as np
from dataclasses import dataclass, field
from typing import List, Union, Tuple, Dict
import os
import torch
from torch.utils.data import Dataset, DataLoader
import torchaudio
from datasets import load_dataset
from torch.nn.utils.rnn import pad_sequence
import torch.nn as nn
from torchcontrib.optim import SWA
from pathlib import Path
import soundfile as sf
from torch import Tensor
from torchmetrics.classification import BinaryEER
from torchmetrics.functional.classification import binary_roc
import csv
from pathlib import Path
from datasets import concatenate_datasets, DatasetDict

In [None]:
from huggingface_hub import notebook_login
notebook_login() 

# Config

In [None]:

config_info = {
    "dataset_name": "", # HF dataset
    "num_epochs": 10,
    "loss": "CCE",
    "eval_all_best": "True",
    "eval_output": "eval_scores_using_best_dev_model.txt",
    "cudnn_deterministic_toggle": "True",
    "cudnn_benchmark_toggle": "False",
    "freq_aug": "False",
    "model_config": {
        "architecture": "AASIST",
        "nb_samp": 62801 , # 64600
        "first_conv": 128,
        "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
        "gat_dims": [64, 32],
        "pool_ratios": [0.5, 0.7, 0.5, 0.5],
        "temperatures": [2.0, 2.0, 100.0, 100.0]
    },
    "optim_config": {
        "optimizer": "adam",
        "amsgrad": "False",
        "base_lr": 0.0001,
        "lr_min": 0.000005,
        "betas": [0.9, 0.999],
        "weight_decay": 0.0001,
        "scheduler": "cosine"
    }
}

model_config = config_info['model_config']
optim_config = config_info['optim_config']

# ============== Dirs ==============

output_dir = "/content/drive/MyDrive/Colab Notebooks/fakevoices"

# Dataset

In [None]:
class DatasetAFAD(Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]

        audio_dict = sample["audio"]
        waveform = torch.tensor(audio_dict["array"]).float()
        label = sample["label"]

        label = 1 if label == "real" else 0
        return waveform, label

In [None]:
def speaker_disjoint_dataset(hf_dataset):
  dataset = load_dataset(hf_dataset, trust_remote_code=True)
  merged_dataset = concatenate_datasets([dataset['train'], dataset['validation'], dataset['test']])

  unique_speakers = merged_dataset.unique('speaker_id')

  np.random.seed(42)
  shuffled_speakers = np.random.permutation(unique_speakers)

  n_speakers = len(shuffled_speakers)
  n_train = int(0.7 * n_speakers)
  n_val = int(0.15 * n_speakers)

  train_speakers = shuffled_speakers[:n_train]
  val_speakers = shuffled_speakers[n_train:n_train+n_val]
  test_speakers = shuffled_speakers[n_train+n_val:]

  train_dataset = merged_dataset.filter(lambda x: x['speaker_id'] in train_speakers)
  val_dataset = merged_dataset.filter(lambda x: x['speaker_id'] in val_speakers)
  test_dataset = merged_dataset.filter(lambda x: x['speaker_id'] in test_speakers)

  new_dataset = DatasetDict({
      'train': train_dataset,
      'validation': val_dataset,
      'test': test_dataset
  })

  print(f"\nTrain samples: {len(new_dataset['train'])}")
  print(f"Val samples: {len(new_dataset['validation'])}")
  print(f"Test samples: {len(new_dataset['test'])}")
  return new_dataset


In [None]:
speaker_disjoint = speaker_disjoint_dataset(config_info['dataset_name'])

# Collate function

In [None]:
def collate_fn(batch):
    """
    Pads a batch of variable-length waveforms to the same length.
    Batch shape: torch.Size([2, 62801])
    labels: (0, 0)
    """
    waveforms, labels = zip(*batch)
    padded = pad_sequence(waveforms, batch_first=True)
    labels = torch.tensor(labels, dtype=torch.long)
    return padded, labels

# Dataloaders

In [None]:
def data_loaders():
    """
    (tensor([[2.3280e-13, 3.2291e-12, 1.2027e-12,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00],
        [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00, 0.0000e+00,
         0.0000e+00]]), tensor([0, 1])) ('AFAD_F_00044', 'AFAD_R_00023')
    """

    train_dataset = speaker_disjoint['train']
    dev_dataset = speaker_disjoint['validation']
    test_dataset = speaker_disjoint['test']

    train_data = DatasetAFAD(train_dataset)
    dev_data = DatasetAFAD(dev_dataset)
    test_data = DatasetAFAD(test_dataset)

    train_loader = DataLoader(train_data,
                                    batch_size=8,
                                    shuffle=True,
                                    drop_last=True,
                                    pin_memory=True,
                                    collate_fn=collate_fn
                                    
                                    )
    dev_loader = DataLoader(dev_data,
                                batch_size=8,
                                shuffle=False,
                                drop_last=False,
                                pin_memory=True,
                                collate_fn=collate_fn)
    test_loader = DataLoader(test_data,
                                batch_size=8,
                                shuffle=False,
                                drop_last=False,
                                pin_memory=True,
                                collate_fn=collate_fn)
    return train_loader, dev_loader, test_loader

train_loader, dev_loader, test_loader = data_loaders()


# Utils

In [None]:
import os
import random
import sys

import numpy as np
import torch


def str_to_bool(val):
    """Convert a string representation of truth to true (1) or false (0).
    Copied from the python implementation distutils.utils.strtobool

    True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
    are 'n', 'no', 'f', 'false', 'off', and '0'.  Raises ValueError if
    'val' is anything else.
    >>> str_to_bool('YES')
    1
    >>> str_to_bool('FALSE')
    0
    """
    val = val.lower()
    if val in ('y', 'yes', 't', 'true', 'on', '1'):
        return True
    if val in ('n', 'no', 'f', 'false', 'off', '0'):
        return False
    raise ValueError('invalid truth value {}'.format(val))


def cosine_annealing(step, total_steps, lr_max, lr_min):
    """Cosine Annealing for learning rate decay scheduler"""
    return lr_min + (lr_max -
                     lr_min) * 0.5 * (1 + np.cos(step / total_steps * np.pi))


def keras_decay(step, decay=0.0001):
    """Learning rate decay in Keras-style"""
    return 1. / (1. + decay * step)


class SGDRScheduler(torch.optim.lr_scheduler._LRScheduler):
    """SGD with restarts scheduler"""
    def __init__(self, optimizer, T0, T_mul, eta_min, last_epoch=-1):
        self.Ti = T0
        self.T_mul = T_mul
        self.eta_min = eta_min

        self.last_restart = 0

        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        T_cur = self.last_epoch - self.last_restart
        if T_cur >= self.Ti:
            self.last_restart = self.last_epoch
            self.Ti = self.Ti * self.T_mul
            T_cur = 0

        return [
            self.eta_min + (base_lr - self.eta_min) *
            (1 + np.cos(np.pi * T_cur / self.Ti)) / 2
            for base_lr in self.base_lrs
        ]


def _get_optimizer(model_parameters, optim_config):
    """Defines optimizer according to the given config"""
    optimizer_name = optim_config['optimizer']

    if optimizer_name == 'sgd':
        optimizer = torch.optim.SGD(model_parameters,
                                    lr=optim_config['base_lr'],
                                    momentum=optim_config['momentum'],
                                    weight_decay=optim_config['weight_decay'],
                                    nesterov=optim_config['nesterov'])
    elif optimizer_name == 'adam':
        optimizer = torch.optim.Adam(model_parameters,
                                     lr=optim_config['base_lr'],
                                     betas=optim_config['betas'],
                                     weight_decay=optim_config['weight_decay'],
                                     amsgrad=str_to_bool(
                                         optim_config['amsgrad']))
    else:
        print('Un-known optimizer', optimizer_name)
        sys.exit()

    return optimizer


def _get_scheduler(optimizer, optim_config):
    """
    Defines learning rate scheduler according to the given config
    """
    if optim_config['scheduler'] == 'multistep':
        scheduler = torch.optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=optim_config['milestones'],
            gamma=optim_config['lr_decay'])

    elif optim_config['scheduler'] == 'sgdr':
        scheduler = SGDRScheduler(optimizer, optim_config['T0'],
                                  optim_config['Tmult'],
                                  optim_config['lr_min'])

    elif optim_config['scheduler'] == 'cosine':
        total_steps = optim_config['epochs'] * \
            optim_config['steps_per_epoch']

        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer,
            lr_lambda=lambda step: cosine_annealing(
                step,
                total_steps,
                1,  # since lr_lambda computes multiplicative factor
                optim_config['lr_min'] / optim_config['base_lr']))

    elif optim_config['scheduler'] == 'keras_decay':
        scheduler = torch.optim.lr_scheduler.LambdaLR(
            optimizer, lr_lambda=lambda step: keras_decay(step))
    else:
        scheduler = None
    return scheduler


def create_optimizer(model_parameters, optim_config):
    """Defines an optimizer and a scheduler"""
    optimizer = _get_optimizer(model_parameters, optim_config)
    scheduler = _get_scheduler(optimizer, optim_config)
    return optimizer, scheduler


def seed_worker(worker_id):
    """
    Used in generating seed for the worker of torch.utils.data.Dataloader
    """
    worker_seed = torch.initial_seed() % 2**32
    np.random.seed(worker_seed)
    random.seed(worker_seed)


def set_seed(seed, config = None):
    """
    set initial seed for reproduction
    """
    if config is None:
        raise ValueError("config should not be None")

    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = str_to_bool(config["cudnn_deterministic_toggle"])
        torch.backends.cudnn.benchmark = str_to_bool(config["cudnn_benchmark_toggle"])

# Model, Device, Optim, Scheduler

In [None]:
# MODEL

"""
AASIST
Copyright (c) 2021-present NAVER Corp.
MIT license
"""

import random
from typing import Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor


class GraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_weight = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x):
        '''
        x   :(#bs, #node, #dim)
        '''
        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)
        return x

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map(self, x):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)
        att_map = torch.matmul(att_map, self.att_weight)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class HtrgGraphAttentionLayer(nn.Module):
    def __init__(self, in_dim, out_dim, **kwargs):
        super().__init__()

        self.proj_type1 = nn.Linear(in_dim, in_dim)
        self.proj_type2 = nn.Linear(in_dim, in_dim)

        # attention map
        self.att_proj = nn.Linear(in_dim, out_dim)
        self.att_projM = nn.Linear(in_dim, out_dim)

        self.att_weight11 = self._init_new_params(out_dim, 1)
        self.att_weight22 = self._init_new_params(out_dim, 1)
        self.att_weight12 = self._init_new_params(out_dim, 1)
        self.att_weightM = self._init_new_params(out_dim, 1)

        # project
        self.proj_with_att = nn.Linear(in_dim, out_dim)
        self.proj_without_att = nn.Linear(in_dim, out_dim)

        self.proj_with_attM = nn.Linear(in_dim, out_dim)
        self.proj_without_attM = nn.Linear(in_dim, out_dim)

        # batch norm
        self.bn = nn.BatchNorm1d(out_dim)

        # dropout for inputs
        self.input_drop = nn.Dropout(p=0.2)

        # activate
        self.act = nn.SELU(inplace=True)

        # temperature
        self.temp = 1.
        if "temperature" in kwargs:
            self.temp = kwargs["temperature"]

    def forward(self, x1, x2, master=None):
        '''
        x1  :(#bs, #node, #dim)
        x2  :(#bs, #node, #dim)
        '''
        num_type1 = x1.size(1)
        num_type2 = x2.size(1)

        x1 = self.proj_type1(x1)
        x2 = self.proj_type2(x2)

        x = torch.cat([x1, x2], dim=1)

        if master is None:
            master = torch.mean(x, dim=1, keepdim=True)

        # apply input dropout
        x = self.input_drop(x)

        # derive attention map
        att_map = self._derive_att_map(x, num_type1, num_type2)

        # directional edge for master node
        master = self._update_master(x, master)

        # projection
        x = self._project(x, att_map)

        # apply batch norm
        x = self._apply_BN(x)
        x = self.act(x)

        x1 = x.narrow(1, 0, num_type1)
        x2 = x.narrow(1, num_type1, num_type2)

        return x1, x2, master

    def _update_master(self, x, master):

        att_map = self._derive_att_map_master(x, master)
        master = self._project_master(x, master, att_map)

        return master

    def _pairwise_mul_nodes(self, x):
        '''
        Calculates pairwise multiplication of nodes.
        - for attention map
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, #dim)
        '''

        nb_nodes = x.size(1)
        x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
        x_mirror = x.transpose(1, 2)

        return x * x_mirror

    def _derive_att_map_master(self, x, master):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = x * master
        att_map = torch.tanh(self.att_projM(att_map))

        att_map = torch.matmul(att_map, self.att_weightM)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _derive_att_map(self, x, num_type1, num_type2):
        '''
        x           :(#bs, #node, #dim)
        out_shape   :(#bs, #node, #node, 1)
        '''
        att_map = self._pairwise_mul_nodes(x)
        # size: (#bs, #node, #node, #dim_out)
        att_map = torch.tanh(self.att_proj(att_map))
        # size: (#bs, #node, #node, 1)

        att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)

        att_board[:, :num_type1, :num_type1, :] = torch.matmul(
            att_map[:, :num_type1, :num_type1, :], self.att_weight11)
        att_board[:, num_type1:, num_type1:, :] = torch.matmul(
            att_map[:, num_type1:, num_type1:, :], self.att_weight22)
        att_board[:, :num_type1, num_type1:, :] = torch.matmul(
            att_map[:, :num_type1, num_type1:, :], self.att_weight12)
        att_board[:, num_type1:, :num_type1, :] = torch.matmul(
            att_map[:, num_type1:, :num_type1, :], self.att_weight12)

        att_map = att_board

        # att_map = torch.matmul(att_map, self.att_weight12)

        # apply temperature
        att_map = att_map / self.temp

        att_map = F.softmax(att_map, dim=-2)

        return att_map

    def _project(self, x, att_map):
        x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
        x2 = self.proj_without_att(x)

        return x1 + x2

    def _project_master(self, x, master, att_map):

        x1 = self.proj_with_attM(torch.matmul(
            att_map.squeeze(-1).unsqueeze(1), x))
        x2 = self.proj_without_attM(master)

        return x1 + x2

    def _apply_BN(self, x):
        org_size = x.size()
        x = x.view(-1, org_size[-1])
        x = self.bn(x)
        x = x.view(org_size)

        return x

    def _init_new_params(self, *size):
        out = nn.Parameter(torch.FloatTensor(*size))
        nn.init.xavier_normal_(out)
        return out


class GraphPool(nn.Module):
    def __init__(self, k: float, in_dim: int, p: Union[float, int]):
        super().__init__()
        self.k = k
        self.sigmoid = nn.Sigmoid()
        self.proj = nn.Linear(in_dim, 1)
        self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
        self.in_dim = in_dim

    def forward(self, h):
        Z = self.drop(h)
        weights = self.proj(Z)
        scores = self.sigmoid(weights)
        new_h = self.top_k_graph(scores, h, self.k)

        return new_h

    def top_k_graph(self, scores, h, k):
        """
        args
        =====
        scores: attention-based weights (#bs, #node, 1)
        h: graph data (#bs, #node, #dim)
        k: ratio of remaining nodes, (float)

        returns
        =====
        h: graph pool applied data (#bs, #node', #dim)
        """
        _, n_nodes, n_feat = h.size()
        n_nodes = max(int(n_nodes * k), 1)
        _, idx = torch.topk(scores, n_nodes, dim=1)
        idx = idx.expand(-1, -1, n_feat)

        h = h * scores
        h = torch.gather(h, 1, idx)

        return h


class CONV(nn.Module):
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)

    @staticmethod
    def to_hz(mel):
        return 700 * (10**(mel / 2595) - 1)

    def __init__(self,
                 out_channels,
                 kernel_size,
                 sample_rate=16000,
                 in_channels=1,
                 stride=1,
                 padding=0,
                 dilation=1,
                 bias=False,
                 groups=1,
                 mask=False):
        super().__init__()
        if in_channels != 1:

            msg = "SincConv only support one input channel (here, in_channels = {%i})" % (
                in_channels)
            raise ValueError(msg)
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.sample_rate = sample_rate

        # Forcing the filters to be odd (i.e, perfectly symmetrics)
        if kernel_size % 2 == 0:
            self.kernel_size = self.kernel_size + 1
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.mask = mask
        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')

        NFFT = 512
        f = int(self.sample_rate / 2) * np.linspace(0, 1, int(NFFT / 2) + 1)
        fmel = self.to_mel(f)
        fmelmax = np.max(fmel)
        fmelmin = np.min(fmel)
        filbandwidthsmel = np.linspace(fmelmin, fmelmax, self.out_channels + 1)
        filbandwidthsf = self.to_hz(filbandwidthsmel)

        self.mel = filbandwidthsf
        self.hsupp = torch.arange(-(self.kernel_size - 1) / 2,
                                  (self.kernel_size - 1) / 2 + 1)
        self.band_pass = torch.zeros(self.out_channels, self.kernel_size)
        for i in range(len(self.mel) - 1):
            fmin = self.mel[i]
            fmax = self.mel[i + 1]
            hHigh = (2*fmax/self.sample_rate) * \
                np.sinc(2*fmax*self.hsupp/self.sample_rate)
            hLow = (2*fmin/self.sample_rate) * \
                np.sinc(2*fmin*self.hsupp/self.sample_rate)
            hideal = hHigh - hLow

            self.band_pass[i, :] = Tensor(np.hamming(
                self.kernel_size)) * Tensor(hideal)

    def forward(self, x, mask=False):
        band_pass_filter = self.band_pass.clone().to(x.device)
        if mask:
            A = np.random.uniform(0, 20)
            A = int(A)
            A0 = random.randint(0, band_pass_filter.shape[0] - A)
            band_pass_filter[A0:A0 + A, :] = 0
        else:
            band_pass_filter = band_pass_filter

        self.filters = (band_pass_filter).view(self.out_channels, 1,
                                               self.kernel_size)

        return F.conv1d(x,
                        self.filters,
                        stride=self.stride,
                        padding=self.padding,
                        dilation=self.dilation,
                        bias=None,
                        groups=1)


class Residual_block(nn.Module):
    def __init__(self, nb_filts, first=False):
        super().__init__()
        self.first = first

        if not self.first:
            self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
        self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(1, 1),
                               stride=1)
        self.selu = nn.SELU(inplace=True)

        self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
        self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(0, 1),
                               stride=1)

        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
                                             out_channels=nb_filts[1],
                                             padding=(0, 1),
                                             kernel_size=(1, 3),
                                             stride=1)

        else:
            self.downsample = False
        self.mp = nn.MaxPool2d((1, 3))  # self.mp = nn.MaxPool2d((1,4))

    def forward(self, x):
        identity = x
        if not self.first:
            out = self.bn1(x)
            out = self.selu(out)
        else:
            out = x
        out = self.conv1(x)

        # print('out',out.shape)
        out = self.bn2(out)
        out = self.selu(out)
        # print('out',out.shape)
        out = self.conv2(out)
        #print('conv2 out',out.shape)
        if self.downsample:
            identity = self.conv_downsample(identity)

        out += identity
        out = self.mp(out)
        return out


class Model(nn.Module):
    def __init__(self, d_args):
        super().__init__()

        self.d_args = d_args
        filts = d_args["filts"]
        gat_dims = d_args["gat_dims"]
        pool_ratios = d_args["pool_ratios"]
        temperatures = d_args["temperatures"]

        self.conv_time = CONV(out_channels=filts[0],
                              kernel_size=d_args["first_conv"],
                              in_channels=1)
        self.first_bn = nn.BatchNorm2d(num_features=1)

        self.drop = nn.Dropout(0.5, inplace=True)
        self.drop_way = nn.Dropout(0.2, inplace=True)
        self.selu = nn.SELU(inplace=True)

        self.encoder = nn.Sequential(
            nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
            nn.Sequential(Residual_block(nb_filts=filts[2])),
            nn.Sequential(Residual_block(nb_filts=filts[3])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])),
            nn.Sequential(Residual_block(nb_filts=filts[4])))

        self.pos_S = nn.Parameter(torch.randn(1, 23, filts[-1][-1]))
        self.master1 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
        self.master2 = nn.Parameter(torch.randn(1, 1, gat_dims[0]))

        self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[0])
        self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
                                               gat_dims[0],
                                               temperature=temperatures[1])

        self.HtrgGAT_layer_ST11 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])
        self.HtrgGAT_layer_ST12 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST21 = HtrgGraphAttentionLayer(
            gat_dims[0], gat_dims[1], temperature=temperatures[2])

        self.HtrgGAT_layer_ST22 = HtrgGraphAttentionLayer(
            gat_dims[1], gat_dims[1], temperature=temperatures[2])

        self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3)
        self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3)
        self.pool_hS1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT1 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.pool_hS2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)
        self.pool_hT2 = GraphPool(pool_ratios[2], gat_dims[1], 0.3)

        self.out_layer = nn.Linear(5 * gat_dims[1], 2)

    def forward(self, x, Freq_aug=False):

        x = x.unsqueeze(1)
        x = self.conv_time(x, mask=Freq_aug)
        x = x.unsqueeze(dim=1)
        x = F.max_pool2d(torch.abs(x), (3, 3))
        x = self.first_bn(x)
        x = self.selu(x)

        # get embeddings using encoder
        # (#bs, #filt, #spec, #seq)
        e = self.encoder(x)

        # spectral GAT (GAT-S)
        e_S, _ = torch.max(torch.abs(e), dim=3)  # max along time
        e_S = e_S.transpose(1, 2) + self.pos_S

        gat_S = self.GAT_layer_S(e_S)
        out_S = self.pool_S(gat_S)  # (#bs, #node, #dim)

        # temporal GAT (GAT-T)
        e_T, _ = torch.max(torch.abs(e), dim=2)  # max along freq
        e_T = e_T.transpose(1, 2)

        gat_T = self.GAT_layer_T(e_T)
        out_T = self.pool_T(gat_T)

        # learnable master node
        master1 = self.master1.expand(x.size(0), -1, -1)
        master2 = self.master2.expand(x.size(0), -1, -1)

        # inference 1
        out_T1, out_S1, master1 = self.HtrgGAT_layer_ST11(
            out_T, out_S, master=self.master1)

        out_S1 = self.pool_hS1(out_S1)
        out_T1 = self.pool_hT1(out_T1)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST12(
            out_T1, out_S1, master=master1)
        out_T1 = out_T1 + out_T_aug
        out_S1 = out_S1 + out_S_aug
        master1 = master1 + master_aug

        # inference 2
        out_T2, out_S2, master2 = self.HtrgGAT_layer_ST21(
            out_T, out_S, master=self.master2)
        out_S2 = self.pool_hS2(out_S2)
        out_T2 = self.pool_hT2(out_T2)

        out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST22(
            out_T2, out_S2, master=master2)
        out_T2 = out_T2 + out_T_aug
        out_S2 = out_S2 + out_S_aug
        master2 = master2 + master_aug

        out_T1 = self.drop_way(out_T1)
        out_T2 = self.drop_way(out_T2)
        out_S1 = self.drop_way(out_S1)
        out_S2 = self.drop_way(out_S2)
        master1 = self.drop_way(master1)
        master2 = self.drop_way(master2)

        out_T = torch.max(out_T1, out_T2)
        out_S = torch.max(out_S1, out_S2)
        master = torch.max(master1, master2)

        T_max, _ = torch.max(torch.abs(out_T), dim=1)
        T_avg = torch.mean(out_T, dim=1)

        S_max, _ = torch.max(torch.abs(out_S), dim=1)
        S_avg = torch.mean(out_S, dim=1)

        last_hidden = torch.cat(
            [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)

        last_hidden = self.drop(last_hidden)
        output = self.out_layer(last_hidden)

        return last_hidden, output


In [None]:
def get_model(config: Dict, device: torch.device):
    m = Model(config).to(device)
    nb_params = sum([param.view(-1).size()[0] for param in m.parameters()])
    print("no. model params:{}".format(nb_params))
    return m

device = "cuda" if torch.cuda.is_available() else "cpu"
model = get_model(config=model_config, device=device)

optim_config["steps_per_epoch"] = len(train_loader)
optim_config["epochs"] = config_info["num_epochs"]
optimizer, scheduler = create_optimizer(model.parameters(), optim_config)
optimizer_swa = SWA(optimizer)


# Train Epoch

In [None]:
def train_epoch(
    trn_loader: DataLoader,
    model,
    optim: Union[torch.optim.SGD, torch.optim.Adam],
    device: torch.device,
    scheduler: torch.optim.lr_scheduler,
    ):

    """Train the model for one epoch"""
    running_loss = 0
    num_total = 0.0
    ii = 0
    model.train()

    weight = torch.FloatTensor([0.1, 0.9]).to(device)
    criterion = nn.CrossEntropyLoss(weight=weight)

    for batch_x, batch_y in trn_loader:
        batch_size = batch_x.size(0)
        num_total += batch_size
        ii += 1
        batch_x = batch_x.to(device)
        batch_y = batch_y.view(-1).type(torch.int64).to(device)
        _, batch_out = model(batch_x, Freq_aug=str_to_bool(config_info["freq_aug"]))
        batch_loss = criterion(batch_out, batch_y)
        running_loss += batch_loss.item() * batch_size
        optim.zero_grad()
        batch_loss.backward()
        optim.step()

        if config_info["optim_config"]["scheduler"] in ["cosine", "keras_decay"]:
            scheduler.step()
        elif scheduler is None:
            pass
        else:
            raise ValueError("scheduler error, got:{}".format(scheduler))
    running_loss /= num_total
    return running_loss

# Eval function
- TorchMetrics implementation


In [None]:

@torch.no_grad()
def evaluate(loader, model, device):
    model.eval()

    criterion = nn.CrossEntropyLoss()

    eer_metric = BinaryEER(thresholds=None).to(device)

    total_loss = 0.0
    num_samples = 0

    all_probs_pos = []
    all_labels = []

    for wav, y in loader:
        wav = wav.to(device)
        y = y.to(device)

        _, logits = model(wav, Freq_aug=False)

        loss = criterion(logits, y)
        bs = y.size(0)
        total_loss += loss.item() * bs
        num_samples += bs

        pos_logit = logits[:, 1]
        eer_metric.update(pos_logit, y)

        # keep probs to compute the threshold via ROC later
        probs = F.softmax(logits, dim=1)[:, 1].detach().cpu()
        all_probs_pos.append(probs)
        all_labels.append(y.detach().cpu())

    eer = float(eer_metric.compute().item())

    probs_all = torch.cat(all_probs_pos)
    labels_all = torch.cat(all_labels)
    fpr, tpr, thresholds = binary_roc(probs_all, labels_all)
    fnr = 1.0 - tpr
    idx = torch.argmin(torch.abs(fpr - fnr))
    thr = float(thresholds[idx].item())

    avg_loss = total_loss / max(1, num_samples)
    return {"loss": avg_loss, "eer": eer, "thr": thr}

# Training Loop

In [None]:

def train():
    best_eer = float("inf")
    best_model_path = Path(output_dir) / "AASIST3.pth"

    # CSV setup
    metrics_path = Path(output_dir) / "AASIST3_training_results.csv"
    metrics_path.parent.mkdir(parents=True, exist_ok=True)
    fieldnames = [
        "epoch",
        "train_loss",
        "eval_loss",
        "eval_eer",
        "eval_threshold",
        "is_best",
        "best_eer_so_far",
    ]
    write_header = not metrics_path.exists()

    with metrics_path.open("a", newline="") as fcsv:
        writer = csv.DictWriter(fcsv, fieldnames=fieldnames)
        if write_header:
            writer.writeheader()

        for epoch in range(config_info['num_epochs']):
            print(f"Start training epoch{epoch:03d}")

            # 1) Training
            running_loss = train_epoch(train_loader, model, optimizer, device, scheduler)

            # 2) Validation
            evaluation = evaluate(dev_loader, model, device)
            eval_loss = evaluation['loss']
            eval_eer = evaluation['eer']
            eval_thr = evaluation['thr']

            print(
                f"Training Loss: {running_loss:.4f}, "
                f"Eval Loss: {eval_loss:.4f}, "
                f"Eval EER: {eval_eer:.4f}, "
                f"Eval Threshold: {eval_thr:.4f}"
            )

            # 3) Save best model based on EER (lower is better)
            is_best = False
            if eval_eer < best_eer:
                best_eer = eval_eer
                torch.save(model.state_dict(), best_model_path)
                is_best = True
                print(f"✅ Best model updated at epoch {epoch:03d} (EER={best_eer:.4f})")

            # 4) Write a CSV row for this epoch
            writer.writerow({
                "epoch": epoch,
                "train_loss": f"{running_loss:.6f}",
                "eval_loss": f"{eval_loss:.6f}",
                "eval_eer": f"{eval_eer:.6f}",
                "eval_threshold": f"{eval_thr:.6f}",
                "is_best": int(is_best),
                "best_eer_so_far": f"{best_eer:.6f}",
            })
            
            fcsv.flush()
            os.fsync(fcsv.fileno())

    print(f"Training finished. Best model saved at {best_model_path} with EER={best_eer:.4f}")
    print(f"Epoch metrics saved to: {metrics_path}")

In [None]:
train()

# Evaluation

In [None]:
class DatasetAFADTEST(Dataset):
    def __init__(self, hf_dataset):
        self.dataset = hf_dataset

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        sample = self.dataset[idx]

        audio_dict = sample["audio"]
        waveform = torch.tensor(audio_dict["array"]).float()
        p = sample['audio']['path']

        id = Path(p).stem
        label = sample["label"]

        label = 1 if label == "real" else 0
        return waveform, id, label

In [None]:
def collate_fn_test(batch):
    """
    Pads a batch of variable-length waveforms to the same length.
    Batch shape: torch.Size([2, 62801])
    labels: (0, 0)
    """
    waveforms, ids, labels = zip(*batch)
    padded = pad_sequence(waveforms, batch_first=True)
    labels = torch.tensor(labels, dtype=torch.long)
    return padded, ids, labels

In [None]:
def get_label_predictions(dataset, split, split_text, results_file):
  """
  This function computes the prediction scores for the rel/fake labels.

  """
  d = load_dataset(dataset, split=split) # split is added with ad-hoc

  evall_dataset = DatasetAFADTEST(d)
  TEST_LOADER = DataLoader(evall_dataset, batch_size=8, shuffle=False, drop_last=False, pin_memory=True, collate_fn=collate_fn_test)

  with open(split_text, "r") as f_trl:
      trial_lines = f_trl.readlines()

  fname_list = []
  score_list = []
  pred_list = []
  prob_real_list = []
  for batch_x, utt_id, _ in TEST_LOADER:
      batch_x = batch_x.to(device)
      with torch.no_grad():
          _, batch_out = model(batch_x)
          batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel()
          logits = batch_out[1] if (isinstance(batch_out, tuple) and len(batch_out) == 2) else batch_out

          if logits.size(-1) == 2:
              probs = torch.softmax(logits, dim=-1)
              p_real = probs[:, 1]
              pred = (p_real > 0.5).long()
          elif logits.size(-1) == 1:
              p_real = torch.sigmoid(logits.squeeze(1))
              pred = (p_real > 0.5).long()
          else:
              raise ValueError("Unexpected logits shape. Expect [B,2] or [B,1].")

      fname_list.extend(utt_id)
      score_list.extend(batch_score.tolist())
      pred_list.extend(pred.cpu().numpy().tolist())
      prob_real_list.extend(p_real.cpu().numpy().tolist())
  print(len(trial_lines))
  print(len(fname_list))
  print(len(score_list))
  print(len(pred_list))
  assert len(trial_lines) == len(fname_list) == len(score_list)
  with open(results_file, "w") as fh:
      for fn, p, pro_real, sco, trl in zip(fname_list, pred_list, prob_real_list, score_list, trial_lines):
          speaker_id, sentence, _, label, gender, tts = trl.strip().split('\t')
          fh.write("{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\n".format(speaker_id, sentence, fn, label, gender, tts, sco, p, pro_real))

## 1. AfAD Test

In [None]:
model.load_state_dict(torch.load(os.path.join(output_dir, "AASIST3.pth")))
model.to(device)
model.eval()

In [None]:
# get EER
e = evaluate(test_loader, model, device)
print(e)
with open(os.path.join(output_dir, "results", "AASIST3", "AASIST3_AFAD_EER.txt"), "w") as file:
    file.write(str(e))

In [None]:
# get scores
print("\nGetting scores...")
get_label_predictions(
    speaker_disjoint['test'],
    os.path.join(output_dir,"results", "AASIST3", "AASIST3_test.txt"),
    os.path.join(output_dir,"results", "AASIST3", "AASIST3_AFAD_test_scores.txt")
)

## 2. TTS model

In [None]:
test_tts_eleven = speaker_disjoint['test'].filter(lambda example: example['tts'] in ['eleven_multilingual_v2',  'none'] and example['label'] in ['real', 'fake'])
test_tts_openai = speaker_disjoint['test'].filter(lambda example: example['tts'] in ['gpt-4o-mini-tts',  'none'] and example['label'] in ['real', 'fake'])
test_tts_minimax = speaker_disjoint['test'].filter(lambda example: example['tts'] in ['speech-2.5-hd-preview',  'none'] and example['label'] in ['real', 'fake'])
test_tts_resemble = speaker_disjoint['test'].filter(lambda example: example['tts'] in ['reseamble-AI',  'none'] and example['label'] in ['real', 'fake'])

In [None]:
test_tts_eleven_d = DatasetAFAD(test_tts_eleven)
test_tts_openai_d = DatasetAFAD(test_tts_openai)
test_tts_minimax_d = DatasetAFAD(test_tts_minimax)
test_tts_resemble_d = DatasetAFAD(test_tts_resemble)

In [None]:
eleven_loader = DataLoader(test_tts_eleven_d, batch_size=8, shuffle=False, drop_last=False, pin_memory=True, collate_fn=collate_fn)
openai_loader = DataLoader(test_tts_openai_d, batch_size=8, shuffle=False, drop_last=False, pin_memory=True, collate_fn=collate_fn)
minimax_loader = DataLoader(test_tts_minimax_d, batch_size=8, shuffle=False, drop_last=False, pin_memory=True, collate_fn=collate_fn)
resemble_loader = DataLoader(test_tts_resemble_d, batch_size=8, shuffle=False, drop_last=False, pin_memory=True, collate_fn=collate_fn)

In [None]:
# get EER: elevenlabs
eleven = evaluate(eleven_loader, model, device)
print(eleven)
with open(os.path.join(output_dir, "results", "AASIST3", "AASIST3_AFAD_test_eleven_eer.txt"), "w") as file:
    file.write(str(eleven))

In [None]:
# get EER: openai
opena = evaluate(openai_loader, model, device)
print(opena)
with open(os.path.join(output_dir, "results", "AASIST3", "AASIST3_AFAD_test_openai_eer.txt"), "w") as file:
    file.write(str(opena))

In [None]:
# get EER: minimax
minimax = evaluate(minimax_loader, model, device)
print(minimax)
with open(os.path.join(output_dir, "results", "AASIST3", "AASIST3_AFAD_test_minimax_eer.txt"), "w") as file:
    file.write(str(minimax))

In [None]:
# get EER: resemble
resemble = evaluate(resemble_loader, model, device)
print(resemble)
with open(os.path.join(output_dir, "results", "AASIST3", "AASIST3_AFAD_test_resemble_eer.txt"), "w") as file:
    file.write(str(resemble))

## 3. ad-hoc sets

In [None]:
def get_adhoc_eer(dataset, split, results_file):
    dd = load_dataset(dataset, split=split)
    dd = DatasetAFAD(dd)
    loader = DataLoader(dd, batch_size=8, shuffle=False, drop_last=False, pin_memory=True, collate_fn=collate_fn)
    e = evaluate(loader, model, device)
    print(e)
    with open(os.path.join(output_dir, "results", "AASIST3", results_file), "w") as file:
        file.write(str(e))

## 1. Fish

In [None]:
# get eer
print("\nGetting eer...")
get_adhoc_eer("elsayedissa/AFAD", "fishaudio", "AASIST3_fish_eer.txt")

In [None]:
# get scores
print("\nGetting scores...")
get_label_predictions(
    config_info['dataset_name'],
    "fishaudio",
    os.path.join(output_dir,"results", "AASIST3", "Fish_test.txt"),
    os.path.join(output_dir,"results", "AASIST3", "AASIST3_fish_scores.txt")
)

## 2. XTTS


In [None]:
# get eer
print("\nGetting eer...")
get_adhoc_eer("elsayedissa/AFAD", "xtts", "AASIST3", "AASIST3_xtts_eer.txt")

In [None]:
# get scores
print("\nGetting scores...")
get_label_predictions(
    config_info['dataset_name'],
    "xtts",
    os.path.join(output_dir,"results", "AASIST3", "XTTS_test.txt"),
    os.path.join(output_dir,"results", "AASIST3", "AASIST3_xtts_scores.txt")
)

## 3. MMS

In [None]:
# get eer
print("\nGetting eer...")
get_adhoc_eer("elsayedissa/AFAD", "mms", "AASIST3", "AASIST3_mms_eer.txt")

In [None]:
# get scores
print("\nGetting scores...")
get_label_predictions(
    config_info['dataset_name'],
    "mms",
    os.path.join(output_dir,"results", "AASIST3", "MMS_test.txt"),
    os.path.join(output_dir,"results", "AASIST3", "AASIST3_mms_scores.txt")
)

# 4. T5

In [None]:
# get eer
print("\nGetting eer...")
get_adhoc_eer("elsayedissa/AFAD", "speecht5", "AASIST3", "AASIST3_t5_eer.txt")

In [None]:
# get scores
print("\nGetting scores...")
get_label_predictions(
    config_info['dataset_name'],
    "speecht5",
    os.path.join(output_dir,"results", "AASIST3", "T5_test.txt"),
    os.path.join(output_dir,"results", "AASIST3", "AASIST3_T5_scores.txt")
)