In [2]:
# !pip install torch_optimizer
import wandb

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import Dataset, DataLoader
from torch.autograd import Variable
from torch.optim import Adam
import torchaudio
from torchmetrics import ROC
from torchinfo import summary
from torch_optimizer import AdaBound

import numpy as np
import os
import gc
from typing import Union
import soundfile as sf
import sys
import time
from copy import deepcopy
import math
import matplotlib.pyplot as plt

from IPython.display import clear_output, Audio
clear_output()

device = 'cuda:1' if torch.cuda.is_available() else 'cpu'

# Progress bar

In [2]:
def progressbar(it, prefix="", size=60, out=sys.stdout): # Python3.6+
    count = len(it)
    start = time.time()
    def show(j):
        x = int(size*j/count)
        remaining = ((time.time() - start) / j) * (count - j)
        passing = time.time() - start
        mins_pas, sec_pass = divmod(passing, 60)
        time_pas = f"{int(mins_pas):02}:{sec_pass:05.2f}"
        
        
        mins, sec = divmod(remaining, 60)
        time_str = f"{int(mins):02}:{sec:05.2f}"
        
        
        print(f"{prefix}[{u'█'*x}{('.'*(size-x))}] {j}/{count} time {time_pas} / {time_str}", end='\r', file=out, flush=True)
        
    for i, item in enumerate(it):
        yield item
        show(i+1)
    print("\n", flush=True, file=out)

# Model

## RawNet

### sinc conv

In [4]:
class SincConv(nn.Module):
    """Sinc-based convolution
    Parameters
    ----------
    in_channels : `int`
        Number of input channels. Must be 1.
    out_channels : `int`
        Number of filters.
    kernel_size : `int`
        Filter length.
    sample_rate : `int`, optional
        Sample rate. Defaults to 16000.
    Usage
    -----
    See `torch.nn.Conv1d`
    Reference
    ---------
    Mirco Ravanelli, Yoshua Bengio,
    "Speaker Recognition from raw waveform with SincNet".
    https://arxiv.org/abs/1808.00158
    """
 
    @staticmethod
    def to_mel(hz):
        return 2595 * np.log10(1 + hz / 700)
 
    @staticmethod
    def to_hz(mel):
        return 700 * (10 ** (mel / 2595) - 1)
 
    def __init__(self, out_channels, kernel_size, sample_rate=64000, in_channels=1,
                 stride=1, padding=0, dilation=1, bias=False, groups=1, min_low_hz=0, min_band_hz=0):
 
        super(SincConv,self).__init__()
 
        if in_channels != 1:
            #msg = (f'SincConv only support one input channel '
            #       f'(here, in_channels = {in_channels:d}).')
            msg = "SincConv only support one input channel (here, in_channels = {%i})" % (in_channels)
            raise ValueError(msg)
 
        self.out_channels = out_channels
        self.kernel_size = kernel_size
 
        # Forcing the filters to be odd (i.e, perfectly symmetrics)
        if kernel_size%2==0:
            self.kernel_size=self.kernel_size+1
 
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
 
        if bias:
            raise ValueError('SincConv does not support bias.')
        if groups > 1:
            raise ValueError('SincConv does not support groups.')
 
        self.sample_rate = sample_rate
        self.min_low_hz = min_low_hz
        self.min_band_hz = min_band_hz
 
        # initialize filterbanks such that they are equally spaced in Mel scale
        low_hz = 0
        high_hz = self.sample_rate / 2 - (self.min_low_hz + self.min_band_hz)
 
        # In the future we will set high hz as band_hz + low + min_band_hz + min_low_hz
        # Where band_hz is (high_hz - low_hz). Therefore, it is reasonable to
        # do diff and do not set high_hz as sr/2
 
        mel = np.linspace(self.to_mel(low_hz),
                          self.to_mel(high_hz),
                          self.out_channels + 1)
        hz = self.to_hz(mel)
 
 
        # filter lower frequency (out_channels, 1)
        self.low_hz_ = nn.Parameter(torch.Tensor(hz[:-1]).view(-1, 1)) # learnable f1 from the paper
 
        # filter frequency band (out_channels, 1)
        self.band_hz_ = nn.Parameter(torch.Tensor(np.diff(hz)).view(-1, 1)) # learnable f2 (f2 = f1+diff) from the paper
 
        # len(g) = kernel_size
        # It is symmetric, therefore we will do computations only with left part, while creating g.
 
        # Hamming window
        #self.window_ = torch.hamming_window(self.kernel_size)
        n_lin=torch.linspace(0, (self.kernel_size/2)-1, steps=int((self.kernel_size/2))) # computing only half of the window
        self.window_=0.54-0.46*torch.cos(2*math.pi*n_lin/self.kernel_size);
 
        # self.window is eq. (8)
 
 
        # (1, kernel_size/2)
        n = (self.kernel_size - 1) / 2.0
        self.n_ = 2*math.pi*torch.arange(-n, 0).view(1, -1) / self.sample_rate # Due to symmetry, I only need half of the time axes
 
        # self.n_ = 2 * pi * n / sr
 
 
    def forward(self, waveforms):
        """
        Parameters
        ----------
        waveforms : `torch.Tensor` (batch_size, 1, n_samples)
            Batch of waveforms.
        Returns
        -------
        features : `torch.Tensor` (batch_size, out_channels, n_samples_out)
            Batch of sinc filters activations.
        """
 
        self.n_ = self.n_.to(waveforms.device)
 
        # print('self.n_', self.n_)
        # print('--------------------')
 
        self.window_ = self.window_.to(waveforms.device)
 
        low = self.min_low_hz  + torch.abs(self.low_hz_) # eq. (5) + make sure low >= min_low_hz
 
        high = torch.clamp(low + self.min_band_hz + torch.abs(self.band_hz_),self.min_low_hz,self.sample_rate/2) # eq. (6) + make sure band has length >= min_band_hz
        band=(high-low)[:,0] # g[0] / 2
 
        # print('band', band)
        # print('low', low)
        # print('high', high)
        # print('--------------------')
 
        f_times_t_low = torch.matmul(low, self.n_) # 2 * pi * n * freq / sr
        f_times_t_high = torch.matmul(high, self.n_)
 
        # print('times_t_low', f_times_t_low)
        # print('times_t_high', f_times_t_high)
        # print('--------------------')
 
        # 2*f2*sinc(2*pi*f2*n) - 2*f1*sinc(2*pi*f1*n)
        # 2*f2*sin(2*pi*f2*n) / (2 * pi * f2 * n) - 2*f1*sin(2*pi*f1*n) / (2 * pi * f1 * n)
        # sin(2*pi*f2*n) / (pi n) - sin(2*pi*f1*n) / (pi n)
 
        # (2 / sr) * sin(f_times_t_high) / self.n_ -  (2 / sr) * sin(f_times_t_low) / self.n_
        # (1/ sr) * (sin(f_times_t_high) - sin(f_times_t_low)) / (self.n_ / 2)
 
        # sr * correct eq. (4)
 
        # because self.n_ = 2 * pi * n / sr
 
        band_pass_left=((torch.sin(f_times_t_high)-torch.sin(f_times_t_low))/(self.n_/2))*self.window_ # Equivalent of Eq.4 of the reference paper (SPEAKER RECOGNITION FROM RAW WAVEFORM WITH SINCNET). I just have expanded the sinc and simplified the terms. This way I avoid several useless computations.
        band_pass_center = 2*band.view(-1,1) # g[0] = 2 * (f2 - f1) = 2 * band, w[0] = 1
        band_pass_right= torch.flip(band_pass_left,dims=[1]) # g[n] = g[-n]
 
        # print('band_pass_left', band_pass_left)
        # print('band_pass_center', band_pass_center)
        # print('---------------')
 
 
        band_pass=torch.cat([band_pass_left,band_pass_center,band_pass_right],dim=1) # create full g[n]
 
 
        band_pass = band_pass / (2*band[:,None]) # normalize so the max is 1
 
        # band_pass_left = sr * correct (4)
        # center = freq (not scaled via division) = sr * scaled_freq
        # thus, after normalization we will divide all by sr and get normalized correct(4) + normalized center
 
 
        self.filters = (band_pass).view(
            self.out_channels, 1, self.kernel_size)
 
        return F.conv1d(waveforms, self.filters, stride=self.stride,
                        padding=self.padding, dilation=self.dilation,
                         bias=None, groups=1) # x[n] * g[n]
 


