# 3D Segmentation with UNet

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Project-MONAI/tutorials/blob/master/3d_segmentation/unet_segmentation_3d_ignite.ipynb)

## Setup environment

In [1]:
!python -c "import monai" || pip install -q "monai-weekly[ignite, nibabel, tensorboard]"

## Setup imports

In [2]:
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#     http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import glob
import logging
import os
import shutil
import sys
import tempfile

import nibabel as nib
import numpy as np
from monai.config import print_config
from monai.data import ArrayDataset, create_test_image_3d
from monai.handlers import (
    MeanDice,
    StatsHandler,
    TensorBoardImageHandler,
    TensorBoardStatsHandler,
)
from monai.apps import download_and_extract
from monai.losses import DiceLoss
from monai.networks.nets import UNet
from monai.transforms import (
    Activations,
    AddChannel,
    AsDiscrete,
    Compose,
    Spacing,
    LoadImage,
    RandSpatialCrop,
    Resize,
    ScaleIntensity,
    ToTensor,
)
from monai.utils import first

import ignite
import torch

print_config()

MONAI version: 0.3.0+90.g59918c4
Python version: 3.7.9 (default, Aug 31 2020, 12:42:55)  [GCC 7.3.0]
OS version: Linux (4.15.0-140-generic)
Numpy version: 1.19.4
Pytorch version: 1.6.0
MONAI flags: HAS_EXT = False, USE_COMPILED = False

Optional dependencies:
Pytorch Ignite version: 0.4.2
Nibabel version: 3.2.0
scikit-image version: 0.17.2
Pillow version: 8.0.1
Tensorboard version: 2.4.0
gdown version: 3.12.2
TorchVision version: 0.8.1
ITK version: 5.1.1
tqdm version: 4.53.0
lmdb version: 1.0.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



## Setup data directory

You can specify a directory with the `MONAI_DATA_DIRECTORY` environment variable.  
This allows you to save results and reuse downloads.  
If not specified a temporary directory will be used.

In [3]:
directory = os.environ.get("MONAI_DATA_DIRECTORY")
root_dir = tempfile.mkdtemp() if directory is None else directory
print(root_dir)

/tmp/tmp3vdfa64m


In [4]:
path_to_target = "/home/alien/dls/nifti_segmentation_nhk/new_result/mri_shoulder/"
path_to_file_list = glob.glob(path_to_target + "train/" + "*.nii" )

#for path_to_file in path_to_file_list:
#    print(path_to_file)
    
train_images = sorted(glob.glob(os.path.join(path_to_target, "train", "*.nii")))
tgt_images = sorted(glob.glob(os.path.join(path_to_target, "train_gt", "*.nii.gz")))
val_images = sorted(glob.glob(os.path.join(path_to_target, "val", "*.nii")))
valgt_images = sorted(glob.glob(os.path.join(path_to_target, "val_gt", "*.nii.gz")))

## Setup logging

In [5]:
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

## Setup transforms, dataset

In [6]:
# Define transforms for image and segmentation
imtrans = Compose(
    [
        LoadImage(image_only=True),
        ScaleIntensity(),
        AddChannel(),
        RandSpatialCrop((128, 128, 128), random_size=False),
        ToTensor(),
    ]
)
segtrans = Compose(
    [
        LoadImage(image_only=True),
        AddChannel(),
        RandSpatialCrop((128, 128, 128), random_size=False),
        ToTensor(),
    ]
)

# Define nifti dataset, dataloader
ds = ArrayDataset(train_images, imtrans, tgt_images, segtrans)
loader = torch.utils.data.DataLoader(
    ds, batch_size=2, num_workers=2, pin_memory=torch.cuda.is_available()
)
im, seg = first(loader)
print(im.shape, seg.shape)


torch.Size([2, 1, 128, 128, 128]) torch.Size([2, 1, 128, 128, 128])


In [7]:
print(np.unique(seg))

[0. 1. 2. 3. 4.]


## Create Model, Loss, Optimizer

