In [10]:
import os
import random
import torch
import numpy as np
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torch.distributions import Beta

import pytorch_lightning as pl
import torchvision
from resnest.torch import resnest50
from efficientnet_pytorch import EfficientNet
from sklearn.metrics import accuracy_score
from src.dataset import SpectrogramDataset
import src.configuration as C
from src.metric import LWLRAP
from src.conformer import ConformerBlock
import pytorch_lightning as pl
from torchlibrosa.stft import Spectrogram, LogmelFilterBank
from torchlibrosa.augmentation import SpecAugmentation
from pytorch_lightning.metrics import F1

class ConformerSED(nn.Module):
    def __init__(self):
        super().__init__()

        self.att_block = AttBlockV2(128, 24, activation="sigmoid")

        self.interpolate_ratio = 30  # Downsampled ratio
        self.convblock = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=(3,3)), nn.ReLU(), nn.MaxPool2d((2,2)),
            nn.Conv2d(16, 32, kernel_size=(3,3)), nn.ReLU(), nn.MaxPool2d((2,2)),
            nn.Conv2d(32, 64, kernel_size=(3,3)), nn.ReLU(), nn.MaxPool2d((1,2)),
            nn.Conv2d(64, 128, kernel_size=(3,3)), nn.ReLU(), nn.MaxPool2d((1,2)),
            nn.Conv2d(128, 128, kernel_size=(3,3)), nn.ReLU(), nn.MaxPool2d((1,2)),
            nn.Conv2d(128, 128, kernel_size=(3,3)), nn.ReLU(), nn.MaxPool2d((1,2)),
            nn.Conv2d(128, 128, kernel_size=(3,3)), nn.ReLU(), nn.MaxPool2d((1,2))
        )
        self.conformerblock = ConformerBlock(dim=128)


    def forward(self, input):
        frames_num = input.size(3)

        # (batch_size, channels, freq, frames) ex->(120, 1408, 7, 12)
        x = self.convblock(input)
        x = self.conformerblock(x)
    
        # (batch_size, channels, frames) ex->(120, 1408, 12)
        x = torch.mean(x, dim=2)

        # channel smoothing
        # channel次元上でpoolingを行う
        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2

        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)  # torch.Size([120, 1408, 12]) -> torch.Size([120, 12, 1408])
        x = F.relu_(self.fc1(x))
        x = x.transpose(1, 2)  # torch.Size([120, 12, 1408]) -> torch.Size([120, 1408, 12])
        x = F.dropout(x, p=0.5, training=self.training)

        (clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
        logit = torch.sum(norm_att * self.att_block.cla(x), dim=2)  # claにsigmoidをかけない状態でclipwiseを計算
        segmentwise_logit = self.att_block.cla(x).transpose(1, 2)  # torch.Size([120, 12, 24])
        segmentwise_output = segmentwise_output.transpose(1, 2)  # torch.Size([120, 12, 24])

        # Get framewise output
        framewise_output = interpolate(segmentwise_output, self.interpolate_ratio)  # n_time次元上でをupsampling
        framewise_output = pad_framewise_output(framewise_output, frames_num)  # n_timesの最後の値で穴埋めしてframes_numに合わせる

        framewise_logit = interpolate(segmentwise_logit, self.interpolate_ratio)
        framewise_logit = pad_framewise_output(framewise_logit, frames_num)

        output_dict = {
            "clipwise_output": clipwise_output,
            "framewise_output": framewise_output,
            "segmentwise_output": segmentwise_output,
            "logit": logit,
            "framewise_logit": framewise_logit,
            "segmentwise_logit": segmentwise_logit  
        }

        return output_dict

In [11]:

def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)


def init_weights(model):
    classname = model.__class__.__name__
    if classname.find("Conv2d") != -1:
        nn.init.xavier_uniform_(model.weight, gain=np.sqrt(2))
        model.bias.data.fill_(0)
    elif classname.find("BatchNorm") != -1:
        model.weight.data.normal_(1.0, 0.02)
        model.bias.data.fill_(0)
    elif classname.find("GRU") != -1:
        for weight in model.parameters():
            if len(weight.size()) > 1:
                nn.init.orghogonal_(weight.data)
    elif classname.find("Linear") != -1:
        model.weight.data.normal_(0, 0.01)
        model.bias.data.zero_()

def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.
    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled

# n_timeの最後の値で穴埋めしてframe数になるようにする
def pad_framewise_output(framewise_output: torch.Tensor, frames_num: int):
    """Pad framewise_output to the same length as input frames. 
       The pad value is the same as the value of the last frame.
    Args:
      framewise_output: (batch_size, frames_num, classes_num)
      frames_num: int, number of frames to pad
    Outputs:
      output: (batch_size, frames_num, classes_num)
    """
    pad = framewise_output[:, -1:, :].repeat(
        1, frames_num - framewise_output.shape[1], 1)
    """tensor for padding"""

    output = torch.cat((framewise_output, pad), dim=1)
    """(batch_size, frames_num, classes_num)"""

    return output


class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
            bias=False)

        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
            bias=False)

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()

    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

    def forward(self, input, pool_size=(2, 2), pool_type='avg'):

        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        if pool_type == 'max':
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg':
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == 'avg+max':
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception('Incorrect argument!')

        return x


class AttBlock(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear",
                 temperature=1.0):
        super().__init__()

        self.activation = activation
        self.temperature = temperature
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.bn_att = nn.BatchNorm1d(out_features)
        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)
        init_bn(self.bn_att)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)


class AttBlockV2(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear"):
        super().__init__()

        self.activation = activation
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)

    def forward(self, x):  
        """
        Args:
        x: (n_samples, n_in, n_time)  ex)torch.Size([120, 1408, 12])
        Outputs:
        x:(batch_size, classes_num) ex)torch.Size([120, 24])
        norm_att: batch_size, classes_num, n_time) ex)torch.Size([120, 24, 12])
        cla: batch_size, classes_num, n_time) ex)torch.Size([120, 24, 12])
        """
        norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1)  # torch.Size([120, 24, 12]) クラス数に圧縮/valueを-1~1/n_timeの次元の総和=１に変換
        cla = self.nonlinear_transform(self.cla(x))  # self.cla()=self.att()/sigmoid変換
        x = torch.sum(norm_att * cla, dim=2)  # 要素同士の積 torch.Size([120, 24]): (n_samples, n_class)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == 'linear':
            return x
        elif self.activation == 'sigmoid':
            return torch.sigmoid(x)


In [12]:
model = ConformerSED()

In [13]:
model

ConformerSED(
  (att_block): AttBlockV2(
    (att): Conv1d(128, 24, kernel_size=(1,), stride=(1,))
    (cla): Conv1d(128, 24, kernel_size=(1,), stride=(1,))
  )
  (convblock): Sequential(
    (0): Conv2d(3, 16, kernel_size=(3, 3), stride=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=(2, 2), stride=(2, 2), padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)
    (9): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1))
    (10): ReLU()
    (11): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), padding=0, dilation=1, ceil_mode=False)
    (12): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1))
    (13): ReLU()
    (14): MaxPool2d(kernel_size=(1, 2), stride=(1, 2), 