In [102]:
import os
import sys
import pandas as pd
import torch
import torchaudio
import torch.nn.functional as F
import torchaudio.transforms as T
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import json
import numpy as np
import math
import seaborn as sns
sns.set_theme()

from typing import Optional

from datetime import timedelta

from src.utils import (
    create_dataset, plot_spectrogram,
    RandomClip, extract_logmel
)
from src.datasets import VoxCelebDataModule
from src.models import (
    SEBlock, SpeakerRecognitionModel, ResNetBlock, build_efficientnetv2,
    SEResNetBlock, conv1x1, conv3x3
)
from torch import nn
from sklearn.decomposition import PCA

from src.resnetse import ResNetSE, SEBasicBlock, ResNetSEV2

from src.losses import SubCenterAAMSoftmaxLoss
from sklearn.cluster import KMeans
from sklearn.metrics import roc_curve, accuracy_score
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier

In [132]:

class SelfAttentionPooling(nn.Module):
    """Implementation of Self Attention Pooling (SAP) as
    described in [1]. We used GELU instead of tanh as
    non-linearity.

    References
    ----------
        [1] W. Cai, J. Chen and M. Li, "Exploring the Encoding 
        Layer and Loss Function in End-to-End Speaker and 
        Language Recognition System", 2018,
        https://arxiv.org/abs/1804.05160
    """
    def __init__(
        self,
        n_mels
    ) -> None:
        super(SelfAttentionPooling, self).__init__()
        self.linear = nn.Linear(n_mels, n_mels)
        self.attention = nn.Parameter(
            torch.FloatTensor(size=(n_mels, 1))
        )
        self.gelu = nn.GELU()
        self.softmax = nn.Softmax2d()

        nn.init.xavier_normal_(self.attention)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        y = x.permute(0,3,1,2)
        h = self.gelu(self.linear(y))
        mul = torch.matmul(h, self.attention)
        w = self.softmax(mul)
        w = w.permute(0,2,3,1)
        e = torch.sum(x * w, dim=-1)
        
        return e

class ResNet34SE(SpeakerRecognitionModel):
    """ResNet34 model, as described in [1]. The
    implementation is a simplified and slightly
    modified version of the official PyTorch 
    Vision ResNet.

    References
    ----------
        [1] K. He, X. Zhang, S. Ren and J. Sun, "Deep 
        Residual Learning for Image Recognition", 2016 IEEE 
        Conference on Computer Vision and Pattern Recognition 
        (CVPR), 2016, pp. 770-778.
    """
    def __init__(self, n_mels=80, **kwargs) -> None:
        super(ResNet34SE, self).__init__(**kwargs)

        self.current_channels = 64
        num_heads = 8
        dropout = 0.3
        self.attn_expansion = 3

        # self.instancenorm   = nn.InstanceNorm2d(n_mels)
        self.conv1 = nn.Conv2d(
            1, 
            self.current_channels, 
            kernel_size=7,
            stride=2,
            padding=3,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.current_channels)
        self.gelu = nn.GELU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2_x = self._make_sequence(64, num_blocks=3)
        self.conv3_x = self._make_sequence(128, num_blocks=4, stride=2)
        self.conv4_x = self._make_sequence(256, num_blocks=6, stride=2)
        self.conv5_x = self._make_sequence(512, num_blocks=3, stride=2)
        self.pool = nn.AdaptiveAvgPool2d((1, 1))
        self.attn_pool = SelfAttentionPooling(self.attn_expansion)
        # self.pool1 = SelfAttentionPooling(10)
        # self.pool2 = VarSelfAttentionPooling(512, 10)
        # self.sp = StatsPoolingLayer()

        """
        outmap_size = int(n_mels/8)

        self.attention = nn.Sequential(
            nn.Conv1d(256 * outmap_size, 128, kernel_size=1),
            nn.ReLU(),
            nn.BatchNorm1d(128),
            nn.Conv1d(128, 256 * outmap_size, kernel_size=1),
            nn.Softmax(dim=2),
        )
        
        self.mha_dim = int(self.current_channels / 32)

        self.qkv_proj = nn.Linear(self.mha_dim, 3 * self.mha_dim)
        self.mha = nn.MultiheadAttention(
            embed_dim=self.mha_dim,
            num_heads=num_heads, 
            batch_first=True
        )
        self.norm1 = nn.LayerNorm(self.mha_dim)
        self.dropout = nn.Dropout(dropout)
        """
        
        self.embeddings = nn.Linear(
            self.current_channels * self.attn_expansion, 
            self.embeddings_dim
        )
        # self.embeddings = nn.Linear(256 * outmap_size, self.embeddings_dim)
        # self.clf = nn.Linear(self.embeddings_dim, self.num_classes)

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(
                    m.weight, 
                    mode="fan_out", 
                    nonlinearity="relu"
                )
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        self._set_optimizers()
        self._set_hyperparams()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x = self.instancenorm(x)
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.gelu(x)
        x = self.maxpool(x)

        x = self.conv2_x(x)
        x = self.conv3_x(x)
        x = self.conv4_x(x)
        x = self.conv5_x(x)
        x = self.attn_pool(x)
        x = torch.flatten(x, start_dim=1)
        
        # x = x.reshape(x.size()[0],-1,x.size()[-1])
        

        """
        w = self.attention(x)
        x = torch.sum(x * w, dim=2)
        x = x.view(x.size()[0], -1)
        """
        # x = self.pool1(x)
        # x = self.pool2(x)
        # x = self.pool(x)
        # x = torch.flatten(x, 1)
        """
        x = x.reshape(x.shape[0],-1,self.mha_dim)

        qkv = self.qkv_proj(x)
        q, k, v = qkv.chunk(3, dim=-1)
        
        attn, _ = self.mha(query=q, key=k, value=v)
        x = x + self.dropout(attn)
        x = self.norm1(x)
        x = torch.flatten(x, 1)
        """
        # x = self.sp(x)
        x = self.embeddings(x)

        return x
    
    def _make_sequence(
        self,
        out_channels: int,
        num_blocks: int,
        stride: int = 1
    ):
        downsample = None

        # downsample when we increase the dimension, using
        # the 1x1 convolution option, as described in the
        # ResNet paper.
        if stride != 1 or self.current_channels != out_channels:
            downsample = nn.Sequential(
                conv1x1(self.current_channels, out_channels, stride),
                nn.BatchNorm2d(out_channels),
            )

        layers = []
        layers.append(
            SEResNetBlock(
                in_channels=self.current_channels, 
                out_channels=out_channels, 
                stride=stride, 
                downsample=downsample,
                se_ratio=8
            )
        )
        self.current_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(
                SEResNetBlock(
                    in_channels=self.current_channels,
                    out_channels=out_channels,
                    se_ratio=8
                )
            )

        return nn.Sequential(*layers)

In [133]:
resnet = ResNet34SE(num_classes=5)

In [134]:
a = torch.randn((4,1,80,601))

In [135]:
res = resnet(a)
res.shape

torch.Size([4, 512])

In [117]:
attn = SelfAttentionPooling(n_mels=3)

In [118]:
torch.flatten(
    res,
    start_dim=2
).shape

torch.Size([4, 512, 57])

In [119]:
attn(res).shape

torch.Size([4, 512, 3])