In [8]:
# Create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:0")
net = UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2).to(device)

loss = DiceLoss(sigmoid=True)
lr = 1e-3
opt = torch.optim.Adam(net.parameters(), lr)

## Create supervised_trainer using ignite

In [9]:
# Create trainer
trainer = ignite.engine.create_supervised_trainer(
    net, opt, loss, device, False
)

## Setup event handlers for checkpointing and logging

In [10]:
# optional section for checkpoint and tensorboard logging
# adding checkpoint handler to save models (network
# params and optimizer stats) during training
log_dir = os.path.join(root_dir, "logs")
checkpoint_handler = ignite.handlers.ModelCheckpoint(
    log_dir, "net", n_saved=10, require_empty=False
)
trainer.add_event_handler(
    event_name=ignite.engine.Events.EPOCH_COMPLETED,
    handler=checkpoint_handler,
    to_save={"net": net, "opt": opt},
)

# StatsHandler prints loss at every iteration
# and print metrics at every epoch,
# we don't set metrics for trainer here, so just print
# loss, user can also customize print functions
# and can use output_transform to convert
# engine.state.output if it's not a loss value
train_stats_handler = StatsHandler(name="trainer")
train_stats_handler.attach(trainer)

# TensorBoardStatsHandler plots loss at every iteration
# and plots metrics at every epoch, same as StatsHandler
train_tensorboard_stats_handler = TensorBoardStatsHandler(log_dir=log_dir)
train_tensorboard_stats_handler.attach(trainer)

## Add Validation every N epochs

In [11]:
# optional section for model validation during training
validation_every_n_epochs = 1
# Set parameters for validation
metric_name = "Mean_Dice"
# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice()}
post_pred = Compose(
    [Activations(sigmoid=True), AsDiscrete(threshold_values=True)]
)
post_label = AsDiscrete(threshold_values=True)
# Ignite evaluator expects batch=(img, seg) and
# returns output=(y_pred, y) at every iteration,
# user can add output_transform to return other values
evaluator = ignite.engine.create_supervised_evaluator(
    net,
    val_metrics,
    device,
    True,
    output_transform=lambda x, y, y_pred: (post_pred(y_pred), post_label(y)),
)

# create a validation data loader
val_imtrans = Compose(
    [
        LoadImage(image_only=True),
        ScaleIntensity(),
        AddChannel(),
        Resize((128, 128, 128)),
        ToTensor(),
    ]
)
val_segtrans = Compose(
    [
        LoadImage(image_only=True),
        AddChannel(),
        Resize((128, 128, 128)),
        ToTensor(),
    ]
)
val_ds = ArrayDataset(val_images[:6], val_imtrans, valgt_images[:6], val_segtrans)
val_loader = torch.utils.data.DataLoader(
    val_ds, batch_size=1, num_workers=2, pin_memory=torch.cuda.is_available()
)


@trainer.on(
    ignite.engine.Events.EPOCH_COMPLETED(every=validation_every_n_epochs)
)
def run_validation(engine):
    evaluator.run(val_loader)


# Add stats event handler to print validation stats via evaluator
val_stats_handler = StatsHandler(
    name="evaluator",
    # no need to print loss value, so disable per iteration output
    output_transform=lambda x: None,
    # fetch global epoch number from trainer
    global_epoch_transform=lambda x: trainer.state.epoch,
)
val_stats_handler.attach(evaluator)

# add handler to record metrics to TensorBoard at every validation epoch
val_tensorboard_stats_handler = TensorBoardStatsHandler(
    log_dir=log_dir,
    # no need to plot loss value, so disable per iteration output
    output_transform=lambda x: None,
    # fetch global epoch number from trainer
    global_epoch_transform=lambda x: trainer.state.epoch,
)
val_tensorboard_stats_handler.attach(evaluator)

