In [None]:
# download train + test datasets
train_dataset = {"path": "../datasets/train-dataset", 
                 "url": "http://data.celltrackingchallenge.net/training-datasets/Fluo-N2DH-GOWT1.zip"}
test_dataset = {"path": "../datasets/test-dataset",
                "url": "http://data.celltrackingchallenge.net/test-datasets/Fluo-N2DH-GOWT1.zip"}

# shell commands
!mkdir -p ../datasets
!wget -nc -O {train_dataset["path"]}.zip {train_dataset["url"]}
!wget -nc -O {test_dataset["path"]}.zip {test_dataset["url"]}
!unzip -n {train_dataset["path"]}.zip -d {train_dataset["path"]}
!unzip -n {test_dataset["path"]}.zip -d {test_dataset["path"]}

In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import torch
from torch import nn, optim
from torch.utils import data
from torchvision import transforms
from tqdm.auto import tqdm

# set device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"device:\t{DEVICE}")


In [None]:
# load tif images as np.uint8
X_train = plt.imread(f"{train_dataset['path']}/Fluo-N2DH-GOWT1/01/t000.tif")
y_train = plt.imread(f"{train_dataset['path']}/Fluo-N2DH-GOWT1/01_ST/SEG/man_seg000.tif")

# plot and example train-test pair
fig, axs = plt.subplots(1, 2)
fig.tight_layout()
axs[0].imshow(X_train, cmap="binary")
axs[0].set_title("X_train")
axs[1].imshow(y_train, cmap="CMRmap_r")
axs[1].set_title("y_train")


In [None]:
class SegmentationDataset(data.Dataset):
    def __init__(self, X_path, y_path, preprocess_transforms=None, augmentation_transforms=None):
        """Create a segmentation dataset for pytorch dataloader.

        Parameters
        ----------
        X_path : str
            Path to the directory with raw images.
        y_path : str
            Path to the directory with labeled images.
        preprocess_transforms : torchvision.transforms.Compose, optional
            Transforms to be applies on train/validation/test images, by default None
        augmentation_transforms : torchvision.transforms.Compose, optional
            Transforms to be applied on train images only, by default None
        """
        self.X_path = X_path
        self.y_path = y_path
        self.preprocess_transforms = preprocess_transforms
        self.augmentation_transforms = augmentation_transforms
        self.X = sorted(Path(self.X_path).glob("*.tif"))
        self.y = sorted(Path(self.y_path).glob("*.tif"))
        assert len(self.X) == len(self.y), "len(X) != len(y)"

    def __len__(self):
        """number of training images"""
        return len(self.y)

    def __getitem__(self, idx):
        """load X-y image pair and apply transformations."""

        # load images as tensors
        X_img = plt.imread(self.X[idx]).astype(np.uint8)
        y_img = plt.imread(self.y[idx]).astype(np.uint8)

        # apply image transformations
        if self.preprocess_transforms:
            X_img = self.preprocess_transforms(X_img)
            y_img = self.preprocess_transforms(y_img)

        if self.augmentation_transforms:
            X_img = self.augmentation_transforms(X_img)
            y_img = self.augmentation_transforms(y_img)

        y_img[y_img != 0] = 1  # binarize

        return X_img, y_img


