In [2]:
import os
import sys
import pandas as pd
import torch
import torchaudio
import torch.nn.functional as F
import torchaudio.transforms as T
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import json
import numpy as np
import math
import seaborn as sns
sns.set_theme()

from typing import Optional

from datetime import timedelta

from src.utils import (
    create_dataset, plot_spectrogram,
    RandomClip, extract_logmel
)
from src.datasets import VoxCelebDataModule
from src.models import SEBlock, SpeakerRecognitionModel, ResNetBlock, build_efficientnetv2
from torch import nn
from sklearn.decomposition import PCA

from src.resnetse import ResNetSE, SEBasicBlock, ResNetSEV2

from src.losses import SubCenterAAMSoftmaxLoss
from sklearn.cluster import KMeans
from sklearn.metrics import roc_curve, accuracy_score
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier

In [94]:
class SERes2Block(nn.Module):
    """Variant of the SE-Res2Block described in [1]. We
    modified the architecture to follow more closely the
    Res2Block described in [2], by using 2d convolution.
    We also inverted the order of RELU and Batch 
    Normalization.

    References
    ----------
        [1] B. Desplanques et al., "ECAPA-TDNN: Emphasized 
        Channel Attention, Propagation and Aggregation TDNN 
        Based Speaker Verification", Proc. Interspeech 
        2020, 2020, pp. 3830-3834.

        [2] S.-H. Gao et al., "Res2Net: A New Multi-Scale 
        Backbone Architecture", IEEE Transactions on Pattern 
        Analysis and Machine Intelligence, vol. 43, no. 2, 
        2021, pp. 652-662.
    """
    def __init__(
        self,
        n_channels: int,
        scale: int,
        dilation: int
    ) -> None:
        super(SERes2Block, self).__init__()
        self.scale = scale
        self.conv1 = nn.Conv2d(n_channels, n_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(n_channels)
        self.relu1 = nn.ReLU()

        conv_ls = [
            nn.Conv2d(
                n_channels // scale, 
                n_channels // scale, 
                kernel_size=3,
                padding=dilation,
                dilation=dilation
            )
            for _ in range(scale - 1)
        ]

        self.K = nn.ModuleList([nn.Identity()] + conv_ls)
        self.bn2 = nn.BatchNorm2d(n_channels)
        self.relu2 = nn.ReLU()
        self.conv2 = nn.Conv2d(n_channels, n_channels, kernel_size=1)
        self.bn3 = nn.BatchNorm2d(n_channels)
        self.relu3 = nn.ReLU()
        self.se = SEBlock(n_channels)

        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out_ls = torch.split(out, out.size(1)//self.scale, dim=1)
        y_ls = []
        
        for idx in range(self.scale):
            out_split = out_ls[idx]
            k_fun = self.K[idx]
            if idx <= 1:
                y_ls.append(k_fun(out_split))
            else:
                prev = y_ls[idx - 1]
                y_ls.append(k_fun(out_split + prev))

        out = torch.cat(y_ls, dim=1)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu3(out)
        out = self.se(out)

        return out + x

class Var_SERes2Block(SERes2Block):
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)

        out_ls = torch.split(out, out.size(1)//self.scale, dim=1)
        y_ls = []
        
        for idx in range(self.scale):
            out_split = out_ls[idx]
            k_fun = self.K[idx]
            y_ls.append(k_fun(out_split))

        out = torch.cat(y_ls, dim=1)
        out = self.bn2(out)
        out = self.relu2(out)
        out = self.conv2(out)
        out = self.bn3(out)
        out = self.relu3(out)
        out = self.se(out)

        return out + x


class AttentiveStatPooling(nn.Module):
    """Attentive stat pooling layer, as described in [1].
    We provide also an implementation with convolution.
    Since the paper worked with MFCC instead of spectrograms,
    we averaged the values of mean and std, so that we
    could remove one dimension.

    References
    ----------
        [1] B. Desplanques et al., "ECAPA-TDNN: Emphasized 
        Channel Attention, Propagation and Aggregation TDNN 
        Based Speaker Verification", Proc. Interspeech 
        2020, 2020, pp. 3830-3834. 
    """
    def __init__(
        self, 
        in_features: int, 
        latent_features: int,
        conv: bool = False
    ) -> None:
        super(AttentiveStatPooling, self).__init__()
        if conv:
            self.seq = nn.Sequential(
                nn.Conv2d(in_features, latent_features, kernel_size=1),
                nn.ReLU(inplace=True),
                nn.Conv2d(latent_features, in_features, kernel_size=1)
            )
        else:
            self.seq = nn.Sequential(
                nn.Linear(in_features, latent_features),
                nn.ReLU(inplace=True),
                nn.Linear(latent_features, in_features)
            )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        attn_weights = F.softmax(self.seq(x), dim=1)
        mean = torch.sum(attn_weights * x, dim=2).mean(-1)
        std = torch.sum(attn_weights * x ** 2, dim=2).mean(-1) - (mean ** 2)
        std = torch.sqrt(std)

        return torch.cat([mean, std], dim=1)

class Var_ECAPA(SpeakerRecognitionModel):
    """Variant of the ECAPA-TDNN model described in [1]. We
    omit the last batch normalization layer and adopt a
    different SE-Res2Block and Attentive Stats pooling
    layer. We also work on spectrograms instead of MFCC.

    References
    ----------
        [1] B. Desplanques et al., "ECAPA-TDNN: Emphasized 
        Channel Attention, Propagation and Aggregation TDNN 
        Based Speaker Verification", Proc. Interspeech 
        2020, 2020, pp. 3830-3834. 
    """
    def __init__(
        self, 
        n_channels: int = 248,
        scale: int = 8, 
        **kwargs
    ) -> None:
        super(Var_ECAPA, self).__init__(**kwargs)

        self.conv1 = nn.Conv2d(1, n_channels, kernel_size=5, padding=1)
        self.bn1 = nn.BatchNorm2d(n_channels)
        self.relu1 = nn.ReLU()
        self.se1 = Var_SERes2Block(n_channels, scale, dilation=2)
        self.se2 = Var_SERes2Block(n_channels, scale, dilation=3)
        self.se3 = Var_SERes2Block(n_channels, scale, dilation=4)
        self.conv2 = nn.Conv2d(n_channels * 3, n_channels, kernel_size=1, padding=1)
        self.attn_pool = AttentiveStatPooling(n_channels, n_channels // 10, conv=True)
        self.bn2 = nn.BatchNorm1d(n_channels * 2)
        self.embeddings = nn.Linear(n_channels * 2, self.embeddings_dim)
        self.clf = nn.Linear(self.embeddings_dim, self.num_classes)

        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.xavier_uniform_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        self._set_optimizers()
        self._set_hyperparams()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out_se1 = self.se1(out)
        out_se2 = self.se2(out_se1)
        out_se3 = self.se3(out_se2)
        out = torch.cat([out_se1, out_se2, out_se3], dim=1)
        out = self.conv2(out)
        out = self.attn_pool(out)
        out = self.bn2(out)
        out = self.embeddings(out)

        # if self.training:
        #    out = self.clf(out)
    
        return out

In [96]:
ecapa = Var_ECAPA(num_classes=8)

In [97]:
a = torch.randn((4,1,80,301))

In [99]:
res = ecapa(a)
res.shape

tensor(False)