# Joint classification

Author: Edoardo De Matteis

This notebook has been used to train the model used for the birdcall classification task i.e. a system that given an audio window predicts which bird is singing, we do not assume that there always is some bird singing therfore we should predict that the case in which there are no birds as well.

## Dependencies

In [None]:
!pip install torchaudio
!pip install torchinfo
!pip install pytorch_lightning
!pip install wandb -qqq

In [None]:
from pathlib import Path
from typing import Dict, Optional, Tuple, Union, List, Any

import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchmetrics
import pytorch_lightning as pl
from pytorch_lightning import Callback, seed_everything
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    LearningRateMonitor
)
from torchinfo import summary
import json
import wandb

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

In [None]:
%cd /content/gdrive/MyDrive/Colab\ Notebooks/Birdcalls

In [None]:
%load_ext autoreload
%autoreload 2

Let's define some basic functions that will help us:

In [None]:
def build_callbacks(callbacks: Dict) -> List[Callback]:
    out_callbacks: List[Callback] = []
    if "lr_monitor" in callbacks:
        out_callbacks.append(
            LearningRateMonitor(
                logging_interval=callbacks["lr_monitor"]["logging_interval"],
                log_momentum=callbacks["lr_monitor"]["log_momentum"],
            )
        )

    if "early_stopping" in callbacks:
        out_callbacks.append(
            EarlyStopping(
                monitor=callbacks["monitor_metric"],
                mode=callbacks["monitor_metric_mode"],
                patience=callbacks["early_stopping"]["patience"],
                verbose=callbacks["early_stopping"]["verbose"],
            )
        )

    if "model_checkpoints" in callbacks:
        out_callbacks.append(
            ModelCheckpoint(
                monitor=callbacks["monitor_metric"],
                mode=callbacks["monitor_metric_mode"],
                save_top_k=callbacks["model_checkpoints"]["save_top_k"],
                verbose=callbacks["model_checkpoints"]["verbsose"],
            )
        )

    return out_callbacks

def cnn_size(
    input: Tuple[int, int],
    kernel: Union[int, Tuple[int, int]],
    padding: Union[int, Tuple[int, int]] = 0,
    stride: Union[int, Tuple[int, int]] = 1,
) -> Tuple[int, int]:
    """
    Return the size of the output of a convolutional layer.
    :param input: Size of the input image.
    :param kernel: Kernel size, it is assumed to be a square.
    :param padding: Padding size.
    :param stride: Stride.
    :return: The output size.
    """
    if isinstance(kernel, int):
        kernel = (kernel, kernel)

    if isinstance(padding, int):
        padding = (padding, padding)

    if isinstance(stride, int):
        stride = (stride, stride)

    out_w = (input[0] - kernel[0] + 2 * padding[0]) / stride[0] + 1
    out_h = (input[1] - kernel[1] + 2 * padding[1]) / stride[1] + 1
    return int(out_w), int(out_h)

def load_vocab(path: Union[str, Path]) -> Dict:
    """
    Load vocabulary from a JSON file.
    :param path: Path to file.
    :return: Dictionary object i.e. the vocabulary.
    """
    f = open(path, "r")
    vocab = json.load(f)
    f.close()
    return vocab

### Dataset

In [None]:
class JointDataset(Dataset):
    def __init__(
        self,
        csv_path: Union[str, Path, None],
        online: bool,
        debug: int,
        load: bool,
        **kwargs
    ):
        """
        :param csv_path: Path of the training CSV file.
        :param online: If true tensors are computed on-the-fly by the dataloader, otherwise they are all precomputed.
        :param debug: Defines the size of the reduced dataset (it is shuffled beforehand) we want to use, any number
        below or equal to 0 means that we keep the whole dataset.
        :param load: If true we do not compute anything and will load values from a file.
        :param kwargs:
        """
        super(JointDataset, self).__init__()

        self.online = online
        self.len: int

        self.spectrograms: torch.Tensor
        self.targets: torch.Tensor

    @staticmethod
    def load(
        spectrograms_path: Union[str, Path], targets_path: Union[str, Path], **kwargs
    ):
        """
        Load a dataset whose spectorgrams and targets are loaded from .pt files.
        :param spectrograms_path: Path of the spectrograms tensor file.
        :param targets_path: Path of the targets tensor file.
        :param kwargs:
        :return: A JointDataset object with populated tensors.
        """
        ds = JointDataset(csv_path=None, online=False, debug=-1, load=True)

        ds.spectrograms = torch.load(spectrograms_path)
        ds.targets = torch.load(targets_path)
        ds.len = len(ds.targets)

        return ds

    def __len__(self):
        """
        :return: Length of the dataset.
        """
        return self.len

    def __getitem__(self, item):
        """
        :param item: Index of the item to retrieve.
        :return: The item-th entry.
        """
        if self.online:
            return {
                "row_id": self.row_id[item],
                "site": self.site[item],
                "audio_id": self.audio_id[item],
                "seconds": self.seconds[item],
                "birds": self.birds[item],
            }
        else:
            return {
                "spectrograms": self.spectrograms[item],
                "targets": self.targets[item],
            }