# add handler to draw the first image and the corresponding
# label and model output in the last batch
# here we draw the 3D output as GIF format along Depth
# axis, at every validation epoch
val_tensorboard_image_handler = TensorBoardImageHandler(
    log_dir=log_dir,
    batch_transform=lambda batch: (batch[0], batch[1]),
    output_transform=lambda output: output[0],
    global_iter_transform=lambda x: trainer.state.epoch,
)
evaluator.add_event_handler(
    event_name=ignite.engine.Events.EPOCH_COMPLETED,
    handler=val_tensorboard_image_handler,
)

<ignite.engine.events.RemovableEventHandle at 0x7fa3de082590>

## Run training loop

In [12]:
# create a training data loader
train_ds = ArrayDataset(train_images[:21], imtrans, tgt_images[:21], segtrans)
train_loader = torch.utils.data.DataLoader(
    train_ds,
    batch_size=2,
    shuffle=True,
    num_workers=2,
    pin_memory=torch.cuda.is_available(),
)

max_epochs = 50
state = trainer.run(train_loader, max_epochs)

INFO:ignite.engine.engine.Engine:Engine run starting with max_epochs=50.
INFO:trainer:Epoch: 1/50, Iter: 1/11 -- Loss: 0.9496 
INFO:trainer:Epoch: 1/50, Iter: 2/11 -- Loss: 0.8604 
INFO:trainer:Epoch: 1/50, Iter: 3/11 -- Loss: 0.5827 
INFO:trainer:Epoch: 1/50, Iter: 4/11 -- Loss: 0.6340 
INFO:trainer:Epoch: 1/50, Iter: 5/11 -- Loss: 0.5287 
INFO:trainer:Epoch: 1/50, Iter: 6/11 -- Loss: 0.5382 
INFO:trainer:Epoch: 1/50, Iter: 7/11 -- Loss: 0.7991 
INFO:trainer:Epoch: 1/50, Iter: 8/11 -- Loss: 0.5254 
INFO:trainer:Epoch: 1/50, Iter: 9/11 -- Loss: 0.5878 
INFO:trainer:Epoch: 1/50, Iter: 10/11 -- Loss: 0.5806 
INFO:trainer:Epoch: 1/50, Iter: 11/11 -- Loss: 0.6297 
INFO:ignite.engine.engine.Engine:Engine run starting with max_epochs=1.
INFO:evaluator:Epoch[1] Metrics -- Mean_Dice: 0.3397 
INFO:evaluator:Epoch[1] Metrics -- Mean_Dice: 0.3397 
INFO:ignite.engine.engine.Engine:Epoch[1] Complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Engine run complete. Time taken: 00:00:10
INF

INFO:trainer:Epoch: 9/50, Iter: 4/11 -- Loss: 0.4970 
INFO:trainer:Epoch: 9/50, Iter: 5/11 -- Loss: 0.4349 
INFO:trainer:Epoch: 9/50, Iter: 6/11 -- Loss: 0.3717 
INFO:trainer:Epoch: 9/50, Iter: 7/11 -- Loss: 0.4042 
INFO:trainer:Epoch: 9/50, Iter: 8/11 -- Loss: 0.4874 
INFO:trainer:Epoch: 9/50, Iter: 9/11 -- Loss: 0.5293 
INFO:trainer:Epoch: 9/50, Iter: 10/11 -- Loss: 0.5326 
INFO:trainer:Epoch: 9/50, Iter: 11/11 -- Loss: 0.0822 
INFO:ignite.engine.engine.Engine:Engine run starting with max_epochs=1.
INFO:evaluator:Epoch[9] Metrics -- Mean_Dice: 0.6117 
INFO:evaluator:Epoch[9] Metrics -- Mean_Dice: 0.6117 
INFO:ignite.engine.engine.Engine:Epoch[1] Complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Engine run complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Epoch[9] Complete. Time taken: 00:00:40
INFO:trainer:Epoch: 10/50, Iter: 1/11 -- Loss: 0.8746 
INFO:trainer:Epoch: 10/50, Iter: 2/11 -- Loss: 0.9117 
INFO:trainer:Epoch: 10/50, Iter: 3/11 -- Loss: 0.4045 