### res blocK

In [5]:
class Res_block(nn.Module):
    def __init__(self, nb_filts: int, first=False):
        super().__init__()
        # for first res block in net
        self.first = first

        if not self.first:
            self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
            
        self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(1, 1),
                               stride=1)
        self.selu = nn.SELU(inplace=True)

        self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
        self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
                               out_channels=nb_filts[1],
                               kernel_size=(2, 3),
                               padding=(0, 1),
                               stride=1)

        if nb_filts[0] != nb_filts[1]:
            self.downsample = True
            self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
                                             out_channels=nb_filts[1],
                                             padding=(0, 1),
                                             kernel_size=(1, 3),
                                             stride=1)

        else:
            self.downsample = False
            
        self.mp = nn.MaxPool2d((1, 3))

    def forward(self, x: Tensor) -> Tensor:
        # original sample save
        original_x = x 
        
        if not self.first:
            out = self.bn1(x)
            out = self.selu(out)
            
        else:
            out = x
            
        out = self.conv1(x)
        
        out = self.bn2(out)
        out = self.selu(out)
        out = self.conv2(out)
        
        # resizing original sample in case of size diff with out
        if self.downsample:
            original_x = self.conv_downsample(original_x)

        # adding original sample in the end of res block
        out += original_x
        out = self.mp(out)
        return out


### encoder

In [6]:
class Encoder(nn.Module):
    def __init__(self, d_args):
        super().__init__()
        
        # list of some args of original model. Full list: https://github.com/clovaai/aasist/blob/main/config/AASIST.conf
        self.d_args = d_args
        filts = d_args["filts"]

        self.sinc_conv = SincConv(out_channels=filts[0],
                                  kernel_size=d_args["first_conv"],
        )
        
        self.first_bn = nn.BatchNorm2d(num_features=1)
        self.selu = nn.SELU(inplace=True)

        self.res_encoder = nn.Sequential(
            Res_block(nb_filts=filts[1], first=True),
            Res_block(nb_filts=filts[2]),
            Res_block(nb_filts=filts[3]),
            Res_block(nb_filts=filts[4]),
            Res_block(nb_filts=filts[4]),
            Res_block(nb_filts=filts[4])
        )
        
    def forward(self, x):
        x = x.unsqueeze(1)
        
        x = self.sinc_conv(x)
        x = x.unsqueeze(dim=1)
        
        x = F.max_pool2d(torch.abs(x), (3, 3))
        x = self.first_bn(x)
        x = self.selu(x)

        
        e = self.res_encoder(x) # [batch, filter, spec,sequence]
        # print(e.shape)
        return e


