In [3]:
## TORCH LIBRARY
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, random_split

## MONAI LIBRARY
from monai.networks.nets import BasicUNet as BU
from monai.losses.dice import DiceLoss
from monai.handlers import checkpoint_saver

## OTHER LIBRARIES
import argparse
import logging
import sys
import numpy as np
from pathlib import Path
import os
import itertools
import matplotlib.pyplot as plt
from dicom_contour.contour import *
from functools import reduce
import SimpleITK as sk
import copy

## WEIDGHTS & BIASES
import wandb
os.environ["WANDB_CONFIG_DIR"] = "/tmp"
from tqdm import tqdm

## IMPORT OTHER CLASSES
import import_ipynb
import DosePredictionDataset

## IGNORE WARNINGS
import warnings
warnings.filterwarnings('ignore')

Failed to load image Python extension: 


importing Jupyter notebook from DosePredictionDataset.ipynb


In [4]:
## FUNCTION FOR TRAINING
def train(device:str,
          input_dir:list,
          struct_types:list,
          epochs:list,
          learning_rate:list,
          dropout_rate:list,
          batch_size:list,
          out_dim:tuple,
          val_percent: float = 0.4,
          ):
    ## Create dataloaders for training and validation
    dataset = DosePredictionDataset.DosePrdictionDataset(input_dir, struct_types, out_dim)
    val_num = int(len(dataset) * val_percent)
    train_num = len(dataset) - val_num
    train_set, val_set = random_split(dataset, [train_num, val_num], generator=torch.Generator().manual_seed(42)) 
    # Login wandb account
    wandb.login() 
    # Create combinations of hyperparameters
    param_combinations = list(itertools.product(epochs, batch_size, dropout_rate, learning_rate))
    # Create a variable to store the best model and path to store to
    best_val_loss = float('inf')
    best_model = None
    best_param = None
    model_save_path = "/Users/wangyangwu/Documents/Maastro/NeuralNets/BasicUNet/saved_model/"
    ### ================================== TRAINING START ===================================== ###
    print("TRAINING STARTS ...")
    for comb in param_combinations:
        # Initiate and configure wandb runner
        run = wandb.init(reinit=True, project="Basic UNet")
        run.config.update({"epoch":comb[0],
                           "batch_size":comb[1],
                           "dropout_rate":comb[2],
                           "learning_rate":comb[3]})
        # Create data loaders for training and validation
        train_loader = DataLoader(train_set, shuffle=True, batch_size=run.config.batch_size)
        val_loader = DataLoader(val_set, shuffle=False, drop_last=True, batch_size=run.config.batch_size)
        # Create neural network model
        net = BU(spatial_dims=3,
                 in_channels=len(struct_types)+1, 
                 out_channels=1, 
                 features=(6, 16, 32, 64, 128, 16),
                 dropout=comb[2])
        wandb.watch(net, log='all', log_freq=1)
        # Set up optimizer
        optimizer = torch.optim.Adam(net.parameters(), lr=run.config.learning_rate) 
        # Create step counter
        sample_count = 0
        # Set up softmaxer for inputs
        softmaxer = nn.Softmax(dim=2)    
        
        for epoch in tqdm(range(run.config.epoch)):
            # Create epoch loss log variables
            epoch_loss_train = 0
            epoch_loss_val = 0  
            ## ========================== TRAINING SECTION =========================== ##
            net.train()
            for images, masks in train_loader:
                images = images.to(device=device, dtype=torch.float32)
                masks = masks.to(device=device, dtype=torch.float32)
                # forward pass
                image_pred = net(images)
                train_loss = DiceLoss().forward(softmaxer(image_pred),
                                                softmaxer(masks))
                epoch_loss_train += train_loss.item()
                # backword pass
                optimizer.zero_grad()
                train_loss.backward()
                print("backward")
                # optimizing
                optimizer.step()
                sample_count += len(images)
                # Log training loss
                wandb.log({"training_loss": train_loss.item()},step=sample_count)
            print(f"Training loss after epoch {epoch+1}: {epoch_loss_train}")                
            ## ========================== VALIDATION SECTION ========================== ##
            with torch.no_grad():
                for images, masks in val_loader:
                    images = images.to(device=device, dtype=torch.float32)
                    masks = masks.to(device=device, dtype=torch.float32)
                    image_pred = net(images)
                    val_loss = DiceLoss().forward(softmaxer(image_pred),
                                              softmaxer(masks))
                    epoch_loss_val += val_loss.item()
                    wandb.log({"validation_loss": val_loss.item()},step=sample_count)
            print(f"Validation loss after epoch {epoch+1}: {epoch_loss_val}")
            if best_val_loss >= epoch_loss_val:
                best_val_loss = epoch_loss_val
                best_model = copy.deepcopy(net)  
                best_param = comb
        run.finish()
    print("TRAINING FINISHED")
    print("Best validation loss is: ", best_val_loss)
    torch.save({'model_state_dict': best_model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'param': best_param,
                'loss': best_val_loss}, model_save_path+f"best_model_{struct_types}.pt")