In [None]:
class JointDataModule(pl.LightningModule):
    def __init__(
        self,
        num_workers: Dict,
        batch_size: Dict,
        shuffle: Dict,
        **kwargs,
    ):
        super().__init__()
        self.num_workers = num_workers
        self.batch_size = batch_size
        self.shuffle = shuffle

        # These attributes will be populated after self.setup() call.
        self.train_ds: Optional[Dataset] = None
        self.val_ds: Optional[Dataset] = None
        self.test_ds: Optional[Dataset] = None

    def setup(self, stage: Optional[str] = None) -> None:
        if stage is None or stage == "fit":
            # Train
            self.train_ds = JointDataset.load(
                spectrograms_path=TRAIN_SPECTROGRAMS,
                targets_path=TRAIN_TARGETS
            )

            # Val
            self.val_ds = JointDataset.load(
                spectrograms_path=VAL_SPECTROGRAMS,
                targets_path=VAL_TARGETS
            )
        
        if stage is None or stage == "test":
            # Test
            self.test_ds = JointDataset.load(
                spectrograms_path=TEST_SPECTROGRAMS,
                targets_path=TEST_TARGETS
            )
    def train_dataloader(
        self,
    ) -> Union[DataLoader, List[DataLoader], Dict[str, DataLoader]]:
        batch_size = self.batch_size["train"]
        shuffle = self.shuffle["train"]

        dl = DataLoader(
            dataset=self.train_ds,
            batch_size=batch_size,
            shuffle=shuffle,
        )

        return dl

    def val_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        batch_size = self.batch_size["val"]
        shuffle = self.shuffle["val"]

        dl = DataLoader(
            dataset=self.val_ds,
            batch_size=batch_size,
            shuffle=shuffle,
        )

        return dl
    
    def test_dataloader(self) -> Union[DataLoader, List[DataLoader]]:
        batch_size = self.batch_size["test"]
        shuffle = self.shuffle["test"]

        dl = DataLoader(
            dataset=self.test_ds,
            batch_size=batch_size,
            shuffle=shuffle,
        )

        return dl

### Model

For all the model definition refer to the project directories Birdcalls.src.pl_module and Birdcalls.our.models.


In [None]:
class CNNRes(nn.Module):
    def __init__(self, in_channels: int, kernel_size: int):
        super(CNNRes, self).__init__()

        # This padding is added to keep dimensionality the same, it is recommended to choose even kernels.
        pad = kernel_size // 2

        self.conv1 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

        self.conv2 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

        self.conv3 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

        self.conv4 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

        self.conv5 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

        self.conv6 = nn.Conv2d(
            in_channels=in_channels,
            out_channels=in_channels,
            kernel_size=(kernel_size, kernel_size),
            padding=pad,
        )

    def forward(self, res):
        out = self.conv1(res)
        res = self.conv1(out) + res
        out = self.conv1(res)
        res = self.conv1(out) + res
        out = self.conv1(res)
        res = self.conv1(out) + res
        return res

In [None]:
class Extraction(nn.Module):
    """
    Feature extraction backbone.
    """

    def __init__(
        self,
        image_shape: Tuple[int, int, int],
        res_kernels: List,
        pool: int,
    ):
        super(Extraction, self).__init__()
        self.level = len(res_kernels)

        # Residual blocks and convolutions in between them used to change dimensionality/filter sizes.
        # We can fix out_channels since all audio data has originally one channel.
        self.res = nn.Sequential()
        in_channels = image_shape[0]

        for n, kernel in enumerate(res_kernels):
            res = CNNRes(in_channels=in_channels, kernel_size=kernel)
            cnn = nn.Conv2d(
                in_channels=in_channels,
                out_channels=2 * in_channels,
                kernel_size=kernel,
                stride=(2, 2),
            )

            self.res.add_module(name=f"CNNRes{n+1}", module=res)
            self.res.add_module(name=f"CNN{n+1}", module=cnn)

            in_channels *= 2

        # Output pooling.
        self.pool = nn.AvgPool2d(kernel_size=pool)

    def forward(self, xb):
        out = self.res(xb)
        out = self.pool(out)

        return out

