# Training
Provided by Prof. Lee

# Install Packages

In [None]:
!pip install monai
!pip install torchvision
!pip install -U Setuptools
!pip install git+https://github.com/qubvel/segmentation_models.pytorch
!pip install adabelief-pytorch==0.2.0

In [None]:
!pip install ipywidgets widgetsnbextension
!jupyter nbextension enable --py widgetsnbextension

# Import Packages

In [None]:
import logging
import os
import sys
import tempfile
import glob
import time
import matplotlib.pyplot as plt
import numpy as np

import setuptools
import torch
import torchvision
from PIL import Image
from torch.utils.data import DataLoader #只有dataloader用torch的
from torch.utils.tensorboard import SummaryWriter

import monai
from monai.data import create_test_image_2d, list_data_collate, decollate_batch
from monai.metrics import DiceMetric
from monai.transforms import (
    Activations,
    AddChanneld,
    AsDiscrete,
    Compose,  
    LoadImaged,
    RandRotate90d,
    ScaleIntensityd,
    EnsureTyped,
    EnsureType,
    AsChannelFirstd,
    Resized,
    SaveImage,
    Resize,
)

# Check MONAI configurations

In [None]:
monai.config.print_config()
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

# Process Data

## -Set the Main Data Folder
Data Folder Structure 

    YOUR PATH
        ├Train_Images
        └Train_Annotations_png

In [None]:
# Set the Data folder
data_path = "{YOUR PATH}"

## -obtain train data and validation data list

Use index to apart train data and validation data.

Create dictionary of images and annotations.

In [None]:
# Number of validation data
val_num = 225

# Load train files
tempdir = data_path + "Train_Images/"
train_images = sorted(glob(os.path.join(tempdir, "*.jpg")))

tempdir = data_path + "Train_Annotations_png/"
train_segs = sorted(glob(os.path.join(tempdir, "*.png")))

# Training data
train_files = [{"img": img, "seg": seg} for img, seg in zip(train_images[val_num:], train_segs[val_num:])]
print(f" {len(train_images[val_num:])} train_images and {len(train_segs[val_num:])} train_segs")

# validation data
val_files = [{"img": img, "seg": seg} for img, seg in zip(train_images[:val_num], train_segs[:val_num])]
print(f" {len(train_images[:val_num])} val_images and {len(train_segs[:val_num])} val_segs")

# Define Trasform for image and segmentation

In [None]:
# define transforms for image and segmentation
train_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["seg"]),        
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img", "seg"]),
        Resized(keys=["img", "seg"], spatial_size=[800, 800]),
        EnsureTyped(keys=["img", "seg"]),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["img", "seg"]),
        AddChanneld(keys=["seg"]),        
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img", "seg"]),
        Resized(keys=["img", "seg"], spatial_size=[800, 800]),
        EnsureTyped(keys=["img", "seg"]),
    ]
)

# Check and visualize the transform results

In [None]:
# define dataset, data loader
check_ds = monai.data.Dataset(data=train_files, transform=train_transforms)

In [None]:
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
check_loader = DataLoader(check_ds, batch_size=8, num_workers=12, collate_fn=list_data_collate)
check_data = monai.utils.misc.first(check_loader)
print(check_data["img"].shape, check_data["seg"].shape)


plt.figure("visualize",(16,64))
for i in range(8):
    plt.subplot(8,2,2*i+1)    
    plt.imshow(check_data["img"][i].permute(1,2,0))
    plt.subplot(8,2,2*i+2)
    plt.imshow(check_data["seg"][i].permute(1,2,0))

# Create DataLoader for train and validation data

In [None]:
# create a training data loader
train_ds = monai.data.Dataset(data=train_files, transform=train_transforms)
# use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training
train_loader = DataLoader(
    train_ds,
    batch_size=8,
    shuffle=True,
    num_workers=8,
    collate_fn=list_data_collate,
    pin_memory=torch.cuda.is_available(),
)

# create a validation data loader
val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=4, num_workers=4, collate_fn=list_data_collate)

# Define metric and post-processing

