<a href="https://colab.research.google.com/github/mans00rahmed/Automatic-Subtitle-File-Creation/blob/main/train_cnn14decisionlevel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train your model

In [None]:
import cv2
import os
import random
import time
import warnings

import librosa
import numpy as np
import pandas as pd

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as data

#from catalyst.dl import SupervisedRunner, State, CallbackOrder, Callback, CheckpointCallback

In [5]:

def seed_all(s:int=42) -> None: 
    random.seed(s)
    np.random.seed(s)
    os.environ["PYTHONHASHSEED"] = str(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed(s)  # type: ignore
    torch.backends.cudnn.deterministic = True  # type: ignore
    torch.backends.cudnn.benchmark = True  # type: ignore

def seed_dataset(s:int=42) -> None:
    random.seed(s)
    

class TrainingDirs(object):
    """
    Initiate the working directory systems for training.
    """
    
    def __init__(self, dsname: str, pre_test: bool) -> None:
        super().__init__()
        ROOT = Path(os.path.split(os.path.abspath(""))[0])
        INPUT_ROOT = ROOT / 'data'
        TARGET_AUDIO_DIR = INPUT_ROOT / dsname
        print(f"Working with dataset under {TARGET_AUDIO_DIR}.")
        assert os.path.exists(TARGET_AUDIO_DIR), "Input dataset folder does not exist."
        
        self.dataset_folder = TARGET_AUDIO_DIR
        self.train_folder = TARGET_AUDIO_DIR / 'train' if pre_test else TARGET_AUDIO_DIR
        self.test_folder = TARGET_AUDIO_DIR / 'test' if pre_test else None

In [6]:
import os
ROOT_PATH_ABS = os.path.split(os.path.abspath(""))[0]

In [7]:
from fastai.imports import *

In [8]:

class ModelConfig(object):
    # Configurations for SED model.
    sed_model_config = {
        "sample_rate": 32000,
        "window_size": 1024,
        "hop_size": 320,
        "mel_bins": 64,
        "fmin": 50,
        "fmax": 14000,
        "classes_num": 527,
    }
    
class DatasetConfig(object):
    """
    Configuration for building your own dataset from sources.
    
    Attributes:
        dataset_clip_time(int): Clip length for dataset. Default 2s.
        dataset_sample_rate(int): Clip length for dataset. Default 32000.
        dataset_audio_format(str): Clip format for dataset. Default using "wav"
        sub_encoding(str): Use "utf-8" encoding for Aegisub subtitle support.
    """
    
    dataset_clip_time = 2 # seconds
    
    dataset_sample_rate = 32000
    
    dataset_audio_format = "wav" 
    
    sub_encoding = "utf-8" # recommanded
    

In [4]:
from csrc.utils import TrainingDirs as TD
from csrc.configurations import DatasetConfig as DC
from csrc.configurations import ModelConfig as MC
from csrc.utils import seed_dataset, seed_all 

ModuleNotFoundError: ignored

## Train configurations

In [9]:
# For better debugging.

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"

In [19]:
### The folder name of your dataset.
DATASET = '/content/drive/My Drive/Auto-Subtitle-File-Generation/ASFG-train/data/standard-p2-32khz'

### Whether you have split your dataset.
### If False then the test dataset will be generated as configured in TrainParams and choose the split ratio.
BUILD_TEST = False
PREBUILD_TEST = False
TEST_RATIO = 5

### The ratio to split your train/validaion dataset.
VALID_RATIO = 5
### Whether to shuffle the dataset.
SHUFFLE = True

### Clip length that will be used for training.
### Default to be the same as the audio clip length in the dataset.
PERIOD = DatasetConfig.dataset_clip_time
print(f"Training clip length (sencods): {PERIOD}")

### Batch size for training. For example: 8gb GPU for 5s clips - batch size 32.
BS = 48

### Training epochs.
EPOCHS = 30

### Weights file path used for training.
### Default under weights folder.
WEIGHTS_PATH = "./weights/Cnn14_DecisionLevelAtt_mAP0.425.pth"

### Default path to store your model.
LOG_DIR = "./train/logs/sp3-3/"

Training clip length (sencods): 2


In [20]:
# Random seeding.
# Change seed will change your validation set randomly picked from the dataset.

SEED = 42
seed_all(SEED)
seed_dataset(SEED)

## Process

In [21]:
# Set up working folder for training.

dirs = TrainingDirs(DATASET, PREBUILD_TEST)
DATASET_FOLDER = dirs.dataset_folder
TRAIN_FOLDER = dirs.train_folder
TEST_FOLDER = dirs.test_folder

### Currently we are training so we set up the training folder as the working folder.
TRAIN_WORKING_FOLDER = TRAIN_FOLDER
TEST_WORKING_FOLDER = TEST_FOLDER if TEST_FOLDER else TRAIN_FOLDER

print(f"FOLDER_FOR_TRAINING: {TRAIN_WORKING_FOLDER}")
print(f"FOLDER_FOR_TEST: {TEST_WORKING_FOLDER}")

Working with dataset under /content/drive/My Drive/Auto-Subtitle-File-Generation/ASFG-train.


AssertionError: ignored

In [None]:
# Train/Test split. If the test folder has not been manually selected, then split the test folder.

def sort_index(x):
    return int(x.split("-")[0])

if not TEST_FOLDER:
    all_files = os.listdir(TRAIN_FOLDER)
    all_files.sort(key=sort_index)
    test_index = len(all_files) // TEST_RATIO
    test_files = all_files[-test_index:]
    train_files = all_files[:-test_index]
else:
    train_files = os.listdir(TRAIN_FOLDER)
    test_files = os.listdir(TEST_FOLDER)

print(f"Files for training: {len(train_files)}")
print(f"Files for testing: {len(test_files)}")

Files for training: 15228
Files for testing: 3806


In [None]:
# Train/Validation split

if SHUFFLE:
    random.shuffle(train_files)

if not BUILD_TEST:
    train_files.extend(test_files)

valid_idx = len(train_files) // VALID_RATIO
valid_files = train_files[-valid_idx:]
train_files = train_files[:-valid_idx]

print(f"Files for training: {len(train_files)}")
print(f"Files for validation: {len(valid_files)}")
print(f"Validation file samples: {valid_files[:5]}")

Files for training: 15228
Files for validation: 3806
Validation file samples: ['2181-dallas-buyers-club-eng-1.wav', '2181-mission-impossible-iv-1.wav', '2181-the-dark-knight-rises-eng-0.wav', '2181-the-kingdom-of-heaven-eng-1.wav', '2181-the-kings-speech-eng-0.wav']


## Dataset

In [None]:
from csrc.dataset import PANNsDataset

## Transformer

In [None]:
from csrc.transformers import BaseAug

## Set up dataloader 

In [None]:
loaders = {
    "train": data.DataLoader(PANNsDataset(train_files, training_folder=TRAIN_WORKING_FOLDER, test_folder=TEST_WORKING_FOLDER, waveform_transforms=BaseAug), # Build training set
                            batch_size=BS,
                            shuffle=True,
                            num_workers=0, # 0 for windows system.
                            pin_memory=True,
                            drop_last=True),
    "valid": data.DataLoader(PANNsDataset(valid_files, training_folder=TRAIN_WORKING_FOLDER, test_folder=TEST_WORKING_FOLDER, waveform_transforms=None), # Build training set.\n",
                             batch_size=BS,
                             shuffle=False,
                             num_workers=0,
                             pin_memory=True,
                             drop_last=False)
}

## Model

In [None]:
import librosa

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

"""
Preprocessors for building models.
"""

class DFTBase(nn.Module):
    def __init__(self):
        """Base class for DFT and IDFT matrix"""
        super(DFTBase, self).__init__()

    def dft_matrix(self, n):
        (x, y) = np.meshgrid(np.arange(n), np.arange(n))
        omega = np.exp(-2 * np.pi * 1j / n)
        W = np.power(omega, x * y)
        return W

    def idft_matrix(self, n):
        (x, y) = np.meshgrid(np.arange(n), np.arange(n))
        omega = np.exp(2 * np.pi * 1j / n)
        W = np.power(omega, x * y)
        return W
    
    
class STFT(DFTBase):
    def __init__(self, n_fft=2048, hop_length=None, win_length=None, 
        window="hann", center=True, pad_mode="reflect", freeze_parameters=True):
        """Implementation of STFT with Conv1d. The function has the same output 
        of librosa.core.stft
        """
        super(STFT, self).__init__()

        assert pad_mode in ["constant", "reflect"]

        self.n_fft = n_fft
        self.center = center
        self.pad_mode = pad_mode

        # By default, use the entire frame
        if win_length is None:
            win_length = n_fft

        # Set the default hop, if it"s not already specified
        if hop_length is None:
            hop_length = int(win_length // 4)

        fft_window = librosa.filters.get_window(window, win_length, fftbins=True)

        # Pad the window out to n_fft size
        fft_window = librosa.util.pad_center(fft_window, n_fft)

        # DFT & IDFT matrix
        self.W = self.dft_matrix(n_fft)

        out_channels = n_fft // 2 + 1

        self.conv_real = nn.Conv1d(in_channels=1, out_channels=out_channels, 
            kernel_size=n_fft, stride=hop_length, padding=0, dilation=1, 
            groups=1, bias=False)

        self.conv_imag = nn.Conv1d(in_channels=1, out_channels=out_channels, 
            kernel_size=n_fft, stride=hop_length, padding=0, dilation=1, 
            groups=1, bias=False)

        self.conv_real.weight.data = torch.Tensor(
            np.real(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :]
        # (n_fft // 2 + 1, 1, n_fft)

        self.conv_imag.weight.data = torch.Tensor(
            np.imag(self.W[:, 0 : out_channels] * fft_window[:, None]).T)[:, None, :]
        # (n_fft // 2 + 1, 1, n_fft)

        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, input):
        """input: (batch_size, data_length)
        Returns:
          real: (batch_size, n_fft // 2 + 1, time_steps)
          imag: (batch_size, n_fft // 2 + 1, time_steps)
        """

        x = input[:, None, :]   # (batch_size, channels_num, data_length)

        if self.center:
            x = F.pad(x, pad=(self.n_fft // 2, self.n_fft // 2), mode=self.pad_mode)

        real = self.conv_real(x)
        imag = self.conv_imag(x)
        # (batch_size, n_fft // 2 + 1, time_steps)

        real = real[:, None, :, :].transpose(2, 3)
        imag = imag[:, None, :, :].transpose(2, 3)
        # (batch_size, 1, time_steps, n_fft // 2 + 1)

        return real, imag
    
    
class Spectrogram(nn.Module):
    def __init__(self, n_fft=2048, hop_length=None, win_length=None, 
        window="hann", center=True, pad_mode="reflect", power=2.0, 
        freeze_parameters=True):
        """Calculate spectrogram using pytorch. The STFT is implemented with 
        Conv1d. The function has the same output of librosa.core.stft
        """
        super(Spectrogram, self).__init__()

        self.power = power

        self.stft = STFT(n_fft=n_fft, hop_length=hop_length, 
            win_length=win_length, window=window, center=center, 
            pad_mode=pad_mode, freeze_parameters=True)

    def forward(self, input):
        """input: (batch_size, 1, time_steps, n_fft // 2 + 1)
        Returns:
          spectrogram: (batch_size, 1, time_steps, n_fft // 2 + 1)
        """

        (real, imag) = self.stft.forward(input)
        # (batch_size, n_fft // 2 + 1, time_steps)

        spectrogram = real ** 2 + imag ** 2

        if self.power == 2.0:
            pass
        else:
            spectrogram = spectrogram ** (self.power / 2.0)

        return spectrogram

    
class LogmelFilterBank(nn.Module):
    def __init__(self, sr=32000, n_fft=2048, n_mels=64, fmin=50, fmax=14000, is_log=True, 
        ref=1.0, amin=1e-10, top_db=80.0, freeze_parameters=True):
        """Calculate logmel spectrogram using pytorch. The mel filter bank is 
        the pytorch implementation of as librosa.filters.mel.
        
        This method is critical for different tasks. In this project,
        we should set those params carefully to suit the task of human
        voice recogniation.
        1. n_mels: 128 and 64 are both reasonable. The performance has not yet been decided.
        2. fmin and fmax: Should bigger than the range of human voice (100-10000hz).
        3. sample rate: This must be set the same project-wise.
        """
        super(LogmelFilterBank, self).__init__()

        self.is_log = is_log
        self.ref = ref
        self.amin = amin
        self.top_db = top_db

        self.melW = librosa.filters.mel(sr=sr, n_fft=n_fft, n_mels=n_mels,
            fmin=fmin, fmax=fmax).T
        # (n_fft // 2 + 1, mel_bins)

        self.melW = nn.Parameter(torch.Tensor(self.melW))

        if freeze_parameters:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, input):
        """input: (batch_size, channels, time_steps)
        
        Output: (batch_size, time_steps, mel_bins)
        """

        # Mel spectrogram
        mel_spectrogram = torch.matmul(input, self.melW)

        # Logmel spectrogram
        if self.is_log:
            output = self.power_to_db(mel_spectrogram)
        else:
            output = mel_spectrogram

        return output


    def power_to_db(self, input):
        """Power to db, this function is the pytorch implementation of 
        librosa.core.power_to_lb
        """
        ref_value = self.ref
        log_spec = 10.0 * torch.log10(torch.clamp(input, min=self.amin, max=np.inf))
        log_spec -= 10.0 * np.log10(np.maximum(self.amin, ref_value))

        if self.top_db is not None:
            if self.top_db < 0:
                raise ValueError("top_db must be non-negative")
            log_spec = torch.clamp(log_spec, min=log_spec.max().item() - self.top_db, max=np.inf)

        return log_spec
    
class DropStripes(nn.Module):
    def __init__(self, dim, drop_width, stripes_num):
        """Drop stripes. 
        Args:
          dim: int, dimension along which to drop
          drop_width: int, maximum width of stripes to drop
          stripes_num: int, how many stripes to drop
        """
        super(DropStripes, self).__init__()

        assert dim in [2, 3]    # dim 2: time; dim 3: frequency

        self.dim = dim
        self.drop_width = drop_width
        self.stripes_num = stripes_num

    def forward(self, input):
        """input: (batch_size, channels, time_steps, freq_bins)"""

        assert input.ndimension() == 4

        if self.training is False:
            return input

        else:
            batch_size = input.shape[0]
            total_width = input.shape[self.dim]

            for n in range(batch_size):
                self.transform_slice(input[n], total_width)

            return input


    def transform_slice(self, e, total_width):
        """e: (channels, time_steps, freq_bins)"""

        for _ in range(self.stripes_num):
            distance = torch.randint(low=0, high=self.drop_width, size=(1,))[0]
            bgn = torch.randint(low=0, high=total_width - distance, size=(1,))[0]

            if self.dim == 2:
                e[:, bgn : bgn + distance, :] = 0
            elif self.dim == 3:
                e[:, :, bgn : bgn + distance] = 0


class SpecAugmentation(nn.Module):
    def __init__(self, time_drop_width, time_stripes_num, freq_drop_width, 
        freq_stripes_num):
        """Spec augmetation. 
        [ref] Park, D.S., Chan, W., Zhang, Y., Chiu, C.C., Zoph, B., Cubuk, E.D. 
        and Le, Q.V., 2019. Specaugment: A simple data augmentation method 
        for automatic speech recognition. arXiv preprint arXiv:1904.08779.
        Args:
          time_drop_width: int
          time_stripes_num: int
          freq_drop_width: int
          freq_stripes_num: int
        """

        super(SpecAugmentation, self).__init__()

        self.time_dropper = DropStripes(dim=2, drop_width=time_drop_width, 
            stripes_num=time_stripes_num)

        self.freq_dropper = DropStripes(dim=3, drop_width=freq_drop_width, 
            stripes_num=freq_stripes_num)

    def forward(self, input):
        x = self.time_dropper(input)
        x = self.freq_dropper(x)
        return x


"""
Building models.
"""

def init_layer(layer):
    nn.init.xavier_uniform_(layer.weight)

    if hasattr(layer, "bias"):
        if layer.bias is not None:
            layer.bias.data.fill_(0.)


def init_bn(bn):
    bn.bias.data.fill_(0.)
    bn.weight.data.fill_(1.0)


def init_weights(model):
    classname = model.__class__.__name__
    if classname.find("Conv2d") != -1:
        nn.init.xavier_uniform_(model.weight, gain=np.sqrt(2))
        model.bias.data.fill_(0)
    elif classname.find("BatchNorm") != -1:
        model.weight.data.normal_(1.0, 0.02)
        model.bias.data.fill_(0)
    elif classname.find("GRU") != -1:
        for weight in model.parameters():
            if len(weight.size()) > 1:
                nn.init.orghogonal_(weight.data)
    elif classname.find("Linear") != -1:
        model.weight.data.normal_(0, 0.01)
        model.bias.data.zero_()


def do_mixup(x: torch.Tensor, mixup_lambda: torch.Tensor):
    """Mixup x of even indexes (0, 2, 4, ...) with x of odd indexes
    (1, 3, 5, ...).
    Args:
      x: (batch_size * 2, ...)
      mixup_lambda: (batch_size * 2,)
    Returns:
      out: (batch_size, ...)
    """
    out = (x[0::2].transpose(0, -1) * mixup_lambda[0::2] +
           x[1::2].transpose(0, -1) * mixup_lambda[1::2]).transpose(0, -1)
    return out


class Mixup(object):
    def __init__(self, mixup_alpha, random_seed=1234):
        """Mixup coefficient generator.
        """
        self.mixup_alpha = mixup_alpha
        self.random_state = np.random.RandomState(random_seed)

    def get_lambda(self, batch_size):
        """Get mixup random coefficients.
        Args:
          batch_size: int
        Returns:
          mixup_lambdas: (batch_size,)
        """
        mixup_lambdas = []
        for n in range(0, batch_size, 2):
            lam = self.random_state.beta(self.mixup_alpha, self.mixup_alpha, 1)[0]
            mixup_lambdas.append(lam)
            mixup_lambdas.append(1. - lam)

        return torch.from_numpy(np.array(mixup_lambdas, dtype=np.float32))


def interpolate(x: torch.Tensor, ratio: int):
    """Interpolate data in time domain. This is used to compensate the
    resolution reduction in downsampling of a CNN.
    Args:
      x: (batch_size, time_steps, classes_num)
      ratio: int, ratio to interpolate
    Returns:
      upsampled: (batch_size, time_steps * ratio, classes_num)
    """
    (batch_size, time_steps, classes_num) = x.shape
    upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1)
    upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num)
    return upsampled


def pad_framewise_output(framewise_output: torch.Tensor, frames_num: int):
    """Pad framewise_output to the same length as input frames. The pad value
    is the same as the value of the last frame.
    Args:
      framewise_output: (batch_size, frames_num, classes_num)
      frames_num: int, number of frames to pad
    Outputs:
      output: (batch_size, frames_num, classes_num)
    """
    pad = framewise_output[:, -1:, :].repeat(
        1, frames_num - framewise_output.shape[1], 1)
    """tensor for padding"""

    output = torch.cat((framewise_output, pad), dim=1)
    """(batch_size, frames_num, classes_num)"""

    return output


class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
            bias=False)

        self.conv2 = nn.Conv2d(
            in_channels=out_channels,
            out_channels=out_channels,
            kernel_size=(3, 3),
            stride=(1, 1),
            padding=(1, 1),
            bias=False)

        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

        self.init_weight()

    def init_weight(self):
        init_layer(self.conv1)
        init_layer(self.conv2)
        init_bn(self.bn1)
        init_bn(self.bn2)

    def forward(self, input, pool_size=(2, 2), pool_type="avg"):

        x = input
        x = F.relu_(self.bn1(self.conv1(x)))
        x = F.relu_(self.bn2(self.conv2(x)))
        if pool_type == "max":
            x = F.max_pool2d(x, kernel_size=pool_size)
        elif pool_type == "avg":
            x = F.avg_pool2d(x, kernel_size=pool_size)
        elif pool_type == "avg+max":
            x1 = F.avg_pool2d(x, kernel_size=pool_size)
            x2 = F.max_pool2d(x, kernel_size=pool_size)
            x = x1 + x2
        else:
            raise Exception("Incorrect argument!")

        return x


