# 1. Build an image segmentation model using pytorch

Use CUDA for training and inferencing:

In [1]:
import os
from typing import List, Tuple

from PIL import Image

import torch
import torch.nn as nn
import torchvision.transforms as T
from torch.utils.data import Dataset

In [2]:
device: torch.device = torch.device('cuda')

## Datasets:

First, we need to create the `BirdSoundDataset` class:

In [3]:
class BirdSoundDataset(Dataset):

    def __init__(self, dataroot: str, resolution: Tuple[int, int]):
        super().__init__()
        self.dataroot: str = dataroot
        self.resolution: Tuple[int, int] = resolution
        self.image_folder: str = f'{self.dataroot}/images'
        self.mask_folder: str = f'{self.dataroot}/masks'
        self.image_filenames: List[str] = sorted(os.listdir(self.image_folder))
        self.mask_filenames: List[str] = sorted(os.listdir(self.mask_folder))
        assert len(self.image_filenames) == len(self.mask_filenames)

        self.__transformer = T.Compose([
            T.ToTensor(),
            T.Grayscale(num_output_channels=1),
            T.Resize(size=self.resolution),
        ])

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image_filename: str = self.image_filenames[idx]
        mask_filename: str = self.mask_filenames[idx]
        assert image_filename == mask_filename
        image_path: str = f'{self.image_folder}/{image_filename}'
        mask_path: str = f'{self.mask_folder}/{mask_filename}'
        image_tensor: torch.Tensor = self.__transformer(Image.open(image_path))
        mask_tensor: torch.Tensor = self.__transformer(Image.open(mask_path))
        mask_tensor: torch.Tensor = (mask_tensor != 0).to(dtype=torch.int8)
        return image_tensor, mask_tensor
    
    def __len__(self) -> int:
        return len(self.image_filenames)

Create the train, validation, test datasets:

In [4]:
train_dataset = BirdSoundDataset(dataroot='data/train', resolution=(128, 512))
val_dataset = BirdSoundDataset(dataroot='data/valid', resolution=(128, 512))
test_dataset = BirdSoundDataset(dataroot='data/test', resolution=(128, 512))

Load the datasets to dataloaders

In [5]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(train_dataset, batch_size=len(train_dataset))
val_dataloader = DataLoader(val_dataset, batch_size=len(val_dataset))
test_dataloader = DataLoader(test_dataset, batch_size=len(test_dataset))

Let's inspect the data shapes:

Training set:

In [6]:
sample_train_images, sample_train_masks = next(iter(train_dataloader))
print(sample_train_images.shape)
print(sample_train_masks.shape)

torch.Size([1000, 1, 128, 512])
torch.Size([1000, 1, 128, 512])


Validation set:

In [7]:
sample_val_images, sample_val_masks = next(iter(val_dataloader))
print(sample_val_images.shape)
print(sample_val_masks.shape)

torch.Size([200, 1, 128, 512])
torch.Size([200, 1, 128, 512])


Test set:

In [8]:
sample_test_images, sample_test_masks = next(iter(test_dataloader))
print(sample_test_images.shape)
print(sample_test_masks.shape)

torch.Size([300, 1, 128, 512])
torch.Size([300, 1, 128, 512])


As can be seen, we treat the input images as grayscale (`n_channels=1`), the output mask has the same resolution with the input images, it also has 1 channel that encodes the groundtruth mask.

In [9]:
# Delete unnecessary variables to save memory:
del sample_train_images, sample_train_masks, sample_val_images, sample_val_masks, sample_test_images, sample_test_masks

## Utilities:

We also need some utility classes to control the training and inferencing process. 

First, create the `Accumulator` to keep track of the loss and metrics:

In [10]:
import os
import pathlib
import time
from typing import Optional, Dict, TextIO, Any, Tuple, NamedTuple
from collections import defaultdict
import datetime as dt
import copy
import inspect

import torch
import torch.nn as nn
from torch.optim import Optimizer

In [11]:
class Accumulator:
    """
    A utility class for accumulating values for multiple metrics.
    """

    def __init__(self) -> None:
        self.__records: defaultdict[str, float] = defaultdict(float)

    def add(self, **kwargs: Any) -> None:
        """
        Add values to the accumulator.

        Parameters:
            - **kwargs: named metric and the value is the amount to add.
        """
        metric: str
        value: float
        for metric, value in kwargs.items():
            # Each keyword argument represents a metric name and its value to be added
            self.__records[metric] += value
    
    def reset(self) -> None:
        """
        Reset the accumulator by clearing all recorded metrics.
        """
        self.__records.clear()

    def __getitem__(self, key: str) -> float:
        """
        Retrieve a record by key.

        Parameters:
            - key (str): The record key name.

        Returns:
            - float: The record value.
        """
        return self.__records[key]

We also need a `EarlyStopping` class to terminate the training process when some evaluation metric stops improving:

