In [1]:
# !pip install torchvision

In [2]:
# !pip uninstall typing_extensions
# !pip install typing_extensions==4.11.0

In [3]:
!pip install typing_extensions>=4.3 --upgrade

In [4]:
# !pip install --upgrade pydantic

In [5]:
!pip install typing_extensions==4.12.2 --upgrade
# pip install typing_extensions==4.7.1 --upgrade

Defaulting to user installation because normal site-packages is not writeable
Collecting typing_extensions==4.12.2
  Using cached typing_extensions-4.12.2-py3-none-any.whl (37 kB)
Installing collected packages: typing_extensions
  Attempting uninstall: typing_extensions
    Found existing installation: typing_extensions 4.13.2
    Uninstalling typing_extensions-4.13.2:
      Successfully uninstalled typing_extensions-4.13.2
Successfully installed typing_extensions-4.12.2


In [6]:
from typing_extensions import TypeIs

**Preprocessing - flipping, resizing, rotation, gamma correction**

In [7]:
# import os
# from torchvision import transforms
# from torch.utils.data import Dataset, DataLoader, random_split
# from PIL import Image

# class CustomDataset(Dataset):
#     def __init__(self, image_dir, mask_dir, transform=None):
#         self.image_dir = image_dir
#         self.mask_dir = mask_dir
#         self.transform = transform
#         self.images = [os.path.join(image_dir, x) for x in os.listdir(image_dir) if x.endswith('.png')]
#         self.masks = [os.path.join(mask_dir, x) for x in os.listdir(mask_dir) if 'Annotation' in x]

#     def __len__(self):
#         return len(self.images)

#     def __getitem__(self, idx):
#         image_path = self.images[idx]
#         mask_path = self.masks[idx]
#         image = Image.open(image_path).convert("RGB")
#         mask = Image.open(mask_path).convert("L")
#         if self.transform:
#             image = self.transform(image)
#             mask = self.transform(mask)
#         return image, mask

# # Define transformations including geometric and intensity-based augmentations
# transform = transforms.Compose([
#     transforms.Resize((256, 256)),  # Resize images to match U-Net expected input
#     transforms.RandomHorizontalFlip(),  # Random horizontal flipping
#     transforms.RandomVerticalFlip(),  # Random vertical flipping
#     transforms.RandomRotation(20),  # Random rotations between -20 to 20 degrees
#     transforms.ColorJitter(brightness=0.2, contrast=0.2),  # Random brightness and contrast adjustments
#     transforms.ToTensor(),
#     transforms.Lambda(lambda x: x.pow(0.5))  # Gamma correction with gamma=0.5
# ])

# # Initialize dataset
# full_dataset = CustomDataset('denoised_training_set', 'masked_annotations', transform=transform)

# # Splitting the dataset into train and validation sets
# train_size = int(0.8 * len(full_dataset))
# validation_size = len(full_dataset) - train_size
# train_dataset, validation_dataset = random_split(full_dataset, [train_size, validation_size])

# # Create separate dataloaders for train and validation datasets
# train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
# validation_loader = DataLoader(validation_dataset, batch_size=10, shuffle=False)

In [8]:
import os
import numpy as np
from torchvision import transforms
from torchvision.transforms import functional as TF
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image

# Custom joint transformation that applies geometric transforms to both image and mask,
# and intensity transforms only to the image.
class JointTransform:
    def __init__(self, resize=(572, 572), rotation=20, hflip_prob=0.5, vflip_prob=0.5,
                 intensity_transforms=None):
        self.resize = resize
        self.rotation = rotation
        self.hflip_prob = hflip_prob
        self.vflip_prob = vflip_prob
        # Intensity transforms should be a torchvision transform applied only to the image.
        # For example, transforms.Compose([transforms.ColorJitter(brightness=0.2, contrast=0.2),
        #                                  transforms.Lambda(lambda x: x.pow(0.5))])
        self.intensity_transforms = intensity_transforms

    def __call__(self, image, mask):
        # 1. Resize both image and mask
        image = TF.resize(image, self.resize)
        mask = TF.resize(mask, self.resize)
        
        # 2. Random horizontal flip
        if np.random.rand() < self.hflip_prob:
            image = TF.hflip(image)
            mask = TF.hflip(mask)
            
        # 3. Random vertical flip
        if np.random.rand() < self.vflip_prob:
            image = TF.vflip(image)
            mask = TF.vflip(mask)
            
        # 4. Random rotation
        angle = np.random.uniform(-self.rotation, self.rotation)
        image = TF.rotate(image, angle, interpolation=Image.BILINEAR)
        # For masks, use nearest neighbor interpolation to preserve label boundaries.
        mask = TF.rotate(mask, angle, interpolation=Image.NEAREST)
        
        # 5. Apply intensity transforms to the image only (if provided)
        if self.intensity_transforms:
            image = self.intensity_transforms(image)
        
        # 6. Convert both image and mask to tensor
        image = TF.to_tensor(image)
        mask = TF.to_tensor(mask)
        # Optionally ensure the mask is binary
        mask = (mask > 0.5).float()
        
        return image, mask