def get_data_loaders(X_path, y_path, batch_size=20, seed=2023, augmentation_transforms=True):
    """Helper function to initialize pytorch dataloaders.

    Parameters
    ----------
    X_path : str
        Path to the directory with raw images.
    y_path : str
        Path to the directory with labeled images.
    batch_size : int, optional
        Number of samples / batch to load, by default 20
    seed : int, optional
        Random seed, by default 2023
    augmentation_transforms : bool, optional
        Whether to apply augmentation transformations to training sets.

    Returns
    -------
    train_loader : torch.utils.data.DataLoader
        Dataloader for the training set.
    val_loader : torch_utils.data.DataLoader
        Dataloader for the validation set.
    """

    # random number generator
    rng = torch.Generator().manual_seed(seed)

    # transformations to train/validation/test images
    preprocess_transforms = transforms.Compose(
        [transforms.ToTensor(), transforms.Resize([572, 572], antialias=True)]
    )  # for input to U-Net

    # transformations to train images
    if augmentation_transforms:
        augmentation_transforms = transforms.Compose(
            [transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()]
        )
    else:
        augmentation_transforms = None

    # train-validation split
    segmentation_dataset = SegmentationDataset(
        X_path, y_path, preprocess_transforms, augmentation_transforms
    )
    train_dataset, val_dataset = data.random_split(segmentation_dataset, [0.8, 0.2], generator=rng)

    train_loader = data.DataLoader(train_dataset, batch_size, shuffle=True)
    val_loader = data.DataLoader(val_dataset, batch_size, shuffle=True)

    return train_loader, val_loader


In [None]:
def test_get_data_loaders(X_path, y_path):
    """Test that data loaders are working as expected"""
    batch_size = 37
    train_loader, val_loader = get_data_loaders(X_path, y_path, batch_size)

    assert len(train_loader.dataset) > 0, "training set is empty"
    assert len(val_loader.dataset) > 0, "validation set is empty"
    assert len(train_loader.dataset) / len(train_loader) == batch_size, "batch_size error"

    X, y = next(iter(train_loader))
    assert X.shape == y.shape, "training set error"

    X, y = next(iter(val_loader))
    assert X.shape == y.shape, "validation set error"


test_get_data_loaders(
    f"{train_dataset['path']}/Fluo-N2DH-GOWT1/01",
    f"{train_dataset['path']}/Fluo-N2DH-GOWT1/01_ST/SEG",
)


In [None]:
class UNet(nn.Module):
    """U-Net: Convolutional Networks for Biomedical Image Segmentation (https://arxiv.org/abs/1505.04597).

    Parameters
    ----------
    nn : torch.nn.Module
        Neural network module to use as the base architecture for the U-Net.

    Returns
    -------
    torch.nn.Module
        The U-Net model.
    """

    def __init__(self, n_classes=2):
        super().__init__()
        self.n_classes = n_classes

        # Encoder
        # -------
        self.maxpool2x2 = nn.MaxPool2d(kernel_size=2, stride=2)
        self.downconv1 = self.double_conv(1, 64)
        self.downconv2 = self.double_conv(64, 128)
        self.downconv3 = self.double_conv(128, 256)
        self.downconv4 = self.double_conv(256, 512)
        self.downconv5 = self.double_conv(512, 1024)

        # Decoder
        # -------
        self.uptrans1 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
        self.upconv1 = self.double_conv(1024, 512)
        self.uptrans2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
        self.upconv2 = self.double_conv(512, 256)
        self.uptrans3 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
        self.upconv3 = self.double_conv(256, 128)
        self.uptrans4 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
        self.upconv4 = self.double_conv(128, 64)

        # Output Layer
        # ------------
        self.output = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, img):
        """Forward pass of the U-Net model."""

        # Encoder
        # -------
        self.e1 = self.downconv1(img)  # ---> e1
        self.e2 = self.maxpool2x2(self.e1)
        self.e3 = self.downconv2(self.e2)  # ---> e3
        self.e4 = self.maxpool2x2(self.e3)
        self.e5 = self.downconv3(self.e4)  # ---> e5
        self.e6 = self.maxpool2x2(self.e5)
        self.e7 = self.downconv4(self.e6)  # ---> e7
        self.e8 = self.maxpool2x2(self.e7)
        self.e9 = self.downconv5(self.e8)

        # Decoder
        # -------
        self.d1 = self.uptrans1(self.e9)
        self.e7_crop = self.crop_img(self.e7, self.d1)  # <--- e7
        self.d2 = self.upconv1(torch.cat([self.d1, self.e7_crop], 1))
        self.d3 = self.uptrans2(self.d2)
        self.e5_crop = self.crop_img(self.e5, self.d3)  # <--- e5
        self.d4 = self.upconv2(torch.cat([self.d3, self.e5_crop], 1))
        self.d5 = self.uptrans3(self.d4)
        self.e3_crop = self.crop_img(self.e3, self.d5)  # <--- e3
        self.d6 = self.upconv3(torch.cat([self.d5, self.e3_crop], 1))
        self.d7 = self.uptrans4(self.d6)
        self.e1_crop = self.crop_img(self.e1, self.d7)  # <--- e1
        self.d8 = self.upconv4(torch.cat([self.d7, self.e1_crop], 1))

        # Output Layer
        # ------------
        self.d9 = self.output(self.d8)
        # !! edit from U-Net paper: matching the original image size!!
        return nn.functional.interpolate(self.d9, (572, 572), mode="nearest")

    def double_conv(self, input_channels, output_channels):
        """Double convolution with ReLU activation."""
        conv2d = nn.Sequential(
            nn.Conv2d(input_channels, output_channels, kernel_size=3),
            nn.ReLU(inplace=True),
            nn.Conv2d(output_channels, output_channels, kernel_size=3),
            nn.ReLU(inplace=True),
        )
        return conv2d

    def crop_img(self, input_img, output_img):
        """Crop the input image to match the size (width, hight) of the output image."""
        input_size = input_img.shape[2]
        output_size = output_img.shape[2]
        delta_size = (input_size - output_size) // 2
        return input_img[
            :, :, delta_size : input_size - delta_size, delta_size : input_size - delta_size
        ]