In [12]:
class EarlyStopping:
    """
    A simple early stopping utility to terminate training when a monitored metric stops improving.

    Attributes:
        - patience (int): The number of epochs with no improvement after which training will be stopped.
        - tolerance (float): The minimum change in the monitored metric to qualify as an improvement,
        - considering the direction of the metric being monitored.
        - bestscore (float): The best score seen so far.
    """
    
    def __init__(self, patience: int, tolerance: float = 0.) -> None:
        """
        Initializes the EarlyStopping instance.
        
        Parameters:
            - patience (int): Number of epochs with no improvement after which training will be stopped.
            - tolerance (float): The minimum change in the monitored metric to qualify as an improvement. 
            Defaults to 0.
        """
        self.patience: int = patience
        self.tolerance: float = tolerance
        self.bestscore: float = float('inf')
        self.__counter: int = 0

    def __call__(self, value: float) -> None:
        """
        Update the state of the early stopping mechanism based on the new metric value.

        Parameters:
            - value (float): The latest value of the monitored metric.
        """
        # Improvement or within tolerance, reset counter
        if value <= self.bestscore + self.tolerance:
            self.bestscore: float = value
            self.__counter: int = 0

        # No improvement, increment counter
        else:
            self.__counter += 1

    def __bool__(self) -> bool:
        """
        Determine if the training process should be stopped early.

        Returns:
            - bool: True if training should be stopped (patience exceeded), otherwise False.
        """
        return self.__counter >= self.patience

We should also create a `Timer` class to time the training and inferencing process:

In [13]:
class Timer:

    """
    A class used to time the duration of epochs and batches.
    """
    def __init__(self) -> None:
        """
        Initialize the Timer.
        """
        self.__epoch_starts: Dict[int, float] = dict()
        self.__epoch_ends: Dict[int, float] = dict()
        self.__batch_starts: Dict[int, Dict[int, float]] = defaultdict(dict)
        self.__batch_ends: Dict[int, Dict[int, float]] = defaultdict(dict)

    def start_epoch(self, epoch: int) -> None:
        """
        Start timing an epoch.

        Parameters:
            epoch (int): The epoch number.
        """
        self.__epoch_starts[epoch] = time.time()

    def end_epoch(self, epoch: int) -> None:
        """
        End timing an epoch.

        Parameters:
            - epoch (int): The epoch number.
        """
        self.__epoch_ends[epoch] = time.time()

    def start_batch(self, epoch: int, batch: Optional[int] = None) -> None:
        """
        Start timing a batch.

        Parameters:
            - epoch (int): The epoch number.
            - batch (int, optional): The batch number. If not provided, the next batch number is used.
        """
        if batch is None:
            if self.__batch_starts[epoch]:
                batch: int = max(self.__batch_starts[epoch].keys()) + 1
            else:
                batch: int = 1
        self.__batch_starts[epoch][batch] = time.time()
    
    def end_batch(self, epoch: int, batch: Optional[int] = None) -> None:
        """
        End timing a batch.

        Parameters:
            - epoch (int): The epoch number.
            - batch (int, optional): The batch number. If not provided, the last started batch number is used.
        """
        if batch is None:
            if self.__batch_starts[epoch]:
                batch: int = max(self.__batch_starts[epoch].keys())
            else:
                raise RuntimeError(f"no batch has started")
        self.__batch_ends[epoch][batch] = time.time()
    
    def time_epoch(self, epoch: int) -> float:
        """
        Get the duration of an epoch.

        Parameters:
            - epoch (int): The epoch number.

        Returns:
            - float: The duration of the epoch in seconds.
        """
        result: float = self.__epoch_ends[epoch] - self.__epoch_starts[epoch]
        if result > 0:
            return result
        else:
            raise RuntimeError(f"epoch {epoch} ends before starts")
    
    def time_batch(self, epoch: int, batch: int) -> float:
        """
        Get the duration of a batch.

        Parameters:
            - epoch (int): The epoch number.
            - batch (int): The batch number.

        Returns:
            - float: The duration of the batch in seconds.
        """
        result: float = self.__batch_ends[epoch][batch] - self.__batch_starts[epoch][batch]
        if result > 0:
            return result
        else:
            raise RuntimeError(f"batch {batch} in epoch {epoch} ends before starts")

A `Logger` class is implemented to log the messages to the standard output and direct it to a log file:

