# Legion Segmentation with SegResNet

## Setup imports

In [None]:
import torch
import numpy as np
import glob
import os
import logging
import time
import matplotlib.pyplot as plt
%matplotlib inline

from monai.config import print_config
from monai.data import ArrayDataset, decollate_batch, DataLoader
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.networks.nets import SegResNet
from monai.transforms import (
    Compose,
    LoadImage,
    EnsureChannelFirst,
    Resize,
    NormalizeIntensity,
    Activations,
    AsDiscrete,
    SqueezeDim,
    SaveImage,
)
from monai.utils import first, set_determinism

print_config()

## Set deterministic training for reproducibility

In [5]:
set_determinism(seed=0)

## Setup data directory and data

In [None]:
root_dir = "C:\\LesionSegmentation\\train"
print(root_dir)
images = sorted(glob.glob(os.path.join(root_dir, "brain*.nii.gz")))
segs = sorted(glob.glob(os.path.join(root_dir, "lesionmask*.nii.gz")))

## Setup logging

In [7]:
log_file = os.path.join(root_dir, "demo.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(asctime)s -  %(message)s")
logger = logging.getLogger()

## Setup transforms, dataset

In [None]:
# Define transforms for image and segmentation
train_imtrans = Compose(
    [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        Resize((64, 64, 64), mode="trilinear"),
        NormalizeIntensity(nonzero=True, channel_wise=True),
    ]
)
train_segtrans = Compose(
    [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        Resize((64, 64, 64), mode="nearest"),
    ]
)
val_imtrans = Compose(
    [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        Resize((64, 64, 64), mode="trilinear"),
        NormalizeIntensity(nonzero=True, channel_wise=True),
    ]
)
val_segtrans = Compose(
    [
        LoadImage(image_only=True),
        EnsureChannelFirst(),
        Resize((64, 64, 64), mode="nearest"),
    ]
)

# Define NIfTI dataset, dataloader
train_ds = ArrayDataset(images[:500], train_imtrans, segs[:500], train_segtrans)
train_loader = DataLoader(train_ds, batch_size=5, shuffle=True, num_workers=8, pin_memory=torch.cuda.is_available())
val_ds = ArrayDataset(images[500:], val_imtrans, segs[500:], val_segtrans)
val_loader = DataLoader(val_ds, batch_size=5, num_workers=2, pin_memory=torch.cuda.is_available())

# Check data shape and visualize
im, seg = first(train_loader)
print(im.shape, seg.shape)

## Check data shape and visualize

In [None]:
fig = plt.figure("Example image and mask for training", (12, 6))
ax1 = fig.add_subplot(1, 2, 1)
ax1.set_title("Image")
ax1.imshow(im[0,0,:, :, 30].detach().cpu(), cmap="gray")
ax1.axis('off')
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_title("Mask")
ax2.imshow(seg[0,0,:, :, 30].detach().cpu(), cmap="gray")
ax2.axis('off')
plt.savefig(os.path.join(root_dir, "image_mask.tif"))
plt.show

## Create model

In [23]:
max_epochs = 100
val_interval = 1
lr = 1e-4

# Create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda")
# device = torch.device("mps")
# device = torch.device("cpu")
model = SegResNet(
    blocks_down=(1, 2, 2, 4),
    blocks_up=(1, 1, 1),
    init_filters=16,
    in_channels=1,
    out_channels=1,
    dropout_prob=0.2,
).to(device)

loss_function = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, sigmoid=True)
optimizer = torch.optim.Adam(model.parameters(), lr, weight_decay=1e-5)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)

dice_metric = DiceMetric(include_background=True, reduction="mean")

post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
post_label = Compose([AsDiscrete(threshold=0.5)])
    
# Use AMP to accelerate training
if torch.cuda.is_available():
    scaler = torch.cuda.amp.GradScaler()

## Execute training process

In [None]:
best_metric = -1
best_metric_epoch = -1
best_metrics_epochs_and_time = [[], [], []]
epoch_loss_values = []
epoch_metric_values = []
metric_values = []

