In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# Check nvcc version
!nvcc -V
# Check GCC version
!gcc --version

# Install PyTorch
!pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html
# Install MMCV
!pip install mmcv==2.0.0rc4 -f https://download.openmmlab.com/mmcv/dist/cu113/torch1.11.0/index.html
# Install MMSegmentation
!pip install mmsegmentation
# Install MMEngine
!pip install mmengine

# Install ftfy
!pip install ftfy

In [None]:
# Check Pytorch installation
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())

# Check MMSegmentation installation
import mmseg
print(mmseg.__version__)

In [None]:
# Required imports definition

# Standard Imports
import os
import numpy as np
import torch

from mmseg.apis import init_model, inference_model
import matplotlib.pyplot as plt

from PIL import Image
from tqdm import tqdm

from mmengine import Config

from mmseg.registry import DATASETS
from mmseg.datasets import BaseSegDataset

# PreProcessing Class Imports
from skimage import util
import skimage.exposure as exposure
from scipy import ndimage

from mmcv.transforms import BaseTransform, TRANSFORMS

# Interpolation Imports
from scipy.ndimage import distance_transform_edt

# Inference Imports
import torch.nn.functional as F

In [None]:
# Include the name of your group
group_name = 'MS1'

In [None]:
#################### INSTRUCTIONS FOR FOLDER MANAGEMENT ########################
#
# Put both config file and checkpoint in the same folder of your notebook
# Name config file as config.py
# Name checkpoint as checkpoint.pth
#
# N.B ::: Leave this as it is, we will input the required paths
#

submission_folder = '...'
print(f'submission folder: {submission_folder}')

checkpoint_path = os.path.join(submission_folder, 'checkpoint.pth')
cfg_path = os.path.join(submission_folder, 'config.py')

test_img_folder = '...'

results_folder = os.path.join(submission_folder, f'results_{group_name}')

os.makedirs(results_folder, exist_ok=True)

In [None]:
classes = ('background',
           'tumor')

paletteProstate = [
    (0, 0, 0),       # background - black
    (255, 255, 255), # tumor - white
]

@DATASETS.register_module()
class ProstateMRI(BaseSegDataset):
    METAINFO = dict(classes = classes, palette = paletteProstate)
    def __init__(self, **kwargs):
        super().__init__(img_suffix='.png',
                        seg_map_suffix='.png',
                        reduce_zero_label = False,
                        **kwargs)

In [None]:
# Load the config file and print to see if it's correct
cfg = Config.fromfile(cfg_path)
print(f'Config:\n{cfg.pretty_text}')

# Init the model from the config and the checkpoint
model = init_model(cfg, checkpoint_path, 'cuda:0')

In [None]:
# Definition of pre-processing functions

# Function to find black borders in T2W channel
def findBorders(t2w):
    # FIND BLACK BORDERS IN T2W CHANNEL
    # Make t2w boolean
    t2w_bool = t2w > 0
    # Find non-black rows and columns
    rows = np.any(t2w_bool, axis=1)
    cols = np.any(t2w_bool, axis=0)
    # Get first and last non-black row and column indexes
    rmin, rmax = np.where(rows)[0][[0, -1]] if np.any(rows) else (0, t2w.shape[0])
    cmin, cmax = np.where(cols)[0][[0, -1]] if np.any(cols) else (0, t2w.shape[1])
    return rmin, rmax, cmin, cmax

