In [27]:
# Config
seed = 42  # for reproducibility
training_split_ratio = 0.75
num_epochs = 5

# If the following values are False, the models will be downloaded and not computed
compute_histograms = False
train_whole_images = False 
train_patches = False

Install the following via terminal
```bash
sudo apt install tree
```

Install some pypi packages

In [28]:
!pip install --quiet --upgrade pip
!pip install --quiet unet==0.7.7
!pip install --quiet torchio==0.18.33

Install pytorch following recommendations at https://pytorch.org/

For me, this was running the following command:

```bash
pip install torch==1.9.0+cu111 torchvision==0.10.0+cu111 torchaudio==0.9.0 -f https://download.pytorch.org/whl/torch_stable.html
```

Import packages

In [29]:
import enum
import os
import time
import random
import multiprocessing
from pathlib import Path
import copy
from datetime import datetime

import torch
import torchvision
import torchio as tio
import torch.nn.functional as F
import pytorch_lightning as pl
from torch.utils.data import random_split, DataLoader
import torch.nn
import monai

import numpy as np
from unet import UNet
from scipy import stats
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()

from IPython import display
from tqdm.notebook import tqdm

random.seed(seed)
torch.manual_seed(seed)
%config InlineBackend.figure_format = 'retina'
num_workers = multiprocessing.cpu_count()
plt.rcParams['figure.figsize'] = 12, 6

print('TorchIO version:', tio.__version__)

TorchIO version: 0.18.33


# Data
# Viewing dataset files
We will use the tree program to view our file structure

In [30]:
# dataset_dir_name = 'dataset'

# !tree -d {dataset_dir_name}

## Making the subjects dataset
This is a torchio data format that lets you easily modify the subjects with transforms (e.g. to generate new training data via data augmentation) and load it efficiently to the model with a DataLoader.

It receives as input a list of torchio.Subject instances and an optional torchio.transforms.Transform.

The inputs to the subject class are instances of torchio.Image, such as torchio.ScalarImage (for scalars) or torchio.LabelMap (for categories). The image class will be used by the transforms to decide whether or not to perform the operation. For example, spatial transforms must apply to both, but intensity transforms must apply to scalar images only.

In [31]:
# image_dir = f'{dataset_dir_name}/crab_images_10/'
# label_dir = f'{dataset_dir_name}/crab_labels_10/'

# # find all the .nii files
# images = []
# labels = []

# for file in os.listdir(image_dir):
#     if file.endswith('.nii'):
#         images.append(image_dir + file)
        
# for file in os.listdir(label_dir):
#     if file.endswith('.nii'):
#         labels.append(label_dir + file)


# # find the matching pairs by their filename
# images_p = []
# for img in images:
#     images_p.append(Path(img).stem)
    
# labels_p = []
# for label in labels:
#     name = (Path(label).stem).replace('_corneas', '').replace('_rhabdoms', '')
#     labels_p.append(name)
    
# filenames = sorted(list(set(images_p) & set(labels_p)))

# print(f'Found {len(filenames)} labelled images for analysis')

# # now add them to a list of subjects
# subjects_list = []

# for filename in filenames:
#     subject = tio.Subject(
#         image=tio.ScalarImage(image_dir + filename + '.nii', check_nans=True),
#         label_corneas=tio.Image(label_dir + filename + '_corneas.nii', type=tio.LABEL, check_nans=True),
#         label_rhabdoms=tio.Image(label_dir + filename + '_rhabdoms.nii', type=tio.LABEL, check_nans=True),
#         filename=filename
#     )
#     subjects_list.append(subject)

# # and finally create a SubjectsDataset
# dataset = tio.SubjectsDataset(subjects_list)
# print('Created a SubjectsDataset')
# print(f'Dataset size: {len(dataset)} subjects')

Let's take a look at one of the subjects in the dataset

In [32]:
# one_subject = dataset[0]
# one_subject.plot()
# print(one_subject)
# print(one_subject.image)
# print(one_subject.label_corneas)
# print(one_subject.label_rhabdoms)

# Explore transforms
### Intensity transforms
We will use the HistogramStandardization and the ZNormalization transforms to standardise and normalize our image intensity.

The images have been acquired by different mct scanners with different fixation and staining methods. We will apply some normalization techniques so that intensities are similarly distributed and within similar ranges.

#### Histogram Standardization
This method is an implementation of [New variants of a method of MRI scale standardization](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.204.102&rep=rep1&type=pdf).
The main problem that this method intends to fix is that absolute intensity values do not have a fixed meaning. Intensity values can be highly dependent on fixation, staining and scanning procedures.
This method intends to make transformed images have intensities values with the same tissue meaning. It does this by defining landmarks aross all images and then deforms the histogram to match that of the trained mean histogram.

To remove the effect of the background (usually zeros in mct scans), and only recognise tissue features in the foreground, we will use a masking function that only uses values greater than the mean.

We test whether the mean gets the foreground values quite well below.

In [33]:
# image_thresholded = copy.deepcopy(dataset[0])
# data = image_thresholded.image.data
# data[data > data.float().mean()] = data.max()
# image_thresholded.plot()

Looks like it does quite a good job. Now let's calculate the landmarks for the foregrounds.

In [34]:
# landmarks = tio.HistogramStandardization.train(
#     images,
#     output_path='landmarks.npy',
#     masking_function=tio.ZNormalization.mean
# )
# np.set_printoptions(suppress=True, precision=3)
# print('\nTrained landmarks:', landmarks)

## Augmentation
We will use a variety of augmentation methods:
- RandomAnisotropy
- RandomBlur
- RandomNoise
- OneOf
    - RandomAffine
    - RandomElasticDeformation
- OneOf
    - RandomMotion
    - RandomGhosting # perhaps not applicable to micro-ct?

In [39]:
# import napari
# viewer = napari.view_image(znormed.image.numpy(), name='transformed intensity image', ndisplay=3)
# viewer.add_image(dataset[0].image.numpy(), name='original image')