<a href="https://colab.research.google.com/github/ssktotoro/neuro/blob/tutorial_branch/Neuro%20UNet%20Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Neuro UNet/ MeshnetTutorial

Authors: [Kevin Wang] (), [Alex Fedorov] (), [Sergey Kolesnikov](https://github.com/Scitator)

[![Catalyst logo](https://raw.githubusercontent.com/catalyst-team/catalyst-pics/master/pics/catalyst_logo.png)](https://github.com/catalyst-team/catalyst)

### Colab setup

First of all, do not forget to change the runtime type to GPU. <br/>
To do so click `Runtime` -> `Change runtime type` -> Select `\"Python 3\"` and `\"GPU\"` -> click `Save`. <br/>
After that you can click `Runtime` -> `Run all` and watch the tutorial.

## Requirements

Download and install the latest versions of catalyst and other libraries required for this tutorial.

In [1]:
%%bash 
git clone https://github.com/ssktotoro/neuro.git -b tutorial_branch
git pull
pip install -r neuro/requirements/requirements.txt




fatal: destination path 'neuro' already exists and is not an empty directory.
fatal: not a git repository (or any of the parent directories): .git


In [2]:
from typing import Callable, List, Tuple

import os
import torch
import catalyst
from catalyst import utils

print(f"torch: {torch.__version__}, catalyst: {catalyst.__version__}")

os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # "" - CPU, "0" - 1 GPU, "0,1" - MultiGPU

SEED = 42
utils.set_global_seed(SEED)
utils.prepare_cudnn(deterministic=True)

torch: 1.8.0+cu101, catalyst: 20.10.1


# Dataset

We'll be using the Mindboggle 101 dataset for a multiclass 3d segmentation task.
The dataset can be downloaded off osf with the following command from osfclient after you register with osf.

`osf -p 9ahyp clone .`

Otherwise you can download it using a Catalyst utility `download-gdrive` which downloads a version from the Catalyst Google Drive

`usage: download-gdrive {FILE_ID} {FILENAME}`

In [3]:
cd neuro

/content/neuro


In [4]:
%%bash
mkdir Mindboggle_data 
mkdir -p data/Mindboggle_101/
osf -p 9ahyp clone Mindboggle_data/
cp -r Mindboggle_data/osfstorage/Mindboggle101_volumes/ data/Mindboggle_101/
find data/Mindboggle_101 -name '*.tar.gz'| xargs -i tar zxvf {} -C data/Mindboggle_101
find data/Mindboggle_101 -name '*.tar.gz'| xargs -i rm {}

NKI-TRT-20_volumes/
NKI-TRT-20_volumes/NKI-TRT-20-7/
NKI-TRT-20_volumes/NKI-TRT-20-9/
NKI-TRT-20_volumes/NKI-TRT-20-8/
NKI-TRT-20_volumes/NKI-TRT-20-6/
NKI-TRT-20_volumes/NKI-TRT-20-1/
NKI-TRT-20_volumes/NKI-TRT-20-20/
NKI-TRT-20_volumes/NKI-TRT-20-18/
NKI-TRT-20_volumes/NKI-TRT-20-11/
NKI-TRT-20_volumes/NKI-TRT-20-16/
NKI-TRT-20_volumes/NKI-TRT-20-17/
NKI-TRT-20_volumes/NKI-TRT-20-10/
NKI-TRT-20_volumes/NKI-TRT-20-19/
NKI-TRT-20_volumes/NKI-TRT-20-4/
NKI-TRT-20_volumes/NKI-TRT-20-3/
NKI-TRT-20_volumes/NKI-TRT-20-2/
NKI-TRT-20_volumes/NKI-TRT-20-5/
NKI-TRT-20_volumes/NKI-TRT-20-15/
NKI-TRT-20_volumes/NKI-TRT-20-12/
NKI-TRT-20_volumes/NKI-TRT-20-13/
NKI-TRT-20_volumes/NKI-TRT-20-14/
NKI-TRT-20_volumes/NKI-TRT-20-14/labels.DKT31.manual.nii.gz
NKI-TRT-20_volumes/NKI-TRT-20-14/t1weighted_brain.MNI152.nii.gz
NKI-TRT-20_volumes/NKI-TRT-20-14/t1weighted.MNI152.nii.gz
NKI-TRT-20_volumes/NKI-TRT-20-14/labels.DKT31.manual+aseg.nii.gz
NKI-TRT-20_volumes/NKI-TRT-20-14/labels.DKT31.manual+aseg.MNI1

mkdir: cannot create directory ‘Mindboggle_data’: File exists
0files [00:00, ?files/s]
  0%|          | 0.00/3.22M [00:00<?, ?bytes/s][A100%|██████████| 3.22M/3.22M [00:00<00:00, 171Mbytes/s]
1files [00:04,  4.50s/files]
  0%|          | 0.00/3.66k [00:00<?, ?bytes/s][A100%|██████████| 3.66k/3.66k [00:00<00:00, 29.1Mbytes/s]
2files [00:05,  3.43s/files]
  0%|          | 0.00/843M [00:00<?, ?bytes/s][A
  1%|          | 8.40M/843M [00:00<00:15, 52.8Mbytes/s][A
  3%|▎         | 23.9M/843M [00:00<00:12, 65.8Mbytes/s][A
  4%|▍         | 36.8M/843M [00:00<00:10, 77.2Mbytes/s][A
  6%|▌         | 50.5M/843M [00:00<00:08, 88.7Mbytes/s][A
  7%|▋         | 60.1M/843M [00:00<00:09, 83.3Mbytes/s][A
  8%|▊         | 71.3M/843M [00:00<00:08, 90.1Mbytes/s][A
 11%|█         | 92.7M/843M [00:00<00:06, 109Mbytes/s] [A
 13%|█▎        | 113M/843M [00:00<00:05, 127Mbytes/s] [A
 15%|█▌        | 128M/843M [00:00<00:05, 134Mbytes/s][A
 17%|█▋        | 147M/843M [00:01<00:04, 147Mb

Run the prepare data script that limits the labels to the DKT human labels (60 labels).

`usage: python ../neuro/scripts/prepare_data.py ../data/Mindboggle_101 {N_labels)`

In [5]:
%%bash 

python neuro/scripts/prepare_data.py data/Mindboggle_101/ 60

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100


Import Catalyst and Torch utils for training

In [6]:
import torch
import collections

from catalyst.contrib.utils.pandas import read_csv_data
from torch.utils.data import RandomSampler, SequentialSampler
from torch.utils.data import DataLoader
from torchvision import transforms
from catalyst.data import Augmentor, ReaderCompose
from torch.optim.lr_scheduler import CosineAnnealingLR
from catalyst.dl import SupervisedRunner
from catalyst.callbacks.logging import TensorboardLogger
from catalyst.callbacks import SchedulerCallback, CheckpointCallback
from torchvision.transforms import ToTensor
from torch.nn import functional as F

In [7]:
from torch.nn import functional as F

Here we import a BrainDataSet, which reads T1 scans and labels and samples either random patches of 38x38x38 samples from them or nonoverlapping patches of 38x38x38 for validation.  More detail can be found in brain_dataset.py and generator_coords.py  

In [8]:
cd training/

/content/neuro/training


In [10]:
from brain_dataset import BrainDataset
from reader import NiftiReader_Image, NiftiReader_Mask
from callbacks import CustomDiceCallback
from model import UNet, MeshNet

In [11]:
open_fn = ReaderCompose(                                                                                                                                                                            
    readers=[                                                                                                                                                                                       
        NiftiReader_Image(input_key="images", output_key="images"),                                                                                                                                 
        NiftiReader_Mask(input_key="nii_labels", output_key="targets"),
    ]
)

In [12]:

def get_loaders(
    random_state: int,
    volume_shape: List[int],
    subvolume_shape: List[int],
    in_csv_train: str = None,                                                                                                                                                                           
    in_csv_valid: str = None,                                                                                                                                                                           
    in_csv_infer: str = None,
    batch_size: int = 16,
    num_workers: int = 10,
) -> dict:

    df, df_train, df_valid, df_infer = read_csv_data(                                                                                                                                                   
    in_csv_train=in_csv_train,                                                                                                                                                                      
    in_csv_valid=in_csv_valid,                                                                                                                                                                      
    in_csv_infer=in_csv_infer,                                                                                                                                                                      
    ) 

    datasets = {}

    train_dataset = BrainDataset(shared_dict={},                                                                                                                                                             
                    list_data=df_train, list_shape=volume_shape, list_sub_shape=subvolume_shape,                                                                                                              
                    open_fn=open_fn,                                                                                                         
                    mode='train', input_key="images",                                                                                                                                     
                    output_key="targets")
    valid_dataset = BrainDataset(shared_dict={},                                                                                                                                                             
                    list_data=df_valid, list_shape=volume_shape, list_sub_shape=subvolume_shape,                                                                                                              
                    open_fn=open_fn,                                                                                                         
                    mode='valid', input_key="images",                                                                                                                                     
                    output_key="targets")
    test_dataset = BrainDataset(shared_dict={},                                                                                                                                                             
                    list_data=df_infer, list_shape=volume_shape, list_sub_shape=subvolume_shape,                                                                                                              
                    open_fn=open_fn,                                                                                                         
                    mode='valid', input_key="images",                                                                                                                                     
                    output_key="targets")

    train_random_sampler = RandomSampler(data_source=train_dataset,                                                                                                                                   
                                          replacement=True,
                                          num_samples=len(train_dataset) * 128)

    valid_random_sampler = RandomSampler(data_source=valid_dataset,  
                                          replacement=True,
                                          num_samples=len(train_dataset)*128)
    
    valid_random_sampler = SequentialSampler(data_source=test_dataset)

    train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, sampler=train_random_sampler, 
                              num_workers=2, pin_memory=True)
    valid_loader = DataLoader(dataset=valid_dataset, batch_size=batch_size, sampler=valid_random_sampler, 
                              num_workers=2, pin_memory=True,drop_last=True)
    test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, sampler=valid_random_sampler, 
                              num_workers=2, pin_memory=True,drop_last=True)
    train_loaders = collections.OrderedDict()
    infer_loaders = collections.OrderedDict()
    train_loaders["train"] = train_loader
    train_loaders["valid"] = valid_loader
    infer_loaders['infer'] = test_loader

    return train_loaders, infer_loaders

In [13]:
train_loaders, infer_loaders = get_loaders(0, [256, 256, 256], [38, 38, 38], 
                      "../data/dataset_train.csv", "../data/dataset_valid.csv", "../data/dataset_infer.csv", )

# Model Training

We'll train the model 30 epochs

An Adam Optimizer with a cosine annealing schedule starting at a learning rate of .01 is used for this experiment.

CrossEntropyLoss is the criterion/ loss function be minimized 

In [14]:
cd ..

/content/neuro


In [15]:
class CustomRunner(catalyst.dl.Runner):

    def predict_batch(self, batch):
        # model inference step
        return self.model(batch[0].to(self.device))

    def _handle_batch(self, batch):
        # model train/valid step
        x, y = batch['images'], batch['targets']
        y_hat = self.model(x)

        loss = F.cross_entropy(y_hat, y)
        self.batch_metrics.update({"loss": loss, })

        if self.is_train_loader:
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

In [None]:
torch.backends.cudnn.deterministic = False
meshnet = MeshNet(n_channels=1, n_classes=60)

num_epochs = 30
logdir = "logs/meshnet"

optimizer = torch.optim.Adam(meshnet.parameters(), lr=0.01, weight_decay=0.0001)
scheduler = CosineAnnealingLR(optimizer, T_max=30)

callbacks = [
    TensorboardLogger(),
    SchedulerCallback(reduced_metric='loss'),
    CustomDiceCallback(),
    CheckpointCallback(),
]
runner = CustomRunner()
runner.train(model=meshnet, optimizer=optimizer, loaders=loaders, num_epochs=30, logdir=logdir, verbose=True)



1/30 * Epoch (train):   0% 0/640 [00:00<?, ?it/s][A[A

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6b98fdd830>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/python3.7/multiprocessing/process.py", line 151, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f6b98fdd830>
Traceback (most recent call last):
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1324, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.7/dist-packages/torch/utils/data/dataloader.py", line 1316, in _shutdown_workers
    if w.is_alive():
  File "/usr/lib/pytho



1/30 * Epoch (train):   0% 0/640 [00:17<?, ?it/s, loss=4.239][A[A

1/30 * Epoch (train):   0% 1/640 [00:17<3:08:03, 17.66s/it, loss=4.239][A[A

1/30 * Epoch (train):   0% 1/640 [00:18<3:08:03, 17.66s/it, loss=4.035][A[A

1/30 * Epoch (train):   0% 2/640 [00:18<2:12:37, 12.47s/it, loss=4.035][A[A

1/30 * Epoch (train):   0% 2/640 [00:32<2:12:37, 12.47s/it, loss=3.795][A[A

1/30 * Epoch (train):   0% 3/640 [00:32<2:17:44, 12.97s/it, loss=3.795][A[A

1/30 * Epoch (train):   0% 3/640 [00:32<2:17:44, 12.97s/it, loss=3.530][A[A

1/30 * Epoch (train):   1% 4/640 [00:32<1:37:29,  9.20s/it, loss=3.530][A[A

1/30 * Epoch (train):   1% 4/640 [00:48<1:37:29,  9.20s/it, loss=3.357][A[A

1/30 * Epoch (train):   1% 5/640 [00:48<1:57:11, 11.07s/it, loss=3.357][A[A

1/30 * Epoch (train):   1% 5/640 [00:48<1:57:11, 11.07s/it, loss=3.137][A[A

1/30 * Epoch (train):   1% 6/640 [00:48<1:23:06,  7.87s/it, loss=3.137][A[A

1/30 * Epoch (train):   1% 6/640 [01:02<1:23:06,  7.87s/it, 