In [None]:
class Classification(nn.Module):
    def __init__(self, out_features: int, **kwargs):
        super().__init__()
        self.ext = Extraction(
            image_shape=(1, 128, 313),
            res_kernels=[3,5],
            pool=1,
        )

        self.gru = nn.GRU(
            input_size=9120,
            hidden_size=512,
            num_layers=1,
            bidirectional=True,
            dropout=0
        )

        self.fc = nn.Linear(
            in_features=1024, out_features=out_features
        )


    def forward(self, xb):
        # Feature extraction backbone.
        out = self.ext(xb)

        # Reshape.
        b, c, w, h = out.shape
        out = out.reshape(b, 1, c * w * h).transpose(0, 1)

        # Prediction head.
        out, _ = self.gru(out)
        logits = self.fc(out.squeeze(0))

        return logits


In [None]:
class JointClassification(pl.LightningModule):
    def __init__(self, out_features: int, **kwargs):
        super(JointClassification, self).__init__()
        self.save_hyperparameters()
        self.model = Classification(out_features=out_features)

        self.loss = nn.CrossEntropyLoss()

        accuracy = torchmetrics.Accuracy()
        self.train_accuracy = accuracy.clone()
        self.val_accuracy = accuracy.clone()
        self.test_accuracy = accuracy.clone()

        precision = torchmetrics.Precision()
        self.train_precision = precision.clone()
        self.val_precision = precision.clone()
        self.test_precision = precision.clone()

        recall = torchmetrics.Recall()
        self.train_recall = recall.clone()
        self.val_recall = recall.clone()
        self.test_recall = recall.clone()

    def forward(self, xb):
        logits = self.model(xb)
        preds = torch.argmax(logits, dim=-1)
        return logits, preds

    def step(self, x: torch.Tensor, y: torch.Tensor):
        logits, preds = self(x)
        loss = self.loss(logits, y)
        return {"logits": logits, "preds": preds, "loss": loss}

    def training_step(self, batch: Any, batch_idx: int) -> torch.Tensor:
        targets = batch["targets"]
        specs = batch["spectrograms"]
        out_step = self.step(x=specs, y=targets)

        x = out_step["preds"]
        y = targets

        self.train_accuracy(x, y)
        self.train_precision(x, y)
        self.train_recall(x, y)

        self.log_dict(
            {
                "train_loss": out_step["loss"],
                "train_acc": self.train_accuracy.compute(),
                "train_prec": self.train_precision.compute(),
                "train_rec": self.train_recall.compute(),
            }
        )
        return out_step["loss"]

    def validation_step(self, batch: Any, batch_idx: int):
        targets = batch["targets"]
        specs = batch["spectrograms"]
        out_step = self.step(x=specs, y=targets)

        x = out_step["preds"]
        y = targets

        self.val_accuracy(x, y)
        self.val_precision(x, y)
        self.val_recall(x, y)

        self.log_dict(
            {
                "val_loss": out_step["loss"],
                "val_acc": self.val_accuracy.compute(),
                "val_prec": self.val_precision.compute(),
                "val_rec": self.val_recall.compute(),
            }
        )
        return out_step["loss"]

    def test_step(self, batch: Any, batch_idx: int):
        targets = batch["targets"]
        specs = batch["spectrograms"]
        out_step = self.step(x=specs, y=targets)

        x = out_step["preds"]
        y = targets

        self.test_accuracy(x, y)
        self.test_precision(x, y)
        self.test_recall(x, y)

        self.log_dict(
            {
                "test_acc": self.test_accuracy.compute(),
                "test_prec": self.test_precision.compute(),
                "test_rec": self.test_recall.compute(),
            }
        )
        
        # Get the list of classes.
        ordered = sorted(load_vocab(BIRD2IDX).items(), key=lambda item: int(item[1]))
        classes = [c for c, _ in ordered] + ["nocall"]

        self.logger.experiment.log(
            {
                "conf_mat": wandb.plot.confusion_matrix(
                    probs=None,
                    preds=x.cpu().numpy(),
                    y_true=y.cpu().numpy(),
                    class_names=classes,
                )
            }
        )

    def configure_optimizers(self):
        opt = self.hparams.optim["optimizer"]["fn"](
            params=self.parameters(),
            lr=self.hparams.optim["optimizer"]["lr"],
            betas=self.hparams.optim["optimizer"]["betas"],
            eps=self.hparams.optim["optimizer"]["eps"],
            weight_decay=self.hparams.optim["optimizer"]["weight_decay"],
        )
        
        if not self.hparams.optim["use_lr_scheduler"]:
            return {"optimizer": opt}
        else:
            scheduler = self.hparams.optim["lr_scheduler"]["fn"](
                optimizer=opt,
                T_0=self.hparams.optim["lr_scheduler"]["T_0"],
                T_mult=self.hparams.optim["lr_scheduler"]["T_mult"],
                eta_min=self.hparams.optim["lr_scheduler"]["eta_min"],
                last_epoch=self.hparams.optim["lr_scheduler"]["last_epoch"],
                verbose=self.hparams.optim["lr_scheduler"]["verbose"],
            )
            return {"optimizer": opt, "lr_scheduler": scheduler}

