This differs from lung-seg-demo in that some additional MONAI features are showcased. The following additonal features are shown here that are not in lung-seg-demo:
- PersistentDataset (seen [here](#Preparing-Data-Lists))
- Sliding window inference (seen [here](#Sliding-Window-Inference))
    - Also, a transfrom `RandomSpatialCropD` is introduced into the training, to help train the network on image patches. (See that [here](#Creating-Transforms).)
- Metrics (dice metric, seen [here](#Defining-a-metric))
- Engines and event handlers (seen in the alternative approach to training [here](#Engines-appraoch-to-training))
    - Also, automatic mixed precision (AMP) is seen here.

# Table of Contents
* [Montgomery and Shenzhen Datasets](#Montgomery-and-Shenzhen-Datasets)
* [Preparing Data Lists](#Preparing-Data-Lists)
* [Creating Transforms](#Creating-Transforms)
* [PersistentDatasets](#PersistentDatasets)
* [Previewing](#Previewing)
* [Define the Segmentation Network](#Define-the-Segmentation-Network)
* [Defining the Loss Function](#Defining-the-Loss-Function)
* [Defining a metric](#Defining-a-metric)
* [Previewing Segmentation Network Outputs](#Previewing-Segmentation-Network-Outputs)
* [Training](#Training)
* [Saving](#Saving)
* [Plotting](#Plotting)
* [Inference](#Inference)
	* [Sliding Window Inference](#Sliding-Window-Inference)
* [Engines appraoch to training](#Engines-appraoch-to-training)


# Montgomery and Shenzhen Datasets




- [Article about both datasets](https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4256233/)
- Montgomery contains 138 chest X-rays, 80 healthy, 58 tuberculosis. Has lung seg.
- Shenzhen contains 662 chest X-rays, 326 healthy, 336 tuberculosis. Has lung seg.
- [Get both here](https://openi.nlm.nih.gov/faq?it=xg#collection). Look for "tuberculosis collection"

### Shenzhen

[The readme](NLM-ChinaCXRSet-ReadMe.pdf).

- 336 cases with manifestation of tuberculosis, and 
- 326 normal cases.

- Format: PNG
- Image size varies for each X-ray. It is approximately 3K x 3K.

- Image file names are coded as `CHNCXR_#####_0/1.png`, where ‘0’ represents the normal and ‘1’
represents the abnormal lung. 

Segmentation can be obtained separately [here](https://www.kaggle.com/yoctoman/shcxr-lung-mask), and it was done manually by: "students and teachers of Computer Engineering Department, Faculty of Informatics and Computer Engineering, National Technical University of Ukraine "Igor Sikorsky Kyiv Polytechnic Institute", Kyiv, Ukraine." So, not necessarily medical experts.

### Montgomery

[The readme](NLM-MontgomeryCXRSet-ReadMe.pdf).

- 58 cases	with	manifestation	of	tuberculosis,	and	 80 normal	cases.
- Image	 file	 names	are	 coded	as	`MCUCXR_#####_0/1.png`, where	‘0’	 represents	 the	 normal	and	‘1’ represents	the	abnormal	lung. These are important classes to keep in mind for the purpose of proportional train/val/test split.

---

- Format:	PNG
- Matrix	size	is	4020	x	4892,	or	4892	x	4020.
- The	pixel	spacing	in	vertical	and	horizontal	directions	is	0.0875	mm.	
- Number	of	gray	levels	is	12 bits.

---

Segmentation:
> We	manually	generated	the	“gold	standard” segmentations	for	the	chest	X-ray	under	the	supervision	of a	radiologist.	We	used	the	following	conventions	for outlining	the	lung	boundaries:	Both	posterior	and	anterior	ribs	are	readily	visible	in	the	CXRs;	the	part	of	the	lung	behind	the	heart	is	excluded.	We	follow	anatomical	 landmarks	 such	 as	 the	 boundary	 of	 the	 heart,	 aortic	 arc/line,	 and	 pericardium	 line;	 and	sharp	costophrenic	angle	that	follow	the	diaphragm	boundary. We	draw	an	inferred	boundary	when	the	pathology	is	severe	and	affects	the	morphological	appearance	of	the	lungs. The	lung	boundaries	(left	and	right)	are	in	binary	image	format	and	have	the	same	file	name	as	chest	Xrays	( e.g.	`…/left/MCUCXR_#####_0/1.png` or	`…/right/MCUCXR_#####_0/1.png`). 

# Preparing Data Lists

We start by creating a list of data items. Each data item will be a dictionary of associated data: image filepath, label filepath, image origin, and presence of pathology.

In [None]:
import os, glob

# Adjust the paths here based on where you downloaded the data
data_base_path = '/home/ebrahim/data/chest_xrays'
montgomery_imgs_path = os.path.join(data_base_path, 'MontgomerySet/CXR_png')
montgomery_segs_path_left = os.path.join(data_base_path, 'MontgomerySet/ManualMask/leftMask')
montgomery_segs_path_right = os.path.join(data_base_path, 'MontgomerySet/ManualMask/rightMask')
shenzhen_imgs_path = os.path.join(data_base_path, 'ChinaSet_AllFiles/CXR')
shenzhen_segs_path = os.path.join(data_base_path, 'ChinaSet_AllFiles/CXR_segs')

# We use glob to get lists of png filepaths
montgomery_imgs = glob.glob(os.path.join(montgomery_imgs_path, '*.png'))
montgomery_segs_left = glob.glob(os.path.join(montgomery_segs_path_left, '*.png'))
montgomery_segs_right = glob.glob(os.path.join(montgomery_segs_path_right, '*.png'))

shenzhen_imgs = glob.glob(os.path.join(shenzhen_imgs_path, '*.png'))
shenzhen_segs = glob.glob(os.path.join(shenzhen_segs_path, '*.png'))

# Here we map filepaths to image IDs; this will allow us to associate images
# to their corresponding labels.
file_path_to_ID = lambda p : os.path.basename(p)[7:11]
montgomery_img_ids = list(map(file_path_to_ID,montgomery_imgs))
montgomery_seg_ids_left = list(map(file_path_to_ID,montgomery_segs_left))
montgomery_seg_ids_right = list(map(file_path_to_ID,montgomery_segs_right))
shenzhen_img_ids = list(map(file_path_to_ID,shenzhen_imgs))
shenzhen_seg_ids = list(map(file_path_to_ID,shenzhen_segs))

# This function uses filename to extract whether tuberculosis is present.
# While we are not necessarily interested in classifying images for tuberculosis here,
# we still want to track which images have it so that our selection of validation data
# can be made representative of the total population.
file_path_to_abnormality = lambda p : bool(int(os.path.basename(p)[12]))

# Finally, we define a list of data items.
# This will be the input into the MONAI training and inference pipeline.
# Each data item is a dictionary containing some associated data-- in this case
# filepaths pointing to images and to associated segmentations, as well as a boolean
# parameter indicating the presence of tuberculosis.
data = []
for img in montgomery_imgs: 
    img_id = file_path_to_ID(img)
    seg_left = montgomery_segs_left[montgomery_seg_ids_left.index(img_id)]
    seg_right = montgomery_segs_right[montgomery_seg_ids_right.index(img_id)]
    tuberculosis = file_path_to_abnormality(img)
    data.append({
        'image' : img,
        'mo_seg_left' : seg_left, # mo for montgomery
        'mo_seg_right' : seg_right,
        'tuberculosis' : tuberculosis,
        'id' : 'montgomery:'+img_id,
        'source' : "montgomery"
    })
skipped_no_seg = 0
skipped_bad = 0
for img in shenzhen_imgs:
    img_id = file_path_to_ID(img)
    if img_id not in shenzhen_seg_ids:
        skipped_no_seg += 1
        continue
    seg = shenzhen_segs[shenzhen_seg_ids.index(img_id)]
    tuberculosis = file_path_to_abnormality(img)
    data.append({
        'image' : img,
        'sh_seg' : seg, # sh for shenzhen
        'tuberculosis' : tuberculosis,
        'id' : 'shenzhen:'+img_id,
        'source' : "shenzhen"
    })
if skipped_no_seg>0:
    print(f"{skipped_no_seg} of the shenzhen images do not have an associated segmentation, and they were skipped.")

In [None]:
import monai
import matplotlib.pyplot as plt
import numpy as np
import torch
from segmentation_post_processing import SegmentationPostProcessing
from segmentation_model_lib import * # This contains some custom monai transforms

# Fixing the random seed is useful for making results reproducible
monai.utils.misc.set_determinism(seed=9274)

When we select out validation data, we take care to select a subset of images that is representative of the total population. We do this by passing into the `classes` parameter of `partition_dataset_classes` a list of class names constructed by concatenating the parameters for which we care about having proportional representation.

In [None]:
data_train, data_valid = monai.data.utils.partition_dataset_classes(
    data,
    classes = list(map(lambda d : (d['tuberculosis'],d['source']), data)),
    ratios = (8,2)
)

# Creating Transforms

MONAI has a large and powerful collection of transforms to draw upon. We use many of them here, and we also use some custom made transforms, which can be created easily by inheriting MONAI's transform classes.


*Randomizable transforms* are the non-deterministic ones that can result in a different output each time they are given the same input. One thing to keep in mind is that any randomizable transform in a chain of transforms will interrupt caching, so all randomizable transforms should be put towards the end of the transform chain to the extent possible.

Notice that the transforms used below have a 'D' in their name. This makes them *MapTransforms*, which means that they expect for data items to be *dictionaries*, and the transforms will operate on some of the *values* of those dictionaries. This is in contrast to a regular *Transform*, which would operate on data items directly. For a MapTransform, we must specify the keys for which the transform should operate on associated values.

In [None]:
image_size = 256

keys_to_delete = ['mo_seg_left', 'mo_seg_right', 'sh_seg']
keys_to_delete += [k+"_meta_dict" for k in keys_to_delete] + [k+"_transforms" for k in keys_to_delete]

# Base transform chain that is common to training and validation
load_and_union_masks = [
    monai.transforms.LoadImageD(reader='itkreader',keys = ['image']), # A few shenzhen images get mysteriously value-inverted with readers other than itkreader
    monai.transforms.LambdaD(keys=['image'], func = rgb_to_grayscale), # A few of the shenzhen imgs are randomly RGB encoded rather than grayscale colormap
    monai.transforms.LoadImageD(keys = ['mo_seg_left', 'mo_seg_right', 'sh_seg'], dtype="int8", allow_missing_keys=True),
    monai.transforms.TransposeD(keys = ['image', 'mo_seg_left', 'mo_seg_right', 'sh_seg'], indices = (1,0), allow_missing_keys=True),
    monai.transforms.AddChannelD(keys = ['image']),
    UnionMasksD(keys = ['mo_seg_left', 'mo_seg_right'], keyList=['mo_seg_left', 'mo_seg_right'], newKeyName='label'),
    UnionMasksD(keys = ['sh_seg',], keyList=['sh_seg'], newKeyName='label'), # using for one-hot conversion, not "union"
    monai.transforms.DeleteItemsD(keys = keys_to_delete),
    monai.transforms.ToTensorD(keys = ['image', 'label'])
]

# Transform for validation
transform_valid = monai.transforms.Compose([
    *load_and_union_masks,
    monai.transforms.ResizeD( # This resize
        keys = ['image', 'label'],
        spatial_size=(image_size,image_size),
        mode = ['bilinear', 'nearest'],
        align_corners = [False, None]
    ),
])

# Transform for training
transform_train = monai.transforms.Compose([
    *load_and_union_masks,
    monai.transforms.ResizeD( # Standardize image size (e.g. good to do this before the spatial crop, for consistency)
        keys = ['image', 'label'],
        spatial_size=(512,512),
        mode = ['bilinear', 'nearest'],
        align_corners = [False, None]
    ),
    monai.transforms.RandZoomD(
        keys = ['image', 'label'],
        mode = ['bilinear', 'nearest'],
        align_corners = [False, None],
        prob=1.,
        padding_mode="constant",
        min_zoom = 0.7,
        max_zoom=1.3,
    ),
    monai.transforms.RandRotateD(
        keys = ['image', 'label'],
        mode = ['bilinear', 'nearest'],
        align_corners = [False, None],
        prob=1.,
        range_x = np.pi/8,
        padding_mode="zeros",
    ),
    monai.transforms.RandGaussianSmoothD(
        keys = ['image'],
        prob = 0.4
    ),
    monai.transforms.RandAdjustContrastD(
        keys = ['image'],
        prob=0.4,
    ),
    monai.transforms.RandSpatialCropD(
        keys = ['image', 'label'],
        roi_size = 64, # used as minimum roi_size, since we enable random_size. max_roi_size is image size.
        random_size = True
    ),
    monai.transforms.ResizeD( # Resize to desired final output size
        keys = ['image', 'label'],
        spatial_size=(image_size,image_size),
        mode = ['bilinear', 'nearest'],
        align_corners = [False, None]
    ),
])

# PersistentDatasets

We take our lists of data items and our transforms and we make datasets out of them. *Persistent*Datasets will pre-run their transforms on some number of data items, for faster access of transformed data later on. Unlike *Cache*Datasets, PersistentDatasets will do their caching on disk rather than in RAM.

In [None]:
cache_dir = './cache'
dataset_train = monai.data.PersistentDataset(data_train, transform_train, cache_dir)
dataset_valid = monai.data.PersistentDataset(data_valid, transform_valid, cache_dir)

# Previewing

Here we preview some training images. If you just created `cache_dir` above and you look inside it then it should start out empty. After doing some previewing below, you will see `cache_dir` fill up as we query the dataset.

In [None]:
def preview(data_item, show_bdry = False, overlay_seg = True, figsize = (7,7)):
    fig = plt.figure(figsize=figsize)
    im = data_item['image'].expand((3,)+data_item['image'].shape[1:])
    im = im/im.max()
    seg = data_item['label'].float()
    if overlay_seg:
        im[1,:,:] *= 1-0.3*seg[1,:,:]
    if show_bdry:
        seg_bdry = bdry(seg[1])
        mask = (seg_bdry == 1.)
        im[0,mask], im[1,mask], im[2,mask] = 1,0,0 # R, G, B
    im = np.transpose(im,axes=(1,2,0))
    plt.imshow(im, cmap='bone')
    plt.plot();

In [None]:
import random
i = random.choice(range(len(dataset_train)))
d = dataset_train[i]
preview(d, show_bdry=False, overlay_seg=False, figsize=(12,12))
d['image'].shape

# Define the Segmentation Network

In [None]:
spatial_dims = 2;
image_channels = 1;
seg_channels = 2; # lung, background
seg_net_channel_seq = (16, 32, 64, 128, 256)
stride_seq = (2,2,2,2) 
dropout_seg_net = 0.5
num_res_units = 2

seg_net = monai.networks.nets.UNet(
    spatial_dims = spatial_dims,
    in_channels = image_channels,
    out_channels = seg_channels, 
    channels = seg_net_channel_seq,
    strides = stride_seq,
    dropout = dropout_seg_net,
    num_res_units = num_res_units
)

num_params = sum(p.numel() for p in seg_net.parameters())
print(f"seg_net has {num_params} parameters")

# Defining the Loss Function

In [None]:
dice_loss = monai.losses.DiceLoss(
    to_onehot_y = False, # The segmentations we pass in are already in one-hot form
    softmax = True, # Note that our segmentation network is missing the softmax at the end
)

In [None]:
# Test drive
data_item = dataset_train[42]
seg_pred = seg_net(data_item['image'].unsqueeze(0)) # shape is (1,2,1024,1024), which is (B,N,H,W)

dice_loss(
    seg_net(data_item['image'].unsqueeze(0)), # input, one-hot
    data_item['label'].unsqueeze(0), # target, one-hot
).detach()

# Defining a metric

When evaluating a model, [metrics](https://docs.monai.io/en/stable/metrics.html) provide a little more functionality than plain functions or loss functions. Here we will use the [mean dice](https://docs.monai.io/en/stable/metrics.html#mean-dice) metric, which is
- an [IterationMetric](https://docs.monai.io/en/stable/metrics.html#iterationmetric): instead of taking just raw model outputs, it can also operate on a _list_ of tensors
- a [CumulativeMetric](https://docs.monai.io/en/stable/metrics.html#cumulative): gradually append/extend data into it and then aggregate when ready to do so.

The real benefit of using a CumulativeMetric is the ability to compute the metric in a distributed manner across a large dataset. For more information on that, read about [distributed processes in pytorch](https://pytorch.org/tutorials/beginner/dist_overview.html), and then check out [this MONAI example](https://github.com/Project-MONAI/tutorials/blob/master/modules/compute_metric.py).

We aren't doing any distributed computing in this notebook, but we will use the mean dice metric during validation just for the sake of demo.

In [None]:
dice_metric = monai.metrics.DiceMetric()

# In order to use dice_metric, we need to discretize the output of segnet and express it in one-hot form
# The following transform is convenient for this. It will turn logits of shape (num_seg_classes,H,W)
# to discrete one-hot labels of the same shape (num_seg_classes,H,W)
logits_to_discrete_one_hot = monai.transforms.AsDiscrete(argmax=True, to_onehot=2)

# Previewing Segmentation Network Outputs

Here we define a convenience function for previewing model outputs.

In [None]:
binary_mask = lambda x : (x!=0).astype('float')
bdry = lambda s : binary_mask((np.abs(np.diff(s, axis=0, prepend=0)) + np.abs(np.diff(s, axis=1, prepend=0)))!=0)

def preview_seg_net(data_item, figsize=(15,10), print_score = True, show_heatmap = False, show_bdry=False, show_post_processing=0):
    """
    Preview seg net prediciton
    
    Args:
        data_item: A data item to input into seg_net.
        figsize: figure size to be used at each matplotlib plotting call
        print_score: show Dice score
        show_heatmap: whether to show class probability image
        show_bdry: whether to draw the boundry
        show_post_processing: 0 to not show it,
            1 to show post processed result,
            2 to show post processed result and intermediate steps
    """
    
    seg_net.eval()
    
    with torch.no_grad():
        im_device = data_item['image'].to(next(seg_net.parameters()).device.type)
        seg_pred = seg_net(im_device.unsqueeze(0))[0].cpu()
        _, max_indices = seg_pred.max(dim=0)
        seg_pred_mask = (max_indices==1).type(torch.uint8)

        f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=figsize)

        im = data_item['image'].expand((3,)+data_item['image'].shape[1:])
        im = im/im.max()

        seg_true = data_item['label'].float()
        im_true = im.clone()
        im_true[1,:,:] *= 1-0.4*seg_true[1,:,:]
        if show_bdry:
            seg_true_bdry = bdry(seg_true[1])
            mask = (seg_true_bdry == 1.)
            im_true[0,mask], im_true[1,mask], im_true[2,mask] = 1,0,0 # R, G, B
        im_true = np.transpose(im_true,axes=(1,2,0))
        ax1.imshow(im_true, cmap='bone')
        ax1.set_title("true seg overlay")
        ax1.axis('off')

        ax2.imshow(max_indices)
        ax2.set_title("predicted seg")
        ax2.axis('off')

        im_pred = im.clone()
        im_pred[1,:,:] *= 1-0.4*seg_pred_mask
        if show_bdry:
            seg_pred_bdry = bdry(seg_pred_mask)
            mask = (seg_pred_bdry == 1.)
            im_pred[0,mask], im_pred[1,mask], im_pred[2,mask] = 1,0,0 # R, G, B
        im_pred = np.transpose(im_pred,axes=(1,2,0))
        ax3.imshow(im_pred, cmap='bone')
        ax3.set_title("predicted seg overlay")
        ax3.axis('off')

        plt.show();
        
        if show_heatmap:
            f, ax1 = plt.subplots(1, 1, figsize=figsize)
            ax1.imshow(seg_pred.softmax(dim=0)[1])
            ax1.axis('off')
            print("predicted seg class probability maps:")
            plt.show()
        
        if show_post_processing!=0:
            plt.figure(figsize = figsize)
            seg_post_process = SegmentationPostProcessing()
            seg_pred_processed = seg_post_process(seg_pred_mask)
            im_pred = im.clone()
            im_pred[1,:,:] *= 1-0.4*(seg_pred_processed==1)
            im_pred[0,:,:] *= 1-0.4*(seg_pred_processed==2)
            if show_bdry:
                seg_pred_bdry1 = bdry(seg_pred_processed==1)
                seg_pred_bdry2 = bdry(seg_pred_processed==2)
                mask1 = (seg_pred_bdry1 == 1.)
                mask2 = (seg_pred_bdry2 == 1.)
                im_pred[0,mask1], im_pred[1,mask1], im_pred[2,mask1] = 1,0,0 # R, G, B
                im_pred[0,mask2], im_pred[1,mask2], im_pred[2,mask2] = 0,1,0 # R, G, B
            im_pred = np.transpose(im_pred,axes=(1,2,0))
            plt.imshow(im_pred, cmap='bone')
            plt.title("post-processed segmentation overlay")
            plt.axis('off')
            plt.show()
            if show_post_processing>1:
                seg_post_process.preview_intermediate_steps()

        if print_score:
            dice_metric.reset()
            dice_metric(
                logits_to_discrete_one_hot(seg_pred).unsqueeze(0),
                data_item['label'].unsqueeze(0),
            )
            score = dice_metric.aggregate()
            print(f"Dice metric: {(1.-score.item()):.3f}")

In [None]:
a = monai.transforms.AsDiscrete(argmax=True, to_onehot=2)(seg_pred[0]).unsqueeze(0)
a.shape

In [None]:
# Try seg_net on a random image.
preview_seg_net(random.choice(dataset_train), show_bdry=True);

In [None]:
# Here's a good sanity check. Ground truth label trnsors should have discrete values--
# let's make sure that's still the case after all the transforms are applied:
dataset_train[0]['label'].unique()

# Training

Define *dataloaders*, which draw their data from datasets and collate it into batches.

In [None]:
dataloader_train = monai.data.DataLoader(
    dataset_train,
    batch_size=16,
    num_workers=8,
    shuffle=True,
    collate_fn = list_data_collate_no_meta # (It's normally not necessary to define a custom collate_fn)
)

dataloader_valid = monai.data.DataLoader(
    dataset_valid,
    batch_size=64,
    num_workers=8,
    shuffle=False,
    collate_fn = list_data_collate_no_meta
)

Do the initial setup for training

In [None]:
device = torch.device('cuda')

learning_rate = 1e-3
optimizer = torch.optim.Adam(seg_net.parameters(), learning_rate)

epoch_number = 0
training_losses = [] 
validation_scores = []
preview_index = random.choice(range(len(dataset_valid)))
best_validation_score = float('-inf')
best_validation_epoch = -1

validate_every = 5

Finally, the training loop! We save the model with the best validation score.

If you've already done some training and run the "`CHECKPOINT CELL; SAVE`" cell [below](#Saving), then you can skip the training loop and uncomment the "`CHECKPOINT CELL; LOAD`" cell.

You can also skip this training loop and try the alternative ignite-based workflow appraoch at the end of the notebook.

In [None]:
max_epochs = 20

seg_net.to(device)

# shift things to always validate on last epoch
validate_this_epoch = lambda epoch_number : epoch_number%validate_every==(max_epochs-1)%validate_every

while epoch_number < max_epochs:
    
    print(f"Epoch {epoch_number+1}/{max_epochs} ...")
    
    if validate_this_epoch(epoch_number):
        preview_seg_net(dataset_valid[preview_index], figsize=(6,6), print_score=False);
    
    seg_net.train()
    losses = []
    for batch in dataloader_train:
        imgs = batch['image'].to(device)
        true_segs = batch['label'].to(device)

        optimizer.zero_grad()
        predicted_segs = seg_net(imgs)
        loss = dice_loss(predicted_segs, true_segs)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
    
    training_loss = np.mean(losses)
    training_losses.append([epoch_number, training_loss])
    
    print(f"\ttraining loss: {training_loss}")

    if validate_this_epoch(epoch_number):
    
        seg_net.eval()
        dice_metric.reset() # We will aggregate the dice metric
        with torch.no_grad():
            for batch in dataloader_valid:
                imgs = batch['image'].to(device)
                true_segs = batch['label'].to(device)
                predicted_segs = seg_net(imgs)
                dice_metric(
                    [logits_to_discrete_one_hot(predicted_seg) 
                        for predicted_seg in monai.data.decollate_batch(predicted_segs)],
                    true_segs
                )
            validation_score = dice_metric.aggregate().item()

        print(f"\tvalidation mean dice score: {validation_score}")
        
        validation_scores.append([epoch_number, validation_score])
        
        if validation_score > best_validation_score:
            best_validation_score = validation_score
            torch.save(seg_net.state_dict(),f'seg_net_bestval.pth')
            best_validation_epoch = epoch_number
    
    epoch_number +=1

del imgs, true_segs, predicted_segs, loss
torch.cuda.empty_cache()

seg_net.load_state_dict(torch.load(f'seg_net_bestval.pth'))
print(f"Loaded model state during the best validation score, which was during epoch {best_validation_epoch}.")

# Saving

In [None]:
# CHECKPOINT CELL; LOAD

# Uncomment the lines below to load a previous post-training state.

# run_id_load = '0023'
# run_id = None
# load_path = f'model{run_id_load}.pth'

# model_dict = torch.load(load_path)

# seg_net = model_dict['model']
# learning_rate = model_dict['learning_rate']
# optimizer = torch.optim.Adam(seg_net.parameters(), learning_rate)
# optimizer.load_state_dict(model_dict['optimizer_state_dict'])
# training_losses = model_dict['training_losses']
# validation_scores = model_dict['validation_scores']
# epoch_number = model_dict['epoch_number']
# best_validation_score = model_dict['best_validation_score']
# best_validation_epoch = model_dict['best_validation_epoch']
# image_size = model_dict['image_size']

In [None]:
run_id = '0029' # Set a new run ID each time
save_path = f'model{run_id}.pth'
if (os.path.exists(save_path)):
    del run_id, save_path
    raise Exception("Please change run_id so you don't overwrite things.")

In [None]:
# CHECKPOINT CELL; SAVE

torch.save(
    {
        'model': seg_net.cpu(),
        'optimizer_state_dict': optimizer.state_dict(),
        'learning_rate': learning_rate,
        'training_losses': training_losses,
        'validation_scores': validation_scores,
        'epoch_number': epoch_number,
        'best_validation_score': best_validation_score,
        'best_validation_epoch': best_validation_epoch,
        'image_size': image_size,
    }, 
    save_path
)

# Plotting

In [None]:
def plot_against_epoch_numbers(epoch_value_pairs, label):
    array = np.array(epoch_value_pairs)
    plt.plot(array[:,0], array[:,1], label=label)

plot_against_epoch_numbers(training_losses, label="training")
plot_against_epoch_numbers(validation_scores, label="validation")
plt.legend()
plt.xlabel('epoch')
plt.ylabel('dice loss')
plt.title('seg net training')
if run_id is not None: plt.savefig(f'seg_net_losses{run_id}.png')
plt.show()

# Inference

Run the cell below to try the segmentation model on random validation images

We also show some post processing done separately using ITK. It ensures that we have two contiguous lung reigions with no holes, and it separates them into left and right lungs.

In [None]:
#Try on a random validation image
data_item_index = random.choice(range(len(dataset_valid)))
print(data_item_index, data_item['id'])
data_item = dataset_valid[data_item_index]
with torch.no_grad():
    preview_seg_net(data_item, show_heatmap=False, show_bdry=True, show_post_processing=1);

In [None]:
# Model evaluation, here is our final dice score!
print(best_validation_score)

##  Sliding Window Inference

*Sliding window inference* is a way to run our network on smaller pieces of an image, and then glue the outputs together into a full result. This might be preferred if there are memory constraints that prevent operating on an image all at once, which often happens when a model is deployed to a variety of environments that don't have the same computational resources that were available during training.

There are a couple of ways to do sliding window inference
1. There's the general concept of an [inferer](https://docs.monai.io/en/stable/inferers.html) in MONAI. An inferer represents the code that takes a pytorch module (like `seg_net`) with an input and yields a final output. For custom applications, a typical approach would be to derive from `Inferer` and implement your own `__call__`. There are also a few pre-made inferers, one of which is [SlidingWindowInferer](https://docs.monai.io/en/stable/inferers.html#slidingwindowinferer).
1. One can skip the whole inferer concept and directly use the function `monai.inferers.sliding_window_inference`. This function is internally used by `monai.inferers.SlidingWindowInferer` anyway, so it's the same thing in the end. To see an example of this, check out the [spleen segmentation tutorial](https://github.com/Project-MONAI/tutorials/blob/master/3d_segmentation/spleen_segmentation_3d.ipynb).

We'll do (1) here. It's completely overkill in this example-- `seg_net` does not have high memory demands to operate on a full image-- but let's just see it in action.

In [None]:
# Play with the roi_size, which is the sliding window size, and run the cell below to see the effect
# Use multiples of 32, to stick to the expected number of strided convolutions
inferer = monai.inferers.SlidingWindowInferer(roi_size = 96, mode='gaussian')

In [None]:
data_item_index = random.choice(range(len(dataset_valid)))
print(data_item_index, data_item['id'])
data_item = dataset_valid[data_item_index]

seg_net.eval()
img = data_item['image'].to(next(seg_net.parameters()).device.type)

with torch.no_grad():
    seg_output = inferer(img.unsqueeze(0),seg_net).cpu()

seg_mask = seg_output.argmax(dim=1)[0]

f, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize = (15,10))

im = img.expand((3,)+img.shape[1:])
im = im/im.max()

seg_true = data_item['label'].float()
im_true = im.cpu()/im.max()
im_true[1,:,:] *= 1-0.4*seg_true[1,:,:]
im_true = np.transpose(im_true,axes=(1,2,0))
ax1.imshow(im_true, cmap='bone')
ax1.set_title("true seg overlay")
ax1.axis('off')

ax2.imshow(seg_mask)
ax2.set_title("predicted seg")
ax2.axis('off')

im_pred = im.cpu()/im.max()
im_pred[1,:,:] *= 1-0.4*seg_mask
im_pred = np.transpose(im_pred,axes=(1,2,0))
ax3.imshow(im_pred, cmap='bone')
ax3.set_title("predicted seg overlay")
ax3.axis('off')

plt.show()

plt.figure(figsize = (10,10))
seg_post_process = SegmentationPostProcessing()
seg_pred_processed = seg_post_process(seg_mask)
im_pred = im.clone()
im_pred[1,:,:] *= 1-0.4*(seg_pred_processed==1)
im_pred[0,:,:] *= 1-0.4*(seg_pred_processed==2)
im_pred = np.transpose(im_pred,axes=(1,2,0))
plt.imshow(im_pred, cmap='bone')
plt.title("post-processed segmentation overlay")
plt.axis('off')
plt.show()

You may notice that the network struggles when the sliding window size becomes too small. Some of this can be alleviated by adjusting the `RandSpatialCropD` transform that is used during training, but that can only go so far. It's not surprising that the network is able to peform better when it has more global context to use.

# Engines appraoch to training

[Engines](https://docs.monai.io/en/stable/engines.html) and [event handlers](https://docs.monai.io/en/stable/handlers.html) are MONAI's window into the [PyTorch Ignite](https://pytorch.org/ignite/index.html) way of doing things. Instead of explicitly defining a `for` loop to do our training, we construct an ignite _engine_ called `trainer` below. Calling `trainer.run()` actually executes the training process. In this case the trainer is a MONAI `SupervisedTrainer`, which has a lot of sensible defaults for "image and label" style training.

Normally when using Ignite we define event handler callables and associate those callables to various events on our Ignite engine. In the MONAI `SupervisedTrainer`, we pass the handlers into a parameter `train_handlers`. What we actually have to pass is a list of objects that have an `attach` method, which is what actually associates callables to events. In the cell below we create some objects that have an `attach` method. These objects will do the same sort of work that we were previously doing inside an explicit `for` loop.

Instead of creating custom handlers, one could also look to the [large collection of ready-to-use handlers provided by MONAI](https://docs.monai.io/en/stable/handlers.html). The reader is encouraged to browse through this collection, but for now I find it easier to see what's going on by making our own event handlers.

In [None]:
# Defining event handlers

from ignite.engine import Events

# An event handler that keeps track of average training losses
class TrainingLossCollector:
    def __init__(self):
        self.losses = [] # iteration losses for the current epoch
        self.training_losses = [] # overall running list of losses by epoch, each averaged over a full epoch
    def attach(self, engine):
        engine.add_event_handler(monai.engines.utils.IterationEvents.LOSS_COMPLETED, self.addIterationLoss)
        engine.add_event_handler(Events.EPOCH_STARTED, self.onEpochStart)
        engine.add_event_handler(Events.EPOCH_COMPLETED, self.onEpochEnd)
    def addIterationLoss(self, engine):
        self.losses.append(engine.state.output['loss'].item())
    def onEpochStart(self, engine):
        self.losses=[]
    def onEpochEnd(self, engine):
        epoch_mean_loss = np.mean(self.losses)
        self.training_losses.append([engine.state.epoch, epoch_mean_loss])
        print(f"Epoch {engine.state.epoch} training loss: {epoch_mean_loss}")

# An event handler that runs validation, keeps track of validation losses, and keeps track of the best model.
# Instead of writing a custom handler here, we could have constructed a monai.engines.SupervisedEvaluator
# and then created a validation handler automatically by using monai.handlers.ValidationHandler.
# To get the exact functionality I wanted, I found it more straight-forward to just write the handler.
class CustomValidationHandler:
    
    best_segnet_filename = 'CustomValidationHandler_seg_net_bestval.pth'
    
    def __init__(self, seg_net, dice_metric, dataloader_valid, device):
        self.seg_net = seg_net
        self.dice_metric = dice_metric
        self.dataloader_valid = dataloader_valid
        self.device = device
        self.best_validation_score = float('-inf')
        self.best_validation_epoch = -1
        self.validation_scores = []
    
    def attach(self, engine):
        engine.add_event_handler(
            Events.EPOCH_COMPLETED(every=validate_every) | Events.COMPLETED | Events.STARTED,
            self.validate
        )
        engine.add_event_handler(
            Events.COMPLETED,
            self.load_best_seg_net
        )
        
    def validate(self, engine):
        self.seg_net.eval()
        self.dice_metric.reset()
        with torch.no_grad():
            for batch in self.dataloader_valid:
                imgs = batch['image'].to(self.device)
                true_segs = batch['label'].to(self.device)
                predicted_segs = self.seg_net(imgs)
                self.dice_metric(
                    [logits_to_discrete_one_hot(predicted_seg) 
                        for predicted_seg in monai.data.decollate_batch(predicted_segs)],
                    true_segs
                )
            validation_score = dice_metric.aggregate().item()

        print(f"Epoch {engine.state.epoch} validation mean dice score: {validation_score}")
        
        self.validation_scores.append([engine.state.epoch, validation_score])
        
        if validation_score > self.best_validation_score:
            self.best_validation_score = validation_score
            torch.save(self.seg_net.state_dict(), self.best_segnet_filename)
            self.best_validation_epoch = engine.state.epoch
    
    def load_best_seg_net(self, engine):
        if os.path.exists(self.best_segnet_filename):
            self.seg_net.load_state_dict(torch.load(self.best_segnet_filename))
            print(f"Loaded model state during the best validation score, which was during epoch {self.best_validation_epoch}.")

# An event handler that shows a preview of a specific data item every few epochs,
# so we have something nice to look at while watching the training take place.
class ImageDisplayHandler:
    def __init__(self, data_item, preview_every):
        """Preview a specific data item every preview_every epochs."""
        self.data_item = data_item
        self.preview_every = preview_every
    def attach(self, engine):
        engine.add_event_handler(Events.EPOCH_COMPLETED(every=self.preview_every), self)
    def __call__(self, engine):
        preview_seg_net(self.data_item, figsize=(6,6), print_score=False)

In [None]:
# Create the handler objects

validation_handler = CustomValidationHandler(seg_net, dice_metric, dataloader_valid, device)
image_display_handler = ImageDisplayHandler(dataset_valid[preview_index], 5)
training_loss_collector = TrainingLossCollector()

In [None]:
# Creating the training engine

seg_net.to(device)

trainer = monai.engines.SupervisedTrainer(
    device = device,
    max_epochs = 10,
    train_data_loader = dataloader_train,
    network = seg_net,
    optimizer = optimizer,
    loss_function = dice_loss,
    train_handlers = [training_loss_collector, validation_handler, image_display_handler],
    amp=True
)

Notice that [automatic mixed precision](https://docs.monai.io/en/stable/highlights.html#auto-mixed-precision-amp)
was enabled by simply setting `amp=True` in the engine parameters. This would have been more complicated with plain PyTorch (see [PyTorch docs on AMP](https://pytorch.org/docs/stable/notes/amp_examples.html), or [this MONAI tutorial](https://github.com/Project-MONAI/tutorials/blob/master/acceleration/automatic_mixed_precision.ipynb) that does AMP explicitly without using an engine).

Now we can start training:

In [None]:
trainer.run()

The following cell defines the variables needed to return to where you were [above](#Saving) in the notebook and skip the training "for loop".

In [None]:
training_losses = training_loss_collector.training_losses
validation_scores = validation_handler.validation_scores
best_validation_score = validation_handler.best_validation_score