In [14]:
class Logger:

    """
    A class used to log the training process.

    This class provides methods to log messages to a file and the console. 
    """
    def __init__(
        self, 
        logfile: str = f"./.logs/{dt.datetime.now().strftime('%Y%m%d%H%M%S')}"
    ) -> None:
    
        """
        Initialize the logger.

        Parameters:
            - logfile (str, optional): The path to the logfile. 
            Defaults to a file in the .logs directory with the current timestamp.
        """
        self.logfile: pathlib.Path = pathlib.Path(logfile)
        os.makedirs(name=self.logfile.parent, exist_ok=True)
        self._file: TextIO = open(self.logfile, mode='w')

    def log(
        self, 
        epoch: int, 
        n_epochs: int, 
        batch: Optional[int] = None, 
        n_batches: Optional[int] = None, 
        took: Optional[float] = None, 
        **kwargs: Any,
    ) -> None:
        """
        Log a message to console and a log file

        Parameters:
            - epoch (int): The current epoch.
            - n_epochs (int): The total number of epochs.
            - batch (int, optional): The current batch. Defaults to None.
            - n_batches (int, optional): The total number of batches. Defaults to None.
            - took (float, optional): The time it took to process the batch or epoch. Defaults to None.
            - **kwargs: Additional metrics to log.
        """
        suffix: str = ', '.join([f'{metric}: {value:.3e}' for metric, value in kwargs.items()])
        prefix: str = f'Epoch {epoch}/{n_epochs} | '
        if batch is not None:
            prefix += f'Batch {batch}/{n_batches} | '
        if took is not None:
            prefix += f'Took {took:.2f}s | '
        logstring: str = prefix + suffix
        print(logstring)
        self._file.write(logstring + '\n')

    def __del__(self) -> None:
        """
        Close the logfile at garbage collected.
        """
        self._file.close()

The `CheckpointSaver` class is used to regularly save the checkpoints to disk:

In [15]:
class CheckpointSaver:
    """
    A class used to save PyTorch model and optimizer checkpoints.
    """
    def __init__(
        self, 
        model: nn.Module, 
        optimizer: Optimizer,
        dirpath: str,
    ) -> None:
        """
        Initialize the CheckPointSaver.

        Parameters:
            - dirpath (os.PathLike): The directory where the checkpoints to save.
            - model (nn.Module): The class object of the model
            - optimizer_classname (Optimizer): The class object of the optimizer
        """
        self.dirpath: pathlib.Path = pathlib.Path(dirpath)
        # For model reconstruction
        self.model_classname: str = model.__class__.__name__
        signature: inspect.Signature = inspect.signature(model.__init__)
        self.model_kwargs: Dict[str, Any] = {
            p: getattr(model, p) for p in signature.parameters.keys() if p != 'self'
        }
        # For optimizer reconstruction
        self.optimizer_classname: str = optimizer.__class__.__name__
        # ensure the dirpath exists in the file system
        os.makedirs(name=self.dirpath, exist_ok=True)

    def save(
        self, 
        model_states: Dict[str, Any],
        optimizer_states: Dict[str, Any],
        filename: str
    ) -> None:
        """
        Save checkpoint to a .pt file.

        Parameters:
            - model_states (Dict[str, torch.Tensor]): The output of model.state_dict()
            - optimizer_states (Dict[str, Any]): The output of optimizer.state_dict()
            - filename (str): the checkpoint file name
        """
        torch.save(
            obj={
                'model': {
                    'classname' : self.model_classname,
                    'kwargs'    : self.model_kwargs,
                    'states'    : copy.deepcopy(model_states),
                },
                'optimizer': {
                    'classname' : self.optimizer_classname,
                    'states'    : copy.deepcopy(optimizer_states),
                }
            },
            f=os.path.join(self.dirpath, filename)
        )

The `CheckpointLoader` class is used to load the checkpoints back to RAM/VRAM for further training or inferencing:

In [16]:
class CheckpointLoader:
    """
    A class used to load PyTorch model and optimizer checkpoints.
    """
    def __init__(self, checkpoint_path: str) -> None:
        """
        Initialize the CheckpointLoader.

        Parameters:
            - checkpoint_path (str): The path to the checkpoint file.
        """
        self.checkpoint_path: str = checkpoint_path
        self.__checkpoint: Dict[str, Any] = torch.load(checkpoint_path, weights_only=False)

        # Model metadata
        self.model_classname: str = self.__checkpoint['model']['classname']
        self.model_kwargs: Dict[str, Any] = self.__checkpoint['model']['kwargs']
        
        # Optimizer metadata
        self.optimizer_classname: str = self.__checkpoint['optimizer']['classname']

    def load(self, scope: Dict[str, Any]) -> Tuple[nn.Module, Optimizer]:
        """
        Load the model and optimizer from the checkpoint.

        Parameters:
            - scope (Dict[str, Any]): The namespace to look up the model and optimizer object. 
                It's typically the dictionary output of `globals()` or `locals()`
        
        Returns:
            - Tuple[nn.Module, Optimizer]: The model and optimizer loaded from the checkpoint.
        """
        # Check caller's namespace for model object
        if self.model_classname not in scope.keys():
            raise ImportError(
                f'{self.model_classname} is not found in the current namespace, you might need to import it first.'
            )
        
        # Check caller's namespace for optimizer object
        if self.optimizer_classname not in scope.keys():
            raise ImportError(
                f'{self.optimizer_classname} is not found in the current namespace, you might need to import it first.'
            )
        
        # Instantiate model and optimizer
        model = eval(self.model_classname, scope)(**self.model_kwargs)
        optimizer = eval(self.optimizer_classname, scope)(params=model.parameters())

        # Load model from model state_dict and check for compatibility
        model_states: Dict[str, Any] = self.__checkpoint['model']['states']
        model_incompatible_keys: NamedTuple = model.load_state_dict(model_states)   # inplace update
        if model_incompatible_keys.missing_keys:  # List[str]
            raise RuntimeError(f'Missing keys from the loaded model checkpoint: {model_incompatible_keys.missing_keys}')
        if model_incompatible_keys.unexpected_keys: # List[str]
            raise RuntimeError(f'Unexpected keys found in the loaded model checkpoint: {model_incompatible_keys.unexpected_keys}')
        
        # Load optimizer from optimizer state_dict, it's always compatible
        optimizer_states: Dict[str, Any] = self.__checkpoint['optimizer']['states']
        optimizer.load_state_dict(optimizer_states) # `load_state_dict` of optimizers always returns None, inplace update

        return model, optimizer

