In [7]:
import logging
import os
import sys
import tempfile
from glob import glob
import numpy as np

import torch
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
from monai.inferers import sliding_window_inference
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AddChanneld,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandRotate90d,
    ScaleIntensityd,
    EnsureTyped,
    EnsureType,
    Resized,
    SqueezeDimd,
    RandFlipd,
    NormalizeIntensityd,
    CropForegroundd,
    ScaleIntensityRanged,
    RandAffined,
)
from monai.visualize import plot_2d_or_3d_image


In [8]:
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

MONAI version: 0.7.dev2133
Numpy version: 1.21.1
Pytorch version: 1.9.0+cu111
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 20ffa3f987fad60a8428ec635fb0b4f6ccca9747

Optional dependencies:
Pytorch Ignite version: 0.4.6
Nibabel version: 3.2.1
scikit-image version: NOT INSTALLED or UNKNOWN VERSION.
Pillow version: 8.3.1
Tensorboard version: 2.6.0
gdown version: 3.13.0
TorchVision version: 0.10.0+cu111
tqdm version: 4.62.0
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: NOT INSTALLED or UNKNOWN VERSION.
pandas version: 1.3.2
einops version: 0.3.0

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



In [9]:
 # define transforms for image and segmentation
train_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        SqueezeDimd(keys=['img', 'seg'], dim=-1),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        #     ScaleIntensityRanged(
        #     keys=["image"], a_min=-57, a_max=164,
        #     b_min=0.0, b_max=1.0, clip=True,
        # ),
        CropForegroundd(keys=["img", "seg"], source_key="img"),
        RandCropByPosNegLabeld(
            keys=["img", "seg"],
            label_key="seg",
            spatial_size=[192, 192],
            pos=1,
            neg=1,
            num_samples=8,
        ),
        RandRotate90d(keys=["img", "seg"], prob=0.10, spatial_axes=[0, 1]),
        RandFlipd(
            keys=["img", "seg"],
            spatial_axis=[0],
            prob=0.10,
        ),
        RandFlipd(
            keys=["img", "img"],
            spatial_axis=[1],
            prob=0.10,
        ),
        #NormalizeIntensityd(keys="img", nonzero=True, channel_wise=True),
        # user can also add other random transforms
        # RandAffined(
        #     keys=['img', 'seg'],
        #     mode=('bilinear', 'nearest'),
        #     prob=1.0, spatial_size=(192, 192),
        #     rotate_range=(0, 0, np.pi/15),
        #     scale_range=(0.1, 0.1, 0.1)),
        EnsureTyped(keys=["img", "seg"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        SqueezeDimd(keys=['img', 'seg'], dim=-1),
        AddChanneld(keys=["img", "seg"]),
        ScaleIntensityd(keys=["img", "seg"]),
        CropForegroundd(keys=["img", "seg"], source_key="img"),
        #NormalizeIntensityd(keys="img", nonzero=True, channel_wise=True),
        EnsureTyped(keys=["img", "seg"]),
    ]
)