In [5]:
# ## FUNCTION TO FIND THE BEST SIZE OF INPUT AND OUTPUT
# def find_best_outdim(ROI_names:list, patient_path_list:list):
# outdim_list = []
# CT_paths = []
# dose_paths = []
# sizes = []
# #get all CT paths of all patients
# for patient_path in patient_path_list:
#     subfolder_path = []
#     for roots, dirs, files in os.walk(patient_path):
#         subfolder_path.append(roots)
#     CT_paths += [input_path for input_path in subfolder_path if "/CT/" in input_path]
#     dose_paths += [input_path for input_path in subfolder_path if "/RTDOSE/" in input_path]
# #loop through all CT paths and get sizes of masks
# for path in CT_paths:
#     masks = []
#     #store dicom file
#     contour_file = get_contour_file(path)
#     contour_data = dicom.read_file(path + '/' + contour_file)
#     ROI_list = get_roi_names(contour_data)
#     print(ROI_list)
#     target_ROI_index = [ROI_list.index(r) for r in ROI_names]
#     for index in target_ROI_index:
#         images, contours = get_data(path, index=index)
#         #get contour maps
#         contour_slices = [contours[i] for i in range(contours.shape[0])]
#         contour_3d = [fill_contour(c) if c.max()==1 else c for c in contour_slices]
#         contour_3d = np.stack(contour_3d)
#         masks.append(contour_3d)
#     added_mask = reduce(lambda a, b: a+b, masks) 
#     cropped_added_mask = crop_zeros(added_mask)
#     ## filter sizes of noise data
#     size_original = added_mask.shape[0]*added_mask.shape[1]*added_mask.shape[2]
#     size_cropped = cropped_added_mask.shape[0]*cropped_added_mask.shape[1]*cropped_added_mask.shape[2]
#     if size_original/size_cropped>2:
#         sizes.append(cropped_added_mask.shape)
# for path in dose_paths:
#     file_ids = sk.ImageSeriesReader.GetGDCMSeriesIDs(str(path))
#     file_names = sk.ImageSeriesReader.GetGDCMSeriesFileNames(str(path), file_ids[0])
#     series_reader = sk.ImageSeriesReader()
#     series_reader.SetFileNames(file_names)
#     image_data = series_reader.Execute()
#     dose = sk.GetArrayFromImage(image_data)
#     sizes.append(dose[0].shape)
# #get the maximum x,y,z
# x,y,z = 0,0,0
# x = max([t[0] for t in sizes])
# y = max([t[1] for t in sizes])
# z = max([t[2] for t in sizes])
# return (x,y,z)