## Custom Unet Model:

To deal with the image segmentation problem, we can leverage the standard Unet architecture where we can make sure the output's resolution matches the input's resolution. In this research, I added a skip connection in each of the unit block to facilitate smooth backpropagation and avoid vanishing gradient descent.

First, we need to define the `UnitBlock`:

In [17]:
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F

In [18]:
class UnitBlock(nn.Module):

    def __init__(
        self, 
        in_channels: int, 
        out_channels: int, 
    ):
        super().__init__()
        self.in_channels: int = in_channels
        self.out_channels: int = out_channels

        self.conv1 = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, 
            kernel_size=3, stride=1, padding=1, bias=False,
        )
        self.batchnorm1 = nn.BatchNorm2d(num_features=out_channels)
        self.conv2 = nn.Conv2d(
            in_channels=out_channels, out_channels=out_channels, 
            kernel_size=3, stride=1, padding=1, bias=False,
        )
        self.batchnorm2 = nn.BatchNorm2d(num_features=out_channels)

        self.conv_connection = nn.Conv2d(
            in_channels=in_channels, out_channels=out_channels, 
            kernel_size=1, stride=1, bias=False,
        )
        self.batchnorm_connection = nn.BatchNorm2d(num_features=out_channels)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert input.ndim == 4
        batch_size, in_channels, in_height, in_width = input.shape
        assert self.in_channels == in_channels

        identity: torch.Tensor = input

        # Local linear transformation
        output: torch.Tensor = self.conv1(input)
        assert output.shape == (batch_size, self.out_channels, in_height, in_width)
        output = self.batchnorm1(output)
        output = F.relu(output)

        output = self.conv2(output)
        assert output.shape == (batch_size, self.out_channels, in_height, in_width)
        output = self.batchnorm2(output)

        # Skip Connection
        identity = self.conv_connection(identity)
        identity = self.batchnorm_connection(identity)

        # Merge
        assert output.shape == identity.shape
        output: torch.Tensor = output + identity
        output = F.relu(output)
        
        assert output.shape == (batch_size, self.out_channels, in_height, in_width)
        return output

Now, we construct the `Encoder`:

