In [1]:
# Start importing the necessary libraries
import os
import random
import warnings
from typing import Callable, List

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import torchmetrics as metrics
import torchvision.io as tvio
import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as F

%matplotlib inline

RANDOM_STATE = 42

# Filter warnings
warnings.filterwarnings("ignore")

# Set device for acceleration
DEVICE = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)

if DEVICE == "mps":
    torch.mps.empty_cache()
elif DEVICE == "cuda":
    torch.cuda.empty_cache()

print(
    "Using CPU for training and testing as no accelerator is available."
    if DEVICE == "cpu"
    else f"Using {DEVICE} for accleration."
)

Using mps for accleration.


In [2]:
# Define hyperparameters in this cell
BATCH_SIZE = 1
EPOCHS = 100
SHUFFLE = True

In [3]:
# Start defininig datsets and dataloaders
class CarImagesDataset(data.Dataset):
    def __init__(self, path_to_dataset: str = None, data_aug_pipeline=None) -> None:
        self.root_dir = path_to_dataset
        self.files = os.listdir(path_to_dataset)
        self.n_files = len(self.files)
        self.data_aug_pipeline = data_aug_pipeline

    def __len__(self) -> int:
        return self.n_files

    def __getitem__(self, idx: int) -> torch.Tensor:
        total_path = f"{self.root_dir}/{self.files[idx]}"
        image_tensor = tvio.read_image(total_path)

        if self.data_aug_pipeline is not None:
            image_tensor = self.data_aug_pipeline(image_tensor)

        return image_tensor


class SatelliteImageDataset(data.Dataset):
    def __init__(self, path_to_dataset: str = None, data_aug_pipeline=None) -> None:
        self.root_dir = path_to_dataset
        self.files = os.listdir(path_to_dataset)
        self.n_files = len(self.files)
        self.data_aug_pipeline = data_aug_pipeline

    def __len__(self) -> int:
        return self.n_files

    def __getitem__(self, idx: int) -> torch.Tensor:
        total_path = f"{self.root_dir}/{self.files[idx]}"
        image_tensor = tvio.read_image(total_path)

        if self.data_aug_pipeline is not None:
            image_tensor = self.data_aug_pipeline(image_tensor)

        return image_tensor


car_images_train_transform = T.Compose(
    [
        T.RandomHorizontalFlip(p=0.5),
        T.RandomVerticalFlip(p=0.5),
        T.RandomRotation(degrees=45),
        T.RandomResizedCrop(size=(256, 256), scale=(0.8, 1.0)),
        T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        T.RandomPerspective(),
    ]
)

satellite_images_train_transform = T.Compose(
    [
        T.RandomHorizontalFlip(p=0.5),
        T.RandomVerticalFlip(p=0.5),
        T.RandomRotation(degrees=45),
        T.RandomResizedCrop(size=(256, 256), scale=(0.8, 1.0)),
        T.RandomAffine(degrees=0, translate=(0.1, 0.1)),
        T.RandomPerspective(),
    ]
)

# Define the Datasets and DataLoaders

## Training Datasets

# TODO: Add dataset paths
car_images_train_unaugmented = CarImagesDataset()
car_images_train_augmented = CarImagesDataset(
    data_aug_pipeline=car_images_train_transform
)
car_images_train = data.ConcatDataset(
    [car_images_train_unaugmented, car_images_train_augmented]
)

satellite_images_train_unaugmented = SatelliteImageDataset()
satellite_images_train_augmented = SatelliteImageDataset(
    data_aug_pipeline=satellite_images_train_transform
)
satellite_images_train = data.ConcatDataset(
    [satellite_images_train_unaugmented, satellite_images_train_augmented]
)

## Testing Datasets
car_images_test = CarImagesDataset()
satellite_images_test = SatelliteImageDataset()

## Dataloaders
car_images_train_loader = data.DataLoader(
    car_images_train, batch_size=BATCH_SIZE, shuffle=SHUFFLE
)
satellite_images_train_loader = data.DataLoader(
    satellite_images_train, batch_size=BATCH_SIZE, shuffle=SHUFFLE
)

car_images_test_loader = data.DataLoader(
    car_images_test, batch_size=BATCH_SIZE, shuffle=SHUFFLE
)
satellite_images_test_loader = data.DataLoader(
    satellite_images_test, batch_size=BATCH_SIZE, shuffle=SHUFFLE
)

In [4]:
# random_img = random.choice(car_images_train)