# Define intensity-only transformations for the image.
intensity_transforms = transforms.Compose([
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.Lambda(lambda x: TF.adjust_gamma(x, 0.5))  # Gamma correction with gamma=0.5
])

# Create the joint transformation instance
joint_transform = JointTransform(resize=(572, 572), rotation=20,
                                 hflip_prob=0.5, vflip_prob=0.5,
                                 intensity_transforms=intensity_transforms)

class CustomDataset(Dataset):
    def __init__(self, image_dir, mask_dir, joint_transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.joint_transform = joint_transform
        # Sort the file lists to ensure alignment between images and masks
        self.images = sorted([os.path.join(image_dir, x) for x in os.listdir(image_dir) if x.endswith('.png')])
        self.masks = sorted([os.path.join(mask_dir, x) for x in os.listdir(mask_dir) if 'Annotation' in x])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        mask_path = self.masks[idx]
        image = Image.open(image_path).convert("RGB")
        mask = Image.open(mask_path).convert("L")
        
        if self.joint_transform:
            image, mask = self.joint_transform(image, mask)
        return image, mask

# Initialize dataset with the joint transform
full_dataset = CustomDataset('denoised_training_set', 'masked_annotations', joint_transform=joint_transform)

# Splitting the dataset into train and validation sets
# train_size = int(0.8 * len(full_dataset))
# validation_size = len(full_dataset) - train_size
# train_dataset, validation_dataset = random_split(full_dataset, [train_size, validation_size])

# # Create separate dataloaders for train and validation datasets
# train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
# validation_loader = DataLoader(validation_dataset, batch_size=10, shuffle=False)

**Checking alignment and order of train image and the corresponding masking**

In [9]:
for img_path, mask_path in zip(sorted(os.listdir('denoised_training_set')), 
                               sorted(os.listdir('masked_annotations'))):
    if img_path.endswith('.png') and 'Annotation' in mask_path:
        base_img = os.path.splitext(img_path)[0]
        base_mask = os.path.splitext(mask_path)[0].replace('_Annotation', '')
        assert base_img == base_mask, f"Mismatch: {base_img} vs {base_mask}"


**Unet with resnet101 as backbone**

In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models

class ResConv(nn.Module):
    """ Convolution block for U-Net with repeated convolutions and ReLU activations. """
    def __init__(self, in_ch, out_ch):
        super(ResConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class UpConv(nn.Module):
    """ Upsampling block for U-Net, using bilinear interpolation and convolution. """
    def __init__(self, in_ch, out_ch):
        super(UpConv, self).__init__()
        self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.conv = ResConv(in_ch, out_ch)

    def forward(self, from_down, from_up):
        from_up = self.up(from_up)
        diffY = from_down.size()[2] - from_up.size()[2]
        diffX = from_down.size()[3] - from_up.size()[3]
        from_up = F.pad(from_up, [diffX // 2, diffX - diffX // 2,
                                  diffY // 2, diffY - diffY // 2])
        x = torch.cat([from_down, from_up], dim=1)
        return self.conv(x)

class UNetResNet101(nn.Module):
    def __init__(self, n_classes=1):
        super(UNetResNet101, self).__init__()
        base_model = models.resnet101(pretrained=True)
        self.base_layers = list(base_model.children())
        
        # Extract layers from ResNet101
        self.layer0 = nn.Sequential(*self.base_layers[:3])  # conv1, bn1, relu
        self.maxpool = self.base_layers[3]
        self.layer1 = self.base_layers[4]  # Output: 256 channels
        self.layer2 = self.base_layers[5]  # Output: 512 channels
        self.layer3 = self.base_layers[6]  # Output: 1024 channels
        self.layer4 = self.base_layers[7]  # Output: 2048 channels

        # Decoder (make sure the channel numbers match the skip connection outputs)
        self.up4 = UpConv(2048 + 1024, 1024)  # Concatenating x3 (1024) and x4 (2048) -> 3072 channels
        self.up3 = UpConv(1024 + 512, 512)    # Concatenating x2 (512) and previous output (1024) -> 1536 channels
        self.up2 = UpConv(512 + 256, 256)     # Concatenating x1 (256) and previous output (512) -> 768 channels
        self.up1 = UpConv(256 + 64, 64)       # Concatenating x0 (64) and previous output (256) -> 320 channels

        # Final upsampling and output convolution to match input size
        self.final_up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
        self.final_conv = nn.Conv2d(64, n_classes, kernel_size=1)

    def forward(self, x):
        # Encoder: get intermediate features for skip connections
        x0 = self.layer0(x)       # Early features, e.g., 64 channels
        x1 = self.maxpool(x0)
        x1 = self.layer1(x1)      # 256 channels
        x2 = self.layer2(x1)      # 512 channels
        x3 = self.layer3(x2)      # 1024 channels
        x4 = self.layer4(x3)      # 2048 channels

        # Decoder: use skip connections from intermediate features
        x = self.up4(x3, x4)      # Upsample: x3 (from_down) + x4 (from_up)
        x = self.up3(x2, x)       # Upsample: x2 + output of previous block
        x = self.up2(x1, x)       # Upsample: x1 + output of previous block
        x = self.up1(x0, x)       # Upsample: x0 + output of previous block

        x = self.final_up(x)      # Final upsampling to the original size
        x = self.final_conv(x)
        return torch.sigmoid(x)

**Dice loss + Binary cross-entropy loss**

In [11]:
import torch.nn.functional as F

class DiceBCELoss(nn.Module):
    def __init__(self):
        super(DiceBCELoss, self).__init__()

    def forward(self, inputs, targets, smooth=1):
        # Flatten label and prediction tensors
        inputs = inputs.view(-1)
        targets = targets.view(-1)
        
        intersection = (inputs * targets).sum()                            
        dice_loss = 1 - (2. * intersection + smooth) / (inputs.sum() + targets.sum() + smooth)  
        BCE = F.binary_cross_entropy(inputs, targets, reduction='mean')
        Dice_BCE = BCE + dice_loss
        
        return Dice_BCE

**Training Loop**

In [12]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda:0


In [13]:
# import torch.optim as optim
# from torch.optim.lr_scheduler import ReduceLROnPlateau
# import numpy as np
# from sklearn.metrics import f1_score

# # Assuming UNetResNet101 and DiceBCELoss are already imported
# model = UNetResNet101().to(device)  # Ensure your model is the one with ResNet-101
# loss_function = DiceBCELoss()
# optimizer = optim.Adam(model.parameters(), lr=1e-4)
# scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)

# num_epochs = 50  # Set the number of epochs you want to train for

# for epoch in range(num_epochs):
#     # Training loop
#     model.train()
#     running_loss = 0.0
#     for images, masks in train_loader:
#         images, masks = images.to(device), masks.to(device)
#         optimizer.zero_grad()
#         outputs = model(images)
#         loss = loss_function(outputs, masks)
#         loss.backward()
#         optimizer.step()
#         running_loss += loss.item()

#     # Compute training metrics in evaluation mode
#     model.eval()
#     train_loss = 0.0
#     all_train_preds = []
#     all_train_targets = []
#     with torch.no_grad():
#         for images, masks in train_loader:
#             images, masks = images.to(device), masks.to(device)
#             outputs = model(images)
#             loss = loss_function(outputs, masks)
#             train_loss += loss.item()
#             # Threshold outputs and targets at 0.5 to obtain binary predictions
#             preds = (outputs > 0.5).float()
#             binary_masks = (masks > 0.5).float()
#             all_train_preds.append(preds.cpu().numpy().flatten())
#             all_train_targets.append(binary_masks.cpu().numpy().flatten())
#     all_train_preds = np.concatenate(all_train_preds)
#     all_train_targets = np.concatenate(all_train_targets)
#     train_f1 = f1_score(all_train_targets, all_train_preds)

#     # Compute validation metrics
#     val_loss = 0.0
#     all_val_preds = []
#     all_val_targets = []
#     with torch.no_grad():
#         for images, masks in validation_loader:
#             images, masks = images.to(device), masks.to(device)
#             outputs = model(images)
#             loss = loss_function(outputs, masks)
#             val_loss += loss.item()
#             preds = (outputs > 0.5).float()
#             binary_masks = (masks > 0.5).float()
#             all_val_preds.append(preds.cpu().numpy().flatten())
#             all_val_targets.append(binary_masks.cpu().numpy().flatten())
#     all_val_preds = np.concatenate(all_val_preds)
#     all_val_targets = np.concatenate(all_val_targets)
#     val_f1 = f1_score(all_val_targets, all_val_preds)

#     # Print epoch summary with both training and validation metrics
#     print(f'Epoch {epoch+1}/{num_epochs}, '
#           f'Train Loss: {running_loss/len(train_loader):.4f}, Train F1: {train_f1:.4f}, '
#           f'Validation Loss: {val_loss/len(validation_loader):.4f}, Validation F1: {val_f1:.4f}')
    
#     # Adjust learning rate based on the validation loss
#     scheduler.step(val_loss/len(validation_loader))

# # Save the trained model
# torch.save(model.state_dict(), 'unet_resnet101_model.pth')

In [14]:
import os
import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import numpy as np
from sklearn.metrics import f1_score
from sklearn.model_selection import KFold
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import cv2

# Assume full_dataset, device, UNetResNet101, DiceBCELoss, and joint_transform are defined already.
# full_dataset = CustomDataset('denoised_training_set', 'masked_annotations', joint_transform=joint_transform)

# Number of folds and training epochs
num_folds = 5
num_epochs = 30
patience = 5

# Set up KFold splitter
indices = np.arange(len(full_dataset))
kfold = KFold(n_splits=num_folds, shuffle=True, random_state=42)

# Folder to save overlay images (optional)
overlay_folder = 'overlay_images'
os.makedirs(overlay_folder, exist_ok=True)

# Helper function for overlay (as before)
def overlay_mask_on_image(image, mask, color=(255, 0, 0), alpha=0.4):
    color_mask = np.zeros_like(image)
    color_mask[mask == 255] = color
    overlay = cv2.addWeighted(image, 1 - alpha, color_mask, alpha, 0)
    return overlay

# Variables to track the best model overall (based on validation F1 score)
best_val_f1_overall = 0.0
best_model_state = None

# Begin cross-validation loop
for fold, (train_idx, val_idx) in enumerate(kfold.split(indices)):
    print(f"Starting fold {fold+1}/{num_folds}")
    
    # Create Subset datasets for current fold
    train_subset = Subset(full_dataset, train_idx)
    val_subset = Subset(full_dataset, val_idx)
    
    # Create dataloaders for current fold
    train_loader = DataLoader(train_subset, batch_size=10, shuffle=True)
    validation_loader = DataLoader(val_subset, batch_size=10, shuffle=False)
    
    # Reinitialize the model, loss, optimizer, and scheduler for each fold
    model = UNetResNet101().to(device)
    loss_function = DiceBCELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=10, verbose=True)
    
    best_fold_val_f1 = -1.0
    early_stop_counter = 0
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        running_loss = 0.0
        for images, masks in train_loader:
            images, masks = images.to(device), masks.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_function(outputs, masks)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        train_loss_epoch = running_loss / len(train_loader)
        
        # Evaluate training set for F1 score
        model.eval()
        all_train_preds = []
        all_train_targets = []
        with torch.no_grad():
            for images, masks in train_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                preds = (outputs > 0.5).float()
                binary_masks = (masks > 0.5).float()
                all_train_preds.append(preds.cpu().numpy().flatten())
                all_train_targets.append(binary_masks.cpu().numpy().flatten())
        all_train_preds = np.concatenate(all_train_preds)
        all_train_targets = np.concatenate(all_train_targets)
        train_f1 = f1_score(all_train_targets, all_train_preds)
        
        # Evaluate validation set
        val_loss = 0.0
        all_val_preds = []
        all_val_targets = []
        with torch.no_grad():
            for images, masks in validation_loader:
                images, masks = images.to(device), masks.to(device)
                outputs = model(images)
                loss = loss_function(outputs, masks)
                val_loss += loss.item()
                preds = (outputs > 0.5).float()
                binary_masks = (masks > 0.5).float()
                all_val_preds.append(preds.cpu().numpy().flatten())
                all_val_targets.append(binary_masks.cpu().numpy().flatten())
        all_val_preds = np.concatenate(all_val_preds)
        all_val_targets = np.concatenate(all_val_targets)
        val_f1 = f1_score(all_val_targets, all_val_preds)
        val_loss_epoch = val_loss / len(validation_loader)
        
        print(f"Fold {fold+1}, Epoch {epoch+1}/{num_epochs}, "
              f"Train Loss: {train_loss_epoch:.4f}, Train F1: {train_f1:.4f}, "
              f"Val Loss: {val_loss_epoch:.4f}, Val F1: {val_f1:.4f}")
        
        # Adjust learning rate based on validation loss
        scheduler.step(val_loss_epoch)
        
        # If current validation F1 is best so far, update best model state across folds/epochs
        if val_f1 > best_val_f1_overall:
            best_val_f1_overall = val_f1
            best_model_state = model.state_dict()
            
        # Early stopping for current fold: check if current validation F1 improved over best_fold_val_f1
        if val_f1 > best_fold_val_f1:
            best_fold_val_f1 = val_f1
            early_stop_counter = 0
        else:
            early_stop_counter += 1
        
        if early_stop_counter >= patience:
            print(f"Early stopping triggered in fold {fold+1} at epoch {epoch+1}")
            break
        
        # Optionally, every 10 epochs, save an overlay visualization from validation
        if (epoch + 1) % 10 == 0:
            with torch.no_grad():
                for images, _ in validation_loader:
                    images = images.to(device)
                    outputs = model(images)
                    preds = (outputs > 0.5).float()
                    img_tensor = images[0].cpu()  # first image from the batch
                    pred_tensor = preds[0].cpu()  # corresponding prediction
                    img_np = img_tensor.permute(1, 2, 0).numpy()
                    img_uint8 = (img_np * 255).astype(np.uint8)
                    mask_np = pred_tensor.squeeze().numpy()
                    mask_uint8 = (mask_np * 255).astype(np.uint8)
                    overlay_img = overlay_mask_on_image(img_uint8, mask_uint8, color=(255, 0, 0), alpha=0.4)
                    # Convert to BGR for cv2.imwrite
                    overlay_bgr = cv2.cvtColor(overlay_img, cv2.COLOR_RGB2BGR)
                    overlay_save_path = os.path.join(overlay_folder, f'fold{fold+1}_epoch{epoch+1}.png')
                    cv2.imwrite(overlay_save_path, overlay_bgr)
                    print(f"Saved overlay image at: {overlay_save_path}")
                    break  # Process only one batch for overlay visualization

# After all folds, save the best model
if best_model_state is not None:
    torch.save(best_model_state, 'best_unet_resnet101_model.pth')
    print(f"Best model saved with validation F1: {best_val_f1_overall:.4f}")
else:
    print("No model was saved.")

Starting fold 1/5




Fold 1, Epoch 1/30, Train Loss: 0.4800, Train F1: 0.9544, Val Loss: 0.3736, Val F1: 0.9460
Fold 1, Epoch 2/30, Train Loss: 0.3216, Train F1: 0.9647, Val Loss: 0.3113, Val F1: 0.9586
Fold 1, Epoch 3/30, Train Loss: 0.2732, Train F1: 0.9697, Val Loss: 0.2671, Val F1: 0.9653
Fold 1, Epoch 4/30, Train Loss: 0.2332, Train F1: 0.9719, Val Loss: 0.2222, Val F1: 0.9696
Fold 1, Epoch 5/30, Train Loss: 0.2105, Train F1: 0.9652, Val Loss: 0.2209, Val F1: 0.9627
Fold 1, Epoch 6/30, Train Loss: 0.1897, Train F1: 0.9686, Val Loss: 0.1940, Val F1: 0.9656
Fold 1, Epoch 7/30, Train Loss: 0.1707, Train F1: 0.9733, Val Loss: 0.1641, Val F1: 0.9718
Fold 1, Epoch 8/30, Train Loss: 0.1579, Train F1: 0.9730, Val Loss: 0.1610, Val F1: 0.9701
Fold 1, Epoch 9/30, Train Loss: 0.1454, Train F1: 0.9748, Val Loss: 0.1508, Val F1: 0.9714
Fold 1, Epoch 10/30, Train Loss: 0.1305, Train F1: 0.9765, Val Loss: 0.1397, Val F1: 0.9726
Saved overlay image at: overlay_images/fold1_epoch10.png
Fold 1, Epoch 11/30, Train Loss:

Fold 4, Epoch 5/30, Train Loss: 0.3518, Train F1: 0.9704, Val Loss: 0.3014, Val F1: 0.9723
Fold 4, Epoch 6/30, Train Loss: 0.3160, Train F1: 0.9718, Val Loss: 0.2811, Val F1: 0.9737
Fold 4, Epoch 7/30, Train Loss: 0.2832, Train F1: 0.9711, Val Loss: 0.2610, Val F1: 0.9717
Fold 4, Epoch 8/30, Train Loss: 0.2532, Train F1: 0.9710, Val Loss: 0.2298, Val F1: 0.9725
Fold 4, Epoch 9/30, Train Loss: 0.2314, Train F1: 0.9758, Val Loss: 0.2161, Val F1: 0.9768
Fold 4, Epoch 10/30, Train Loss: 0.2162, Train F1: 0.9735, Val Loss: 0.1959, Val F1: 0.9749
Saved overlay image at: overlay_images/fold4_epoch10.png
Fold 4, Epoch 11/30, Train Loss: 0.2103, Train F1: 0.9721, Val Loss: 0.1974, Val F1: 0.9723
Fold 4, Epoch 12/30, Train Loss: 0.1915, Train F1: 0.9726, Val Loss: 0.1687, Val F1: 0.9745
Fold 4, Epoch 13/30, Train Loss: 0.1796, Train F1: 0.9748, Val Loss: 0.1599, Val F1: 0.9754
Fold 4, Epoch 14/30, Train Loss: 0.1622, Train F1: 0.9769, Val Loss: 0.1503, Val F1: 0.9768
Early stopping triggered in 

In [15]:
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model = UNetResNet101(n_classes=1).to(device)
# dummy_input = torch.rand(1, 3, 224, 224).to(device)  # Adjust input size as necessary
# output = model(dummy_input)
# print("Output size:", output.size())

**Running Model on test data**

In [16]:
import os
import torch
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2
import matplotlib.pyplot as plt

# Device configuration
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Define a test dataset class
class TestDataset(Dataset):
    def __init__(self, image_dir, transform=None):
        self.image_dir = image_dir
        self.transform = transform
        self.images = sorted([os.path.join(image_dir, x) for x in os.listdir(image_dir) if x.endswith('.png')])
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        return image, image_path

# Define deterministic transforms for test data
test_transform = transforms.Compose([
    transforms.Resize((572, 572)),
    transforms.Lambda(lambda x: TF.adjust_gamma(x, 0.5)),
    transforms.ToTensor()
])

# Set the directory where your test images are stored
test_dir = 'denoised_test_set'  # Adjust this to your actual test directory path

# Create the test dataset and dataloader
test_dataset = TestDataset(test_dir, transform=test_transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

# Instantiate your model (assumed defined in the notebook)
model = UNetResNet101(n_classes=1).to(device)

# Load the saved model weights
model.load_state_dict(torch.load('best_unet_resnet101_model.pth', map_location=device))
model.eval()

# Create an output directory for the segmentation results
output_dir = 'output_segmentations'
os.makedirs(output_dir, exist_ok=True)

# Inference loop: run the model on each test image and save the segmentation mask
for image, image_path in test_loader:
    image = image.to(device)
    with torch.no_grad():
        output = model(image)  # Get the probability map from the model (sigmoid already applied)
        # Threshold the probability map to obtain a binary segmentation mask
        seg_mask = (output > 0.5).float()
    
    # Convert tensor to NumPy array and scale to 0-255 for visualization/saving
    seg_mask_np = seg_mask.cpu().numpy().squeeze() * 255

    # Generate output filename based on input image name
    base_name = os.path.basename(image_path[0])
    output_filename = os.path.join(output_dir, f"seg_{base_name}")
    
    # Save the segmentation mask image using OpenCV
    cv2.imwrite(output_filename, seg_mask_np.astype('uint8'))
    
    # Optionally, display the segmentation mask using matplotlib
#     plt.imshow(seg_mask_np, cmap='gray')
#     plt.title(f"Segmentation: {base_name}")
#     plt.axis('off')
#     plt.show()
    
    print(f"Saved segmentation for {base_name} at {output_filename}")

Saved segmentation for 500_HC.png at output_segmentations/seg_500_HC.png
Saved segmentation for 501_HC.png at output_segmentations/seg_501_HC.png
Saved segmentation for 502_HC.png at output_segmentations/seg_502_HC.png
Saved segmentation for 503_HC.png at output_segmentations/seg_503_HC.png
Saved segmentation for 504_HC.png at output_segmentations/seg_504_HC.png
Saved segmentation for 505_HC.png at output_segmentations/seg_505_HC.png
Saved segmentation for 506_HC.png at output_segmentations/seg_506_HC.png
Saved segmentation for 507_2HC.png at output_segmentations/seg_507_2HC.png
Saved segmentation for 507_HC.png at output_segmentations/seg_507_HC.png
Saved segmentation for 508_HC.png at output_segmentations/seg_508_HC.png
Saved segmentation for 509_HC.png at output_segmentations/seg_509_HC.png
Saved segmentation for 510_HC.png at output_segmentations/seg_510_HC.png
Saved segmentation for 511_HC.png at output_segmentations/seg_511_HC.png
Saved segmentation for 512_HC.png at output_segme

Saved segmentation for 592_HC.png at output_segmentations/seg_592_HC.png
Saved segmentation for 593_HC.png at output_segmentations/seg_593_HC.png
Saved segmentation for 594_HC.png at output_segmentations/seg_594_HC.png
Saved segmentation for 595_HC.png at output_segmentations/seg_595_HC.png
Saved segmentation for 596_HC.png at output_segmentations/seg_596_HC.png
Saved segmentation for 597_HC.png at output_segmentations/seg_597_HC.png
Saved segmentation for 598_HC.png at output_segmentations/seg_598_HC.png
Saved segmentation for 599_HC.png at output_segmentations/seg_599_HC.png
Saved segmentation for 600_HC.png at output_segmentations/seg_600_HC.png
Saved segmentation for 601_HC.png at output_segmentations/seg_601_HC.png
Saved segmentation for 602_HC.png at output_segmentations/seg_602_HC.png
Saved segmentation for 603_HC.png at output_segmentations/seg_603_HC.png
Saved segmentation for 604_HC.png at output_segmentations/seg_604_HC.png
Saved segmentation for 605_HC.png at output_segment

Saved segmentation for 686_HC.png at output_segmentations/seg_686_HC.png
Saved segmentation for 687_HC.png at output_segmentations/seg_687_HC.png
Saved segmentation for 688_HC.png at output_segmentations/seg_688_HC.png
Saved segmentation for 689_HC.png at output_segmentations/seg_689_HC.png
Saved segmentation for 690_2HC.png at output_segmentations/seg_690_2HC.png
Saved segmentation for 690_HC.png at output_segmentations/seg_690_HC.png
Saved segmentation for 691_HC.png at output_segmentations/seg_691_HC.png
Saved segmentation for 692_2HC.png at output_segmentations/seg_692_2HC.png
Saved segmentation for 692_HC.png at output_segmentations/seg_692_HC.png
Saved segmentation for 693_HC.png at output_segmentations/seg_693_HC.png
Saved segmentation for 694_HC.png at output_segmentations/seg_694_HC.png
Saved segmentation for 695_HC.png at output_segmentations/seg_695_HC.png
Saved segmentation for 696_HC.png at output_segmentations/seg_696_HC.png
Saved segmentation for 697_HC.png at output_seg

**Morphological Opening and Closing + Canny edge Detector**

In [17]:
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt

# Define input and output folders
input_folder = 'output_segmentations'
output_folder = 'output_edges'
os.makedirs(output_folder, exist_ok=True)

# Define a structuring element for the morphological operations
kernel = np.ones((5, 5), np.uint8)

# Loop over all segmented images in the input folder
for filename in os.listdir(input_folder):
    if filename.endswith('.png'):
        img_path = os.path.join(input_folder, filename)
        # Read the segmentation image in grayscale
        seg_img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        
        # Apply morphological opening (erosion followed by dilation) to remove small artifacts
        opened = cv2.morphologyEx(seg_img, cv2.MORPH_OPEN, kernel)
        # Then apply morphological closing (dilation followed by erosion) to fill small holes
        closed = cv2.morphologyEx(opened, cv2.MORPH_CLOSE, kernel)
        
        # Apply Canny edge detector to extract the contour
        # Adjust thresholds as necessary (here, 50 and 150 are example values)
        edges = cv2.Canny(closed, 50, 150)
        
        # Save the edge-detected image
        output_path = os.path.join(output_folder, filename)
        cv2.imwrite(output_path, edges)
        
        # Optionally, display the original segmentation, post-morphology, and edge image side by side
#         plt.figure(figsize=(12, 4))
#         plt.subplot(1, 3, 1)
#         plt.imshow(seg_img, cmap='gray')
#         plt.title('Original Segmentation')
#         plt.axis('off')
        
#         plt.subplot(1, 3, 2)
#         plt.imshow(closed, cmap='gray')
#         plt.title('After Morphological Ops')
#         plt.axis('off')
        
#         plt.subplot(1, 3, 3)
#         plt.imshow(edges, cmap='gray')
#         plt.title('Canny Edges')
#         plt.axis('off')
        
#         plt.show()
        print(f"Processed and saved edge image for: {filename}")

Processed and saved edge image for: seg_608_HC.png
Processed and saved edge image for: seg_646_HC.png
Processed and saved edge image for: seg_517_HC.png
Processed and saved edge image for: seg_631_HC.png
Processed and saved edge image for: seg_669_HC.png
Processed and saved edge image for: seg_676_HC.png
Processed and saved edge image for: seg_665_HC.png
Processed and saved edge image for: seg_648_2HC.png
Processed and saved edge image for: seg_639_HC.png
Processed and saved edge image for: seg_592_HC.png
Processed and saved edge image for: seg_570_3HC.png
Processed and saved edge image for: seg_528_HC.png
Processed and saved edge image for: seg_560_HC.png
Processed and saved edge image for: seg_593_HC.png
Processed and saved edge image for: seg_520_HC.png
Processed and saved edge image for: seg_584_HC.png
Processed and saved edge image for: seg_610_HC.png
Processed and saved edge image for: seg_567_2HC.png
Processed and saved edge image for: seg_707_HC.png
Processed and saved edge ima

Processed and saved edge image for: seg_596_HC.png
Processed and saved edge image for: seg_682_HC.png
Processed and saved edge image for: seg_638_HC.png
Processed and saved edge image for: seg_617_HC.png
Processed and saved edge image for: seg_664_HC.png
Processed and saved edge image for: seg_615_HC.png
Processed and saved edge image for: seg_527_HC.png
Processed and saved edge image for: seg_688_HC.png
Processed and saved edge image for: seg_617_2HC.png
Processed and saved edge image for: seg_705_HC.png
Processed and saved edge image for: seg_565_HC.png
Processed and saved edge image for: seg_706_HC.png
Processed and saved edge image for: seg_640_HC.png
Processed and saved edge image for: seg_637_HC.png
Processed and saved edge image for: seg_630_2HC.png
Processed and saved edge image for: seg_674_HC.png
Processed and saved edge image for: seg_643_HC.png
Processed and saved edge image for: seg_554_HC.png
Processed and saved edge image for: seg_609_2HC.png
Processed and saved edge ima

**Ellipse fiting**

In [18]:
    import csv
    import cv2
    import numpy as np
    import os
    import math

    # --- Step 1: Load the pixel size information from the CSV file ---
    # Assume your CSV file 'test_set_pixel_size.csv' has at least these columns:
    # filename,pixel_size_mm
    pixel_size_dict = {}
    with open('test_set_pixel_size.csv', mode='r') as f:
        reader = csv.DictReader(f)
        for row in reader:
            # Adjust the column names if they are different in your CSV file.
            filename_csv = row['filename']
            pixel_size_mm = float(row['pixel size(mm)'])
            pixel_size_dict[filename_csv] = pixel_size_mm

    # --- Step 2: Process the edge images and fit ellipses ---
    edges_folder = 'output_edges'
    csv_output = 'ellipse_results.csv'
    header = ["filename", "center_x_mm", "center_y_mm", "semi_axes_a_mm", "semi_axes_b_mm", "angle_rad"]
    rows = []

    for filename in sorted(os.listdir(edges_folder)):
        if filename.endswith('.png'):
            filepath = os.path.join(edges_folder, filename)
            # Read the edge image in grayscale
            edge_img = cv2.imread(filepath, cv2.IMREAD_GRAYSCALE)

            # Find contours in the edge image
            contours, _ = cv2.findContours(edge_img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if len(contours) == 0:
                print(f"No contours found in {filename}. Skipping.")
                continue

            # Choose the largest contour (assumed to be the head contour)
            largest_contour = max(contours, key=cv2.contourArea)

            # cv2.fitEllipse requires at least 5 points
            if len(largest_contour) < 5:
                print(f"Not enough points for ellipse fitting in {filename}. Skipping.")
                continue

            # Fit ellipse to the largest contour
            ellipse = cv2.fitEllipse(largest_contour)
            # ellipse returns ((center_x, center_y), (full_axis_length_a, full_axis_length_b), angle_in_degrees)
            center, axes, angle = ellipse
            # Compute semi-axes (the axes given are the full lengths)
            semi_a = axes[0] / 2.0  # semi-major axis in pixels
            semi_b = axes[1] / 2.0  # semi-minor axis in pixels

            # --- Step 3: Look up the pixel conversion factor for this image ---
            # The filenames in the CSV are expected to be like "001_HC.png"
            # and our edge images are named "seg_001_HC.png". Remove the "seg_" prefix.
            base_filename = filename.replace("seg_", "", 1)
            if base_filename in pixel_size_dict:
                pixel_to_mm = pixel_size_dict[base_filename]
            else:
                print(f"Pixel size for {base_filename} not found in CSV. Skipping.")
                continue

            # Convert measurements from pixels to millimeters
            center_x_mm = center[0] * pixel_to_mm
            center_y_mm = center[1] * pixel_to_mm
            semi_a_mm = semi_a * pixel_to_mm
            semi_b_mm = semi_b * pixel_to_mm

            # Convert angle from degrees to radians
            angle_rad = math.radians(angle)

            # Append the result: filename, center_x_mm, center_y_mm, semi_axes_a_mm, semi_axes_b_mm, angle_rad
            rows.append([base_filename, center_x_mm, center_y_mm, semi_a_mm, semi_b_mm, angle_rad])

    # --- Step 4: Write results to a CSV file ---
    with open(csv_output, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(header)
        writer.writerows(rows)

    print(f"CSV file '{csv_output}' saved with {len(rows)} rows.")

CSV file 'ellipse_results.csv' saved with 250 rows.
