# Dog Heart Vertebral Heart Size Point Detection 
# 1. Build an object detection model using pytorch

First, I create two `Dataset` classes: `LabeledDogHeartDataset` (for labeled data) and `UnlabeledDogHeartDataset` (for unlabeled data). These two classes share some common functionalities. Hence, they inherit from a `BaseDogHearDataset`:

In [2]:
import os
from typing import List, Tuple, Dict, Literal

from PIL import Image
from scipy.io import loadmat

import numpy as np
import torch
import torchvision.transforms as T
from torch.utils.data import Dataset


class BaseDogHeartDataset(Dataset):

    def __init__(
        self, 
        dataroot: str, 
        image_resolution: Tuple[int, int], 
        has_labels: bool,
    ):
        super().__init__()
        self.dataroot: str = dataroot
        self.image_resolution: Tuple[int, int] = image_resolution
        self.image_folder: str = os.path.join(dataroot, 'Images')
        self.image_filenames: List[str] = sorted(os.listdir(self.image_folder))
        self.has_labels: bool = has_labels
        if self.has_labels:
            self.point_folder: str = os.path.join(dataroot, 'Labels')
            self.point_filenames: List[str] = sorted(os.listdir(self.point_folder))

    def __len__(self) -> int:
        return len(self.image_filenames)

    def transform(self, input: Image.Image) -> torch.Tensor:
        transformer = T.Compose([
            T.ToTensor(),
            T.Resize(size=self.image_resolution),
            T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        return transformer(input)


class LabeledDogHeartDataset(BaseDogHeartDataset):

    def __init__(self, dataroot: str, image_resolution: Tuple[int, int]):
        super().__init__(dataroot, image_resolution, has_labels=True)

    # implement
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, str, str]:
        # Load images and masks
        image_path: str = os.path.join(self.image_folder, self.image_filenames[idx])
        point_path: str = os.path.join(self.point_folder, self.point_filenames[idx])
        image: Image.Image = Image.open(image_path).convert("RGB")
        
        width_original, height_original = image.size
        image_tensor: torch.Tensor = self.transform(input=image)
        height_new, width_new = image_tensor.shape[1], image_tensor.shape[2]
        
        mat: Dict[Literal['six_points', 'VHS'], np.array] = loadmat(file_name=point_path)
        six_points: torch.Tensor = torch.as_tensor(mat['six_points'], dtype=torch.float32)
        # Resize image to any size and maintain original points
        six_points[:, 0] = width_new / width_original * six_points[:, 0]
        six_points[:, 1] = height_new / height_original * six_points[:, 1]
        # Normalize
        six_points = six_points / height_new

        vhs: torch.Tensor = torch.as_tensor(mat['VHS'], dtype=torch.float32).reshape(-1)
        return image_tensor, six_points, vhs, image_path, point_path


class UnlabeledDogHeartDataset(BaseDogHeartDataset):

    def __init__(self, dataroot: str, image_resolution: Tuple[int, int]):
        super().__init__(dataroot, image_resolution, has_labels=False)

    # implement
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, str]:
        # Load images
        image_path: str = os.path.join(self.image_folder, self.image_filenames[idx])
        image: Image.Image = Image.open(image_path).convert("RGB")
        image_tensor: torch.Tensor = self.transform(input=image)
        return image_tensor, image_path


Create dataset instances:

In [3]:
train_dataset = LabeledDogHeartDataset(dataroot='Dog_Heart_VHS/train', image_resolution=(512, 512))
val_dataset = LabeledDogHeartDataset(dataroot='Dog_Heart_VHS/validation', image_resolution=(512, 512))

test_dataset = UnlabeledDogHeartDataset(dataroot='Dog_Heart_VHS/test', image_resolution=(512, 512))

## Model Architecture:

