Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support model complexity computation #779

Merged
merged 52 commits into from Feb 20, 2023

Conversation

tonysy
Copy link
Collaborator

@tonysy tonysy commented Nov 30, 2022

Implementation of model complexity analysis.

Similar usage with mmcv.cnn.get_model_complexity_info

from mmengine.analysis import get_model_complexity_info

# return a dict of analysis results, including:
# ['flops', 'flops_str', 'activations', 'activations_str', 'params', 'params_str', 'out_table', 'out_str']
analysis_results = get_model_complexity_info(model, input_shape)

print(analysis_results['flops_str'])
print(analysis_results['params_str'])
  • Documentation
  • Unit testing
  • Cls Example
  • Det Example

@zhouzaida
Copy link
Member

Lint failed.

@vansin
Copy link
Collaborator

vansin commented Dec 12, 2022

success in one stage detector , but failed in two stage detector in mmdet 3.x @ZwwWayne @RangiLyu

get_flops.py

# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import numpy as np
import torch
from mmengine.config import Config, DictAction

from mmdet.registry import MODELS
from mmdet.utils import register_all_modules

try:
    from mmengine.analysis import get_model_complexity_info
except ImportError:
    raise ImportError('Please upgrade mmcv to >0.6.2')


def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[1280, 800],
        help='input image size')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--size-divisor',
        type=int,
        default=32,
        help='Pad the input image, the minimum size that is divisible '
        'by size_divisor, -1 means do not pad the image.')
    args = parser.parse_args()
    return args


def main():
    register_all_modules()
    args = parse_args()

    if len(args.shape) == 1:
        h = w = args.shape[0]
    elif len(args.shape) == 2:
        h, w = args.shape
    else:
        raise ValueError('invalid input shape')
    ori_shape = (3, h, w)
    divisor = args.size_divisor
    if divisor > 0:
        h = int(np.ceil(h / divisor)) * divisor
        w = int(np.ceil(w / divisor)) * divisor

    input_shape = (3, h, w)

    cfg = Config.fromfile(args.config)
    if args.cfg_options is not None:
        cfg.merge_from_dict(args.cfg_options)

    model = MODELS.build(cfg.model)
    # if torch.cuda.is_available():
    #     model.cuda()
    model.eval()

    flops, activations, params, complexity_table, complexity_str = get_model_complexity_info(model, input_shape, show_table=True, show_str=True)
    split_line = '=' * 30

    if divisor > 0 and \
            input_shape != ori_shape:
        print(f'{split_line}\nUse size divisor set input shape '
              f'from {ori_shape} to {input_shape}\n')
    print(f'{split_line}\nInput shape: {input_shape}\n'
          f'Flops: {flops}\nParams: {params}\n{split_line}')

    print(activations)
    print(complexity_table)
    # print(complexity_str)
    print('!!!Please be cautious if you use the results in papers. '
          'You may need to check if all ops are supported and verify that the '
          'flops computation is correct.')


if __name__ == '__main__':
    main()
(mmroate) ➜  mmdetection git:(dev-3.x) ✗ python tools/analysis_tools/get_flops.py configs/yolo/yolov3_d53_8xb8-320-273e_coco.py 
12/12 13:20:29 - mmengine - WARNING - Unsupported operator aten::leaky_relu_ encountered 72 time(s)
12/12 13:20:29 - mmengine - WARNING - Unsupported operator aten::add encountered 23 time(s)
12/12 13:20:29 - mmengine - WARNING - The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bbox_head.loss_cls, bbox_head.loss_conf, bbox_head.loss_wh, bbox_head.loss_xy, data_preprocessor
12/12 13:20:30 - mmengine - WARNING - Unsupported operator aten::batch_norm encountered 72 time(s)
12/12 13:20:30 - mmengine - WARNING - Unsupported operator aten::leaky_relu_ encountered 72 time(s)
12/12 13:20:30 - mmengine - WARNING - Unsupported operator aten::add encountered 23 time(s)
12/12 13:20:30 - mmengine - WARNING - Unsupported operator aten::upsample_nearest2d encountered 2 time(s)
12/12 13:20:30 - mmengine - WARNING - The following submodules of the model were never called during the trace of the graph. They may be unused, or they were accessed by direct calls to .forward() or via other python methods. In the latter case they will have zeros for statistics, though their statistics will still contribute to their parent calling module.
bbox_head.loss_cls, bbox_head.loss_conf, bbox_head.loss_wh, bbox_head.loss_xy, data_preprocessor
==============================
Input shape: (3, 1280, 800)
Flops: 0.195T
Params: 61.949M
==============================
0.232G