In [10]:
def main(traindir,valdir):
    '''main function
       start 
    '''
    # images = sorted(glob(os.path.join(data_dir, "img*.nii.gz")))
    # segs = sorted(glob(os.path.join(data_dir, "label*.nii.gz")))
    # #print(images)
    # print(f"image size:{len(images)}")
    # #print(segs)
    # print(f"label size:{len(segs)}")
    # train_files = [{"img": img, "seg": seg} for img, seg in zip(images[:32], segs[:32])]
    # val_files = [{"img": img, "seg": seg} for img, seg in zip(images[-12:], segs[-12:])]

    images_train = sorted(glob(os.path.join(traindir, "img*.nii.gz")))
    segs_train = sorted(glob(os.path.join(traindir, "label*.nii.gz")))
    #print(images)
    print(f"image size:{len(images_train)}")
    #print(segs)
    print(f"label size:{len(segs_train)}")
    train_files = [{"img": img, "seg": seg} for img, seg in zip(images_train, segs_train)]


    images_val = sorted(glob(os.path.join(valdir, "img*.nii.gz")))
    segs_val = sorted(glob(os.path.join(valdir, "label*.nii.gz")))
    #print(images)
    print(f"image size:{len(images_val)}")
    #print(segs)
    print(f"label size:{len(segs_val)}")
    val_files = [{"img": img, "seg": seg} for img, seg in zip(images_val, segs_val)]

    # define transforms for image and segmentation
    train_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            SqueezeDimd(keys=['img', 'seg'], dim=-1),
            AddChanneld(keys=["img", "seg"]),
            ScaleIntensityd(keys=["img", "seg"]),
        #     ScaleIntensityRanged(
        #     keys=["image"], a_min=-57, a_max=164,
        #     b_min=0.0, b_max=1.0, clip=True,
        # ),
            CropForegroundd(keys=["img", "seg"], source_key="img"),
            RandCropByPosNegLabeld(
                keys=["img", "seg"],
                label_key="seg",
                spatial_size=[192, 192],
                pos=1,
                neg=1,
                num_samples=8,
            ),
            RandRotate90d(keys=["img", "seg"], prob=0.10, spatial_axes=[0, 1]),
            RandFlipd(
                keys=["img", "seg"],
                spatial_axis=[0],
                prob=0.10,
            ),
            RandFlipd(
                keys=["img", "seg"],
                spatial_axis=[1],
                prob=0.10,
            ),
            #NormalizeIntensityd(keys="img", nonzero=True, channel_wise=True),
              # user can also add other random transforms
            # RandAffined(
            #     keys=['img', 'seg'],
            #     mode=('bilinear', 'nearest'),
            #     prob=1.0, spatial_size=(192, 192),
            #     rotate_range=(0, 0, np.pi/15),
            #     scale_range=(0.1, 0.1, 0.1)),
            EnsureTyped(keys=["img", "seg"]),
        ]
    )
    val_transforms = Compose(
        [
            LoadImaged(keys=["img", "seg"]),
            SqueezeDimd(keys=['img', 'seg'], dim=-1),
            AddChanneld(keys=["img", "seg"]),
            ScaleIntensityd(keys=["img", "seg"]),
            CropForegroundd(keys=["img", "seg"], source_key="img"),
            #NormalizeIntensityd(keys="img", nonzero=True, channel_wise=True),
            EnsureTyped(keys=["img", "seg"]),
        ]
    )

    # define dataset, data loader
    check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    check_loader = DataLoader(check_ds, batch_size=2,
                          num_workers=4, collate_fn=list_data_collate)
    check_data = monai.utils.misc.first(check_loader)
    print(check_data["img"].shape, check_data["seg"].shape)
    
    # create a training data loader
    train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
    # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
    train_loader = DataLoader(
        train_ds,
        batch_size=2,
        shuffle=True,
        num_workers=4,
        collate_fn=list_data_collate,
        pin_memory=torch.cuda.is_available(),
    )

    # create a validation data loader
    val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
    val_loader = DataLoader(val_ds, batch_size=1,
                            num_workers=4, collate_fn=list_data_collate)
    dice_metric = DiceMetric(include_background=True,
                            reduction="mean", get_not_nans=False)
    post_trans = Compose([EnsureType(), Activations(
        sigmoid=True), AsDiscrete(threshold_values=True)])
    # create UNet, DiceLoss and Adam optimizer
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = monai.networks.nets.UNet(
        dimensions=2,
        in_channels=1,
        out_channels=1,
        channels=(16, 32, 64, 128, 256),
        strides=(2, 2, 2, 2),
        num_res_units=2,
    ).to(device)
    loss_function = monai.losses.DiceLoss(sigmoid=True)
    optimizer = torch.optim.Adam(model.parameters(), 1e-3)

    #load model
    weight_file = "model_weights_dict.pth"

    if os.path.exists(weight_file):
         model.load_state_dict(torch.load(weight_file))
         print("load weight")

    # start a typical PyTorch training
    max_iterations = 5000  
    val_interval = 2
    best_metric = -1
    best_metric_epoch = -1
    epoch_loss_values = list()
    metric_values = list()
    writer = SummaryWriter()
    for epoch in range(max_iterations):
        print("-" * 10)
        print(f"epoch {epoch + 1}/{max_iterations}")
        model.train()
        epoch_loss = 0
        step = 0
        for batch_data in train_loader:
            step += 1
            inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = loss_function(outputs, labels)
            loss.backward()
            optimizer.step()
            epoch_loss += loss.item()
            epoch_len = len(train_ds) // train_loader.batch_size
            #print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
            writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)
        epoch_loss /= step
        epoch_loss_values.append(epoch_loss)
        print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

        if (epoch + 1) % val_interval == 0:
            model.eval()
            with torch.no_grad():
                val_images = None
                val_labels = None
                val_outputs = None
                for val_data in val_loader:
                    val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                    roi_size = (192, 192)
                    sw_batch_size = 4
                    val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
                    val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
                    # compute metric for current iteration
                    dice_metric(y_pred=val_outputs, y=val_labels)
                # aggregate the final mean dice result
                metric = dice_metric.aggregate().item()
                # reset the status for next validation round
                dice_metric.reset()
                metric_values.append(metric)
                if metric > best_metric:
                    best_metric = metric
                    best_metric_epoch = epoch + 1
                    torch.save(model.state_dict(), "best6_metric_model_segmentation2d_dict.pth")
                    print("saved new best metric model")
                # print(
                #     "current epoch: {} current mean dice: {:.4f} best mean dice: {:.4f} at epoch {}".format(
                #         epoch + 1, metric, best_metric, best_metric_epoch
                #     )
                #)
                writer.add_scalar("val_mean_dice", metric, epoch + 1)
                # plot the last model output as GIF image in TensorBoard with the corresponding image and label
                plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
                plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
                plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")

    print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
    writer.close()