In [6]:
## FUNCTION TO FIND THE BEST SIZE OF INPUT AND OUTPUT BASED ON LUNG AND BODY CONTOUR
def find_best_outdim(patient_path_list:list):
    outdim_list = []
    CT_paths = []
    dose_paths = []
    sizes = []
    #get all CT paths of all patients
    for patient_path in patient_path_list:
        subfolder_path = []
        for roots, dirs, files in os.walk(patient_path):
            subfolder_path.append(roots)
        CT_paths += [input_path for input_path in subfolder_path if "/CT/" in input_path]
        dose_paths += [input_path for input_path in subfolder_path if "/RTDOSE/" in input_path]
    #loop through all CT paths and get sizes of masks
    for path in CT_paths:
        masks = []
        #store dicom file
        contour_file = get_contour_file(path)
        contour_data = dicom.read_file(path + '/' + contour_file)
        ROI_list = get_roi_names(contour_data)
        #get body and lung contours
        body_index = [i for i in range(len(ROI_list)) if "body" in ROI_list[i].lower()][0]
        lung_index = [i for i in range(len(ROI_list)) if "lungs-gtv" in ROI_list[i].lower()][0]
        body_images, body_contours = get_data(path, index=body_index)
        lung_images, lung_contours = get_data(path, index=lung_index)
        #get contour maps
        body_contour_slices = [body_contours[i] for i in range(body_contours.shape[0])]
        body_contour_3d = [fill_contour(c) if c.max()==1 else c for c in body_contour_slices]
        body_contour_3d = np.stack(body_contour_3d)
        lung_contour_slices = [lung_contours[i] for i in range(lung_contours.shape[0])]
        lung_contour_3d = [fill_contour(c) if c.max()==1 else c for c in lung_contour_slices]
        lung_contour_3d = np.stack(lung_contour_3d)
        #get cropped contours
        body_cropped_added_mask = crop_zeros(body_contour_3d)
        lung_cropped_added_mask = crop_zeros(lung_contour_3d)
        #get size
        sizes.append((lung_cropped_added_mask.shape[0],
                      body_cropped_added_mask.shape[1], 
                      body_cropped_added_mask.shape[2]))
    for path in dose_paths:
        file_ids = sk.ImageSeriesReader.GetGDCMSeriesIDs(str(path))
        file_names = sk.ImageSeriesReader.GetGDCMSeriesFileNames(str(path), file_ids[0])
        series_reader = sk.ImageSeriesReader()
        series_reader.SetFileNames(file_names)
        image_data = series_reader.Execute()
        dose = sk.GetArrayFromImage(image_data)
        sizes.append(dose[0].shape)
    #get the maximum x,y,z
    x,y,z = 0,0,0
    x = max([t[0] for t in sizes])
    y = max([t[1] for t in sizes])
    z = max([t[2] for t in sizes])
    return (x,y,z)

In [7]:
## FUNCTION TO CROP ZEROS IN AN ARRAY
def crop_zeros(ndarray):
    valid_data_coords = np.argwhere(ndarray)
    begin_nd_corners = valid_data_coords.min(axis=0)
    finish_nd_corners = valid_data_coords.max(axis=0) + 1
    ndslice = tuple(slice(begin, finish) for (begin, finish) in zip(begin_nd_corners, finish_nd_corners))
    return ndarray[ndslice]

In [9]:
top_dir = "/Users/wangyangwu/Documents/Maastro/NeuralNets/sample"
pathlist_patients = []
for folder in os.listdir(top_dir):
    folder_path = os.path.join(top_dir, folder)
    if os.path.isdir(folder_path):
        pathlist_patients.append(folder_path)          
print("Looking for the best output dimension ...")
out_dim = find_best_outdim(pathlist_patients)
print("Best dimension is : ",out_dim)

Looking for the best output dimension ...
Best dimension is :  (117, 295, 511)


In [None]:
ROI_names = ["Heart"]  
dropout_rate = [0.5]
learning_rate = [0.1]
epochs = [1]
batch_size = [1]
train("cpu", pathlist_patients, ROI_names, epochs, learning_rate, dropout_rate, batch_size, out_dim) 