class AttBlock(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear",
                 temperature=1.0):
        super().__init__()

        self.activation = activation
        self.temperature = temperature
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.bn_att = nn.BatchNorm1d(out_features)
        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)
        init_bn(self.bn_att)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == "linear":
            return x
        elif self.activation == "sigmoid":
            return torch.sigmoid(x)


class AttBlockV2(nn.Module):
    def __init__(self,
                 in_features: int,
                 out_features: int,
                 activation="linear"):
        super().__init__()

        self.activation = activation
        self.att = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)
        self.cla = nn.Conv1d(
            in_channels=in_features,
            out_channels=out_features,
            kernel_size=1,
            stride=1,
            padding=0,
            bias=True)

        self.init_weights()

    def init_weights(self):
        init_layer(self.att)
        init_layer(self.cla)

    def forward(self, x):
        # x: (n_samples, n_in, n_time)
        norm_att = torch.softmax(torch.tanh(self.att(x)), dim=-1)
        cla = self.nonlinear_transform(self.cla(x))
        x = torch.sum(norm_att * cla, dim=2)
        return x, norm_att, cla

    def nonlinear_transform(self, x):
        if self.activation == "linear":
            return x
        elif self.activation == "sigmoid":
            return torch.sigmoid(x)


class PANNsCNN14Att(nn.Module):
    def __init__(self, sample_rate: int, window_size: int, hop_size: int,
                 mel_bins: int, fmin: int, fmax: int, classes_num: int):
        super().__init__()

        window = "hann"
        center = True
        pad_mode = "reflect"
        ref = 1.0
        amin = 1e-10
        top_db = None
        self.interpolate_ratio = 32  # Downsampled ratio

        # Spectrogram extractor
        self.spectrogram_extractor = Spectrogram(
            n_fft=window_size,
            hop_length=hop_size,
            win_length=window_size,
            window=window,
            center=center,
            pad_mode=pad_mode,
            freeze_parameters=True)

        # Logmel feature extractor
        self.logmel_extractor = LogmelFilterBank(
            sr=sample_rate,
            n_fft=window_size,
            n_mels=mel_bins,
            fmin=fmin,
            fmax=fmax,
            ref=ref,
            amin=amin,
            top_db=top_db,
            freeze_parameters=True)

        # Spec augmenter
        self.spec_augmenter = SpecAugmentation(
            time_drop_width=64,
            time_stripes_num=2,
            freq_drop_width=8,
            freq_stripes_num=2)

        self.bn0 = nn.BatchNorm2d(mel_bins)

        self.conv_block1 = ConvBlock(in_channels=1, out_channels=64)
        self.conv_block2 = ConvBlock(in_channels=64, out_channels=128)
        self.conv_block3 = ConvBlock(in_channels=128, out_channels=256)
        self.conv_block4 = ConvBlock(in_channels=256, out_channels=512)
        self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024)
        self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048)

        self.fc1 = nn.Linear(2048, 2048, bias=True)
        self.att_block = AttBlock(2048, classes_num, activation="sigmoid")

        self.init_weight()

    def init_weight(self):
        init_bn(self.bn0)
        init_layer(self.fc1)

    def forward(self, input, mixup_lambda=None):
        """
        Input: (batch_size, data_length)"""

        # t1 = time.time()
        x = self.spectrogram_extractor(
            input)  # (batch_size, 1, time_steps, freq_bins)
        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        frames_num = x.shape[2]

        x = x.transpose(1, 3)
        x = self.bn0(x)
        x = x.transpose(1, 3)

        if self.training:
            x = self.spec_augmenter(x)

        # Mixup on spectrogram
        if self.training and mixup_lambda is not None:
            x = do_mixup(x, mixup_lambda)

        x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg")
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg")
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg")
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg")
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg")
        x = F.dropout(x, p=0.2, training=self.training)
        x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg")
        x = F.dropout(x, p=0.2, training=self.training)
        x = torch.mean(x, dim=3)

        x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1)
        x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1)
        x = x1 + x2
        x = F.dropout(x, p=0.5, training=self.training)
        x = x.transpose(1, 2)
        x = F.relu_(self.fc1(x))
        x = x.transpose(1, 2)
        x = F.dropout(x, p=0.5, training=self.training)
        (clipwise_output, norm_att, segmentwise_output) = self.att_block(x)
        logit = torch.sum(norm_att * self.att_block.cla(x), dim=2)
        segmentwise_output = segmentwise_output.transpose(1, 2)

        # Get framewise output
        framewise_output = interpolate(segmentwise_output,
                                       self.interpolate_ratio)
        framewise_output = pad_framewise_output(framewise_output, frames_num)

        output_dict = {
            "framewise_output": framewise_output,
            "logit": logit,
            "clipwise_output": clipwise_output
        }

        return output_dict
    

