# U-Net analysis

The following is a u-net analysis based off the templates for torchio found at https://colab.research.google.com/github/fepegar/torchio-notebooks/blob/main/notebooks/TorchIO_tutorial.ipynb#scrollTo=p-V_kHC5BvST

We will train a [3D U-Net](https://link.springer.com/chapter/10.1007/978-3-319-46723-8_49) to perform the landmark detection of the corneas and rhabdoms

Install some pypi packages

In [40]:
!pip install --quiet --upgrade pip
!pip install --quiet unet==0.7.7
!pip install --quiet torchio==0.18.33

Install pytorch following recommendations at https://pytorch.org/

For me, this was running the following command:

```bash
pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
```

Import packages

In [41]:
import enum
import os
import time
import random
import multiprocessing
from pathlib import Path
import copy
from datetime import datetime

import torch
import torchvision
import torchio as tio
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
import torch.nn
import monai

import numpy as np
from unet import UNet
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

from IPython import display
from tqdm.notebook import tqdm

seed = 42

random.seed(seed)
torch.manual_seed(seed)
%config InlineBackend.figure_format = 'retina'
num_workers = multiprocessing.cpu_count()
plt.rcParams['figure.figsize'] = 12, 6

print('TorchIO version:', tio.__version__)

TorchIO version: 0.18.33


# Define DataModule
We will use a LightningDataModule to handle our data.

In [42]:
# from multiprocessing import Manager
# class SubjectsDataset(tio.SubjectsDataset):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         # Only changes.
#         manager = Manager()
#         self._subjects = manager.list(self._subjects)


class DataModule(pl.LightningDataModule):
    def __init__(
        self,
        batch_size,
        train_val_ratio,
        train_images_dir,
        train_labels_dir,
        test_images_dir,
        patch_size,
        samples_per_volume,
        max_length,
        num_workers = 4,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.train_val_ratio = train_val_ratio
        self.train_images_dir = train_images_dir
        self.train_labels_dir = train_labels_dir
        self.test_images_dir = test_images_dir
        self.patch_size = patch_size
        self.samples_per_volume = samples_per_volume
        self.num_workers = num_workers
        self.max_length = max_length

    def get_max_shape(self, subjects):
        import numpy as np
        dataset = tio.SubjectsDataset(subjects)
        shapes = np.array([s.spatial_shape for s in dataset])
        return shapes.max(axis=0)

    def prepare_data(self):
        self.subjects = []

        # find all the .nii files
        images = []
        labels = []
        for file in os.listdir(self.train_images_dir):
            if file.endswith('.nii.gz'):
                spltstr = file.split('-')
                images.append([spltstr[0], spltstr[-1]])
        for file in os.listdir(self.train_labels_dir):
            if file.endswith('.nii.gz'):
                spltstr = file.split('-')
                labels.append([spltstr[0], spltstr[-1]])
            
        breakpoint()
        filenames = sorted(list(set(images) & set(labels)))
        print(f'Found {len(filenames)} labelled images for analysis')

        # now add them to a list of subjects
        for filename in filenames:
            subject = tio.Subject(
                image=tio.ScalarImage(self.train_images_dir + filename + '.nii', check_nans=True),
                label_corneas=tio.Image(self.train_labels_dir + filename + '_corneas.nii', type=tio.LABEL, check_nans=True),
                label_rhabdoms=tio.Image(self.train_labels_dir + filename + '_rhabdoms.nii', type=tio.LABEL, check_nans=True),
                filename=filename
            )
            self.subjects.append(subject)

        # collect test images
        self.test_subjects = []
        for file in os.listdir(self.test_images_dir):
            if file.endswith('.nii'):
                subject = tio.Subject(image=tio.ScalarImage(self.test_images_dir + file))
                self.test_subjects.append(subject)
        
    def get_preprocessing_transform(self):
        preprocess = tio.Compose([
            tio.ToCanonical(),
            tio.HistogramStandardization({'image': 'landmarks.npy'}, masking_method=tio.ZNormalization.mean),
            tio.ZNormalization(masking_method=tio.ZNormalization.mean),
            tio.EnsureShapeMultiple(8) # for the u-net
        ])
        return preprocess
    
    def get_augmentation_transform(self):
        augment = tio.Compose([
            # tio.RandomMotion(p=0.2),
            # tio.RandomNoise(p=0.5),
            tio.RandomFlip(),
            # tio.OneOf({
            #     tio.RandomAffine(): 0.8,
            #     tio.RandomElasticDeformation(): 0.2,
            # }, p=0.8),
        ])
        return augment

    def get_sampler(self):
        self.sampler = tio.UniformSampler(self.patch_size)
    
    def setup(self, stage=None):
        num_subjects = len(self.subjects)
        num_train_subjects = int(round(num_subjects * self.train_val_ratio))
        num_val_subjects = num_subjects - num_train_subjects
        splits = num_train_subjects, num_val_subjects
        train_subjects, val_subjects = random_split(self.subjects, splits)

        self.preprocess = self.get_preprocessing_transform()
        augment = self.get_augmentation_transform()
        self.transform = tio.Compose([self.preprocess, augment])

        self.train_set = tio.SubjectsDataset(train_subjects, transform=self.transform)
        self.val_set = tio.SubjectsDataset(val_subjects, transform=self.preprocess)
        self.test_set = tio.SubjectsDataset(self.test_subjects, transform=self.preprocess)

        self.get_sampler()

        # self.train_queue = tio.Queue(
        #     self.train_set,
        #     self.max_length,
        #     self.samples_per_volume,
        #     self.sampler,
        #     num_workers=num_workers
        # )
        # self.val_queue = tio.Queue(
        #     self.val_set,
        #     self.max_length,
        #     self.samples_per_volume,
        #     self.sampler,
        #     num_workers=num_workers
        # )
        # self.test_queue = tio.Queue(
        #     self.test_set,
        #     self.max_length,
        #     self.samples_per_volume,
        #     self.sampler,
        #     num_workers=num_workers
        # )

    
    def train_dataloader(self):
        return DataLoader(self.train_set, batch_size=self.batch_size, num_workers=1)
        # return DataLoader(self.train_queue, batch_size=self.batch_size, num_workers=1)

    def val_dataloader(self):
        return DataLoader(self.val_set, batch_size=self.batch_size, num_workers=1)
        # return DataLoader(self.val_queue, batch_size=self.batch_size, num_workers=1)

    def test_dataloader(self):
        return DataLoader(self.test_set, batch_size=self.batch_size, num_workers=1)
        # return DataLoader(self.test_queue, batch_size=self.batch_size, num_workers=1)


Now let's initialise the data module

In [43]:
data = DataModule(
    batch_size=1,
    train_val_ratio=0.8,
    train_images_dir='./dataset/crab_images/',
    train_labels_dir='./dataset/crab_labels/',
    test_images_dir='./dataset/crab_test/',
    patch_size=64,
    samples_per_volume=4,
    max_length=40,
    num_workers=8
)

Let's see how many images we are working with. Normally, we won't need to run the below code as PyTorch Lightning will automatically call prepare_data and setup for us. 

In [44]:
data.prepare_data()
data.setup()
print('Training:  ', len(data.train_set))
print('Validation: ', len(data.val_set))
print('Test:      ', len(data.test_set))

[['flammula_20180307', '31.nii.gz']]


Exception: 

# Lightning model

We are using a standard Pytorch lightning model to define the logic for training and validation steps.

Read this for ideas: https://docs.monai.io/en/latest/highlights.html#network-architectures

# TODO:
- add in support for multiple labels by concatenating them into two channels. E.g. with:

```
def prepare_batch(self, batch):
    return batch['image'][tio.DATA], concatenate(batch['label_corneas'][tio.DATA], batch['label_rhabdoms'][tio.DATA])

```
or (probably better)

making prepare_data load in the Subject 'label_corneas' and 'label_rhabdoms' as 'label' with two channels

- add in support to have a sliding window sampler that can then reconstruct the test image

- add in support to have a continuous output. THis is probably by using a different loss function (i.e. MSE)

In [24]:
class Model(pl.LightningModule):
    def __init__(self, net, criterion, learning_rate, optimizer_class):
        super().__init__()
        self.lr = learning_rate
        self.net = net
        self.criterion = criterion
        self.optimizer_class = optimizer_class
    
    def configure_optimizers(self):
        optimizer = self.optimizer_class(self.parameters(), lr=self.lr)
        return optimizer
    
    def prepare_batch(self, batch):
        # print('Make some logic here to concatenate the two types of labels into two channels')
        # breakpoint()
        return batch['image'][tio.DATA], batch['label_corneas'][tio.DATA]
    
    def infer_batch(self, batch):
        x, y = self.prepare_batch(batch)
        y_hat = self.net(x)
        return y_hat, y

    def training_step(self, batch, batch_idx):
        y_hat, y = self.infer_batch(batch)
        loss = self.criterion(y_hat, y)
        self.log('train_loss', loss, prog_bar=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        y_hat, y = self.infer_batch(batch)
        loss = self.criterion(y_hat, y)
        self.log('val_loss', loss)
        return loss

In [25]:
# TODO: check what 16bit precision is
unet = monai.networks.nets.UNet(
    dimensions=3,
    in_channels=1,
    out_channels=1,
    channels=(8, 16, 32, 64),
    strides=(2, 2, 2),
)

model = Model(
    net=unet,
    criterion=torch.nn.MSELoss(),
    learning_rate=1e-2,
    optimizer_class=torch.optim.AdamW,
)
early_stopping = pl.callbacks.early_stopping.EarlyStopping(
    monitor='val_loss',
)
trainer = pl.Trainer(
    gpus=1,
    precision=16,
    callbacks=[early_stopping],
)
trainer.logger._default_hp_metric = False

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


# Train
Let's start fitting the model!

In [28]:
train_dataloader = data.train_dataloader()
one_batch = next(iter(train_dataloader))



KeyboardInterrupt: 

In [26]:
# run this in terminal: tensorboard --logdir lightning_logs

In [27]:
start = datetime.now()
print('Training started at', start)
trainer.fit(model=model, datamodule=data)
print('Training duration:', datetime.now() - start)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type    | Params
--------------------------------------
0 | net       | UNet    | 121 K 
1 | criterion | MSELoss | 0     
--------------------------------------
121 K     Trainable params
0         Non-trainable params
121 K     Total params
0.487     Total estimated model params size (MB)


Training started at 2021-12-01 15:39:21.135394
Validation sanity check:   0%|          | 0/1 [00:00<?, ?it/s]

KeyboardInterrupt: 