# Train an Retinal Vessel Segmentation Model (with MPI parallel)

In [None]:
!pip install ColossalAI deepspeed


Import dependencies

In [1]:
from datetime import datetime
import os
import os.path as osp

import torch
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torchvision.models as models
import colossalai
from colossalai.core import global_context as gpc
from colossalai.logging import get_dist_logger
import argparse
from colossalai.trainer import Trainer

from dataloaders.vessel import RetinalVesselSegmentation
from dataloaders import custom_transforms as tr
from networks.unet.unet_model import UNet

## Set training parameters
Using fundus image datasets from four domains. Dataset available upon request.

In [6]:
datasetTrain = [0,1,2]
datasetTest = [3]
data_dir = '/work/zhangyq/RVS' #Dataset directory
config = './configs/Fundus/FundusConfigBase.py'
gpc.load_config(config)
colossalai.context.config.Config.from_file(config)
splitidTrain = []
for x in datasetTrain:
    splitidTrain.append(int(x))
splitidTest = []
for x in datasetTest:
    splitidTest.append(int(x))



## Initialize distributed environment (Supporting MPI/Slurm/torch)

In [3]:
if 'OMPI_COMM_WORLD_RANK' in os.environ:
    colossalai.launch_from_openmpi(config=config,
    host='gpu01',
    port='11455',
    backend='nccl')
elif 'SLURM_PROCID' in os.environ:
    colossalai.launch_from_slurm(config=config,
    host='localhost',
    port='11455',
    backend='nccl')
elif 'WORLD_SIZE' in os.environ:
    colossalai.launch_from_torch(config=config,
    host='localhost',
    port='11455',
    backend='nccl')
else:
    colossalai.launch(
        config=config,
        host='localhost',
        port='11455',
        rank=0,
        world_size=1,
        backend='nccl')

colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:35:14,151 INFO: Added key: store_based_barrier_key:1 to store for rank: 0
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:35:14,153 INFO: Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 1 nodes.
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:35:14,154 INFO: Added key: store_based_barrier_key:2 to store for rank: 0
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:35:14,155 INFO: Rank 0: Completed store-based barrier for key:store_based_barrier_key:2 with 1 nodes.
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:35:14,157 INFO: Added key: store_based_barrier_key:3 to store for rank: 0
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:35:14,158 INFO: Rank 0: Completed store-based barrier for key:store_based_barrier_key:3 with 1 nodes.
colossalai - root - 2021-12-11 03:35:14,166 INFO: process rank 0 is bound to device 0
co



## Set up dataset and augmentation

In [7]:
composed_transforms_tr = transforms.Compose([
    tr.RandomScaleCrop(256),
    # tr.RandomCrop(512),
    # tr.RandomRotate(),
    # tr.RandomFlip(),
    # tr.elastic_transform(),
    # tr.add_salt_pepper_noise(),
    # tr.adjust_light(),
    # tr.eraser(),
    tr.Normalize_tf(),
    tr.ToTensor()
])

composed_transforms_ts = transforms.Compose([
    tr.RandomCrop(256),
    tr.Normalize_tf(),
    tr.ToTensor()
])

domain = RetinalVesselSegmentation(base_dir=data_dir, phase='train', splitid=splitidTrain,
                                                        transform=composed_transforms_tr)
train_loader = DataLoader(domain, batch_size=8, shuffle=True, num_workers=1, pin_memory=True)

domain_val = RetinalVesselSegmentation(base_dir=data_dir, phase='test', splitid=splitidTest,
                                    transform=composed_transforms_ts)
val_loader = DataLoader(domain_val, batch_size=8, shuffle=False, num_workers=1, pin_memory=True)


==> Loading train data from: /work/zhangyq/RVS/CHASEDB1/train
==> Loading train data from: /work/zhangyq/RVS/DRIVE/train
==> Loading train data from: /work/zhangyq/RVS/HRF/train
img_num: 55
key STARE has no data
20 images in CHASEDB1
20 images in DRIVE
15 images in HRF
-----Total number of images in train: 55
==> Loading test data from: /work/zhangyq/RVS/STARE
img_num: 10
key CHASEDB1 has no data
key DRIVE has no data
key HRF has no data
10 images in STARE
-----Total number of images in test: 10


## Set up model

In [8]:
model = UNet(3,2).cuda()

## Initialize Engine and Trainer

In [9]:
optim = torch.optim.Adam(
    model.parameters(),
    lr=0.001,
    betas=(0.9, 0.99)
)
def batch_data_process_func(sample):
    image = None
    label = None
    for domain in sample:
        if image is None:
            image = domain['image']
            label = domain['label']
        else:
            image = torch.cat([image, domain['image']], 0)
            label = torch.cat([label, domain['label']], 0)
    return image,label