INFO:trainer:Epoch: 17/50, Iter: 7/11 -- Loss: 0.7153 
INFO:trainer:Epoch: 17/50, Iter: 8/11 -- Loss: 0.3958 
INFO:trainer:Epoch: 17/50, Iter: 9/11 -- Loss: 0.3535 
INFO:trainer:Epoch: 17/50, Iter: 10/11 -- Loss: 0.3328 
INFO:trainer:Epoch: 17/50, Iter: 11/11 -- Loss: 0.0244 
INFO:ignite.engine.engine.Engine:Engine run starting with max_epochs=1.
INFO:evaluator:Epoch[17] Metrics -- Mean_Dice: 0.6543 
INFO:evaluator:Epoch[17] Metrics -- Mean_Dice: 0.6543 
INFO:ignite.engine.engine.Engine:Epoch[1] Complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Engine run complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Epoch[17] Complete. Time taken: 00:00:39
INFO:trainer:Epoch: 18/50, Iter: 1/11 -- Loss: 0.9656 
INFO:trainer:Epoch: 18/50, Iter: 2/11 -- Loss: 0.9071 
INFO:trainer:Epoch: 18/50, Iter: 3/11 -- Loss: 0.4434 
INFO:trainer:Epoch: 18/50, Iter: 4/11 -- Loss: 0.4003 
INFO:trainer:Epoch: 18/50, Iter: 5/11 -- Loss: 0.4217 
INFO:trainer:Epoch: 18/50, Iter: 6/11 -- Los

INFO:trainer:Epoch: 25/50, Iter: 9/11 -- Loss: 0.3639 
INFO:trainer:Epoch: 25/50, Iter: 10/11 -- Loss: 0.3802 
INFO:trainer:Epoch: 25/50, Iter: 11/11 -- Loss: 0.2712 
INFO:ignite.engine.engine.Engine:Engine run starting with max_epochs=1.
INFO:evaluator:Epoch[25] Metrics -- Mean_Dice: 0.7000 
INFO:evaluator:Epoch[25] Metrics -- Mean_Dice: 0.7000 
INFO:ignite.engine.engine.Engine:Epoch[1] Complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Engine run complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Epoch[25] Complete. Time taken: 00:00:38
INFO:trainer:Epoch: 26/50, Iter: 1/11 -- Loss: 0.9608 
INFO:trainer:Epoch: 26/50, Iter: 2/11 -- Loss: 0.9235 
INFO:trainer:Epoch: 26/50, Iter: 3/11 -- Loss: 0.4493 
INFO:trainer:Epoch: 26/50, Iter: 4/11 -- Loss: 0.3706 
INFO:trainer:Epoch: 26/50, Iter: 5/11 -- Loss: 0.4234 
INFO:trainer:Epoch: 26/50, Iter: 6/11 -- Loss: 0.3719 
INFO:trainer:Epoch: 26/50, Iter: 7/11 -- Loss: 0.5037 
INFO:trainer:Epoch: 26/50, Iter: 8/11 -- Los

INFO:trainer:Epoch: 33/50, Iter: 11/11 -- Loss: 0.3648 
INFO:ignite.engine.engine.Engine:Engine run starting with max_epochs=1.
INFO:evaluator:Epoch[33] Metrics -- Mean_Dice: 0.6977 
INFO:evaluator:Epoch[33] Metrics -- Mean_Dice: 0.6977 
INFO:ignite.engine.engine.Engine:Epoch[1] Complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Engine run complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Epoch[33] Complete. Time taken: 00:00:39
INFO:trainer:Epoch: 34/50, Iter: 1/11 -- Loss: 0.9149 
INFO:trainer:Epoch: 34/50, Iter: 2/11 -- Loss: 0.7360 
INFO:trainer:Epoch: 34/50, Iter: 3/11 -- Loss: 0.4662 
INFO:trainer:Epoch: 34/50, Iter: 4/11 -- Loss: 0.5717 
INFO:trainer:Epoch: 34/50, Iter: 5/11 -- Loss: 0.3500 
INFO:trainer:Epoch: 34/50, Iter: 6/11 -- Loss: 0.3444 
INFO:trainer:Epoch: 34/50, Iter: 7/11 -- Loss: 0.5242 
INFO:trainer:Epoch: 34/50, Iter: 8/11 -- Loss: 0.5087 
INFO:trainer:Epoch: 34/50, Iter: 9/11 -- Loss: 0.3097 
INFO:trainer:Epoch: 34/50, Iter: 10/11 -- Los

