Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve cpu training/testing #1506

Open
wants to merge 5 commits into
base: 0.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions mmaction/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from ..datasets import build_dataloader, build_dataset
from ..utils import (PreciseBNHook, build_ddp, build_dp, default_device,
get_root_logger)
from .test import multi_gpu_test
from .test import multi_gpu_test, single_gpu_test


def init_random_seed(seed=None, device=default_device, distributed=True):
Expand Down Expand Up @@ -62,6 +62,7 @@ def train_model(model,
validate=False,
test=dict(test_best=False, test_last=False),
timestamp=None,
device='gpu',
meta=None):
"""Train model entry function.

Expand Down Expand Up @@ -131,8 +132,15 @@ def train_model(model,
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters))
else:
model = build_dp(
model, default_device, default_args=dict(device_ids=cfg.gpu_ids))
if device == 'cuda':
model = build_dp(
model,
default_device,
default_args=dict(device_ids=cfg.gpu_ids))
elif device == 'cpu':
model = model.cpu()
else:
raise ValueError(f'unsupported device name {device}.')

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
Expand Down Expand Up @@ -284,8 +292,12 @@ def train_model(model,
if ckpt is not None:
runner.load_checkpoint(ckpt)

outputs = multi_gpu_test(runner.model, test_dataloader, tmpdir,
gpu_collect)
if distributed:
outputs = multi_gpu_test(runner.model, test_dataloader, tmpdir,
gpu_collect)
else:
outputs = single_gpu_test(model, test_dataloader)

rank, _ = get_dist_info()
if rank == 0:
out = osp.join(cfg.work_dir, f'{name}_pred.pkl')
Expand Down
3 changes: 2 additions & 1 deletion mmaction/models/losses/hvu_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def _forward(self, cls_score, label, mask, category_mask):
# there should be at least one sample which contains tags
# in this category
if torch.sum(category_mask_i) < 0.5:
losses[f'{name}_LOSS'] = torch.tensor(.0).cuda()
losses[f'{name}_LOSS'] = torch.tensor(.0).to(
category_loss.device)
loss_weights[f'{name}_LOSS'] = .0
continue
category_loss = torch.sum(category_loss * category_mask_i)
Expand Down
15 changes: 10 additions & 5 deletions mmaction/utils/multigrid/longshortcyclehook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,29 @@
from mmaction.utils import get_root_logger


def modify_subbn3d_num_splits(logger, module, num_splits):
def modify_subbn3d_num_splits(logger, module, num_splits, device='cuda'):
"""Recursively modify the number of splits of subbn3ds in module.
Inheritates the running_mean and running_var from last subbn.bn.

Args:
logger (:obj:`logging.Logger`): The logger to log information.
module (nn.Module): The module to be modified.
num_splits (int): The targeted number of splits.
device (str | :obj:`torch.device`): The desired device of returned
tensor. Default: 'cuda'.
Returns:
int: The number of subbn3d modules modified.
"""
count = 0
for child in module.children():
from mmaction.models import SubBatchNorm3D
if isinstance(child, SubBatchNorm3D):
new_split_bn = nn.BatchNorm3d(
child.num_features * num_splits, affine=False).cuda()
if device == 'cuda':
new_split_bn = nn.BatchNorm3d(
child.num_features * num_splits, affine=False).cuda()
else:
new_split_bn = nn.BatchNorm3d(
child.num_features * num_splits, affine=False).cpu()
new_state_dict = new_split_bn.state_dict()

for param_name, param in child.bn.state_dict().items():
Expand Down Expand Up @@ -125,8 +131,7 @@ def _update_long_cycle(self, runner):
dist=True,
num_gpus=len(self.cfg.gpu_ids),
drop_last=True,
seed=self.cfg.get('seed', None),
)
seed=self.cfg.get('seed', None))
runner.data_loader = dataloader
self.logger.info('Rebuild runner.data_loader')

Expand Down
14 changes: 12 additions & 2 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--device',
choices=['cpu', 'cuda'],
default='cuda',
help='device used for testing')
parser.add_argument(
'--onnx',
action='store_true',
Expand Down Expand Up @@ -157,8 +162,13 @@ def inference_pytorch(args, cfg, distributed, data_loader):
model = fuse_conv_bn(model)

if not distributed:
model = build_dp(
model, default_device, default_args=dict(device_ids=cfg.gpu_ids))
if args.device == 'cpu':
model = model.cpu()
else:
model = build_dp(
model,
default_device,
default_args=dict(device_ids=cfg.gpu_ids))
outputs = single_gpu_test(model, data_loader)
else:
model = build_ddp(
Expand Down
20 changes: 17 additions & 3 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ def parse_args():
help=('whether to test the best checkpoint (if applicable) after '
'training'))
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--device',
choices=['cuda', 'cpu'],
default='cuda',
help='device used for training')
group_gpus.add_argument(
'--gpus',
type=int,
Expand Down Expand Up @@ -119,10 +124,16 @@ def main():
'Non-distributed training can only use 1 gpu now. We will '
'use the 1st one in gpu_ids. ')
cfg.gpu_ids = [args.gpu_ids[0]]
elif args.gpus is not None:
else:
warnings.warn('Non-distributed training can only use 1 gpu now. ')
cfg.gpu_ids = range(1)

if args.gpus is None and args.gpu_ids is None:
cfg.gpu_ids = range(1)

if args.device == 'cpu':
cfg.gpu_ids = range(1)

# init distributed env first, since logger depends on the dist info.
if args.launcher == 'none':
distributed = False
Expand Down Expand Up @@ -163,8 +174,10 @@ def main():
logger.info(f'Config: {cfg.pretty_text}')

# set random seeds
seed = init_random_seed(args.seed, distributed=distributed)
seed = seed + dist.get_rank() if args.diff_seed else seed
seed = init_random_seed(
args.seed, device=args.device, distributed=distributed)
if distributed:
seed = seed + dist.get_rank() if args.diff_seed else seed
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
Expand Down Expand Up @@ -215,6 +228,7 @@ def main():
validate=args.validate,
test=test_option,
timestamp=timestamp,
device='cpu' if args.device == 'cpu' else 'cuda',
meta=meta)


Expand Down