In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import monai
import torch

### Data list
Copy your annotated data into `mydata` folder

In [None]:
!cp ../data/spleen_img/spleen* ./mydata/img/
!cp ../data/spleen_seg/spleen* ./mydata/seg/

In [None]:
keys = ['img', 'seg']
data_dir = './mydata/'
fns = os.listdir(data_dir+'img/')
fnames = [{key: data_dir+key+'/'+fn for key in keys} for fn in fns]

In [None]:
fnames

### Preview

In [None]:
sample = monai.transforms.LoadImaged(keys)(fnames[0])
sample['img'].shape

In [None]:
from IPython.display import clear_output
import time

vmax = sample['img'].max()
vmin = sample['img'].min()
for i in range(sample['img'].shape[-1]):
    plt.figure(figsize=(8, 4))
    plt.subplot(121)
    plt.imshow(sample['img'][..., i], cmap='gray', vmax=vmax, vmin=vmin)
    plt.subplot(122)
    plt.imshow(sample['seg'][..., i], cmap='gray')
    plt.show()
    clear_output(wait=True)

### Transforms

In [None]:
### Fix the dictionary transformation ###
spatial_size = [128, 128, 16]
trans = monai.transforms.Compose([monai.transforms.LoadImaged(), 
                                  monai.transforms.AddChanneld(), 
                                  monai.transforms.EnsureTyped(),
                                  monai.transforms.ToDeviced(),
                                  monai.transforms.NormalizeIntensityd(),
                                  monai.transforms.Resized(),
                                  monai.transforms.RandScaleIntensityd(),
                                  monai.transforms.RandFlipd()])

In [None]:
data = trans(fnames[0])
data['img'].shape, data['seg'].shape

### Dataset and DataLoader

In [None]:
### Fix the CacheDataset ###
ds = monai.data.CacheDataset()
for data in ds: print(data['img'].shape, data['seg'].shape)

In [None]:
dl = torch.utils.data.DataLoader(ds, batch_size=1)
for data in dl: print(data['img'].shape, data['seg'].shape)

### Network, loss and optimizer

In [None]:
### Fix the UNet
device = torch.device("cuda:0")

net = monai.networks.nets.UNet(
    dimensions=,  # 2 or 3 for a 2D or 3D network
    in_channels=,  # number of input channels
    out_channels=,  # number of output channels
    channels=[16, 32, 64, 128],  # channel counts for layers
    strides=[2, 2, 2]  # strides for mid layers
).to(device)

net(data['img']).shape

In [None]:
loss_function = monai.losses.DiceLoss(sigmoid=True)
optimizer = torch.optim.Adam(net.parameters(), 5e-4)

### Training

In [None]:
max_epochs = 200
epoch_loss_values = []

net.train()
for epoch in range(max_epochs):
    print('Epoch: '+str(epoch+1)+'/'+str(max_epochs))
    epoch_loss = 0
    for step, batch_data in enumerate(dl):
        inputs, labels = (
            batch_data["img"],
            batch_data["seg"]
        )
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = loss_function(outputs, labels)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        print('  Step '+str(step+1)+'/'+str(len(dl))+f', train_loss: {loss.item():.4f}')
    epoch_loss /= (step+1)
    epoch_loss_values.append(epoch_loss)

In [None]:
plt.figure(figsize=(16, 4))
plt.plot(epoch_loss_values)