# Neuro UNet/ MeshnetTutorial

Authors: [Kevin Wang] (), [Alex Fedorov] (), [Sergey Kolesnikov](https://github.com/Scitator)

[![Catalyst logo](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/catalyst_logo.png)](https://github.com/catalyst-team/catalyst)

### Colab setup

First of all, do not forget to change the runtime type to GPU. <br/>
To do so click `Runtime` -> `Change runtime type` -> Select `\"Python 3\"` and `\"GPU\"` -> click `Save`. <br/>
After that you can click `Runtime` -> `Run all` and watch the tutorial.

## Requirements

Download and install the latest versions of catalyst and other libraries required for this tutorial.

In [1]:
from typing import Callable, List, Tuple

import os
import torch
import catalyst
from catalyst import utils

print(f"torch: {torch.__version__}, catalyst: {catalyst.__version__}")

# os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # "" - CPU, "0" - 1 GPU, "0,1" - MultiGPU

SEED = 42
utils.set_global_seed(SEED)
utils.prepare_cudnn(deterministic=True)

torch: 1.4.0+cu100, catalyst: 20.10.1


# Dataset

We'll be using the Mindboggle 101 dataset for a multiclass 3d segmentation task.
The dataset can be downloaded off osf with the following command from osfclient after you register with osf.

`osf -p 9ahyp clone .`

Otherwise you can download it using a Catalyst utility `download-gdrive` which downloads a version from the Catalyst Google Drive

`usage: download-gdrive {FILE_ID} {FILENAME}`

In [7]:
mkdir Mindboggle_data 

In [19]:
%%bash 

osf -p 9ahyp clone Mindboggle_data/

0files [00:00, ?files/s]
  0%|          | 0.00/3.22M [00:00<?, ?bytes/s][A
 27%|██▋       | 868k/3.22M [00:00<00:00, 8.65Mbytes/s][A
 69%|██████▊   | 2.21M/3.22M [00:00<00:00, 9.68Mbytes/s][A100%|██████████| 3.22M/3.22M [00:00<00:00, 11.3Mbytes/s]
1files [00:03,  3.82s/files]
  0%|          | 0.00/3.66k [00:00<?, ?bytes/s][A100%|██████████| 3.66k/3.66k [00:00<00:00, 12.7Mbytes/s]
2files [00:04,  3.02s/files]
  0%|          | 0.00/843M [00:00<?, ?bytes/s][A
  0%|          | 1.11M/843M [00:00<01:16, 11.1Mbytes/s][A
  0%|          | 2.29M/843M [00:00<01:14, 11.2Mbytes/s][A
  0%|          | 3.47M/843M [00:00<01:13, 11.4Mbytes/s][A
  1%|          | 4.67M/843M [00:00<01:12, 11.5Mbytes/s][A
  1%|          | 5.85M/843M [00:00<01:12, 11.5Mbytes/s][A
  1%|          | 7.03M/843M [00:00<01:12, 11.6Mbytes/s][A
  1%|          | 8.21M/843M [00:00<01:11, 11.6Mbytes/s][A
  1%|          | 9.39M/843M [00:00<01:11, 11.6Mbytes/s][A
  1%|▏         | 10.6M/843M [00:00<01:11, 1

Copy and extract volumes to the following location.

In [None]:
cp -r Mindboggle_data/osfstorage/Mindboggle101_volumes/ ../data/Mindboggle_data/
find data/Mindboggle_101 -name '*.tar.gz'| xargs -i tar zxvf {} -C data/Mindboggle_101
find data/Mindboggle_101 -name '*.tar.gz'| xargs -i rm {}

Run the prepare data script that limits the labels to 30 labels.

`usage: python ../neuro/scripts/prepare_data.py ../data/Mindboggle_101 {N_labels)`

In [None]:
%%bash 

python ../neuro/scripts/prepare_data.py ../data/Mindboggle_101/

Import Catalyst and Torch utils for training

In [2]:
import torch
import collections

from multiprocessing import Manager
from catalyst.contrib.utils.pandas import read_csv_data
from torch.utils.data import RandomSampler
from torch.utils.data import DataLoader
from torchvision import transforms
from catalyst.data import Augmentor, ReaderCompose

Here we import a BrainDataSet, which reads T1 scans and labels and samples either random patches of 38x38x38 samples from them or nonoverlapping patches of 38x38x38 for validation.  More detail can be found in brain_dataset.py and generator_coords.py  

In [3]:
from brain_dataset import BrainDataset
from reader import NiftiReader_Image, NiftiReader_Mask

The Transforms for the BrainDataset here are simple. 

Convert the T1 scans from numpy arrays to Troch floats and convert the corresponding labels to whatever default Torch array exists.

In [4]:
def get_transforms(self, stage: str = None, mode: str = None):                                                                                                                                          
    """                                                                                                                                                                                                 
    Args:                                                                                                                                                                                               
        stage (str)                                                                                                                                                                                     
        mode (str)                                                                                                                                                                                      
    """                                                                                                                                                                                                 
    if mode == "train":                                                                                                                                                                                 
        Augmentor1 = Augmentor(                                                                                                                                                                         
            dict_key="images",                                                                                                                                                                          
            augment_fn=lambda x: torch.from_numpy(x).float(),                                                                                                                                           
        )                                                                                                                                                                                               
        Augmentor2 = Augmentor(                                                                                                                                                                         
            dict_key="targets", augment_fn=lambda x: torch.from_numpy(x)                                                                                                                                
        )                                                                                                                                                                                               
        return transforms.Compose([Augmentor1, Augmentor2])                                                                                                                                             
    elif mode == "valid":                                                                                                                                                                               
        Augmentor1 = Augmentor(                                                                                                                                                                         
            dict_key="images",                                                                                                                                                                          
            augment_fn=lambda x: torch.from_numpy(x).float(),                                                                                                                                           
        )                                                                                                                                                                                               
        Augmentor2 = Augmentor(                                                                                                                                                                         
            dict_key="targets", augment_fn=lambda x: torch.from_numpy(x)                                                                                                                                
        )                                                                                                                                                                                               
        return transforms.Compose([Augmentor1, Augmentor2])

In [5]:
open_fn = ReaderCompose(                                                                                                                                                                            
    readers=[                                                                                                                                                                                       
        NiftiReader_Image(input_key="images", output_key="images"),                                                                                                                                 
        NiftiReader_Mask(input_key="nii_labels", output_key="targets"),
    ]
)

In [6]:
if __name__ == "__main__":
    def get_loaders(
        random_state: int,
        volume_shape: List[int],
        subvolume_shape: List[int],
        in_csv_train: str = None,                                                                                                                                                                           
        in_csv_valid: str = None,                                                                                                                                                                           
        in_csv_infer: str = None,
        batch_size: int = 16,
        num_workers: int = 10,
    ) -> dict:
        manager = Manager()

        df, df_train, df_valid, df_infer = read_csv_data(                                                                                                                                                   
        in_csv_train=in_csv_train,                                                                                                                                                                      
        in_csv_valid=in_csv_valid,                                                                                                                                                                      
        in_csv_infer=in_csv_infer,                                                                                                                                                                      
        ) 

        datasets = {}

        train_dataset = BrainDataset(shared_dict={},                                                                                                                                                             
                        list_data=df_train, list_shape=volume_shape, list_sub_shape=subvolume_shape,                                                                                                              
                        open_fn=open_fn, dict_transform=get_transforms(None, mode='train'),                                                                                                          
                        n_samples=100, mode='train', input_key="images",                                                                                                                                     
                        output_key="targets")
        valid_dataset = BrainDataset(shared_dict={},                                                                                                                                                             
                        list_data=df_valid, list_shape=volume_shape, list_sub_shape=subvolume_shape,                                                                                                              
                        open_fn=open_fn, dict_transform=get_transforms(None, mode='valid'),                                                                                                          
                        n_samples=100, mode='valid', input_key="images",                                                                                                                                     
                        output_key="targets")
        test_dataset = BrainDataset(shared_dict={},                                                                                                                                                             
                        list_data=df_infer, list_shape=volume_shape, list_sub_shape=subvolume_shape,                                                                                                              
                        open_fn=open_fn, dict_transform=get_transforms(None, mode='valid'),                                                                                                          
                        n_samples=100, mode='valid', input_key="images",                                                                                                                                     
                        output_key="targets")

        train_random_sampler = RandomSampler(data_source=train_dataset,                                                                                                                                   
                                             replacement=True,
                                             num_samples=80 * 128)

        valid_random_sampler = RandomSampler(data_source=valid_dataset,  
                                             replacement=True,
                                             num_samples=20*216)

        train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_random_sampler, 
                                  num_workers=10, pin_memory=True)
        valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, sampler=valid_random_sampler, 
                                  num_workers=10, pin_memory=True, drop_last=True)
        loaders = collections.OrderedDict()
        loaders["train"] = train_loader
        loaders["valid"] = valid_loader

        return loaders

In [7]:
if __name__ =="__main__":
    loaders = get_loaders(0, [256, 256, 256], [38, 38, 38], 
                          "../data/dataset_train.csv", "../data/dataset_valid.csv", "../data/dataset_infer.csv", )
    train_dataloader = loaders['train']
    next(iter(train_dataloader))

# Model

We'll be using the UNet defined in the model.py file for training

In [8]:
from model import UNet

unet = UNet(n_channels=1, n_classes=30)

# Model Training

We'll train the model 30 epochs

An Adam Optimizer with a cosine annealing schedule starting at a learning rate of .01 is used for this experiment.

CrossEntropyLoss is the criterion/ loss function be minimized 

In [None]:
from torch.nn import CrossEntropyLoss
from torch.optim.lr_scheduler import CosineAnnealingLR
from catalyst.dl import SupervisedRunner
from catalyst.callbacks.logging import TensorboardLogger
from catalyst.callbacks import SchedulerCallback, CheckpointCallback
from custom_metrics import CustomDiceCallback

num_epochs = 30
logdir = "logs/unet"

optimizer = torch.optim.Adam(unet.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = CosineAnnealingLR(optimizer, T_max=30)

runner = SupervisedRunner(input_key='images', input_target_key='labels', output_key='logits')

callbacks = [
    TensorboardLogger(),
    SchedulerCallback(reduced_metric='loss'),
    CustomDiceCallback(),
    CheckpointCallback(),
]

runner.train(model=unet, criterion=CrossEntropyLoss(), optimizer=optimizer, scheduler=scheduler, loaders=loaders,
            callbacks=callbacks, logdir=logdir, num_epochs=num_epochs, verbose=True)

1/30 * Epoch (train):   0% 0/640 [00:00<?, ?it/s]


Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument.



1/30 * Epoch (train): 100% 640/640 [21:51<00:00,  2.05s/it, dice=0.952, loss=0.050]   
1/30 * Epoch (valid): 100% 270/270 [10:13<00:00,  2.27s/it, dice=0.970, loss=0.030]    
[2020-11-03 13:43:46,953] 
1/30 * Epoch 1 (_base): lr=0.0099 | momentum=0.9000
1/30 * Epoch 1 (train): dice=0.9164 | loss=0.3170
1/30 * Epoch 1 (valid): dice=0.9828 | loss=0.0175
2/30 * Epoch (train): 100% 640/640 [22:36<00:00,  2.12s/it, dice=0.968, loss=0.034]
2/30 * Epoch (valid): 100% 270/270 [10:22<00:00,  2.31s/it, dice=0.996, loss=0.004]
[2020-11-03 14:16:47,057] 
2/30 * Epoch 2 (_base): lr=0.0098 | momentum=0.9000
2/30 * Epoch 2 (train): dice=0.9314 | loss=0.2609
2/30 * Epoch 2 (valid): dice=0.9854 | loss=0.0161
3/30 * Epoch (train): 100% 640/640 [22:30<00:00,  2.11s/it, dice=0.984, loss=0.016]
3/30 * Epoch (valid): 100% 270/270 [10:12<00:00,  2.27s/it, dice=0.993, loss=0.007]
[2020-11-03 14:49:32,799] 
3/30 * Epoch 3 (_base): lr=0.0096 | momentum=0.9000
3/30 * Epoch 3 (train): dice=0.9407 | loss=0.2247
3/