# Coding exercise for Wang Lab
### Implementation of nnU-net for the segmentation of pancreas and pancreatic lesions on CT scan, with classication of the pancreatic lesions
### by Leo Chen
### August/September 2024

In [1]:
### IMPORTS
import os
import glob
#import util

import numpy as np
import pandas as pd
import random
import math
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as data
from torch.autograd import Variable
import torch.nn.init as init

from collections import defaultdict
from collections import Counter

from datetime import datetime

import SimpleITK as sitk
#import nibabel as nib

import json
import shutil

In [3]:
# check if cuda is working
torch.cuda.is_available()
torch.cuda.device_count()
torch.cuda.current_device()

0

In [11]:
### GLOBAL VARIABLES

# using GPU?
gpu = False


# directories where the files are
traindir = r'C:\Users\Leo\Documents\UHN-MedImg3D-ML-quiz\train'
valdir = r'C:\Users\Leo\Documents\UHN-MedImg3D-ML-quiz\validation'
testdir = r'C:\Users\Leo\Documents\UHN-MedImg3D-ML-quiz\test'

csvpath = r'C:\Users\Leo\OneDrive\Documents\GitHub\WangLabQuiz\csv files\trainval_metadata.csv'   # csv with the image dimensions, image and mask file paths

nnUNet_raw_dir = r'C:\Users\Leo\OneDrive\Documents\UHN-MedImg3D-ML-quiz\nnUnet_raw'

In [12]:
### FUNCTIONS FOR SITK and IMAGE AUGMENTATION

def rotateImage(original, anglex, angley, anglez, interpolate='linear'):
    """ Returns the 'rotated' 3d image about the physical center that is resampled based on the 'original' image
    1. original - original image 
    2. angle x is roll / twisting the body like a rolling pin, turning in dance
    3. angle y is yaw / rotating the body like a propeller blade, like break dancing
    4. angle z - pitch / tilt along the superior/inferior axis (i.e trendelenburg)
    
    """

    if interpolate == 'linear':
        interpolator = sitk.sitkLinear
    elif interpolate == 'NN':
        interpolator = sitk.sitkNearestNeighbor

    radx = anglex * math.pi / 180
    rady = angley * math.pi / 180
    radz = anglez * math.pi / 180

    origin = np.array(original.GetOrigin())
    pixelcenter = np.array(sitk.GetSize(original)) / 2.
    physicalcenter = sitk.TransformContinuousIndexToPhysicalPoint(pixelcenter)

    transform = sitk.Euler3DTransform()
    transform.SetCenter(physicalcenter)
    transform.SetRotation(radz, rady, radx)    # note the order is z, y, x

    unitvecs = np.transpose(np.reshape(original.GetDirection(), (-1, 3)))
    #print(unitvecs)
    matrix = np.reshape(transform.GetMatrix(), (-1, 3))
    inverse = np.linalg.inv(matrix)


    # the transform matrix is actually mapping backwards: post to pre
    # therefore the forward transformation is the inverse matrix
    transformedunitvecs = inverse @ unitvecs   # new i, j, k are columns
    #print(transformedunitvecs)
    newdirection = transformedunitvecs.flatten('F')    # flatten by column

    print(newdirection)
    neworigin = (matrix @ (origin - physicalcenter)) + physicalcenter

    rotatedImage = sitk.Resample(original, original, transform, interpolator)
    rotatedImage.SetDirection(newdirection)
    rotatedImage.SetOrigin(neworigin)

    return rotatedImage

def flipImage(original):
    """Flips an SimpleITK over left/right axis"""
    flipped = sitk.Flip(original, [True, False, False])
    return flipped

def flipslice(original):
    """Flips a numpy slice (2d image) """
    # flips 2D slice (reverses x indices)
    flipped = np.flipud(original)  #np.fliplr(original)
    return flipped

def bbox_3D(img):
    """Finds the bounding box around a 3D image (numpy)
    returns rmin, rmax, cmin, cmax, zmin, zmax (r = row, c = column)"""
    try:    
        z = np.any(img, axis=(1, 2))    #z
        c = np.any(img, axis=(0, 1))    #x , (c = column)
        r = np.any(img, axis=(0, 2))    #y , (r = row)

        rmin, rmax = np.where(r)[0][[0, -1]]
        cmin, cmax = np.where(c)[0][[0, -1]]
        zmin, zmax = np.where(z)[0][[0, -1]]

        #x min max, y min max, z min max
        return [rmin, rmax, cmin, cmax, zmin, zmax]
    except:
        return -1, -1, -1, -1, -1, -1