# plt.imshow(random_img.permute(1, 2, 0).numpy())
# plt.show()

In [20]:
# Define Model Architecture here
class ConvBlock(nn.Module):
    def __init__(self, in_channels: int, out_channels: int):
        super(ConvBlock, self).__init__()

        self.layers = nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size=3,
                padding_mode="reflect",
                bias=False,
            ),
            nn.ReLU(),
            # nn.BatchNorm2d(out_channels),
            nn.Conv2d(
                out_channels,
                out_channels,
                kernel_size=3,
                padding_mode="reflect",
                bias=False,
            ),
            nn.ReLU(),
            # nn.BatchNorm2d(out_channels),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.layers(x)


class UNet(nn.Module):
    def __init__(self, in_channels: int, out_channels: int = 2) -> None:
        super(UNet, self).__init__()

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.encoder = nn.ModuleList(
            [
                ConvBlock(in_channels, 64),
                ConvBlock(64, 128),
                ConvBlock(128, 256),
                ConvBlock(256, 512),
            ]
        )

        self.base = nn.ModuleList(
            [
                nn.Conv2d(512, 1024, kernel_size=3, padding_mode="reflect"),
                nn.ReLU(),
                nn.Conv2d(1024, 1024, kernel_size=3, padding_mode="reflect"),
                nn.ReLU(),
            ]
        )

        self.decoder = nn.ModuleList()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # TODO: Add the skip connections and add them to the decoder
        copied_feature_maps = []

        for layer in self.encoder:
            x = layer(x)
            copied_feature_maps.append(x)
            x = self.pool(x)

        print(copied_feature_maps[-1].shape)

        for layer in self.base:
            x = layer(x)

        for layer in self.decoder:
            x = layer(x)
            # TODO: IMPLEMENT CROPPING AND CONCATENATION

        return x


# test the block and model
test_image = torch.randint(0, 256, (1, 572, 572), dtype=torch.float32)
test_batch = torch.randint(0, 256, (3, 1, 572, 572), dtype=torch.float32)

block = UNet(1, 2)

conv_out = block(test_image)
conv_out.shape

torch.Size([512, 64, 64])


torch.Size([1024, 28, 28])

