In [50]:
## TORCH LIBRARY
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, random_split

## MONAI LIBRARY
from monai.networks.nets import BasicUNet as BU
from monai.losses.dice import DiceLoss

## OTHER LIBRARIES
import argparse
import logging
import sys
import numpy as np
from pathlib import Path
import os
import itertools
import matplotlib.pyplot as plt
from dicom_contour.contour import *
from functools import reduce

## WEIDGHTS & BIASES
import wandb
os.environ["WANDB_CONFIG_DIR"] = "/tmp"
from tqdm import tqdm

## IMPORT OTHER CLASSES
import import_ipynb
import DosePredictionDataset

## IGNORE WARNINGS
import warnings
warnings.filterwarnings('ignore')

In [2]:
## FUNCTION FOR TRAINING
def train_net(device:str,
              input_dir:list,
              struct_types:list,
              param_comb:list,
              batch_size:int,
              out_dim:tuple,
              val_percent: float = 0.1,
              ):

    ## create dataloaders for training and validation
    dataset = DosePredictionDataset.DosePrdictionDataset(input_dir, 
                                                         struct_types,
                                                         out_dim)
    val_num = int(len(dataset) * val_percent)
    train_num = len(dataset) - val_num
    train_set, val_set = random_split(dataset, 
                                      [train_num, val_num], 
                                      generator=torch.Generator().manual_seed(42))    
    loader_args = dict(batch_size=batch_size, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
    
    wandb.login() ## login wandb account
    
    ## Experiment with different combinations of hyperparameters 
    for comb in param_comb:     
        with wandb.init(
            project=f"Basic UNet (dropout-{comb[0]} & lr-{comb[1]} & epochs-{comb[2]})",
            config={
                "Architecture": "Basic UNet",
                "epochs": int(comb[2]),
                "batch_size": batch_size,
                "learning_rate": comb[1],
                "dropout": comb[0]
                }):
            config = wandb.config
            
            # create neural network model
            net = BU(spatial_dims=3,
                     in_channels=len(struct_types)+1, 
                     out_channels=1, 
                     features=(6, 16, 32, 64, 128, 16),
                     dropout=config.dropout)
            
            wandb.watch(net, log='all', log_freq=1)
            
            # set up optimizer
            optimizer = torch.optim.Adam(net.parameters(), 
                                      lr=config.learning_rate)

            sample_count = 0
            train_dice_loss = []
            val_dice_loss = []
            softmaxer = nn.Softmax(dim=2) #softmax function to apply to inputs
            
            # Training begins
            print("Training begins:")
            for epoch in tqdm(range(config.epochs)):
                epoch_loss_train = 0
                epoch_loss_val = 0

                # ======== TRAIN SECTION ========
                net.train()
                for images, masks in train_loader:
                    images = images.to(device=device, dtype=torch.float32)
                    masks = masks.to(device=device, dtype=torch.float32)
                    print("input loaded:")
                    # forward pass
                    image_pred = net(images)
                    train_loss = DiceLoss().forward(softmaxer(image_pred),
                                                    softmaxer(masks))
                    epoch_loss_train += train_loss.item()
                    print("foward finished")
                    # backword pass
                    optimizer.zero_grad()
                    train_loss.backward()
                    # optimizing
                    optimizer.step()
                    sample_count += len(images)
                # logging
                wandb.log({"train_epoch_loss": epoch_loss_train, 
                           "epoch": sample_count/len(train_loader.dataset)}, 
                          step=sample_count)
                print(f"Training loss after epoch {epoch+1}: {epoch_loss_train}")
                train_dice_loss.append(epoch_loss_train)

                # ======== VALIDATION SECTION ========
                with torch.no_grad():
                    for images, masks in val_loader:
                        images = images.to(device=device, dtype=torch.float32)
                        masks = masks.to(device=device, dtype=torch.float32)
                        image_pred = net(images)
                        loss = DiceLoss().forward(softmaxer(image_pred),
                                                  softmaxer(masks))
                        epoch_loss_val += loss
                val_dice_loss.append(epoch_loss_val) 
                wandb.log({"val_epoch_loss": epoch_loss_val}, 
                          step=sample_count)

In [65]:
## FUNCTION TO FIND THE BEST SIZE OF INPUT AND OUTPUT
def find_best_outdim(ROI_names:list, patient_path_list:list):
    outdim_list = []
    CT_paths = [] 
    sizes = []
    #get all CT paths of all patients
    for patient_path in patient_path_list:
        subfolder_path = []
        for roots, dirs, files in os.walk(patient_path):
            subfolder_path.append(roots)
        CT_paths += [input_path for input_path in subfolder_path if "/CT/" in input_path]
    #loop through all CT paths and get sizes of masks
    for path in CT_paths:
        masks = []
        #store dicom file
        contour_file = get_contour_file(path)
        contour_data = dicom.read_file(path + '/' + contour_file)
        ROI_list = get_roi_names(contour_data)
        target_ROI_index = [ROI_list.index(r) for r in ROI_names]
        images, contours = get_data(path, index=target_ROI_index[0])
        for index in target_ROI_index:
            images, contours = get_data(path, index=10)
            #get contour maps
            contour_slices = [contours[i] for i in range(contours.shape[0])]
            contour_3d = [fill_contour(c) if c.max()==1 else c for c in contour_slices]
            contour_3d = np.stack(contour_3d)
            masks.append(contour_3d)
        added_mask = reduce(lambda a, b: a+b, masks) 
        cropped_added_mask = crop_zeros(added_mask)
        sizes.append(cropped_added_mask.shape)
    #get the maximum x,y,z
    x,y,z = 0,0,0
    x = max([t[0] for t in sizes])
    y = max([t[1] for t in sizes])
    z = max([t[2] for t in sizes])
    return (x,y,z)

In [66]:
## FUNCTION TO CROP ZEROS IN AN ARRAY
def crop_zeros(ndarray):
    valid_data_coords = np.argwhere(ndarray)
    begin_nd_corners = valid_data_coords.min(axis=0)
    finish_nd_corners = valid_data_coords.max(axis=0) + 1
    ndslice = tuple(slice(begin, finish) for (begin, finish) in zip(begin_nd_corners, finish_nd_corners))
    return ndarray[ndslice]

In [67]:
## MAIN METHOD
def run():
    top_dir = "/Users/wangyangwu/Documents/Maastro/NeuralNets/sample2"
    pathlist_patients = []
    for folder in os.listdir(top_dir):
        folder_path = os.path.join(top_dir, folder)
        if os.path.isdir(folder_path):
            pathlist_patients.append(folder_path)
            
    ROI_names = ["Heart", "Lungs-GTV"]       
    out_dim = find_best_outdim(ROI_names,pathlist_patients)
            
    dropout_rate = np.arange(0.5)
    learning_rate = np.array([0.1])
    epochs = np.array([1])
    param_comb = list(itertools.product(dropout_rate, learning_rate, epochs))
    batch_size = 1
    
    train_net("cpu", pathlist_patients, ROI_names, param_comb, batch_size, out_dim) 

In [68]:
run()

BasicUNet features: (6, 16, 32, 64, 128, 16).
Training begins:


  0%|                                                     | 0/1 [00:00<?, ?it/s]

SC (166, 512, 586)
SM: (166, 512, 586)
X: torch.Size([3, 32, 89, 107]) y: torch.Size([1, 32, 89, 107])
input loaded:
foward finished


100%|█████████████████████████████████████████████| 1/1 [00:07<00:00,  7.24s/it]

Training loss after epoch 1: 0.969220757484436





VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,▁
train_epoch_loss,▁
val_epoch_loss,▁

0,1
epoch,1.0
train_epoch_loss,0.96922
val_epoch_loss,0.0
