In [1]:
import logging
import os
import shutil
import sys
import tempfile
from glob import glob

import nibabel as nib
import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import NiftiDataset, create_test_image_3d, list_data_collate
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import \
    Compose, LoadNiftid, AddChanneld, ScaleIntensityRanged, CropForegroundd, \
    RandCropByPosNegLabeld, RandSpatialCropd, RandAffined, Spacingd, Orientationd, ToTensord
from monai.visualize import plot_2d_or_3d_image

from torch import nn
monai.config.print_config()

from typing import Optional, Union

import warnings

from monai.networks import one_hot
from monai.utils import MetricReduction

MONAI version: 0.1.0+626.g63eec3a.dirty
Python version: 3.7.4 (default, Jul 18 2019, 19:34:02)  [GCC 5.4.0]
Numpy version: 1.18.1
Pytorch version: 1.5.0

Optional dependencies:
Pytorch Ignite version: 0.3.0
Nibabel version: 3.1.0
scikit-image version: 0.14.2
Pillow version: 7.0.0
Tensorboard version: 2.1.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [10]:
print('EU toy data 4-fold segmentation for micro-challenge: fold-2')
import glob

# Supervised learning data for training and validation
data_dir = '/home/marafath/scratch/eu_challenge/synthetic_data'
train_images = sorted(glob.glob(os.path.join(data_dir, '*_vol.nii.gz')))
train_labels = sorted(glob.glob(os.path.join(data_dir, '*_labels.nii.gz')))
data_dicts = [{'image': image_name, 'label': label_name}
              for image_name, label_name in zip(train_images, train_labels)]

fold = 2
epc = 100
train_files = data_dicts[0:5] #data_dicts[0:48] + data_dicts[72:96]
val_files = data_dicts[48:52] #data_dicts[48:72]

EU toy data 4-fold segmentation for micro-challenge: fold-2


In [11]:
print(len(train_files), len(val_files))

5 4


In [19]:
# Defining Transform
train_transforms = Compose([
    LoadNiftid(keys=['image', 'label']),
    AddChanneld(keys=['image', 'label']),
    Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 1.5), mode=('bilinear', 'nearest')),
    Orientationd(keys=['image', 'label'], axcodes='RAS'),
    ScaleIntensityRanged(keys=['image'], a_min=-1250, a_max=250, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=['image', 'label'], source_key='image'),
    RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', spatial_size=(96, 96, 96), pos=1,
                       neg=1, num_samples=4, image_key='image', image_threshold=0),
    ToTensord(keys=['image', 'label'])
])
val_transforms = Compose([
    LoadNiftid(keys=['image', 'label']),
    AddChanneld(keys=['image', 'label']),
    Spacingd(keys=['image', 'label'], pixdim=(1.5, 1.5, 1.5), mode=('bilinear', 'nearest')),
    Orientationd(keys=['image', 'label'], axcodes='RAS'),
    ScaleIntensityRanged(keys=['image'], a_min=-1250, a_max=250, b_min=0.0, b_max=1.0, clip=True),
    CropForegroundd(keys=['image', 'label'], source_key='image'),
    RandCropByPosNegLabeld(keys=['image', 'label'], label_key='label', spatial_size=(96, 96, 96), pos=1,
                       neg=1, num_samples=4, image_key='image', image_threshold=0),
    ToTensord(keys=['image', 'label'])
])

In [13]:
# Data loader
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
train_loader = DataLoader(
        train_ds,
        batch_size=1,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available()
)

val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(
        val_ds, 
        batch_size=1, 
        num_workers=4, 
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available()
)

In [14]:
dice_metric = DiceMetric(include_background=False, to_onehot_y=True, mutually_exclusive = True, sigmoid=False, reduction="mean")

In [8]:
# Defining model and hyperparameters
device = torch.device("cuda:0")
model = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    out_channels=7,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2
).to(device)
    
if torch.cuda.device_count() > 1:
    model = nn.DataParallel(model)

loss_function = DC_and_CE_loss({'smooth': 1e-5, 'do_bg': False}, {})
optimizer = torch.optim.Adam(model.parameters(), 1e-2)

In [20]:
# start a typical PyTorch training
epc = 10
val_interval = 1
best_metric = 1e10
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(epc):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{epc}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["image"].to(device), batch_data["label"].to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            metric_sum = 0.0
            metric_count = 0
            val_images = None
            val_labels = None
            val_outputs = None
            for val_data in val_loader:
                val_images, val_labels = val_data["image"].to(device), val_data["label"].to(device)
                outputs_ = model(val_images)
                loss_ = loss_function(outputs_, val_labels)
                metric_sum += loss_.item()
            metric = metric_sum / len(val_ds)
            metric_values.append(metric)
            if metric < best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current val loss: {:.4f} best val loss: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_mean_dice", metric, epoch + 1)

----------
epoch 1/10
1/5, train_loss: 0.4586
2/5, train_loss: 0.4025
3/5, train_loss: 0.4509
4/5, train_loss: 0.5433
5/5, train_loss: 0.6791
epoch 1 average loss: 0.5069
saved new best metric model
current epoch: 1 current val loss: 0.4580 best val loss: 0.4580 at epoch 1
----------
epoch 2/10
1/5, train_loss: 0.5453
2/5, train_loss: 0.4328
3/5, train_loss: 0.3945
4/5, train_loss: 0.5480
5/5, train_loss: 0.5280
epoch 2 average loss: 0.4897
saved new best metric model
current epoch: 2 current val loss: 0.4494 best val loss: 0.4494 at epoch 2
----------
epoch 3/10
1/5, train_loss: 0.4621
2/5, train_loss: 0.5617
3/5, train_loss: 0.4819
4/5, train_loss: 0.5317
5/5, train_loss: 0.7715
epoch 3 average loss: 0.5618
saved new best metric model
current epoch: 3 current val loss: 0.4171 best val loss: 0.4171 at epoch 3
----------
epoch 4/10
1/5, train_loss: 0.4356
2/5, train_loss: 0.5337
3/5, train_loss: 0.4557
4/5, train_loss: 0.5702
5/5, train_loss: 0.4497
epoch 4 average loss: 0.4890
current

KeyboardInterrupt: 