<a href="https://colab.research.google.com/github/ssktotoro/neuro/blob/tutorial_branch/Neuro%20UNet%20Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neuro UNet/ MeshnetTutorial

Authors: [Kevin Wang] (), [Alex Fedorov] (), [Sergey Kolesnikov](https://github.com/Scitator)

[![Catalyst logo](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/catalyst_logo.png)](https://github.com/catalyst-team/catalyst)

### Colab setup

First of all, do not forget to change the runtime type to GPU. <br/>
To do so click `Runtime` -> `Change runtime type` -> Select `\"Python 3\"` and `\"GPU\"` -> click `Save`. <br/>
After that you can click `Runtime` -> `Run all` and watch the tutorial.

## Requirements

Download and install the latest versions of catalyst and other libraries required for this tutorial.

In [1]:
%%bash 
git clone https://github.com/ssktotoro/neuro.git -b tutorial_branch
git pull
pip install -r neuro/requirements/requirements.txt


Collecting alchemy==20.4
  Downloading https://files.pythonhosted.org/packages/e1/d0/29085429e2f6203ee206a4aa93cb20cdafbdc2aa649d7b20de24eeb7fb69/alchemy-20.4-py2.py3-none-any.whl
Collecting catalyst==20.10.1
  Downloading https://files.pythonhosted.org/packages/1c/1f/7c0591a256990e146b377c282f17e2cd2717b25ac7e489c97dc972ed7248/catalyst-20.10.1-py2.py3-none-any.whl (475kB)
Collecting reaction==20.2
  Downloading https://files.pythonhosted.org/packages/75/9b/c549eb02e2b5caf8e2dcfb6386fa82645ffaaf2e7fc3c6d682f0591d8187/reaction-20.2-py2.py3-none-any.whl
Collecting osfclient
  Downloading https://files.pythonhosted.org/packages/a8/7a/8d6fe30d424329ced46a738faaea4150efb8eee656599b88a791cf7ad07e/osfclient-0.0.5-py2.py3-none-any.whl
Collecting requests==2.22.0
  Downloading https://files.pythonhosted.org/packages/51/bd/23c926cd341ea6b7dd0b2a00aba99ae0f828be89d72b2190f27c11d4b7fb/requests-2.22.0-py2.py3-none-any.whl (57kB)
Collecting GitPython>=3.1.1
  Downloading https://files.pythonhosted.o

Cloning into 'neuro'...
fatal: not a git repository (or any of the parent directories): .git
ERROR: google-colab 1.0.0 has requirement requests~=2.23.0, but you'll have requests 2.22.0 which is incompatible.
ERROR: datascience 0.10.6 has requirement folium==0.2.1, but you'll have folium 0.8.3 which is incompatible.


In [2]:
from typing import Callable, List, Tuple
import numpy as np
import os
import torch
import catalyst
from catalyst import utils

print(f"torch: {torch.__version__}, catalyst: {catalyst.__version__}")

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # "" - CPU, "0" - 1 GPU, "0,1" - MultiGPU

SEED = 42
utils.set_global_seed(SEED)
utils.prepare_cudnn(deterministic=True)

torch: 1.8.0+cu101, catalyst: 20.10.1


# Dataset

We'll be using the Mindboggle 101 dataset for a multiclass 3d segmentation task.
The dataset can be downloaded off osf with the following command from osfclient after you register with osf.

`osf -p 9ahyp clone .`

Otherwise you can download it using a Catalyst utility `download-gdrive` which downloads a version from the Catalyst Google Drive

`usage: download-gdrive {FILE_ID} {FILENAME}`

In [3]:
cd neuro

/content/neuro


In [None]:
%%bash
mkdir Mindboggle_data 
mkdir -p data/Mindboggle_101/
osf -p 9ahyp clone Mindboggle_data/
cp -r Mindboggle_data/osfstorage/Mindboggle101_volumes/ data/Mindboggle_101/
find data/Mindboggle_101 -name '*.tar.gz'| xargs -i tar zxvf {} -C data/Mindboggle_101
find data/Mindboggle_101 -name '*.tar.gz'| xargs -i rm {}

Run the prepare data script that limits the labels to the DKT human labels (60 labels).

`usage: python ../neuro/scripts/prepare_data.py ../data/Mindboggle_101 {N_labels)`

In [None]:
%%bash 

python neuro/scripts/prepare_data.py data/Mindboggle_101/ 60

Import Catalyst and Torch utils for training

In [None]:
import torch
import collections

from catalyst.contrib.utils.pandas import read_csv_data
from torch.utils.data import RandomSampler, SequentialSampler
from torch.utils.data import DataLoader
from torchvision import transforms
from catalyst.data import Augmentor, ReaderCompose
from torch.optim.lr_scheduler import CosineAnnealingLR
from catalyst.dl import SupervisedRunner
from catalyst.callbacks.logging import TensorboardLogger
from catalyst.callbacks import SchedulerCallback, CheckpointCallback
from torchvision.transforms import ToTensor
from torch.nn import functional as F

In [None]:
from torch.nn import functional as F

### **Create the relevant Dataloaders for the specified train, validation, and inference BrainDatasets.**

BrainDatasets comprise of T1 scans + the prepared limited labels.

Training/ Validation batches: Randomly Sampled NxNxN Subvolumes from a Normal Distribution across the Volume Space with their corresponding labels.

Inference batches: Non-overlapping NxNxN Subvolumes across the existing volume space with their corresponding labels

More detail can be found in brain_dataset.py and generator_coords.py  

In [None]:
cd training/

In [None]:
from brain_dataset import BrainDataset
from reader import NiftiReader_Image, NiftiReader_Mask
from callbacks import CustomDiceCallback
from model import UNet, MeshNet
from custom_metrics import custom_dice_metric

In [None]:
open_fn = ReaderCompose(                                                                                                                                                                            
    readers=[                                                                                                                                                                                       
        NiftiReader_Image(input_key="images", output_key="images"),                                                                                                                                 
        NiftiReader_Mask(input_key="nii_labels", output_key="targets"),
    ]
)

In [None]:

def get_loaders(
    random_state: int,
    volume_shape: List[int],
    subvolume_shape: List[int],
    in_csv_train: str = None,                                                                                                                                                                           
    in_csv_valid: str = None,                                                                                                                                                                           
    in_csv_infer: str = None,
    batch_size: int = 16,
    num_workers: int = 10,
) -> dict:

    df, df_train, df_valid, df_infer = read_csv_data(                                                                                                                                                   
    in_csv_train=in_csv_train,                                                                                                                                                                      
    in_csv_valid=in_csv_valid,                                                                                                                                                                      
    in_csv_infer=in_csv_infer,                                                                                                                                                                      
    ) 

    datasets = {}

    train_dataset = BrainDataset(shared_dict={},                                                                                                                                                             
                    list_data=df_train, list_shape=volume_shape, list_sub_shape=subvolume_shape,                                                                                                              
                    open_fn=open_fn,                                                                                                         
                    mode='train', input_key="images",                                                                                                                                     
                    output_key="targets")
    valid_dataset = BrainDataset(shared_dict={},                                                                                                                                                             
                    list_data=df_valid, list_shape=volume_shape, list_sub_shape=subvolume_shape,                                                                                                              
                    open_fn=open_fn,                                                                                                         
                    mode='valid', input_key="images",                                                                                                                                     
                    output_key="targets")
    test_dataset = BrainDataset(shared_dict={},                                                                                                                                                             
                    list_data=df_infer, list_shape=volume_shape, list_sub_shape=subvolume_shape,                                                                                                              
                    open_fn=open_fn,                                                                                                         
                    mode='infer', input_key="images",                                                                                                                                     
                    output_key="targets")

    train_random_sampler = RandomSampler(data_source=train_dataset,                                                                                                                                   
                                          replacement=True,
                                          num_samples=len(train_dataset) * 16)

    valid_random_sampler = RandomSampler(data_source=valid_dataset,  
                                          replacement=True,
                                          num_samples=len(valid_dataset)*16)
    
    infer_random_sampler = SequentialSampler(data_source=test_dataset)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_random_sampler, 
                              num_workers=2, pin_memory=True)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, sampler=valid_random_sampler, 
                              num_workers=2, pin_memory=True,drop_last=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, sampler=infer_random_sampler, 
                              num_workers=2, pin_memory=True,drop_last=True)
    train_loaders = collections.OrderedDict()
    infer_loaders = collections.OrderedDict()
    train_loaders["train"] = train_loader
    train_loaders["valid"] = valid_loader
    infer_loaders['infer'] = test_loader

    return train_loaders, infer_loaders

