# Import Packages

In [1]:
import logging
import os
import sys
import tempfile
from glob import glob
import time
import cv2

import torchvision
import timm
import mmcv

import torch
from PIL import Image
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.optim.lr_scheduler import ReduceLROnPlateau
import  matplotlib.pyplot as plt
import monai
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
from monai.inferers import sliding_window_inference, SimpleInferer
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AddChanneld,
    AsDiscrete,
    Compose,
    LoadImaged,
    RandCropByPosNegLabeld,
    RandAdjustContrastd,
    RandRotate90d,
    ScaleIntensityd,
    EnsureTyped,
    EnsureType,
    AsChannelFirstd,
    AsChannelLast,
    Resized,
    RandScaleCropd,
    RandRotated,
    Rotated,
    SaveImage,
    ThresholdIntensity,
    RandBiasField,
    ThresholdIntensityd
)
from monai.visualize import plot_2d_or_3d_image

# Check MONAI configurations

In [2]:
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)


MONAI version: 0.8.1
Numpy version: 1.20.3
Pytorch version: 1.11.0
MONAI flags: HAS_EXT = False, USE_COMPILED = False
MONAI rev id: 71ff399a3ea07aef667b23653620a290364095b1

Optional dependencies:
Pytorch Ignite version: NOT INSTALLED or UNKNOWN VERSION.
Nibabel version: NOT INSTALLED or UNKNOWN VERSION.
scikit-image version: 0.18.3
Pillow version: 8.4.0
Tensorboard version: 2.8.0
gdown version: NOT INSTALLED or UNKNOWN VERSION.
TorchVision version: 0.12.0
tqdm version: 4.62.3
lmdb version: NOT INSTALLED or UNKNOWN VERSION.
psutil version: 5.8.0
pandas version: 1.3.4
einops version: NOT INSTALLED or UNKNOWN VERSION.
transformers version: NOT INSTALLED or UNKNOWN VERSION.
mlflow version: NOT INSTALLED or UNKNOWN VERSION.

For details about installing the optional dependencies, please visit:
    https://docs.monai.io/en/latest/installation.html#installing-the-recommended-dependencies



# Process VGH Data

In [3]:
# Set the Data folder
data_path = "D:/nycu_deep_learning/SEG_Train_Datasets/"

## -obtain train data list

In [4]:
# Load train files
tempdir = data_path + "Train_Images/"
train_images = sorted(glob(os.path.join(tempdir, "*.jpg")))

tempdir = data_path + "mask_img/"
train_segs = sorted(glob(os.path.join(tempdir, "*.png")))


In [5]:
from sklearn.model_selection import train_test_split
train_images , valid_images , train_segs , valid_segs =  train_test_split(train_images,train_segs, train_size=0.88,random_state=2)

print(f" {len(train_images)} train_images and {len(train_segs)} train_segs")
train_files = [{"img": img, "seg": seg} for img, seg in zip(train_images[:], train_segs[:])]

print(f" {len(valid_images)} valid_images and {len(valid_segs)} valid_segs")
val_files = [{"img": img, "seg": seg} for img, seg in zip(valid_images[:], valid_segs[:])]

 926 train_images and 926 train_segs
 127 valid_images and 127 valid_segs


# Define Trasform for image and segmentation

In [6]:
# define transforms for image and segmentation
threshold_value = 0.1
cval_value=0.2
train_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["seg"]),        
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img", "seg"]),
        RandScaleCropd(keys=["img", "seg"],roi_scale=0.5,max_roi_scale=1.5),
        RandRotated(keys=["img", "seg"],range_x=3.14),
        RandAdjustContrastd(keys=["img"], prob=0.1, gamma=(0.5, 4.5)),
        #RandCropByPosNegLabeld(
        #    keys=["img", "seg"], label_key="seg", spatial_size=[800, 800], pos=1, neg=1, num_samples=4
        #),
        RandRotate90d(keys=["img", "seg"], prob=0.5),
        Resized(keys=["img", "seg"], spatial_size=[1600, 800]),
        EnsureTyped(keys=["img", "seg"]),
        #ThresholdIntensityd(keys=["img"],threshold=threshold_value,above=False,cval=cval_value)
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["seg"]),        
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img", "seg"]),
        Resized(keys=["img", "seg"], spatial_size=[1696, 928]),
        EnsureTyped(keys=["img", "seg"]),
        #ThresholdIntensityd(keys=["img"],threshold=threshold_value,above=False,cval=cval_value)        
        
    ]
)

# Check and visualize the transform results

In [7]:
# # define dataset, data loader
# check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

In [8]:
# # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
# batch = 4
# check_loader = DataLoader(check_ds, batch_size=batch, num_workers=12, collate_fn=list_data_collate)
# check_data = monai.utils.misc.first(check_loader)
# print(check_data["img"].shape, check_data["seg"].shape)


# import matplotlib.pyplot as plt

# plt.figure("visualize",(16,64))
# for i in range(batch):
#     plt.subplot(8,2,2*i+1)    
#     plt.imshow(check_data["img"][i].permute(1,2,0))
#     plt.subplot(8,2,2*i+2)
#     plt.imshow(check_data["seg"][i].permute(1,2,0))

# Create DataLoader for train and validation data

In [9]:
# create a training data loader
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
train_loader = DataLoader(
    train_ds,
    batch_size=4,
    shuffle=True,
    num_workers=8,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)

# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=2, num_workers=4, collate_fn=list_data_collate)

