# Joint classification

Author: Edoardo De Matteis

## Dependencies

### Libraries

In [None]:
!pip install torchaudio
!pip install torchinfo
!pip install pytorch_lightning
!pip install wandb -qqq

In [None]:
from pathlib import Path
from typing import Dict, Optional, Tuple, Union, List, Any

import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning import Callback, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    LearningRateMonitor
)
from torchinfo import summary

import wandb

### Filesystem

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd /content/gdrive/MyDrive/Colab\ Notebooks/Birdcalls

In [None]:
%load_ext autoreload
%autoreload 2

### Utilities

In [None]:
def build_callbacks(callbacks: Dict) -> List[Callback]:
    out_callbacks: List[Callback] = []
    if "lr_monitor" in callbacks:
        out_callbacks.append(
            LearningRateMonitor(
                logging_interval=callbacks["lr_monitor"]["logging_interval"],
                log_momentum=callbacks["lr_monitor"]["log_momentum"],
            )
        )

    if "early_stopping" in callbacks:
        out_callbacks.append(
            EarlyStopping(
                monitor=callbacks["monitor_metric"],
                mode=callbacks["monitor_metric_mode"],
                patience=callbacks["early_stopping"]["patience"],
                verbose=callbacks["early_stopping"]["verbose"],
            )
        )

    if "model_checkpoints" in callbacks:
        out_callbacks.append(
            ModelCheckpoint(
                monitor=callbacks["monitor_metric"],
                mode=callbacks["monitor_metric_mode"],
                save_top_k=callbacks["model_checkpoints"]["save_top_k"],
                verbose=callbacks["model_checkpoints"]["verbsose"],
            )
        )

    return out_callbacks

def cnn_size(
    input: Tuple[int, int],
    kernel: Union[int, Tuple[int, int]],
    padding: Union[int, Tuple[int, int]] = 0,
    stride: Union[int, Tuple[int, int]] = 1,
) -> Tuple[int, int]:
    """
    Return the size of the output of a convolutional layer.
    :param input: Size of the input image.
    :param kernel: Kernel size, it is assumed to be a square.
    :param padding: Padding size.
    :param stride: Stride.
    :return: The output size.
    """
    if isinstance(kernel, int):
        kernel = (kernel, kernel)

    if isinstance(padding, int):
        padding = (padding, padding)

    if isinstance(stride, int):
        stride = (stride, stride)

    out_w = (input[0] - kernel[0] + 2 * padding[0]) / stride[0] + 1
    out_h = (input[1] - kernel[1] + 2 * padding[1]) / stride[1] + 1
    return int(out_w), int(out_h)

### Dataset

In [None]:
class JointDataset(Dataset):
    def __init__(
        self,
        csv_path: Union[str, Path, None],
        online: bool,
        debug: int,
        load: bool,
        **kwargs
    ):
        """
        :param csv_path: Path of the training CSV file.
        :param online: If true tensors are computed on-the-fly by the dataloader, otherwise they are all precomputed.
        :param debug: Defines the size of the reduced dataset (it is shuffled beforehand) we want to use, any number
        below or equal to 0 means that we keep the whole dataset.
        :param load: If true we do not compute anything and will load values from a file.
        :param kwargs:
        """
        super(JointDataset, self).__init__()

        self.online = online
        self.len: int

        self.spectrograms: torch.Tensor
        self.targets: torch.Tensor

    @staticmethod
    def load(
        spectrograms_path: Union[str, Path], targets_path: Union[str, Path], **kwargs
    ):
        """
        Load a dataset whose spectorgrams and targets are loaded from .pt files.
        :param spectrograms_path: Path of the spectrograms tensor file.
        :param targets_path: Path of the targets tensor file.
        :param kwargs:
        :return: A SoundscapeDataset object with populated tensors.
        """
        ds = JointDataset(csv_path=None, online=False, debug=-1, load=True)

        ds.spectrograms = torch.load(spectrograms_path)
        ds.targets = torch.load(targets_path)
        ds.len = len(ds.targets)

        return ds

    def __len__(self):
        """
        :return: Length of the dataset.
        """
        return self.len

    def __getitem__(self, item):
        """
        :param item: Index of the item to retrieve.
        :return: The item-th entry.
        """
        if self.online:
            return {
                "row_id": self.row_id[item],
                "site": self.site[item],
                "audio_id": self.audio_id[item],
                "seconds": self.seconds[item],
                "birds": self.birds[item],
            }
        else:
            return {
                "spectrograms": self.spectrograms[item],
                "targets": self.targets[item],
            }

