-
Notifications
You must be signed in to change notification settings - Fork 323
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] get_model_complexity_info() supports multiple inputs (#1065)
- Loading branch information
Showing
2 changed files
with
148 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
|
||
import pytest | ||
import torch | ||
import torch.nn as nn | ||
|
||
from mmengine.analysis.complexity_analysis import FlopAnalyzer, parameter_count | ||
from mmengine.analysis.print_helper import get_model_complexity_info | ||
from mmengine.utils import digit_version | ||
from mmengine.utils.dl_utils import TORCH_VERSION | ||
|
||
|
||
class NetAcceptOneTensor(nn.Module): | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.l1 = nn.Linear(in_features=5, out_features=6) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
out = self.l1(x) | ||
return out | ||
|
||
|
||
class NetAcceptTwoTensors(nn.Module): | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.l1 = nn.Linear(in_features=5, out_features=6) | ||
self.l2 = nn.Linear(in_features=7, out_features=6) | ||
|
||
def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: | ||
out = self.l1(x1) + self.l2(x2) | ||
return out | ||
|
||
|
||
class NetAcceptOneTensorAndOneScalar(nn.Module): | ||
|
||
def __init__(self) -> None: | ||
super().__init__() | ||
self.l1 = nn.Linear(in_features=5, out_features=6) | ||
self.l2 = nn.Linear(in_features=5, out_features=6) | ||
|
||
def forward(self, x1: torch.Tensor, r) -> torch.Tensor: | ||
out = r * self.l1(x1) + (1 - r) * self.l2(x1) | ||
return out | ||
|
||
|
||
def test_get_model_complexity_info(): | ||
input1 = torch.randn(1, 9, 5) | ||
input_shape1 = (9, 5) | ||
input2 = torch.randn(1, 9, 7) | ||
input_shape2 = (9, 7) | ||
scalar = 0.3 | ||
|
||
# test a network that accepts one tensor as input | ||
model = NetAcceptOneTensor() | ||
complexity_info = get_model_complexity_info(model=model, inputs=input1) | ||
flops = FlopAnalyzer(model=model, inputs=input1).total() | ||
params = parameter_count(model=model)[''] | ||
assert complexity_info['flops'] == flops | ||
assert complexity_info['params'] == params | ||
|
||
complexity_info = get_model_complexity_info( | ||
model=model, input_shape=input_shape1) | ||
flops = FlopAnalyzer( | ||
model=model, inputs=(torch.randn(1, *input_shape1), )).total() | ||
assert complexity_info['flops'] == flops | ||
|
||
# test a network that accepts two tensors as input | ||
model = NetAcceptTwoTensors() | ||
complexity_info = get_model_complexity_info( | ||
model=model, inputs=(input1, input2)) | ||
flops = FlopAnalyzer(model=model, inputs=(input1, input2)).total() | ||
params = parameter_count(model=model)[''] | ||
assert complexity_info['flops'] == flops | ||
assert complexity_info['params'] == params | ||
|
||
complexity_info = get_model_complexity_info( | ||
model=model, input_shape=(input_shape1, input_shape2)) | ||
inputs = (torch.randn(1, *input_shape1), torch.randn(1, *input_shape2)) | ||
flops = FlopAnalyzer(model=model, inputs=inputs).total() | ||
assert complexity_info['flops'] == flops | ||
|
||
# test a network that accepts one tensor and one scalar as input | ||
model = NetAcceptOneTensorAndOneScalar() | ||
# For pytorch<1.9, a scalar input is not acceptable for torch.jit, | ||
# wrap it to `torch.tensor`. See https://github.com/pytorch/pytorch/blob/cd9dd653e98534b5d3a9f2576df2feda40916f1d/torch/csrc/jit/python/python_arg_flatten.cpp#L90. # noqa: E501 | ||
scalar = torch.tensor([ | ||
scalar | ||
]) if digit_version(TORCH_VERSION) < digit_version('1.9.0') else scalar | ||
complexity_info = get_model_complexity_info( | ||
model=model, inputs=(input1, scalar)) | ||
flops = FlopAnalyzer(model=model, inputs=(input1, scalar)).total() | ||
params = parameter_count(model=model)[''] | ||
assert complexity_info['flops'] == flops | ||
assert complexity_info['params'] == params | ||
|
||
# `get_model_complexity_info()` should throw `ValueError` | ||
# when neithor `inputs` nor `input_shape` is specified | ||
with pytest.raises(ValueError, match='should be set'): | ||
get_model_complexity_info(model) | ||
|
||
# `get_model_complexity_info()` should throw `ValueError` | ||
# when both `inputs` and `input_shape` are specified | ||
model = NetAcceptOneTensor() | ||
with pytest.raises(ValueError, match='cannot be both set'): | ||
get_model_complexity_info( | ||
model, inputs=input1, input_shape=input_shape1) |