Skip to content

Commit

Permalink
Update docstring
Browse files Browse the repository at this point in the history
  • Loading branch information
tonysy committed Dec 26, 2022
1 parent 05c2a54 commit a591787
Showing 1 changed file with 109 additions and 89 deletions.
198 changes: 109 additions & 89 deletions mmengine/analysis/complexity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,60 +48,68 @@


class FlopAnalyzer(JitModelAnalysis):
"""Provides access to per-submodule model flop count obtained by
tracing a model with pytorch's jit tracing functionality. By default,
comes with standard flop counters for a few common operators.
Note that:
1. Flop is not a well-defined concept. We just produce our best
"""Provides access to per-submodule model flop count obtained by tracing a
model with pytorch's jit tracing functionality.
By default, comes with standard flop counters for a few common operators.
Note:
- Flop is not a well-defined concept. We just produce our best
estimate.
2. We count one fused multiply-add as one flop.
- We count one fused multiply-add as one flop.
Handles for additional operators may be added, or the default ones
overwritten, using the ``.set_op_handle(name, func)`` method.
See the method documentation for details.
Flop counts can be obtained as:
* ``.total(module_name="")``: total flop count for the module
* ``.by_operator(module_name="")``: flop counts for the module, as a
Counter over different operator types
* ``.by_module()``: Counter of flop counts for all submodules
* ``.by_module_and_operator()``: dictionary indexed by descendant of
Counters over different operator types
- ``.total(module_name="")``: total flop count for the module
- ``.by_operator(module_name="")``: flop counts for the module, as a
Counter over different operator types
- ``.by_module()``: Counter of flop counts for all submodules
- ``.by_module_and_operator()``: dictionary indexed by descendant of
Counters over different operator types
An operator is treated as within a module if it is executed inside the
module's ``__call__`` method. Note that this does not include calls to
other methods of the module or explicit calls to ``module.forward(...)``.
Modified from
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/flop_count.py
Example usage:
>>> import torch.nn as nn
>>> import torch
>>> class TestModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.fc = nn.Linear(in_features=1000, out_features=10)
... self.conv = nn.Conv2d(
... in_channels=3, out_channels=10, kernel_size=1
... )
... self.act = nn.ReLU()
... def forward(self, x):
... return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel()
>>> inputs = (torch.randn((1,3,10,10)),)
>>> flops = FlopAnalyzer(model, inputs)
>>> flops.total()
13000
>>> flops.total("fc")
10000
>>> flops.by_operator()
Counter({"addmm" : 10000, "conv" : 3000})
>>> flops.by_module()
Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0})
>>> flops.by_module_and_operator()
{"" : Counter({"addmm" : 10000, "conv" : 3000}),
"fc" : Counter({"addmm" : 10000}),
"conv" : Counter({"conv" : 3000}),
"act" : Counter()
}
Args:
model (nn.Module): The model to analyze.
inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model.
Examples:
>>> import torch.nn as nn
>>> import torch
>>> class TestModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.fc = nn.Linear(in_features=1000, out_features=10)
... self.conv = nn.Conv2d(
... in_channels=3, out_channels=10, kernel_size=1
... )
... self.act = nn.ReLU()
... def forward(self, x):
... return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel()
>>> inputs = (torch.randn((1,3,10,10)),)
>>> flops = FlopAnalyzer(model, inputs)
>>> flops.total()
13000
>>> flops.total("fc")
10000
>>> flops.by_operator()
Counter({"addmm" : 10000, "conv" : 3000})
>>> flops.by_module()
Counter({"" : 13000, "fc" : 10000, "conv" : 3000, "act" : 0})
>>> flops.by_module_and_operator()
{"" : Counter({"addmm" : 10000, "conv" : 3000}),
"fc" : Counter({"addmm" : 10000}),
"conv" : Counter({"conv" : 3000}),
"act" : Counter()
}
"""

def __init__(
Expand All @@ -117,54 +125,61 @@ def __init__(

class ActivationAnalyzer(JitModelAnalysis):
"""Provides access to per-submodule model activation count obtained by
tracing a model with pytorch's jit tracing functionality. By default, comes
with standard activation counters for convolutional and dot-product
operators. Handles for additional operators may be added, or the default
ones overwritten, using the ``.set_op_handle(name, func)`` method. See the
method documentation for details. Activation counts can be obtained as:
* ``.total(module_name="")``: total activation count for a module
* ``.by_operator(module_name="")``: activation counts for the module, as a
Counter over different operator types
* ``.by_module()``: Counter of activation counts for all submodules
* ``.by_module_and_operator()``: dictionary indexed by descendant of
Counters over different operator types
tracing a model with pytorch's jit tracing functionality.
By default, comes with standard activation counters for convolutional and
dot-product operators. Handles for additional operators may be added, or
the default ones overwritten, using the ``.set_op_handle(name, func)``
method. See the method documentation for details. Activation counts can be
obtained as:
- ``.total(module_name="")``: total activation count for a module
- ``.by_operator(module_name="")``: activation counts for the module,
as a Counter over different operator types
- ``.by_module()``: Counter of activation counts for all submodules
- ``.by_module_and_operator()``: dictionary indexed by descendant of
Counters over different operator types
An operator is treated as within a module if it is executed inside the
module's ``__call__`` method. Note that this does not include calls to
other methods of the module or explicit calls to ``module.forward(...)``.
Modified from
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/activation_count.py
Example usage:
>>> import torch.nn as nn
>>> import torch
>>> class TestModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.fc = nn.Linear(in_features=1000, out_features=10)
... self.conv = nn.Conv2d(
... in_channels=3, out_channels=10, kernel_size=1
... )
... self.act = nn.ReLU()
... def forward(self, x):
... return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel()
>>> inputs = (torch.randn((1,3,10,10)),)
>>> acts = ActivationAnalyzer(model, inputs)
>>> acts.total()
1010
>>> acts.total("fc")
10
>>> acts.by_operator()
Counter({"conv" : 1000, "addmm" : 10})
>>> acts.by_module()
Counter({"" : 1010, "fc" : 10, "conv" : 1000, "act" : 0})
>>> acts.by_module_and_operator()
{"" : Counter({"conv" : 1000, "addmm" : 10}),
"fc" : Counter({"addmm" : 10}),
"conv" : Counter({"conv" : 1000}),
"act" : Counter()
}
Args:
model (nn.Module): The model to analyze.
inputs (Union[Tensor, Tuple[Tensor, ...]]): The input to the model.
Examples:
>>> import torch.nn as nn
>>> import torch
>>> class TestModel(nn.Module):
... def __init__(self):
... super().__init__()
... self.fc = nn.Linear(in_features=1000, out_features=10)
... self.conv = nn.Conv2d(
... in_channels=3, out_channels=10, kernel_size=1
... )
... self.act = nn.ReLU()
... def forward(self, x):
... return self.fc(self.act(self.conv(x)).flatten(1))
>>> model = TestModel()
>>> inputs = (torch.randn((1,3,10,10)),)
>>> acts = ActivationAnalyzer(model, inputs)
>>> acts.total()
1010
>>> acts.total("fc")
10
>>> acts.by_operator()
Counter({"conv" : 1000, "addmm" : 10})
>>> acts.by_module()
Counter({"" : 1010, "fc" : 10, "conv" : 1000, "act" : 0})
>>> acts.by_module_and_operator()
{"" : Counter({"conv" : 1000, "addmm" : 10}),
"fc" : Counter({"addmm" : 10}),
"conv" : Counter({"conv" : 1000}),
"act" : Counter()
}
"""

def __init__(
Expand Down Expand Up @@ -198,6 +213,7 @@ def flop_count(
convolution and matmul and einsum. The key is operator name and
the value is a function that takes (inputs, outputs) of the op.
We count one Multiply-Add as one FLOP.
Returns:
tuple[defaultdict, Counter]: A dictionary that records the number of
gflops for each operation and a Counter that records the number of
Expand Down Expand Up @@ -231,6 +247,7 @@ def activation_count(
handlers for extra ops, or overwrite the existing handlers for
convolution and matmul. The key is operator name and the value
is a function that takes (inputs, outputs) of the op.
Returns:
tuple[defaultdict, Counter]: A dictionary that records the number of
activation (mega) for each operation and a Counter that records the
Expand All @@ -253,12 +270,13 @@ def parameter_count(model: nn.Module) -> typing.DefaultDict[str, int]:
https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/parameter_count.py
Args:
model: a torch module
model (nn.Module): the model to count parameters.
Returns:
dict (str-> int): the key is either a parameter name or a module name.
The value is the number of elements in the parameter, or in all
parameters of the module. The key "" corresponds to the total
number of parameters of the model.
The value is the number of elements in the parameter, or in all
parameters of the module. The key "" corresponds to the total
number of parameters of the model.
"""
r = defaultdict(int) # type: typing.DefaultDict[str, int]
for name, prm in model.named_parameters():
Expand Down Expand Up @@ -311,10 +329,12 @@ def parameter_count_table(model: nn.Module, max_depth: int = 3) -> str:
| backbone.bottom_up.res4 | 7.1M |
| backbone.bottom_up.res5 | 14.9M |
| ...... | ..... |
Args:
model: a torch module
model (nn.Module): the model to count parameters.
max_depth (int): maximum depth to recursively print submodules or
parameters
Returns:
str: the table to be printed
"""
Expand Down

0 comments on commit a591787

Please sign in to comment.