### Import Modules

In [None]:
import os
import csv
import nibabel as nib
import numpy as np
import glob
import torch
import torch.nn.functional as F
import time
import logging
import matplotlib.pyplot as plt
import sys
from sklearn.model_selection import train_test_split
from monai.data import Dataset, DataLoader, decollate_batch
from monai.transforms import (
    Compose,
    LoadImage, LoadImaged,
    EnsureChannelFirstd,
    Resize, Resized,
    ScaleIntensityd,
    RandRotate90d,
    RandFlipd,
    Activations,
    AsDiscrete
)
from monai.networks.nets import UNet, VNet, DynUNet, AttentionUnet, ResNet, SegResNet, UNETR, SwinUNETR
from monai.losses import DiceLoss
from monai.metrics import DiceMetric
from monai.utils import first, set_determinism
from monai.visualize.utils import blend_images
%env CUDA_VISIBLE_DEVICES=0

### Define Functions and Classes

In [None]:
def load_image_with_metadata(image_file):
    loader = LoadImage(image_only=False)
    image, metadata = loader(image_file)
    return {"brain": image, "affine": metadata['affine'], "original_dim": image.shape, "filename": os.path.basename(image_file)}

def load_data(data_dir, resize_dim, batch_size, test_size=0.2, inference=False):
    train_transforms = [
        LoadImaged(keys=["brain", "lesion"], image_only=True),
        EnsureChannelFirstd(keys=["brain", "lesion"]),
        ScaleIntensityd(keys="brain"),
        RandRotate90d(keys=["brain", "lesion"], prob=0.8, spatial_axes=[0, 2]),
        RandFlipd(keys=["brain", "lesion"], spatial_axis=0, prob=0.5)
    ]
    val_transforms = [
        LoadImaged(keys=["brain", "lesion"], image_only=True),
        EnsureChannelFirstd(keys=["brain", "lesion"]),
        ScaleIntensityd(keys="brain")
    ]
    test_transforms = [
        EnsureChannelFirstd(keys="brain"),
        Resized(keys="brain", spatial_size=resize_dim, mode="trilinear"),
        ScaleIntensityd(keys="brain")
    ]
    if resize_dim is not None:
        trainval_resize_transform = Resized(keys=["brain", "lesion"], spatial_size=resize_dim, mode=["trilinear", "nearest"])
        test_resize_transform = Resized(keys="brain", spatial_size=resize_dim, mode="trilinear")
        train_transforms.insert(2, trainval_resize_transform)
        val_transforms.insert(2, trainval_resize_transform)
        test_transforms.insert(1, test_resize_transform)
    train_transforms = Compose(train_transforms)
    val_transforms = Compose(val_transforms)
    test_transforms = Compose(test_transforms)
    if not inference:
        brains = sorted(glob.glob(os.path.join(data_dir, "Brain", "*.nii.gz")))
        lesions = sorted(glob.glob(os.path.join(data_dir, "Lesion", "*.nii.gz")))
        data_dicts = [
            {"brain": brain_name, "lesion": lesion_name}
            for brain_name, lesion_name in zip(brains, lesions)
        ]
        train_files, val_files = train_test_split(data_dicts, test_size=test_size, random_state=42)
        train_ds = Dataset(data=train_files, transform=train_transforms)
        val_ds = Dataset(data=val_files, transform=val_transforms)
        train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=torch.cuda.is_available())
        val_loader = DataLoader(val_ds, batch_size=batch_size, num_workers=0, pin_memory=torch.cuda.is_available())
        return train_loader, val_loader
    else:
        brains = sorted(glob.glob(os.path.join(data_dir, "Brain", "*.nii.gz")))
        data_dicts = [load_image_with_metadata(brain_name) for brain_name in brains]
        test_ds = Dataset(data=data_dicts, transform=test_transforms)
        test_loader = DataLoader(test_ds, batch_size=1, num_workers=0, pin_memory=torch.cuda.is_available()) # 추론 시 일반적으로 배치 처리의 이점이 학습 시만큼 크지 않습니다.
        return test_loader

