# Technical Challenge: Surgical Tool Segmentation

This notebook implements semantic image segmentation of surgical tool parts. Semantic segmentation is a computer vision task that involves labelling every pixel in an image with a corresponding class.

**Objective:** Generate pixel-level masks for RGB frames of surgical videos, segmenting both prominent surgical tools as well as thin and small objects such as surgical clips, suturing threads and needles.

## Install and Import Libraries

In [None]:
from google.colab import drive
import os
import glob
import pandas as pd
import random
import shutil
import time

import tensorflow as tf
from tensorflow.keras.utils import Sequence
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
from tensorflow.keras.metrics import MeanIoU

import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.patches as mpatches

# Check if still needed
import cv2
import numpy as np
from sklearn.model_selection import train_test_split

## Configurations

In [None]:
# --- Mount Google Drive to access the project files and dataset ---
drive.mount('/content/drive')

# --- Configuration ---
VAL_SPLIT = 0.2
BATCH_SIZE = 16
SEED = 42 # A random seed for reproducibility
IMG_ROWS, IMG_COLS =  480, 640 # Per the README, resolution of all videos is 1080i, i.e. 1920, 1080, but we will use a smaller size for faster processing

data_root = '/content'
train_path = os.path.join(data_root, 'train_dataset')
test_path = os.path.join(data_root, 'test_dataset')

input_shape = (IMG_ROWS, IMG_COLS, 3)

class_names = [
    "Background",
    "Tool clasper",
    "Tool wrist",
    "Tool shaft",
    "Suturing needle",
    "Thread",
    "Suction tool",
    "Needle holder",
    "Clamps",
    "Catheter"
]

NUM_CLASSES = len(class_names)

In [None]:
# %%time
# # --- Copy the compressed pre-processed data from my google drive to the local disk ---
# train_zip_path = "/content/drive/MyDrive/Colab Notebooks/Surgical_Tool_Segmentation/data/train_dataset.zip"
# !cp "{train_zip_path}" "/content/"
# print("Train zip file copied.")

# test_zip_path = "/content/drive/MyDrive/Colab Notebooks/Surgical_Tool_Segmentation/data/test_dataset.zip"
# !cp "{test_zip_path}" "/content/"
# print("Test zip file copied.")

In [None]:
# %%time
# # --- Unzip the data on the local disk ---
# !unzip -q "/content/train_dataset.zip" -d "/content/train_dataset"
# !unzip -q "/content/test_dataset.zip" -d "/content/test_dataset"

# print("Data unzipped.")

## Load and Inspect Data

The dataset is already divided into train data (13043 frames from 40 videos, 80%) and test data (3252 from 10 videos, 20%). It is provided in the following directory structure, where the `rgb` folder contains the to-be-segmented images and the `segmentation` folder contains the ground-truth segmentation masks. From each video segment there are between 101 and 706 extracted frames available.

```
data/
└── train_dataset
    └── video_XX
        └── segmentation
        └── rgb
    └── ...
└── test_dataset
    └── video_XX
        └── segmentation
        └── rgb
    └── ...
```

In [None]:
# --- Get a list of video directories ---
all_paths = glob.glob(f'/content/drive/MyDrive/Colab Notebooks/Surgical_Tool_Segmentation/data/*_dataset/video_*')
all_video_dirs = [path for path in all_paths if os.path.isdir(path)]
all_video_dirs.sort() # Sort the list for consistent processing order

# --- Count Files and Collect Data ---
summary_data = []
print(f"Generating summary of available data...\n")

for video_dir in all_video_dirs:
    video_name = os.path.basename(video_dir)
    segmentation_path = os.path.join(video_dir, 'segmentation')
    images_path = os.path.join(video_dir, 'rgb')

    segmentation_count = len(os.listdir(segmentation_path)) if os.path.exists(segmentation_path) else 0
    images_count = len(os.listdir(images_path)) if os.path.exists(images_path) else 0

    summary_data.append({
        'Video': video_name,
        'Segmentation_Masks': segmentation_count,
        'RGB_Images': images_count,
    })