In [None]:
from csrc.models import AttBlock, PANNsCNN14Att

## Loss

In [None]:

import torch
import torch.nn as nn
import torch.nn.functional as F


class PANNsLoss(nn.Module):
    def __init__(self):
        super().__init__()

        self.bce = nn.BCELoss()

    def forward(self, input, target):
        input_ = input["clipwise_output"]
        input_ = torch.where(torch.isnan(input_),
                             torch.zeros_like(input_),
                             input_)
        input_ = torch.where(torch.isinf(input_),
                             torch.zeros_like(input_),
                             input_)

        target = target.float()

        input_ = torch.clamp(input_, 0.0, 1.0)
        
        return self.bce(input_, target)
    
    
class FocalLoss(nn.Module):
    def __init__(self, gamma=2):
        super().__init__()
        self.gamma = gamma

    def forward(self, logit, target):
        target = target.float()
        max_val = (-logit).clamp(min=0)
        loss = logit - logit * target + max_val + \
            ((-max_val).exp() + (-logit - max_val).exp()).log()

        invprobs = F.logsigmoid(-logit * (target * 2.0 - 1.0))
        loss = (invprobs * self.gamma).exp() * loss
        if len(loss.size()) == 2:
            loss = loss.sum(dim=1)
        return loss.mean()


