In [1]:
## TORCH LIBRARY
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader, random_split

## MONAI LIBRARY
from monai.networks.nets import BasicUNet as BU
from monai.losses.dice import DiceLoss

## OTHER LIBRARIES
import argparse
import logging
import sys
import numpy as np
from pathlib import Path
import os
import itertools
import matplotlib.pyplot as plt

## WEIDGHTS & BIASES
import wandb
os.environ["WANDB_CONFIG_DIR"] = "/tmp"
from tqdm import tqdm

## IMPORT OTHER CLASSES
import import_ipynb
import DosePredictionDataset

## IGNORE WARNINGS
import warnings
warnings.filterwarnings('ignore')

  warn(f"Failed to load image Python extension: {e}")


importing Jupyter notebook from DosePredictionDataset.ipynb


In [2]:
## FUNCTION FOR TRAINING
def train_net(device:str,
              input_dir:list,
              struct_types:list,
              param_comb:list,
              batch_size:int,
              val_percent: float = 0.1,
              transform=None
              ):

    ## create dataloaders for training and validation
    dataset = DosePredictionDataset.DosePrdictionDataset(input_dir, 
                                                         struct_types)
    val_num = int(len(dataset) * val_percent)
    train_num = len(dataset) - val_num
    train_set, val_set = random_split(dataset, 
                                      [train_num, val_num], 
                                      generator=torch.Generator().manual_seed(42))    
    loader_args = dict(batch_size=batch_size, pin_memory=True)
    train_loader = DataLoader(train_set, shuffle=True, **loader_args)
    val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
    
    wandb.login() ## login wandb account
    
    ## Experiment with different combinations of hyperparameters 
    for comb in param_comb:     
        with wandb.init(
            project=f"Basic UNet (dropout-{comb[0]} & lr-{comb[1]} & epochs-{comb[2]})",
            config={
                "Architecture": "Basic UNet",
                "epochs": int(comb[2]),
                "batch_size": batch_size,
                "learning_rate": comb[1],
                "dropout": comb[0]
                }):
            config = wandb.config
            
            # create neural network model
            net = BU(spatial_dims=3,
                     in_channels=len(struct_types)+1, 
                     out_channels=1, 
                     features=(6, 16, 32, 64, 128, 16),
                     dropout=config.dropout)
            
            wandb.watch(net, log='all', log_freq=10)
            
            # set up optimizer
            optimizer = torch.optim.Adam(net.parameters(), 
                                      lr=config.learning_rate)

            sample_count = 0
            train_dice_loss = []
            val_dice_loss = []
            softmaxer = nn.Softmax(dim=2) #softmax function to apply to inputs
            
            # Training begins
            for epoch in tqdm(range(config.epochs)):
                epoch_loss_train = 0
                epoch_loss_val = 0
                
                # ======== TRAIN SECTION ========
                net.train()
                for images, masks in train_loader:
                    images = images.to(device=device, dtype=torch.float32)
                    masks = masks.to(device=device, dtype=torch.float32)
                    # forward pass
                    image_pred = net(images)
                    train_loss = DiceLoss().forward(softmaxer(image_pred),
                                                    softmaxer(masks))
                    epoch_loss_train += train_loss.item()
                    # backword pass
                    optimizer.zero_grad()
                    train_loss.backward()
                    # optimizing
                    optimizer.step()
                    sample_count += len(images)
                # logging
                wandb.log({"train_epoch_loss": epoch_loss_train, 
                           "epoch": sample_count/len(train_loader.dataset)}, 
                          step=sample_count)
                print(f"Training loss after epoch {epoch+1}: {epoch_loss_train}")
                train_dice_loss.append(epoch_loss_train)

                # ======== VALIDATION SECTION ========
                with torch.no_grad():
                    for images, masks in val_loader:
                        images = images.to(device=device, dtype=torch.float32)
                        masks = masks.to(device=device, dtype=torch.float32)
                        image_pred = net(images)
                        loss = DiceLoss().forward(softmaxer(image_pred),
                                                  softmaxer(masks))
                        epoch_loss_val += loss
                val_dice_loss.append(epoch_loss_val) 
                wandb.log({"val_epoch_loss": epoch_loss_val}, 
                          step=sample_count)

In [3]:
## MAIN METHOD
def run():
    top_dir = "/Users/wangyangwu/Documents/Maastro/NeuralNets/PROTON"
    pathlist_patients = []
    for folder in os.listdir(top_dir):
        folder_path = os.path.join(top_dir, folder)
        if os.path.isdir(folder_path):
            pathlist_patients.append(folder_path)
            
    dropout_rate = np.arange(0.5, 0.6, 0.1)
    learning_rate = np.array([0.0001])
    epochs = np.array([1])
    param_comb = list(itertools.product(dropout_rate, learning_rate, epochs))
    batch_size = 1
    
    train_net("cpu", pathlist_patients, ["Heart"], param_comb, batch_size) 

In [4]:
run()

[34m[1mwandb[0m: Currently logged in as: [33mwwy[0m (use `wandb login --relogin` to force relogin)


BasicUNet features: (6, 16, 32, 64, 128, 16).


  0%|                                                     | 0/1 [00:02<?, ?it/s]

(104, 108, 143)





VBox(children=(Label(value=' 0.00MB of 0.00MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

[34m[1mwandb[0m: [32m[41mERROR[0m Control-C detected -- Run data was not synced


TypeError: can't convert np.ndarray of type numpy.uint32. The only supported types are: float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.