## Setup imports

In [None]:
from monai.utils import first, set_determinism
from monai.transforms import (
    AsDiscrete,
    AsDiscreted,
    EnsureChannelFirstd,
    Compose,
    CropForegroundd,
    LoadImaged,
    Orientationd,
    RandCropByPosNegLabeld,
    SaveImaged,
    ScaleIntensityRanged,
    Spacingd,
    Invertd,
)
from monai.handlers.utils import from_engine
from monai.networks.nets import UNet
from monai.networks.layers import Norm
from monai.metrics import DiceMetric
from monai.losses import DiceLoss
from monai.inferers import sliding_window_inference
from monai.data import CacheDataset, DataLoader, Dataset, decollate_batch
import torch
import matplotlib.pyplot as plt
import shutil
import os
import glob

## Set train/validation/test data filepath

In [None]:
data_dir = # your path
train_images = sorted(glob.glob(os.path.join(data_dir,"train" ,"image", "*.nii.gz")))
train_labels = sorted(glob.glob(os.path.join(data_dir,"train" , "label", "*.nii.gz")))
train_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(train_images, train_labels)]

val_images = sorted(glob.glob(os.path.join(data_dir,"val" ,"image", "*.nii.gz")))
val_labels = sorted(glob.glob(os.path.join(data_dir,"val" , "label", "*.nii.gz")))
val_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(val_images, val_labels)]

test_images = sorted(glob.glob(os.path.join(data_dir,"test" ,"image", "*.nii.gz")))
test_labels = sorted(glob.glob(os.path.join(data_dir,"test" , "label", "*.nii.gz")))
test_data_dicts = [{"image": image_name, "label": label_name} for image_name, label_name in zip(test_images, test_labels)]

## Setup data augmentation

For data augmentation, here are the basic requirements:

1. `LoadImaged` loads the spleen CT images and labels from NIfTI format files.
1. `EnsureChannelFirstd` ensures the original data to construct "channel first" shape.
1. `ScaleIntensityRanged` clips the CT's data format, HU value, into a certain range (-57,164) and normalize it to (0,1)
1. `CropForegroundd` removes all zero borders to focus on the valid body area of the images and labels.
1. `RandCropByPosNegLabeld` randomly crop patch samples from big image based on pos / neg ratio.  
The image centers of negative samples must be in valid body area.

You can try more data augmentation techniques to further improve the performance.

In [None]:
train_transforms = Compose(
    [
    ]
)
val_transforms = Compose(
    [
    ]
)

In [None]:
import os
import torch
from torch.utils.data import Dataset, DataLoader



class CT_Dataset(Dataset):
    def __init__(self, dataset_path, transform=None,split='test'):
        pass

    def __len__(self):
        pass

    def __getitem__(self, idx):
        pass
        


# here we don't cache any data in case out of memory issue
train_ds = CT_Dataset(train_files,train_transforms,split='train')
train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=4)
val_ds = CT_Dataset(val_files,val_transforms,split='val')
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)
test_ds = CT_Dataset(test_files,test_transforms,split='test')
test_loader = DataLoader(test_ds, batch_size=2, shuffle=True, num_workers=4)
val_ds = CT_Dataset(val_files,val_transforms,split='val')
val_loader = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=4)

# Implement a 3D UNet for segmentation task

We give a possible network structure here, and you can modify it for a stronger performance.

In the block ```double_conv```, you can implement the following structure：

| Layer |
|-------|
| Conv3d |
| BatchNorm3d |
| PReLU |
| Conv3d |
| BatchNorm3d |
| PReLU |


In the overall UNet structure, you can implement the following structure. ```conv_down``` and ```conv_up``` refers to the function block you defined above.

| Layer | Input Channel | Output Channel |
|-------|-------------|--------------|
| conv_down1 | 1 | 16 |
| maxpool | 16 | 16 |
| conv_down2 | 16 | 32 |
| maxpool | 32 | 32 |
| conv_down3 | 32 | 64 |
| maxpool | 64 | 64 |
| conv_down4 | 64 | 128 |
| maxpool | 128 | 128 |
| conv_down5 | 128 | 256 |
| upsample | 256 | 256 |
| conv_up4 | 128+256 | 128 |
| upsample | 128 | 128 |
| conv_up3 | 64+128 | 64 |
| upsample | 64 | 32 |
| conv_up4 | 32+64 | 32 |
| upsample | 32 | 32 |
| conv_up4 | 16+32 | 16 |
| conv_out | 16 | 2 |


In [None]:
import torch
import torch.nn as nn

def double_conv(in_channels, out_channels):
    pass

class UNet(nn.Module):

    def __init__(self):
        super().__init__()
        pass
        
    def forward(self, x):
        pass

## Create Model, Loss, Optimizer

In [None]:
# standard PyTorch program style: create UNet, DiceLoss and Adam optimizer
device = torch.device("cuda:3")
model = # define your model here
loss_function = # define your loss function here
optimizer = torch.optim.Adam(model.parameters(), 1e-4)
dice_metric = # define your metric here

## Define your training/val/test loop

In [None]:
max_epochs = 600
val_interval = 2
best_metric = -1
best_metric_epoch = -1
epoch_loss_values = []
metric_values = []
post_pred = Compose([AsDiscrete(argmax=True, to_onehot=2)])
post_label = Compose([AsDiscrete(to_onehot=2)])

for epoch in range(max_epochs):
    # the steps are similar to HW1. Please pay attention to the difference in the segmentation task.

## Inference and Report performance on Test Set