In [19]:
class Encoder(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.encoder1 = UnitBlock(in_channels=1, out_channels=64)
        self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder2 = UnitBlock(in_channels=64, out_channels=128)
        self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.encoder3 = UnitBlock(in_channels=128, out_channels=256)
        self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
        
        self.encoder4 = UnitBlock(in_channels=256, out_channels=512)


    def forward(self, input: torch.Tensor) -> Tuple[
        torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor
    ]:
        enc1: torch.Tensor = self.encoder1(input)
        enc2: torch.Tensor = self.encoder2(self.pool1(enc1))
        enc3: torch.Tensor = self.encoder3(self.pool2(enc2))
        enc4: torch.Tensor = self.encoder4(self.pool3(enc3))
        return enc1, enc2, enc3, enc4

And the `BottleNeck`:

In [20]:
class BottleNeck(nn.Module):

    def __init__(self):
        super().__init__()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
        self.bottleneck = UnitBlock(in_channels=512, out_channels=1024)
        self.upconv = nn.ConvTranspose2d(
            in_channels=1024, out_channels=512, 
            kernel_size=2, stride=2, bias=False,
        )
    
    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert input.ndim == 4
        batch_size, in_channels, height, width = input.shape
        assert in_channels == 512
        output: torch.Tensor = self.pool(input)
        output: torch.Tensor = self.bottleneck(output)
        output: torch.Tensor = self.upconv(output)
        assert output.shape == (batch_size, 512, height, width)
        return output

Finally, the `Decoder`:

In [21]:
class Decoder(nn.Module):

    def __init__(self):
        super().__init__()
        self.decoder1 = UnitBlock(in_channels=1024, out_channels=512)
        self.upconv1 = nn.ConvTranspose2d(
            in_channels=512, out_channels=256, 
            kernel_size=2, stride=2, bias=False,
        )
        self.decoder2 = UnitBlock(in_channels=512, out_channels=256)
        self.upconv2 = nn.ConvTranspose2d(
            in_channels=256, out_channels=128, 
            kernel_size=2, stride=2, bias=False,
        )
        self.decoder3 = UnitBlock(in_channels=256, out_channels=128)
        self.upconv3 = nn.ConvTranspose2d(
            in_channels=128, out_channels=64, 
            kernel_size=2, stride=2, bias=False, 
        )
        self.decoder4 = UnitBlock(in_channels=128, out_channels=64)
        self.prediction_head = nn.Conv2d(
            in_channels=64, out_channels=1, 
            kernel_size=1, stride=1, bias=False,
        )
    
    def forward(
        self,
        input: torch.Tensor, 
        encoder_outputs: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], 
    ) -> torch.Tensor:
        assert input.ndim == 4
        batch_size, in_channels, in_height, in_width = input.shape
        assert in_channels == 512
        enc1, enc2, enc3, enc4 = encoder_outputs

        assert enc4.shape == input.shape
        output: torch.Tensor = self.decoder1(torch.cat(tensors=[enc4, input], dim=1))
        output: torch.Tensor = self.upconv1(output)

        assert enc3.shape == output.shape
        output: torch.Tensor = self.decoder2(torch.cat(tensors=[enc3, output], dim=1))
        output: torch.Tensor = self.upconv2(output)

        assert enc2.shape == output.shape
        output: torch.Tensor = self.decoder3(torch.cat(tensors=[enc2, output], dim=1))
        output: torch.Tensor = self.upconv3(output)

        assert enc1.shape == output.shape
        output: torch.Tensor = self.decoder4(torch.cat(tensors=[enc1, output], dim=1))
        output: torch.Tensor = self.prediction_head(output)

        return output

The `UNet` model makes use of all the modules above:

In [22]:
class UNet(nn.Module):

    def __init__(
        self, 
        encoder: Encoder, 
        bottleneck: BottleNeck, 
        decoder: Decoder = 64, 
    ) -> None:
        
        super().__init__()
        self.encoder: Encoder = encoder
        self.bottleneck: BottleNeck = bottleneck
        self.decoder: Decoder = decoder

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert input.ndim == 4
        batch_size, in_channels, in_height, in_width = input.shape
        enc1, enc2, enc3, enc4 = self.encoder(input)
        output: torch.Tensor = self.bottleneck(enc4)
        output: torch.Tensor = self.decoder(input=output, encoder_outputs=[enc1, enc2, enc3, enc4])
        assert output.shape == (batch_size, 1, in_height, in_width)
        return output

Let's test the `Unet` model on random data:

In [23]:
net = UNet(encoder=Encoder(), bottleneck=BottleNeck(), decoder=Decoder()).to(device=device)
x = torch.rand(8, 1, 128, 512, device=device)
y = net(x)
print(x.shape)
print(y.shape)

torch.Size([8, 1, 128, 512])
torch.Size([8, 1, 128, 512])


In [24]:
# Delete unnecessary variables to save memory:
del x, y

The output of the model is as expected.

# 2. Train your model using [Bird sound datasets](https://yuad-my.sharepoint.com/personal/youshan_zhang_yu_edu/_layouts/15/onedrive.aspx?id=%2Fpersonal%2Fyoushan%5Fzhang%5Fyu%5Fedu%2FDocuments%2FBird%5FSound%5FDataset&ga=1)

## Loss function and evaluation metrics:

In order to train the model, we need a to prepare a loss function and a evaluation metric. In this project, I choose the Soft Dice Loss function and the IoU metric:

In [25]:
class SoftDiceLoss(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, logits: torch.Tensor, groundtruths: torch.Tensor) -> torch.Tensor:
        assert torch.all((groundtruths == 0) | (groundtruths == 1))
        groundtruths: torch.Tensor = groundtruths.float()
        probabilities: torch.Tensor = torch.sigmoid(input=logits)
        numerator: torch.Tensor = 2. * (probabilities * groundtruths).sum() + 1.
        denorminator: torch.Tensor = (probabilities + groundtruths).sum() + 1.
        return 1. - numerator / denorminator

