Skip to content

Commit

Permalink
Support mmtrack with NPU backend. (#876)
Browse files Browse the repository at this point in the history
  • Loading branch information
luomaoling committed Apr 25, 2023
1 parent 5b47f18 commit e79491e
Show file tree
Hide file tree
Showing 5 changed files with 106 additions and 14 deletions.
16 changes: 10 additions & 6 deletions mmtrack/apis/train.py
@@ -1,16 +1,17 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os

import numpy as np
import torch
import torch.distributed as dist
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
build_optimizer, get_dist_info)
from mmcv.utils import build_from_cfg
from mmdet.datasets import build_dataset

from mmtrack.core import DistEvalHook, EvalHook
from mmtrack.datasets import build_dataloader
from mmtrack.utils import get_root_logger
from mmtrack.utils import build_ddp, build_dp, get_root_logger


def init_random_seed(seed=None, device='cuda'):
Expand Down Expand Up @@ -103,13 +104,14 @@ def train_model(model,
logger.info('set find_unused_parameters = True in DDP')
# Sets the `find_unused_parameters` parameter in
# torch.nn.parallel.DistributedDataParallel
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
model = build_ddp(
model,
cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
Expand All @@ -124,6 +126,8 @@ def train_model(model,

# fp16 setting
fp16_cfg = cfg.get('fp16', None)
if fp16_cfg is None and cfg.get('device', None) == 'npu':
fp16_cfg = dict(loss_scale='dynamic')
optimizer_config = cfg.optimizer_config
if 'type' not in cfg.optimizer_config:
optimizer_config.type = 'Fp16OptimizerHook' \
Expand Down
5 changes: 4 additions & 1 deletion mmtrack/utils/__init__.py
@@ -1,5 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .collect_env import collect_env
from .logger import get_root_logger
from .util_distribution import build_ddp, build_dp, get_device

__all__ = ['collect_env', 'get_root_logger']
__all__ = [
'collect_env', 'get_root_logger', 'build_ddp', 'build_dp', 'get_device'
]
71 changes: 71 additions & 0 deletions mmtrack/utils/util_distribution.py
@@ -0,0 +1,71 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel

dp_factory = {'cuda': MMDataParallel, 'cpu': MMDataParallel}

ddp_factory = {'cuda': MMDistributedDataParallel}


def build_dp(model, device='cuda', dim=0, *args, **kwargs):
"""build DataParallel module by device type.
if device is cuda, return a MMDataParallel model; if device is npu,
return a NPUDataParallel model.
Args:
model (:class:`nn.Module`): model to be parallelized.
device (str): device type, cuda, cpu or npu. Defaults to cuda.
dim (int): Dimension used to scatter the data. Defaults to 0.
Returns:
nn.Module: the model to be parallelized.
"""
if device == 'npu':
from mmcv.device.npu import NPUDataParallel
dp_factory['npu'] = NPUDataParallel
torch.npu.set_device(kwargs['device_ids'][0])
torch.npu.set_compile_mode(jit_compile=False)
model = model.npu()
elif device == 'cuda':
model = model.cuda(kwargs['device_ids'][0])

return dp_factory[device](model, dim=dim, *args, **kwargs)


def build_ddp(model, device='cuda', *args, **kwargs):
"""Build DistributedDataParallel module by device type.
If device is cuda, return a MMDistributedDataParallel model;
if device is npu, return a NPUDistributedDataParallel model.
Args:
model (:class:`nn.Module`): module to be parallelized.
device (str): device type, npu or cuda.
Returns:
:class:`nn.Module`: the module to be parallelized
References:
.. [1] https://pytorch.org/docs/stable/generated/torch.nn.parallel.
DistributedDataParallel.html
"""
assert device in ['cuda', 'npu'], 'Only available for cuda or npu devices.'
if device == 'npu':
from mmcv.device.npu import NPUDistributedDataParallel
torch.npu.set_compile_mode(jit_compile=False)
ddp_factory['npu'] = NPUDistributedDataParallel
model = model.npu()
elif device == 'cuda':
model = model.cuda()

return ddp_factory[device](model, *args, **kwargs)


def is_npu_available():
"""Returns a bool indicating if NPU is currently available."""
return hasattr(torch, 'npu') and torch.npu.is_available()


def get_device():
"""Returns an available device, cpu, cuda or npu."""
is_device_available = {
'npu': is_npu_available(),
'cuda': torch.cuda.is_available()
}
device_list = [k for k, v in is_device_available.items() if v]
return device_list[0] if len(device_list) >= 1 else 'cpu'
21 changes: 16 additions & 5 deletions tools/test.py
Expand Up @@ -8,13 +8,13 @@
import torch
from mmcv import Config, DictAction
from mmcv.cnn import fuse_conv_bn
from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
wrap_fp16_model)
from mmdet.apis import set_random_seed

from mmtrack.core import setup_multi_processes
from mmtrack.datasets import build_dataset
from mmtrack.utils import build_ddp, build_dp, get_device


def parse_args():
Expand Down Expand Up @@ -156,6 +156,9 @@ def main():
dist=distributed,
shuffle=False)

cfg.device = get_device() if cfg.get('device',
None) is None else cfg.device

# build the model and load checkpoint
if cfg.get('test_cfg', False):
model = build_model(
Expand All @@ -179,18 +182,26 @@ def main():
model = fuse_conv_bn(model)

if not distributed:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
model = build_dp(model, cfg.device, device_ids=cfg.gpu_ids)
outputs = single_gpu_test(
model,
data_loader,
args.show,
args.show_dir,
show_score_thr=args.show_score_thr)
else:
model = MMDistributedDataParallel(
model.cuda(),
device_ids=[torch.cuda.current_device()],
model = build_ddp(
model,
cfg.device,
device_ids=[int(os.environ['LOCAL_RANK'])],
broadcast_buffers=False)

# In multi_gpu_test, if tmpdir is None, some tesnors
# will init on cuda by default, and no device choice supported.
# Init a tmpdir to avoid error on npu here.
if cfg.device == 'npu' and args.tmpdir is None:
args.tmpdir = './npu_tmpdir'

outputs = multi_gpu_test(model, data_loader, args.tmpdir,
args.gpu_collect)

Expand Down
7 changes: 5 additions & 2 deletions tools/train.py
Expand Up @@ -17,7 +17,7 @@
from mmtrack.apis import init_random_seed
from mmtrack.core import setup_multi_processes
from mmtrack.datasets import build_dataset
from mmtrack.utils import collect_env, get_root_logger
from mmtrack.utils import collect_env, get_device, get_root_logger


def parse_args():
Expand Down Expand Up @@ -165,16 +165,19 @@ def main():

# set random seeds. Force setting fixed seed and deterministic=True in SOT
# configs
cfg.device = get_device() if cfg.get('device',
None) is None else cfg.device
if args.seed is not None:
cfg.seed = args.seed
elif cfg.get('seed', None) is None:
cfg.seed = init_random_seed()
cfg.seed = init_random_seed(device=cfg.device)
cfg.seed = cfg.seed + dist.get_rank() if args.diff_seed else cfg.seed

deterministic = True if args.deterministic else cfg.get(
'deterministic', False)
logger.info(f'Set random seed to {cfg.seed}, '
f'deterministic: {deterministic}')

set_random_seed(cfg.seed, deterministic=deterministic)
meta['seed'] = cfg.seed

Expand Down

0 comments on commit e79491e

Please sign in to comment.