In [None]:
dice_metric = DiceMetric(include_background=False, reduction="mean", get_not_nans=False)
post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold=0.5)])

# Set Environment
Select GPU

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Create Visualize Function

In [None]:
def visualize(**images):
    """PLot images in one row."""
    n = len(images)
    plt.figure(figsize=(16, 16))
    for i, (name, image) in enumerate(images.items()):
        plt.subplot(1, n, i + 1)
        plt.xticks([])
        plt.yticks([])
        plt.title(' '.join(name.split('_')).title())
        plt.imshow(image, cmap= 'gray')
    plt.show()

# Build Model
>Decoder : PAN
>
>Encoder : tu-tf_efficientnetv2_s_in21ft1k

In [None]:
import segmentation_models_pytorch as smp

In [None]:
aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.4,               # dropout ratio, default is None
    activation=None,      # activation function, default is None
    classes=1,                 # define number of output labels
)

encoder_name = 'tu-tf_efficientnetv2_s_in21ft1k'
model = smp.PAN(encoder_name, aux_params=aux_params).to(device)
model_name = encoder_name + '_PAN'

## -set optimizer & loss function
Choose AdaBelief

In [None]:
from adabelief_pytorch import AdaBelief
optimizer = AdaBelief(model.parameters(), lr=1e-4, eps=1e-16, betas=(0.9, 0.98523), weight_decouple = True, rectify = False, weight_decay = 1e-4)
loss_function = monai.losses.DiceLoss(sigmoid=True)

## -start traning

In [None]:
#### start a typical PyTorch training
total_epochs = 64
val_interval = 1
best_metric = 100   # best model treshold
best_metric_epoch = -1  # epoch of saved model
epoch_loss_values = list()   
metric_values = list()
writer = SummaryWriter() 
train_loss = []
val_loss = []

for epoch in range(total_epochs):
    print("-" * 10)
    print(f"epoch {epoch + 1}/{total_epochs}")

    # start training
    model.train()
    epoch_loss = 0
    step = 0
    time_start = time.time()

    for batch_data in enumerate(train_loader):
        step += 1
        time_end = time.time()
        epoch_len = len(train_ds) // train_loader.batch_size  
        inputs, labels = batch_data["img"].to(device), batch_data["seg"].to(device)
        optimizer.zero_grad()
        outputs, label = model(inputs) 
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += float(loss.item())
        # print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}")
        print(f"{step}/{epoch_len} RUN   Use {'%.3f'%(time_end - time_start)}s   ||{'-'*step+'>>'+'-'*(epoch_len-step)}||" , end='\r')
        writer.add_scalar("train_loss", loss.item(), epoch_len * epoch + step)

    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    local_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
    train_loss.append(epoch_loss)  
    print("\n", f"{local_time} epoch {epoch + 1} training average loss: {epoch_loss:.4f}")
    print('1 Epoch Finished')

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            val_images = None
            val_labels = None
            val_outputs = None
            val_step = 0
            loss_val = 0

            for val_data in val_loader:
                val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)
                val_outputs, val_out_label = model(val_images)
                val_outputs = Resize([-1, 1716, 942])(val_outputs)
                val_loss = monai.losses.DiceLoss(sigmoid=True)(val_outputs, val_labels)
                loss_val += float(val_loss)
                val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]                
                    # val_labels = [post_trans(i) for i in decollate_batch(val_labels)]

                # current validation
                if  val_step == 16:
                    print("val loss", val_loss)
                    visualize( 
                        image=val_images[0].cpu().permute(1,2,0), 
                        ground_truth_mask=val_labels[0].cpu().permute(1,2,0), 
                        predicted_mask=val_outputs[0].cpu().permute(1,2,0)
                    )  
                val_step += 1

                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            val_loss_ave = loss_val / val_step
            print("val_loss = ", loss_val / val_step)

            # reset the status for next validation round
            dice_metric.reset()
            metric_values.append(val_loss_ave)
            val_loss.append(val_loss_ave)

            if val_loss_ave < best_metric:
                best_metric = val_loss_ave
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), f"best_{total_epochs}_epochs_{model_name}.pth")
                print("saved new best metric model")
            print(
                "current epoch: {} current val mean dice loss: {:.4f} best val mean dice loss: {:.4f} at epoch {}".format(
                    epoch + 1, val_loss_ave, best_metric, best_metric_epoch
                )
            )