# Function to crop and resize images according to black borders in T2W channel
def cropResizeImages(adc, hbv, t2w):
    # CROP IMAGES
    # Get original image size
    original_shape = t2w.shape
    # Find black borders in t2w channel
    rmin_t2w, rmax_t2w, cmin_t2w, cmax_t2w = findBorders(t2w)
    # Get final cropped image size
    cropDimension = min(rmax_t2w - rmin_t2w, cmax_t2w - cmin_t2w)
    if cropDimension % 2 != 0: # ensure crop_size is even
        cropDimension -= 1
    if abs(t2w.shape[0] - cropDimension) < 5: # do not crop if the black region is too small
        cropDimension = t2w.shape[0]
        flagCrop = False
        # resize images to 256x256
        resizeShape = (256, 256)
        adc_resized = np.array(Image.fromarray(adc, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)
        hbv_resized = np.array(Image.fromarray(hbv, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)
        t2w_resized = np.array(Image.fromarray(t2w, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)
        # Create the dictionary to store resize metadata
        cropMetadata = {'flagCrop': flagCrop, 'original_shape': original_shape, 'resizeShape': resizeShape, 'cropDimension': cropDimension}
        return adc_resized, hbv_resized, t2w_resized, cropMetadata
    else:
        flagCrop = True

    # Get center of the non-black region
    center_x = (cmin_t2w + cmax_t2w) // 2
    center_y = (rmin_t2w + rmax_t2w) // 2
    # Get start and end coordinates for cropping
    start_x = max(center_x - cropDimension // 2, 0)
    start_y = max(center_y - cropDimension // 2, 0)
    end_x = min(start_x + cropDimension, t2w.shape[1])
    end_y = min(start_y + cropDimension, t2w.shape[0])

    # Crop images at the center of the non-black region
    adc_cropped = adc[start_y:end_y, start_x:end_x]
    hbv_cropped = hbv[start_y:end_y, start_x:end_x]
    t2w_cropped = t2w[start_y:end_y, start_x:end_x]

    cropped_shape = t2w_cropped.shape

    # RESIZE IMAGES
    # Resize cropped images to 256x256
    resizeShape = (256, 256)
    adc_cropped_resized = np.array(Image.fromarray(adc_cropped, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)
    hbv_cropped_resized = np.array(Image.fromarray(hbv_cropped, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)
    t2w_cropped_resized = np.array(Image.fromarray(t2w_cropped, mode='L').resize(resizeShape, Image.LANCZOS), dtype=np.uint8)

    # Create a dictionary to store crop metadata
    cropMetadata = {'start_x': start_x, 'start_y': start_y, 'end_x': end_x, 'end_y': end_y, 'cropDimension': cropDimension, 'cropped_shape': cropped_shape, 'center_x': center_x, 'center_y': center_y, 'original_shape': original_shape, 'resizeShape': resizeShape, 'flagCrop': flagCrop}

    return adc_cropped_resized, hbv_cropped_resized, t2w_cropped_resized, cropMetadata

# Function to restore the mask to the original size
def restoreSizeMask(mask, cropMetadata):
    # RESTORE AUTOMATIC MASK TO ORIGINAL SIZE
    # Check if the image was not cropped
    if not cropMetadata['flagCrop']:
        mask_restored = np.array(Image.fromarray(mask, mode='L').resize(cropMetadata['original_shape'], Image.NEAREST), dtype=np.uint8)
        return mask_restored
    # If the image was cropped, restore the mask to the original size
    else:
        mask_restored = np.zeros(cropMetadata['original_shape'], dtype=np.uint8)
        mask_cropDim = np.array(Image.fromarray(mask, mode='L').resize(cropMetadata['cropped_shape'], Image.NEAREST), dtype=np.uint8)
        mask_restored[cropMetadata['start_y']:cropMetadata['end_y'], cropMetadata['start_x']:cropMetadata['end_x']] = mask_cropDim

        return mask_restored

# Function to apply pre-processing to the images
def minMaxScaling(img):
    if np.max(img) != np.min(img):
        img = (img - np.min(img)) / (np.max(img) - np.min(img))
        img = np.array(img*255, dtype=np.uint8)

    return img


In [None]:
# Definition of post-processing function

# Function to apply interpolation between two masks
def interp_mask(mask1, mask2):
  # mask1 and mask2 must be two boolean numpy arrays that represent tumoral slices
  mask1 = np.array(mask1, dtype=bool)
  mask2 = np.array(mask2, dtype=bool)
  d1 = distance_transform_edt(mask1) - distance_transform_edt(~mask1)
  d2 = distance_transform_edt(mask2) - distance_transform_edt (~mask2)
  interpolated_mask = (d1 + d2) > 0
  interpolated_mask = np.array(interpolated_mask, dtype=np.uint8)

  return interpolated_mask

In [None]:
# Get the list of patients
patientList = os.listdir(test_img_folder)

# Loop over the test patients
for patient in tqdm(patientList):

  # Create patient folder in the results folder
  os.makedirs(os.path.join(results_folder, patient, "automatic"), exist_ok=True)

  # Get the list of patient's slices
  sliceList = os.listdir(os.path.join(test_img_folder, patient, "adc"))

  # Initialization of elements needed for post-processing
  tumorVolume = 0
  numTumoralSlices = 0
  hasTumor = False
  listTumoralSlices = []

  # Loop over the slices
  for slice in sliceList:
    if slice.endswith('.png'):

      ###############################
      #######      Load      ########
      ###############################

      adc_path = os.path.join(test_img_folder, patient, "adc", slice)
      hbv_path = os.path.join(test_img_folder, patient, "hbv", slice)
      t2w_path = os.path.join(test_img_folder, patient, "t2w", slice)

      adc = np.array(Image.open(adc_path))
      hbv = np.array(Image.open(hbv_path))
      t2w = np.array(Image.open(t2w_path))

      ###############################
      ####### Pre-Processing ########
      ###############################

      # Crop
      adc, hbv, t2w, cropMetadata = cropResizeImages(adc, hbv, t2w)

      # Min-Max Scaling
      adc = minMaxScaling(adc)
      hbv = minMaxScaling(hbv)
      t2w = minMaxScaling(t2w)

      # Create multimodal RGB image
      stacked = np.stack((adc, hbv, t2w), axis=-1)

      ###############################
      #######    Inference   ########
      ###############################

      result = inference_model(model, stacked)

      # Get data from the result
      # Logits extraction
      logitsScaleFactor = 0.5
      seg_logitsBG = logitsScaleFactor * result.seg_logits.data[0]
      seg_logitsTUM = logitsScaleFactor * result.seg_logits.data[1]
      seg_logits = torch.stack([seg_logitsBG, seg_logitsTUM], dim=0)

      # Softmax Function Application (probability extraction)
      seg_probs = F.softmax(seg_logits, dim=0)
      seg_probs = seg_probs.cpu().numpy()

      # Manual Probability Thresholding
      threshold = 0.6
      pred_label = (seg_probs[1] > threshold).astype(np.uint8)

      ###############################
      ### Computation of elements ###
      ###       needed for        ###
      ###     post-processing     ###
      ###############################

      # Computation of elements needed for post-processing
      if np.sum(pred_label) > 0:
          hasTumor = True
          numTumoralSlices += 1
          tumorVolume += np.sum(pred_label.astype(bool))
          listTumoralSlices.append(slice)

      ###############################
      #######   Save results  #######
      ###############################

      # Restore the original size
      pred_label = restoreSizeMask(pred_label, cropMetadata)
      pred_label = Image.fromarray(pred_label, mode='L')

      # Save the result
      pred_label.save(os.path.join(results_folder, patient, "automatic", slice))


  ###############################
  ####### Post-Processing #######
  ###############################

  # Check if the patient has a tumor: if so, apply post-processing techniques
  if hasTumor:
    # Extraction of the first slice of the patient to get the shape of the slices
    exampleSliceAdc = np.array(Image.open(os.path.join(test_img_folder, patient, "adc", sliceList[0])))
    shapeSlices = exampleSliceAdc.shape

    # FIRST, SECOND, SECOND-LAST AND LAST SLICE CHECK: Check if first, second, second-last and/or last slices are tumoral and, if so, remove them from the list of tumoral slices and save black masks instead
    slicesToCheck = [f'slice_{i}.png' for i in [0, 1, len(sliceList) - 2, len(sliceList) - 1]]
    # Loop over the slices to check
    for actCheckingSlice in slicesToCheck:
      if actCheckingSlice in listTumoralSlices:
        # Create a black mask
        blackMask = Image.fromarray(np.zeros((shapeSlices[0], shapeSlices[1]), dtype=np.uint8)).convert('L')
        # Save the black mask
        blackMask.save(os.path.join(results_folder, patient, "automatic", actCheckingSlice))
        # Remove the slice from the list of tumoral slices
        listTumoralSlices.remove(actCheckingSlice)

    # INTERPOLATION: Check if there are black slices between two tumoral slices: if so, interpolate all the black slices with the tumoral slices before and after
    maxBlackSlices = 3 # maximum number of black slices between two tumoral slices to interpolate: if there are more than 3 black slices, it coulde be a bad segmentation and so we do not interpolate
    listTumoralSlices.sort()
    # Extract the number of the tumoral slices from the list. They are called slice_n.png, so extract "n"
    listTumoralSlices = [int(slice.split('.')[0].split('_')[1]) for slice in listTumoralSlices]
    # Check how many black slices there are between two tumoral slices using difference between consecutive elements
    diffTumoralSlices = np.diff(listTumoralSlices) - 1
    # Loop over the differences
    for i in range(len(diffTumoralSlices)):
      if diffTumoralSlices[i] > 0:
        if diffTumoralSlices[i] <= maxBlackSlices:
          # Loop over the black slices between two tumoral slices
          for j in range(1, diffTumoralSlices[i]+1):
            # Get the top and bottom tumoral slices (called mask1 and mask2 in the function and in the report)
            top = np.array(Image.open(os.path.join(results_folder, patient, "automatic", f"slice_{listTumoralSlices[i]}.png")))
            bottom = np.array(Image.open(os.path.join(results_folder, patient, "automatic", f"slice_{listTumoralSlices[i+1]}.png")))
            # Interpolate the black slice with the tumoral slices
            black = interp_mask(top, bottom)
            black = Image.fromarray(black, mode='L')
            # Save the interpolated slice
            black.save(os.path.join(results_folder, patient, "automatic", f"slice_{listTumoralSlices[i]+j}.png"))