## CapsNet

### chanelWiseStats

In [7]:
class ChanelWiseStats(nn.Module):
    """
    The class that computes mean and standart deviation
    in input data acrocc channels
    """
    def __init__(self):
        super(ChanelWiseStats, self).__init__()
    
    def forward(self, x):
        x = x.view(x.data.shape[0], x.data.shape[1],
                   x.data.shape[2]*x.data.shape[3])
        
        mean = torch.mean(x, 2)
        std = torch.std(x,2)
        
        return torch.stack((mean, std), dim=1)
    

class View(nn.Module):
    """
    Auxiliary class
    """
    def __init__(self, *shape):
        super(View, self).__init__()
        self.shape = shape

    def forward(self, input):
        return input.view(self.shape)

### Spatial attention

In [8]:
class ChannelPool(nn.Module):
    def forward(self, x):
        return torch.cat((torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)


class SpatialAttention(nn.Module):
    def __init__(self):
        super(SpatialAttention,self).__init__()
        kernel_size=7
        self.compress = ChannelPool()
        self.spatial = nn.Sequential(
            nn.Conv2d(2,1, kernel_size=7, stride=1, padding=(kernel_size-1)//2),
            nn.BatchNorm2d(1, eps=1e-5, momentum=0.1, affine=True),
            nn.ReLU()
        )
    
    def forward(self,x):
        x_compress = self.compress(x)
        x_out = self.spatial(x_compress)
        scale = F.sigmoid(x_out)
        return x*scale
    

### Primary capsules

In [9]:
class PrimaryCapsules(nn.Module):
    """
    This class create capsules and makes
    forward propagation through them
    """
    def __init__(self, num_capsules=10):
        super(PrimaryCapsules, self).__init__()

        self.num_capsules = num_capsules

        self.capsules = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(),
                SpatialAttention(),
                nn.Conv2d(64,16, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm2d(16),
                nn.ReLU(),
                SpatialAttention(),
                ChanelWiseStats(),
                nn.Conv1d(2,8, kernel_size=5, stride=2, padding=2),
                nn.BatchNorm1d(8),
                nn.Conv1d(8, 1, kernel_size=3, stride=1, padding=1),
                nn.BatchNorm1d(1),
                View(-1, 8))
            
            for _ in range(num_capsules)]
            )

    
    def forward(self, x):
        results = [capsule(x) for capsule in self.capsules]
        result = torch.stack(results, dim=-1)
        return result

In [32]:
summary(PrimaryCapsules(), (1,64,23,29))

Layer (type:depth-idx)                        Output Shape              Param #
PrimaryCapsules                               [1, 8, 10]                --
├─ModuleList: 1-1                             --                        --
│    └─Sequential: 2-1                        [1, 8]                    --
│    │    └─Conv2d: 3-1                       [1, 64, 23, 29]           36,928
│    │    └─BatchNorm2d: 3-2                  [1, 64, 23, 29]           128
│    │    └─ReLU: 3-3                         [1, 64, 23, 29]           --
│    │    └─SpatialAttention: 3-4             [1, 64, 23, 29]           101
│    │    └─Conv2d: 3-5                       [1, 16, 23, 29]           9,232
│    │    └─BatchNorm2d: 3-6                  [1, 16, 23, 29]           32
│    │    └─ReLU: 3-7                         [1, 16, 23, 29]           --
│    │    └─SpatialAttention: 3-8             [1, 16, 23, 29]           101
│    │    └─ChanelWiseStats: 3-9              [1, 2, 16]                --
│    │    

### Routing mechanism

In [10]:
class RoutingMechanism(nn.Module):
    def __init__(self,
                 gpu_id,
                 num_input_capsules,
                 num_output_capsules,
                 data_in,
                 data_out,
                 num_iterations=2):
        super(RoutingMechanism, self).__init__()

        self.gpu_id = gpu_id
        self.num_iterations = num_iterations
        self.route_weights = nn.Parameter(torch.randn(
            num_output_capsules, num_input_capsules,
            data_out, data_in
        ))
    
    def squash(self, x, dim):
        squared_norm = (x ** 2).sum(dim=dim, keepdim=True)
        scale = squared_norm / (1 + squared_norm)
        return scale * x / (torch.sqrt(squared_norm))
    
    def forward(self, x, random, dropout, random_size=0.01):
        # x[batch, data, in_caps]

        x = x.transpose(2, 1)
        # x[batch, in_caps, data]

        if random:
            noise = Variable(random_size * torch.randn(*self.route_weights.size()))
            if self.gpu_id >=0:
                noise = noise.cuda(self.gpu_id)
            route_weights = self.route_weights + noise     #w_ji + rand(size(w_ji))
        else:
            route_weights = self.route_weights

        priors = route_weights[:, None, :, :, :] @ x[None, :, :, :, None]

        #route_weights[out_caps, 1,     in_caps, data_out, data_in]
        #x            [1,        batch, in_caps, data_in, 1]
        #priors       [out_caps, batch, in_caps, data_out, 1]
        priors = self.squash(priors.transpose(1, 0), dim=3)  #sqush(w_ij u_i)
        # priors[batch, out_caps, in_caps, data_out, 1]

        if dropout > 0.0:
            drop = Variable(torch.FloatTensor(*priors.size())).bernoulli(1.0-dropout)
            if self.gpu_id >= 0:
                drop = drop.cuda(self.gpu_id)
            priors = priors * drop


        logits = Variable(torch.zeros(*priors.size()))       #initialization b_ij = 
        #logits[batch,out_caps,in_caps, data_out,1]

        if self.gpu_id >= 0:
            logits = logits.cuda(self.gpu_id)

        num_iterations = self.num_iterations

        for i in range(num_iterations):                       #for r iterations do
            probs = F.softmax(logits, dim=2)                  #a_j = softmax(b_j)
            outputs = self.squash((probs*priors).sum(dim=2, keepdim=True), dim=3) #v_j=squash(a_ji * u_ji)

            
            logits = priors * outputs                         #b_ij = v_j*u_ji
        
        #outputs[b, out_caps, 1, data_out, 1]
        outputs = outputs.squeeze()

        if len(outputs.shape) == 3:
            outputs = outputs.transpose(2,1).contiguous()
        else:
            outputs = outputs.unsqueeze_(dim=0).transpose(2,1).contiguous()

        return outputs

## net

In [11]:
class CapsuleNet(nn.Module):
    # во время оценки отключить рандом и дропаут, модель перевести в эвал
    def __init__(self, num_class, gpu_id, d_args, num_capsules=3, num_iterations=3):
        super(CapsuleNet, self).__init__()
        
        self.num_class = num_class
        self.extractor = Encoder(d_args)
        self.fea_ext = PrimaryCapsules(num_capsules=num_capsules)
        self.routing_stats = RoutingMechanism(gpu_id=gpu_id, 
                                              num_input_capsules=num_capsules, 
                                              num_output_capsules=2,
                                              data_in=8,
                                              data_out=4,
                                              num_iterations=num_iterations)
    
    def forward(self, x, random=True, dropout=0.05, random_size=0.01):
        z = self.extractor(x)
        z = self.fea_ext(z)
        z = self.routing_stats(z, random, dropout, random_size=0.01)
#         classes = F.softmax(z, dim=-1)
        class_ = z.detach()
        class_ = class_.mean(dim=1)
        return z, class_

# Capsule Cross Entropy

In [12]:
class CapsuleLoss(nn.Module):
    def __init__(self, gpu_id, weight):
        super(CapsuleLoss, self).__init__()
        self.weight = weight
        self.ce = nn.CrossEntropyLoss(weight=self.weight)

        if gpu_id>=0:
            self.weight.cuda(gpu_id)
            self.ce.cuda(gpu_id)
    def forward(self, classes, labels):
        loss_t = self.ce(classes[:,0,:], labels)

        for i in range(classes.size(1) - 1):
            loss_t = loss_t + self.ce(classes[:,i+1,:], labels)

        return loss_t


# Dataset and Dataloader

In [11]:
class ASVspoof2019(Dataset):
    def __init__(self,IDs,dir_path, labels):
        self.IDs = IDs
        self.labels = labels
        self.dir_path = dir_path
        self.cut = 64600

    def __getitem__(self, index):
        path_to_flac = f"{self.dir_path}/flac/{self.IDs[index]}.flac"
        audio, rate = sf.read(path_to_flac)
        x_pad = self.pad_random(audio, self.cut)
        x_inp = Tensor(x_pad)
        return (x_inp, torch.tensor(self.labels[index]), rate)
    
    def __len__(self):
        return len(self.IDs)

    def pad_random(self, x, max_len=64600):
        x_len = x.shape[0]

        if x_len >= max_len:
            stt = np.random.randint(x_len-max_len)
            return x[stt:stt+max_len]

        num_repeats = int(max_len / x_len) + 1
        padded_x = np.tile(x, (num_repeats))[:max_len]
        return padded_x
    
class ASVspoof2019_dev_eval(Dataset):
    def __init__(self,IDs,dir_path, labels):
        self.IDs = IDs
        self.labels = labels
        self.dir_path = dir_path
        self.cut = 64600

    def __getitem__(self, index):
        path_to_flac = f"{self.dir_path}/flac/{self.IDs[index]}.flac"
        audio, rate = sf.read(path_to_flac)
        x_pad = self.pad_random(audio, self.cut)
        x_inp = Tensor(x_pad)
        return x_inp, self.IDs[index], torch.tensor(self.labels[index])
    
    def __len__(self):
        return len(self.IDs)

    def pad_random(self, x, max_len=64600):
        x_len = x.shape[0]

        if x_len > max_len:
            stt = np.random.randint(x_len-max_len)
            return x[stt:stt+max_len]

        num_repeats = int(max_len / x_len) + 1
        padded_x = np.tile(x, (num_repeats))[:max_len]
        return padded_x
 
def get_data_for_dataset(path):
    ids_list = []
    label_list = []
    with open(path,"r") as file:
        for line in file:
            line = line.split()
            id, label = line[1], line[-1]
            ids_list.append(id)
            label = 1 if label == "bonafide" else 0
            label_list.append(label)
    return ids_list, label_list

