Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] input_shape of get_model_complexity_info() for multiple tensors #1065

Merged
merged 17 commits into from
Apr 23, 2023
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
49 changes: 43 additions & 6 deletions mmengine/analysis/print_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from rich.table import Table
from torch import nn

from ..utils import is_tuple_of
sjiang95 marked this conversation as resolved.
Show resolved Hide resolved
from .complexity_analysis import (ActivationAnalyzer, FlopAnalyzer,
parameter_count)

Expand Down Expand Up @@ -675,19 +676,41 @@ def complexity_stats_table(

def get_model_complexity_info(
model: nn.Module,
input_shape: Optional[tuple] = None,
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None,
input_shape: Union[Tuple[int, ...], Tuple[Tuple[int, ...], ...],
None] = None,
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], Tuple[Any, ...],
None] = None,
show_table: bool = True,
show_arch: bool = True,
):
"""Interface to get the complexity of a model.

The parameter `inputs` are fed to the forward method of model.
If `input` is not specified, the `input_shape` is required and
it will be used to construct the dummy input fed to model.
If the forward of model requires two or more inputs, the `inputs`
should be a tuple of tensor or the `input_shape` should be a tuple
of tuple which each element will be constructed into a dumpy input.

Examples:
>>> # the forward of model accepts only one input
>>> input_shapes = (3, 224, 224)
>>> # the following tuple of one tensor will be constructed
>>> inputs = tuple(torch.randn(1, 3, 224, 224),)
>>>
>>> # the forward of model accepts two or more inputs
>>> input_shapes = ((3, 224, 224), (3, 10))
>>> # the following tuple of tensors will be constructed
>>> inputs = tuple(torch.randn(1, 3, 224, 224), torch.randn(1, 3, 10))
sjiang95 marked this conversation as resolved.
Show resolved Hide resolved

Args:
model (nn.Module): The model to analyze.
input_shape (tuple, optional): The input shape of the model.
If inputs is not specified, the input_shape should be set.
input_shape (Union[Tuple[int, ...], Tuple[Tuple[int, ...]], None]):
The input shape of the model.
If "inputs" is not specified, the "input_shape" should be set.
Defaults to None.
inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]):
inputs (torch.Tensor, tuple[torch.Tensor, ...] or Tuple[Any, ...],\
optional]):
The input tensor(s) of the model. If not given the input tensor
will be generated automatically with the given input_shape.
Defaults to None.
Expand All @@ -705,7 +728,21 @@ def get_model_complexity_info(
raise ValueError('"input_shape" and "inputs" cannot be both set.')

if inputs is None:
inputs = (torch.randn(1, *input_shape), )
if is_tuple_of(input_shape, int): # tuple of int, construct one tensor
inputs = (torch.randn(1, *input_shape), )
elif is_tuple_of(input_shape, tuple) and all([
is_tuple_of(one_input_shape, int)
for one_input_shape in input_shape
]): # tuple of tuple of int, construct multiple tensors
inputs = tuple([
torch.randn(1, *one_input_shape)
for one_input_shape in input_shape
])
else:
raise ValueError(
'"input_shape" should be either a `tuple of int` (to construct'
'one input tensor) or a `tuple of tuple of int` (to construct'
'multiple input tensors).')

flop_handler = FlopAnalyzer(model, inputs)
activation_handler = ActivationAnalyzer(model, inputs)
Expand Down
147 changes: 147 additions & 0 deletions tests/test_analysis/test_print_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
# Copyright (c) OpenMMLab. All rights reserved.
import unittest

import torch
import torch.nn as nn

from mmengine.analysis.complexity_analysis import FlopAnalyzer, parameter_count
from mmengine.analysis.print_helper import get_model_complexity_info
from mmengine.utils import digit_version
from mmengine.utils.dl_utils import TORCH_VERSION


class NetAcceptOneTensor(nn.Module):

def __init__(self) -> None:
super().__init__()
self.l1 = nn.Linear(in_features=5, out_features=6)

def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.l1(x)
return out


class NetAcceptTwoTensors(nn.Module):

def __init__(self) -> None:
super().__init__()
self.l1 = nn.Linear(in_features=5, out_features=6)
self.l2 = nn.Linear(in_features=7, out_features=6)

def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
out = self.l1(x1) + self.l2(x2)
return out


class NetAcceptOneTensorNOneScalar(nn.Module):
sjiang95 marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self) -> None:
super().__init__()
self.l1 = nn.Linear(in_features=5, out_features=6)
self.l2 = nn.Linear(in_features=5, out_features=6)

def forward(self, x1: torch.Tensor, r) -> torch.Tensor:
out = r * self.l1(x1) + (1 - r) * self.l2(x1)
return out


class TestGetModelCompexityInfo(unittest.TestCase):
"""Unittest for function get_model_complexity_info()

