Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] input_shape of get_model_complexity_info() for multiple tensors #1065

Merged
merged 17 commits into from Apr 23, 2023

Conversation

sjiang95
Copy link
Contributor

@sjiang95 sjiang95 commented Apr 10, 2023

Motivation

Previously, the arg input_shape of get_model_complexity_info() is used to construct only one input tensor.

if inputs is None:
inputs = (torch.randn(1, *input_shape), )

It is inconvenient if a custom model requires more than one inputs. Such as

class mymodel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.l1=nn.Linear(in_features=5,out_features=6)
        self.l2=nn.Linear(in_features=7,out_features=6)
        
    def forward(self, x1, x2):
        out=self.l1(x1)+self.l2(x2)
        return out

This PR add support of using input_shape to construct multiple input tensors.

Modification

If inputs is None, check input_shape.
If input_shape is a tuple of int, one input tensor will be constructed as before.
If input_shape is a tuple of tuple of int, multiple input tensors will be constructed. For example,

input_shape = ((5,6), (7,8))

a tuple of tensors

inputs = (torch.randn(1, 5, 6), torch.randn(1, 7, 8))

will be constructed and fed into the model.

BC-breaking (Optional)

For better literacy of the code, I would also like to change the variable name from input_shape to input_shapes (this change has not been applied yet), since it now supports multiple tensors. But this definitely requires modification of some downstream repos.
Thus, I require discussion with reviewers about this.

Use cases (Optional)

Below is a code snippet that utilized the added feature.

import torch
import torch.nn as nn
from mmengine.analysis import get_model_complexity_info

class mymodel(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.l1=nn.Linear(in_features=5,out_features=6)
        self.l2=nn.Linear(in_features=7,out_features=6)
        
    def forward(self, x1, x2):
        out=self.l1(x1)+self.l2(x2)
        return out

def main():
    model=mymodel()
    complexity=get_model_complexity_info(model=model, input_shape=((66,5), (66,7))) # two tuples of int to construct to tensors
    print(complexity['flops'])
    print(complexity['params'])

if __name__=="__main__":
    main()

Checklist

  • Pre-commit or other linting tools are used to fix the potential lint issues.
  • The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  • If the modification has potential influence on downstream projects, this PR should be tested with downstream projects, like MMDet or MMCls.
  • The documentation has been modified accordingly, like docstring or example tutorials.

Copy link
Collaborator

@HAOCHENYE HAOCHENYE left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your contribution. The overall logic makes sense to me, and you can use mmengine.utils.is_seq_of to simplify the logic

@sjiang95 sjiang95 requested a review from HAOCHENYE April 11, 2023 01:09
@zhouzaida
Copy link
Member

Hi @sjiang95 , if you want to rename a parameter of function, you can use deprecated_api_warning.

def deprecated_api_warning(name_dict: dict,
cls_name: Optional[str] = None) -> Callable:
"""A decorator to check if some arguments are deprecate and try to replace
deprecate src_arg_name to dst_arg_name.
Args:
name_dict(dict):
key (str): Deprecate argument names.
val (str): Expected argument names.
Returns:
func: New function.
"""

BTW, input_shape is also OK here and we do not need to rename it to input_shapes.

@zhouzaida
Copy link
Member

It would be helpful if adding some unit tests in tests/test_analysis/test_print_helper.py (needs to create a new file) for this function get_model_complexity_info.

@zhouzaida zhouzaida added this to the 0.7.3 milestone Apr 21, 2023
Previously, the arg `input_shape` of get_model_complexity_info() is used
to construct only one input tensor.
This PR add support of using `input_shape` to construct multiple input
tensors.

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>
Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>
@sjiang95 sjiang95 requested a review from zhouzaida April 22, 2023 07:30
@sjiang95
Copy link
Contributor Author

Hi @zhouzaida,

It cannot pass the mypy test due to None type of input_shape, even if that scenario is already handled by the type check of input_shape:

if input_shape is None and inputs is None:
raise ValueError('One of "input_shape" and "inputs" should be set.')
elif input_shape is not None and inputs is not None:
raise ValueError('"input_shape" and "inputs" cannot be both set.')

Any suggestions to fix this?

@zhouzaida
Copy link
Member

There are two ways to fix the mypy error.

One way is to use an assert clause.

    if inputs is None:
+       assert isintance(input_shape, tuple)
        if is_tuple_of(input_shape, int):  # tuple of int, construct one tensor
            inputs = (torch.randn(1, *input_shape), )
        elif is_tuple_of(input_shape, tuple) and all([
                is_tuple_of(one_input_shape, int)
                for one_input_shape in input_shape
        ]):  # tuple of tuple of int, construct multiple tensors
            inputs = tuple([
                torch.randn(1, *one_input_shape)
                for one_input_shape in input_shape
            ])
        else:
            raise ValueError(
                '"input_shape" should be either a `tuple of int` (to construct'
                'one input tensor) or a `tuple of tuple of int` (to construct'
                'multiple input tensors).')

Another way is to use # type: ignore.

    if inputs is None:
        if is_tuple_of(input_shape, int):  # tuple of int, construct one tensor
            inputs = (torch.randn(1, *input_shape), )
        elif is_tuple_of(input_shape, tuple) and all([
                is_tuple_of(one_input_shape, int)
-                for one_input_shape in input_shape
+                for one_input_shape in input_shape  # type: ignore
        ]):  # tuple of tuple of int, construct multiple tensors
            inputs = tuple([
                torch.randn(1, *one_input_shape)
-                for one_input_shape in input_shape
+                for one_input_shape in input_shape  # type: ignore
            ])
        else:
            raise ValueError(
                '"input_shape" should be either a `tuple of int` (to construct'
                'one input tensor) or a `tuple of tuple of int` (to construct'
                'multiple input tensors).')

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>
Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>
mmengine/analysis/print_helper.py Outdated Show resolved Hide resolved
mmengine/analysis/print_helper.py Outdated Show resolved Hide resolved
tests/test_analysis/test_print_helper.py Outdated Show resolved Hide resolved
sjiang95 and others added 3 commits April 23, 2023 16:27
accept suggestions of example block

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
accept relative import

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
accept test unit suggestions

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
@sjiang95 sjiang95 requested a review from zhouzaida April 23, 2023 07:30
Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>
@zhouzaida zhouzaida merged commit fafb476 into open-mmlab:main Apr 23, 2023
16 of 19 checks passed
@sjiang95 sjiang95 deleted the inputshape branch April 23, 2023 08:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants