Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Support model complexity computation #779

Merged
merged 52 commits into from
Feb 20, 2023
Merged
Show file tree
Hide file tree
Changes from 49 commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
2bfe5b9
[Feature] Add support model complexity computation
tonysy Nov 30, 2022
a8528a1
[Fix] fix lint error
tonysy Dec 1, 2022
57c1c67
[Feature] update print_helper
tonysy Dec 26, 2022
461fc74
Update docstring
tonysy Dec 26, 2022
fbb4193
update api, docs, fix lint
tonysy Feb 1, 2023
473b881
fix lint
tonysy Feb 1, 2023
2c29faa
update doc and add test
tonysy Feb 14, 2023
a0635ba
update docstring
tonysy Feb 14, 2023
86f5b8a
update docstring
tonysy Feb 14, 2023
330d3c7
update test
tonysy Feb 14, 2023
6e38a2a
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
5c8ebd0
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
57a310c
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
96fd397
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
e88b611
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
05b4ea4
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
30971b7
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
acb135b
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
e8bfc87
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
02e657b
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
a143671
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
41c2960
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
e6377c7
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
6326a15
Update mmengine/analysis/print_helper.py
tonysy Feb 15, 2023
5f0db43
Update mmengine/analysis/complexity_analysis.py
tonysy Feb 15, 2023
2c0318a
Update docs/en/advanced_tutorials/model_analysis.md
tonysy Feb 15, 2023
c1fcca7
Update docs/en/advanced_tutorials/model_analysis.md
tonysy Feb 15, 2023
10cbf5d
update docs
tonysy Feb 15, 2023
9c6e6fb
update docs
tonysy Feb 15, 2023
9eb7c12
update docs and docstring
tonysy Feb 15, 2023
20831ff
update docs
tonysy Feb 15, 2023
e7f97e8
update test withj mmlogger
tonysy Feb 16, 2023
7be6d0f
Update mmengine/analysis/complexity_analysis.py
tonysy Feb 17, 2023
9cccbfb
Update tests/test_analysis/test_activation_count.py
tonysy Feb 17, 2023
0b0136f
Apply suggestions from code review
tonysy Feb 17, 2023
36c6154
update test according to review
tonysy Feb 17, 2023
4fac94a
Apply suggestions from code review
tonysy Feb 17, 2023
1c08a64
fix lint
tonysy Feb 17, 2023
a995c76
fix test
tonysy Feb 17, 2023
630aceb
Apply suggestions from code review
zhouzaida Feb 17, 2023
90b233b
fix API document
zhouzaida Feb 17, 2023
068d14a
Update analysis.rst
zhouzaida Feb 17, 2023
9491f41
rename variables
zhouzaida Feb 17, 2023
16e70e5
minor refinement
zhouzaida Feb 17, 2023
9a2f0d7
Apply suggestions from code review
zhouzaida Feb 19, 2023
42f5ff8
fix lint
zhouzaida Feb 19, 2023
c4cd248
replace tabulate with existing rich
zhouzaida Feb 20, 2023
b7aca58
Apply suggestions from code review
zhouzaida Feb 20, 2023
01e50ca
indent
zhouzaida Feb 20, 2023
a9998f7
Update mmengine/analysis/complexity_analysis.py
zhouzaida Feb 20, 2023
6812b62
Update mmengine/analysis/complexity_analysis.py
zhouzaida Feb 20, 2023
862673a
Update mmengine/analysis/complexity_analysis.py
zhouzaida Feb 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
171 changes: 171 additions & 0 deletions docs/en/advanced_tutorials/model_analysis.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,171 @@
# Model Complexity Analysis

We provide a tool to help with the complexity analysis for the network. We borrow the idea from the implementation of [fvcore](https://github.com/facebookresearch/fvcore) to build this tool, and plan to support more custom operators in the future. Currently, it provides the interfaces to compute "parameter", "activation" and "flops" of the given model, and supports printing the related information layer-by-layer in terms of network structure or table. The analysis tool provides both operator-level and module-level flop counts simultaneously. Please refer to [Flop Count](https://github.com/facebookresearch/fvcore/blob/main/docs/flop_count.md) for implementation details of how to accurately measure the flops of one operator if interested.

## What's FLOPs

Flop is not a well-defined metric in complexity analysis, we follow [detectron2](https://detectron2.readthedocs.io/en/latest/modules/fvcore.html#fvcore.nn.FlopCountAnalysis) to use one fused multiple-add as one flop.

## What's Activation

Activation is used to measure the feature quantity produced from one layer.

For example, given the inputs with shape `inputs = torch.randn((1, 3, 10, 10))`, and one linear layer with `conv = nn.Conv2d(in_channels=3, out_channels=10, kernel_size=1)`.

We get the `output` with shape `(1, 10, 10, 10)` after feeding the `inputs` into `conv`. The activation quantity of `output` of this `conv` layer is `1000=10*10*10`

Let's start with the following examples.

## Usage Example 1: Model built with native nn.Module

### Code

```python
import torch
from torch import nn
from mmengine.analysis import get_model_complexity_info
# return a dict of analysis results, including:
# ['flops', 'flops_str', 'activations', 'activations_str', 'params', 'params_str', 'out_table', 'out_arch']

class InnerNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10,10)
self.fc2 = nn.Linear(10,10)
def forward(self, x):
return self.fc1(self.fc2(x))


class TestNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10,10)
self.fc2 = nn.Linear(10,10)
self.inner = InnerNet()
def forward(self, x):
return self.fc1(self.fc2(self.inner(x)))

input_shape = (1, 10)
model = TestNet()

analysis_results = get_model_complexity_info(model, input_shape)

print(analysis_results['out_table'])
print(analysis_results['out_arch'])

print("Model Flops:{}".format(analysis_results['flops_str']))
print("Model Parameters:{}".format(analysis_results['params_str']))
```

tonysy marked this conversation as resolved.
Show resolved Hide resolved
### Description of Results

The return outputs is dict, which contains the following keys:

- `flops`: number of total flops, e.g., 10000, 10000
- `flops_str`: with formatted string, e.g., 1.0G, 100M
- `params`: number of total parameters, e.g., 10000, 10000
- `params_str`: with formatted string, e.g., 1.0G, 100M
- `activations`: number of total activations, e.g., 10000, 10000
- `activations_str`: with formatted string, e.g., 1.0G, 100M
- `out_table`: print related information by table
tonysy marked this conversation as resolved.
Show resolved Hide resolved

```
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━┳━━━━━━━━━━━━━━┓
┃ module ┃ #parameters or shape ┃ #flops ┃ #activations ┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━╇━━━━━━━━━━━━━━┩
│ model │ 0.44K │ 0.4K │ 40 │
│ fc1 │ 0.11K │ 100 │ 10 │
│ fc1.weight │ (10, 10) │ │ │
│ fc1.bias │ (10,) │ │ │
│ fc2 │ 0.11K │ 100 │ 10 │
│ fc2.weight │ (10, 10) │ │ │
│ fc2.bias │ (10,) │ │ │
│ inner │ 0.22K │ 0.2K │ 20 │
│ inner.fc1 │ 0.11K │ 100 │ 10 │
│ inner.fc1.weight │ (10, 10) │ │ │
│ inner.fc1.bias │ (10,) │ │ │
│ inner.fc2 │ 0.11K │ 100 │ 10 │
│ inner.fc2.weight │ (10, 10) │ │ │
│ inner.fc2.bias │ (10,) │ │ │
└─────────────────────┴──────────────────────┴────────┴──────────────┘
```

- `out_arch`: print related information by network layers

```bash
TestNet(
#params: 0.44K, #flops: 0.4K, #acts: 40
(fc1): Linear(
in_features=10, out_features=10, bias=True
#params: 0.11K, #flops: 100, #acts: 10
)
(fc2): Linear(
in_features=10, out_features=10, bias=True
#params: 0.11K, #flops: 100, #acts: 10
)
(inner): InnerNet(
#params: 0.22K, #flops: 0.2K, #acts: 20
(fc1): Linear(
in_features=10, out_features=10, bias=True
#params: 0.11K, #flops: 100, #acts: 10
)
(fc2): Linear(
in_features=10, out_features=10, bias=True
#params: 0.11K, #flops: 100, #acts: 10
)
)
)
```

## Usage Example 2: Model built with mmengine

### Code

```python
import torch.nn.functional as F
import torchvision
from mmengine.model import BaseModel
from mmengine.analysis import get_model_complexity_info


class MMResNet50(BaseModel):
def __init__(self):
super().__init__()
self.resnet = torchvision.models.resnet50()

def forward(self, imgs, labels=None, mode='tensor'):
x = self.resnet(imgs)
if mode == 'loss':
return {'loss': F.cross_entropy(x, labels)}
elif mode == 'predict':
return x, labels
elif mode == 'tensor':
return x


input_shape = (3, 224, 224)
model = MMResNet50()

analysis_results = get_model_complexity_info(model, input_shape)


print("Model Flops:{}".format(analysis_results['flops_str']))
print("Model Parameters:{}".format(analysis_results['params_str']))
```

### Output

```bash
Model Flops:4.145G
Model Parameters:25.557M
```

## Interface

We provide more options to support custom output

- `model`: (nn.Module) the model to be analyzed
- `input_shape`: (tuple) the shape of the input, e.g., (3, 224, 224)
- `inputs`: (optional: torch.Tensor), if given, `input_shape` will be ignored
- `show_table`: (bool) whether return the statistics in the form of table, default: True
- `show_arch`: (bool) whether return the statistics in the form of table, default: True
30 changes: 30 additions & 0 deletions docs/en/api/analysis.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
.. role:: hidden
:class: hidden-section

mmengine.analysis
===================================

.. contents:: mmengine.analysis
:depth: 2
:local:
:backlinks: top

.. currentmodule:: mmengine.analysis

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

ActivationAnalyzer
FlopAnalyzer

.. autosummary::
:toctree: generated
:nosignatures:

activation_count
flop_count
parameter_count
parameter_count_table
get_model_complexity_info
2 changes: 2 additions & 0 deletions docs/en/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ You can switch between Chinese and English documents in the lower-left corner of
advanced_tutorials/manager_mixin.md
advanced_tutorials/cross_library.md
advanced_tutorials/test_time_augmentation.md
advanced_tutorials/model_analysis.md

.. toctree::
:maxdepth: 1
Expand All @@ -79,6 +80,7 @@ You can switch between Chinese and English documents in the lower-left corner of
:maxdepth: 2
:caption: API Reference

mmengine.analysis <api/analysis>
mmengine.registry <api/registry>
mmengine.config <api/config>
mmengine.runner <api/runner>
Expand Down
30 changes: 30 additions & 0 deletions docs/zh_cn/api/analysis.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
.. role:: hidden
:class: hidden-section

mmengine.analysis
===================================

.. contents:: mmengine.analysis
:depth: 2
:local:
:backlinks: top

.. currentmodule:: mmengine.analysis

.. autosummary::
:toctree: generated
:nosignatures:
:template: classtemplate.rst

ActivationAnalyzer
FlopAnalyzer

.. autosummary::
:toctree: generated
:nosignatures:

activation_count
flop_count
parameter_count
parameter_count_table
get_model_complexity_info
1 change: 1 addition & 0 deletions docs/zh_cn/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
:maxdepth: 2
:caption: API 文档

mmengine.analysis <api/analysis>
mmengine.registry <api/registry>
mmengine.config <api/config>
mmengine.runner <api/runner>
Expand Down
10 changes: 10 additions & 0 deletions mmengine/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer,
activation_count, flop_count,
parameter_count, parameter_count_table)
from .print_helper import get_model_complexity_info

__all__ = [
'FlopAnalyzer', 'ActivationAnalyzer', 'flop_count', 'activation_count',
'parameter_count', 'parameter_count_table', 'get_model_complexity_info'
]