Test use cases of variant `input_shape` and `input` combinations.
"""

def setUp(self) -> None:
"""Create test elements (tensors, scalars, etc.) once for all."""
self.t1 = torch.randn(1, 9, 5)
self.shape1 = (9, 5)
self.t2 = torch.randn(1, 9, 7)
self.shape2 = (9, 7)
self.scalar = 0.3

def test_oneTensor(self) -> None:
"""Test a network that accept one tensor as input."""
model = NetAcceptOneTensor()
input = self.t1
dict_complexity = get_model_complexity_info(model=model, inputs=input)
self.assertEqual(dict_complexity['flops'],
FlopAnalyzer(model=model, inputs=input).total())
self.assertEqual(dict_complexity['params'],
parameter_count(model=model)[''])

def test_oneShape(self) -> None:
"""Test a network that accept one tensor as input."""
model = NetAcceptOneTensor()
input_shape = self.shape1
dict_complexity = get_model_complexity_info(
model=model, input_shape=input_shape)
self.assertEqual(
dict_complexity['flops'],
FlopAnalyzer(model=model,
inputs=(torch.randn(1, *input_shape), )).total())
self.assertEqual(dict_complexity['params'],
parameter_count(model=model)[''])

def test_twoTensors(self) -> None:
"""Test a network that accept two tensors as input."""
model = NetAcceptTwoTensors()
input1 = self.t1
input2 = self.t2
dict_complexity = get_model_complexity_info(
model=model, inputs=(input1, input2))
self.assertEqual(
dict_complexity['flops'],
FlopAnalyzer(model=model, inputs=(input1, input2)).total())
self.assertEqual(dict_complexity['params'],
parameter_count(model=model)[''])

def test_twoShapes(self) -> None:
"""Test a network that accept two tensors as input."""
model = NetAcceptTwoTensors()
input_shape1 = self.shape1
input_shape2 = self.shape2
dict_complexity = get_model_complexity_info(
model=model, input_shape=(input_shape1, input_shape2))
self.assertEqual(
dict_complexity['flops'],
FlopAnalyzer(
model=model,
inputs=(torch.randn(1, *input_shape1),
torch.randn(1, *input_shape2))).total())
self.assertEqual(dict_complexity['params'],
parameter_count(model=model)[''])

def test_oneTensorNOneScalar(self) -> None:
"""Test a network that accept one tensor and one scalar as input."""
model = NetAcceptOneTensorNOneScalar()
input = self.t1
# For pytorch<1.9, a scalar input is not acceptable for torch.jit,
# wrap it to `torch.tensor`. See https://github.com/pytorch/pytorch/blob/cd9dd653e98534b5d3a9f2576df2feda40916f1d/torch/csrc/jit/python/python_arg_flatten.cpp#L90. # noqa: E501
scalar = torch.tensor([self.scalar]) if digit_version(
TORCH_VERSION) < digit_version('1.9.0') else self.scalar
dict_complexity = get_model_complexity_info(
model=model, inputs=(input, scalar))
self.assertEqual(
dict_complexity['flops'],
FlopAnalyzer(model=model, inputs=(input, scalar)).total())
self.assertEqual(dict_complexity['params'],
parameter_count(model=model)[''])

def test_provideBothInputsNInputshape(self) -> None:
"""The function `get_model_complexity_info()` should throw `ValueError`
when both `inputs` and `input_shape` are specified."""
model = NetAcceptOneTensor()
input = self.t1
input_shape = self.shape1
self.assertRaises(
ValueError,
get_model_complexity_info,
model=model,
inputs=input,
input_shape=input_shape)

def test_provideNoneOfInputsNInputshape(self) -> None:
"""The function `get_model_complexity_info()` should throw `ValueError`
when neithor `inputs` nor `input_shape` is specified."""
model = NetAcceptOneTensor()
self.assertRaises(ValueError, get_model_complexity_info, model=model)