## Training

Environmental and setup variables.

In [None]:
TRAIN_SPECTROGRAMS= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/precomputed/train/joint/spectrograms.pt")
TRAIN_TARGETS= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/precomputed/train/joint/targets.pt")

VAL_SPECTROGRAMS= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/precomputed/val/joint/spectrograms.pt")
VAL_TARGETS= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/precomputed/val/joint/targets.pt")

TEST_SPECTROGRAMS= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/precomputed/test/joint/spectrograms.pt")
TEST_TARGETS= Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/precomputed/test/joint/targets.pt")

BIRD2IDX = Path("/content/gdrive/My Drive/Colab Notebooks/Birdcalls/out/vocabs/bird2idx.json")

In [None]:
num_workers = {'train': 12, 'val': 12, 'test':12}
batch_size = {'train': 8, 'val': 8, 'test':8}
shuffle = {'train': True, 'val': False, 'test': False}

# Optimizer
optimizer = {'fn': torch.optim.Adam,
             'lr': 1e-4, #You can even try 1e-4
             'betas': [ 0.9, 0.999 ],
             'eps': 1e-08,
             'weight_decay': 0
             }

use_lr_scheduler = False

lr_scheduler = {'fn': torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
                'T_0': 10,
                'T_mult': 2,
                'eta_min': 0,
                'last_epoch': -1,
                'verbose': True}

optim = {'optimizer': optimizer,
         'use_lr_scheduler': use_lr_scheduler,
         'lr_scheduler': lr_scheduler}

# Trainer
train = {
    "deterministic": True,
    "random_seed": 42,
    "val_check_interval": 1.0,
    "progress_bar_refresh_rate": 20,
    "fast_dev_run": False, # True for debug purposes.
    "gpus": -1 if torch.cuda.is_available() else 0,
    "precision": 32,
    "max_steps": 100,
    "max_epochs": 20,
    "accumulate_grad_batches": 1,
    "num_sanity_val_steps": 2,
    "gradient_clip_val": 10.0
}

In [None]:
# Call that only once!
if train["deterministic"]:
    seed_everything(train["random_seed"])

W&B login.

In [None]:
wandb.login()

Let's setup the trainer and we can run it.

In [None]:
datamodule = JointDataModule(num_workers=num_workers,
                        batch_size=batch_size,
                        shuffle=shuffle)

model = JointClassification(optim=optim, out_features=398)

wandb_logger = WandbLogger(
    project="Joint classification",
    config={
        "batch_size": batch_size['train'],
        "learning_rate": optimizer['lr'],
        "optimizer": optimizer['fn'],
        "betas": optimizer["betas"],
        "eps": optimizer["eps"],
        "weight_decay": optimizer["weight_decay"],
        "lr_scheduler": use_lr_scheduler,
        "T_0": lr_scheduler["T_0"],
        "T_mult": lr_scheduler["T_mult"],
        "eta_min": lr_scheduler["eta_min"],
        "last_epoch": lr_scheduler["last_epoch"],
        "dataset": "Bird CLEF 2021",
        "summary": summary(model),
        }
)

trainer = pl.Trainer(
        logger=wandb_logger,
        deterministic=train["deterministic"],
        gpus=train["gpus"],
        max_epochs=train["max_epochs"],
    )

In [None]:
print(summary(model))
print(model)

Fit

In [None]:
trainer.fit(model=model, datamodule=datamodule)

Validation

In [None]:
trainer.validate(model=model, datamodule=datamodule)

Test

In [None]:
trainer.test(model=model, datamodule=datamodule)

Quit W&B

In [None]:
wandb.finish()