def get_model(model_name, img_size, num_classes=1):
    if model_name == "UNet":
        return UNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=num_classes,
            channels=(32, 64, 128, 256),
            strides=(2, 2, 2),
            num_res_units=2,
        )
    elif model_name == "VNet":
        return VNet(
            spatial_dims=3,
            in_channels=1,
            out_channels=num_classes,
            dropout_dim=3,
            dropout_prob_down=0.2,
            dropout_prob_up=(0.2, 0.2, 0.2),
        )
    elif model_name == "AttentionUnet":
        return AttentionUnet(
            spatial_dims=3,
            in_channels=1,
            out_channels=num_classes,
            channels=(32, 64, 128, 256),
            strides=(2, 2, 2),
        )
    elif model_name == "SegResNet":
        return SegResNet(
            spatial_dims=3,
            init_filters=32,
            in_channels=1,
            out_channels=num_classes,
            dropout_prob=0.3,
            use_conv_final=True,
            blocks_down=(1, 2, 2),
            blocks_up=(1, 1),
        )
    elif model_name == "UNETR":
        return UNETR(
            in_channels=1,
            out_channels=num_classes,
            img_size=img_size,
            feature_size=16,
            hidden_size=768,
            mlp_dim=3072,
            num_heads=12,
            pos_embed='perceptron',
            norm_name='instance',
            res_block=True,
            dropout_rate=0.1,
        )
    elif model_name == "SwinUNETR":
        return SwinUNETR(
            img_size=img_size,
            in_channels=1,
            out_channels=num_classes,
            feature_size=24,
            use_checkpoint=True,
        )
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