# --- Display the Summary Table ---
summary_df = pd.DataFrame(summary_data)
merged_df = summary_df.groupby(summary_df['Video'].str.extract(r'(video_\d+)')[0]).sum(numeric_only=True).reset_index()
merged_df = merged_df.rename(columns={'Video': 'base_video'})

print(f"Min frames per video: {merged_df.Segmentation_Masks.min()}")
print(f"Max frames per video: {merged_df.Segmentation_Masks.max()}")
print(f"Total num of training frames: {merged_df.head(40)['Segmentation_Masks'].sum()}")
print(f"Total num of testing frames: {merged_df.tail(10)['Segmentation_Masks'].sum()}\n")
display(merged_df)

## Data Pipeline

We will implement a custom data generator (`SurgicalToolGenerator`) which will be responsible for:
1.  Locating corresponding image and mask pairs.
2.  Loading them in batches to avoid loading the entire dataset into memory.
3.  Applying resizing, normalisation, and data augmentation on the fly.

The segmentation masks are provided at the same resolution as the video frames, with the grayscale value of each pixel corresponding to one of the following semantic classes:

| Label            | Class Name      |
| -------------    | -------------   |
| 0                | Background      |
| 1                | Tool clasper    |
| 2                | Tool wrist      |
| 3                | Tool shaft      |
| 4                | Suturing needle |
| 5                | Thread          |
| 6                | Suction tool    |
| 7                | Needle holder   |
| 8                | Clamps          |
| 9                | Catheter        |

In [None]:
def apply_augmentation(image, mask):
    """Applies a set of augmentations to an image and its corresponding mask."""
    # Paired Augmentations (applied to both image and mask)
    # Horizontal Flip
    if random.random() > 0.5:
        image = cv2.flip(image, 1)
        mask = cv2.flip(mask, 1)

    # Image-Only Augmentations
    # Brightness/Contrast
    if random.random() > 0.5:
        # Adjust brightness by a random factor
        brightness_factor = random.uniform(0.8, 1.2)
        image = np.clip(image * brightness_factor, 0, 255).astype(np.uint8)

    return image, mask

class SurgicalToolGenerator(Sequence):
    """Custom data generator for surgical tool segmentation."""
    def __init__(self, image_paths, mask_paths, batch_size, image_size, augment=False):
        super().__init__()
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.batch_size = batch_size
        self.image_size = image_size
        self.augment = augment

    def __len__(self):
        """Denotes the number of batches per epoch."""
        return len(self.image_paths) // self.batch_size

    def __getitem__(self, index):
        """Generate one batch of data."""
        batch_image_paths = self.image_paths[index*self.batch_size:(index+1)*self.batch_size]
        batch_mask_paths = self.mask_paths[index*self.batch_size:(index+1)*self.batch_size]

        batch_images = np.zeros((self.batch_size, self.image_size[0], self.image_size[1], 3), dtype=np.float32)
        batch_masks = np.zeros((self.batch_size, self.image_size[0], self.image_size[1]), dtype=np.uint8)

        for i, (img_path, mask_path) in enumerate(zip(batch_image_paths, batch_mask_paths)):
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
            mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)

            if mask is None:
                print(f"WARNING: Could not read mask {mask_path}. Filling with zeros.")
                mask = np.zeros((10, 10), dtype=np.uint8)

            if self.augment:
                img, mask = apply_augmentation(img, mask)

            img_resized = cv2.resize(img, (self.image_size[1], self.image_size[0]))
            mask_resized = cv2.resize(mask, (self.image_size[1], self.image_size[0]), interpolation=cv2.INTER_NEAREST)

            batch_images[i] = img_resized / 255.0
            batch_masks[i] = mask_resized

        return batch_images, batch_masks

    def on_epoch_end(self):
        """Shuffle data at the end of every epoch."""
        if self.augment:
            combined = list(zip(self.image_paths, self.mask_paths))
            random.shuffle(combined)
            unzipped = list(zip(*combined))
            self.image_paths = list(unzipped[0])
            self.mask_paths = list(unzipped[1])

