## Setup imports

In [None]:
from monai.transforms import (
    AsDiscrete,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    ScaleIntensityRanged,
    Spacingd,
)
from monai.metrics import DiceMetric,SurfaceDistanceMetric,HausdorffDistanceMetric,MeanIoU
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
import matplotlib.pyplot as plt
import os
import glob
import wandb
import SimpleITK as sitk

## Set train/validation/test data filepath

In [2]:
data_dir = "/Your/Path/To/Datafolder"
train_images = sorted(glob.glob(os.path.join(data_dir,"train" ,"image", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir,"train" , "mask", "*.nii.gz")))
train_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]

val_images = sorted(glob.glob(os.path.join(data_dir,"val" ,"image", "*.nii.gz")))
val_labels = sorted(glob.glob(os.path.join(data_dir,"val" , "mask", "*.nii.gz")))
val_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(val_images, val_labels)]

test_images = sorted(glob.glob(os.path.join(data_dir,"test" ,"image", "*.nii.gz")))
test_labels = sorted(glob.glob(os.path.join(data_dir,"test" , "mask", "*.nii.gz")))
test_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(test_images, test_labels)]

## Setup data augmentation

For data augmentation, here are the basic requirements:

1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape.
1. `ScaleIntensityRanged` clips the CT's data format, HU value, into a certain range (-57,164) and normalize it to (0,1)
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.

You can try more data augmentation techniques to further improve the performance.

In [None]:
train_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
        RandCropByPosNegLabeld(
            keys=["image", "label"],
            label_key="label",
            spatial_size=(96, 96, 96),
            pos=1,
            neg=1,
            num_samples=4,
            image_key="image",
            image_threshold=0,
        ),
    ]
)
val_transforms = Compose(
    [
        LoadImaged(keys=["image", "label"]),
        EnsureChannelFirstd(keys=["image", "label"]),
        ScaleIntensityRanged(
            keys=["image"],
            a_min=-57,
            a_max=164,
            b_min=0.0,
            b_max=1.0,
            clip=True,
        ),
        CropForegroundd(keys=["image", "label"], source_key="image"),
        Orientationd(keys=["image", "label"], axcodes="RAS"),
        Spacingd(keys=["image", "label"], pixdim=(1.5, 1.5, 2.0), mode=("bilinear", "nearest")),
    ]
)

In [4]:
import os
import torch
from torch.utils.data import Dataset, DataLoader



class CT_Dataset(Dataset):
    def __init__(self, dataset_path, transform=None,split='test'):
        self.data=dataset_path
        self.transform=transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data_dict = self.data[idx]
        if self.transform:
            data_dict = self.transform(data_dict)
        return data_dict
        


# here we don't cache any data in case out of memory issue
train_ds = CT_Dataset(train_data_dicts,train_transforms,split='train')
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=16)
val_ds = CT_Dataset(val_data_dicts,val_transforms,split='val')
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)
test_ds = CT_Dataset(test_data_dicts,val_transforms,split='test')
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=4)

# Implement a 3D UNet for segmentation task

We give a possible network structure here, and you can modify it for a stronger performance.

In the block ```double_conv```, you can implement the following structure：

| Layer |
|-------|
| Conv3d |
| BatchNorm3d |
| PReLU |
| Conv3d |
| BatchNorm3d |
| PReLU |


In the overall UNet structure, you can implement the following structure. ```conv_down``` and ```conv_up``` refers to the function block you defined above.

| Layer | Input Channel | Output Channel |
|-------|-------------|--------------|
| conv_down1 | 1 | 16 |
| maxpool | 16 | 16 |
| conv_down2 | 16 | 32 |
| maxpool | 32 | 32 |
| conv_down3 | 32 | 64 |
| maxpool | 64 | 64 |
| conv_down4 | 64 | 128 |
| maxpool | 128 | 128 |
| conv_down5 | 128 | 256 |
| upsample | 256 | 256 |
| conv_up4 | 128+256 | 128 |
| upsample | 128 | 128 |
| conv_up3 | 64+128 | 64 |
| upsample | 64 | 32 |
| conv_up4 | 32+64 | 32 |
| upsample | 32 | 32 |
| conv_up4 | 16+32 | 16 |
| conv_out | 16 | 2 |


In [5]:
import torch
import torch.nn as nn

def double_conv(in_channels, out_channels):
    return nn.Sequential(
    nn.Conv3d(in_channels, out_channels, 3, padding=1),
    nn.BatchNorm3d(out_channels),
    nn.PReLU(),
    nn.Conv3d(out_channels, out_channels, 3, padding=1),
    nn.BatchNorm3d(out_channels),
    nn.PReLU())