total_start = time.time()
for epoch in range(max_epochs):
    epoch_start = time.time()
    print("-" * 10)
    print(f"epoch {epoch + 1}/{max_epochs}")
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step_start = time.time()
        step += 1
        im, seg = batch_data
        inputs, labels = (
            im.to(device),
            seg.to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        if torch.cuda.is_available():
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        epoch_loss += loss.item()
        print(
            f"{step}/{len(train_ds) // train_loader.batch_size}"
            f", train_loss: {loss.item():.4f}"
            f", step time: {(time.time() - step_start):.4f}"
        )
    lr_scheduler.step()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    logger.info(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")

    outputs = [post_pred(i) for i in decollate_batch(outputs)]
    labels = [post_label(i) for i in decollate_batch(labels)]
    dice_metric(y_pred=outputs, y=labels)
    metric = dice_metric.aggregate().item()
    epoch_metric_values.append(metric)

    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                im, seg = val_data
                val_inputs, val_labels = (
                    im.to(device),
                    seg.to(device),
                )
                val_outputs = model(val_inputs)
                val_outputs = [post_pred(i) for i in decollate_batch(val_outputs)]
                val_labels = [post_label(i) for i in decollate_batch(val_labels)]
                dice_metric(y_pred=val_outputs, y=val_labels)

            metric = dice_metric.aggregate().item()
            metric_values.append(metric)
            dice_metric.reset()

            if metric > best_metric:
                best_metric = metric
                best_metric_epoch = epoch + 1
                best_metrics_epochs_and_time[0].append(best_metric)
                best_metrics_epochs_and_time[1].append(best_metric_epoch)
                best_metrics_epochs_and_time[2].append(time.time() - total_start)
                torch.save(
                    model.state_dict(),
                    os.path.join(root_dir, "best_metric_model.pth"),
                )
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {metric:.4f}"
                f"\nbest mean dice: {best_metric:.4f}"
                f" at epoch: {best_metric_epoch}"
            )
            logger.info(f"epoch {epoch + 1} mean dice: {metric:.4f}")
    print(f"time consuming of epoch {epoch + 1} is: {(time.time() - epoch_start):.4f}")

## Plot loss and metric

In [None]:
total_time = time.time() - total_start
print(f"train completed, best_metric: {best_metric:.4f} at epoch: {best_metric_epoch}, total time: {total_time}.")
logger.info(
    f"best_metric: {best_metric:.4f} at epoch {best_metric_epoch}, "
    f"total time to train: {total_time}"
    )

fig = plt.figure("Performance in training", (12, 6))
ax1 = fig.add_subplot(1, 2, 1)
ax1.set_title("Loss")
x = [i + 1 for i in range(len(epoch_loss_values))]
y = epoch_loss_values
ax1.plot(x, y, color="red")
ax1.set_xlabel("Epoch")
ax1.set_ylabel("Loss")
ax2 = fig.add_subplot(1, 2, 2)
ax2.set_title("DSC")
x1 = [i + 1 for i in range(len(epoch_loss_values))]
x2 = [val_interval * (i + 1) for i in range(len(metric_values))]
y1 = epoch_metric_values
y2 = metric_values
ax2.plot(x1, y1, color="red")
ax2.plot(x2, y2, color="blue")
ax2.set_xlabel("Epoch")
ax2.set_ylabel("DSC")
ax2.legend(["Train","Validation"])
plt.savefig(os.path.join(root_dir, "performance.tif"))
plt.show

## Check best model output

In [None]:
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
with torch.no_grad():
    # Select one image to evaluate and visualize the model output
    val_input = val_ds[6][0].unsqueeze(0).to(device)
    val_output = model(val_input)
    val_output = post_pred(val_output[0])

fig = plt.figure("Actual vs. Predicted", (12, 6))
ax1 = fig.add_subplot(1, 3, 1)
ax1.set_title("Image")
ax1.imshow(val_ds[6][0][0,:, :, 30].detach().cpu(), cmap="gray")
ax1.axis('off')
ax2 = fig.add_subplot(1, 3, 2)
ax2.set_title("Actual mask")
ax2.imshow(val_ds[6][1][0,:, :, 30].detach().cpu())
ax2.axis('off')
ax3 = fig.add_subplot(1, 3, 3)
ax3.set_title("Predicted mask")
ax3.imshow(val_output[0,:, :, 30].detach().cpu())
ax3.axis('off')
plt.savefig(os.path.join(root_dir, "actual_predicted.tif"))
plt.show

## Apply best model

In [None]:
# Define NIfTI dataset, dataloader
test_root_dir = "C:\\LesionSegmentation\\test"
test_images = sorted(glob.glob(os.path.join(test_root_dir, "brain*.nii.gz")))
test_ds = ArrayDataset(test_images, val_imtrans)

# Apply the best model and save predictions
model.load_state_dict(torch.load(os.path.join(root_dir, "best_metric_model.pth")))
model.eval()
post_testpred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5), SqueezeDim(dim=0)])
saver = SaveImage(output_dir=test_root_dir, output_ext=".nii.gz", output_postfix="seg",
                  separate_folder=False, print_log=False,
                  resample=True, mode='nearest', output_dtype=np.uint8)

with torch.no_grad():
    for idx in range(len(test_ds)):
        test_input = test_ds[idx].unsqueeze(0).to(device)
        test_output = model(test_input)[0]
        test_output = post_testpred(test_output)

        # To save as NIfTI files
        saver(test_output)

        # To save Numpy array files
        # fname = os.path.join(test_root_dir, os.path.split(test_images[idx])[1].replace(".nii.gz", "_seg.npy"))
        # with open(fname, 'wb') as f:
        #     np.save(f, test_output)