In [None]:
def test_unet_shape():
    """Test expected dimensions for each layer of the U-Net"""
    img = torch.rand((1, 1, 572, 572))
    n_classes = 3
    unet = UNet(n_classes=n_classes)
    output = unet(img)

    # Encoder
    # -------
    assert unet.e1.shape == (1, 64, 568, 568), "e1"
    assert unet.e2.shape == (1, 64, 284, 284), "e2"
    assert unet.e3.shape == (1, 128, 280, 280), "e3"
    assert unet.e4.shape == (1, 128, 140, 140), "e4"
    assert unet.e5.shape == (1, 256, 136, 136), "e5"
    assert unet.e6.shape == (1, 256, 68, 68), "e6"
    assert unet.e7.shape == (1, 512, 64, 64), "e7"
    assert unet.e8.shape == (1, 512, 32, 32), "e8"
    assert unet.e9.shape == (1, 1024, 28, 28), "e9"

    # Decoder
    # -------
    assert unet.d1.shape == (1, 512, 56, 56), "d1"
    assert unet.e7_crop.shape == (1, 512, 56, 56), "e7_crop"
    assert unet.d2.shape == (1, 512, 52, 52), "d2"
    assert unet.d3.shape == (1, 256, 104, 104), "d3"
    assert unet.e5_crop.shape == (1, 256, 104, 104), "e5_crop"
    assert unet.d4.shape == (1, 256, 100, 100), "d4"
    assert unet.d5.shape == (1, 128, 200, 200), "d5"
    assert unet.e3_crop.shape == (1, 128, 200, 200), "e3_crop"
    assert unet.d6.shape == (1, 128, 196, 196), "d6"
    assert unet.d7.shape == (1, 64, 392, 392), "d7"
    assert unet.e1_crop.shape == (1, 64, 392, 392), "e1_crop"
    assert unet.d8.shape == (1, 64, 388, 388), "d8"

    # Output Layer
    # ------------
    assert unet.d9.shape == (1, n_classes, 388, 388), "d9"
    assert output.shape == (1, n_classes, 572, 572), "output"


test_unet_shape()