In [None]:
train_loaders, infer_loaders = get_loaders(0, [256, 256, 256], [38, 38, 38], 
                      "../data/dataset_train.csv", "../data/dataset_valid.csv", "../data/dataset_infer.csv", )

# Model Training

We'll train the model 1 epoch for demonstration although typically we train for 30 epochs.

* Scheduler: Adam with a cosine annealing schedule starting at a learning rate of .01 
* Batch Metric: DICE
* Loss: CrossEntropyLoss
* Logger: Tensorboard
* CheckpointerCallback

For training and validation we sample the volume with subvolumes specified in our Dataset 

In [None]:
cd ..

In [None]:
class CustomRunner(catalyst.dl.Runner):

    def predict_batch(self, batch):
        # model inference step
        return self.model(batch['images'].to(self.device)), batch['coords']

    def _handle_batch(self, batch):
        # model train/valid step
        x, y = batch['images'], batch['targets']
        y_hat = self.model(x)

        loss = F.cross_entropy(y_hat, y)
        self.batch_metrics.update({"loss": loss, "dice": custom_dice_metric(y_hat.float(), y, num_classes=60, activation='Softmax')})

        if self.is_train_loader:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

In [None]:
n_classes = 60
torch.backends.cudnn.deterministic = False
meshnet = MeshNet(n_channels=1, n_classes=n_classes)