class ImprovedFocalLoss(nn.Module):
    def __init__(self, weights=[1, 1]):
        super().__init__()

        self.focal = FocalLoss()
        self.weights = weights

    def forward(self, input, target):
        input_ = input["logit"]
        target = target.float()

        framewise_output = input["framewise_logit"]
        clipwise_output_with_max, _ = framewise_output.max(dim=1)

        normal_loss = self.focal(input_, target)
        auxiliary_loss = self.focal(clipwise_output_with_max, target)

        return self.weights[0] * normal_loss + self.weights[1] * auxiliary_loss


class ImprovedPANNsLoss(nn.Module):
    def __init__(self, output_key="logit", weights=[1, 1]):
        super().__init__()

        self.output_key = output_key
        if output_key == "logit":
            self.normal_loss = nn.BCEWithLogitsLoss()
        else:
            self.normal_loss = nn.BCELoss()

        self.weights = weights

    def forward(self, input, target):
        input_ = input[self.output_key]
        target = target.float()

        framewise_output = input["framewise_output"]
        clipwise_output_with_max, _ = framewise_output.max(dim=1)

        normal_loss = self.normal_loss(input_, target)
        # auxiliary_loss = self.bce(clipwise_output_with_max, target)
        auxiliary_loss = self.normal_loss(clipwise_output_with_max, target)

        return self.weights[0] * normal_loss + self.weights[1] * auxiliary_loss

