Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support calculating FLOPs of segmentors #2706

Merged
merged 3 commits into from
Mar 10, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 86 additions & 22 deletions tools/analysis_tools/get_flops.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,23 @@
# Copyright (c) OpenMMLab. All rights reserved.
import argparse
import tempfile
from pathlib import Path

from mmcv.cnn import get_model_complexity_info
from mmengine import Config
import torch
from mmengine import Config, DictAction
from mmengine.logging import MMLogger
from mmengine.model import revert_sync_batchnorm
from mmengine.registry import init_default_scope

from mmseg.models import build_segmentor
from mmseg.models import BaseSegmentor
from mmseg.registry import MODELS
from mmseg.structures import SegDataSample

try:
from mmengine.analysis import get_model_complexity_info
from mmengine.analysis.print_helper import _format_size
except ImportError:
raise ImportError('Please upgrade mmengine >= 0.6.0 to use this script.')


def parse_args():
Expand All @@ -17,43 +30,94 @@ def parse_args():
nargs='+',
default=[2048, 1024],
help='input image size')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args


def main():
def inference(args: argparse.Namespace, logger: MMLogger) -> dict:
config_name = Path(args.config)

args = parse_args()
if not config_name.exists():
logger.error(f'Config file {config_name} does not exist')

cfg: Config = Config.fromfile(config_name)
cfg.work_dir = tempfile.TemporaryDirectory().name
cfg.log_level = 'WARN'
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)

init_default_scope(cfg.get('scope', 'mmseg'))

if len(args.shape) == 1:
input_shape = (3, args.shape[0], args.shape[0])
elif len(args.shape) == 2:
input_shape = (3, ) + tuple(args.shape)
else:
raise ValueError('invalid input shape')
result = {}

cfg = Config.fromfile(args.config)
cfg.model.pretrained = None
model = build_segmentor(
cfg.model,
train_cfg=cfg.get('train_cfg'),
test_cfg=cfg.get('test_cfg')).cuda()
model: BaseSegmentor = MODELS.build(cfg.model)
if hasattr(model, 'auxiliary_head'):
model.auxiliary_head = None
if torch.cuda.is_available():
model.cuda()
model = revert_sync_batchnorm(model)
result['ori_shape'] = input_shape[-2:]
result['pad_shape'] = input_shape[-2:]
data_batch = {
'inputs': [torch.rand(input_shape)],
'data_samples': [SegDataSample(metainfo=result)]
}
data = model.data_preprocessor(data_batch)
model.eval()
if cfg.model.decode_head.type in ['MaskFormerHead', 'Mask2FormerHead']:
# TODO: Support MaskFormer and Mask2Former
raise NotImplementedError('MaskFormer and Mask2Former are not '
'supported yet.')
outputs = get_model_complexity_info(
model,
input_shape,
inputs=data['inputs'],
show_table=False,
show_arch=False)
result['flops'] = _format_size(outputs['flops'])
result['params'] = _format_size(outputs['params'])
result['compute_type'] = 'direct: randomly generate a picture'
return result

if hasattr(model, 'forward_dummy'):
model.forward = model.forward_dummy
else:
raise NotImplementedError(
'FLOPs counter is currently not currently supported with {}'.
format(model.__class__.__name__))

flops, params = get_model_complexity_info(model, input_shape)
def main():

args = parse_args()
logger = MMLogger.get_instance(name='MMLogger')

result = inference(args, logger)
split_line = '=' * 30
print('{0}\nInput shape: {1}\nFlops: {2}\nParams: {3}\n{0}'.format(
split_line, input_shape, flops, params))
ori_shape = result['ori_shape']
pad_shape = result['pad_shape']
flops = result['flops']
params = result['params']
compute_type = result['compute_type']

if pad_shape != ori_shape:
print(f'{split_line}\nUse size divisor set input shape '
f'from {ori_shape} to {pad_shape}')
print(f'{split_line}\nCompute type: {compute_type}\n'
f'Input shape: {pad_shape}\nFlops: {flops}\n'
f'Params: {params}\n{split_line}')
print('!!!Please be cautious if you use the results in papers. '
'You may need to check if all ops are supported and verify that the '
'flops computation is correct.')
'You may need to check if all ops are supported and verify '
'that the flops computation is correct.')


if __name__ == '__main__':
Expand Down