def bbox_2D(img):
    """Finds the bounding box around a 2D image (numpy)
    returns rmin, rmax, cmin, cmax (r = row, c = column)
    If no elements exist, then returns (-1, -1, -1, -1)"""
    
    try:
        c = np.any(img, axis=0)    #y , (c = column)
        r = np.any(img, axis=1)    #x , (r = row)

        rmin, rmax = np.where(r)[0][[0, -1]]
        cmin, cmax = np.where(c)[0][[0, -1]]
    
        return rmin, rmax, cmin, cmax
    except:
        return -1, -1, -1, -1


def cropImage(image, threshold, xshift, yshift):
    """Crops SimpleITK image to remove pixels below a threshold (e.g. black space)
    Can also shift by *xshift and *yshift (random shifts in pixels) for augmentation"""
    # load image
    npy = sitk.GetArrayFromImage(image)

    # GET METADATA
    direction = image.GetDirection()
    spacing = image.GetSpacing()

    # CALCULATE BOUNDING BOX OF BODY (removes black space)
    mask = npy > threshold
    [xmin, xmax, ymin, ymax, zmin, zmax] = bbox_3D(mask)

    # check to make sure shifts do not extend outside boundaries of image
    if xmin + xshift < 0 or xmax + xshift > npy.shape[2]:
        xshift = 0

    if ymin + yshift < 0 or ymax + yshift > npy.shape[1]:
        yshift = 0

    # CROP IMAGE
    newnpy = npy[zmin:zmax, (ymin+yshift):(ymax+yshift), (xmin+xshift):(xmax+xshift)]

    newimage = sitk.GetImageFromArray(newnpy)
    topleft = [int(xmin+xshift), int(ymin+yshift), zmin]
    neworigin = image.TransformIndexToPhysicalPoint(topleft)

    newimage.SetOrigin(neworigin)
    newimage.SetDirection(direction)
    newimage.SetSpacing(spacing)

    return newimage


def squareImage(image):
    """Makes an SimpleITK image square by padding with zeros
    (square meaning width = height)"""
    [numcols, numrows, numslices] = image.GetSize()
    npy = sitk.GetArrayFromImage(image)

    if numcols < numrows:    #pad columns
        numzerostopad = numrows - numcols
        leftpad = int(numzerostopad / 2)
        rightpad = numzerostopad - leftpad

        newnpy = np.concatenate((np.zeros([numslices, numrows, leftpad]), npy, np.zeros([numslices, numrows, rightpad])), axis=2)

        topleft = [-leftpad, 0, 0]
        neworigin = image.TransformIndexToPhysicalPoint(topleft)

    elif numrows <= numcols:  #pad rows
        numzerostopad = numcols - numrows
        toppad = int(numzerostopad / 2)
        botpad = numzerostopad - toppad

        newnpy = np.concatenate((np.zeros([numslices, toppad, numcols]), npy, np.zeros([numslices, botpad, numcols])), axis=1)

        topleft = [0, -toppad, 0]
        neworigin = image.TransformIndexToPhysicalPoint(topleft)

    paddedimg = sitk.GetImageFromArray(newnpy)
    paddedimg.SetOrigin(neworigin)
    paddedimg.SetDirection(image.GetDirection())
    paddedimg.SetSpacing(image.GetSpacing())

    return paddedimg

def resampleImage(image, finalsize, interpolation='linear'):
    """Resamples SimpleITK image to finalsize x finalsize (width and height in pixels)
    Preserves the original physical size of the image and number of slices
    Changes the resolution so that the new image has numslices x *finalsize x *finalsize dimensions"""
    
    size = image.GetSize()
    numslices = size[2]
    squaresize = size[1]

    # RESAMPLE TO finalsize x finalsize
    finalnpy = np.zeros([numslices, finalsize, finalsize])
    reference = sitk.GetImageFromArray(finalnpy)
    reference.SetOrigin(image.GetOrigin())
    reference.SetDirection(image.GetDirection())

    spacing = image.GetSpacing()
    newspacing = np.zeros(3)
    newspacing[0:2] = (squaresize - 1) * np.array(spacing[0:2]) / (finalsize - 1)
    newspacing[2] = spacing[2]
    reference.SetSpacing(newspacing)


    # MAKING RESAMPLING FILTER
    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(reference)
    if interpolation == 'linear':
        resample.SetInterpolator(sitk.sitkLinear)
    elif interpolation == 'NN':
        resample.SetInterpolator(sitk.sitkNearestNeighbor)

    # RESAMPLE TO finalsize x finalsize x n
    resampledimg = resample.Execute(image)

    return resampledimg


def projectImage(reference, moving, interpolate = 'linear'):
    """Projects an SimpleITK image (*moving onto *reference)
    interpolate* = linear or NN (nearest neighbor)"""
    
    resample = sitk.ResampleImageFilter()
    resample.SetReferenceImage(reference)
    if interpolate == 'linear':
        resample.SetInterpolator(sitk.sitkLinear)
    elif interpolate == 'NN':
        resample.SetInterpolator(sitk.sitkNearestNeighbor)

    resampledimg = resample.Execute(moving)

    return resampledimg