In [None]:
from csrc.losses import ImprovedPANNsLoss

## Callbacks

In [None]:
from sklearn.metrics import f1_score, average_precision_score, precision_score


class F1Callback(Callback):
    def __init__(self,
                 input_key: str = "targets",
                 output_key: str = "logits",
                 model_output_key: str = "clipwise_output",
                 prefix: str = "macro_f1"):
        super().__init__(CallbackOrder.Metric)

        self.input_key = input_key
        self.output_key = output_key
        self.model_output_key = model_output_key
        self.prefix = prefix

    def on_loader_start(self, state: IRunner):
        self.prediction: List[np.ndarray] = []
        self.target: List[np.ndarray] = []

    def on_batch_end(self, state: IRunner):
        targ = state.batch[self.input_key].detach().cpu().numpy()
        out = state.batch[self.output_key]

        clipwise_output = out[self.model_output_key].detach().cpu().numpy()

        self.prediction.append(clipwise_output)
        self.target.append(targ)

        y_pred = clipwise_output.argmax(axis=1)
        y_true = targ.argmax(axis=1)

        score = f1_score(y_true, y_pred, average="macro")
        state.batch_metrics[self.prefix] = score

    def on_loader_end(self, state: IRunner):
        y_pred = np.concatenate(self.prediction, axis=0).argmax(axis=1)
        y_true = np.concatenate(self.target, axis=0).argmax(axis=1)
        score = f1_score(y_true, y_pred, average="macro")
        state.loader_metrics[self.prefix] = score
        if state.is_valid_loader:
            state.epoch_metrics["valid"] = {
                self.prefix: score
            }
        else:
            state.epoch_metrics["train"] = {
                self.prefix: score
            }            
            
