In [None]:
import os
import sys
import pandas as pd
import torch
import torchaudio
import torch.nn.functional as F
import torchaudio.transforms as T
import librosa
import matplotlib.pyplot as plt
from IPython.display import Audio, display
import json
import numpy as np
import math
import random
import seaborn as sns
sns.set_theme()

from typing import Optional

from datetime import timedelta

from src.utils import (
    create_dataset, plot_spectrogram,
    RandomClip, extract_logmel
)
from src.datasets import VoxCelebDataModule
from src.models import (
    SEBlock, SpeakerRecognitionModel, ResNetBlock, build_efficientnetv2,
    SEResNetBlock, conv1x1, conv3x3, ResNet34SE, ResNet20,
    SelfAttentivePooling
)
from torch import nn
from sklearn.decomposition import PCA

from src.resnetse import ResNetSE, SEBasicBlock, ResNetSEV2

from src.losses import SubCenterAAMSoftmaxLoss
from sklearn.cluster import KMeans
from sklearn.metrics import roc_curve, accuracy_score
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split
from sklearn.neural_network import MLPClassifier

In [None]:
class ResNet34SEV2(SpeakerRecognitionModel):
    """ResNet34 model, as described in [1]. The
    implementation is a simplified and slightly
    modified version of the official PyTorch 
    Vision ResNet.

    References
    ----------
        [1] K. He, X. Zhang, S. Ren and J. Sun, "Deep 
        Residual Learning for Image Recognition", 2016 IEEE 
        Conference on Computer Vision and Pattern Recognition 
        (CVPR), 2016, pp. 770-778.
    """
    def __init__(self, n_mels=80, **kwargs) -> None:
        super(ResNet34SEV2, self).__init__(**kwargs)

        self.current_channels = 32
        self.attn_expansion = int(n_mels / 8)

        self.conv1 = nn.Conv2d(
            1, 
            self.current_channels, 
            kernel_size=3,
            stride=1,
            padding=1,
            bias=False
        )
        self.bn1 = nn.BatchNorm2d(self.current_channels)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.conv2_x = self._make_sequence(32, num_blocks=3)
        self.conv3_x = self._make_sequence(64, num_blocks=4, stride=2)
        self.conv4_x = self._make_sequence(128, num_blocks=6, stride=2)
        self.conv5_x = self._make_sequence(256, num_blocks=3, stride=2)
        self.attn_pool = SelfAttentivePooling(self.attn_expansion)
        
        self.embeddings = nn.Linear(
            self.current_channels * self.attn_expansion, 
            self.embeddings_dim
        )

        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.xavier_normal_(
                    m.weight
                )
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

        self._set_optimizers()
        self._set_hyperparams()

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        # x = self.maxpool(x)

        x = self.conv2_x(x)
        x = self.conv3_x(x)
        x = self.conv4_x(x)
        x = self.conv5_x(x)
        x = self.attn_pool(x)
        x = torch.flatten(x, start_dim=1)
        x = self.embeddings(x)

        return x
    
    def _make_sequence(
        self,
        out_channels: int,
        num_blocks: int,
        stride: int = 1
    ):
        downsample = None

        # downsample when we increase the dimension, using
        # the 1x1 convolution option, as described in the
        # ResNet paper.
        if stride != 1 or self.current_channels != out_channels:
            downsample = nn.Sequential(
                conv1x1(self.current_channels, out_channels, stride),
                nn.BatchNorm2d(out_channels),
            )

        layers = []
        layers.append(
            SEResNetBlock(
                in_channels=self.current_channels, 
                out_channels=out_channels, 
                stride=stride, 
                downsample=downsample,
                se_ratio=8
            )
        )
        self.current_channels = out_channels
        for _ in range(1, num_blocks):
            layers.append(
                SEResNetBlock(
                    in_channels=self.current_channels,
                    out_channels=out_channels,
                    se_ratio=8
                )
            )

        return nn.Sequential(*layers)

In [None]:
resnet34 = ResNet34SEV2(n_mels=80, num_classes=5)

In [None]:
a = torch.randn((4,1,80,301))

In [None]:
resnet34(a).shape