### Datamodule

In [None]:
class JointDataModule(pl.LightningModule):
    def __init__(
        self,
        num_workers: Dict,
        batch_size: Dict,
        shuffle: Dict,
        **kwargs,
    ):
        super().__init__()
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.shuffle = shuffle

        # These attributes will be populated after self.setup() call.
        self.train_ds: Optional[Dataset] = None
        self.val_ds: Optional[Dataset] = None
        self.test_ds: Optional[Dataset] = None

    def setup(self, stage: Optional[str] = None) -> None:
        if stage is None or stage == "fit":
            # Train
            self.train_ds = JointDataset.load(
                spectrograms_path=TRAIN_JOINT_DEBUG_SPECTROGRAMS,
                targets_path=TRAIN_JOINT_DEBUG_TARGETS
            )

            # Val
            self.val_ds = JointDataset.load(
                spectrograms_path=VAL_JOINT_DEBUG_SPECTROGRAMS,
                targets_path=VAL_JOINT_DEBUG_TARGETS
            )

    def train_dataloader(
        self,
    ) -> Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]:
        batch_size = self.batch_size["train"]
        shuffle = self.shuffle["train"]

        dl = DataLoader(
            dataset=self.train_ds,
            batch_size=batch_size,
            shuffle=shuffle,
        )

        return dl

    def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        batch_size = self.batch_size["val"]
        shuffle = self.shuffle["val"]

        dl = DataLoader(
            dataset=self.val_ds,
            batch_size=batch_size,
            shuffle=shuffle,
        )

        return dl

### Model

In [None]:
class CNNAtt(nn.Module):
    def __init__(
        self,
        image_shape: Tuple[int, int, int],
        channels: List,
        kernels: List,
        paddings: List,
        strides: List,
        num_heads: int,
    ):
        """
        A layer with some CNNs and attention.
        :param image_shape: The shape of the input image (channels, width, height).
        :param channels: The number of output channels for the CNNs, the last one has to be 1.
        :param kernels: A tuple with kernels for first, second and third convolution.
        :param paddings: A tuple with paddings for first, second and third convolution.
        :param strides: A tuple with strides for first, second and third convolution.
        :param num_heads: Number of heads to use in the attention layer, by default one to avoid prime numbers' issues.
        """
        super(CNNAtt, self).__init__()
        c, w, h = image_shape

        # Are used three CNNs since is the minimum needed to have a bottleneck and return to the original channel size,
        # yet is still possible to learn the identity function as a composition of f and its inverse.
        channels_start, channels_mid, channels_end = 1, 3, 1
        kernel_start, kernel_mid, kernel_end = kernels
        padding_start, padding_mid, padding_end = paddings
        stride_start, stride_mid, stride_end = strides

        # Query
        self.cnn_q = self.get_seq(
            name="CNNQuery",
            channels=channels,
            kernels=kernels,
            paddings=paddings,
            strides=strides,
        )

        # Key
        self.cnn_k = self.get_seq(
            name="CNNKey",
            channels=channels,
            kernels=kernels,
            paddings=paddings,
            strides=strides,
        )

        # Value
        self.cnn_v = self.get_seq(
            name="CNNValue",
            channels=channels,
            kernels=kernels,
            paddings=paddings,
            strides=strides,
        )

        # Measure output sizes.
        seq, embed = self.count_outputs(
            input=(w, h), kernels=kernels, paddings=paddings, strides=strides
        )

        # Attention layers, one for each image dimension.
        self.att1 = nn.MultiheadAttention(
            embed_dim=embed, num_heads=num_heads, batch_first=True
        )
        self.att2 = nn.MultiheadAttention(
            embed_dim=seq, num_heads=num_heads, batch_first=True
        )

    def forward(self, xb):
        q = self.cnn_q(xb).squeeze()
        k = self.cnn_k(xb).squeeze()
        v = self.cnn_v(xb).squeeze()

        # Transpose since batch first is usually faster, in att2 to swap row with columns.
        att1, _ = self.att1(q, k, v)
        att2, _ = self.att2(q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2))

        # Add one dimension and concatenate as a 2-channels image.
        att1 = att1.unsqueeze(1)
        att2 = att2.transpose(1, 2).unsqueeze(1)
        out = torch.cat((att1, att2), dim=1)

        return out

    def get_seq(
        self,
        name: str,
        channels: List,
        kernels: List,
        paddings: List,
        strides: List,
    ):
        """
        This function returns a list of modules that build up a basic block. Each one will be used as query, key or value
        in an attention layer.
        :param name: Name of the layer.
        :param channels: Channels of each CNN layer, the last one has to be 1.
        :param kernels: Kernel size for each CNN.
        :param paddings: Paddings for each CNN.
        :param strides: Strides for each CNN.
        :return: A nn.Sequential object.
        """
        assert channels[-1] == 1, "The attention layer can accept one channel only!."

        # By default audio data has only one channel.
        in_channels = 1
        seq = nn.Sequential()

        for n, (out_channels, kernel, padding, stride) in enumerate(
            zip(channels, kernels, paddings, strides)
        ):
            cnn = nn.Conv2d(
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel,
                padding=padding,
                stride=stride,
            )
            seq.add_module(name=f"{name}{n}", module=cnn)
            in_channels = out_channels

        return seq

    def count_outputs(
        self,
        input: Tuple[int, int],
        kernels: List,
        paddings: List,
        strides: List,
    ):
        """
        Just count the output of an image obtained by a sequence of convolutions.
        :param input: Input image dimension.
        :param kernels: List of kernels to apply.
        :param paddings: List of paddings.
        :param strides: List of strides.
        :return: The size of the feature image.
        """
        image = input
        for kernel, padding, stride in zip(kernels, paddings, strides):
            image = cnn_size(input=image, kernel=kernel, padding=padding, stride=stride)
        return image

