In [None]:
import copy
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
os.environ['LOCAL_RANK'] = '0'
import os.path as osp
import time

import mmcv
import torch
from mmcv.cnn.utils import revert_sync_batchnorm
from mmcv.utils import Config, get_git_hash

from mmseg import __version__
from mmseg.apis import set_random_seed, train_segmentor
from mmseg.datasets import build_dataset
from mmseg.models import build_segmentor
from mmseg.utils import (get_device, get_root_logger, setup_multi_processes)

import pvt, pvtv2

In [2]:
cfg_path = '../my_configs/fpn_pvtv2_b5_casia512_320k.py'
cfg = Config.fromfile(cfg_path)

In [None]:
if cfg.get('cudnn_benchmark', False):
    torch.backends.cudnn.benchmark = True

cfg.work_dir = '../work_dirs/casia512_320k/fpn_pvtv2_b5'
mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
setup_multi_processes(cfg)
meta = dict()

distributed = False

logger.info(f'Distributed training: {distributed}')
logger.info(f'Config:\n{cfg.pretty_text}')

In [4]:
cfg.seed = 0
set_random_seed(0, deterministic=False)
cfg.gpu_ids = range(1)
cfg.device = get_device()
seed = cfg.seed
meta['seed'] = seed

In [None]:
model = build_segmentor(
        cfg.model,
        train_cfg=cfg.get('train_cfg'),
        test_cfg=cfg.get('test_cfg'))
model.init_weights()
if not distributed:
    model = revert_sync_batchnorm(model)
logger.info(model)


In [None]:
datasets = [build_dataset(cfg.data.train)]
if len(cfg.workflow) == 2:
    val_dataset = copy.deepcopy(cfg.data.val)
    val_dataset.pipeline = cfg.data.train.pipeline
    datasets.append(build_dataset(val_dataset))
model.CLASSES = datasets[0].CLASSES

In [7]:
if cfg.checkpoint_config is not None:
    cfg.checkpoint_config.meta = dict(
        mmseg_version=f'{__version__}+{get_git_hash()[:7]}',
        config=cfg.pretty_text,
        CLASSES=datasets[0].CLASSES,
        PALETTE=datasets[0].PALETTE)

In [None]:
train_segmentor(
        model,
        datasets,
        cfg,
        distributed=distributed,
        validate=True,
        timestamp=timestamp,
        meta=meta)