In [11]:
if __name__ == "__main__":
    root_dir = os.getcwd()
    train_dir = os.path.join(root_dir, "data/imagesTr")
    val_dir = os.path.join(root_dir, "data/imagesTs")
    print(train_dir)
    print(val_dir)
    main(train_dir,val_dir)

/home/xindong/project/monai-test/unet-segment2d/data/imagesTr
/home/xindong/project/monai-test/unet-segment2d/data/imagesTs
image size:80
label size:80
image size:24
label size:24




torch.Size([16, 1, 192, 192]) torch.Size([16, 1, 192, 192])
load weight
----------
epoch 1/5000




epoch 1 average loss: 0.2694
----------
epoch 2/5000




epoch 2 average loss: 0.2696




saved new best metric model
----------
epoch 3/5000




epoch 3 average loss: 0.2789
----------
epoch 4/5000




epoch 4 average loss: 0.2864




saved new best metric model
----------
epoch 5/5000




epoch 5 average loss: 0.2612
----------
epoch 6/5000




epoch 6 average loss: 0.2688




----------
epoch 7/5000




epoch 7 average loss: 0.2648
----------
epoch 8/5000




epoch 8 average loss: 0.2687




----------
epoch 9/5000




epoch 9 average loss: 0.2617
----------
epoch 10/5000




epoch 10 average loss: 0.2716




saved new best metric model
----------
epoch 11/5000




epoch 11 average loss: 0.2514
----------
epoch 12/5000




epoch 12 average loss: 0.2523




----------
epoch 13/5000




epoch 13 average loss: 0.2730
----------
epoch 14/5000




epoch 14 average loss: 0.2773




----------
epoch 15/5000




epoch 15 average loss: 0.2731
----------
epoch 16/5000




epoch 16 average loss: 0.2990




----------
epoch 17/5000




epoch 17 average loss: 0.2802
----------
epoch 18/5000




epoch 18 average loss: 0.2724




----------
epoch 19/5000




epoch 19 average loss: 0.2745
----------
epoch 20/5000




epoch 20 average loss: 0.2700




----------
epoch 21/5000




epoch 21 average loss: 0.2618
----------
epoch 22/5000




epoch 22 average loss: 0.2609




----------
epoch 23/5000




epoch 23 average loss: 0.2725
----------
epoch 24/5000




epoch 24 average loss: 0.2747




----------
epoch 25/5000




epoch 25 average loss: 0.2643
----------
epoch 26/5000




epoch 26 average loss: 0.2792




----------
epoch 27/5000




epoch 27 average loss: 0.2793
----------
epoch 28/5000




epoch 28 average loss: 0.2463




----------
epoch 29/5000




epoch 29 average loss: 0.2675
----------
epoch 30/5000




epoch 30 average loss: 0.2758




----------
epoch 31/5000




epoch 31 average loss: 0.2721
----------
epoch 32/5000




epoch 32 average loss: 0.2613




----------
epoch 33/5000




epoch 33 average loss: 0.2653
----------
epoch 34/5000




epoch 34 average loss: 0.2697




----------
epoch 35/5000




epoch 35 average loss: 0.2727
----------
epoch 36/5000




epoch 36 average loss: 0.2808




----------
epoch 37/5000




epoch 37 average loss: 0.2702
----------
epoch 38/5000




epoch 38 average loss: 0.2599




----------
epoch 39/5000




epoch 39 average loss: 0.2628
----------
epoch 40/5000




epoch 40 average loss: 0.2725




----------
epoch 41/5000




epoch 41 average loss: 0.2600
----------
epoch 42/5000




epoch 42 average loss: 0.2732




KeyboardInterrupt: 