In [26]:
class IOU(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, logits: torch.Tensor, groundtruths: torch.Tensor) -> torch.Tensor:
        assert torch.all((groundtruths == 0) | (groundtruths == 1))
        groundtruths: torch.Tensor = groundtruths.int()
        probabilities: torch.Tensor = torch.sigmoid(input=logits)
        predictions: torch.Tensor = (probabilities > 0.5).int()
        intersection: torch.Tensor = predictions & groundtruths
        union: torch.Tensor = predictions | groundtruths
        return intersection.sum() / union.sum()

We implement the training process in the `Trainer` class:

In [27]:
from typing import List, Optional

import torch
import torch.nn as nn
from torch.utils.data import Dataset, Subset, DataLoader
from torch.optim import Optimizer

In [28]:
class Trainer:

    def __init__(
        self, 
        model: nn.Module,
        optimizer: Optimizer,
        train_dataset: BirdSoundDataset,
        val_dataset: BirdSoundDataset,
        train_batch_size: int,
        val_batch_size: int,
        device: torch.device,
    ):
        self.model: nn.Module = model.to(device=device)
        self.optimizer: Optimizer = optimizer
        self.train_dataset: BirdSoundDataset = train_dataset
        self.val_dataset: BirdSoundDataset = val_dataset
        self.train_batch_size: int = train_batch_size
        self.val_batch_size: int = val_batch_size
        self.device: torch.device = device

        self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True)
        self.val_dataloader = DataLoader(dataset=val_dataset, batch_size=val_batch_size, shuffle=False)
        self.loss_function: nn.Module = SoftDiceLoss()
        self.evaluation_metric: nn.Module = IOU()

    def train(
        self, 
        n_epochs: int,
        patience: int,
        tolerance: float,
        checkpoint_path: Optional[str] = None,
        save_frequency: int = 5,
    ) -> None:
        
        train_accumulator = Accumulator()
        early_stopping = EarlyStopping(patience, tolerance)
        timer = Timer()
        logger = Logger()
        checkpoint_saver = CheckpointSaver(
            model=self.model,
            optimizer=self.optimizer,
            dirpath=checkpoint_path,
        )
        self.model.train()
        
        # loop through each epoch
        for epoch in range(1, n_epochs + 1):
            timer.start_epoch(epoch)
            # Loop through each batch
            for batch, (batch_image, batch_groundtruth) in enumerate(self.train_dataloader, start=1):
                timer.start_batch(epoch, batch)
                assert batch_image.ndim == 4
                batch_size, n_channels, height, width = batch_image.shape
                batch_image: torch.Tensor = batch_image.to(device=self.device)
                batch_groundtruth: torch.Tensor = batch_groundtruth.to(device=self.device)
                self.optimizer.zero_grad()
                batch_prediction: torch.Tensor = self.model(input=batch_image)
                assert batch_prediction.shape == batch_groundtruth.shape
                dice_loss = self.loss_function(
                    logits=batch_prediction, groundtruths=batch_groundtruth,
                )
                dice_loss.backward()
                self.optimizer.step()

                # Accumulate the train metrics
                with torch.no_grad():
                    iou = self.evaluation_metric(
                        logits=batch_prediction, groundtruths=batch_groundtruth,
                    )

                train_accumulator.add(
                    dice_loss=dice_loss.item(), 
                    iou=iou.item(),
                )
                timer.end_batch(epoch=epoch)
                logger.log(
                    epoch=epoch, n_epochs=n_epochs, 
                    batch=batch, n_batches=len(self.train_dataloader), 
                    took=timer.time_batch(epoch, batch), 
                    train_dice_loss=train_accumulator['dice_loss'] / batch,  
                    train_iou=train_accumulator['iou'] / batch, 
                )
        
            # Ragularly save checkpoint
            if checkpoint_path is not None and epoch % save_frequency == 0:
                checkpoint_saver.save(
                    model_states=self.model.state_dict(), 
                    optimizer_states=self.optimizer.state_dict(),
                    filename=f'epoch{epoch}.pt',
                )
            
            # Reset metric records for next epoch
            train_accumulator.reset()

            # Evaluate
            val_dice_loss, val_iou = self.evaluate()
            timer.end_epoch(epoch)
            logger.log(
                epoch=epoch, n_epochs=n_epochs, 
                took=timer.time_epoch(epoch), 
                val_dice_loss=val_dice_loss, val_iou=val_iou,
            )
            print('=' * 20)

            early_stopping(val_dice_loss)
            if early_stopping:
                print('Early Stopped')
                break

        # Always save last checkpoint
        if checkpoint_path:
            checkpoint_saver.save(
                self.model, 
                filename=f'epoch{epoch}.pt', 
                optimizer_states=self.optimizer.state_dict(),
            )

    def evaluate(self) -> float:
        val_accumulator = Accumulator()
        self.model.eval()
        with torch.no_grad():
            # Loop through each batch
            for batch, (batch_image, batch_groundtruth) in enumerate(self.val_dataloader, start=1):
                assert batch_image.ndim == 4
                batch_size, n_channels, height, width = batch_image.shape
                batch_image: torch.Tensor = batch_image.to(device=self.device)
                batch_groundtruth: torch.Tensor = batch_groundtruth.to(device=self.device)
                batch_prediction: torch.Tensor = self.model(input=batch_image)
                assert batch_prediction.shape == batch_groundtruth.shape
                
                dice_loss = self.loss_function(
                    logits=batch_prediction, groundtruths=batch_groundtruth,
                )
                iou = self.evaluation_metric(
                    logits=batch_prediction, groundtruths=batch_groundtruth,
                )
                # Accumulate the val metrics
                val_accumulator.add(
                    val_dice_loss=dice_loss.item(),
                    val_iou=iou.item(),
                )

        # Compute the aggregate metrics
        val_dice_loss: float = val_accumulator['val_dice_loss'] / batch
        val_iou: float = val_accumulator['val_iou'] / batch
        return val_dice_loss, val_iou