logdir = "logs/meshnet"

optimizer = torch.optim.Adam(meshnet.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = CosineAnnealingLR(optimizer, T_max=30)

runner = CustomRunner()
runner.train(model=meshnet, optimizer=optimizer, loaders=train_loaders, num_epochs=1, scheduler=scheduler, 
             callbacks=[TensorboardLogger(), CheckpointCallback()], logdir=logdir, verbose=True)

In [None]:
def voxel_majority_predict_from_subvolumes(loader, volume_shape, n_classes):
  segmentations = {}
  for subject in range(loader.dataset.subjects - 1):
    segmentations[subject] = torch.zeros(tuple(np.insert(volume_shape, 0, n_classes)), dtype=torch.uint8)

  for inference in runner.predict_loader(loader=loader):
    subj_id = loader.dataset.subjects // len(loader.dataset.coords)
    coords = inference[1]
    predicted = inference[0].cpu()
    for j in range(predicted.shape[0]):
      c_j = coords[j]
      for c in range(n_classes):
        segmentations[subj_id][c, c_j[0, 0]:c_j[0, 1], 
                              c_j[1, 0]:c_j[1, 1], 
                              c_j[2, 0]:c_j[2, 1]] += (predicted[j] == c)

  for i in segmentations.keys():
    segmentations[i] = torch.max(segmentations[i], 0)[1]
  return segmentations

In [None]:
inference[1].shape

In [None]:
for i in SequentialSampler(infer_loader.dataset): 
  print(i)

In [None]:
infer_loader.sampler()

In [None]:
from torch.nn.functional import log_softmax

In [None]:
segmentations = {}
for subject in range(infer_loader.dataset.subjects):
  segmentations[subject] = torch.zeros(tuple(np.insert([256, 256, 256], 0, 60)), dtype=torch.uint8)

for inference in runner.predict_loader(loader=infer_loader):
  subj_id = infer_loader.dataset.subjects // len(infer_loader.dataset.coords)
  coords = inference[1]
  print(coords)
  _, predicted = torch.max(log_softmax(inference[0].cpu(), dim=1), 1)
  for j in range(predicted.shape[0]):
    c_j = coords[j][0]
    for c in range(n_classes):
      segmentations[subj_id][c, c_j[0, 0]:c_j[0, 1], 
                            c_j[1, 0]:c_j[1, 1], 
                            c_j[2, 0]:c_j[2, 1]] += (predicted[j] == c)

for i in segmentations.keys():
  segmentations[i] = torch.max(segmentations[i], 0)[1]


In [None]:
n_classes = 60
infer_loader = infer_loaders['infer']

segmentations = voxel_majority_predict_from_subvolumes(infer_loader, [256, 256, 256], n_classes)