train_label_path = "LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.train.trn.txt"
train_path_flac = "LA/ASVspoof2019_LA_train"
train_IDs, train_labels = get_data_for_dataset(train_label_path)
 
train_dataset = ASVspoof2019(train_IDs,train_path_flac,train_labels)
# train_loader = DataLoader(
#     train_dataset,
#     batch_size=batch_size,
#     shuffle=True,
#     num_workers=2
# )


dev_label_path = "LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.dev.trl.txt"
dev_path_flac = "LA/ASVspoof2019_LA_dev"
dev_IDs, dev_labels = get_data_for_dataset(dev_label_path)

dev_dataset = ASVspoof2019_dev_eval(dev_IDs, dev_path_flac, dev_labels)
# dev_loader = DataLoader(
#     dev_dataset,
#     batch_size=batch_size,
#     shuffle=False,
#     num_workers=2
# )


eval_label_path = "LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt"
eval_path_flac = "LA/ASVspoof2019_LA_eval"
eval_IDs, eval_labels = get_data_for_dataset(eval_label_path)

eval_dataset = ASVspoof2019_dev_eval(eval_IDs, eval_path_flac, eval_labels)
# eval_loader = DataLoader(
#     eval_dataset,
#     batch_size=32,
#     shuffle = False,
#     num_workers =2
#     )

# EER & t-DCf

In [12]:
def compute_det_curve(bonafide_scores, spoof_scores):
    """
    function, that comuputes FRR and FAR with their thresholds

    args:
        bonafide_scores: score for bonafide speech
        spoof_scores: score for spoofed speech
    output:
        frr: false rejection rate
        far: false acceptance rate
        threshlods: thresholds for frr and far
    todo:
        rewrite to torch
        create tests
    """
    # number of scores
    n_scores = bonafide_scores.size + spoof_scores.size

    # bona fide scores and spoof scores
    all_scores = np.concatenate((bonafide_scores, spoof_scores))

    # label of bona fide score is 1
    # label of spoof score is 0
    labels = np.concatenate((np.ones(bonafide_scores.size), np.zeros(spoof_scores.size)))

    # indexes of sorted scores in all scores
    indices = np.argsort(all_scores, kind='mergesort')
    # sort labels based on scores
    labels = labels[indices]

    # Compute false rejection and false acceptance rates

    # tar cumulative value
    tar_trial_sums = np.cumsum(labels)
    nontarget_trial_sums = spoof_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums)

    # false rejection rates
    frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / bonafide_scores.size)) 

    # false acceptance rates 
    far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / spoof_scores.size))

    # Thresholds are the sorted scores
    thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices]))  

    return frr, far, thresholds


def compute_eer(bonafide_scores, spoof_scores):
    """ 
    Returns equal error rate (EER) and the corresponding threshold.
    args:
        bonafide_scores: score for bonafide speech
        spoof_scores: score for spoofed speech
    output:
        eer: equal error rate
        threshold: index, where frr=far
    todo:
        rewrite to torch
        create tests
    """
    frr, far, thresholds = compute_det_curve(bonafide_scores, spoof_scores)

    # absolute differense between frr and far
    abs_diffs = np.abs(frr - far)

    # index of minimal absolute difference
    min_index = np.argmin(abs_diffs)

    # equal error rate
    eer = np.mean((frr[min_index], far[min_index]))
    return eer, thresholds[min_index]

def produce_evaluation_file(data_loader,
                            model,
                            device,
                            loss_fn,
                            save_path,
                            trial_path,
                            random = False,
                            dropout = 0):
    """
    Create file, that need to give in function calculcate_t-DCF_EER
    args:
        data_loader: loader, that gives batch to model
        model: model, that calculate what we need
        device: device for data, model
        save_path: path where file shoud be saved
        trial_path: path from LA CM protocols
    todo:
        this function must return result: tensor of uid, src, key, score
    """

    # turning model into evaluation mode
    model.eval()

    # read file ASVspoof2019.LA.cm.<dev/train/eval>.trl.txt
    with open(trial_path, "r") as file_trial:
        trial_lines = file_trial.readlines()
    
    # list of utterance id and list of score for appropiate uid
    fname_list = []
    score_list = []
    current_loss = 0
    # inference
    for batch_x, utt_id, batch_y in progressbar(data_loader, prefix='computing cm score'):
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        with torch.no_grad():
            
            # first is hidden layer, second is result
            classes, batch_out = model.forward(batch_x, random=random, dropout=dropout)

            # 1 - for bonafide speech class
            batch_score = (batch_out[:, 1]).data.cpu().numpy().ravel()
            loss = loss_fn(classes, batch_y)
            current_loss += loss.item() / len(data_loader)

        # add outputs
        fname_list.extend(utt_id)
        score_list.extend(batch_score.tolist())
    assert len(trial_lines) == len(fname_list) == len(score_list)

    # saving results
    with open(save_path, "w") as fh:

        # fn - uid, sco - score, trl - trial_lines
        for fn, sco, trl in zip(fname_list, score_list, trial_lines):

            # first - id of speaker
            # utt_id - utterance id
            # third - "-"
            # src - type of spoof if exist
            # key - spoof or bonafide
            _, utt_id, _, src, key = trl.strip().split(' ')
            assert fn == utt_id
            # format: utterance id - type of spoof attack - key - score
            fh.write("{} {} {} {}\n".format(utt_id, src, key, sco))
    print("Scores saved to {}".format(save_path))
    
    return current_loss
    
