In [1]:
import importlib

import torch
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau

from datasets.hdf5 import get_train_loaders
from unet3d.config import load_config
from unet3d.losses import get_loss_criterion
from unet3d.metrics import get_evaluation_metric
from unet3d.model import get_model
from unet3d.trainer import UNet3DTrainer
from unet3d.utils import get_logger
from unet3d.utils import get_number_of_learnable_parameters

import argparse

import os
import yaml
import pdb

In [2]:
# set config and device parameters

logger = get_logger('UNet3DTrainer')


if torch.cuda.is_available():
    DEFAULT_DEVICE = 'cuda:0'
else:
     DEFAULT_DEVICE = 'cpu'

# config file
config_file = 'resources/train_config_ce.yaml'
config = yaml.load(open(config_file, 'r'))
# Get a device to train on
device = config.get('device', DEFAULT_DEVICE)
config['device'] = torch.device(device)

logger.info(config)


2019-08-07 22:44:08,948 [MainThread] INFO UNet3DTrainer - {'manual_seed': 0, 'model': {'name': 'UNet3D', 'in_channels': 1, 'out_channels': 2, 'layer_order': 'crg', 'f_maps': 32, 'num_groups': 8, 'final_sigmoid': False}, 'trainer': {'checkpoint_dir': '3dunet', 'resume': '3dunet/last_checkpoint.pytorch', 'validate_after_iters': 20, 'log_after_iters': 20, 'epochs': 2000, 'iters': 100000, 'eval_score_higher_is_better': True}, 'optimizer': {'learning_rate': 0.0002, 'weight_decay': 0.0001}, 'loss': {'name': 'WeightedCrossEntropyLoss', 'loss_weight': [0.05, 0.95], 'ignore_index': None}, 'eval_metric': {'name': 'MeanIoU', 'ignore_index': None}, 'lr_scheduler': {'name': 'MultiStepLR', 'milestones': [10, 30, 60], 'gamma': 0.2}, 'loaders': {'train_patch': [128, 128, 128], 'train_stride': [64, 64, 64], 'val_patch': [128, 128, 128], 'val_stride': [128, 128, 128], 'raw_internal_path': 'raw', 'label_internal_path': 'label', 'train_path': ['../h5_fractals/0.h5', '../h5_fractals/1.h5', '../h5_fractals/

  del sys.path[0]


In [3]:
def _create_trainer(config, model, optimizer, lr_scheduler, loss_criterion, eval_criterion, loaders, logger):
    assert 'trainer' in config, 'Could not find trainer configuration'
    trainer_config = config['trainer']

    resume = trainer_config.get('resume', None)
    pre_trained = trainer_config.get('pre_trained', None)

    if resume is not None:
        # continue training from a given checkpoint
        return UNet3DTrainer.from_checkpoint(resume, model,
                                             optimizer, lr_scheduler, loss_criterion,
                                             eval_criterion, loaders,
                                             logger=logger)
    elif pre_trained is not None:
        # fine-tune a given pre-trained model
        return UNet3DTrainer.from_pretrained(pre_trained, model, optimizer, lr_scheduler, loss_criterion,
                                             eval_criterion, device=config['device'], loaders=loaders,
                                             max_num_epochs=trainer_config['epochs'],
                                             max_num_iterations=trainer_config['iters'],
                                             validate_after_iters=trainer_config['validate_after_iters'],
                                             log_after_iters=trainer_config['log_after_iters'],
                                             eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
                                             logger=logger)
    else:
        # start training from scratch
        return UNet3DTrainer(model, optimizer, lr_scheduler, loss_criterion, eval_criterion,
                             config['device'], loaders, trainer_config['checkpoint_dir'],
                             max_num_epochs=trainer_config['epochs'],
                             max_num_iterations=trainer_config['iters'],
                             validate_after_iters=trainer_config['validate_after_iters'],
                             log_after_iters=trainer_config['log_after_iters'],
                             eval_score_higher_is_better=trainer_config['eval_score_higher_is_better'],
                             logger=logger)


def _create_optimizer(config, model):
    assert 'optimizer' in config, 'Cannot find optimizer configuration'
    optimizer_config = config['optimizer']
    learning_rate = optimizer_config['learning_rate']
    weight_decay = optimizer_config['weight_decay']
    optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
    return optimizer


def _create_lr_scheduler(config, optimizer):
    lr_config = config.get('lr_scheduler', None)
    if lr_config is None:
        # use ReduceLROnPlateau as a default scheduler
        return ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=20, verbose=True)
    else:
        class_name = lr_config.pop('name')
        m = importlib.import_module('torch.optim.lr_scheduler')
        clazz = getattr(m, class_name)
        # add optimizer to the config
        lr_config['optimizer'] = optimizer
        return clazz(**lr_config)

In [None]:

# Load and log experiment configuration
manual_seed = config.get('manual_seed', None)
if manual_seed is not None:
    logger.info(f'Seed the RNG for all devices with {manual_seed}')
    torch.manual_seed(manual_seed)
    # see https://pytorch.org/docs/stable/notes/randomness.html
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

# Create the model
model = get_model(config)

# put the model on GPUs (if available)
#if torch.cuda.is_available():
logger.info(f"Sending the model to '{config['device']}'")
model = model.to(config['device'])
                
# Log the number of learnable parameters
logger.info(f'Number of learnable params {get_number_of_learnable_parameters(model)}')

# Create loss criterion
loss_criterion = get_loss_criterion(config)
# Create evaluation metric
eval_criterion = get_evaluation_metric(config)

# Create data loaders
loaders = get_train_loaders(config)

# Create the optimizer
optimizer = _create_optimizer(config, model)

# Create learning rate adjustment strategy
lr_scheduler = _create_lr_scheduler(config, optimizer)

# Create model trainer
trainer = _create_trainer(config, model=model, optimizer=optimizer, lr_scheduler=lr_scheduler,
                          loss_criterion=loss_criterion, eval_criterion=eval_criterion, loaders=loaders,
                          logger=logger)
# Start training
trainer.fit()

2019-08-07 22:44:08,971 [MainThread] INFO UNet3DTrainer - Seed the RNG for all devices with 0
2019-08-07 22:44:09,467 [MainThread] INFO UNet3DTrainer - Sending the model to 'cuda:0'
2019-08-07 22:44:37,665 [MainThread] INFO UNet3DTrainer - Number of learnable params 4080914
2019-08-07 22:44:37,667 [MainThread] INFO HDF5Dataset - Creating training and validation set loaders...
2019-08-07 22:44:37,667 [MainThread] INFO HDF5Dataset - Slice builder class: SliceBuilder
2019-08-07 22:44:37,668 [MainThread] INFO HDF5Dataset - Loading training set from: ../h5_fractals/0.h5...
2019-08-07 22:44:38,026 [MainThread] INFO HDF5Dataset - Loading training set from: ../h5_fractals/1.h5...
2019-08-07 22:44:38,230 [MainThread] INFO HDF5Dataset - Loading training set from: ../h5_fractals/2.h5...
2019-08-07 22:44:38,379 [MainThread] INFO HDF5Dataset - Loading training set from: ../h5_fractals/3.h5...
2019-08-07 22:44:38,513 [MainThread] INFO HDF5Dataset - Loading training set from: ../h5_fractals/4.h5...
2

  result = self.forward(*input, **kwargs)


2019-08-07 22:44:52,437 [MainThread] INFO UNet3DTrainer - Validating...
2019-08-07 22:44:55,248 [MainThread] INFO UNet3DTrainer - Validation finished. Loss: 0.3198775003353755. Evaluation score: 0.5447047253449758
2019-08-07 22:44:55,256 [MainThread] INFO UNet3DTrainer - Saving last checkpoint to '3dunet/last_checkpoint.pytorch'
2019-08-07 22:44:55,601 [MainThread] INFO UNet3DTrainer - Training stats. Loss: 0.007944676093757153. Evaluation score: 0.5498158931732178
2019-08-07 22:44:55,602 [MainThread] INFO UNet3DTrainer - Logging model parameters and gradients
2019-08-07 22:44:56,619 [MainThread] INFO UNet3DTrainer - Training iteration 5441. Batch 1. Epoch [91/1999]


  return (img - np.min(img)) / np.ptp(img)


2019-08-07 22:44:56,901 [MainThread] INFO UNet3DTrainer - Training iteration 5442. Batch 2. Epoch [91/1999]
2019-08-07 22:44:58,153 [MainThread] INFO UNet3DTrainer - Training iteration 5443. Batch 3. Epoch [91/1999]
2019-08-07 22:44:59,398 [MainThread] INFO UNet3DTrainer - Training iteration 5444. Batch 4. Epoch [91/1999]
2019-08-07 22:45:00,646 [MainThread] INFO UNet3DTrainer - Training iteration 5445. Batch 5. Epoch [91/1999]
2019-08-07 22:45:01,893 [MainThread] INFO UNet3DTrainer - Training iteration 5446. Batch 6. Epoch [91/1999]
2019-08-07 22:45:03,143 [MainThread] INFO UNet3DTrainer - Training iteration 5447. Batch 7. Epoch [91/1999]
2019-08-07 22:45:04,391 [MainThread] INFO UNet3DTrainer - Training iteration 5448. Batch 8. Epoch [91/1999]
2019-08-07 22:45:05,639 [MainThread] INFO UNet3DTrainer - Training iteration 5449. Batch 9. Epoch [91/1999]
2019-08-07 22:45:06,889 [MainThread] INFO UNet3DTrainer - Training iteration 5450. Batch 10. Epoch [91/1999]
2019-08-07 22:45:08,138 [Ma