logger = get_dist_logger('root')
schedule=colossalai.engine.schedule.NonPipelineSchedule()
# lr_scheduler=colossalai.nn.lr_scheduler.CosineAnnealingLR(optim, 1000)
criterion=torch.nn.BCELoss()
schedule.batch_data_process_func = batch_data_process_func
engine, train_dataloader, test_dataloader, lr_scheduler = colossalai.initialize(model=model,
optimizer=optim,
criterion=criterion,
train_dataloader=train_loader,
test_dataloader=val_loader,
verbose=True,)

logger.info("engine is built", ranks=[0])

trainer = Trainer(engine=engine,
        schedule=schedule, logger=logger)
logger.info("trainer is built", ranks=[0])

colossalai - root - 2021-12-11 03:36:09,685 INFO: 
{'hooks': [{'type': 'LogMetricByEpochHook'},
           {'type': 'LogTimingByEpochHook'},
           {'type': 'LogMemoryByEpochHook'},
           {'type': 'Accuracy2DHook'},
           {'type': 'LossHook'}],
 'logging': {'root_path': './logs'},
 'num_epochs': 60,
 'optimizer': {'lr': 0.001, 'type': 'Adam', 'weight_decay': 0},
 'parallel': {'data': {'size': 1},
              'pipeline': {'size': 1},
              'tensor': {'mode': None, 'size': 1}},
 'schedule': {'num_microbatches': 8}}

colossalai - root - 2021-12-11 03:36:09,686 INFO: cuDNN benchmark = True, deterministic = False
colossalai - root - 2021-12-11 03:36:09,691 INFO: engine is built
colossalai - root - 2021-12-11 03:36:09,692 INFO: trainer is built


## Train it!

In [10]:
logger.info("start training", ranks=[0])
trainer.fit(
    train_dataloader=train_dataloader,
    test_dataloader=test_dataloader,
    epochs=gpc.config.num_epochs,
    display_progress=True,
    test_interval=2
)

colossalai - root - 2021-12-11 03:36:10,397 INFO: start training
colossalai - root - 2021-12-11 03:36:10,398 INFO: Lower value means higher priority for calling hook function
[Epoch 0 train]: 100%|██████████| 8/8 [00:08<00:00,  1.08s/it]
[Epoch 0 val]: 100%|██████████| 2/2 [00:02<00:00,  1.31s/it]
[Epoch 1 train]: 100%|██████████| 8/8 [00:01<00:00,  5.09it/s]
[Epoch 2 train]: 100%|██████████| 8/8 [00:01<00:00,  5.05it/s]
[Epoch 2 val]: 100%|██████████| 2/2 [00:00<00:00, 15.81it/s]
[Epoch 3 train]: 100%|██████████| 8/8 [00:01<00:00,  5.19it/s]
[Epoch 4 train]: 100%|██████████| 8/8 [00:01<00:00,  5.33it/s]
[Epoch 4 val]: 100%|██████████| 2/2 [00:00<00:00, 15.30it/s]
[Epoch 5 train]: 100%|██████████| 8/8 [00:01<00:00,  5.14it/s]
[Epoch 6 train]: 100%|██████████| 8/8 [00:01<00:00,  5.35it/s]
[Epoch 6 val]: 100%|██████████| 2/2 [00:00<00:00, 15.73it/s]
[Epoch 7 train]: 100%|██████████| 8/8 [00:01<00:00,  5.34it/s]
[Epoch 8 train]: 100%|██████████| 8/8 [00:01<00:00,  5.44it/s]
[Epoch 8 val]:

# Special note
Parallel runs can be acheived via command below:

In [1]:
!time mpirun -np 4 python train.py --config configs/Fundus/FundusConfig1d.py --datasetTrain 012 --datasetTest 3 --data-dir ~/RVS

colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:42:32,461 INFO: Added key: store_based_barrier_key:1 to store for rank: 3
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:42:33,376 INFO: Added key: store_based_barrier_key:1 to store for rank: 1
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:42:33,449 INFO: Added key: store_based_barrier_key:1 to store for rank: 2
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:42:33,455 INFO: Added key: store_based_barrier_key:1 to store for rank: 0
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:42:33,456 INFO: Rank 0: Completed store-based barrier for key:store_based_barrier_key:1 with 4 nodes.
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:42:33,456 INFO: Added key: store_based_barrier_key:2 to store for rank: 0
colossalai - torch.distributed.distributed_c10d - 2021-12-11 03:42:33,458 INFO: Rank 1: Completed store-based barrier for key:store_based_barrie

However, to achieve best performance, workarounds are still required to tweak the model for ColossalAI capabilities.