In [None]:
# --- Prepare file paths for generators ---
def get_paths(dataset_path):
    image_paths = sorted(glob.glob(os.path.join(dataset_path, 'all_rgb/*.png'), recursive=True))
    mask_paths = sorted(glob.glob(os.path.join(dataset_path, 'all_segmentation/*.png'), recursive=True))
    return image_paths, mask_paths

train_val_images, train_val_masks = get_paths(train_path)
test_images, test_masks = get_paths(test_path)

# Split the training data into training and validation sets
train_images, val_images, train_masks, val_masks = train_test_split(
    train_val_images, train_val_masks, test_size=VAL_SPLIT, random_state=SEED
)

# --- Instantiate the generators ---
train_generator = SurgicalToolGenerator(
    image_paths=train_images,
    mask_paths=train_masks,
    batch_size=BATCH_SIZE,
    image_size=(IMG_ROWS, IMG_COLS),
    augment=True
)

validation_generator = SurgicalToolGenerator(
    image_paths=val_images,
    mask_paths=val_masks,
    batch_size=BATCH_SIZE,
    image_size=(IMG_ROWS, IMG_COLS),
    augment=False
)

test_generator = SurgicalToolGenerator(
    image_paths=test_images,
    mask_paths=test_masks,
    batch_size=BATCH_SIZE,
    image_size=(IMG_ROWS, IMG_COLS),
    augment=False
)

print(f"Found {len(train_images)} images for training.")
print(f"Found {len(val_images)} images for validation.")
print(f"Found {len(test_images)} images for testing.")

### Visualise Sample Images and Segmentation Masks

In [None]:
def visualise_from_generator(generator, num_examples, class_names):
    """Visualises images and masks from a data generator."""
    # Get a colormap for the segmentation mask
    colours = plt.get_cmap('tab10', len(class_names))
    custom_cmap = mcolors.ListedColormap(colours.colors)

    # Create the plot
    fig, ax = plt.subplots(num_examples, 2, figsize=(10, 5 * num_examples))
    fig.suptitle('Visualising Example Images and Segmentation Masks', fontsize=24, fontweight='bold')

    # Get a single batch from the generator
    sample_image_batch, sample_mask_batch = generator[0] # Get the first batch

    for i in range(num_examples):
        # Pick a sample from the batch
        image = sample_image_batch[i]
        mask = sample_mask_batch[i]

        ax[i, 0].imshow(image)
        ax[i, 0].set_title(f'Image #{i+1}', fontsize=16)
        ax[i, 1].imshow(mask, cmap=custom_cmap, vmin=0, vmax=len(class_names)-1)
        ax[i, 1].set_title(f'Segmentation Mask #{i+1}', fontsize=16)

    # Add a legend to the figure
    legend_patches = [mpatches.Patch(color=colours.colors[i], label=class_names[i]) for i in range(len(class_names))]
    fig.legend(handles=legend_patches, bbox_to_anchor=(1.05, 0.7), loc='upper left', fontsize=12)

    # Clean up the plot
    for axis in ax.flat:
        axis.axis('off')

    plt.tight_layout(rect=[0, 0, 0.85, 0.96]) # Adjust layout to make space for the legend
    plt.show()

In [None]:
# Visualise a few examples from the training generator to check augmentations
visualise_from_generator(train_generator, num_examples=3, class_names=class_names)

## Build the Model

We will use a **U-Net** architecture, which is the standard and a highly effective choice for biomedical image segmentation. Its key features are:

1.  An **Encoder** (contracting path) that captures context by downsampling the image and extracting features at different scales.
2.  A **Decoder** (expansive path) that uses transposed convolutions to upsample the feature maps, enabling precise localisation.
3.  **Skip Connections** that merge feature maps from the encoder path with the corresponding decoder path. This is crucial as it allows the network to combine high-level contextual features with fine-grained spatial information, resulting in accurate segmentation masks.

### Encoder

The encoder consists of repeated blocks of convolutions followed by max-pooling to downsample the image.

In [None]:
# Encoder Utilities

def conv2d_block(input_tensor, n_filters, kernel_size=3):
  '''
  Adds 2 convolutional layers with Batch Normalisation and ReLU activation.
  This is the fundamental building block of the U-Net.

  Args:
    input_tensor (tensor): The input tensor.
    n_filters (int): Number of filters for the convolutional layers.
    kernel_size (int): Kernel size for the convolution.

  Returns:
    x (tensor): Tensor of output features.
  '''
  x = input_tensor
  for _ in range(2):
    x = tf.keras.layers.Conv2D(filters=n_filters,
                              kernel_size=(kernel_size, kernel_size),
                              kernel_initializer='he_normal',
                              padding='same')(x)
    # --- IMPROVEMENT: Added Batch Normalisation --- #
    # This helps stabilise and accelerate training.
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation('relu')(x)

  return x


def encoder_block(inputs, n_filters=64, pool_size=(2,2), dropout=0.3):
  '''
  Defines one downsampling block of the encoder.
  It consists of a convolutional block followed by max-pooling and dropout.

  Returns:
    f: The output features of the convolution block (for the skip connection).
    p: The max-pooled features to be passed to the next block.
  '''
  f = conv2d_block(inputs, n_filters=n_filters)
  p = tf.keras.layers.MaxPooling2D(pool_size=pool_size)(f)
  p = tf.keras.layers.Dropout(dropout)(p)

  return f, p


def encoder(inputs):
  '''
  Defines the complete encoder (downsampling path) of the U-Net.

  Returns:
    p4: The output features from the final encoder block (to be passed to the bottleneck).
    (f1, f2, f3, f4): A tuple of feature maps from each encoder block for the skip connections.
  '''
  f1, p1 = encoder_block(inputs, n_filters=64, dropout=0.3)
  f2, p2 = encoder_block(p1, n_filters=128, dropout=0.3)
  f3, p3 = encoder_block(p2, n_filters=256, dropout=0.3)
  f4, p4 = encoder_block(p3, n_filters=512, dropout=0.3)

  return p4, (f1, f2, f3, f4)

### Bottleneck


A bottleneck block sits at the bottom of the "U" shape, between the encoder and decoder. It applies further convolutions to the feature map with the highest level of abstraction.

In [None]:
def bottleneck(inputs):
  '''
  This function defines the bottleneck convolutions that link the encoder and decoder.
  '''
  bottle_neck = conv2d_block(inputs, n_filters=1024)
  return bottle_neck

### Decoder

The decoder upsamples the feature maps back to the original image size. At each level, it merges the upsampled features with the corresponding high-resolution features from the encoder via a skip connection. This is the key mechanism that allows U-Net to produce highly detailed segmentation masks.

In [None]:
# Decoder Utilities

def decoder_block(inputs, conv_output, n_filters=64, kernel_size=3, strides=2, dropout=0.3):
  '''
  Defines one upsampling block of the decoder.
  It uses a transposed convolution to upsample, concatenates with the skip connection,
  and then applies a standard convolutional block.
  '''
  u = tf.keras.layers.Conv2DTranspose(n_filters, kernel_size, strides=strides, padding='same')(inputs)
  c = tf.keras.layers.concatenate([u, conv_output])
  c = tf.keras.layers.Dropout(dropout)(c)
  c = conv2d_block(c, n_filters, kernel_size=3)
  return c