INFO:evaluator:Epoch[41] Metrics -- Mean_Dice: 0.7144 
INFO:evaluator:Epoch[41] Metrics -- Mean_Dice: 0.7144 
INFO:ignite.engine.engine.Engine:Epoch[1] Complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Engine run complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Epoch[41] Complete. Time taken: 00:00:40
INFO:trainer:Epoch: 42/50, Iter: 1/11 -- Loss: 0.8586 
INFO:trainer:Epoch: 42/50, Iter: 2/11 -- Loss: 0.6678 
INFO:trainer:Epoch: 42/50, Iter: 3/11 -- Loss: 0.3419 
INFO:trainer:Epoch: 42/50, Iter: 4/11 -- Loss: 0.5939 
INFO:trainer:Epoch: 42/50, Iter: 5/11 -- Loss: 0.4258 
INFO:trainer:Epoch: 42/50, Iter: 6/11 -- Loss: 0.3554 
INFO:trainer:Epoch: 42/50, Iter: 7/11 -- Loss: 0.4496 
INFO:trainer:Epoch: 42/50, Iter: 8/11 -- Loss: 0.3756 
INFO:trainer:Epoch: 42/50, Iter: 9/11 -- Loss: 0.4094 
INFO:trainer:Epoch: 42/50, Iter: 10/11 -- Loss: 0.3297 
INFO:trainer:Epoch: 42/50, Iter: 11/11 -- Loss: -0.1327 
INFO:ignite.engine.engine.Engine:Engine run starting with ma

INFO:ignite.engine.engine.Engine:Epoch[1] Complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Engine run complete. Time taken: 00:00:10
INFO:ignite.engine.engine.Engine:Epoch[49] Complete. Time taken: 00:00:40
INFO:trainer:Epoch: 50/50, Iter: 1/11 -- Loss: 0.8052 
INFO:trainer:Epoch: 50/50, Iter: 2/11 -- Loss: 0.6392 
INFO:trainer:Epoch: 50/50, Iter: 3/11 -- Loss: 0.3857 
INFO:trainer:Epoch: 50/50, Iter: 4/11 -- Loss: 0.0654 
INFO:trainer:Epoch: 50/50, Iter: 5/11 -- Loss: 0.3411 
INFO:trainer:Epoch: 50/50, Iter: 6/11 -- Loss: 0.5155 
INFO:trainer:Epoch: 50/50, Iter: 7/11 -- Loss: 0.5918 
INFO:trainer:Epoch: 50/50, Iter: 8/11 -- Loss: 0.3620 
INFO:trainer:Epoch: 50/50, Iter: 9/11 -- Loss: 0.2516 
INFO:trainer:Epoch: 50/50, Iter: 10/11 -- Loss: 0.3298 
INFO:trainer:Epoch: 50/50, Iter: 11/11 -- Loss: -0.0199 
INFO:ignite.engine.engine.Engine:Engine run starting with max_epochs=1.
INFO:evaluator:Epoch[50] Metrics -- Mean_Dice: 0.7363 
INFO:evaluator:Epoch[50] Metrics -- Mean_Di

## Visualizing Tensorboard logs

In [31]:
%load_ext tensorboard
%tensorboard --logdir=logdir

The tensorboard extension is already loaded. To reload it, use:
  %reload_ext tensorboard


Expected training curve on TensorBoard:
![image.png](attachment:image.png)

## Cleanup data directory

Remove directory if a temporary was used.

In [32]:
if directory is None:
    shutil.rmtree(root_dir)