def resampleImageToVoxelSize(image, voxelx, voxely, voxelz, interpolation='linear'):
    """Resamples SimpleITK *image* to spacing *[voxelx, voxely, voxelz] in mm
    Preserves the original physical size of the image
    *voxelz is slice thickness (usually)
    *voxelx and *voxely are voxel width and height, respectively
    """
    
    original_spacing = image.GetSpacing()
    original_size = image.GetSize()
    
    new_spacing = [voxelx, voxely, voxelz]
    new_size = [int(round(osz*ospc/nspc)) for osz,ospc,nspc in zip(original_size, original_spacing, new_spacing)]
    # new dimension will be original size * original spacing / new spacing
    # based on physical distance formula: 
    #    original size (pixel) * original spacing (mm / pixel) = new size (pixel) * new spacing (mm / pixel)
    
    if interpolation == 'linear':
        interpolator = sitk.sitkLinear
    elif interpolation == 'NN':
        interpolator = sitk.sitkNearestNeighbor
    
    # creates new image
    new_image = sitk.Resample(image, new_size, sitk.Transform(), interpolator,
                         image.GetOrigin(), new_spacing, image.GetDirection(), 0,
                         image.GetPixelID())
    
    return new_image


def windowImage(image, window_width, window_center, output_min=0, output_max=255):
    """Normalizes SimpleITK *image* (CT scan) based on window specification
    (example, abdominal soft tissue window is W = 400, C = 50, or -350 to 450)
    Clips values above 0 and 1
    """
    
    window_min = window_center - window_width / 2
    window_max = window_center + window_width / 2
    
    output_min = 0
    output_max = 255
    
    windowed_image = sitk.IntensityWindowing(image, window_min, window_max, output_min, output_max)
    
    return windowed_image
    


### dataset.json

In [18]:
raw_folder = r'C:\Users\Leo\Documents\UHN-MedImg3D-ML-quiz\nnUnet_raw'
data_folder = r'C:\Users\Leo\Documents\UHN-MedImg3D-ML-quiz\nnUnet_raw\Dataset002_PancreasLesion'


In [17]:
from nnunetv2.training.dataloading.data_loader_3d_classify import nnUNetDataLoader3Dclassify
from nnunetv2.training.dataloading.data_loader_2d_classify import nnUNetDataLoader2Dclassify
from nnunetv2.utilities.plans_handling.plans_handler import PlansManager, ConfigurationManager
from nnunetv2.training.dataloading.nnunet_dataset_classify import nnUNetDatasetClassify
from nnunetv2.utilities.label_handling.label_handling import LabelManager
from batchgenerators.utilities.file_and_folder_operations import load_json, join
from nnunetv2.training.dataloading.base_data_loader_classify import nnUNetDataLoaderBaseClassify
from nnunetv2.training.dataloading.nnunet_dataset_classify import nnUNetDatasetClassify

import numpy as np

In [43]:


#folder = r'C:\Users\Leo\Documents\UHN-MedImg3D-ML-quiz\nnUNet_preprocessed\Dataset002_PancreasLesion\nnUNetPlans_3d_fullres'
folder = r'C:\Users\Leo\Documents\UHN-MedImg3D-ML-quiz\nnUNet_preprocessed\Dataset002_PancreasLesion\nnUNetPlans_2d'
ds = nnUNetDatasetClassify(folder, None, 0)


nnUNet_preprocessed = r'C:\Users\Leo\Documents\UHN-MedImg3D-ML-quiz\nnUNet_preprocessed\Dataset002_PancreasLesion'
dataset_json = load_json(r'C:\Users\Leo\Documents\UHN-MedImg3D-ML-quiz\nnUNet_preprocessed\Dataset002_PancreasLesion\dataset.json')

plans = load_json(join(nnUNet_preprocessed, 'ClassifyinnUNetPlans.json'))

plans_manager = PlansManager(plans)

label_manager = plans_manager.get_label_manager(dataset_json)



In [3]:
dl = nnUNetDataLoaderBaseClassify(data=ds,
                                  batch_size=3,
                                  patch_size=[64, 128, 192],
                                  final_patch_size= [64, 128, 192],
                                  label_manager=label_manager
                                  )

print(dl)

<nnunetv2.training.dataloading.base_data_loader_classify.nnUNetDataLoaderBaseClassify object at 0x000002AFAAC2AA80>


In [4]:


dl3D = nnUNetDataLoader3Dclassify(data=ds,
                                  batch_size=3,
                                  patch_size=[64, 128, 192],
                                  final_patch_size= [64, 128, 192],
                                  label_manager=label_manager
                                  )

In [44]:
selected_keys = dl.get_indices()
# preallocate memory for data and seg
data_all = np.zeros(dl.data_shape, dtype=np.float32)
seg_all = np.zeros(dl.seg_shape, dtype=np.int16)
#lesion_all = np.zeros((len(selected_keys), 3))          # probability classes, one hot
lesion_all = np.zeros(len(selected_keys))

case_properties = []

for j, i in enumerate(selected_keys):     # 'i' is the key (quiz_2_413) and 'j' is the index 0-2
    if j == 0:
        # oversampling foreground will improve stability of model training, especially if many patches are empty
        # (Lung for example)
        #force_fg = self.get_do_oversample(j)
        force_fg = True
    
        data, seg, properties, lesion = dl._data.load_case(i)
        case_properties.append(properties)
    
        # If we are doing the cascade then the segmentation from the previous stage will already have been loaded by
        # self._data.load_case(i) (see nnUNetDataset.load_case)
        shape = data.shape[1:]
        dim = len(shape)
        bbox_lbs, bbox_ubs = dl.get_bbox(shape, force_fg, properties['class_locations'])

In [47]:
print(i)
print(bbox_lbs, bbox_ubs)

pancreas_locs = properties['class_locations'][1]
lesion_locs = properties['class_locations'][2]

quiz_1_034
[21, -8, -12] [85, 120, 180]


In [49]:
print(pancreas_locs[0])

[ 0 37 51 13]


In [48]:
print(pancreas_locs.shape)
print(lesion_locs.shape)
print(data.shape)
print(seg.shape)

(10000, 4)
(10000, 4)
(1, 75, 113, 169)
(1, 75, 113, 169)


In [5]:
dict = dl3D.generate_train_batch()

In [6]:
for key, value in dict.items():
    print(key)

data
target
keys
lesion_class


In [10]:
data = dict['data']
target = dict['target']
keys = dict['keys']
lesion_class = dict['lesion_class']

In [8]:
print(keys)
print(lesion_class)

['quiz_1_516' 'quiz_1_025' 'quiz_0_313']
[1. 1. 0.]


In [14]:
print(data.shape)
print(target.shape)

(3, 1, 64, 128, 192)
(3, 1, 64, 128, 192)
1.0


In [12]:
img = data[0, 0, :]
seg = target[0, 0, :]

In [3]:
pickle_path = r'C:\Users\Leo\Documents\UHN-MedImg3D-ML-quiz\nnUNet_preprocessed\Dataset002_PancreasLesion\nnUNetPlans_3d_fullres\quiz_0_041.pkl'

pickle = pd.read_pickle(pickle_path)

In [4]:
print(pickle)

{'sitk_stuff': {'spacing': (0.7049999833106995, 0.7049999833106995, 0.801025390625), 'origin': (-164.0625, -180.4687042236328, 1647.5), 'direction': (1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0)}, 'spacing': [0.801025390625, 0.7049999833106995, 0.7049999833106995], 'shape_before_cropping': (106, 116, 161), 'bbox_used_for_cropping': [[0, 106], [0, 116], [0, 161]], 'shape_after_cropping_and_before_resampling': (106, 116, 161), 'class_locations': {1: array([[ 0, 33, 45, 55],
       [ 0, 29, 25, 22],
       [ 0, 24, 25, 38],
       ...,
       [ 0, 27, 55, 34],
       [ 0,  8, 38, 70],
       [ 0, 35, 44, 28]], dtype=int64), 2: array([[ 0, 30, 39, 51],
       [ 0, 34, 43, 31],
       [ 0, 34, 43, 28],
       ...,
       [ 0, 30, 41, 37],
       [ 0, 29, 36, 27],
       [ 0, 28, 53, 38]], dtype=int64)}}


In [11]:
class_locations = pickle['class_locations']


In [12]:
print(class_locations)


{1: array([[ 0, 33, 45, 55],
       [ 0, 29, 25, 22],
       [ 0, 24, 25, 38],
       ...,
       [ 0, 27, 55, 34],
       [ 0,  8, 38, 70],
       [ 0, 35, 44, 28]], dtype=int64), 2: array([[ 0, 30, 39, 51],
       [ 0, 34, 43, 31],
       [ 0, 34, 43, 28],
       ...,
       [ 0, 30, 41, 37],
       [ 0, 29, 36, 27],
       [ 0, 28, 53, 38]], dtype=int64)}


In [13]:
eligible_classes_or_regions = [i for i in class_locations.keys() if len(class_locations[i]) > 0]

print(eligible_classes_or_regions)

[1, 2]