def train_one_epoch(model, device, train_loader, optimizer, criterion, scaler, metric, post_pred):
    model.train()
    epoch_loss = 0.0
    metric.reset()
    for batch_data in train_loader:
        images, labels = (
            batch_data["brain"].to(device),
            batch_data["lesion"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(images)
        with torch.cuda.amp.autocast(enabled=scaler is not None):
            loss = criterion(outputs, labels)
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        epoch_loss += loss.item()
        outputs = [post_pred(i) for i in decollate_batch(outputs)]
        metric(y_pred=outputs, y=labels)
    epoch_metric = metric.aggregate().item()
    return epoch_loss / len(train_loader), epoch_metric

def validate_one_epoch(model, device, val_loader, metric, post_pred):
    model.eval()
    metric.reset()
    with torch.no_grad():
        for batch_data in val_loader:
            images, labels = (
                batch_data["brain"].to(device),
                batch_data["lesion"].to(device),
            )
            outputs = model(images)
            outputs = [post_pred(i) for i in decollate_batch(outputs)]
            metric(y_pred=outputs, y=labels)
    return metric.aggregate().item()

class EarlyStopping:
    def __init__(self, patience=30, delta=0):
        self.patience = patience
        self.delta = delta
        self.best_score = None
        self.early_stop = False
        self.counter = 0
    def __call__(self, metric):
        score = metric
        if self.best_score is None:
            self.best_score = score
        elif score < self.best_score - self.delta:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.counter = 0

def train_model(model_dir, model, device, train_loader, val_loader, logger,
        criterion, metric, post_pred, max_epochs=100, learning_rate=1e-4, weight_decay=1e-5, val_interval=1, es_patience=30):
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=max_epochs)
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None
    start_time = time.time()
    best_metric = -1
    best_metric_epoch = -1
    best_model_state = None
    early_stopping = EarlyStopping(patience=es_patience, delta=0)
    epoch_loss_values, epoch_metric_values, metric_values = [], [], []
    for epoch in range(max_epochs):
        epoch_start_time = time.time()
        epoch_loss, epoch_metric = train_one_epoch(model, device, train_loader, optimizer, criterion, scaler, metric, post_pred)
        epoch_loss_values.append(epoch_loss)
        epoch_metric_values.append(epoch_metric)
        if (epoch + 1) % val_interval == 0:
            val_metric = validate_one_epoch(model, device, val_loader, metric, post_pred)
            metric_values.append(val_metric)
            if val_metric > best_metric:
                best_metric = val_metric
                best_metric_epoch = epoch + 1
                best_model_state = model.state_dict()
                torch.save(model.state_dict(), os.path.join(model_dir, "BestMetricModel.pth"))
                logger.info(f"Best DSC: {best_metric:.4f} at epoch {best_metric_epoch}")
            early_stopping(val_metric)
            if early_stopping.early_stop:
                logger.info(f"Early stopping triggered at epoch {epoch + 1}")
                print(f"; Early stopping triggered at epoch {epoch + 1}", end="")
                break
        epoch_end_time = time.time()
        logger.info(f"Epoch {epoch + 1} completed for {(epoch_end_time - epoch_start_time)/60:.2f} mins - Training loss: {epoch_loss:.4f}, Training DSC: {epoch_metric:.4f}, Validation DSC: {val_metric:.4f}")
        lr_scheduler.step()
        sys.stdout.write(f"\rEpoch {epoch + 1}/{max_epochs} completed")
        sys.stdout.flush()
    end_time = time.time()
    total_time = end_time - start_time
    logger.info(f"Best DSC: {best_metric:.3f} at epoch {best_metric_epoch}; Total time consumed: {total_time/60:.2f} mins")
    print(f"\nBest DSC: {best_metric:.3f} at epoch {best_metric_epoch}; Total time consumed: {total_time/60:.2f} mins")
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    return model, epoch_loss_values, epoch_metric_values, metric_values

def plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval=1):
    _, axs = plt.subplots(1, 2, figsize=(8, 5))
    axs[0].plot( [i + 1 for i in range(len(epoch_loss_values))], epoch_loss_values, label='Training Loss', color='red')
    axs[0].set_title('Training Loss')
    axs[0].set_xlabel('Epoch')
    axs[0].set_ylabel('Loss')
    axs[1].plot([i + 1 for i in range(len(epoch_metric_values))], epoch_metric_values, label='Training DSC', color='red')
    axs[1].plot([val_interval * (i + 1) for i in range(len(metric_values))], metric_values, label='Validation DSC', color='blue')
    axs[1].set_title('Training DSC vs. Validation DSC')
    axs[1].set_xlabel('Epoch')
    axs[1].set_ylabel('DSC')
    axs[1].legend()
    plt.tight_layout()
    plt.savefig(os.path.join(model_dir, "Performance.png"), dpi=300)

def apply_best_model(model_dir, model, device, test_loader, post_pred, pred_dir, resize_dim): 
    model.load_state_dict(torch.load(os.path.join(model_dir, "BestMetricModel.pth")))
    model.eval()
    os.makedirs(pred_dir, exist_ok=True)
    with torch.no_grad():
        for data in test_loader:
            brain = data["brain"].to(device)
            affine = data["affine"][0]
            original_dim = (*(t.item() for t in data["original_dim"]),)
            filename = data["filename"][0]
            output = model(brain)
            output = post_pred(output[0])
            if resize_dim is not None:
                resize_to_original = Resize(spatial_size=original_dim, mode="nearest")
                output = resize_to_original(output)
            nifti_image = nib.Nifti1Image(output.squeeze().detach().cpu().numpy(), affine)
            pred_file = os.path.join(pred_dir, filename)
            nib.save(nifti_image, pred_file)