In [6]:
# TODO: Integrate metrics into train function
def train(
    model: nn.Module,
    optimizer: optim.Optimizer,
    criterion: nn.Module,
    train_loader: data.DataLoader,
    val_loader: data.DataLoader = None,
    epochs: int = 15,
    device: str = "cpu",
    metrics_to_monitor: str | List[Callable] | None = None,
):
    _epoch_wise_train_loss = []
    _epoch_wise_val_loss = []

    _epoch_wise_train_accuracy = []
    _epoch_wise_val_accuracy = []

    if metrics_to_monitor is None:
        metrics_to_monitor = []

    try:
        _n_batches = len(train_loader)
        _max_char_epoch, _max_char_batch = len(str(epochs)), len(str(_n_batches))

        # BATCH TRAIN FORMAT STRING
        def _batch_train_message(i_epoch, i_batch):
            return f"Epoch {i_epoch:>{_max_char_epoch}}/{epochs} Batch: [{i_batch:>{_max_char_batch}}/{_n_batches}]"

        # CREATE EPOCH TRAIN FORMAT STRING
        def _epoch_train_message(
            i_epoch, i_batch, t_loss, t_accuracy, v_loss=None, v_accuracy=None
        ):
            return (
                f"Epoch {i_epoch:>{_max_char_epoch}}/{epochs} Batch: [{i_batch:>{_max_char_batch}}/{_n_batches}] Train Loss: {t_loss:.4f} Train Accuracy: {t_accuracy:.4f}"
                if v_loss is None or v_accuracy is None
                else f"Epoch {i_epoch:>{_max_char_epoch}}/{epochs} Batch: [{i_batch:>{_max_char_batch}}/{_n_batches}] Train Loss: {t_loss:.4f} Train Accuracy: {t_accuracy:.4f} Val Loss: {v_loss:.4f} Val Accuracy: {v_accuracy:.4f}"
            )

        # Move model to device
        model.to(device)
        print(f"Model moved to {device}.")

        # Train the model
        print("++++++++++ MODEL TRAINING STARTS ++++++++++")

        for epoch in range(1, epochs + 1):
            _batch_wise_train_loss = []
            _batch_wise_accuracy = []

            # Run batches
            for batch_idx, (data_img, labels) in enumerate(train_loader, 1):
                model.train()
                print(_batch_train_message(epoch, batch_idx), end="\r")

                # Move data to device
                data_img, labels = data_img.to(device), labels.to(device)

                # Zero the gradients
                optimizer.zero_grad()

                # Forward pass
                outputs = model(data_img)

                # Get accuracy
                batch_accuracy = torch.sum(
                    torch.argmax(outputs, dim=1) == torch.argmax(labels, dim=1)
                ).item() / len(labels)
                _batch_wise_accuracy.append(batch_accuracy)

                loss = criterion(outputs, labels)
                _batch_wise_train_loss.append(loss.item())

                loss.backward()
                optimizer.step()

                del data_img, labels, outputs, loss

            _t_loss = torch.mean(torch.tensor(_batch_wise_train_loss))
            _t_accuracy = torch.mean(torch.tensor(_batch_wise_accuracy))

            _epoch_wise_train_loss.append(_t_loss.item())
            _epoch_wise_train_accuracy.append(_t_accuracy)

            # Validation
            if val_loader:
                model.eval()

                with torch.no_grad():
                    _batch_wise_val_loss = []
                    _batch_wise_accuracy = []

                    for data_img, labels in val_loader:
                        data_img, labels = data_img.to(device), labels.to(device)
                        outputs = model(data_img)

                        accuracy = torch.sum(
                            torch.argmax(outputs, dim=1) == torch.argmax(labels, dim=1)
                        ).item() / len(labels)
                        _batch_wise_accuracy.append(accuracy)

                        loss = criterion(outputs, labels)
                        _batch_wise_val_loss.append(loss.item())
                        _, predictions = torch.max(outputs, 1)

                        del data_img, labels, outputs, loss, predictions

                    _v_loss = torch.mean(torch.tensor(_batch_wise_val_loss))
                    _epoch_wise_val_loss.append(_v_loss.item())

                    _v_accuracy = torch.mean(torch.tensor(_batch_wise_accuracy))
                    _epoch_wise_val_accuracy.append(_v_accuracy.item())

            print(
                _epoch_train_message(
                    epoch, batch_idx, _t_loss, _t_accuracy, _v_loss, _v_accuracy
                )
            )

    except RuntimeError as re:
        print("++++++++++ MODEL TRAINING ENDS ++++++++++")
        print("Some error occurred. Training stopped.")
        print(re)
        return
    except KeyboardInterrupt:
        print("\n")
        print("++++++++++ MODEL TRAINING ENDS ++++++++++")
        print("Training interrupted.")

    print("++++++++++ MODEL TRAINING ENDS ++++++++++")
    print("Training completed.")
    return {
        "train_loss": _epoch_wise_train_loss,
        "val_loss": _epoch_wise_val_loss,
        "train_accuracy": _epoch_wise_train_accuracy,
        "val_accuracy": _epoch_wise_val_accuracy,
    }


def test(model, test_loader, criterion, device="cpu"):
    model.to(device)

    _n_test_batches = len(test_loader)
    _max_char_batch = len(str(_n_test_batches))

    _test_accuracy, _test_loss = [], []

    for batch_idx, (data_img, labels) in enumerate(test_loader, 1):
        # Set model to eval mode
        model.eval()

        # Log the batch testing
        print(f"Batch: [{batch_idx:>{_max_char_batch}}/{_n_test_batches}]", end="\r")

        # Move data and label to device
        data_img, labels = data_img.to(device), labels.to(device)

        # Forward pass
        outputs = model(data_img)

        # Get accuracy
        _accuracy = torch.sum(
            torch.argmax(outputs) == torch.argmax(labels)
        ).item() / len(labels)

        _test_loss.append(criterion(outputs, labels).item())
        _test_accuracy.append(_accuracy)

    # Finally move model back to CPU so that other models can use the GPU
    model.to("cpu")

    return {
        "test_loss": torch.mean(torch.tensor(_test_loss)).item(),
        "test_accuracy": torch.mean(torch.tensor(_test_accuracy)).item(),
    }

In [7]:
# model = UNet(1, 1)  # TODO: Add proper parameters
# optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.99)
# loss = nn.CrossEntropyLoss() # TODO: Add correct loss function

# train(
#     model,
#     optimizer,
#     loss,
#     car_images_train_loader,
#     car_images_test_loader,
#     epochs=EPOCHS,
#     device=DEVICE,
#     metrics_to_monitor=[metrics.Dice(num_classes=2, average="micro")],
# )