In this project, I built a `Vision Transformer (ViT)` from scratch (https://arxiv.org/abs/2010.11929). This architecrure can be described by the following figure:

<div style="background-color:white; width:1000px">
    <img src="https://raw.githubusercontent.com/hiepdang-ml/dnn_project_two/master/assets/architecture.png"/>
</div>

First, I build the `PatchPositionEmbedding` layer:

In [4]:
from typing import Tuple, List

import torch
import torch.nn as nn
import torch.nn.functional as F


class PatchPositionEmbedding(nn.Module):

    def __init__(
        self, 
        in_channels: int, 
        patch_size: int, 
        embedding_dim: int, 
        image_size: Tuple[int, int],
    ):
        super().__init__()
        self.in_channels: int = in_channels
        self.patch_size: int = patch_size
        self.embedding_dim: int = embedding_dim
        self.image_size: Tuple[int, int] = image_size
        self.n_hpatches: int = image_size[0] // patch_size
        self.n_wpatches: int = image_size[1] // patch_size
        self.n_patches: int = self.n_hpatches * self.n_wpatches
        self.projector = nn.Conv2d(
            in_channels=in_channels, out_channels=embedding_dim,
            kernel_size=patch_size, stride=patch_size,
        )

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert input.ndim == 4  # (batch_size, n_channels, height, width)
        batch_size: int = input.shape[0]
        output: torch.Tensor = self.projector(input)
        assert output.shape == (batch_size, self.embedding_dim, self.n_hpatches, self.n_wpatches)
        output: torch.Tensor = output.flatten(start_dim=2, end_dim=-1)
        assert output.shape == (batch_size, self.embedding_dim, self.n_patches)
        return output.permute(0, 2, 1)

The Transformer Encoder contains a stack of multiple `TransformerBlock`:

In [5]:
class TransformerBlock(nn.Module):

    def __init__(
        self, 
        embedding_dim: int, 
        n_heads: int, 
        dropout: float,
    ):
        super().__init__()
        self.embedding_dim: int = embedding_dim
        self.n_heads: int = n_heads

        assert embedding_dim % n_heads == 0, f'embedding_dim must be divisible by n_heads'
        self.head_embedding_dim: int = self.embedding_dim // self.n_heads
        
        self.qkv = nn.Linear(in_features=embedding_dim, out_features=embedding_dim * 3)
        self.attention = nn.MultiheadAttention(
            embed_dim=embedding_dim, num_heads=n_heads, 
            dropout=dropout, batch_first=False,
        )
        self.projector1 = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
        self.projector2 = nn.Linear(in_features=embedding_dim, out_features=embedding_dim)
        self.dropout = nn.Dropout(p=dropout)
        self.layer_norm1 = nn.LayerNorm(normalized_shape=embedding_dim)
        self.layer_norm2 = nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert input.ndim == 3
        assert input.shape[2] == self.embedding_dim
        batch_size: int = input.shape[0]
        n_patches: int = input.shape[1]

        residual: torch.Tensor = input.clone()
        
        # LayerNorm
        input: torch.Tensor = self.layer_norm1(input)
        
        # Multihead Attention
        qkv: torch.Tensor = self.qkv(input)
        assert qkv.shape == (batch_size, n_patches, self.embedding_dim * 3)
        qkv: torch.Tensor = qkv.reshape(batch_size, n_patches, 3, self.embedding_dim)
        qkv: torch.Tensor = qkv.permute(2, 1, 0, 3)
        assert qkv.shape == (3, n_patches, batch_size, self.embedding_dim)
        queries: torch.Tensor = qkv[0]
        keys: torch.Tensor = qkv[1]
        values: torch.Tensor = qkv[2]
        output, _ = self.attention(query=queries, key=keys, value=values)
        assert output.shape == (n_patches, batch_size, self.embedding_dim)
        output: torch.Tensor = output.permute(1, 0, 2)
        output = F.gelu(self.projector1(output))

        # Residual Connection
        output = residual + output
        residual: torch.Tensor = output.clone()
        # LayerNorm
        output = self.layer_norm2(output)
        # MLP
        output = F.gelu(self.projector2(output))
        # Residual Connection
        output = residual + output
        assert output.shape == (batch_size, n_patches, self.embedding_dim)
        return output


In [6]:
class TransformerEncoder(nn.Module):

    def __init__(
        self, 
        embedding_dim: int, 
        n_heads: int, 
        depth: int, 
        dropout: float
    ):
        super().__init__()
        self.embedding_dim: int = embedding_dim
        self.n_heads: int = n_heads
        self.depth: int = depth
        self.dropout: float = dropout

        self.blocks = nn.Sequential(
            *[TransformerBlock(embedding_dim, n_heads, dropout) for _ in range(depth)]
        )
        self.layer_norm = nn.LayerNorm(normalized_shape=embedding_dim)

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert input.ndim == 3
        assert input.shape[2] == self.embedding_dim
        batch_size: int = input.shape[0]
        n_patches: int = input.shape[1]

        output: torch.Tensor = self.blocks(input)
        output: torch.Tensor = self.layer_norm(output)
        assert output.shape == (batch_size, n_patches, self.embedding_dim)
        return output

We also need an `OrthogonalLayer` to ensure `AB` is perpendicular to `CD`:

In [7]:
class OrthogonalLayer(nn.Module):

    def __init__(self):
        super().__init__()

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        batch_size: int = input.shape[0]
        assert input.shape == (batch_size, 6, 2)
        s: torch.Tensor = - (input[:, 0, 0] - input[:, 1, 0]) / (input[:, 0, 1] - input[:, 1, 1])
        y3: torch.Tensor = s * (input[:, 3, 0] - input[:, 2, 0]) + input[:, 2, 1]
        output = input.clone()
        output[:, 3, 1] = y3
        assert output.shape == input.shape
        return output

Lastly, we stack all of the above modules to form the Vision Transformer model:

In [8]:
class VisionTransformer(nn.Module):

    def __init__(
        self, 
        in_channels: int, 
        patch_size: int, 
        embedding_dim: int, 
        image_size: Tuple[int, int], 
        depth: int, 
        n_heads: int, 
        dropout: float, 
    ):
        super().__init__()
        self.in_channels: int = in_channels
        self.out_channels: int = 12
        self.patch_size: int = patch_size
        self.embedding_dim: int = embedding_dim
        self.image_size: Tuple[int, int] = image_size
        self.depth: int = depth
        self.n_heads: int = n_heads
        self.dropout: float = dropout

        self.patch_embedding = PatchPositionEmbedding(in_channels, patch_size, embedding_dim, image_size)
        self.encoder = TransformerEncoder(embedding_dim, n_heads, depth, dropout)
        self.orthogonalizer = OrthogonalLayer()

        scale_pos: float = self.patch_embedding.n_patches * embedding_dim
        self.pos_embedding = nn.Parameter(
            data=torch.rand(1, self.patch_embedding.n_patches, embedding_dim) / scale_pos
        )
        self.mlp_head = nn.Sequential(*[
            nn.Linear(in_features=self.patch_embedding.n_patches * embedding_dim, out_features=1024), nn.ReLU(), nn.Dropout(p=0.1),
            nn.Linear(in_features=1024, out_features=512), nn.ReLU(), nn.Dropout(p=0.1),
            nn.Linear(in_features=512, out_features=self.out_channels),
        ])

    def forward(self, input: torch.Tensor) -> torch.Tensor:
        assert input.ndim == 4
        batch_size, n_channels, image_height, image_width = input.shape
        output: torch.Tensor = self.patch_embedding(input)
        assert output.shape == (batch_size, self.patch_embedding.n_patches, self.embedding_dim)
        output: torch.Tensor = output + self.pos_embedding
        output: torch.Tensor = self.encoder(output)
        assert output.shape == (batch_size, self.patch_embedding.n_patches, self.embedding_dim)
        output: torch.Tensor = output.flatten(start_dim=1, end_dim=-1)
        output: torch.Tensor = self.mlp_head(output).reshape(batch_size, 6, 2)
        output: torch.Tensor = output.reshape(batch_size, 6, 2)
        return self.orthogonalizer(output)

Let's test on random input:

In [9]:
net = VisionTransformer(
    in_channels=3, patch_size=32, 
    embedding_dim=2048, image_size=(512, 512),
    depth=12, n_heads=16, dropout=0.,
)
x = torch.rand(8, 3, 512, 512)
y = net(x)

print(f'Input shape: {x.shape}')
print(f'Output shape: {y.shape}')

Input shape: torch.Size([8, 3, 512, 512])
Output shape: torch.Size([8, 6, 2])


# 2. Train your model using [Dog VHS Dataset](https://yuad-my.sharepoint.com/:f:/g/personal/youshan_zhang_yu_edu/ErguFJBE4y9KqzEdWWNlXzMBkTbsBaNX9l856SyvQauwJg?e=L3JOuN)

Before training a `VisionTransformer` model on the `LabeledDogHeartDataset` datasets, we need to define some utility classes to control and monitor the training process. 

First, we need a `Accumulator` to keep track of the performance metrics:

In [10]:
import os
import pathlib
import time
from typing import Optional, Dict, TextIO, Any, List, Tuple
from collections import defaultdict
import datetime as dt

import matplotlib.pyplot as plt

import torch
import torch.nn as nn


class Accumulator:
    """
    A utility class for accumulating values for multiple metrics.
    """

    def __init__(self) -> None:
        self.__records: defaultdict[str, float] = defaultdict(float)

    def add(self, **kwargs: Any) -> None:
        """
        Add values to the accumulator.

        Parameters:
            - **kwargs: named metric and the value is the amount to add.
        """
        metric: str
        value: float
        for metric, value in kwargs.items():
            # Each keyword argument represents a metric name and its value to be added
            self.__records[metric] += value
    
    def reset(self) -> None:
        """
        Reset the accumulator by clearing all recorded metrics.
        """
        self.__records.clear()

    def __getitem__(self, key: str) -> float:
        """
        Retrieve a record by key.

        Parameters:
            - key (str): The record key name.

        Returns:
            - float: The record value.
        """
        return self.__records[key]

An implementation of `EarlyStopping` mechanism to avoid overfitting:

In [11]:
class EarlyStopping:
    """
    A simple early stopping utility to terminate training when a monitored metric stops improving.

    Attributes:
        - patience (int): The number of epochs with no improvement after which training will be stopped.
        - tolerance (float): The minimum change in the monitored metric to qualify as an improvement,
        - considering the direction of the metric being monitored.
        - bestscore (float): The best score seen so far.
    """
    
    def __init__(self, patience: int, tolerance: float = 0.) -> None:
        """
        Initializes the EarlyStopping instance.
        
        Parameters:
            - patience (int): Number of epochs with no improvement after which training will be stopped.
            - tolerance (float): The minimum change in the monitored metric to qualify as an improvement. 
            Defaults to 0.
        """
        self.patience: int = patience
        self.tolerance: float = tolerance
        self.bestscore: float = float('inf')
        self.__counter: int = 0

    def __call__(self, value: float) -> None:
        """
        Update the state of the early stopping mechanism based on the new metric value.

        Parameters:
            - value (float): The latest value of the monitored metric.
        """
        # Improvement or within tolerance, reset counter
        if value <= self.bestscore + self.tolerance:
            self.bestscore: float = value
            self.__counter: int = 0

        # No improvement, increment counter
        else:
            self.__counter += 1

    def __bool__(self) -> bool:
        """
        Determine if the training process should be stopped early.

        Returns:
            - bool: True if training should be stopped (patience exceeded), otherwise False.
        """
        return self.__counter >= self.patience

A `Timer` to monitor the running time:

In [12]:
class Timer:

    """
    A class used to time the duration of epochs and batches.
    """
    def __init__(self) -> None:
        """
        Initialize the Timer.
        """
        self.__epoch_starts: Dict[int, float] = dict()
        self.__epoch_ends: Dict[int, float] = dict()
        self.__batch_starts: Dict[int, Dict[int, float]] = defaultdict(dict)
        self.__batch_ends: Dict[int, Dict[int, float]] = defaultdict(dict)

    def start_epoch(self, epoch: int) -> None:
        """
        Start timing an epoch.

        Parameters:
            epoch (int): The epoch number.
        """
        self.__epoch_starts[epoch] = time.time()

    def end_epoch(self, epoch: int) -> None:
        """
        End timing an epoch.

        Parameters:
            - epoch (int): The epoch number.
        """
        self.__epoch_ends[epoch] = time.time()

    def start_batch(self, epoch: int, batch: Optional[int] = None) -> None:
        """
        Start timing a batch.

        Parameters:
            - epoch (int): The epoch number.
            - batch (int, optional): The batch number. If not provided, the next batch number is used.
        """
        if batch is None:
            if self.__batch_starts[epoch]:
                batch: int = max(self.__batch_starts[epoch].keys()) + 1
            else:
                batch: int = 1
        self.__batch_starts[epoch][batch] = time.time()
    
    def end_batch(self, epoch: int, batch: Optional[int] = None) -> None:
        """
        End timing a batch.

        Parameters:
            - epoch (int): The epoch number.
            - batch (int, optional): The batch number. If not provided, the last started batch number is used.
        """
        if batch is None:
            if self.__batch_starts[epoch]:
                batch: int = max(self.__batch_starts[epoch].keys())
            else:
                raise RuntimeError(f"no batch has started")
        self.__batch_ends[epoch][batch] = time.time()
    
    def time_epoch(self, epoch: int) -> float:
        """
        Get the duration of an epoch.

        Parameters:
            - epoch (int): The epoch number.

        Returns:
            - float: The duration of the epoch in seconds.
        """
        result: float = self.__epoch_ends[epoch] - self.__epoch_starts[epoch]
        if result > 0:
            return result
        else:
            raise RuntimeError(f"epoch {epoch} ends before starts")
    
    def time_batch(self, epoch: int, batch: int) -> float:
        """
        Get the duration of a batch.

        Parameters:
            - epoch (int): The epoch number.
            - batch (int): The batch number.

        Returns:
            - float: The duration of the batch in seconds.
        """
        result: float = self.__batch_ends[epoch][batch] - self.__batch_starts[epoch][batch]
        if result > 0:
            return result
        else:
            raise RuntimeError(f"batch {batch} in epoch {epoch} ends before starts")

A `Logger` to log the training process:

In [13]:
class Logger:

    """
    A class used to log the training process.

    This class provides methods to log messages to a file and the console. 
    """
    def __init__(
        self, 
        logfile: str = f".logs/{dt.datetime.now().strftime('%Y%m%d%H%M%S')}"
    ) -> None:
    
        """
        Initialize the logger.

        Parameters:
            - logfile (str, optional): The path to the logfile. 
            Defaults to a file in the .logs directory with the current timestamp.
        """
        self.logfile: pathlib.Path = pathlib.Path(logfile)
        os.makedirs(name=self.logfile.parent, exist_ok=True)
        self._file: TextIO = open(self.logfile, mode='w')

    def log(
        self, 
        epoch: int, 
        n_epochs: int, 
        batch: Optional[int] = None, 
        n_batches: Optional[int] = None, 
        took: Optional[float] = None, 
        **kwargs: Any,
    ) -> None:
        """
        Log a message to console and a log file

        Parameters:
            - epoch (int): The current epoch.
            - n_epochs (int): The total number of epochs.
            - batch (int, optional): The current batch. Defaults to None.
            - n_batches (int, optional): The total number of batches. Defaults to None.
            - took (float, optional): The time it took to process the batch or epoch. Defaults to None.
            - **kwargs: Additional metrics to log.
        """
        suffix: str = ', '.join([f'{metric}: {value:.3e}' for metric, value in kwargs.items()])
        prefix: str = f'Epoch {epoch}/{n_epochs} | '
        if batch is not None:
            prefix += f'Batch {batch}/{n_batches} | '
        if took is not None:
            prefix += f'Took {took:.2f}s | '
        logstring: str = prefix + suffix
        print(logstring)
        self._file.write(logstring + '\n')

    def __del__(self) -> None:
        """
        Close the logfile at garbage collected.
        """
        self._file.close()

A `CheckPointSaver` to ragularly save to model checkpoint during training:

In [14]:
class CheckPointSaver:
    """
    A class used to save PyTorch model checkpoints.

    Attributes:
        - dirpath (pathlib.Path): The directory where the checkpoints are saved.
    """

    def __init__(self, dirpath: str) -> None:
        """
        Initialize the CheckPointSaver.

        Parameters:
            - dirpath (os.PathLike): The directory where the checkpoints are saved.
        """
        self.dirpath: pathlib.Path = pathlib.Path(dirpath)
        os.makedirs(name=self.dirpath, exist_ok=True)

    def save(self, model: nn.Module, filename: str) -> None:
        """
        Save checkpoint to a .pt file.

        Parameters:
            - model (nn.Module): The PyTorch model to save.
            - filename (str): the checkpoint file name
        """
        torch.save(obj=model, f=os.path.join(self.dirpath, filename))

A `compute_vhs` function to compute the VHS on a batch of 6 points:

In [15]:
from typing import Optional

import os
from PIL import Image
import matplotlib.pyplot as plt

import torch


def compute_vhs(points: torch.Tensor) -> torch.Tensor:
    assert points.shape[1:] == (6, 2), 'Each sample in points should be in shape (6, 2)'
    batch_size: int = points.shape[0]
    AB = torch.norm(points[:, 1] - points[:, 0], dim=1)
    CD = torch.norm(points[:, 3] - points[:, 2], dim=1)
    EF = torch.norm(points[:, 5] - points[:, 4], dim=1)
    vhs = 6 * (AB + CD) / EF
    return vhs.reshape(batch_size, 1)

A `plot_predictions` to plot the predicted points against the groundtruth points:

In [16]:
def plot_predictions(
    image_path: str, 
    gt_points: Optional[torch.Tensor], 
    pred_points: torch.Tensor,
):
    assert pred_points.shape == (6, 2), 'points should be in shape (6, 2)'
    # Make sure all tensors are in CPU
    pred_points = pred_points.to(device='cpu')
    if gt_points is not None:
        assert gt_points.shape == (6, 2), 'points should be in shape (6, 2)'
        gt_points = gt_points.to(device='cpu')

    # Load image
    image: Image.Image = Image.open(image_path).convert("RGB")
    plt.imshow(image)
    
    # Scale points
    pred_points = pred_points * torch.tensor(image.size, dtype=pred_points.dtype)
    if gt_points is not None:
        gt_points = gt_points * torch.tensor(image.size, dtype=gt_points.dtype)

    # Draw points    
    plt.scatter(x=pred_points[:, 0], y=pred_points[:, 1], color='red', label='Prediction')
    if gt_points is not None:
        plt.scatter(x=gt_points[:, 0], y=gt_points[:, 1], color='green', label='Groundtruth')

    # Draw lines
    for p1, p2 in [(0, 1), (2, 3), (4, 5)]:
        plt.plot(
            [pred_points[p1, 0], pred_points[p2, 0]], [pred_points[p1, 1], pred_points[p2, 1]], 
            color='r', linestyle='--',
        )
        if gt_points is not None:
            plt.plot(
                [gt_points[p1, 0], gt_points[p2, 0]], [gt_points[p1, 1], gt_points[p2, 1]], 
                color='g', linestyle='-'
            )

    # Report the VHS in figure title
    filename: str = os.path.basename(image_path)
    title = f'{filename}\n'
    if gt_points is not None:
        title += f'Groundtruth VHS: {compute_vhs(gt_points.unsqueeze(0)).item():.4f}, '

    title += f'Predicted VHS: {compute_vhs(pred_points.unsqueeze(0)).item():.4f}'
    plt.title(title)

    # Set legend
    plt.legend(loc='upper right')
    # Fit plot margins
    plt.subplots_adjust(left=0.01, right=0.99, bottom=0.05, top=0.9)
    # Save file
    os.makedirs('results', exist_ok=True)
    plt.savefig(f'results/{filename}')
    plt.close()

With all the necessary utilities, now we can define a `Trainer` class:

In [17]:
import os
from typing import List, Tuple, Optional

import datetime as dt
import pandas as pd

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.optim import Optimizer, Adam


class Trainer:

    def __init__(
        self, 
        model: nn.Module,
        train_dataset: Dataset,
        val_dataset: Dataset,
        optimizer: Optimizer,
        train_batch_size: int,
        val_batch_size: int,
        device: torch.device,
    ):
        self.model = model.to(device=device)
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.optimizer = optimizer
        self.train_batch_size = train_batch_size
        self.val_batch_size = val_batch_size
        self.device = device

        self.train_dataloader = DataLoader(dataset=train_dataset, batch_size=train_batch_size, shuffle=True)
        self.val_dataloader = DataLoader(dataset=val_dataset, batch_size=val_batch_size, shuffle=False)
        self.loss_function = nn.MSELoss(reduction='mean')

    def train(
        self, 
        n_epochs: int,
        patience: int,
        tolerance: float,
        checkpoint_path: Optional[str] = None,
    ) -> None:
        
        train_metrics = Accumulator()
        early_stopping = EarlyStopping(patience, tolerance)
        timer = Timer()
        logger = Logger()
        checkpoint_saver = CheckPointSaver(dirpath=checkpoint_path)
        self.model.train()

        # loop through each epoch
        for epoch in range(1, n_epochs + 1):
            timer.start_epoch(epoch)
            # Loop through each batch
            for batch, (batch_images, batch_sixpoints, *_) in enumerate(self.train_dataloader, start=1):
                timer.start_batch(epoch, batch)
                batch_images: torch.Tensor = batch_images.to(device=self.device)
                batch_sixpoints: torch.Tensor = batch_sixpoints.to(device=self.device)
                self.optimizer.zero_grad()
                pred_targets: torch.Tensor = self.model(input=batch_images)
                print(f'gt_targets {batch_sixpoints[-1]}')
                print(f'pred_targets {pred_targets[-1]}')
                mse_loss = self.loss_function(input=pred_targets, target=batch_sixpoints)
                mse_loss.backward()
                self.optimizer.step()

                # Accumulate the metrics
                train_metrics.add(mse_loss=mse_loss.item())
                timer.end_batch(epoch=epoch)
                logger.log(
                    epoch=epoch, n_epochs=n_epochs, 
                    batch=batch, n_batches=len(self.train_dataloader), 
                    took=timer.time_batch(epoch, batch), 
                    train_mse_loss=train_metrics['mse_loss'] / batch, 
                )

            # Ragularly save checkpoint
            if checkpoint_path and epoch % 5 == 0:
                checkpoint_saver.save(self.model, filename=f'epoch{epoch}.pt')
            
            # Reset metric records for next epoch
            train_metrics.reset()
            
            # Evaluate
            val_mse_loss = self.evaluate()
            timer.end_epoch(epoch)
            logger.log(
                epoch=epoch, n_epochs=n_epochs, 
                took=timer.time_epoch(epoch), 
                val_mse_loss=val_mse_loss,
            )
            print('=' * 20)

            early_stopping(val_mse_loss)
            if early_stopping:
                print('Early Stopped')
                break

        # Save last checkpoint
        if checkpoint_path:
            checkpoint_saver.save(self.model, filename=f'epoch{epoch}.pt')

    def evaluate(self) -> float:
        val_metrics = Accumulator()
        self.model.eval()
        with torch.no_grad():
            # Loop through each batch
            for batch, (batch_images, batch_sixpoints, *_) in enumerate(self.val_dataloader, start=1):
                batch_images: torch.Tensor = batch_images.to(device=self.device)
                batch_sixpoints: torch.Tensor = batch_sixpoints.to(device=self.device)
                pred_targets: torch.Tensor = self.model(input=batch_images)
                mse_loss = self.loss_function(input=pred_targets, target=batch_sixpoints)
                # Accumulate the val_metrics
                val_metrics.add(mse_loss=mse_loss.item())

        # Compute the aggregate metrics
        return val_metrics['mse_loss'] / batch

We also need a `Predictor` to make six-point predictions on a trained model against any dataset:

In [18]:
class Predictor:

    def __init__(self, model: nn.Module, device: torch.device) -> None:
        self.model: nn.Module = model.to(device=device)
        self.device: torch.device = device

    def predict(self, dataset: Dataset, need_plots: bool) -> pd.DataFrame:
        self.model.eval()
        dataloader = DataLoader(dataset, batch_size=8, shuffle=False)

        image_paths: List[str] = []
        point_predictions: List[torch.Tensor] = []
        vhs_predictions: List[torch.Tensor] = []

        if isinstance(dataloader.dataset, UnlabeledDogHeartDataset):
            with torch.no_grad():
                # Loop through each batch
                for batch_images, batch_image_paths in dataloader:
                    batch_images: torch.Tensor = batch_images.to(device=self.device)
                    pred_points: torch.Tensor = self.model(input=batch_images)
                    pred_vhs: torch.Tensor = compute_vhs(points=pred_points)

                    image_paths.extend(batch_image_paths)
                    point_predictions.append(pred_points)
                    vhs_predictions.append(pred_vhs)

                point_predictions = torch.cat(tensors=point_predictions, dim=0).to(device=self.device)
                vhs_predictions: torch.Tensor = torch.cat(tensors=vhs_predictions, dim=0).reshape(-1).to(device=self.device)
                if need_plots:
                    assert point_predictions.shape[0] == len(image_paths)
                    for i in range(len(image_paths)):
                        image_path: str = image_paths[i]
                        point_prediction: torch.Tensor = point_predictions[i]
                        plot_predictions(
                            image_path=image_path, gt_points=None, pred_points=point_prediction,
                        )

        elif isinstance(dataloader.dataset, LabeledDogHeartDataset):
            point_groundtruths: List[torch.Tensor] = []

            with torch.no_grad():
                # Loop through each batch
                for batch_images, batch_gt_six_points, _, batch_image_paths, _ in dataloader:
                    batch_images: torch.Tensor = batch_images.to(device=self.device)
                    batch_gt_six_points: torch.Tensor = batch_gt_six_points.to(device=self.device)
                    pred_points: torch.Tensor = self.model(input=batch_images)
                    pred_vhs: torch.Tensor = compute_vhs(points=pred_points)

                    image_paths.extend(batch_image_paths)
                    point_predictions.append(pred_points)
                    vhs_predictions.append(pred_vhs)
                    point_groundtruths.append(batch_gt_six_points)

                vhs_predictions: torch.Tensor = torch.cat(tensors=vhs_predictions, dim=0).reshape(-1).to(device=self.device)
                point_predictions = torch.cat(tensors=point_predictions, dim=0).to(device=self.device)
                point_groundtruths = torch.cat(tensors=point_groundtruths, dim=0).to(device=self.device)
                if need_plots:
                    assert (
                        point_predictions.shape[0] 
                        == point_groundtruths.shape[0] 
                        == vhs_predictions.shape[0]
                        == len(image_paths) 
                    )
                    for i in range(len(image_paths)):
                        image_path: str = image_paths[i]
                        point_prediction: torch.Tensor = point_predictions[i]
                        point_groundtruth: torch.Tensor = point_groundtruths[i]
                        plot_predictions(
                            image_path=image_path, gt_points=point_groundtruth, pred_points=point_prediction,
                        )
        
        else:
            raise ValueError('Invalid dataset')

        prediction_table = pd.DataFrame(
            data={
                'image': [os.path.basename(image_path) for image_path in image_paths], 
                'label': vhs_predictions.cpu().numpy().tolist(),
            }
        )
        prediction_table.to_csv(
            f'{dt.datetime.now().strftime(r"%Y%m%d%H%M%S")}.csv', 
            header=False, 
            index=False
        )
        return prediction_table

## Training:

Since I `build and train a large-scale ViT model from scratch`, I had to `ssh` into a GPU cloud instance which hosts 2 `RTX A6000` GPUs, each has `48GB` of VRAM. The training process required parallelism and synchronous data transfer between these two GPUs, effectively leveraging vertical scaling. The training batch size is set `128` which is large enough for parallel processing across both GPUs while compensating the overhead of data transfer.

Although vertical scaling on a multi-GPU cloud instance is efficient, `it is not manageable in a Jupyter notebook environment`, which is intended only for experimental prototyping.

While the cell below was not directly run in this notebook, the resulting model’s `checkpoint` can be found here: https://drive.google.com/file/d/12moljBHBecJP5S0T2bi7NCAVr09IT524/view?usp=share_link

In [None]:
device: torch.device = torch.device('cuda')
learning_rate: float = 1e-7

net = VisionTransformer(
    in_channels=3, patch_size=32, 
    embedding_dim=4096, image_size=(512, 512),
    depth=16, n_heads=32, dropout=0.,
)

net = nn.DataParallel(module=net).to(device=device)

trainer = Trainer(
    model=net, 
    train_dataset=train_dataset, val_dataset=val_dataset, 
    optimizer=Adam(params=net.parameters(), lr=learning_rate),
    train_batch_size=128, val_batch_size=32,
    device=device,
)
trainer.train(
    n_epochs=1000, 
    patience=100, tolerance=0., 
    checkpoint_path='.checkpoints'
)

From the trained model, one might run the inference script as followed:

In [None]:
device: torch.device = torch.device('cuda')
predictor = Predictor(model=net, device=device)

predictor.predict(dataset=test_dataset, need_plots=False)

# 3.Evaluate your model using the test images with the [software](https://github.com/YoushanZhang/Dog-Cardiomegaly_VHS)

The predicting script above generates a `.csv` file in a format that is expected by the evaluation software. You should be able to verify the result below:

<div style="background-color:white; width:700px">
    <img src="https://raw.githubusercontent.com/hiepdang-ml/dnn_project_two/master/assets/predictions.png"/>
</div>

The `.csv` can also be downloaded at: https://github.com/hiepdang-ml/dnn_project_two/blob/master/20240723209958.csv

# 4. Your results should be achieved 85%. VHS = 6(AB+CD)/EF

## (10 points, accuracy < 75% --> 0 points)

The `EarlyStopping` triggered the training process to stop at epoch `480`. So far, the ViT model was only able to achieve `83%` accuracy. 

Further improvement is required in later work.

# 5. Show the comprison between predictions and ground truth

### `1420.png`

<div style="background-color:white; width:700px">
    <img src="https://raw.githubusercontent.com/hiepdang-ml/dnn_project_two/master/results/1420.png"/>
</div>


### `1479.png`

<div style="background-color:white; width:700px">
    <img src="https://raw.githubusercontent.com/hiepdang-ml/dnn_project_two/master/results/1479.png"/>
</div>


### `1530.png`

<div style="background-color:white; width:700px">
    <img src="https://raw.githubusercontent.com/hiepdang-ml/dnn_project_two/master/results/1530.png"/>
</div>

# 6. Write a three-page report using LaTex and upload your paper to ResearchGate or Arxiv, and put your paper link here.


Paper link: https://www.researchgate.net/publication/382491046_Vision_Transformer-Based_Approach_for_Accurate_Cardiomegaly_Detection_in_Canines

Source code: https://github.com/hiepdang-ml/dnn_project_two

# 7. Grading rubric

(1). Code ------- 20 points (you also need to upload your final model as a pt file, prediction CSV file and add paper link)

(2). Grammer ---- 20 points

(3). Introduction & related work --- 10 points

(4). Method  ---- 20 points

(5). Results ---- 20 points

(6). Discussion - 10 points

# 8. Bonus points (10 points if your accuracy is higer than 87.3%)

---