def obtain_asv_error_rates(tar_asv, non_asv, spoof_asv, asv_thresholds):
    """
    Calculate false alarm rate and miss rate for asv scores

    args:
        tar_asv: scores for asv targets
        non_asv: scores for asv nontargets
        spoof_asv: scores for asv spoofed
        asv_threshold: threshold for asv EER between targets and non_targets
    returns:
        Pfa_asv: false alarm rate for asv
        Pmiss_asv: false miss rate for asv
        Pmiss_spoof_asv: rate of rejection spoofs in asv
    todo:
        rewrite to torch
    """
    Pfa_asv = sum(non_asv >= asv_thresholds) / non_asv.size
    Pmiss_asv = sum(tar_asv < asv_thresholds) / tar_asv.size

    if spoof_asv.size == 0:
        Pmiss_spoof_asv = None
    else:
        Pmiss_spoof_asv = np.sum(spoof_asv < asv_thresholds) / spoof_asv.size

    return Pfa_asv, Pmiss_asv, Pmiss_spoof_asv

def compute_tDCF(bonafide_score_cm,spoof_score_cm, Pfa_asv,
    Pmiss_asv, Pmiss_spoof_asv, cost_model):
    """
    This function computes min t-DCF value
    
    args:
        bonafide_score_cm: score for bonafide speech from CM system
        spoof_score_cm: score for spoofed speech from CM systn
        Pfa_asv: false alarm rate from asv system
        Pmiss_asv: miss rate from asv sustem
        Pmiss_spoof_asv: miss rate for spoof utterance from asv system
        cost_model: dict of parameters for t-DCF
    output:
        t-DCF: computed value
        CM_threshold: threshold for EER between Pmiss_cm and Pfa_cm
    todo:
        rewrite to torch
    """

    # obtain miss and false alarm rate of cm
    Pmiss_cm, Pfa_cm, CM_thresholds = compute_det_curve(
        bonafide_score_cm, spoof_score_cm
    )

    # Constants
    C1 = cost_model['Ptar'] * (cost_model['Cmiss_cm'] - cost_model['Cmiss_asv'] * Pmiss_asv) - \
        cost_model['Pnon'] * cost_model['Cfa_asv'] * Pfa_asv
    
    C2 = cost_model['Cfa_cm'] * cost_model['Pspoof'] * (1 - Pmiss_spoof_asv)

    # obtain t-DCF curve for all thresholds
    tDCF = C1 * Pmiss_cm + C2 * Pfa_cm

    # normalized t-DCF
    tDCFnorm = tDCF / np.minimum(C1, C2)

    return tDCFnorm, CM_thresholds


def calculate_eer_tdcf(cm_scores_file, asv_score_file, output_file, printout=True):
    """
    Function cimputes tdcf, eer for CM sustem, and also compute
    EER of each type of attack and write them into file
    args:
        cm_scores_file: file from produce_evaluation file
        asv_score_file: file from organizers
        ouput_file: file where information of each type of attack for eval dataset will be
        printout: print this file or not
    output:
        EER * 100: percentage of equal error rate for CM system
        min_tDCF: value of t-DCF for CM system
    todo:
        rewrite into torch
        return array instead of create file
    """
    # cm data from file
    cm_data = np.genfromtxt(cm_scores_file, dtype=str)

    # type of spoof attack
    cm_sources = cm_data[:,1]

    # spoof or bonafide speech
    cm_keys = cm_data[:, 2]

    # score for utterance
    cm_scores = cm_data[:, 3].astype(np.float64)

    # score for bonafide speech
    bona_cm = cm_scores[cm_keys == 'bonafide']

    # score for spoofed utterance
    spoof_cm = cm_scores[cm_keys == 'spoof']

    # equal error rate
    EER, _ = compute_eer(bona_cm, spoof_cm)




    # fix parameters for t-DCF
    cost_model = {
        'Pspoof': 0.05,
        'Ptar': 0.9405,
        'Pnon': 0.0095,
        'Cmiss': 1,
        'Cfa': 10, ###########
        'Cmiss_asv': 1,
        'Cfa_asv': 10,
        'Cmiss_cm': 1,
        'Cfa_cm' : 10,
    }

    # load organizers' ASV scores
    asv_data = np.genfromtxt(asv_score_file, dtype=str)

    # keys: target, non-target, spoof
    asv_keys = asv_data[:, 1]

    # score for each utterance
    asv_scores = asv_data[:, 2].astype(np.float64)


    # target, non-target and spoof scores from the ASV scores
    tar_asv = asv_scores[asv_keys == 'target']
    non_asv = asv_scores[asv_keys == 'nontarget']
    spoof_asv = asv_scores[asv_keys == 'spoof']
    
    #EER of the standalone systems and fix ASV operation point to
    eer_asv, asv_threshold = compute_eer(tar_asv, non_asv)

    # generate attack types from A07 to A19
    attack_types = [f'A{_id:02d}' for _id in range(7,20)]

    # compute eer for each type of attack
    if printout:
        spoof_cm_breakdown = {
            attack_type: cm_scores[cm_sources == attack_type]
            for attack_type in attack_types
        }

        eer_cm_breakdown = {
            attack_type: compute_eer(bona_cm, spoof_cm_breakdown[attack_type])[0]
            for attack_type in attack_types
        }
    [Pfa_asv, Pmiss_asv, Pmiss_spoof_asv] = obtain_asv_error_rates(
        tar_asv,
        non_asv,
        spoof_asv,
        asv_threshold
    )

    # Compute t-DCF
    tDCF_curve, CM_thresholds = compute_tDCF(
        bona_cm,
        spoof_cm,
        Pfa_asv,
        Pmiss_asv,
        Pmiss_spoof_asv,
        cost_model
    )

    # Minimum t-DCF
    min_tDCF_index = np.argmin(tDCF_curve)
    min_tDCF = tDCF_curve[min_tDCF_index]
    # write results into file
    if printout:
        with open(output_file, 'w') as f_res:
            f_res.write('\nCM SYSTEM\n')
            f_res.write("""\tEER\t\t= {:8.9f} % 
            (Equal error rate for countermeasure)\n""".format(EER*100)
            )
            f_res.write('\nTANDEM\n')
            f_res.write('\tmin-tDCF\t\t= {:8.9f}\n'.format(min_tDCF))
            f_res.write('\nBREAKDOWN CM SYSTEM\n')
            for attack_type in attack_types:
                _eer = eer_cm_breakdown[attack_type] * 100
                f_res.write(
                    f'\tEER {attack_type}\t\t= {_eer:8.9f} % (Equal error rate for {attack_type})\n'
                )
        os.system(f"cat {output_file}")
    return EER * 100, min_tDCF

