In [None]:
import glob
import numpy as np

import imageio.v2 as image

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if not torch.cuda.is_available():
    print("Warning: CUDA not found. Using CPU")

In [None]:
class OasisDataset(Dataset):
    """
        Custom PyTorch dataset for loading OASIS images and labels.

    Attributes (Where N - batch, C - channels, D - depth, H - height, W - width):
        images (torch.Tensor): Image tensors of shape (N, C, H, W).
        labels (torch.Tensor): Label tensors of shape (N, H, W, C) or (N, H, W).
    """
    def __init__(self, images, labels):
        # Convert to torch tensors and ensure channels-first format
        self.images = torch.from_numpy(images).permute(0, 3, 1, 2)  # (N, H, W, 1) → (N, 1, H, W)
        self.labels = torch.from_numpy(labels)  # (N, H, W, C) or (N, H, W)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        return self.images[idx], self.labels[idx]
    
def load_training(path):
    """
    Load training images from the specified directory.
    """
    image_list = []
    for filename in glob.glob(path + "/*.png"):
        im = image.imread(filename)
        image_list.append(im)
    print("train_X shape:", np.array(image_list).shape)
    return np.array(image_list, dtype=np.float32)

def process_training(data_set):
    """
    Normalise and reshape training images.
    """
    data_set = data_set.astype(np.float32)
    if data_set.max() > 1.0:
        data_set = data_set / 255.0
    data_set = data_set[:, :, :, np.newaxis]  # (N, H, W, 1)
    return data_set

def load_labels(path):
    """
    Load label masks and map pixel values to integer class IDs.
    """
    image_list = []
    for filename in glob.glob(path + "/*.png"):
        im = image.imread(filename)
        one_hot = np.zeros((im.shape[0], im.shape[1]))
        for i, unique_value in enumerate(np.unique(im)):
            one_hot[:, :][im == unique_value] = i
        image_list.append(one_hot)
    print("train_y shape:", np.array(image_list).shape)
    return np.array(image_list, dtype=np.uint8)

def process_labels(seg_data):
    """
    Convert integer label masks into one-hot encoded format.
    """
    onehot_Y = []
    for n in range(seg_data.shape[0]):
        im = seg_data[n]
        n_classes = 4
        one_hot = np.zeros((im.shape[0], im.shape[1], n_classes), dtype=np.uint8)
        for i, unique_value in enumerate(np.unique(im)):
            one_hot[:, :, i][im == unique_value] = 1
        onehot_Y.append(one_hot)
    onehot_Y = np.array(onehot_Y)
    print("Labels shape:", onehot_Y.shape)
    return onehot_Y

In [None]:
train_X = process_training(load_training("oasis/keras_png_slices_data/keras_png_slices_train"))
train_Y = process_labels(load_labels("oasis/keras_png_slices_data/keras_png_slices_seg_train"))
train_dataset = OasisDataset(train_X, train_Y)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

test_X = process_training(load_training("oasis/keras_png_slices_data/keras_png_slices_test"))
test_Y = process_labels(load_labels("oasis/keras_png_slices_data/keras_png_slices_seg_test"))
test_dataset = OasisDataset(test_X, test_Y)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

In [None]:
class DoubleConv(nn.Module):
    """
    A two-layer convolutional block used in UNet.
    """
    def __init__(self, in_ch: int, out_ch: int):
        super().__init__()
        self.double_conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.double_conv(x)

In [None]:
class UNet(nn.Module):
    """
    UNet with encoder channel sizes similar to the provided VAE (1 -> 32 -> 64 -> 128 -> 256 -> 512).
    The decoder mirrors the encoder and uses transposed convolutions for upsampling.
    Final layer outputs n_classes channels (categorical / one-hot style output before Softmax).
    """
    def __init__(self, n_classes: int = 4, input_channels: int = 1, base_filters: int = 32):
        super().__init__()
        f = base_filters
        # Encoder
        self.inc = DoubleConv(input_channels, f)                            # 1 -> 32
        self.down1 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(f, f*2))     # 32 -> 64
        self.down2 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(f*2, f*4))   # 64 -> 128
        self.down3 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(f*4, f*8))   # 128 -> 256
        self.down4 = nn.Sequential(nn.MaxPool2d(2), DoubleConv(f*8, f*16))  # 256 -> 512

        # Decoder
        self.up1 = nn.ConvTranspose2d(f*16, f*8, kernel_size=2, stride=2)   # 512 -> 256
        self.dec1 = DoubleConv(f*16, f*8)

        self.up2 = nn.ConvTranspose2d(f*8, f*4, kernel_size=2, stride=2)    # 256 -> 128
        self.dec2 = DoubleConv(f*8, f*4)

        self.up3 = nn.ConvTranspose2d(f*4, f*2, kernel_size=2, stride=2)    # 128 -> 64
        self.dec3 = DoubleConv(f*4, f*2)

        self.up4 = nn.ConvTranspose2d(f*2, f, kernel_size=2, stride=2)      # 64 -> 32
        self.dec4 = DoubleConv(f*2, f)

        # Final conv to produce logits for each class
        self.outc = nn.Conv2d(f, n_classes, kernel_size=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Encoder
        x1 = self.inc(x)     # (N,   f,    H, W)
        x2 = self.down1(x1)  # (N,  2f,  H/2, W/2)
        x3 = self.down2(x2)  # (N,  4f,  H/4, W/4)
        x4 = self.down3(x3)  # (N,  8f,  H/8, W/8)
        x5 = self.down4(x4)  # (N, 16f, H/16, W/16)

        # Decoder with skip connections
        d1 = self.up1(x5)
        d1 = torch.cat([d1, x4], dim=1)
        d1 = self.dec1(d1)

        d2 = self.up2(d1)
        d2 = torch.cat([d2, x3], dim=1)
        d2 = self.dec2(d2)

        d3 = self.up3(d2)
        d3 = torch.cat([d3, x2], dim=1)
        d3 = self.dec3(d3)

        d4 = self.up4(d3)
        d4 = torch.cat([d4, x1], dim=1)
        d4 = self.dec4(d4)

        logits = self.outc(d4)  # (N, n_classes, H, W)
        return logits