class GradCAM3D:
    def __init__(self, model, target_layer, criterion):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.model.eval()
        target_layer.register_forward_hook(self.save_activation)
        target_layer.register_full_backward_hook(self.save_gradient)
        self.loss = criterion
    def save_activation(self, _module, _input, output):
        self.activations = output.detach()
    def save_gradient(self, _module, _grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    def __call__(self, brain, lesion):
        self.model.zero_grad()
        output = self.model(brain)
        loss = self.loss(output, lesion)
        loss.backward()
        gradients = self.gradients
        activations = self.activations
        weights = torch.mean(gradients, dim=(2, 3, 4), keepdim=True)
        cam = torch.sum(weights * activations, dim=1, keepdim=True)
        cam = F.relu(cam)
        cam = F.interpolate(cam, size=brain.shape[2:], mode='trilinear', align_corners=False)
        cam = cam - torch.min(cam)
        cam = cam / torch.max(cam)
        return cam.squeeze(0)

def apply_gradcam_to_sample(model, val_loader, target_layers, device, criterion, sample_index, slice_index):
    brain = val_loader.dataset[sample_index]["brain"].to(device) 
    lesion = val_loader.dataset[sample_index]["lesion"].to(device)
    _, axes = plt.subplots(len(target_layers), 3, figsize=(12, 4 * len(target_layers)))
    brain_slice = torch.rot90(brain[0, :, :, slice_index], k=1, dims=(0, 1))
    lesion_slice = torch.rot90(lesion[0, :, :, slice_index], k=1, dims=(0,1))
    for i, (layer_name, target_layer) in enumerate(target_layers.items()):
        grad_cam = GradCAM3D(model, target_layer, criterion)
        with torch.no_grad():
            _ = model(brain.unsqueeze(0))
        cam = grad_cam(brain.unsqueeze(0), lesion.unsqueeze(0))
        cam_slice = torch.rot90(cam[0, :, :, slice_index], k=1, dims=(0,1))
        axes[i, 0].imshow(brain_slice.detach().cpu(), cmap='gray')
        axes[i, 0].set_title("Brain")
        axes[i, 0].axis('off')
        axes[i, 1].imshow(lesion_slice.detach().cpu(), cmap='gray')
        axes[i, 1].set_title("Lesion")
        axes[i, 1].axis('off')
        axes[i, 2].imshow(brain_slice.detach().cpu(), cmap='gray')
        axes[i, 2].imshow(cam_slice.detach().cpu(), cmap='jet', alpha=0.5)
        axes[i, 2].set_title(f"GradCAM: {layer_name}")
        axes[i, 2].axis('off')
    plt.tight_layout()
    plt.show()

def get_target_layers(model, layer_indices=[-1]):
    target_layers = {}
    all_layers = list(model.modules())
    conv_layers = [layer for layer in all_layers if isinstance(layer, (torch.nn.Conv2d, torch.nn.Conv3d))]
    for index in layer_indices:
        if index == -1 or index == len(conv_layers) - 1:
            layer_index = len(conv_layers) - 1
            layer_name = "Final Conv"
        elif index < 0:
            layer_index = len(conv_layers) + index
            layer_name = f"Conv {abs(index)} from End"
        elif index < len(conv_layers):
            layer_index = index
            leyer_name = f"Conv {index}"
        target_layers[layer_name] = conv_layers[layer_index]
    return target_layers

### Prepare Inputs

In [None]:
data_dir = "LesionSegmentation_2mm"
model_dir_prefix = "LesionSegmentation"
model_name = "SegResNet"  # any supported model name: UNet, VNet, AttentionUnet, SegResNet, UNETR, SwinUNETR
num_classes = 1  # for binary segmentation
resize_dim = (64, 64, 64)
test_size = 0.2
batch_size = 5
max_epochs = 100
learning_rate = 1e-4
weight_decay = 1e-5
val_interval = 1
es_patience = 30

model_dir = f"{model_dir_prefix}_{model_name}"
os.makedirs(model_dir, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print('Device:', device)
log_file = os.path.join(model_dir, "Prediction.log")
logging.basicConfig(filename=log_file, level=logging.INFO, format="%(message)s")
logger = logging.getLogger()

### Read Data

In [None]:
set_determinism(seed=0)
train_loader, val_loader = load_data(os.path.join(data_dir, "train"), resize_dim, batch_size, test_size)

# Check data shape
tr = first(train_loader)
print('\nData shape for training:')
for key, value in tr.items():
    print(f'\u2022 {key}: {tuple(value.shape)} \u00D7 {len(train_loader)}')
    img_size = tuple(value.shape[-3:])
vl = first(val_loader)
print('\nData shape for validation:')
for key, value in vl.items():
    print(f'\u2022 {key}: {tuple(value.shape)} \u00D7 {len(val_loader)}')

# Visualize data
_, axs = plt.subplots(1, 3, figsize=(12, 5))
brain = tr["brain"][0, :, :, :, :].detach().cpu()
lesion = tr["lesion"][0, :, :, :, :].detach().cpu()
slice_index = 29
brain_slice = torch.rot90(brain[0, :, :, slice_index], k=1, dims=(0, 1))
lesion_slice = torch.rot90(lesion[0, :, :, slice_index], k=1, dims=(0,1))
blended = blend_images(brain, lesion, alpha=0.5)
blended_slice = torch.rot90(blended[0, :, :, slice_index], k=1, dims=(0,1)).squeeze()
axs[0].imshow(brain_slice, cmap='gray')
axs[0].set_title("Brain")
axs[0].axis('off')
axs[1].imshow(lesion_slice, cmap='gray')
axs[1].set_title("Lesion")
axs[1].axis('off')
axs[2].imshow(blended_slice, cmap='gray')
axs[2].set_title("Lesion-overlaied Brain")
axs[2].axis('off')
plt.tight_layout()
plt.show()

### Train Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = get_model(model_name, img_size, num_classes).to(device)
print(f"Selected model: {model_name}")
num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Number of trainable parameters: {num_params}")

criterion = DiceLoss(smooth_nr=0, smooth_dr=1e-5, squared_pred=True, sigmoid=True)
metric = DiceMetric(include_background=True, reduction="mean")
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
model, epoch_loss_values, epoch_metric_values, metric_values = train_model(model_dir, model, device, train_loader, val_loader, logger,
    criterion, metric, post_pred, max_epochs, learning_rate, weight_decay, val_interval)
plot_metric_values(model_dir, epoch_loss_values, epoch_metric_values, metric_values, val_interval)

# Visualize outcome
sample_index = 5
slice_index = 32
model.eval()
metric.reset()
with torch.no_grad():
    brain = val_loader.dataset[sample_index]["brain"].to(device)
    lesion = val_loader.dataset[sample_index]["lesion"].to(device)
    output = model(brain.unsqueeze(0))
    output = post_pred(output).squeeze(0)
    metric(y_pred=output, y=lesion)
sample_metric = metric.aggregate().item()
_, axs = plt.subplots(1, 3, figsize=(12, 5))
brain_slice = torch.rot90(brain[0, :, :, slice_index], k=1, dims=(0, 1))
lesion_slice = torch.rot90(lesion[0, :, :, slice_index], k=1, dims=(0,1))
output_slice = torch.rot90(output[0, :, :, slice_index], k=1, dims=(0,1))
axs[0].imshow(brain_slice.detach().cpu(), cmap="gray")
axs[0].set_title("Brain")
axs[0].axis('off')
axs[1].imshow(lesion_slice.detach().cpu(), cmap="gray")
axs[1].set_title("Lesion")
axs[1].axis('off')
axs[2].imshow(output_slice.detach().cpu(), cmap="gray")
axs[2].set_title(f"Predicted Lesion: DSC = {sample_metric:.3f}")
axs[2].axis('off')
plt.tight_layout()
plt.show()

### GradCAM

In [None]:
sample_index = 5 
slice_index = 32
# target_layers = get_target_layers(model)
target_layers = {
    "Encoder Last": model.down_layers[-1][-1],
    "Decoder First": model.up_layers[0][0],
    "Final Conv": model.conv_final[-1]
}
apply_gradcam_to_sample(model, val_loader, target_layers, device, criterion, sample_index, slice_index)

### Inference

In [None]:
test_loader = load_data(os.path.join(data_dir, "test"), resize_dim, None, None, True)
pred_dir = os.path.join(model_dir, "Prediction")
apply_best_model(model_dir, model, device, test_loader, post_pred, pred_dir, resize_dim)