print(f"training completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}")

# Reproduce Validation Data

## -load previous model

In [None]:
model.load_state_dict(torch.load(f"{encoder_name}.pth"))

## -reproduce validation data

In [None]:
model.eval()
dice_metric.reset()
with torch.no_grad():
    val_images = None
    val_labels = None
    val_outputs = None
    for val_data in val_loader:
        val_images, val_labels = val_data["img"].to(device), val_data["seg"].to(device)

        val_outputs, val_out_label = model(val_images) #forward
        val_outputs = Resize([-1, 1716, 942])(val_outputs)
        val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]                

        # compute metric for current iteration
        dice_metric(y_pred=val_outputs, y=val_labels)
        print(dice_metric.aggregate())
        
    # aggregate the final mean dice result
    metric = dice_metric.aggregate().item()
    print("metric = ", metric)
    # reset the status for next validation round
    dice_metric.reset()

# Try Public Data

In [None]:
# load data
tempdir = data_path + "Public_Image/"
test_images = sorted(glob.glob(os.path.join(tempdir, "*.jpg")))
print(f" {len(test_images)} test_images")

test_files = [{"img": img} for img in test_images[:]]

In [None]:
# define transform
test_transforms = Compose(
    [
        LoadImaged(keys=["img"]),   
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=[800, 800]),
        EnsureTyped(keys=["img"])
    ]
)
test_ds = monai.data.Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1,  collate_fn=list_data_collate)

In [None]:
pub_data = sorted(glob.glob(tempdir + "*.jpg"))
pub_data[0].split("/")[-1].split(".")[0]

In [None]:
model.eval()
with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        test_images = test_data["img"].to(device)

        test_outputs, test_out_label = model(test_images) #forward
        test_outputs = Resize([-1, 1716, 942])(test_outputs)
        
        saverPD = SaveImage(output_dir=f"{os.path.join(data_path, 'Predict/')}", output_ext=".png", output_postfix=f"{pub_data[i].split('/')[-1].split('.')[0]}",scale=255,separate_folder=False)
        saverPD(test_outputs[0].cpu())

# Try Private Data

In [None]:
# load data
tempdir = data_path + "Private_Image/"
test_images = sorted(glob.glob(os.path.join(tempdir, "*.jpg")))
print(f" {len(test_images)} test_images")

test_files = [{"img": img} for img in test_images[:]]

In [None]:
# define transform
test_transforms = Compose(
    [
        LoadImaged(keys=["img"]),   
        AsChannelFirstd(keys=["img"]),
        ScaleIntensityd(keys=["img"]),
        Resized(keys=["img"], spatial_size=[800, 800]),
        EnsureTyped(keys=["img"])
    ]
)
test_ds = monai.data.Dataset(data=test_files, transform=test_transforms)
test_loader = DataLoader(test_ds, batch_size=1,  collate_fn=list_data_collate)

In [None]:
pri_data = sorted(glob.glob(tempdir + "*.jpg"))
pri_data[0].split("/")[-1].split(".")[0]

In [None]:
model.eval()
with torch.no_grad():
    for i, test_data in enumerate(test_loader):
        test_images = test_data["img"].to(device)

        test_outputs, test_out_label = model(test_images) 
        test_outputs = Resize([-1, 1716, 942])(test_outputs)
        
        saverPD = SaveImage(output_dir=f"{os.path.join(data_path, 'Predict/')}", output_ext=".png", output_postfix=f"{pri_data[i].split('/')[-1].split('.')[0]}",scale=255,separate_folder=False)
        saverPD(test_outputs[0].cpu())

# Change File Name

In [None]:
predict = sorted(glob.glob(f"{os.path.join(data_path, 'Predict/')}"))
print(len(predict))

for pred in predict:
    os.rename(pred, os.path.join(*pred.split("/")[:-1], pred.split("/")[-1].split("_", 1)[-1]))