class UNet(nn.Module):

    def __init__(self, n_classes):
        super().__init__()
                
        self.conv_down1 = double_conv(1, 16)
        self.conv_down2 = double_conv(16, 32)
        self.conv_down3 = double_conv(32, 64)
        self.conv_down4 = double_conv(64, 128)       
        self.conv_down5 = double_conv(128, 256)

        self.maxpool = nn.MaxPool3d(2)
        self.upsample = nn.Upsample(scale_factor=2, mode='trilinear', align_corners=True)        
        
        self.conv_up4 = double_conv(128 + 256, 128)
        self.conv_up3 = double_conv(64 + 128, 64)
        self.conv_up2 = double_conv(64 + 32, 32)
        self.conv_up1 = double_conv(32 + 16, 16)
        
        self.last_conv = nn.Conv3d(16, n_classes, kernel_size=1)
        
    def forward(self, x):
        conv1 = self.conv_down1(x)  
        x = self.maxpool(conv1)    
        
        conv2 = self.conv_down2(x) 
        x = self.maxpool(conv2)    
        
        conv3 = self.conv_down3(x)  
        x = self.maxpool(conv3)

        conv4 = self.conv_down4(x)  
        x = self.maxpool(conv4)      
        
        x = self.conv_down5(x)
        x = self.upsample(x)  
        
        x = torch.cat([x, conv4], dim=1) 

        x = self.conv_up4(x) 
        x = self.upsample(x) 
        
        x = torch.cat([x, conv3], dim=1) 
        
        x = self.conv_up3(x) 
        x = self.upsample(x) 
        
        x = torch.cat([x, conv2], dim=1) 

        x = self.conv_up2(x)
        x = self.upsample(x)   
        
        x = torch.cat([x, conv1], dim=1)
        
        x = self.conv_up1(x)
        
        out = self.last_conv(x)
        
        return out

## Create Model, Loss, Optimizer

In [6]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:1")
model = UNet(n_classes=2).to(device)
loss_function = DiceLoss(to_onehot_y=True, softmax=True)
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = DiceMetric(include_background=False, reduction="mean")
asd_metric = SurfaceDistanceMetric(include_background=False, distance_metric="euclidean")
hausdorff_metric = HausdorffDistanceMetric(include_background=False,percentile=95)
jaccard_metric = MeanIoU(include_background=False)

## Define your training/val/test loop

In [7]:
wandb.init(project="homework2", name="Trial1")

max_epochs = 600
val_interval = 50
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
dice_values = []
asd_values = []
hausdorff_values = []
jaccard_values = []
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
    print("-" * 10)
    model.train()
    epoch_loss = 0
    step = 0
    for batch_data in train_loader:
        step += 1
        inputs, labels = (
            batch_data[0]["image"].to(device),
            batch_data[0]["label"].to(device),
        )
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
    epoch_loss /= step
    epoch_loss_values.append(epoch_loss)
    print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}")
    wandb.log({'train_loss': epoch_loss}, step= epoch)

    # ONE POSSIBLE VALIDATION METHOD
    if (epoch + 1) % val_interval == 0:
        model.eval()
        with torch.no_grad():
            for val_data in val_loader:
                val_inputs, val_labels = (
                    val_data["image"].to(device),
                    val_data["label"].to(device),
                )
                roi_size = (160, 160, 160)
                sw_batch_size = 4
                val_outputs = sliding_window_inference(val_inputs, roi_size, sw_batch_size, model)
                val_outputs = [post_pred(val_output) for val_output in val_outputs]
                val_labels = [post_label(val_label) for val_label in val_labels]
                # compute metric for current iteration
                dice_metric(y_pred=val_outputs, y=val_labels)
                asd_metric(y_pred=val_outputs, y=val_labels)
                hausdorff_metric(y_pred=val_outputs, y=val_labels)
                jaccard_metric(y_pred=val_outputs, y=val_labels)

            # aggregate the final mean dice result
            dice = dice_metric.aggregate().item()
            asd = asd_metric.aggregate().item()
            hausdorff = hausdorff_metric.aggregate().item()
            jaccard = jaccard_metric.aggregate().item()
            # reset the status for next validation round
            dice_metric.reset()
            asd_metric.reset()
            hausdorff_metric.reset()
            jaccard_metric.reset()

            dice_values.append(dice)
            asd_values.append(asd)
            hausdorff_values.append(hausdorff)
            jaccard_values.append(jaccard)
            wandb.log({'val_dice': dice}, step= epoch)
            wandb.log({'val_asd': asd}, step= epoch)
            wandb.log({'val_hausdorff': hausdorff}, step= epoch)
            wandb.log({'val_jaccard': jaccard}, step= epoch)
            
            if dice > best_metric:
                best_metric = dice
                best_metric_epoch = epoch + 1
                torch.save(model.state_dict(), "./best_metric_model.pth")
                print("saved new best metric model")
            print(
                f"current epoch: {epoch + 1} current mean dice: {dice:.4f}"
                f"\nbest mean dice: {best_metric:.4f} "
                f"at epoch: {best_metric_epoch}"
            )



Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.


