In [1]:
import os
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from torchvision import models
import segmentation_models_pytorch as smp
import albumentations as A
from albumentations.pytorch import ToTensorV2
from PIL import Image
from tifffile import imread

### Setting up the directories

In [None]:
base_dir = os.getcwd()
print("Base directory: ", base_dir)
data_dir = os.path.join(base_dir, 'PathologyData')
print("Data directory: ", data_dir)
train_dir = os.path.join(data_dir, 'train')
print("Train directory: ", train_dir)
test_dir = os.path.join(data_dir, 'test')
print("Test directory: ", test_dir)

### Setup the Train and Test directories for Original and Masked images

In [3]:
# Paths to the datasets
train_images_dir = os.path.join(train_dir, 'Original')
train_masks_dir = os.path.join(train_dir, 'Masked')
test_images_dir = os.path.join(test_dir, 'Original')
test_masks_dir = os.path.join(test_dir, 'Masked')

### Reading the Pathology dataset.
#### Original images are in .tif format and also in RGB format
#### Masked images are in .out.png format and in grey scale

In [4]:
# Pathology Dataset
class PathologyDataset(Dataset):
    def __init__(self, images_dir, masks_dir, transform=None):
        self.images_dir = images_dir
        self.masks_dir = masks_dir
        self.image_files = os.listdir(images_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_name = self.image_files[idx]
        img_path = os.path.join(self.images_dir, img_name)
        
        # Change the extension of the mask file to .png
        mask_name = os.path.splitext(img_name)[0] + '.out.png'
        mask_path = os.path.join(self.masks_dir, mask_name)

        # Debugging: Print the paths to check if they are correct
        #print(f"Loading image: {img_path}")
        #print(f"Loading mask: {mask_path}")

        if not os.path.exists(img_path):
            print(f"Image not found: {img_path}")
            return None, None

        if not os.path.exists(mask_path):
            print(f"Mask not found: {mask_path}")
            return None, None

        # Read .tif files using tifffile and .png using PIL
        try:
            image = imread(img_path)
            mask = np.array(Image.open(mask_path).convert("L"))
        except Exception as e:
            print(f"Error reading image or mask: {e}")
            return None, None

        # Convert to 3-channel RGB if necessary
        if len(image.shape) == 2:  # If grayscale, convert to RGB
            image = np.stack([image] * 3, axis=-1)
        elif image.shape[2] == 4:  # If RGBA, convert to RGB
            image = image[:, :, :3]

        # Ensure mask is single channel
        if len(mask.shape) == 3 and mask.shape[2] > 1:
            mask = mask[:, :, 0]

        # Ensure mask is binary
        mask = (mask > 0).astype(np.uint8)
        
        if self.transform:
            augmented = self.transform(image=image, mask=mask)
            image = augmented['image']
            mask = augmented['mask']

        return image, mask

### Image transformations for the ResNet model
#### Normalize the images to meet the mean and SD of the model training

In [5]:
# Define transformations
train_transform = A.Compose([
    A.Resize(256, 256),
    A.HorizontalFlip(p=0.5),
    A.VerticalFlip(p=0.5),
    A.RandomRotate90(p=0.5),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

test_transform = A.Compose([
    A.Resize(256, 256),
    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
    ToTensorV2()
])

### Creating the dataloaders to train and test

In [6]:
# Dataset and Dataloader
train_dataset = PathologyDataset(train_images_dir, train_masks_dir, transform=train_transform)
test_dataset = PathologyDataset(test_images_dir, test_masks_dir, transform=test_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=0)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=0)

### Unet Model with ResNet as the encoder

Using ResNet34 as the encoder (or backbone) in a UNet model for image segmentation offers several benefits:

#### Benefits of Using ResNet34
Pre-trained Weights:
ResNet34 is often available with pre-trained weights on large datasets like ImageNet. Using these pre-trained weights can significantly improve the performance of the model through transfer learning, especially if the dataset is small.
Deep Architecture:

ResNet34 is a relatively deep network, which helps in learning complex features. This depth helps the model to understand and capture more intricate details in the images, which is crucial for tasks like segmentation.
Residual Connections:

The ResNet architecture includes residual connections, which help in mitigating the vanishing gradient problem. This makes training deep networks more effective and allows for the construction of deeper models.
Versatility:

ResNet34 has shown great performance across various computer vision tasks, making it a versatile choice for different types of image analysis, including segmentation.

In [7]:
# Define the model
model = smp.Unet(
    encoder_name="resnet34",
    encoder_weights="imagenet",
    in_channels=3,
    classes=1,
)

BCEWithLogitsLoss combines a sigmoid activation function with the binary cross-entropy loss in a numerically stable way. Logits are the raw scores produced by the model before applying the sigmoid function, whereas probabilities are the output of the sigmoid function (ranging between 0 and 1).

In [8]:
# Loss function and optimizer
loss_fn = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

#### Training function

In [9]:
# Training function
def train_epoch(model, dataloader, loss_fn, optimizer, device):
    model.train()
    epoch_loss = 0
    for images, masks in dataloader:
        images = images.to(device)
        masks = masks.to(device).unsqueeze(1).float()

        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, masks)
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

#### Validation/Test function

In [10]:
# Validation function
def validate_epoch(model, dataloader, loss_fn, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for images, masks in dataloader:
            images = images.to(device)
            masks = masks.to(device).unsqueeze(1).float()

            outputs = model(images)
            loss = loss_fn(outputs, masks)

            epoch_loss += loss.item()

    return epoch_loss / len(dataloader)

In [11]:
# Training loop
num_epochs = 5
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: ", device)
model.to(device)

Device:  cpu


Unet(
  (encoder): ResNetEncoder(
    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
    (layer1): Sequential(
      (0): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (relu): ReLU(inplace=True)
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      )
      (1): BasicBlock(
        (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track

#### Testing the loader with 1 image (for debug)

In [12]:
dataset = PathologyDataset(train_images_dir, train_masks_dir, transform=train_transform)
image, mask = dataset[0]
print(image.shape, mask.shape)

torch.Size([3, 256, 256]) torch.Size([256, 256])


#### the Train-Test run

In [13]:
for epoch in range(num_epochs):
    train_loss = train_epoch(model, train_loader, loss_fn, optimizer, device)
    test_loss = validate_epoch(model, test_loader, loss_fn, device)
    
    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_loss:.4f}, test Loss: {test_loss:.4f}')

Epoch 1/5, Train Loss: 0.8259, test Loss: 0.9162
Epoch 2/5, Train Loss: 0.7673, test Loss: 0.7882
Epoch 3/5, Train Loss: 0.7292, test Loss: 0.6857
Epoch 4/5, Train Loss: 0.7015, test Loss: 0.6729
Epoch 5/5, Train Loss: 0.6787, test Loss: 0.6533
