Skip to content

Commit

Permalink
support training mmrotate on NPU (#806)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenmin00 committed Apr 13, 2023
1 parent 7755aa5 commit 04405ab
Show file tree
Hide file tree
Showing 6 changed files with 110 additions and 19 deletions.
19 changes: 11 additions & 8 deletions mmrotate/apis/train.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
# Copyright (c) OpenMMLab. All rights reserved.
# Copied from mmdet, only modified `get_root_logger`.
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
import os

from mmcv.runner import (DistSamplerSeedHook, EpochBasedRunner,
Fp16OptimizerHook, OptimizerHook, build_optimizer,
build_runner)
from mmdet.core import DistEvalHook, EvalHook
from mmdet.datasets import (build_dataloader, build_dataset,
replace_ImageToTensor)

from mmrotate.utils import compat_cfg, find_latest_checkpoint, get_root_logger
from mmrotate.utils import (build_ddp, build_dp, compat_cfg,
find_latest_checkpoint, get_root_logger)


def train_detector(model,
Expand Down Expand Up @@ -51,14 +52,14 @@ def train_detector(model,
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
model = build_ddp(
model,
cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(
model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
Expand All @@ -77,6 +78,8 @@ def train_detector(model,

# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is None and cfg.get('device', None) == 'npu':
fp16_cfg = dict(loss_scale='dynamic')
if fp16_cfg is not None:
optimizer_config = Fp16OptimizerHook(
**cfg.optimizer_config, **fp16_cfg, distributed=distributed)
Expand Down
2 changes: 1 addition & 1 deletion mmrotate/models/dense_heads/rotated_anchor_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
# since feature map sizes of all images are the same, we only compute
# anchors for one time
multi_level_anchors = self.anchor_generator.grid_priors(
featmap_sizes, device)
featmap_sizes, device=device)
anchor_list = [multi_level_anchors for _ in range(num_imgs)]

# for each image, we compute valid flags of multi level anchors
Expand Down
3 changes: 2 additions & 1 deletion mmrotate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
from .logger import get_root_logger
from .misc import find_latest_checkpoint
from .setup_env import setup_multi_processes
from .util_distribution import build_ddp, build_dp, get_device

__all__ = [
'get_root_logger', 'collect_env', 'find_latest_checkpoint', 'compat_cfg',
'setup_multi_processes'
'setup_multi_processes', 'build_dp', 'build_ddp', 'get_device'
]
77 changes: 77 additions & 0 deletions mmrotate/utils/util_distribution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel}

ddp_factory = {'cuda': MMDistributedDataParallel}


def build_dp(model, device='cuda', dim=0, *args, **kwargs):
"""build DataParallel module by device type.
if device is cuda, return a MMDataParallel model; if device is mlu,
return a MLUDataParallel model.
Args:
model (:class:`nn.Module`): model to be parallelized.
device (str): device type, cuda, cpu or mlu. Defaults to cuda.
dim (int): Dimension used to scatter the data. Defaults to 0.
Returns:
nn.Module: the model to be parallelized.
"""
if device == 'npu':
from mmcv.device.npu import NPUDataParallel
dp_factory['npu'] = NPUDataParallel
torch.npu.set_device(kwargs['device_ids'][0])
torch.npu.set_compile_mode(jit_compile=False)
model = model.npu()
elif device == 'cuda':
model = model.cuda(kwargs['device_ids'][0])

return dp_factory[device](model, dim=dim, *args, **kwargs)


def build_ddp(model, device='cuda', *args, **kwargs):
"""Build DistributedDataParallel module by device type.
If device is cuda, return a MMDistributedDataParallel model;
if device is mlu, return a MLUDistributedDataParallel model.
Args:
model (:class:`nn.Module`): module to be parallelized.
device (str): device type, mlu or cuda.
Returns:
:class:`nn.Module`: the module to be parallelized
References:
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
assert device in ['cuda', 'npu'], 'Only available for cuda or npu devices.'
if device == 'npu':
from mmcv.device.npu import NPUDistributedDataParallel
torch.npu.set_compile_mode(jit_compile=False)
ddp_factory['npu'] = NPUDistributedDataParallel
model = model.npu()
elif device == 'cuda':
model = model.cuda()

return ddp_factory[device](model, *args, **kwargs)


def is_npu_available():
"""Returns a bool indicating if NPU is currently available."""
return hasattr(torch, 'npu') and torch.npu.is_available()


def get_device():
"""Returns an available device, cpu, cuda."""
is_device_available = {
'npu': is_npu_available(),
'cuda': torch.cuda.is_available(),
}
device_list = [k for k, v in is_device_available.items() if v]
return device_list[0] if len(device_list) >= 1 else 'cpu'
24 changes: 16 additions & 8 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,15 @@
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmdet.apis import multi_gpu_test, single_gpu_test
from mmdet.datasets import build_dataloader, replace_ImageToTensor

from mmrotate.datasets import build_dataset
from mmrotate.models import build_detector
from mmrotate.utils import compat_cfg, setup_multi_processes
from mmrotate.utils import (build_ddp, build_dp, compat_cfg, get_device,
setup_multi_processes)


def parse_args():
Expand Down Expand Up @@ -209,9 +209,12 @@ def main():

# build the model and load checkpoint
cfg.model.train_cfg = None
cfg.device = get_device()
model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))

# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is not None:
if fp16_cfg is not None or cfg.device == 'npu':
wrap_fp16_model(model)
checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
if args.fuse_conv_bn:
Expand All @@ -224,14 +227,19 @@ def main():
model.CLASSES = dataset.CLASSES

if not distributed:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
args.show_score_thr)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
broadcast_buffers=False)
find_unused_parameters = cfg.get('find_unused_parameters', False)
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = build_ddp(
model,
cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)

Expand Down
4 changes: 3 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from mmrotate.apis import train_detector
from mmrotate.datasets import build_dataset
from mmrotate.models import build_detector
from mmrotate.utils import collect_env, get_root_logger, setup_multi_processes
from mmrotate.utils import (collect_env, get_device, get_root_logger,
setup_multi_processes)


def parse_args():
Expand Down Expand Up @@ -178,6 +179,7 @@ def main():
CLASSES=datasets[0].CLASSES)
# add an attribute for visualization convenience
model.CLASSES = datasets[0].CLASSES
cfg.device = get_device()
train_detector(
model,
datasets,
Expand Down

0 comments on commit 04405ab

Please sign in to comment.