### Training a basic segmentation algorithm in Pytorch Lightning

In this python notebook, I will outline a barebones Pytorch Lightning (PL) implementation of training a network for the CAMUS echocardiography segmentation challenge.

Pytorch Lightning is a library that is completely built on Pytorch, but it re-organizes pytorch code into something more concise and readable.

Just like in Pytorch, PL starts with defining an architecture as a class-object. In addition, we must also define a training and validation step as a class-function -- such that later we can simply call trainer.fit(). PL has scripted the rest for us, which spares us a lot of boiler-plate coding.

So, we define a model as follows, note that we inherit the pl.LightningModule base class intead of torch.nn.Module:

In [1]:
import os
import torch
from torch import optim, nn, utils, Tensor
from torchvision.transforms import ToTensor
import pytorch_lightning as pl
import torchvision.transforms.functional as TF
import torch.nn.functional as F
import monai
from unet import UNet
from metrics import *

# define the LightningModule
class SegmentationModel(pl.LightningModule):
    def __init__(self, out_path_test='./test_results/', pretrained_weights=None):
        super().__init__()
        self.pretrained_weights = pretrained_weights
        self.model = UNet(n_channels=1, n_classes=4, bilinear=False, scaling=4)
        if pretrained_weights:
            self.model.load_state_dict(torch.load(pretrained_weights))
        self.criterion = nn.CrossEntropyLoss()
        self.out_path_test = out_path_test
        
        
    def forward(self, x, method='train'):
        if method=='train':
            return self.model.forward(x)
        elif method=='ssl':
            return self.model.forward_ssl(x)
        else:
            print('method not recognized')
    

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        if self.current_epoch < 2 and self.pretrained_weights:
            # Freeze model encoder weights
            for param in self.model.encoder.parameters():
                param.requires_grad = False
        else:
            # Unfreeze model encoder weights
            for param in self.model.encoder.parameters():
                param.requires_grad = True

        x, y = batch
        y_hat = self.forward(x) #self.forward(x)
        loss = self.criterion(y_hat, y)
        # Logging (to TensorBoard  by default)
        dice_loss_train =  dice_loss(y_hat, y)

        self.log("loss", {'train': loss.item() } )
        self.log("dice_loss", {'train': dice_loss_train.item() } )

        return loss  
    
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        loss = self.criterion(y_hat, y)
        hs_distances = calculate_hausdorff_distance(y_hat.cpu().detach(),
                                         y.cpu().detach())
        dice_loss_train =  dice_loss(y_hat, y)
        self.log("loss", {'val': loss })
        self.log("dice_loss", {'val': dice_loss_train.item() } )
        self.log("hs_distance region 1", {'val': hs_distances[0].item() } )
        self.log("hs_distance region 2", {'val': hs_distances[1].item() } )
        self.log("hs_distance region 3", {'val': hs_distances[2].item() } )
        self.log("hs_distance region 4", {'val': hs_distances[3].item() } )
        self.log("hs_distance average", {'val': np.mean(hs_distances) } )
        return loss    
    

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=3e-4)
        return optimizer
    
    
    def predict_step(self, sample, sample_idx):
        x, y = sample  # INCOMPATIBLE WITH batch_size > 1

        y_hat = self.forward(x)
        
        # log results as images
        fig, (ax0, ax1, ax2) = plt.subplots(nrows=1, ncols=3, figsize=(16,9));

        ax0.set_title('prediction', fontsize=30)
        ax1.set_title('ground truth', fontsize=30)
        ax2.set_title('image', fontsize=30)

        ax0.imshow(y_hat.cpu().argmax(dim=1)[0], vmax=3); ax0.axis('off')
        ax1.imshow(y[0].cpu(),  vmax=3); ax1.axis('off')
        ax2.imshow(x[0, 0].cpu(), cmap='Greys_r'); ax2.axis('off')
        
        fig.tight_layout()
        
        tensorboard = self.logger.experiment
        tensorboard.add_figure('inference results', fig, sample_idx)
        
        
    def test_step(self, sample, sample_idx):
        x, x_attrs, info = sample # INCOMPATIBLE WITH batch_size > 1
        
        # pad until divisible by 2^n (because of n up- and downsampling steps):
        two_n = 16
        row_pad = (-x.shape[-2]) % two_n
        col_pad = (-x.shape[-1]) % two_n
        x_padded = F.pad(x, (col_pad, 0, row_pad, 0))
        
        y_hat = self.forward(x_padded)[:, :, row_pad:, col_pad:]

        
        # log results to Tensorboard:
        fig, (ax0, ax1) = plt.subplots(nrows=1, ncols=2, figsize=(11,9));

        ax0.set_title('prediction', fontsize=30)
        ax1.set_title('image', fontsize=30)
        
        ax0.imshow(y_hat.cpu().argmax(dim=1)[0], vmax=3); ax0.axis('off')
        ax1.imshow(x[0, 0].cpu(), cmap='Greys_r'); ax1.axis('off')
        
        fig.tight_layout()
        
        tensorboard = self.logger.experiment
        tensorboard.add_figure('test results', fig, sample_idx)
        
        # write output to mhd+raw files (but first convert to sitk):
        mask = TF.resize(y_hat, x_attrs['shape'], InterpolationMode.BICUBIC)
        mask = mask.argmax(dim=1, keepdim=True).type(torch.uint8)[0]
        
        mask_sitk = sitk.GetImageFromArray(mask.cpu().numpy())
        mask_sitk.SetSpacing([x.item() for x in x_attrs['spacing']])

        filename = "_".join(x[0] for x in info[:3]) + '.mhd'
        out_path = os.path.join(self.out_path_test, filename)

        writer = sitk.ImageFileWriter()
        writer.SetFileName(out_path)
        writer.Execute(mask_sitk)

