In [2]:
## TORCH LIBRARY
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, random_split
import torchio as tio

## MONAI LIBRARY
from monai.networks.nets import BasicUNet as BU
from monai.losses.dice import DiceLoss
from monai.handlers import checkpoint_saver

## OTHER LIBRARIES
import argparse
import logging
import sys
import numpy as np
from pathlib import Path
import os
import itertools
import matplotlib.pyplot as plt
from dicom_contour.contour import *
from functools import reduce
import SimpleITK as sk
import copy
import itk

## WEIDGHTS & BIASES
import wandb
os.environ["WANDB_CONFIG_DIR"] = "/tmp"
from tqdm import tqdm

## IMPORT OTHER CLASSES
import import_ipynb
import DosePredictionDataset

## IGNORE WARNINGS
import warnings
warnings.filterwarnings('ignore')

Failed to load image Python extension: 


importing Jupyter notebook from DosePredictionDataset.ipynb


In [3]:
## FUNCTION FOR TRAINING
def train(device:str,
          patients_data:list,
          epochs:list,
          learning_rate:list,
          dropout_rate:list,
          batch_size:list,
          in_channels:int,
          val_percent: float = 0.4,
          ):
    ## Create dataloaders for training and validation
    dataset = DosePredictionDataset.DosePrdictionDataset(patients_data)
    val_num = int(len(dataset) * val_percent)
    train_num = len(dataset) - val_num
    train_set, val_set = random_split(dataset, [train_num, val_num], generator=torch.Generator().manual_seed(42)) 
    # Login wandb account
    wandb.login() 
    # Create combinations of hyperparameters
    param_combinations = list(itertools.product(epochs, batch_size, dropout_rate, learning_rate))
    # Create a variable to store the best model and path to store to
    best_val_loss = float('inf')
    best_model = None
    best_param = None
    model_save_path = "/Users/wangyangwu/Documents/Maastro/NeuralNets/BasicUNet/saved_model/"
    ### ================================== TRAINING START ===================================== ###
    print("TRAINING STARTS ...")
    for comb in param_combinations:
        # Initiate and configure wandb runner
        run = wandb.init(reinit=True, project="Basic UNet")
        run.config.update({"epoch":comb[0],
                           "batch_size":comb[1],
                           "dropout_rate":comb[2],
                           "learning_rate":comb[3]})
        # Create data loaders for training and validation
        train_loader = DataLoader(train_set, shuffle=True, batch_size=run.config.batch_size)
        val_loader = DataLoader(val_set, shuffle=False, drop_last=True, batch_size=run.config.batch_size)
        # Create neural network model
        net = BU(spatial_dims=3,
                 in_channels=in_channels, 
                 out_channels=1, 
                 features=(6, 16, 32, 64, 128, 16),
                 dropout=comb[2])
        wandb.watch(net, log='all', log_freq=1)
        # Set up optimizer
        optimizer = torch.optim.Adam(net.parameters(), lr=run.config.learning_rate) 
        # Create step counter
        sample_count = 0
        # Set up softmaxer for inputs
        softmaxer = nn.Softmax(dim=2)    
        
        for epoch in tqdm(range(run.config.epoch)):
            # Create epoch loss log variables
            epoch_loss_train = 0
            epoch_loss_val = 0  
            ## ========================== TRAINING SECTION =========================== ##
            net.train()
            for images, masks in train_loader:
                images = images.to(device=device, dtype=torch.float32)
                masks = masks.to(device=device, dtype=torch.float32)
                # forward pass
                image_pred = net(images)
                train_loss = DiceLoss().forward(softmaxer(image_pred),
                                                softmaxer(masks))
                epoch_loss_train += train_loss.item()
                # backword pass
                optimizer.zero_grad()
                train_loss.backward()
                # optimizing
                optimizer.step()
                sample_count += len(images)
                # Log training loss
                wandb.log({"training_loss": train_loss.item()},step=sample_count)
            print(f"Training loss after epoch {epoch+1}: {epoch_loss_train}")                
            ## ========================== VALIDATION SECTION ========================== ##
            with torch.no_grad():
                for images, masks in val_loader:
                    images = images.to(device=device, dtype=torch.float32)
                    masks = masks.to(device=device, dtype=torch.float32)
                    image_pred = net(images)
                    val_loss = DiceLoss().forward(softmaxer(image_pred),
                                              softmaxer(masks))
                    epoch_loss_val += val_loss.item()
                    wandb.log({"validation_loss": val_loss.item()},step=sample_count)
            print(f"Validation loss after epoch {epoch+1}: {epoch_loss_val}")
            if best_val_loss >= epoch_loss_val:
                best_val_loss = epoch_loss_val
                best_model = copy.deepcopy(net)  
                best_param = comb
        run.finish()
    print("TRAINING FINISHED")
    print("Best validation loss is: ", best_val_loss)
    best_param = {'epoch':best_param[0],
                  'batch_size':best_param[1],
                  'dropout_rate':best_param[2],
                  'learning_rate':comb[3]}
    torch.save({'model_state_dict': best_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'param': best_param,
                'loss': best_val_loss}, model_save_path+f"best_model_{struct_types}.pt")

