Skip to content

Commit

Permalink
imporve cpu training/testing
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Mar 10, 2022
1 parent 748aca2 commit 6228ccd
Show file tree
Hide file tree
Showing 5 changed files with 43 additions and 12 deletions.
18 changes: 14 additions & 4 deletions mmaction/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
OmniSourceRunner)
from ..datasets import build_dataloader, build_dataset
from ..utils import PreciseBNHook, get_root_logger
from .test import multi_gpu_test
from .test import multi_gpu_test, single_gpu_test


def init_random_seed(seed=None, device='cuda', distributed=True):
Expand Down Expand Up @@ -62,6 +62,7 @@ def train_model(model,
validate=False,
test=dict(test_best=False, test_last=False),
timestamp=None,
device='gpu',
meta=None):
"""Train model entry function.
Expand Down Expand Up @@ -128,7 +129,12 @@ def train_model(model,
broadcast_buffers=False,
find_unused_parameters=find_unused_parameters)
else:
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
if device == 'cuda':
model = MMDataParallel(model, device_ids=cfg.gpu_ids)
elif device == 'cpu':
model = model.cpu()
else:
raise ValueError(f'unsupported device name {device}.')

# build runner
optimizer = build_optimizer(model, cfg.optimizer)
Expand Down Expand Up @@ -280,8 +286,12 @@ def train_model(model,
if ckpt is not None:
runner.load_checkpoint(ckpt)

outputs = multi_gpu_test(runner.model, test_dataloader, tmpdir,
gpu_collect)
if distributed:
outputs = multi_gpu_test(runner.model, test_dataloader, tmpdir,
gpu_collect)
else:
outputs = single_gpu_test(model, test_dataloader)

rank, _ = get_dist_info()
if rank == 0:
out = osp.join(cfg.work_dir, f'{name}_pred.pkl')
Expand Down
3 changes: 2 additions & 1 deletion mmaction/models/losses/hvu_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,8 @@ def _forward(self, cls_score, label, mask, category_mask):
# there should be at least one sample which contains tags
# in this category
if torch.sum(category_mask_i) < 0.5:
losses[f'{name}_LOSS'] = torch.tensor(.0).cuda()
losses[f'{name}_LOSS'] = torch.tensor(.0).to(
category_loss.device)
loss_weights[f'{name}_LOSS'] = .0
continue
category_loss = torch.sum(category_loss * category_mask_i)
Expand Down
15 changes: 10 additions & 5 deletions mmaction/utils/multigrid/longshortcyclehook.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,29 @@
from mmaction.utils import get_root_logger


def modify_subbn3d_num_splits(logger, module, num_splits):
def modify_subbn3d_num_splits(logger, module, num_splits, device='cuda'):
"""Recursively modify the number of splits of subbn3ds in module.
Inheritates the running_mean and running_var from last subbn.bn.
Args:
logger (:obj:`logging.Logger`): The logger to log information.
module (nn.Module): The module to be modified.
num_splits (int): The targeted number of splits.
device (str | :obj:`torch.device`): The desired device of returned
tensor. Default: 'cuda'.
Returns:
int: The number of subbn3d modules modified.
"""
count = 0
for child in module.children():
from mmaction.models import SubBatchNorm3D
if isinstance(child, SubBatchNorm3D):
new_split_bn = nn.BatchNorm3d(
child.num_features * num_splits, affine=False).cuda()
if device == 'cuda':
new_split_bn = nn.BatchNorm3d(
child.num_features * num_splits, affine=False).cuda()
else:
new_split_bn = nn.BatchNorm3d(
child.num_features * num_splits, affine=False).cpu()
new_state_dict = new_split_bn.state_dict()

for param_name, param in child.bn.state_dict().items():
Expand Down Expand Up @@ -125,8 +131,7 @@ def _update_long_cycle(self, runner):
dist=True,
num_gpus=len(self.cfg.gpu_ids),
drop_last=True,
seed=self.cfg.get('seed', None),
)
seed=self.cfg.get('seed', None))
runner.data_loader = dataloader
self.logger.info('Rebuild runner.data_loader')

Expand Down
10 changes: 9 additions & 1 deletion tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,11 @@ def parse_args():
default='none',
help='job launcher')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
'--device',
choices=['cpu', 'cuda'],
default='cuda',
help='device used for testing')
parser.add_argument(
'--onnx',
action='store_true',
Expand Down Expand Up @@ -157,7 +162,10 @@ def inference_pytorch(args, cfg, distributed, data_loader):
model = fuse_conv_bn(model)

if not distributed:
model = MMDataParallel(model, device_ids=[0])
if args.device == 'cpu':
model = model.cpu()
else:
model = MMDataParallel(model, device_ids=[0])
outputs = single_gpu_test(model, data_loader)
else:
model = MMDistributedDataParallel(
Expand Down
9 changes: 8 additions & 1 deletion tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def parse_args():
help=('whether to test the best checkpoint (if applicable) after '
'training'))
group_gpus = parser.add_mutually_exclusive_group()
group_gpus.add_argument(
'--device',
choices=['cuda', 'cpu'],
default='cuda',
help='device used for training')
group_gpus.add_argument(
'--gpus',
type=int,
Expand Down Expand Up @@ -158,7 +163,8 @@ def main():
logger.info(f'Config: {cfg.pretty_text}')

# set random seeds
seed = init_random_seed(args.seed, distributed=distributed)
seed = init_random_seed(
args.seed, device=args.device, distributed=distributed)
logger.info(f'Set random seed to {seed}, '
f'deterministic: {args.deterministic}')
set_random_seed(seed, deterministic=args.deterministic)
Expand Down Expand Up @@ -209,6 +215,7 @@ def main():
validate=args.validate,
test=test_option,
timestamp=timestamp,
device='cpu' if args.device == 'cpu' else 'cuda',
meta=meta)


Expand Down

0 comments on commit 6228ccd

Please sign in to comment.