| module                           | #parameters or shape   | #flops     | #activations   |
|:---------------------------------|:-----------------------|:-----------|:---------------|
| model                            | 61.949M                | 0.195T     | 0.232G         |
|  backbone                        |  40.585M               |  0.145T    |  0.194G        |
|   backbone.conv1                 |   0.928K               |   0.95G    |   32.768M      |
|    backbone.conv1.conv           |    0.864K              |    0.885G  |    32.768M     |
|    backbone.conv1.bn             |    64                  |    65.536M |    0           |
|   backbone.conv_res_block1       |   39.232K              |   10.043G  |   40.96M       |
|    backbone.conv_res_block1.conv |    18.56K              |    4.751G  |    16.384M     |
|    backbone.conv_res_block1.res0 |    20.672K             |    5.292G  |    24.576M     |
|   backbone.conv_res_block2       |   0.239M               |   15.27G   |   32.768M      |
|    backbone.conv_res_block2.conv |    73.984K             |    4.735G  |    8.192M      |
|    backbone.conv_res_block2.res0 |    82.304K             |    5.267G  |    12.288M     |
|    backbone.conv_res_block2.res1 |    82.304K             |    5.267G  |    12.288M     |
|   backbone.conv_res_block3       |   2.923M               |   46.768G  |   53.248M      |
|    backbone.conv_res_block3.conv |    0.295M              |    4.727G  |    4.096M      |
|    backbone.conv_res_block3.res0 |    0.328M              |    5.255G  |    6.144M      |
|    backbone.conv_res_block3.res1 |    0.328M              |    5.255G  |    6.144M      |
|    backbone.conv_res_block3.res2 |    0.328M              |    5.255G  |    6.144M      |
|    backbone.conv_res_block3.res3 |    0.328M              |    5.255G  |    6.144M      |
|    backbone.conv_res_block3.res4 |    0.328M              |    5.255G  |    6.144M      |
|    backbone.conv_res_block3.res5 |    0.328M              |    5.255G  |    6.144M      |
|    backbone.conv_res_block3.res6 |    0.328M              |    5.255G  |    6.144M      |
|    backbone.conv_res_block3.res7 |    0.328M              |    5.255G  |    6.144M      |
|   backbone.conv_res_block4       |   11.679M              |   46.715G  |   26.624M      |
|    backbone.conv_res_block4.conv |    1.181M              |    4.723G  |    2.048M      |
|    backbone.conv_res_block4.res0 |    1.312M              |    5.249G  |    3.072M      |
|    backbone.conv_res_block4.res1 |    1.312M              |    5.249G  |    3.072M      |
|    backbone.conv_res_block4.res2 |    1.312M              |    5.249G  |    3.072M      |
|    backbone.conv_res_block4.res3 |    1.312M              |    5.249G  |    3.072M      |
|    backbone.conv_res_block4.res4 |    1.312M              |    5.249G  |    3.072M      |
|    backbone.conv_res_block4.res5 |    1.312M              |    5.249G  |    3.072M      |
|    backbone.conv_res_block4.res6 |    1.312M              |    5.249G  |    3.072M      |
|    backbone.conv_res_block4.res7 |    1.312M              |    5.249G  |    3.072M      |
|   backbone.conv_res_block5       |   25.704M              |   25.704G  |   7.168M       |
|    backbone.conv_res_block5.conv |    4.721M              |    4.721G  |    1.024M      |
|    backbone.conv_res_block5.res0 |    5.246M              |    5.246G  |    1.536M      |
|    backbone.conv_res_block5.res1 |    5.246M              |    5.246G  |    1.536M      |
|    backbone.conv_res_block5.res2 |    5.246M              |    5.246G  |    1.536M      |
|    backbone.conv_res_block5.res3 |    5.246M              |    5.246G  |    1.536M      |
|  neck                            |  14.71M                |  33.871G   |  25.856M       |
|   neck.detect1                   |   11.017M              |   11.017G  |   3.584M       |
|    neck.detect1.conv1            |    0.525M              |    0.525G  |    0.512M      |
|    neck.detect1.conv2            |    4.721M              |    4.721G  |    1.024M      |
|    neck.detect1.conv3            |    0.525M              |    0.525G  |    0.512M      |
|    neck.detect1.conv4            |    4.721M              |    4.721G  |    1.024M      |
|    neck.detect1.conv5            |    0.525M              |    0.525G  |    0.512M      |
|   neck.conv1                     |   0.132M               |   0.132G   |   0.256M       |
|    neck.conv1.conv               |    0.131M              |    0.131G  |    0.256M      |
|    neck.conv1.bn                 |    0.512K              |    0.512M  |    0           |
|   neck.detect2                   |   2.822M               |   11.287G  |   7.168M       |
|    neck.detect2.conv1            |    0.197M              |    0.788G  |    1.024M      |
|    neck.detect2.conv2            |    1.181M              |    4.723G  |    2.048M      |
|    neck.detect2.conv3            |    0.132M              |    0.526G  |    1.024M      |
|    neck.detect2.conv4            |    1.181M              |    4.723G  |    2.048M      |
|    neck.detect2.conv5            |    0.132M              |    0.526G  |    1.024M      |
|   neck.conv2                     |   33.024K              |   0.132G   |   0.512M       |
|    neck.conv2.conv               |    32.768K             |    0.131G  |    0.512M      |
|    neck.conv2.bn                 |    0.256K              |    1.024M  |    0           |
|   neck.detect3                   |   0.706M               |   11.301G  |   14.336M      |
|    neck.detect3.conv1            |    49.408K             |    0.791G  |    2.048M      |
|    neck.detect3.conv2            |    0.295M              |    4.727G  |    4.096M      |
|    neck.detect3.conv3            |    33.024K             |    0.528G  |    2.048M      |
|    neck.detect3.conv4            |    0.295M              |    4.727G  |    4.096M      |
|    neck.detect3.conv5            |    33.024K             |    0.528G  |    2.048M      |
|  bbox_head                       |  6.654M                |  15.998G   |  12.523M       |
|   bbox_head.convs_bridge         |   6.197M               |   14.17G   |   7.168M       |
|    bbox_head.convs_bridge.0      |    4.721M              |    4.721G  |    1.024M      |
|    bbox_head.convs_bridge.1      |    1.181M              |    4.723G  |    2.048M      |
|    bbox_head.convs_bridge.2      |    0.295M              |    4.727G  |    4.096M      |
|   bbox_head.convs_pred           |   0.458M               |   1.828G   |   5.355M       |
|    bbox_head.convs_pred.0        |    0.261M              |    0.261G  |    0.255M      |
|    bbox_head.convs_pred.1        |    0.131M              |    0.522G  |    1.02M       |
|    bbox_head.convs_pred.2        |    65.535K             |    1.044G  |    4.08M       |
!!!Please be cautious if you use the results in papers. You may need to check if all ops are supported and verify that the flops computation is correct.
(mmroate) ➜  mmdetection git:(dev-3.x) ✗ 
(mmroate) ➜  mmdetection git:(dev-3.x) ✗ python tools/analysis_tools/get_flops.py configs/faster_rcnn/faster-rcnn_r101_fpn_1x_coco.py      
Traceback (most recent call last):
  File "tools/analysis_tools/get_flops.py", line 92, in <module>
    main()
  File "tools/analysis_tools/get_flops.py", line 73, in main
    flops, activations, params, complexity_table, complexity_str = get_model_complexity_info(model, input_shape, show_table=True, show_str=True)
  File "/home/ubuntu/mmroate-1.x/mmengine/mmengine/analysis/print_helper.py", line 668, in get_model_complexity_info
    flops = _format_size(flop_handler.total())
  File "/home/ubuntu/mmroate-1.x/mmengine/mmengine/analysis/jit_analysis.py", line 259, in total
    stats = self._analyze()
  File "/home/ubuntu/mmroate-1.x/mmengine/mmengine/analysis/jit_analysis.py", line 550, in _analyze
    graph = _get_scoped_trace_graph(self._model, self._inputs,
  File "/home/ubuntu/mmroate-1.x/mmengine/mmengine/analysis/jit_analysis.py", line 189, in _get_scoped_trace_graph
    graph, _ = _get_trace_graph(module, inputs)
  File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/jit/_trace.py", line 1166, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1120, in _call_impl
    result = forward_call(*input, **kwargs)
  File "/home/ubuntu/miniconda3/envs/mmroate/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1090, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/ubuntu/mmroate-1.x/mmdetection/mmdet/models/detectors/base.py", line 96, in forward
    return self._forward(inputs, data_samples)
  File "/home/ubuntu/mmroate-1.x/mmdetection/mmdet/models/detectors/two_stage.py", line 131, in _forward
    rpn_results_list = self.rpn_head.predict(
  File "/home/ubuntu/mmroate-1.x/mmdetection/mmdet/models/dense_heads/base_dense_head.py", line 191, in predict
    batch_img_metas = [
TypeError: 'NoneType' object is not iterable```

@CLAassistant
Copy link

CLAassistant commented Dec 14, 2022

CLA assistant check
All committers have signed the CLA.

@HAOCHENYE HAOCHENYE added this to the 0.6.0 milestone Jan 12, 2023
@tonysy tonysy requested a review from C1rN09 as a code owner February 1, 2023 17:10
@tonysy
Copy link
Collaborator Author

tonysy commented Feb 1, 2023

Usage in mmcls

  • get_flops.py
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

from mmengine.analysis import get_model_complexity_info
from mmengine import Config

from mmcls.models import build_classifier


def parse_args():
    parser = argparse.ArgumentParser(description='Get model flops and params')
    parser.add_argument('config', help='config file path')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[224, 224],
        help='input image size')
    args = parser.parse_args()
    return args


def main():

    args = parse_args()

    if len(args.shape) == 1:
        input_shape = (3, args.shape[0], args.shape[0])
    elif len(args.shape) == 2:
        input_shape = (3, ) + tuple(args.shape)
    else:
        raise ValueError('invalid input shape')

    cfg = Config.fromfile(args.config)
    model = build_classifier(cfg.model)
    model.eval()

    if hasattr(model, 'extract_feat'):
        model.forward = model.extract_feat
    else:
        raise NotImplementedError(
            'FLOPs counter is currently not currently supported with {}'.
            format(model.__class__.__name__))

    analysis_results = get_model_complexity_info(model,input_shape,)


    flops = analysis_results['flops_str']
    params = analysis_results['params_str']
    activations = analysis_results['activations_str']

    out_table = analysis_results['out_table']
    out_arch = analysis_results['out_arch']
    print(out_table)
    print(out_arch)

    split_line = '=' * 30
    print(f'{split_line}\nInput shape: {input_shape}\n'
          f'Flops: {flops}\nParams: {params}\n'
          f'Activation: {activations}\n{split_line}')
    print('!!!Please be cautious if you use the results in papers. '
          'You may need to check if all ops are supported and verify that the '
          'flops computation is correct.')


if __name__ == '__main__':
    main()
  • CLI command example
    python tools/analysis_tools/get_flops.py configs/deit/deit-small_pt-4xb256_in1k.py

@tonysy
Copy link
Collaborator Author

tonysy commented Feb 1, 2023

Usage in mmdet

  • get_flops.py:
# Copyright (c) OpenMMLab. All rights reserved.
import argparse

import numpy as np
import torch
from mmengine.config import Config, DictAction

from mmdet.registry import MODELS
from mmengine import Config
from functools import partial
from mmdet.utils import register_all_modules
from mmengine.runner import Runner
from mmengine.logging import MMLogger

from mmengine.analysis import get_model_complexity_info


def parse_args():
    parser = argparse.ArgumentParser(description='Train a detector')
    parser.add_argument('config', help='train config file path')
    parser.add_argument(
        '--shape',
        type=int,
        nargs='+',
        default=[1280, 800],
        help='input image size')
    parser.add_argument(
        '--cfg-options',
        nargs='+',
        action=DictAction,
        help='override some settings in the used config, the key-value pair '
        'in xxx=yyy format will be merged into config file. If the value to '
        'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
        'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
        'Note that the quotation marks are necessary and that no white space '
        'is allowed.')
    parser.add_argument(
        '--size-divisor',
        type=int,
        default=32,
        help='Pad the input image, the minimum size that is divisible '
        'by size_divisor, -1 means do not pad the image.')
    args = parser.parse_args()
    return args


def main():
    register_all_modules()
    args = parse_args()

    if len(args.shape) == 1:
        h = w = args.shape[0]
    elif len(args.shape) == 2:
        h, w = args.shape
    else:
        raise ValueError('invalid input shape')
    ori_shape = (3, h, w)
    divisor = args.size_divisor
    if divisor > 0:
        h = int(np.ceil(h / divisor)) * divisor
        w = int(np.ceil(w / divisor)) * divisor

    input_shape = (3, h, w)

    try:
        cfg = Config.fromfile(args.config)
        if args.cfg_options is not None:
            cfg.merge_from_dict(args.cfg_options)

        model = MODELS.build(cfg.model)
        if torch.cuda.is_available():
            model.cuda()
        model.eval()
        
        analysis_results = get_model_complexity_info(model, input_shape)
        flops = analysis_results['flops_str']
        activations = analysis_results['activations_str']
        params = analysis_results['params_str']

    except:
        logger = MMLogger.get_instance(name='MMLogger')
        logger.warning('Direct get flops failed, try to get flops with data')
        cfg = Config.fromfile(args.config)
        data_loader = Runner.build_dataloader(cfg.val_dataloader)
        data_batch = next(iter(data_loader))
        model = MODELS.build(cfg.model)
        _forward = model.forward
        data = model.data_preprocessor(data_batch)

        model.forward = partial(_forward, data_samples=data['data_samples'])

        analysis_results = get_model_complexity_info(
            model, input_shape, data['inputs'])

        flops = analysis_results['flops_str']
        activations = analysis_results['activations_str']
        params = analysis_results['params_str']
        
    print(analysis_results['out_table'])
    print(analysis_results['out_arch'])

    split_line = '=' * 30

    if divisor > 0 and \
            input_shape != ori_shape:
        print(f'{split_line}\nUse size divisor set input shape '
              f'from {ori_shape} to {input_shape}\n')
    print(f'{split_line}\nInput shape: {input_shape}\n'
          f'Flops: {flops}\nParams: {params}\n{split_line}')
    print('!!!Please be cautious if you use the results in papers. '
          'You may need to check if all ops are supported and verify that the '
          'flops computation is correct.')


if __name__ == '__main__':
    main()
  • CLI example:
    python tools/analysis_tools/get_flops.py configs/fcos/fcos_r101-caffe_fpn_gn-head-1x_coco.py

@codecov
Copy link

codecov bot commented Feb 14, 2023

Codecov Report

❗ No coverage uploaded for pull request base (main@30fe410). Click here to learn what that means.
Patch has no changes to coverable lines.

❗ Current head 6f64061 differs from pull request most recent head 862673a. Consider uploading reports for the commit 862673a to get more accurate results

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #779   +/-   ##
=======================================
  Coverage        ?   76.82%           
=======================================
  Files           ?      138           
  Lines           ?    10791           
  Branches        ?     2154           
=======================================
  Hits            ?     8290           
  Misses          ?     2143           
  Partials        ?      358           
Flag Coverage Δ
unittests 76.82% <0.00%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@zhouzaida
Copy link
Member

mmengine/analysis/print_helper.py imports the tabulate but it is not added to runtime.txt. Suggest using rich to replace tabulate.

The reason why ut did not fail was that tabulate was installed when installing openmim.

@zhouzaida
Copy link
Member

image

image

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants