Skip to content

Commit

Permalink
[Enhance] Complement type hint of get_model_complexity_info() (#1064)
Browse files Browse the repository at this point in the history
* Complement type hint of get_model_complexity_info()

The type of `inputs` should be one of `torch.Tensor`,
`tuple[torch.Tensor, ...]` and `None`.

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>

* Update print_helper.py

---------

Signed-off-by: Shengjiang QUAN <qsj287068067@126.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
  • Loading branch information
sjiang95 and zhouzaida committed Apr 10, 2023
1 parent b2ad221 commit 5da24ed
Showing 1 changed file with 10 additions and 7 deletions.
17 changes: 10 additions & 7 deletions mmengine/analysis/print_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# https://github.com/facebookresearch/fvcore/blob/main/fvcore/nn/print_model_statistics.py

from collections import defaultdict
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union

import torch
from rich import box
Expand Down Expand Up @@ -675,19 +675,22 @@ def complexity_stats_table(

def get_model_complexity_info(
model: nn.Module,
input_shape: tuple = None,
inputs: Optional[torch.Tensor] = None,
input_shape: Optional[tuple] = None,
inputs: Union[torch.Tensor, Tuple[torch.Tensor, ...], None] = None,
show_table: bool = True,
show_arch: bool = True,
):
"""Interface to get the complexity of a model.
Args:
model (nn.Module): The model to analyze.
input_shape (tuple): The input shape of the model.
inputs (torch.Tensor, optional): The input tensor of the model.
If not given the input tensor will be generated automatically
with the given input_shape.
input_shape (tuple, optional): The input shape of the model.
If inputs is not specified, the input_shape should be set.
Defaults to None.
inputs (torch.Tensor or tuple[torch.Tensor, ...], optional]):
The input tensor(s) of the model. If not given the input tensor
will be generated automatically with the given input_shape.
Defaults to None.
show_table (bool): Whether to show the complexity table.
Defaults to True.
show_arch (bool): Whether to show the complexity arch.
Expand Down

0 comments on commit 5da24ed

Please sign in to comment.