In [None]:
class CNNRes(nn.Module):
    def __init__(self, in_channels: int, kernel_size: int):
        super(CNNRes, self).__init__()

        # This padding is added to keep dimensionality the same, it is recommended to choose even kernels.
        pad = kernel_size // 2

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

        self.conv2 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

        self.conv3 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

        self.conv4 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

        self.conv5 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

        self.conv6 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

    def forward(self, res):
        out = self.conv1(res)
        res = self.conv1(out) + res
        out = self.conv1(res)
        res = self.conv1(out) + res
        out = self.conv1(res)
        res = self.conv1(out)
        return res

In [None]:
class Extraction(nn.Module):
    """
    Feature extraction backbone.
    """

    def __init__(
        self,
        image_shape: Tuple[int, int, int],
        att_channels: List,
        att_kernels: List,
        att_paddings: List,
        att_strides: List,
        att_num_heads: int,
        pool_att: int,
        res_kernels: List,
        pool: int,
    ):
        super(Extraction, self).__init__()
        self.level = len(res_kernels)

        # # CNNAtt blocks.
        # self.cnnatt1 = CNNAtt(
        #     image_shape=image_shape,
        #     channels=att_channels,
        #     kernels=att_kernels,
        #     paddings=att_paddings,
        #     strides=att_strides,
        #     num_heads=att_num_heads,
        # )

        # self.cnnatt2 = CNNAtt(
        #     image_shape=image_shape,
        #     channels=att_channels,
        #     kernels=att_kernels,
        #     paddings=att_paddings,
        #     strides=att_strides,
        #     num_heads=att_num_heads,
        # )

        # # CNN for the concatenated attentions, it returns us an image of the same shape as the input and
        # # can be used for residuals. The number of input channels is fixed to 4 due to it being 1*2*2:
        # # - 1 the number of out channels in CNNAtt.
        # # - 2 due to the concatenation of attentions in CNNAtt.
        # # - 2 due to the concatenation of attentions self.forward().
        # self.cnn_att = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=(1, 1))

        # # # Pooling before blocks, can be useful to lighten computation.
        # self.pool_att = nn.AvgPool2d(kernel_size=pool_att)

        # Residual blocks and convolutions in between them used to change dimensionality/filter sizes.
        # We can fix out_channels since all audio data has originally one channel.
        self.res = nn.Sequential()
        in_channels = image_shape[0]

        for n, kernel in enumerate(res_kernels):
            res = CNNRes(in_channels=in_channels, kernel_size=kernel)
            cnn = nn.Conv2d(
                in_channels=in_channels,
                out_channels=2 * in_channels,
                kernel_size=kernel,
                stride=(2, 2),
            )

            self.res.add_module(name=f"CNNRes{n+1}", module=res)
            self.res.add_module(name=f"CNN{n+1}", module=cnn)

            in_channels *= 2

        # Output pooling.
        self.pool = nn.AvgPool2d(kernel_size=pool)

    def forward(self, xb):
        # Get attentions.
        # att1 = self.cnnatt1(xb)
        # att2 = self.cnnatt1(xb)
        # out = torch.cat((att1, att2), dim=1)

        # # Conv to use residuals on.
        # out = self.cnn_att(out) + xb
        # out = self.pool_att(out)

        # Residuals.
        out = self.res(xb) # Change to out

        # Pooling.
        out = self.pool(out)

        return out

