<a href="https://colab.research.google.com/github/ssktotoro/neuro/blob/tutorial_branch/Neuro_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neuro UNet/ MeshnetTutorial

Authors: [Kevin Wang] (), [Alex Fedorov] (), [Sergey Kolesnikov](https://github.com/Scitator)

[![Catalyst logo](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/catalyst_logo.png)](https://github.com/catalyst-team/catalyst)

### Colab setup

First of all, do not forget to change the runtime type to GPU. <br/>
To do so click `Runtime` -> `Change runtime type` -> Select `\"Python 3\"` and `\"GPU\"` -> click `Save`. <br/>
After that you can click `Runtime` -> `Run all` and watch the tutorial.

## Requirements

Download and install the latest versions of catalyst and other libraries required for this tutorial.

In [2]:
%%bash 
git clone https://github.com/ssktotoro/neuro.git -b tutorial_branch
pip install -r neuro/requirements/requirements.txt




fatal: destination path 'neuro' already exists and is not an empty directory.


In [5]:
import torch                                                                                                                                                                                              
from tqdm import tqdm                                                                                                                                                                                     
import numpy as np                                                                                                                                                                                        
import nibabel as nib                                                                                                                                                                                     
import collections                                                                                                                                                                                        
from collections import OrderedDict                                                                                                                                                                       
                                                                                                                                                                                                          
import catalyst                                                                                                                                                                                           
import pandas as pd  
import os                                                                                                                                                                                     
                                                                                                                                                                                                          
from catalyst.contrib.utils.pandas import dataframe_to_list                                                                                                                                               
from torch.utils.data import SequentialSampler                                                                                                                                                            
from torch.utils.data import DataLoader                                                                                                                                                                   
from catalyst.data import ReaderCompose                                                                                                                                                                   
from torch.optim.lr_scheduler import CosineAnnealingLR, OneCycleLR                                                                                                                                        
from catalyst.callbacks import CheckpointCallback                                                                                                                                                         
from torch.nn import functional as F                                                                                                                                                                      
from typing import List
from catalyst import utils                                                                                                                                                                                   
                                                                                                                                                                                                                                                                                                                                                                                    
from catalyst import metrics                                                                                                                                                                              
from catalyst.data import BatchPrefetchLoaderWrapper                                                                                                                                                      
from catalyst.dl import Runner, LRFinder                                                                                                                                                                  
                                                                                                                                                                                                          
from catalyst.metrics.functional._segmentation import dice


print(f"torch: {torch.__version__}, catalyst: {catalyst.__version__}")

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # "" - CPU, "0" - 1 GPU, "0,1" - MultiGPU

SEED = 42
utils.set_global_seed(SEED)
utils.prepare_cudnn(deterministic=True)

torch: 1.8.1+cu101, catalyst: 21.03.1


# Dataset

We'll be using the Mindboggle 101 dataset for a multiclass 3d segmentation task.
The dataset can be downloaded off osf with the following command from osfclient after you register with osf.

`osf -p 9ahyp clone .`

Otherwise you can download it using a Catalyst utility `download-gdrive` which downloads a version from the Catalyst Google Drive

`usage: download-gdrive {FILE_ID} {FILENAME}`

In [6]:
cd neuro

/content/neuro


In [7]:
%%bash
mkdir Mindboggle_data 
mkdir -p data/Mindboggle_101/
osf -p 9ahyp clone Mindboggle_data/
cp -r Mindboggle_data/osfstorage/Mindboggle101_volumes/ data/Mindboggle_101/
find data/Mindboggle_101 -name '*.tar.gz'| xargs -i tar zxvf {} -C data/Mindboggle_101
find data/Mindboggle_101 -name '*.tar.gz'| xargs -i rm {}

MMRR-21_volumes/
MMRR-21_volumes/MMRR-21-12/
MMRR-21_volumes/MMRR-21-15/
MMRR-21_volumes/MMRR-21-14/
MMRR-21_volumes/MMRR-21-13/
MMRR-21_volumes/MMRR-21-7/
MMRR-21_volumes/MMRR-21-9/
MMRR-21_volumes/MMRR-21-8/
MMRR-21_volumes/MMRR-21-6/
MMRR-21_volumes/MMRR-21-1/
MMRR-21_volumes/MMRR-21-18/
MMRR-21_volumes/MMRR-21-20/
MMRR-21_volumes/MMRR-21-16/
MMRR-21_volumes/MMRR-21-11/
MMRR-21_volumes/MMRR-21-10/
MMRR-21_volumes/MMRR-21-17/
MMRR-21_volumes/MMRR-21-21/
MMRR-21_volumes/MMRR-21-19/
MMRR-21_volumes/MMRR-21-4/
MMRR-21_volumes/MMRR-21-3/
MMRR-21_volumes/MMRR-21-2/
MMRR-21_volumes/MMRR-21-5/
MMRR-21_volumes/MMRR-21-5/labels.DKT31.manual.nii.gz
MMRR-21_volumes/MMRR-21-5/t1weighted_brain.MNI152.nii.gz
MMRR-21_volumes/MMRR-21-5/t1weighted.MNI152.nii.gz
MMRR-21_volumes/MMRR-21-5/labels.DKT31.manual+aseg.nii.gz
MMRR-21_volumes/MMRR-21-5/labels.DKT31.manual+aseg.MNI152.nii.gz
MMRR-21_volumes/MMRR-21-5/labels.DKT31.manual.MNI152.nii.gz
MMRR-21_volumes/MMRR-21-5/t1weighted.nii.gz
MMRR-21_volumes/

0files [00:00, ?files/s]
  0%|          | 0.00/3.22M [00:00<?, ?bytes/s][A100%|██████████| 3.22M/3.22M [00:00<00:00, 216Mbytes/s]
1files [00:06,  6.32s/files]
  0%|          | 0.00/3.66k [00:00<?, ?bytes/s][A100%|██████████| 3.66k/3.66k [00:00<00:00, 19.3Mbytes/s]
2files [00:07,  4.79s/files]
  0%|          | 0.00/843M [00:00<?, ?bytes/s][A
  0%|          | 4.21M/843M [00:00<01:10, 11.9Mbytes/s][A
  1%|          | 8.40M/843M [00:00<01:00, 13.7Mbytes/s][A
  4%|▍         | 33.6M/843M [00:00<00:42, 19.1Mbytes/s][A
  5%|▍         | 42.0M/843M [00:01<00:39, 20.1Mbytes/s][A
  8%|▊         | 69.7M/843M [00:01<00:27, 27.8Mbytes/s][A
 10%|▉         | 82.5M/843M [00:01<00:24, 30.6Mbytes/s][A
 11%|█         | 92.7M/843M [00:01<00:20, 36.3Mbytes/s][A
 13%|█▎        | 109M/843M [00:01<00:15, 46.7Mbytes/s] [A
 14%|█▍        | 121M/843M [00:01<00:12, 57.1Mbytes/s][A
 16%|█▌        | 132M/843M [00:02<00:16, 43.5Mbytes/s][A
 18%|█▊        | 151M/843M [00:02<00:12, 56.0Mb

Run the prepare data script that limits the labels to the DKT cortical labels (31 labels).  We can use of course use more labels.

`usage: python ../neuro/scripts/prepare_data.py ../data/Mindboggle_101 {N_labels)`

In [8]:
%%bash 

python neuro/scripts/prepare_data.py data/Mindboggle_101/ 31

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100


### **Create the relevant Dataloaders for the specified train, validation, and inference BrainDatasets.**

BrainDatasets comprise of T1 scans + the prepared limited labels.

Training/ Validation batches: Randomly Sampled NxNxN Subvolumes from a Normal Distribution across the Volume Space with their corresponding labels.

Inference batches: Non-overlapping NxNxN Subvolumes across the existing volume space with their corresponding labels

More detail can be found in brain_dataset.py and generator_coords.py  

In [9]:
cd training/

/content/neuro/training


In [10]:
from brain_dataset import BrainDataset                                                                                                                                                                    
from reader import NiftiFixedVolumeReader, NiftiReader                                                                                                                                                    
from model import MeshNet, UNet

In [19]:
  def get_loaders(                                                                                                                                                                                          
      random_state: int,                                                                                                                                                                                    
      volume_shape: List[int],                                                                                                                                                                              
      subvolume_shape: List[int],                                                                                                                                                                           
      in_csv_train: str = None,                                                                                                                                                                             
      in_csv_valid: str = None,                                                                                                                                                                             
      in_csv_infer: str = None,                                                                                                                                                                             
      batch_size: int = 16,                                                                                                                                                                                 
      num_workers: int = 10,                                                                                                                                                                                
  ) -> dict:                                                                                                                                                                                                
                                                                                                                                                                                                            
      datasets = {}                                                                                                                                                                                         
      open_fn = ReaderCompose(                                                                                                                                                                              
          [                                                                                                                                                                                                 
              NiftiFixedVolumeReader(input_key="images", output_key="images"),                                                                                                                              
              NiftiReader(input_key="nii_labels", output_key="targets"),                                                                                                                                    
                                                                                                                                                                                                            
          ]                                                                                                                                                                                                 
      )                                                                                                                                                                                                     
                                                                                                                                                                                                            
      for mode, source in zip(("train", "validation", "infer"),                                                                                                                                             
                              (in_csv_train, in_csv_valid, in_csv_infer)):                                                                                                                                  
          if mode == "infer":                                                                                                                                                                               
              n_subvolumes = 512                                                                                                                                                                            
          else:                                                                                                                                                                                             
              n_subvolumes = 128

          if source is not None and len(source) > 0:                                                                                                                                                        
              dataset = BrainDataset(                                                                                                                                                                       
                  list_data=dataframe_to_list(pd.read_csv(source)),                                                                                                                                         
                  list_shape=volume_shape,                                                                                                                                                                  
                  list_sub_shape=subvolume_shape,                                                                                                                                                           
                  open_fn=open_fn,                                                                                                                                                                          
                  n_subvolumes=n_subvolumes,                                                                                                                                                                
                  mode=mode,                                                                                                                                                                                
                  input_key="images",                                                                                                                                                                       
                  output_key="targets",                                                                                                                                                                     
              )                                                                                                                                                                                             
                                                                                                                                                                                                            
          datasets[mode] = {"dataset": dataset}                                                                                                                                                             
                                                                                                                                                                                                            
      def worker_init_fn(worker_id):                                                                                                                                                                        
          np.random.seed(np.random.get_state()[1][0] + worker_id)                                                                                                                                           
                                                                                                                                                                                                            
                                                                                                                                                                                                            
      train_loader = DataLoader(dataset=datasets['train']['dataset'], batch_size=batch_size,                                                                                                                
                                shuffle=True, worker_init_fn=worker_init_fn,                                                                                                                                
                                num_workers=2, pin_memory=True)                                                                                                                                            
      valid_loader = DataLoader(dataset=datasets['validation']['dataset'],                                                                                                                                  
                                shuffle=True, worker_init_fn=worker_init_fn,                                                                                                                                
                                batch_size=batch_size,                                                                                                                                                      
                                num_workers=2, pin_memory=True,drop_last=True)                                                                                                                             
      test_loader = DataLoader(dataset=datasets['infer']['dataset'],                                                                                                                                        
                               batch_size=batch_size, worker_init_fn=worker_init_fn,                                                                                                                        
                               num_workers=2, pin_memory=True,drop_last=True)                                                                                                                              
      train_loaders = collections.OrderedDict()                                                                                                                                                             
      infer_loaders = collections.OrderedDict()                                                                                                                                                             
      train_loaders["train"] = BatchPrefetchLoaderWrapper(train_loader)                                                                                                                                     
      train_loaders["valid"] = BatchPrefetchLoaderWrapper(valid_loader)                                                                                                                                     
      infer_loaders['infer'] = BatchPrefetchLoaderWrapper(test_loader)                                                                                                                                      
                                                                                                                                                                                                            
      return train_loaders, infer_loaders

In [15]:
cd ../

/content/neuro


In [20]:
  volume_shape = [256, 256, 256]                                                                                                                                                                            
  subvolume_shape = [38, 38, 38]                                                                                                                                                                            
  train_loaders, infer_loaders = get_loaders(0, volume_shape, subvolume_shape,                                                                                                                              
                                             "./data/dataset_train.csv",                                                                                                                                    
                                             "./data/dataset_valid.csv",                                                                                                                                    
                                             "./data/dataset_infer.csv")                                                                                                                                   

# Model Training

We'll train the model 5 epochs for demonstration although typically we train for 30 epochs.

* Scheduler: Adam with a cosine annealing schedule starting at a learning rate of .01 
* Batch Metric: DICE
* Loss: CrossEntropyLoss
* Logger: Tensorboard
* CheckpointerCallback

For training and validation we sample the volume with subvolumes specified in our Dataset 

In [17]:
 class CustomRunner(Runner):                                                                                                                                                                               
                                                                                                                                                                                                            
      def get_loaders(self, stage: str) -> "OrderedDict[str, DataLoader]":                                                                                                                                  
          """Returns the loaders for a given stage."""                                                                                                                                                      
          self._loaders = self._loaders                                                                                                                                                                     
          return self._loaders                                                                                                                                                                              
                                                                                                                                                                                                            
      def predict_batch(self, batch):                                                                                                                                                                       
          # model inference step                                                                                                                                                                            
          batch = batch[0]                                                                                                                                                                                  
          return self.model(batch['images'].float().to(self.device)), batch['coords']                                                                                                                       
                                                                                                                                                                                                            
      def on_loader_start(self, runner):                                                                                                                                                                    
          super().on_loader_start(runner)                                                                                                                                                                   
          self.meters = {                                                                                                                                                                                   
              key: metrics.AdditiveValueMetric(compute_on_call=False)                                                                                                                                       
              for key in ["loss", "macro_dice"]                                                                                                                                                             
          }                                                                                                                                                                                                 
                                                                                                                                                                                                            
      def handle_batch(self, batch):                                                                                                                                                                        
                                                                                                                                                                                                            
          # model train/valid step                                                                                                                                                                          
          batch = batch[0]                                                                                                                                                                                  
          x, y = batch['images'].float(), batch['targets']                                                                                                                                                  
                                                                                                                                                                                                            
          if self.is_train_loader:                                                                                                                                                                          
              self.optimizer.zero_grad()                                                                                                                                                                    
                                                                                                                                                                                                            
          y_hat = self.model(x)                                                                                                                                                                             
          loss = F.cross_entropy(y_hat, y)                                                                                                                                                                  
                                                                                                                                                                                                            
          if self.is_train_loader:                                                                                                                                                                          
              loss.backward()                                                                                                                                                                               
              self.optimizer.step()                                                                                                                                                                         
              scheduler.step()                                                                                                                                                                              
                                                                                                                                                                                                            
          one_hot_targets = (                                                                                                                                                                               
              torch.nn.functional.one_hot(y, 31)                                                                                                                                                            
              .permute(0, 4, 1, 2, 3)                                                                                                                                                                       
              .cuda()                                                                                                                                                                                       
              )                                                                                                                                                                                             
                                                                                                                                                                                                            
          logits_softmax = F.softmax(y_hat)                                                                                                                                                                 
          macro_dice = dice(logits_softmax, one_hot_targets, mode='macro')                                                                                                                                  
                                                                                                                                                                                                            
          self.batch_metrics.update({"loss": loss,                                                                                                                                                          
                                     'macro_dice': macro_dice})                                                                                                                                             
                                                                                                                                                                                                            
          for key in ["loss", "macro_dice"]:                                                                                                                                                                
              self.meters[key].update(self.batch_metrics[key].item(), self.batch_size)                                                                                                                      
                                                                                                                                                                                                            
      def on_loader_end(self, runner):                                                                                                                                                                      
          for key in ["loss", "macro_dice"]:                                                                                                                                                                
              self.loader_metrics[key] = self.meters[key].compute()[0]                                                                                                                                      
          super().on_loader_end(runner)                                                                                                                                                                     
                                           

In [None]:
  n_classes = 31                                                                                                                                                                                            
  n_epochs = 30                                                                                                                                                                                            
  meshnet = MeshNet(n_channels=1, n_classes=n_classes)                                                                                                                                                      
                                                                                                                                                                                                            
  logdir = "logs/meshnet_mindboggle"                                                                                                                                                                        
                                                                                                                                                                                                            
  optimizer = torch.optim.Adam(meshnet.parameters(), lr=0.02)                                                                                                                                               
                                                                                                                                                                                                            
                                                                                                                                                                                                            
  scheduler = OneCycleLR(optimizer, max_lr=.02,                                                                                                                                                             
                         epochs=n_epochs, steps_per_epoch=len(train_loaders['train']))                                                                                                                                                                                                                                                                                                                                                                                                                                                                                   
                                                                                                                                                                                                            
  runner = CustomRunner()                                                                                                                                                                                   
  runner.train(model=meshnet, optimizer=optimizer, loaders=train_loaders,                                                                                                                                   
               num_epochs=n_epochs, scheduler=scheduler,                                                                                                                                                    
               callbacks=[CheckpointCallback(logdir=logdir)], logdir=logdir, verbose=True)

HBox(children=(FloatProgress(value=0.0, description='1/30 * Epoch (train)', max=560.0, style=ProgressStyle(des…

  for k, v in self.batch_metrics.items()


# Model Evaluation

For every brain volume we implement a majority vote for every voxel and use  that to compute a Dice score.


The initial volume is segmented into
a regular grid of subvolumes partitioning the whole volume.
These volumes ensure a prediction for each voxel. Then we
sample overlapping volumes from the brain region until N
subvolumes (512 in this case) are achieved for prediction.  

Given only 70 volumes for training (typically 700+) and very minimal augmentations we can achieve a 0.6688 Mean Dice Score for a brain.

In [None]:
  def voxel_majority_predict_from_subvolumes(loader, n_classes, segmentations):                                                                                                                             
      if segmentations is None:                                                                                                                                                                             
          for subject in range(loader.dataset.subjects):                                                                                                                                                    
              segmentations[subject] = torch.zeros(                                                                                                                                                         
                  tuple(np.insert(loader.volume_shape, 0, n_classes)),                                                                                                                                      
                  dtype=torch.uint8).cpu()                                                                                                                                                                  
                                                                                                                                                                                                            
      prediction_n = 0                                                                                                                                                                                      
      for inference in tqdm(runner.predict_loader(loader=loader)):                                                                                                                                          
          coords = inference[1].cpu()                                                                                                                                                                       
          _, predicted = torch.max(F.log_softmax(inference[0].cpu(), dim=1), 1)                                                                                                                             
          for j in range(predicted.shape[0]):                                                                                                                                                               
              c_j = coords[j][0]                                                                                                                                                                            
              subj_id = prediction_n // loader.dataset.n_subvolumes                                                                                                                                         
              for c in range(n_classes):                                                                                                                                                                    
                  segmentations[subj_id][c, c_j[0, 0]:c_j[0, 1],                                                                                                                                            
                                         c_j[1, 0]:c_j[1, 1],                                                                                                                                               
                                         c_j[2, 0]:c_j[2, 1]] += (predicted[j] == c)                                                                                                                        
              prediction_n += 1                                                                                                                                                                             
                                                                                                                                                                                                            
      for i in segmentations.keys():                                                                                                                                                                        
          segmentations[i] = torch.max(segmentations[i], 0)[1]                                                                                                                                              
      return segmentations 

In [None]:
  segmentations = {}                                                                                                                                                                                        
  for subject in range(infer_loaders['infer'].dataset.subjects):                                                                                                                                            
      segmentations[subject] = torch.zeros(tuple(np.insert(volume_shape, 0, n_classes)), dtype=torch.uint8) 

In [None]:
  segmentations = voxel_majority_predict_from_subvolumes(infer_loaders['infer'],                                                                                                                            
                                                         n_classes, segmentations)                                                                                                                          
  subject_metrics = []                                                                                                                                                                                      
  for subject, subject_data in enumerate(tqdm(infer_loaders['infer'].dataset.data)):                                                                                                                        
      seg_labels = nib.load(subject_data['nii_labels']).get_fdata()                                                                                                                                         
      segmentation_labels = torch.nn.functional.one_hot(                                                                                                                                                    
          torch.from_numpy(seg_labels).to(torch.int64), n_classes)                                                                                                                                          
                                                                                                                                                                                                            
      inference_dice = dice(                                                                                                                                                                                
          torch.nn.functional.one_hot(                                                                                                                                                                      
              segmentations[subject], n_classes).permute(0, 3, 1, 2),                                                                                                                                       
          segmentation_labels.permute(0, 3, 1, 2)).detach().numpy()                                                                                                                                         
      macro_inference_dice = dice(                                                                                                                                                                          
          torch.nn.functional.one_hot(segmentations[subject], n_classes).permute(0, 3, 1, 2),                                                                                                               
          segmentation_labels.permute(0, 3, 1, 2), mode='macro').detach().numpy()                                                                                                                           
      subject_metrics.append((inference_dice, macro_inference_dice))                                                                                                                                        
                                                                                                                                                                                                            
  per_class_df = pd.DataFrame([metric[0] for metric in subject_metrics])                                                                                                                                    
  macro_df = pd.DataFrame([metric[1] for metric in subject_metrics])                                                                                                                                        
  print(per_class_df, macro_df)                                                                                                                                                                             
  print(macro_df.mean())