Now train the model:

In [29]:
from torch.optim import Adam

optimizer = Adam(params=net.parameters(), lr=0.001)

trainer = Trainer(
    model=net, optimizer=optimizer,
    train_dataset=train_dataset, val_dataset=val_dataset,
    train_batch_size=16, val_batch_size=4,
    device=device,
)
trainer.train(
    n_epochs=100, patience=5,
    tolerance=0., checkpoint_path='./checkpoints/',
    save_frequency=5,
)

Epoch 1/100 | Batch 1/63 | Took 0.77s | train_dice_loss: 8.559e-01, train_iou: 5.983e-02
Epoch 1/100 | Batch 2/63 | Took 0.64s | train_dice_loss: 8.313e-01, train_iou: 1.433e-01
Epoch 1/100 | Batch 3/63 | Took 0.64s | train_dice_loss: 8.120e-01, train_iou: 1.768e-01
Epoch 1/100 | Batch 4/63 | Took 0.64s | train_dice_loss: 7.779e-01, train_iou: 2.116e-01
Epoch 1/100 | Batch 5/63 | Took 0.64s | train_dice_loss: 7.475e-01, train_iou: 2.477e-01
Epoch 1/100 | Batch 6/63 | Took 0.64s | train_dice_loss: 7.286e-01, train_iou: 2.602e-01
Epoch 1/100 | Batch 7/63 | Took 0.64s | train_dice_loss: 7.172e-01, train_iou: 2.671e-01
Epoch 1/100 | Batch 8/63 | Took 0.64s | train_dice_loss: 6.958e-01, train_iou: 2.882e-01
Epoch 1/100 | Batch 9/63 | Took 0.64s | train_dice_loss: 6.721e-01, train_iou: 3.051e-01
Epoch 1/100 | Batch 10/63 | Took 0.64s | train_dice_loss: 6.615e-01, train_iou: 3.141e-01
Epoch 1/100 | Batch 11/63 | Took 0.67s | train_dice_loss: 6.504e-01, train_iou: 3.275e-01
Epoch 1/100 | Batch

# 3.Evaluate your model using the test images

We need a plotting function to plot the input images, the groundtruth masks, and predicted masks:

In [30]:
import os
from typing import List

import datetime as dt
import matplotlib.pyplot as plt
import torch

In [38]:
def plot_predictions(
    images: torch.Tensor,
    groundtruths: torch.Tensor,
    predictions: torch.Tensor,
    notes: List[str],
) -> None:

    assert groundtruths.shape == predictions.shape
    assert groundtruths.ndim == 4   # (batch_size, n_channels, height, width)
    assert groundtruths.shape[1] == 1, 'Expect n_channels to be 1'
    assert notes is None or len(notes) == groundtruths.shape[0]

    os.makedirs(f"./results", exist_ok=True)

    images = images.to(device=torch.device('cpu'))
    groundtruths = groundtruths.to(device=torch.device('cpu'))
    predictions = predictions.to(device=torch.device('cpu'))

    # Ensure that the plot respect the tensor's shape

    for idx in range(predictions.shape[0]):
        image: torch.Tensor = images[idx]
        groundtruth: torch.Tensor = groundtruths[idx]
        prediction: torch.Tensor = predictions[idx]
        fig, axs = plt.subplots(3, 1, figsize=(12, 10))
        axs[0].imshow(
            image.squeeze(dim=0),
            cmap='gray',
        )
        axs[0].set_xticks([])
        axs[0].set_yticks([])
        axs[0].set_title(f'$image$', fontsize=20)
        axs[1].imshow(
            groundtruth.squeeze(dim=0),
            cmap='gray',
        )
        axs[1].set_xticks([])
        axs[1].set_yticks([])
        axs[1].set_title(f'$groundtruth$', fontsize=20)
        prediction: torch.Tensor = (torch.sigmoid(input=prediction) > 0.5).int()
        axs[2].imshow(
            prediction.squeeze(dim=0),
            cmap='gray',
        )
        axs[2].set_xticks([])
        axs[2].set_yticks([])
        axs[2].set_title(f'$prediction - {notes[idx]}$', fontsize=20)
        fig.subplots_adjust(hspace=0.1)
        fig.tight_layout()
        timestamp: dt.datetime = dt.datetime.now()
        fig.savefig(
            f"./results/{timestamp.strftime('%Y%m%d%H%M%S')}"
            f"{timestamp.microsecond // 1000:03d}.png"
        )
        plt.close(fig)