class PrecisionCallback(Callback):
    def __init__(self,
                 input_key: str = "targets",
                 output_key: str = "logits",
                 model_output_key: str = "clipwise_output",
                 prefix: str = "precision"):
        super().__init__(CallbackOrder.Metric)

        self.input_key = input_key
        self.output_key = output_key
        self.model_output_key = model_output_key
        self.prefix = prefix

    def on_loader_start(self, state: IRunner):
        self.prediction: List[np.ndarray] = []
        self.target: List[np.ndarray] = []

    def on_batch_end(self, state: IRunner):
        targ = state.batch[self.input_key].detach().cpu().numpy()
        out = state.batch[self.output_key]

        clipwise_output = out[self.model_output_key].detach().cpu().numpy()

        self.prediction.append(clipwise_output)
        self.target.append(targ)

        y_pred = clipwise_output.argmax(axis=1)
        y_true = targ.argmax(axis=1)

        score = precision_score(y_true, y_pred)
        state.batch_metrics[self.prefix] = score

    def on_loader_end(self, state: IRunner):
        y_pred = np.concatenate(self.prediction, axis=0).argmax(axis=1)
        y_true = np.concatenate(self.target, axis=0).argmax(axis=1)
        score = precision_score(y_true, y_pred)
        state.loader_metrics[self.prefix] = score
        if state.is_valid_loader:
            state.epoch_metrics["valid"] = {
                self.prefix: score
            }
        else:
            state.epoch_metrics["train"] = {
                self.prefix: score
            }  
            
            