In [11]:
def process_patients_data(patient_path_list:list, ROI_names:list):
    CT_paths = []
    dose_paths = []
    sizes = []
    INPUTs = []
    DOSEs = []
    for patient_path in patient_path_list:
        subfolder_path = []
        for roots, dirs, files in os.walk(patient_path):
            subfolder_path.append(roots)
        CT_paths += [input_path for input_path in subfolder_path if "/CT/" in input_path]
        dose_paths += [input_path for input_path in subfolder_path if "/RTDOSE/" in input_path]

    for CT_path ,dose_path in zip(CT_paths, dose_paths):
        # ============================== align CT origin with Dose ============================ #
        #get CT scan
        CT = extract_images(CT_path)
        CT_origin = np.array(list(CT.GetOrigin()))
        spacing = np.array(CT.GetSpacing())
        CT = sk.GetArrayFromImage(CT)
        #get Dose map
        DOSE = extract_images(dose_path)
        dose_origin = np.array(DOSE.GetOrigin()[:3])
        slicing_idx = np.absolute(np.ceil(((dose_origin-CT_origin)/spacing)))[::-1].astype(int)
        aligned_CT = CT[slicing_idx[0]:,slicing_idx[1]:, slicing_idx[2]:]
        # ======================= align Struct contour origin with Dose ======================= #        
        struct_contours = get_struct_contours(CT_path, ROI_names)
        aligned_struct_contours = [contour[slicing_idx[0]:,slicing_idx[1]:, slicing_idx[2]:] for contour in struct_contours]
        # ========================== get X slicing from body contour ========================== #   
        body_contour = get_struct_contours(CT_path, ["Body1"])[0]
        aligned_body_contour = body_contour[slicing_idx[0]:,slicing_idx[1]:, slicing_idx[2]:]
        X_slicing_idx = aligned_body_contour.shape[0]-1
        while X_slicing_idx>0:
            if np.sum(aligned_body_contour[X_slicing_idx]>0):
                X_slicing_idx-=1
            else:
                break
        # ==================================== final slicing  ==================================#   
        aligned_CT = aligned_CT[:X_slicing_idx, :, :]
        aligned_struct_contours = [contour[:X_slicing_idx,:,:] for contour in aligned_struct_contours]
        aligned_struct_contours.append(aligned_CT)
        INPUTs.append(torch.from_numpy(np.stack(aligned_struct_contours)).type(torch.float32))
        DOSEs.append(sk.GetArrayFromImage(DOSE))
        # =============================== add size to size list  ===============================# 
        x = aligned_CT.shape[0]
        y = aligned_CT.shape[1]
        z = aligned_CT.shape[2]
        sizes.append((x,y,z))  
    X = max([t[0] for t in sizes])
    Y = max([t[1] for t in sizes])
    Z = max([t[2] for t in sizes])
    max_outdim = (X,Y,Z)
    print(f"Best input size is: {max_outdim}")
    patients_data = []
    for inputs, dose in zip(INPUTs, DOSEs):
        patients_data.append([crop(inputs,max_outdim), crop(dose,max_outdim,True)])
    return patients_data

In [5]:
def get_struct_contours(path:str, ROI_names:list):
    contours_list = []
    #store dicom file
    contour_file = get_contour_file(path)
    contour_data = dicom.read_file(path + '/' + contour_file)
    ROI_list = get_roi_names(contour_data)
    target_ROI_index = []
    for name in ROI_names:
        for t in ROI_list:
            if name.lower()=="body1" and t.lower()=="body":   
                target_ROI_index.append(ROI_list.index(t))
            elif name.lower()==t.lower():
                target_ROI_index.append(ROI_list.index(t))
                
    # images, contours = get_data(path, index=target_ROI_index[0])
    for index in target_ROI_index:
        images, contours = get_data(path, index=index)
        #get contour maps
        contour_slices = [contours[i] for i in range(contours.shape[0])]
        contour_3d = [fill_contour(c) if c.max()==1 else c for c in contour_slices]
        contour_3d = np.stack(contour_3d)
        contours_list.append(contour_3d)
    return contours_list

In [6]:
## FUNCTION TO EXTRACT CT AND DOSE IMAGES
def extract_images(folder_path:str):
    path = Path(folder_path)
    file_ids = sk.ImageSeriesReader.GetGDCMSeriesIDs(str(path))
    file_names = sk.ImageSeriesReader.GetGDCMSeriesFileNames(str(path), file_ids[0])
    series_reader = sk.ImageSeriesReader()
    series_reader.SetFileNames(file_names)
    image_data = series_reader.Execute()
    return image_data   

In [7]:
## FUNCTION TO CROP INPUT IMAGES
def crop(images:object, out_dim:tuple, if_dose:bool=False):
    if if_dose:
        padding = 0
    else:
        padding = -1024
    cropper = tio.CropOrPad(target_shape=out_dim, padding_mode=padding)
    cropped = cropper(images)
    return cropped

In [8]:
def get_patient_list(top_dir):
    pathlist_patients = []
    for folder in os.listdir(top_dir):
        folder_path = os.path.join(top_dir, folder)
        if os.path.isdir(folder_path):
            pathlist_patients.append(folder_path)
    return pathlist_patients

In [9]:
## define parameters
top_dir = "/Users/wangyangwu/Documents/Maastro/NeuralNets/sample"
ROI_names = ["Heart"] 
dropout_rate = [0.5]
learning_rate = [0.1]
epochs = [1]
batch_size = [1]
in_channel = len(ROI_names)+1

In [10]:
pathlist_patients = get_patient_list(top_dir)
patients_data = process_patients_data(pathlist_patients, ROI_names)

Best input size is: (110, 440, 542)


In [None]:
train("cpu", patients_data, epochs, learning_rate, dropout_rate, batch_size, in_channel) 



TRAINING STARTS ...


VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

BasicUNet features: (6, 16, 32, 64, 128, 16).


  0%|                                                     | 0/1 [00:00<?, ?it/s]