We need a `Predictor` class to make predictions on test dataset:

In [39]:
class Predictor:

    def __init__(self, model: nn.Module, device: torch.device) -> None:
        self.model: nn.Module = model.to(device=device)
        self.device: torch.device = device
        self.loss_function: nn.Module = SoftDiceLoss()
        self.evaluation_metric: nn.Module = IOU()

    def predict(self, dataset: BirdSoundDataset) -> float:
        self.model.eval()
        dataloader = DataLoader(dataset, batch_size=1, shuffle=False) # sample-level method, not batch-level

        batch_images: List[torch.Tensor] = []
        batch_groundtruths: List[torch.Tensor] = []
        batch_predictions: List[torch.Tensor] = []
        iou_values: List[float] = []
        metric_notes: List[str] = []

        with torch.no_grad():
            # Loop through each batch
            for batch_image, batch_groundtruth in dataloader:
                assert batch_image.ndim == 4
                batch_image: torch.Tensor = batch_image.to(device=self.device)
                batch_groundtruth: torch.Tensor = batch_groundtruth.to(device=self.device)
                batch_prediction: torch.Tensor = self.model(input=batch_image)
                assert batch_prediction.shape == batch_groundtruth.shape

                dice_loss = self.loss_function(
                    logits=batch_prediction, groundtruths=batch_groundtruth,
                ).item()
                iou = self.evaluation_metric(
                    logits=batch_prediction, groundtruths=batch_groundtruth,
                ).item()
                batch_images.append(batch_image)
                batch_groundtruths.append(batch_groundtruth)
                batch_predictions.append(batch_prediction)
                iou_values.append(iou)
                metric_notes.append(f'Dice Loss: {dice_loss:.4f}, IoU: {iou:.4f}')

            images = torch.cat(tensors=batch_images, dim=0)
            predictions = torch.cat(tensors=batch_predictions, dim=0)
            groundtruths = torch.cat(tensors=batch_groundtruths, dim=0)
            assert predictions.shape == groundtruths.shape
            # Plot the prediction
            plot_predictions(
                images=images,
                groundtruths=groundtruths, 
                predictions=predictions, 
                notes=metric_notes, 
            )

        return sum(iou_values) / len(iou_values) * 100

We have everything ready. Now, we make predictions on test dataset. Since the training process early stopped at `epoch 19`, we will load the checkpoint at this epoch to do the predictions:

In [41]:
model, optimizer = CheckpointLoader(r'checkpoints/epoch19.pt').load(scope=globals())
predictor = Predictor(model=model, device=device)
test_iou: float = predictor.predict(dataset=test_dataset)
print(test_iou)

67.19253231709203


# 4. Your IoU score should be higher than 60

As can be seen, our model achieved the IoU of `67.19%` on the test dataset. Let's look at several specific predictions:

<img src="https://raw.githubusercontent.com/hiepdang-ml/dnn_project_three/master/results/20240802201944718.png" width="1000"/>

<br>
<br>

<img src="https://raw.githubusercontent.com/hiepdang-ml/dnn_project_three/master/results/20240802201941657.png" width="1000"/>

<br>
<br>

<img src="https://raw.githubusercontent.com/hiepdang-ml/dnn_project_three/master/results/20240802201949954.png" width="1000"/>

<br>
<br>

<img src="https://raw.githubusercontent.com/hiepdang-ml/dnn_project_three/master/results/20240802202036598.png" width="1000"/>

<br>
<br>

<img src="https://raw.githubusercontent.com/hiepdang-ml/dnn_project_three/master/results/20240802202052500.png" width="1000"/>

# 5. Write a 3-page report using LaTex and upload your paper to ResearchGate or Arxiv, and put your paper link here.


Source code: https://github.com/hiepdang-ml/dnn_project_three

Paper link: 

Model Weight: https://drive.google.com/drive/folders/1U02WWSGn-dKJ7MBtXNb9FLBhy9fJjcj2?usp=share_link

# 6. Grading rubric

(1). Code ------- 20 points (you also need to upload your final model as a pt file, and add paper link)

(2). Grammer ---- 20 points

(3). Introduction & related work --- 10 points

(4). Method  ---- 20 points

(5). Results ---- 20 points     

(6). Discussion - 10 points

---