class mAPCallback(Callback):
    def __init__(self,
                 input_key: str = "targets",
                 output_key: str = "logits",
                 model_output_key: str = "clipwise_output",
                 prefix: str = "mAP"):
        super().__init__(CallbackOrder.Metric)
        self.input_key = input_key
        self.output_key = output_key
        self.model_output_key = model_output_key
        self.prefix = prefix

    def on_loader_start(self, state: IRunner):
        self.prediction: List[np.ndarray] = []
        self.target: List[np.ndarray] = []

    def on_batch_end(self, state: IRunner):
        targ = state.batch[self.input_key].detach().cpu().numpy()
        out = state.batch[self.output_key]

        clipwise_output = out[self.model_output_key].detach().cpu().numpy()

        self.prediction.append(clipwise_output)
        self.target.append(targ)

        targ = np.nan_to_num(targ)
        clipwise_output = np.nan_to_num(clipwise_output)
        score = average_precision_score(targ, clipwise_output, average=None)
        score = np.nan_to_num(score).mean()
        state.batch_metrics[self.prefix] = score

    def on_loader_end(self, state: IRunner):
        y_pred = np.concatenate(self.prediction, axis=0)
        y_true = np.concatenate(self.target, axis=0)
        score = average_precision_score(y_true, y_pred, average=None)
        score = np.nan_to_num(score).mean()
        state.loader_metrics[self.prefix] = score
        if state.is_valid_loader:
            state.epoch_metrics["valid"] = {
                self.prefix: score
            }
        else:
            state.epoch_metrics["train"] = {
                self.prefix: score
            }    