In [None]:
def train_unet(unet, train_loader, optimizer, loss_function, device="cuda"):
    """Train a U-Net model for one epoch.

    Parameters
    ----------
    unet : torch.nn.Module
        The U-Net model to train.
    train_loader : torch.utils.data.DataLoader
        The data loader for the training dataset.
    optimizer : torch.optim.Optimizer
        The optimizer to use for training.
    loss_function : callable
        The loss function to use for parameter optimization.
    device : str, optional
        The device to use for training. Default is "cuda".

    Returns
    -------
    torch.nn.Module
        The trained U-Net model.
    """
    unet.train()  # set model to training mode (retains gradients)

    for X, y in train_loader:
        X, y = X.to(device), y.to(device)  # mv to cpu or cuda
        optimizer.zero_grad()  # clear gradients btw batches
        # !! dimension change: (bs, 1, 572, 572) > (bs, 2, 572, 572) !!
        y_hat = unet(X)  # forward pass
        loss = loss_function(y_hat[:, :1, :, :], y)
        loss.backward()  # backward pass
        optimizer.step()  # update parameters

    return unet


def test_unet(unet, test_loader, loss_function, device="cuda"):
    """Calculate validation/test accuracy for the U-Net model against labeled images.

    Parameters
    ----------
    unet : torch.nn.Module
        Trained U-Net model.
    test_loader : torch.utils.data.DataLoader
        The data loader for the validation/test set.
    loss_function : callable
        The loss function to use for parameter optimization.
    device : str, optional
        The device to use for testing. Default is "cuda".

    Returns
    -------
    float
        The average loss across the validation/test sets.
    """
    unet.eval()  # set model to evaluation mode (gradients not retained)
    batch_loss = 0
    with torch.no_grad():
        for X, y in test_loader:
            X, y = X.to(device), y.to(device)  # mv to cpu or cuda
            y_hat = unet(X)
            # sum of batch loss
            batch_loss += loss_function(y_hat[:, :1, :, :], y, reduction="sum").item()

    batch_loss /= len(test_loader.dataset)  # average loss across batches
    return batch_loss


def main(args, unet, train_loader, val_loader, device="cuda"):
    """Train and validate U-Net model.

    Parameters
    ----------
    args : dict {n_epochs: <int>, learning_rate: <float>}
        A dictionary specifying number of epochs and learning rate.
    unet : torch.nn.Module
        The U-Net model to be trained and validated.
    train_loader : torch.utils.data.DataLoader
        The data loader for the training set.
    val_loader : torch.utils.data.DataLoader
        The data loader for the validation set.
    device : str, optional
        The device to use for training and validation, by default "cuda".

    Returns
    -------
    tuple
        A tuple containing the trained model and the validation loss history.
    """
    unet = unet.to(device)  # mv model to "cpu" or "cuda"
    optimizer = optim.Adam(unet.parameters(), lr=args["learning_rate"])
    loss_function = nn.BCEWithLogitsLoss()  # binary cross-entropy loss

    train_losses, val_losses = [], []
    for _ in tqdm(range(args["n_epochs"])):
        unet = train_unet(unet, train_loader, optimizer, loss_function, device)
        train_loss = test_unet(unet, train_loader, loss_function, device)
        train_losses.append(train_loss)
        val_loss = test_unet(unet, val_loader, loss_function, device)
        val_losses.append(val_loss)

    return train_losses, val_losses


In [None]:
# !! train model !!
Xy_paths = (
    f"{train_dataset['path']}/Fluo-N2DH-GOWT1/01",
    f"{train_dataset['path']}/Fluo-N2DH-GOWT1/01_ST/SEG",
)
train_loader, val_loader = get_data_loaders(Xy_paths[0], Xy_paths[1], batch_size=10)
n_classes = 2
args = dict(n_epochs=1, learning_rate=1e-4)
unet = UNet(n_classes)
train_losses, val_losses = main(args, unet, train_loader, val_loader, device=DEVICE)