def decoder(inputs, convs, output_channels):
  '''
  Defines the complete decoder (upsampling path) of the U-Net.

  Args:
    inputs (tensor): Input features from the bottleneck.
    convs (tuple): Feature maps from the encoder for skip connections.
    output_channels (int): Number of classes for the final segmentation map.
  '''
  f1, f2, f3, f4 = convs

  # Upsampling block 1: from 1024 to 512 filters
  c6 = decoder_block(inputs, f4, n_filters=512)
  # Upsampling block 2: from 512 to 256 filters
  c7 = decoder_block(c6, f3, n_filters=256)
  # Upsampling block 3: from 256 to 128 filters
  c8 = decoder_block(c7, f2, n_filters=128)
  # Upsampling block 4: from 128 to 64 filters
  c9 = decoder_block(c8, f1, n_filters=64)

  # Final output layer
  outputs = tf.keras.layers.Conv2D(output_channels, (1, 1), activation='softmax')(c9)

  return outputs

### Putting It All Together

In [None]:
def unet(input_shape, output_channels):
  '''
  Defines the complete U-Net model by connecting the encoder, bottleneck, and decoder.
  '''
  inputs = tf.keras.layers.Input(shape=input_shape)
  encoder_output, convs = encoder(inputs)
  bottle_neck = bottleneck(encoder_output)
  outputs = decoder(bottle_neck, convs, output_channels=output_channels)
  model = tf.keras.Model(inputs=inputs, outputs=outputs)
  return model

# Instantiate the model
model = unet(input_shape, NUM_CLASSES)

# See the resulting model architecture
model.summary()

## Compile and Train the Model

Now, we compile and train the model. Key choices here are:

* **Loss Function:** We use `sparse_categorical_crossentropy` since our ground-truth masks are integer-encoded (0, 1, 2...) for each pixel, while our model outputs a probability distribution over the classes for each pixel (thanks to the softmax activation).

* **Metric:** We use `MeanIoU` (Mean Intersection over Union), which provides a meaningful measure of how well the predicted segmentation masks overlap with the ground-truth masks.

* **Callbacks:** We use several callbacks to improve the training process:
    * `ModelCheckpoint`: Saves the best version of the model based on validation loss.
    * `EarlyStopping`: Halts training if the validation loss does not improve for a set number of epochs, preventing overfitting.
    * `ReduceLROnPlateau`: Decreases the learning rate if the training process stagnates, which can help the model escape local minima.

In [None]:
# Create an instance of the MeanIoU metric
iou_metric = MeanIoU(num_classes=NUM_CLASSES, sparse_y_true=True, sparse_y_pred=False)

def custom_mean_iou(y_true, y_pred):
    """
    A wrapper for the MeanIoU metric that handles potential shape issues.
    """
    # Keras metrics expect the last dimension to be 1 for the labels
    y_true = tf.expand_dims(y_true, axis=-1)

    # Update the state of the metric
    iou_metric.update_state(y_true, y_pred)

    # Return the result
    return iou_metric.result()

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss='sparse_categorical_crossentropy',
    metrics=[custom_mean_iou]
    )

In [None]:
checkpoint_path = 'surgical_tool_segmentation_best.h5'

# Callback to save the best model
model_checkpoint = ModelCheckpoint(checkpoint_path,
                                 monitor='val_loss',
                                 save_best_only=True,
                                 mode='min',
                                 verbose=1)

# Callback to stop training early if validation loss stops improving
early_stopper = EarlyStopping(patience=5,
                              monitor='val_loss',
                              mode='min',
                              verbose=1)

# Callback to reduce learning rate on plateau
reduce_lr = ReduceLROnPlateau(monitor='val_loss',
                              factor=0.2,
                              patience=3,
                              min_lr=0.00001,
                              verbose=1)

N_EPOCHS = 50

model_history = model.fit(train_generator,
                          epochs=N_EPOCHS,
                          validation_data=validation_generator,
                          callbacks=[model_checkpoint, early_stopper, reduce_lr]
                          )