[34m[1mwandb[0m: Currently logged in as: [33measonqin[0m. Use [1m`wandb login --relogin`[0m to force relogin


----------
epoch 1 average loss: 0.6310
----------
epoch 2 average loss: 0.5678
----------
epoch 3 average loss: 0.5718
----------
epoch 4 average loss: 0.5675
----------
epoch 5 average loss: 0.6023
----------
epoch 6 average loss: 0.5545
----------
epoch 7 average loss: 0.5521
----------
epoch 8 average loss: 0.5581
----------
epoch 9 average loss: 0.5490
----------
epoch 10 average loss: 0.5607
----------
epoch 11 average loss: 0.5478
----------
epoch 12 average loss: 0.5415
----------
epoch 13 average loss: 0.5343
----------
epoch 14 average loss: 0.5380
----------
epoch 15 average loss: 0.5145
----------
epoch 16 average loss: 0.5282
----------
epoch 17 average loss: 0.5182
----------
epoch 18 average loss: 0.5228
----------
epoch 19 average loss: 0.5281
----------
epoch 20 average loss: 0.5211
----------
epoch 21 average loss: 0.5140
----------
epoch 22 average loss: 0.5311
----------
epoch 23 average loss: 0.5092
----------
epoch 24 average loss: 0.5352
----------
epoch 25 avera

## Inference and Report performance on Test Set

In [8]:
def save_results(save_image,save_path):
    save_image=save_image.cpu().numpy()
    sitk_img = sitk.GetImageFromArray(save_image)
    sitk.WriteImage(sitk_img, save_path)

In [10]:
# ONE POSSIBLE TEST METHOD

from tqdm import tqdm
model.load_state_dict(torch.load("./best_metric_model.pth"))
model.eval()
with torch.no_grad():
    for test_data in tqdm(test_loader):
        test_inputs, test_labels = (
            test_data["image"].to(device),
            test_data["label"].to(device),
        )
        roi_size = (160, 160, 160)
        sw_batch_size = 4
        test_outputs = sliding_window_inference(test_inputs, roi_size, sw_batch_size, model)
        # print(test_outputs.shape)
        test_outputs = [post_pred(test_output) for test_output in test_outputs]
        test_labels = [post_label(test_label) for test_label in test_labels]
        # print(test_outputs[0].shape)
        # compute metric for current iteration
        dice_metric(y_pred=test_outputs, y=test_labels)
        asd_metric(y_pred=test_outputs, y=test_labels)
        hausdorff_metric(y_pred=test_outputs, y=test_labels)
        jaccard_metric(y_pred=test_outputs, y=test_labels)
        
    # aggregate the final mean dice result
    dice = dice_metric.aggregate().item()
    asd = asd_metric.aggregate().item()
    hausdorff = hausdorff_metric.aggregate().item()
    jaccard = jaccard_metric.aggregate().item()
    # reset the status for next validation round
    dice_metric.reset()
    asd_metric.reset()
    hausdorff_metric.reset()
    jaccard_metric.reset()


    

print("Test cases results:")
# print average results
print(f"Average Dice: {dice:.4f}")
print(f"Average ASD: {asd:.4f}")
print(f"Average Hausdorff: {hausdorff:.4f}")
print(f"Average Jaccard: {jaccard:.4f}")

    


  0%|          | 0/8 [00:00<?, ?it/s]

100%|██████████| 8/8 [00:17<00:00,  2.18s/it]

Test cases results:
Average Dice: 0.9414
Average ASD: 3.8363
Average Hausdorff: 35.9547
Average Jaccard: 0.8899





## Acknowledgement

We acknowledge the awesome coding samples provided by MONAI. You are also welcome to check out the other great coding samples provided by them to get familiar with dealing with Medical Imaging.