diff --git a/.dev_scripts/benchmark_valid_flops.py b/.dev_scripts/benchmark_valid_flops.py new file mode 100644 index 000000000..a02e9056a --- /dev/null +++ b/.dev_scripts/benchmark_valid_flops.py @@ -0,0 +1,282 @@ +import logging +import re +import tempfile +from argparse import ArgumentParser +from collections import OrderedDict +from functools import partial +from pathlib import Path + +import numpy as np +import pandas as pd +import torch +from mmengine import Config, DictAction +from mmengine.analysis import get_model_complexity_info +from mmengine.analysis.print_helper import _format_size +from mmengine.fileio import FileClient +from mmengine.logging import MMLogger +from mmengine.model import revert_sync_batchnorm +from mmengine.registry import init_default_scope +from mmengine.runner import Runner +from modelindex.load_model_index import load +from rich.console import Console +from rich.table import Table +from rich.text import Text +from tqdm import tqdm + +from mmocr.registry import MODELS + +console = Console() +MMOCR_ROOT = Path(__file__).absolute().parents[1] + + +def parse_args(): + parser = ArgumentParser(description='Valid all models in model-index.yml') + parser.add_argument( + '--shape', + type=int, + nargs='+', + default=[1280, 800], + help='input image size') + parser.add_argument( + '--checkpoint_root', + help='Checkpoint file root path. If set, load checkpoint before test.') + parser.add_argument('--img', default='demo/demo.jpg', help='Image file') + parser.add_argument('--models', nargs='+', help='models name to inference') + parser.add_argument( + '--batch-size', + type=int, + default=1, + help='The batch size during the inference.') + parser.add_argument( + '--flops', action='store_true', help='Get Flops and Params of models') + parser.add_argument( + '--flops-str', + action='store_true', + help='Output FLOPs and params counts in a string form.') + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') + parser.add_argument( + '--size_divisor', + type=int, + default=32, + help='Pad the input image, the minimum size that is divisible ' + 'by size_divisor, -1 means do not pad the image.') + args = parser.parse_args() + return args + + +def inference(config_file, checkpoint, work_dir, args, exp_name): + logger = MMLogger.get_instance(name='MMLogger') + logger.warning('if you want test flops, please make sure torch>=1.12') + cfg = Config.fromfile(config_file) + cfg.work_dir = work_dir + cfg.load_from = checkpoint + cfg.log_level = 'WARN' + cfg.experiment_name = exp_name + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + init_default_scope(cfg.get('default_scope', 'mmocr')) + + # forward the model + result = {'model': config_file.stem} + + if args.flops: + + if len(args.shape) == 1: + h = w = args.shape[0] + elif len(args.shape) == 2: + h, w = args.shape + else: + raise ValueError('invalid input shape') + divisor = args.size_divisor + if divisor > 0: + h = int(np.ceil(h / divisor)) * divisor + w = int(np.ceil(w / divisor)) * divisor + + input_shape = (3, h, w) + result['resolution'] = input_shape + + try: + cfg = Config.fromfile(config_file) + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + model = MODELS.build(cfg.model) + input = torch.rand(1, *input_shape) + if torch.cuda.is_available(): + model.cuda() + input = input.cuda() + model = revert_sync_batchnorm(model) + inputs = (input, ) + model.eval() + outputs = get_model_complexity_info( + model, input_shape, inputs, show_table=False, show_arch=False) + flops = outputs['flops'] + params = outputs['params'] + activations = outputs['activations'] + result['Get Types'] = 'Random input' + except: # noqa 772 + logger = MMLogger.get_instance(name='MMLogger') + logger.warning( + 'Direct get flops failed, try to get flops with data') + cfg = Config.fromfile(config_file) + data_loader = Runner.build_dataloader(cfg.val_dataloader) + data_batch = next(iter(data_loader)) + model = MODELS.build(cfg.model) + if torch.cuda.is_available(): + model = model.cuda() + model = revert_sync_batchnorm(model) + model.eval() + _forward = model.forward + data = model.data_preprocessor(data_batch) + del data_loader + model.forward = partial( + _forward, data_samples=data['data_samples']) + outputs = get_model_complexity_info( + model, + input_shape, + data['inputs'], + show_table=False, + show_arch=False) + flops = outputs['flops'] + params = outputs['params'] + activations = outputs['activations'] + result['Get Types'] = 'Dataloader' + + if args.flops_str: + flops = _format_size(flops) + params = _format_size(params) + activations = _format_size(activations) + + result['flops'] = flops + result['params'] = params + + return result + + +def show_summary(summary_data, args): + table = Table(title='Validation Benchmark Regression Summary') + table.add_column('Model') + table.add_column('Validation') + table.add_column('Resolution (c, h, w)') + if args.flops: + table.add_column('Flops', justify='right', width=11) + table.add_column('Params', justify='right') + + for model_name, summary in summary_data.items(): + row = [model_name] + valid = summary['valid'] + color = 'green' if valid == 'PASS' else 'red' + row.append(f'[{color}]{valid}[/{color}]') + if valid == 'PASS': + row.append(str(summary['resolution'])) + if args.flops: + row.append(str(summary['flops'])) + row.append(str(summary['params'])) + table.add_row(*row) + + console.print(table) + table_data = { + x.header: [Text.from_markup(y).plain for y in x.cells] + for x in table.columns + } + table_pd = pd.DataFrame(table_data) + table_pd.to_csv('./mmocr_flops.csv') + + +# Sample test whether the inference code is correct +def main(args): + model_index_file = MMOCR_ROOT / 'model-index.yml' + model_index = load(str(model_index_file)) + model_index.build_models_with_collections() + models = OrderedDict({model.name: model for model in model_index.models}) + + logger = MMLogger( + 'validation', + logger_name='validation', + log_file='benchmark_test_image.log', + log_level=logging.INFO) + + if args.models: + patterns = [ + re.compile(pattern.replace('+', '_')) for pattern in args.models + ] + filter_models = {} + for k, v in models.items(): + k = k.replace('+', '_') + if any([re.match(pattern, k) for pattern in patterns]): + filter_models[k] = v + if len(filter_models) == 0: + print('No model found, please specify models in:') + print('\n'.join(models.keys())) + return + models = filter_models + + summary_data = {} + tmpdir = tempfile.TemporaryDirectory() + for model_name, model_info in tqdm(models.items()): + + if model_info.config is None: + continue + + model_info.config = model_info.config.replace('%2B', '+') + config = Path(model_info.config) + + try: + config.exists() + except: # noqa 722 + logger.error(f'{model_name}: {config} not found.') + continue + + logger.info(f'Processing: {model_name}') + + http_prefix = 'https://download.openmmlab.com/mmocr/' + if args.checkpoint_root is not None: + root = args.checkpoint_root + if 's3://' in args.checkpoint_root: + from petrel_client.common.exception import AccessDeniedError + file_client = FileClient.infer_client(uri=root) + checkpoint = file_client.join_path( + root, model_info.weights[len(http_prefix):]) + try: + exists = file_client.exists(checkpoint) + except AccessDeniedError: + exists = False + else: + checkpoint = Path(root) / model_info.weights[len(http_prefix):] + exists = checkpoint.exists() + if exists: + checkpoint = str(checkpoint) + else: + print(f'WARNING: {model_name}: {checkpoint} not found.') + checkpoint = None + else: + checkpoint = None + + try: + # build the model from a config file and a checkpoint file + result = inference(MMOCR_ROOT / config, checkpoint, tmpdir.name, + args, model_name) + result['valid'] = 'PASS' + except Exception: # noqa 722 + import traceback + logger.error(f'"{config}" :\n{traceback.format_exc()}') + result = {'valid': 'FAIL'} + + summary_data[model_name] = result + + tmpdir.cleanup() + show_summary(summary_data, args) + + +if __name__ == '__main__': + args = parse_args() + main(args) diff --git a/docs/en/user_guides/useful_tools.md b/docs/en/user_guides/useful_tools.md index 9ba6a49dd..27c028704 100644 --- a/docs/en/user_guides/useful_tools.md +++ b/docs/en/user_guides/useful_tools.md @@ -29,7 +29,7 @@ python tools/visualizations/browse_dataset.py \ | -t, --task | `auto`, `textdet`, `textrecog` | Specify the task type of the dataset. If `auto`, the task type will be inferred from the config. If the script is unable to infer the task type, you need to specify it manually. Defaults to `auto`. | | -n, --show-number | int | The number of samples to visualized. If not specified, display all images in the dataset. | | -i, --show-interval | float | Interval of visualization (s), defaults to 2. | -| --cfg-options | float | Override configs. [Example](./config.md#command-line-modification) | +| --cfg-options | str | Override configs. [Example](./config.md#command-line-modification) | #### Examples @@ -110,19 +110,15 @@ python tools/analysis_tools/offline_eval.py configs/textdet/psenet/psenet_r50_fp In addition, based on this tool, users can also convert predictions obtained from other libraries into MMOCR-supported formats, then use MMOCR's built-in metrics to evaluate them. -| ARGS | Type | Description | -| ------------- | ----- | ------------------------------------------------------------------ | -| config | str | (required) Path to the config. | -| pkl_results | str | (required) The saved predictions. | -| --cfg-options | float | Override configs. [Example](./config.md#command-line-modification) | +| ARGS | Type | Description | +| ------------- | ---- | ------------------------------------------------------------------ | +| config | str | (required) Path to the config. | +| pkl_results | str | (required) The saved predictions. | +| --cfg-options | str | Override configs. [Example](./config.md#command-line-modification) | ### Calculate FLOPs and the Number of Parameters -We provide a method to calculate the FLOPs and the number of parameters, first we install the dependencies using the following command. - -```shell -pip install fvcore -``` +We provide a method to calculate the FLOPs and the number of parameters. The usage of the script to calculate FLOPs and the number of parameters is as follows. @@ -130,10 +126,11 @@ The usage of the script to calculate FLOPs and the number of parameters is as fo python tools/analysis_tools/get_flops.py ${config} --shape ${IMAGE_SHAPE} ``` -| ARGS | Type | Description | -| ------- | ---- | ----------------------------------------------------------------------------------------- | -| config | str | (required) Path to the config. | -| --shape | int | Image size to use when calculating FLOPs, such as `--shape 320 320`. Default is `640 640` | +| ARGS | Type | Description | +| ------------- | ------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | +| config | str | (required) Path to the config. | +| --shape | int * \[1-3\] | Image size to use when calculating FLOPs, such as `--shape 320 320`. It can accept 1 to 3 arguments, representing `H&W`, `H, W` and `C, H, W` respectively (C = 3 by default). Default is `640 640` | +| --cfg-options | str | Override configs. [Example](./config.md#command-line-modification) | For example, you can run the following command to get FLOPs and the number of parameters of `dbnet_resnet18_fpnc_100k_synthtext.py`: @@ -144,51 +141,13 @@ python tools/analysis_tools/get_flops.py configs/textdet/dbnet/dbnet_resnet18_fp The output is as follows: ```shell -input shape is (1, 3, 1024, 1024) -| module | #parameters or shape | #flops | -| :------------------------ | :------------------- | :------ | -| model | 12.341M | 63.955G | -| backbone | 11.177M | 38.159G | -| backbone.conv1 | 9.408K | 2.466G | -| backbone.conv1.weight | (64, 3, 7, 7) | | -| backbone.bn1 | 0.128K | 83.886M | -| backbone.bn1.weight | (64,) | | -| backbone.bn1.bias | (64,) | | -| backbone.layer1 | 0.148M | 9.748G | -| backbone.layer1.0 | 73.984K | 4.874G | -| backbone.layer1.1 | 73.984K | 4.874G | -| backbone.layer2 | 0.526M | 8.642G | -| backbone.layer2.0 | 0.23M | 3.79G | -| backbone.layer2.1 | 0.295M | 4.853G | -| backbone.layer3 | 2.1M | 8.616G | -| backbone.layer3.0 | 0.919M | 3.774G | -| backbone.layer3.1 | 1.181M | 4.842G | -| backbone.layer4 | 8.394M | 8.603G | -| backbone.layer4.0 | 3.673M | 3.766G | -| backbone.layer4.1 | 4.721M | 4.837G | -| neck | 0.836M | 14.887G | -| neck.lateral_convs | 0.246M | 2.013G | -| neck.lateral_convs.0.conv | 16.384K | 1.074G | -| neck.lateral_convs.1.conv | 32.768K | 0.537G | -| neck.lateral_convs.2.conv | 65.536K | 0.268G | -| neck.lateral_convs.3.conv | 0.131M | 0.134G | -| neck.smooth_convs | 0.59M | 12.835G | -| neck.smooth_convs.0.conv | 0.147M | 9.664G | -| neck.smooth_convs.1.conv | 0.147M | 2.416G | -| neck.smooth_convs.2.conv | 0.147M | 0.604G | -| neck.smooth_convs.3.conv | 0.147M | 0.151G | -| det_head | 0.329M | 10.909G | -| det_head.binarize | 0.164M | 10.909G | -| det_head.binarize.0 | 0.147M | 9.664G | -| det_head.binarize.1 | 0.128K | 20.972M | -| det_head.binarize.3 | 16.448K | 1.074G | -| det_head.binarize.4 | 0.128K | 83.886M | -| det_head.binarize.6 | 0.257K | 67.109M | -| det_head.threshold | 0.164M | | -| det_head.threshold.0 | 0.147M | | -| det_head.threshold.1 | 0.128K | | -| det_head.threshold.3 | 16.448K | | -| det_head.threshold.4 | 0.128K | | -| det_head.threshold.6 | 0.257K | | + +============================== +Compute type: Random input +Input shape: torch.Size([1024, 1024]) +Flops: 63.737G +Params: 12.341M +============================== !!!Please be cautious if you use the results in papers. You may need to check if all ops are supported and verify that the flops computation is correct. + ``` diff --git a/docs/zh_cn/user_guides/useful_tools.md b/docs/zh_cn/user_guides/useful_tools.md index f72322cef..c6b688b98 100644 --- a/docs/zh_cn/user_guides/useful_tools.md +++ b/docs/zh_cn/user_guides/useful_tools.md @@ -29,7 +29,7 @@ python tools/visualizations/browse_dataset.py \ | -t, --task | `auto`, `textdet`, `textrecog` | 用于指定可视化数据集的任务类型。`auto`:自动模式,将依据给定的配置文件自动选择合适的任务类型,如果无法自动获取任务类型,则需要用户手动指定为 `textdet` 文本检测任务 或 `textrecog` 文本识别任务。默认采用 `auto` 自动模式。 | | -n, --show-number | int | 指定需要可视化的样本数量。若该参数缺省则默认将可视化全部图片。 | | -i, --show-interval | float | 可视化图像间隔时间,默认为 2 秒。 | -| --cfg-options | float | 用于覆盖配置文件中的参数,详见[示例](./config.md#command-line-modification)。 | +| --cfg-options | str | 用于覆盖配置文件中的参数,详见[示例](./config.md#command-line-modification)。 | #### 用法示例 @@ -110,11 +110,11 @@ python tools/analysis_tools/offline_eval.py configs/textdet/psenet/psenet_r50_fp 此外,基于此工具,用户也可以将其他算法库获取的预测结果转换成 MMOCR 支持的格式,从而使用 MMOCR 内置的评估指标来对其他算法库的模型进行评测。 -| 参数 | 类型 | 说明 | -| ------------- | ----- | ---------------------------------------------------------------- | -| config | str | (必须)配置文件路径。 | -| pkl_results | str | (必须)预先保存的预测结果文件。 | -| --cfg-options | float | 用于覆写配置文件中的指定参数。[示例](./config.md#命令行修改配置) | +| 参数 | 类型 | 说明 | +| ------------- | ---- | ---------------------------------------------------------------- | +| config | str | (必须)配置文件路径。 | +| pkl_results | str | (必须)预先保存的预测结果文件。 | +| --cfg-options | str | 用于覆写配置文件中的指定参数。[示例](./config.md#命令行修改配置) | ### 计算 FLOPs 和参数量 @@ -130,10 +130,11 @@ pip install fvcore python tools/analysis_tools/get_flops.py ${config} --shape ${IMAGE_SHAPE} ``` -| 参数 | 类型 | 说明 | -| ------- | ------ | ------------------------------------------------------------------ | -| config | str | (必须) 配置文件路径。 | -| --shape | int\*2 | 计算 FLOPs 使用的图片尺寸,如 `--shape 320 320`。 默认为 `640 640` | +| 参数 | 类型 | 说明 | +| ------------- | ------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------- | +| config | str | (必须) 配置文件路径。 | +| --shape | int * \[1-3\] | 计算 FLOPs 使用的图片尺寸,例如 `--shape 320 320`。它可以接受 1 到 3 个参数,分别表示 `H&W`,`H, W` 和 `C, H, W`(C = 3 为默认值)。默认为 `640 640` | +| --cfg-options | str | 用于覆写配置文件中的指定参数。[示例](./config.md#命令行修改配置) | 获取 `dbnet_resnet18_fpnc_100k_synthtext.py` FLOPs 和参数量的示例命令如下。 @@ -144,51 +145,13 @@ python tools/analysis_tools/get_flops.py configs/textdet/dbnet/dbnet_resnet18_fp 输出如下: ```shell -input shape is (1, 3, 1024, 1024) -| module | #parameters or shape | #flops | -| :------------------------ | :------------------- | :------ | -| model | 12.341M | 63.955G | -| backbone | 11.177M | 38.159G | -| backbone.conv1 | 9.408K | 2.466G | -| backbone.conv1.weight | (64, 3, 7, 7) | | -| backbone.bn1 | 0.128K | 83.886M | -| backbone.bn1.weight | (64,) | | -| backbone.bn1.bias | (64,) | | -| backbone.layer1 | 0.148M | 9.748G | -| backbone.layer1.0 | 73.984K | 4.874G | -| backbone.layer1.1 | 73.984K | 4.874G | -| backbone.layer2 | 0.526M | 8.642G | -| backbone.layer2.0 | 0.23M | 3.79G | -| backbone.layer2.1 | 0.295M | 4.853G | -| backbone.layer3 | 2.1M | 8.616G | -| backbone.layer3.0 | 0.919M | 3.774G | -| backbone.layer3.1 | 1.181M | 4.842G | -| backbone.layer4 | 8.394M | 8.603G | -| backbone.layer4.0 | 3.673M | 3.766G | -| backbone.layer4.1 | 4.721M | 4.837G | -| neck | 0.836M | 14.887G | -| neck.lateral_convs | 0.246M | 2.013G | -| neck.lateral_convs.0.conv | 16.384K | 1.074G | -| neck.lateral_convs.1.conv | 32.768K | 0.537G | -| neck.lateral_convs.2.conv | 65.536K | 0.268G | -| neck.lateral_convs.3.conv | 0.131M | 0.134G | -| neck.smooth_convs | 0.59M | 12.835G | -| neck.smooth_convs.0.conv | 0.147M | 9.664G | -| neck.smooth_convs.1.conv | 0.147M | 2.416G | -| neck.smooth_convs.2.conv | 0.147M | 0.604G | -| neck.smooth_convs.3.conv | 0.147M | 0.151G | -| det_head | 0.329M | 10.909G | -| det_head.binarize | 0.164M | 10.909G | -| det_head.binarize.0 | 0.147M | 9.664G | -| det_head.binarize.1 | 0.128K | 20.972M | -| det_head.binarize.3 | 16.448K | 1.074G | -| det_head.binarize.4 | 0.128K | 83.886M | -| det_head.binarize.6 | 0.257K | 67.109M | -| det_head.threshold | 0.164M | | -| det_head.threshold.0 | 0.147M | | -| det_head.threshold.1 | 0.128K | | -| det_head.threshold.3 | 16.448K | | -| det_head.threshold.4 | 0.128K | | -| det_head.threshold.6 | 0.257K | | + +============================== +Compute type: Random input +Input shape: torch.Size([1024, 1024]) +Flops: 63.737G +Params: 12.341M +============================== !!!Please be cautious if you use the results in papers. You may need to check if all ops are supported and verify that the flops computation is correct. + ``` diff --git a/tools/analysis_tools/get_flops.py b/tools/analysis_tools/get_flops.py index caa97203a..43351bd6f 100644 --- a/tools/analysis_tools/get_flops.py +++ b/tools/analysis_tools/get_flops.py @@ -1,54 +1,153 @@ # Copyright (c) OpenMMLab. All rights reserved. import argparse +import tempfile +from functools import partial +from pathlib import Path import torch -from fvcore.nn import FlopCountAnalysis, flop_count_table -from mmengine import Config +from mmengine.config import Config, DictAction +from mmengine.logging import MMLogger +from mmengine.model import revert_sync_batchnorm from mmengine.registry import init_default_scope +from mmengine.runner import Runner from mmocr.registry import MODELS +try: + from mmengine.analysis import get_model_complexity_info + from mmengine.analysis.print_helper import _format_size +except ImportError: + raise ImportError('Please upgrade to mmengine>=0.6.0') + def parse_args(): - parser = argparse.ArgumentParser(description='Train a detector') + parser = argparse.ArgumentParser(description='Get model flops') parser.add_argument('config', help='train config file path') parser.add_argument( '--shape', type=int, nargs='+', default=[640, 640], - help='input image size') + help='Image size to use when calculating FLOPs, e.g. ' + '`--shape 320 320`. It can accept 1 to 3 arguments, representing' + ' `H&W`, `H, W` and `C, H, W` respectively (C = 3 by default).'), + parser.add_argument( + '--cfg-options', + nargs='+', + action=DictAction, + help='override some settings in the used config, the key-value pair ' + 'in xxx=yyy format will be merged into config file. If the value to ' + 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' + 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' + 'Note that the quotation marks are necessary and that no white space ' + 'is allowed.') args = parser.parse_args() return args -def main(): +def inference(args, logger): - args = parse_args() + config_name = Path(args.config) + if not config_name.exists(): + logger.error(f'{config_name} not found.') + + cfg = Config.fromfile(args.config) + cfg.work_dir = tempfile.TemporaryDirectory().name + cfg.log_level = 'WARN' + if args.cfg_options is not None: + cfg.merge_from_dict(args.cfg_options) + + init_default_scope(cfg.get('default_scope', 'mmocr')) + c = 3 if len(args.shape) == 1: h = w = args.shape[0] elif len(args.shape) == 2: h, w = args.shape + elif len(args.shape) == 3: + c, h, w = args.shape else: - raise ValueError('invalid input shape, please use --shape h w') + raise ValueError('invalid input shape') + result = {} + + # Supports two ways to calculate flops, + # 1. randomly generate a picture + # 2. load a picture from the dataset + try: + model = MODELS.build(cfg.model) + if torch.cuda.is_available(): + model.cuda() + model = revert_sync_batchnorm(model) + data_batch = {'inputs': [torch.rand(c, h, w)], 'batch_samples': [None]} + data = model.data_preprocessor(data_batch) + result['ori_shape'] = (h, w) + result['pad_shape'] = data['inputs'].shape[-2:] + model.eval() + outputs = get_model_complexity_info( + model, + None, + inputs=data['inputs'], + show_table=False, + show_arch=False) + flops = outputs['flops'] + params = outputs['params'] + result['compute_type'] = 'Random input' + + except TypeError: + logger.warning( + 'Failed to directly get FLOPs, try to get flops with real data') + data_loader = Runner.build_dataloader(cfg.val_dataloader) + data_batch = next(iter(data_loader)) + model = MODELS.build(cfg.model) + if torch.cuda.is_available(): + model = model.cuda() + model = revert_sync_batchnorm(model) + model.eval() + _forward = model.forward + data = model.data_preprocessor(data_batch) + result['ori_shape'] = data['data_samples'][0].ori_shape + result['pad_shape'] = data['data_samples'][0].pad_shape + + del data_loader + model.forward = partial(_forward, data_samples=data['data_samples']) + outputs = get_model_complexity_info( + model, + None, + inputs=data['inputs'], + show_table=False, + show_arch=False) + flops = outputs['flops'] + params = outputs['params'] + result['compute_type'] = 'Dataloader' + + flops = _format_size(flops) + params = _format_size(params) + result['flops'] = flops + result['params'] = params + + return result - input_shape = (1, 3, h, w) - - cfg = Config.fromfile(args.config) - init_default_scope(cfg.get('default_scope', 'mmocr')) - model = MODELS.build(cfg.model) - - flops = FlopCountAnalysis(model, torch.ones(input_shape)) - - # params = parameter_count_table(model) - flops_data = flop_count_table(flops) - - print(flops_data) +def main(): + args = parse_args() + logger = MMLogger.get_instance(name='MMLogger') + result = inference(args, logger) + split_line = '=' * 30 + ori_shape = result['ori_shape'] + pad_shape = result['pad_shape'] + flops = result['flops'] + params = result['params'] + compute_type = result['compute_type'] + + if pad_shape != ori_shape: + print(f'{split_line}\nUse size divisor set input shape ' + f'from {ori_shape} to {pad_shape}') + print(f'{split_line}\nCompute type: {compute_type}\n' + f'Input shape: {pad_shape}\nFlops: {flops}\n' + f'Params: {params}\n{split_line}') print('!!!Please be cautious if you use the results in papers. ' - 'You may need to check if all ops are supported and verify that the ' - 'flops computation is correct.') + 'You may need to check if all ops are supported and verify ' + 'that the flops computation is correct.') if __name__ == '__main__':