Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] input_shape of get_model_complexity_info() for multiple tensors #1065

Merged
merged 17 commits into from
Apr 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
46 changes: 40 additions & 6 deletions mmengine/analysis/print_helper.py
Expand Up @@ -12,6 +12,7 @@
from rich.table import Table
from torch import nn

from mmengine.utils import is_tuple_of
from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer,
parameter_count)

Expand Down Expand Up @@ -675,19 +676,38 @@ def complexity_stats_table(

def get_model_complexity_info(
model: nn.Module,
input_shape: Optional[tuple] = None,
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None,
input_shape: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...],
None] = None,
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[Any, ...],
None] = None,
show_table: bool = True,
show_arch: bool = True,
):
"""Interface to get the complexity of a model.

The parameter `inputs` are fed to the forward method of model.
If `inputs` is not specified, the `input_shape` is required and
it will be used to construct the dummy input fed to model.
If the forward of model requires two or more inputs, the `inputs`
should be a tuple of tensor or the `input_shape` should be a tuple
of tuple which each element will be constructed into a dumpy input.

Examples:
>>> # the forward of model accepts only one input
>>> input_shape = (3, 224, 224)
>>> get_model_complexity_info(model, input_shape=input_shape)
>>> # the forward of model accepts two or more inputs
>>> input_shape = ((3, 224, 224), (3, 10))
>>> get_model_complexity_info(model, input_shape=input_shape)

Args:
model (nn.Module): The model to analyze.
input_shape (tuple, optional): The input shape of the model.
If inputs is not specified, the input_shape should be set.
input_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...]], None]):
The input shape of the model.
If "inputs" is not specified, the "input_shape" should be set.
Defaults to None.
inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]):
inputs (torch.Tensor, tuple[torch.Tensor, ...] or Tuple[Any, ...],\
optional]):
The input tensor(s) of the model. If not given the input tensor
will be generated automatically with the given input_shape.
Defaults to None.
Expand All @@ -705,7 +725,21 @@ def get_model_complexity_info(
raise ValueError('"input_shape" and "inputs" cannot be both set.')

if inputs is None:
inputs = (torch.randn(1, *input_shape), )
if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
inputs = (torch.randn(1, *input_shape), )
elif is_tuple_of(input_shape, tuple) and all([
is_tuple_of(one_input_shape, int)
for one_input_shape in input_shape # type: ignore
]): # tuple of tuple of int, construct multiple tensors
inputs = tuple([
torch.randn(1, *one_input_shape)
for one_input_shape in input_shape # type: ignore
])
else:
raise ValueError(
'"input_shape" should be either a `tuple of int` (to construct'
'one input tensor) or a `tuple of tuple of int` (to construct'
'multiple input tensors).')

flop_handler = FlopAnalyzer(model, inputs)
activation_handler = ActivationAnalyzer(model, inputs)
Expand Down
108 changes: 108 additions & 0 deletions tests/test_analysis/test_print_helper.py
@@ -0,0 +1,108 @@
# Copyright (c) OpenMMLab. All rights reserved.

import pytest
import torch
import torch.nn as nn

from mmengine.analysis.complexity_analysis import FlopAnalyzer, parameter_count
from mmengine.analysis.print_helper import get_model_complexity_info
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION


class NetAcceptOneTensor(nn.Module):

def __init__(self) -> None:
super().__init__()
self.l1 = nn.Linear(in_features=5, out_features=6)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.l1(x)
return out


class NetAcceptTwoTensors(nn.Module):

def __init__(self) -> None:
super().__init__()
self.l1 = nn.Linear(in_features=5, out_features=6)
self.l2 = nn.Linear(in_features=7, out_features=6)

def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
out = self.l1(x1) + self.l2(x2)
return out


class NetAcceptOneTensorAndOneScalar(nn.Module):

def __init__(self) -> None:
super().__init__()
self.l1 = nn.Linear(in_features=5, out_features=6)
self.l2 = nn.Linear(in_features=5, out_features=6)

def forward(self, x1: torch.Tensor, r) -> torch.Tensor:
out = r * self.l1(x1) + (1 - r) * self.l2(x1)
return out


def test_get_model_complexity_info():
input1 = torch.randn(1, 9, 5)
input_shape1 = (9, 5)
input2 = torch.randn(1, 9, 7)
input_shape2 = (9, 7)
scalar = 0.3

# test a network that accepts one tensor as input
model = NetAcceptOneTensor()
complexity_info = get_model_complexity_info(model=model, inputs=input1)
flops = FlopAnalyzer(model=model, inputs=input1).total()
params = parameter_count(model=model)['']
assert complexity_info['flops'] == flops
assert complexity_info['params'] == params

complexity_info = get_model_complexity_info(
model=model, input_shape=input_shape1)
flops = FlopAnalyzer(
model=model, inputs=(torch.randn(1, *input_shape1), )).total()
assert complexity_info['flops'] == flops

# test a network that accepts two tensors as input
model = NetAcceptTwoTensors()
complexity_info = get_model_complexity_info(
model=model, inputs=(input1, input2))
flops = FlopAnalyzer(model=model, inputs=(input1, input2)).total()
params = parameter_count(model=model)['']
assert complexity_info['flops'] == flops
assert complexity_info['params'] == params

complexity_info = get_model_complexity_info(
model=model, input_shape=(input_shape1, input_shape2))
inputs = (torch.randn(1, *input_shape1), torch.randn(1, *input_shape2))
flops = FlopAnalyzer(model=model, inputs=inputs).total()
assert complexity_info['flops'] == flops

# test a network that accepts one tensor and one scalar as input
model = NetAcceptOneTensorAndOneScalar()
# For pytorch<1.9, a scalar input is not acceptable for torch.jit,
# wrap it to `torch.tensor`. See https://github.com/pytorch/pytorch/blob/cd9dd653e98534b5d3a9f2576df2feda40916f1d/torch/csrc/jit/python/python_arg_flatten.cpp#L90. # noqa: E501
scalar = torch.tensor([
scalar
]) if digit_version(TORCH_VERSION) < digit_version('1.9.0') else scalar
complexity_info = get_model_complexity_info(
model=model, inputs=(input1, scalar))
flops = FlopAnalyzer(model=model, inputs=(input1, scalar)).total()
params = parameter_count(model=model)['']
assert complexity_info['flops'] == flops
assert complexity_info['params'] == params

# `get_model_complexity_info()` should throw `ValueError`
# when neithor `inputs` nor `input_shape` is specified
with pytest.raises(ValueError, match='should be set'):
get_model_complexity_info(model)

# `get_model_complexity_info()` should throw `ValueError`
# when both `inputs` and `input_shape` are specified
model = NetAcceptOneTensor()
with pytest.raises(ValueError, match='cannot be both set'):
get_model_complexity_info(
model, inputs=input1, input_shape=input_shape1)