In [2]:
# init the model
pretrained_weights = './models/vicreg_encoder.pth' # Choose either this or None
#pretrained_weights = None
model = SegmentationModel(pretrained_weights=pretrained_weights)

As a next step, we need to define a dataset class. This is one is identical to a pytorch dataset object: we simply define how we want to load our data samples, and make them retrievable by defining indices. 

To organize the data, I recurse through the dataset directories, and put the relevant info in a Pandas dataframe (df). Every row in the df represents one image, and calling their row index will retrieve the image and mask data with the __getitem__ method.

In [3]:
from torch.utils.data import Dataset
import pandas as pd
import SimpleITK as sitk
from torchvision.transforms.functional import resize, center_crop
from torchvision.transforms import InterpolationMode

from sklearn.model_selection import train_test_split

class CamusDataset(Dataset):
    def __init__(self, data_path, image_size=(512, 512)):
        super().__init__()
        self.root = data_path

        self.data_list = []
        self.image_size = image_size

        for root, dirs, files in os.walk(self.root):
            for file in files:
                if file.split('_')[-1] == 'gt.mhd':
                    sample = file.split('_')[:3] # [patient, view, ED/ES]
                    self.data_list.append(sample)
        self.df = pd.DataFrame(self.data_list, columns=['patient', 'view', 'ED/ES'])

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = list(self.df.loc[idx])
        path = os.path.join(self.root, row[0], "_".join(row))
        image_sitk = sitk.ReadImage(f'{path}.mhd', sitk.sitkFloat32)

        # get pixel spacing to correct aspect ratio
        spacing = image_sitk.GetSpacing()
        aspect_ratio = spacing[1]/spacing[0]

        # convert to numpy
        image = sitk.GetArrayFromImage(image_sitk) / 255
        mask = sitk.GetArrayFromImage(sitk.ReadImage(f'{path}_gt.mhd', sitk.sitkFloat32))

        # compute aspect ratio of pixel(mm) and image(pixels)
        pixel_aspect = spacing[1] / spacing[0]
        image_aspect = image_sitk.GetHeight() / image_sitk.GetWidth()

        # preprocess image and mask
        image, mask = torch.Tensor(image), torch.Tensor(mask)
        size =  (self.image_size[0], int(image.shape[2]*image_aspect*pixel_aspect))

        image  = resize(image, size, interpolation=InterpolationMode.BICUBIC)
        mask = resize(mask, size, interpolation=InterpolationMode.NEAREST)

        image, mask = center_crop(image, self.image_size), center_crop(mask, self.image_size)
        mask = mask.squeeze()

        return image, mask.to(torch.long)



Now we instantiate the dataset object, and split the data into a training and validation set.
The dataloaders control how the data will be batched.

In [4]:
import matplotlib.pyplot as plt
import numpy as np
from torch.utils.data.sampler import SubsetRandomSampler

# init dataset object
dataset = CamusDataset(data_path=r"../data/training", image_size=(512, 512))

# split into train and validation set, by splitting indices:

indices = np.arange(len(dataset))
train_indices, val_indices = train_test_split(indices, random_state=42)


# init samplers: 
train_sampler = SubsetRandomSampler(train_indices)
val_sampler = SubsetRandomSampler(val_indices)

# init loaders: (set num_workers to 8 * number of gpus, 0 for debugging)
train_loader = utils.data.DataLoader(dataset, sampler=train_sampler, batch_size=5, num_workers=0)
val_loader = utils.data.DataLoader(dataset, sampler=val_sampler, batch_size=5, num_workers=0)