# test summary

In [15]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mslenser0[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [22]:
wandb.init(
    project="CapsNet",
    config = {
        'batch_size': 32,
        'd_args': {
            "nb_samp": 64600, 
            "first_conv": 128,
            "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
        },
        'num_class': 2,
        'gpu_id': 1,
        'num_capsules': 30,
        'epoches': 40,
        'opt': 'AdaBound',
        'lr': 0.0001,
        'weight_decay': 0.00001,
        'random': True,
        'dropout': 0.05,
        'random_size': 0.01,
        'num_iterations': 2,
        'gamma': 0.5,
        'step_size': 10
    }
)

VBox(children=(Label(value='0.026 MB of 0.033 MB uploaded\r'), FloatProgress(value=0.7781200058114195, max=1.0…

0,1
dev_eer,▄▇█▃▃▄▁▄▃▂▂▂▃▂▂▄▁▄▃
dev_loss,█▂█▂▂▆▁▁▁▂▁▁▁▁▁▁▁▁▁
dev_tdcf,▄█▆▃▃▄▁▅▃▂▂▂▄▃▂▄▁▅▃
train_loss,▇█▃▃▅▂▂▂▂▂▁▁▁▁▁▁▁▁▁

0,1
dev_eer,3.76758
dev_loss,9.26193
dev_tdcf,0.12549
train_loss,9.33968


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011112069245427847, max=1.0…

In [23]:
model = CapsuleNet(num_class=wandb.config['num_class'],
                   gpu_id=wandb.config['gpu_id'],
                   d_args=wandb.config['d_args'],
                   num_capsules=wandb.config['num_capsules'],
                   num_iterations=wandb.config['num_iterations']).to(device)

In [14]:
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
params = sum([np.prod(p.size()) for p in model_parameters])
params

1612724

In [13]:
d_args= {
            "nb_samp": 64600, 
            "first_conv": 128,
            "filts": [70, [1, 32], [32, 32], [32, 64], [64, 64]],
        }

model = CapsuleNet(num_class=2,
                   gpu_id=-1,
                   d_args=d_args,
                   num_capsules=30,
                   num_iterations=2).to(device)

summary(model, (1,64600), mode='train', col_names=(
                "input_size",
                "output_size",
                "num_params",
                "mult_adds",
                "trainable",
            ), depth=4)

Layer (type:depth-idx)                             Input Shape               Output Shape              Param #                   Mult-Adds                 Trainable
CapsuleNet                                         [1, 64600]                [1, 4, 2]                 --                        --                        True
├─Encoder: 1-1                                     [1, 64600]                [1, 64, 23, 29]           --                        --                        True
│    └─SincConv: 2-1                               [1, 1, 64600]             [1, 70, 64472]            140                       --                        True
│    └─BatchNorm2d: 2-2                            [1, 1, 23, 21490]         [1, 1, 23, 21490]         2                         2                         True
│    └─SELU: 2-3                                   [1, 1, 23, 21490]         [1, 1, 23, 21490]         --                        --                        --
│    └─Sequential: 2-4               

In [25]:
torch.rand(1,1,32).view(-1,32).shape

torch.Size([1, 32])

# Train

In [27]:
!nvidia-smi

Fri Feb 23 10:59:22 2024       
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 545.23.08              Driver Version: 545.23.08    CUDA Version: 12.3     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|   0  NVIDIA A100-SXM4-40GB          Off | 00000000:4C:00.0 Off |                    0 |
| N/A   37C    P0              76W / 400W |  28442MiB / 40960MiB |     13%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+
|   1  NVIDIA A100-SXM4-40GB          Off | 00000000:88:00.0 Off |  

In [26]:
del model
del optimizer
del loss_fn
gc.collect()

3929

In [28]:
batch_size = wandb.config['batch_size']
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=2
)
dev_loader = DataLoader(
    dev_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=2
)


d_args = wandb.config['d_args']

device = 'cuda:1' if torch.cuda.is_available else 'cpu'
model = CapsuleNet(num_class=wandb.config['num_class'],
                   gpu_id=wandb.config['gpu_id'],
                   d_args=d_args,
                   num_capsules=wandb.config['num_capsules'],
                   num_iterations=wandb.config['num_iterations']).to(device)

optimizer = AdaBound(model.parameters(),
                     lr=wandb.config['lr'],
                     weight_decay=wandb.config['weight_decay'])

scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                            step_size=wandb.config['step_size'], 
                                            gamma=wandb.config['gamma'])

loss_fn = CapsuleLoss(gpu_id=wandb.config['gpu_id'], weight=torch.FloatTensor([0.1,0.9]))

epoches = wandb.config['epoches']

best_score = 2
best_state = None


for epoch in range(epoches):
    # train part
    train_loss = 0
    prefix = '%s / %s, best_score %s ' % (epoch + 1, epoches, best_score)
    for data, label, _ in progressbar(train_loader, prefix=prefix):
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        classes, class_ = model.forward(data,
                                        random=wandb.config['random'],
                                        dropout=wandb.config['dropout'],
                                        random_size=wandb.config['random_size'])

        loss = loss_fn(classes, label)
        train_loss += loss.item() / len(train_loader)
        loss.backward()
        optimizer.step()
    scheduler.step()
        
    # val_part
    dev_loss = produce_evaluation_file(dev_loader, model, device, loss_fn, "pruduced_file.txt", dev_label_path)
    eer, tdcf = calculate_eer_tdcf('pruduced_file.txt',
                              "LA/ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.dev.gi.trl.scores.txt",
                              None,
                              printout=False)

    if best_score > eer:
        best_score = eer
        best_state = deepcopy(model.state_dict())
        path = 'best_checkpoint' + str(best_score) + ".pth"
        torch.save(best_state, path)

    clear_output()

    metrics = {
        "train_loss": train_loss,
        "dev_loss": dev_loss,
        "dev_eer": eer,
        "dev_tdcf": tdcf
    }
    wandb.log(metrics)

wandb.finish()
    

        

4 / 40, best_score 1.8837882093302667 [████████....................................................] 110/794 time 00:57.44 / 05:57.18

IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



16 / 40, best_score 1.492453609745629 [███████.....................................................] 104/794 time 00:54.34 / 06:00.52

KeyboardInterrupt: 

In [None]:
for epoch in range(26):
    # train part
    train_loss = 0
    prefix = '%s / %s, best_score %s ' % (epoch + 1, epoches, best_score)
    for data, label, _ in progressbar(train_loader, prefix=prefix):
        data, label = data.to(device), label.to(device)
        optimizer.zero_grad()
        classes, class_ = model.forward(data,
                                        random=wandb.config['random'],
                                        dropout=wandb.config['dropout'],
                                        random_size=wandb.config['random_size'])

        loss = loss_fn(classes, label)
        train_loss += loss.item() / len(train_loader)
        loss.backward()
        optimizer.step()
    scheduler.step()
        
    # val_part
    dev_loss = produce_evaluation_file(dev_loader, model, device, loss_fn, "pruduced_file.txt", dev_label_path)
    eer, tdcf = calculate_eer_tdcf('pruduced_file.txt',
                              "LA/ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.dev.gi.trl.scores.txt",
                              None,
                              printout=False)

    if best_score > eer:
        best_score = eer
        best_state = deepcopy(model.state_dict())
        path = 'best_checkpoint' + str(best_score) + ".pth"
        torch.save(best_state, path)

    clear_output()

    metrics = {
        "train_loss": train_loss,
        "dev_loss": dev_loss,
        "dev_eer": eer,
        "dev_tdcf": tdcf
    }
    wandb.log(metrics)

7 / 40, best_score 1.492453609745629 [████████████████████████████████████████....................] 531/794 time 04:39.65 / 02:18.51

# Test

In [34]:
model.load_state_dict(torch.load("best_checkpoint1.2491311420651727.pth"))

<All keys matched successfully>

In [35]:
loss_fn = CapsuleLoss(gpu_id=1, weight=torch.FloatTensor([0.1,0.9])).to('cuda:1')

In [36]:
eval_label_path = "LA/ASVspoof2019_LA_cm_protocols/ASVspoof2019.LA.cm.eval.trl.txt"
eval_path_flac = "LA/ASVspoof2019_LA_eval"
eval_IDs, eval_labels = get_data_for_dataset(eval_label_path)

eval_dataset = ASVspoof2019_dev_eval(eval_IDs, eval_path_flac, eval_labels)
eval_loader = DataLoader(
    eval_dataset,
    batch_size=32,
    shuffle = False,
    num_workers =2
    )

In [37]:
dev_loss = produce_evaluation_file(eval_loader, model, device, loss_fn, "pruduced_file_eval.txt", eval_label_path)
eer, tdcf = calculate_eer_tdcf('pruduced_file_eval.txt',
                              "LA/ASVspoof2019_LA_asv_scores/ASVspoof2019.LA.asv.eval.gi.trl.scores.txt",
                              None,
                              printout=False)

computing cm score[████████████████████████████████████████████████████████████] 2227/2227 time 03:31.75 / 00:00.00

Scores saved to pruduced_file_eval.txt


In [38]:
eer, tdcf

(2.2701871020649453, 0.07593021271811795)