The architecture of the network is fixed to avoid having too clutter in the hyperparameters section.

In [None]:
class Classification(nn.Module):
    def __init__(self, out_features: int, **kwargs):
        super().__init__()
        # self.ext = Extraction(
        #     image_shape=(1, 128, 3751),
        #     att_channels=[1, 3, 1],
        #     att_kernels=[1, 3, 1],
        #     att_paddings=[0, 1, 0],
        #     att_strides=[1, 1, 1],
        #     att_num_heads=1,
        #     pool_att=1,
        #     res_kernels=[3,5],
        #     pool=1,
        # )

        # self.gru = nn.GRU(
        #     input_size=112320,
        #     hidden_size=512,
        #     num_layers=1,
        #     bidirectional=True,
        #     dropout=0,  # Dropout should be 0 when there is only one layer.
        # )

        # ResNet18
        self.cnn = nn.Conv2d(in_channels=1, out_channels=3, kernel_size=1)
        self.resnet = torchvision.models.resnet50(pretrained=True)
        # for param in self.resnet.parameters():
        #     param.requires_grad = False


        self.fc = nn.Linear(
            in_features=1000, out_features=out_features
        )


    def forward(self, xb):
        # # Feature extraction backbone.
        # out = self.ext(xb)

        # # Reshape.
        # b, c, w, h = out.shape
        # out = out.reshape(b, 1, c * w * h).transpose(0, 1)

        # # Prediction head.
        # # out, weights = self.att(out, out, out)
        # out, _ = self.gru(out)
        # logits = self.fc(out.squeeze())

        out = self.cnn(xb)
        out = self.resnet(out)
        # print("SGRAOX"*30)
        # print(out.shape)
        logits = self.fc(out)


        return logits


### Module

In [None]:
class JointClassification(pl.LightningModule):
    def __init__(self, out_features: int, **kwargs):
        super(JointClassification, self).__init__()
        self.save_hyperparameters()
        self.model = Classification(out_features=out_features)

        self.loss = nn.CrossEntropyLoss()

        metric = torchmetrics.Accuracy()
        self.train_accuracy = metric.clone()
        self.val_accuracy = metric.clone()
        self.test_accuracy = metric.clone()

    def forward(self, xb):
        logits = self.model(xb)
        preds = torch.argmax(logits, dim=-1)
        return logits, preds

    def step(self, x: torch.Tensor, y: torch.Tensor):
        logits, preds = self(x)
        loss = self.loss(logits, y)
        return {"logits": logits, "preds": preds, "loss": loss}

    def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
        targets = batch["targets"]
        specs = batch["spectrograms"]
        out_step = self.step(x=specs, y=targets)

        self.train_accuracy(out_step["preds"], targets)
        self.log_dict(
            {"train_loss": out_step["loss"], "train_acc": self.train_accuracy.compute()}
        )
        return out_step["loss"]

    def validation_step(self, batch: Any, batch_idx: int):
        targets = batch["targets"]
        specs = batch["spectrograms"]
        out_step = self.step(x=specs, y=targets)

        self.val_accuracy(out_step["preds"], targets)
        self.log_dict(
            {"val_loss": out_step["loss"], "val_acc": self.val_accuracy.compute()}
        )
        return out_step["loss"]

     def configure_optimizers(self):
        opt = self.hparams.optim["optimizer"]["fn"](
            params=self.parameters(),
            lr=self.hparams.optim["optimizer"]["lr"],
            betas=self.hparams.optim["optimizer"]["betas"],
            eps=self.hparams.optim["optimizer"]["eps"],
            weight_decay=self.hparams.optim["optimizer"]["weight_decay"],
        )


        if not self.hparams.optim["use_lr_scheduler"]:
            return {"optimizer": opt}
        else:
            scheduler = self.hparams.optim["lr_scheduler"]["fn"](
                optimizer=opt,
                T_0=self.hparams.optim["lr_scheduler"]["T_0"],
                T_mult=self.hparams.optim["lr_scheduler"]["T_mult"],
                eta_min=self.hparams.optim["lr_scheduler"]["eta_min"],
                last_epoch=self.hparams.optim["lr_scheduler"]["last_epoch"],
                verbose=self.hparams.optim["lr_scheduler"]["verbose"],
            )
            return {"optimizer": opt, "lr_scheduler": scheduler}