Finally, we define a trainer, which takes care of the rest. We train by simply calling trainer.fit(): 

Run ```tensorboard --logdir=lightning_logs --samples_per_plugin images=200``` in your (anaconda/bash) terminal to track the loss over time.

In [5]:
# train the model: cc 0.26 after 1 epoch
trainer = pl.Trainer(max_epochs=25, gpus=1)#, limit_train_batches=10)
trainer.fit(model, train_loader, val_loader)
# Save the model
model_name = 'unet.pth' if pretrained_weights is None else 'ssl_unet_pretrained.pth'
trainer.save_checkpoint('./models/'+model_name)

  f"Setting `Trainer(gpus={gpus!r})` is deprecated in v1.7 and will be removed"
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | model     | UNet             | 1.9 M 
1 | criterion | CrossEntropyLoss | 0     
-----------------------------------------------
1.9 M     Trainable params
0         Non-trainable params
1.9 M     Total params
7.769     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=25` reached.


Let's display some (validation) results in Tensorboard:

In [6]:
# automatically auto-loads the best weights from the previous run
val_loader_log = utils.data.DataLoader(dataset, sampler=val_sampler, batch_size=1, num_workers=0)
outputs = trainer.predict(dataloaders=val_loader_log)

  + f" You can pass `.{fn}(ckpt_path='best')` to use the best model or"
Restoring states from the checkpoint path at /home/tin/Documents/GitHub/CAMUS-challenge/gino_baseline/lightning_logs/version_0/checkpoints/epoch=24-step=6750.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /home/tin/Documents/GitHub/CAMUS-challenge/gino_baseline/lightning_logs/version_0/checkpoints/epoch=24-step=6750.ckpt


Predicting: 270it [00:00, ?it/s]



Finally, if we're happy with the results, we should run our network on the test set, and save the resulting masks in a format that's compatible with the CAMUS challenge website (The website allows 4 test submissions).


First we need to adjust the dataset class, as the ```__getitem__``` method shouldn't try to load a ground truth mask.

In [7]:
class CamusTestSet(Dataset):
    def __init__(self, data_path, image_size=(512, 512)):
        super().__init__()
        self.root = data_path
        
        self.data_list = []
        self.image_size = image_size
        
        for root, dirs, files in os.walk(self.root):
            for file in files:
                suffix = file.split('_')[-1]
                if suffix in ['ED.mhd', 'ES.mhd']:
                    sample = file.split('.')[0].split('_')[:3] # [patient, view, ED/ES]
                    self.data_list.append(sample)
        self.df = pd.DataFrame(self.data_list, columns=['patient', 'view', 'ED/ES'])

        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        row = list(self.df.loc[idx])
        path = os.path.join(self.root, row[0], "_".join(row))
        
        image_sitk = sitk.ReadImage(f'{path}.mhd', sitk.sitkFloat32)
        image = sitk.GetArrayFromImage(image_sitk) / 255
        
        # get pixel spacing to correct aspect ratio
        spacing = image_sitk.GetSpacing()
        pixel_aspect = spacing[1]/spacing[0]
        image_aspect = image_sitk.GetHeight() / image_sitk.GetWidth()
        
        # preprocess image
        image = torch.Tensor(image)
        size =  (self.image_size[0], int(image.shape[2]*image_aspect*pixel_aspect))
        image  = resize(image, size, interpolation=InterpolationMode.BICUBIC)
        
        image_attrs = dict(
            shape = [image_sitk.GetHeight(), image_sitk.GetWidth()],
            spacing = spacing
        )
        return image, image_attrs, row

Now we are ready to test the model. Earlier, in the model Class definition, I defined the ```test_step```, where it writes a predicted mask to an .mhd file format, which is compatible with the Challenge submission website.

In [8]:
test_set = CamusTestSet(data_path=r"C:\Users\Tadija\Desktop\data\testing")
test_loader = utils.data.DataLoader(test_set, batch_size=1, num_workers=0)
trainer.test(ckpt_path="best", dataloaders=test_loader);

Restoring states from the checkpoint path at /home/tin/Documents/GitHub/CAMUS-challenge/gino_baseline/lightning_logs/version_0/checkpoints/epoch=24-step=6750.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Loaded model weights from checkpoint at /home/tin/Documents/GitHub/CAMUS-challenge/gino_baseline/lightning_logs/version_0/checkpoints/epoch=24-step=6750.ckpt
  f"Total length of `{dataloader.__class__.__name__}` across ranks is zero."