NameError: ignored

In [None]:
from csrc.callbacks import F1Callback, mAPCallback, PrecisionCallback

## Training Configurations

In [None]:
# device
device = torch.device("cuda:0")

# model
model = PANNsCNN14Att(**MC.sed_model_config)
weights = torch.load(WEIGHTS_PATH)
model.load_state_dict(weights["model"])
model.att_block = AttBlock(2048, 2, activation="sigmoid")
model.att_block.init_weights()
model.to(device)

# optimizer
optimizer = optim.Adam(model.parameters(), lr=0.001)

# scheduler
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# loss
loss = ImprovedPANNsLoss().to(device)

# callbacks
callbacks = [
    F1Callback(),
    mAPCallback(),
    PrecisionCallback(),
    CheckpointCallback(save_n_best=3),
]

## Training

In [None]:
warnings.simplefilter("ignore")

runner = SupervisedRunner(
    device=device,
    input_key="waveform",
    input_target_key="targets")

runner.train(
    model=model,
    criterion=loss,
    loaders=loaders,
    optimizer=optimizer,
    scheduler=scheduler,
    num_epochs=EPOCHS,
    verbose=True,
    logdir=LOG_DIR,
    callbacks=callbacks,
    main_metric="epoch_precision",
    minimize_metric=False,
    # fp16=True,
)

1/30 * Epoch (train): 100% 317/317 [12:17<00:00,  2.33s/it, loss=1.165, mAP=0.835, macro_f1=0.723, precision=0.913]
1/30 * Epoch (valid): 100% 80/80 [00:20<00:00,  3.94it/s, loss=1.796, mAP=0.500, macro_f1=0.364, precision=0.000e+00]
[2021-02-12 10:42:02,018] 
1/30 * Epoch 1 (_base): lr=0.0010 | momentum=0.9000
1/30 * Epoch 1 (train): epoch_mAP=0.8601 | epoch_macro_f1=0.8238 | epoch_precision=0.8784 | loss=1.1902 | mAP=0.8735 | macro_f1=0.8127 | precision=0.8863
1/30 * Epoch 1 (valid): epoch_mAP=0.9348 | epoch_macro_f1=0.8901 | epoch_precision=0.8869 | loss=0.8952 | mAP=0.8876 | macro_f1=0.8527 | precision=0.8097
2/30 * Epoch (train): 100% 317/317 [11:50<00:00,  2.24s/it, loss=1.044, mAP=0.918, macro_f1=0.838, precision=0.875]
2/30 * Epoch (valid): 100% 80/80 [00:16<00:00,  4.77it/s, loss=1.368, mAP=0.500, macro_f1=1.000, precision=0.000e+00]
[2021-02-12 10:54:13,367] 
2/30 * Epoch 2 (_base): lr=0.0009 | momentum=0.9000
2/30 * Epoch 2 (train): epoch_mAP=0.8825 | epoch_macro_f1=0.8561 |