/
train.py
203 lines (178 loc) · 7.28 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# Copyright (c) OpenMMLab. All rights reserved.
import os
import random
import warnings
import mmcv
import numpy as np
import torch
import torch.distributed as dist
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
build_runner, get_dist_info)
from mmcv.utils import build_from_cfg
from mmseg import digit_version
from mmseg.core import DistEvalHook, EvalHook, build_optimizer
from mmseg.datasets import build_dataloader, build_dataset
from mmseg.utils import (build_ddp, build_dp, find_latest_checkpoint,
get_root_logger)
def init_random_seed(seed=None, device='cuda'):
"""Initialize random seed.
If the seed is not set, the seed will be automatically randomized,
and then broadcast to all processes to prevent some potential bugs.
Args:
seed (int, Optional): The seed. Default to None.
device (str): The device where the seed will be put on.
Default to 'cuda'.
Returns:
int: Seed to be used.
"""
if seed is not None:
return seed
# Make sure all ranks share the same random seed to prevent
# some potential bugs. Please refer to
# https://github.com/open-mmlab/mmdetection/issues/6339
rank, world_size = get_dist_info()
seed = np.random.randint(2**31)
if world_size == 1:
return seed
if rank == 0:
random_num = torch.tensor(seed, dtype=torch.int32, device=device)
else:
random_num = torch.tensor(0, dtype=torch.int32, device=device)
dist.broadcast(random_num, src=0)
return random_num.item()
def set_random_seed(seed, deterministic=False):
"""Set random seed.
Args:
seed (int): Seed to be used.
deterministic (bool): Whether to set the deterministic option for
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
to True and `torch.backends.cudnn.benchmark` to False.
Default: False.
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def train_segmentor(model,
dataset,
cfg,
distributed=False,
validate=False,
timestamp=None,
meta=None):
"""Launch segmentor training."""
logger = get_root_logger(cfg.log_level)
# prepare data loaders
dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
# The default loader config
loader_cfg = dict(
# cfg.gpus will be ignored if distributed
num_gpus=len(cfg.gpu_ids),
dist=distributed,
seed=cfg.seed,
drop_last=True)
# The overall dataloader settings
loader_cfg.update({
k: v
for k, v in cfg.data.items() if k not in [
'train', 'val', 'test', 'train_dataloader', 'val_dataloader',
'test_dataloader'
]
})
# The specific dataloader settings
train_loader_cfg = {**loader_cfg, **cfg.data.get('train_dataloader', {})}
data_loaders = [build_dataloader(ds, **train_loader_cfg) for ds in dataset]
# put model on devices
if distributed:
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# DDP wrapper
model = build_ddp(
model,
cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
if not torch.cuda.is_available():
assert digit_version(mmcv.__version__) >= digit_version('1.4.4'), \
'Please use MMCV >= 1.4.4 for CPU training!'
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
# build runner
optimizer = build_optimizer(model, cfg.optimizer)
if cfg.get('runner') is None:
cfg.runner = {'type': 'IterBasedRunner', 'max_iters': cfg.total_iters}
warnings.warn(
'config is now expected to have a `runner` section, '
'please set `runner` in your config.', UserWarning)
runner = build_runner(
cfg.runner,
default_args=dict(
model=model,
batch_processor=None,
optimizer=optimizer,
work_dir=cfg.work_dir,
logger=logger,
meta=meta))
if cfg.device == 'npu':
optimiter_config = dict(type='Fp16OptimizerHook', loss_scale='dynamic')
cfg.optimizer_config = optimiter_config if \
not cfg.optimizer_config else cfg.optimizer_config
# register hooks
runner.register_training_hooks(cfg.lr_config, cfg.optimizer_config,
cfg.checkpoint_config, cfg.log_config,
cfg.get('momentum_config', None))
if distributed:
# when distributed training by epoch, using`DistSamplerSeedHook` to set
# the different seed to distributed sampler for each epoch, it will
# shuffle dataset at each epoch and avoid overfitting.
if isinstance(runner, EpochBasedRunner):
runner.register_hook(DistSamplerSeedHook())
# an ugly walkaround to make the .log and .log.json filenames the same
runner.timestamp = timestamp
# register eval hooks
if validate:
val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
# The specific dataloader settings
val_loader_cfg = {
**loader_cfg,
'samples_per_gpu': 1,
'shuffle': False, # Not shuffle by default
**cfg.data.get('val_dataloader', {}),
}
val_dataloader = build_dataloader(val_dataset, **val_loader_cfg)
eval_cfg = cfg.get('evaluation', {})
eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
eval_hook = DistEvalHook if distributed else EvalHook
# In this PR (https://github.com/open-mmlab/mmcv/pull/1193), the
# priority of IterTimerHook has been modified from 'NORMAL' to 'LOW'.
runner.register_hook(
eval_hook(val_dataloader, **eval_cfg), priority='LOW')
# user-defined hooks
if cfg.get('custom_hooks', None):
custom_hooks = cfg.custom_hooks
assert isinstance(custom_hooks, list), \
f'custom_hooks expect list type, but got {type(custom_hooks)}'
for hook_cfg in cfg.custom_hooks:
assert isinstance(hook_cfg, dict), \
'Each item in custom_hooks expects dict type, but got ' \
f'{type(hook_cfg)}'
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = build_from_cfg(hook_cfg, HOOKS)
runner.register_hook(hook, priority=priority)
if cfg.resume_from is None and cfg.get('auto_resume'):
resume_from = find_latest_checkpoint(cfg.work_dir)
if resume_from is not None:
cfg.resume_from = resume_from
if cfg.resume_from:
if cfg.device == 'npu':
runner.resume(cfg.resume_from, map_location='npu')
else:
runner.resume(cfg.resume_from)
elif cfg.load_from:
runner.load_checkpoint(cfg.load_from)
runner.run(data_loaders, cfg.workflow)