## Training

### Environmental variables

In [None]:
PROJECT_ROOT = Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls")

TRAIN_BIRDCALLS_DEBUG = Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/debug_datasets/train/birdcalls")
TRAIN_BIRDCALLS_DEBUG_SPECTROGRAMS= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/debug_datasets/train/joint/spectrograms.pt")
TRAIN_BIRDCALLS_DEBUG_TARGETS= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/debug_datasets/train/joint/targets.pt")

VAL_BIRDCALLS_DEBUG= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/debug_datasets/val/birdcalls")
VAL_BIRDCALLS_DEBUG_SPECTROGRAMS= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/debug_datasets/val/joint/spectrograms_balanced.pt")
VAL_BIRDCALLS_DEBUG_TARGETS= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/debug_datasets/val/joint/targets_balanced.pt")

Input dictionaries

In [None]:
num_workers = {'train': 12, 'val': 12}
batch_size = {'train': 8, 'val': 8}
shuffle = {'train': True, 'val': False}

# Optimizer
optimizer = {'fn': torch.optim.Adam,
             'lr': 1e-4,
             'betas': [ 0.9, 0.999 ],
             'eps': 1e-08,
             'weight_decay': 0
             }

use_lr_scheduler = False

lr_scheduler = {'fn': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
                'T_0': 10,
                'T_mult': 2,
                'eta_min': 0,
                'last_epoch': -1,
                'verbose': True}

optim = {'optimizer': optimizer,
         'use_lr_scheduler': use_lr_scheduler,
         'lr_scheduler': lr_scheduler}

# Callbacks
lr_monitor = {
    "logging_interval": "step",
    "log_momentum": False
}

early_stopping = {
    "patience": 42,
    "verbose": False
}

model_checkpoints = {
    "save_top_k": 2,
    "verbose": False
}

callbacks = {
    "monitor_metric": 'train_loss',
    "monitor_metric_mode": 'min',
    # "lr_monitor": lr_monitor,
    # "model_checkpoints": model_checkpoints,
    # "early_stopping": early_stopping
}

# Trainer
train = {
    "deterministic": True,
    "random_seed": 42,
    "val_check_interval": 1.0,
    "progress_bar_refresh_rate": 20,
    "fast_dev_run": True, # True for debug purposes.
    "gpus": -1 if torch.cuda.is_available() else 0,
    "precision": 32,
    "max_steps": 100,
    "max_epochs": 25,
    "accumulate_grad_batches": 1,
    "num_sanity_val_steps": 2,
    "gradient_clip_val": 10.0
}

In [None]:
# Call that only once!
if train["deterministic"]:
    seed_everything(train["random_seed"])

### Wandb

In [None]:
wandb.login()

### Run

In [None]:
datamodule = JointDataModule(num_workers=num_workers,
                        batch_size=batch_size,
                        shuffle=shuffle)

model = JointClassification(optim=optim, out_features=397)
# callbacks=build_callbacks(callbacks=callbacks),


wandb_logger = WandbLogger(
    project="Birdcalls classification",
    config={
        "batch_size": batch_size['train'],
        "learning_rate": optimizer['lr'],
        "optimizer": optimizer['fn'],
        "betas": optimizer["betas"],
        "eps": optimizer["eps"],
        "weight_decay": optimizer["weight_decay"],
        "lr_scheduler": use_lr_scheduler,
        "T_0": lr_scheduler["T_0"],
        "T_mult": lr_scheduler["T_mult"],
        "eta_min": lr_scheduler["eta_min"],
        "last_epoch": lr_scheduler["last_epoch"],
        "dataset": "Bird CLEF 2021",
        "summary": summary(model),
        }
)

trainer = pl.Trainer(
        logger=wandb_logger,
        deterministic=train["deterministic"],
        gpus=train["gpus"],
        max_epochs=train["max_epochs"],
        # callbacks=callbacks
    )

In [None]:
print(summary(model))
print(model)

Fit

In [None]:
trainer.fit(model=model, datamodule=datamodule)

Validation

In [None]:
trainer.validate(model=model, datamodule=datamodule)

Quit W&B

In [None]:
wandb.finish()