# Define metric and post-processing

In [10]:
dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_trans_label = Compose([EnsureType(), AsDiscrete(threshold=0.5)])

# Built Model

In [11]:
# create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = monai.networks.nets.DynUNet(
    spatial_dims=2,
    in_channels=3,
    out_channels=1,
    kernel_size=(3,3,3,3,3,3),
    strides=(1,2,2,2,2,2),
    filters  = (36, 64, 128, 256, 512, 1024), #  [32, 64, 128, 256, 512, 1024][: len(strides)]
    upsample_kernel_size=(2,2,2,2,2,2), # The values should equal to strides[1:]
    res_block=True,
    trans_bias=True,
).to(device)


#loss_function = monai.losses.DiceLoss(sigmoid=True,squared_pred =False,jaccard =False)  
loss_function = monai.losses.DiceFocalLoss(include_background=True,sigmoid=True,squared_pred =True,lambda_dice=1.,lambda_focal=1.,gamma=5)
#loss_function = monai.losses.DiceCELoss(sigmoid=True,lambda_dice=1.0,lambda_ce=1.0)
#loss_function = monai.losses.GeneralizedDiceLoss(include_background=True,sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), 5e-3)
scheduler = ReduceLROnPlateau(optimizer, 'max',patience=8,factor=0.5,min_lr=5e-5)

In [12]:
model = torch.nn.DataParallel(model)

In [13]:
model.load_state_dict(torch.load("best_metric_model_segmentation2d_dict_17.pth"))

<All keys matched successfully>

# Create Visualize Function

In [14]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 16))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image,'gray')
    plt.show()

# Define training parameters and Start training

In [15]:
#### start a typical PyTorch training
total_epochs = 250
val_interval = 1
best_metric = 0
best_metric_epoch = -1
epoch_loss_values = list()
metric_values = list()
writer = SummaryWriter()
for epoch in range(total_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{total_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
        
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        
        outputs=[post_trans(i) for i in decollate_batch(outputs)]
        labels=[post_trans_label(i) for i in decollate_batch(labels)] 
        dice_metric(y_pred=outputs, y=labels)
        
        
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_len = len(train_ds) // train_loader.batch_size
        # print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

    metric = dice_metric.aggregate().item()
    dice_metric.reset()
    #print("current training dice score: {:.4f} ".format(metric))
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    local_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    print(f"{local_time} epoch {epoch + 1} average loss: {epoch_loss:.4f} dice score:{metric}")

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_images = None
            val_labels = None
            val_outputs = None
            show_val = False
            for val_data in val_loader:
                val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                      
        
                roi_size = (1600, 800)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)       
                
                
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]                
                val_labels = [post_trans_label(i) for i in decollate_batch(val_labels)]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

                if show_val:
                    visualize( 
                        image=val_images[0].cpu().permute(1,2,0), 
                        ground_truth_mask=val_labels[0].cpu().permute(1,2,0), 
                        predicted_mask=val_outputs[0].cpu().permute(1,2,0)
                    )                                      
                
                show_val = False
                

            # aggregate the final mean dice result
            metric = dice_metric.aggregate().item()
            scheduler.step(metric)
            print('epoch:',epoch+1, 'learning rate:',optimizer.param_groups[0]['lr'])
            # reset the status for next validation round
            dice_metric.reset()            
            metric_values.append(metric)
            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "best_metric_model_segmentation2d_dict.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current val mean dice score: {:.4f} best val mean dice score: {:.4f} at epoch {}".format(
                    epoch + 1, metric, best_metric, best_metric_epoch
                )
            )
            writer.add_scalar("val_mean_dice score", metric, epoch + 1)
            # plot the last model output as GIF image in TensorBoard with the corresponding image and label
            plot_2d_or_3d_image(val_images, epoch + 1, writer, index=0, tag="image")
            plot_2d_or_3d_image(val_labels, epoch + 1, writer, index=0, tag="label")
            plot_2d_or_3d_image(val_outputs, epoch + 1, writer, index=0, tag="output")
            

print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")
writer.close()

torch.save(model.state_dict(), "Final_model_40_epoches_segmentation2d_dict.pth")

----------
epoch 1/250




2022-05-30 15:08:51 epoch 1 average loss: 0.1433 dice score:0.8781594038009644
epoch: 1 learning rate: 0.005
saved new best metric model
current epoch: 1 current val mean dice score: 0.8606 best val mean dice score: 0.8606 at epoch 1
----------
epoch 2/250
2022-05-30 15:18:44 epoch 2 average loss: 0.1367 dice score:0.8825605511665344
epoch: 2 learning rate: 0.005
saved new best metric model
current epoch: 2 current val mean dice score: 0.8625 best val mean dice score: 0.8625 at epoch 2
----------
epoch 3/250
2022-05-30 15:24:07 epoch 3 average loss: 0.1332 dice score:0.8770321011543274
epoch: 3 learning rate: 0.005
current epoch: 3 current val mean dice score: 0.8562 best val mean dice score: 0.8625 at epoch 2
----------
epoch 4/250
2022-05-30 15:29:30 epoch 4 average loss: 0.1281 dice score:0.8772559762001038
epoch: 4 learning rate: 0.005
saved new best metric model
current epoch: 4 current val mean dice score: 0.8627 best val mean dice score: 0.8627 at epoch 4
----------
epoch 5/250


KeyboardInterrupt: 

In [None]:
plt.plot(epoch_loss_values)

In [None]:
plt.plot(metric_values)