Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix counter mapping bug #331

Merged
merged 2 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
import sys
from functools import partial
from typing import Dict
from typing import Dict, List

import mmcv
import torch
import torch.nn as nn

Expand Down Expand Up @@ -497,9 +498,29 @@ def add_flops_params_counter_variable_or_reset(module):
module.__params__ = 0


def get_counter_type(module):
"""Get counter type of the module based on the module class name."""
return module.__class__.__name__ + 'Counter'
def get_counter_type(module) -> str:
"""Get counter type of the module based on the module class name.

If the current module counter_type is not in TASK_UTILS._module_dict,
it will search the base classes of the module to see if it matches any
base class counter_type.

Returns:
str: Counter type (or the base counter type) of the current module.
"""
counter_type = module.__class__.__name__ + 'Counter'
if counter_type not in TASK_UTILS._module_dict.keys():
old_counter_type = counter_type
assert nn.Module in module.__class__.mro()
for base_cls in module.__class__.mro():
if base_cls in get_modules_list():
counter_type = base_cls.__name__ + 'Counter'
from mmengine import MMLogger
logger = MMLogger.get_current_instance()
logger.warning(f'`{old_counter_type}` not in op_counters. '
f'Using `{counter_type}` instead.')
break
return counter_type


def is_supported_instance(module):
Expand All @@ -518,3 +539,54 @@ def remove_flops_params_counter_hook_function(module):
del module.__flops__
if hasattr(module, '__params__'):
del module.__params__


def get_modules_list() -> List:
return [
# convolutions
nn.Conv1d,
nn.Conv2d,
nn.Conv3d,
mmcv.cnn.bricks.Conv2d,
mmcv.cnn.bricks.Conv3d,
# activations
nn.ReLU,
nn.PReLU,
nn.ELU,
nn.LeakyReLU,
nn.ReLU6,
# poolings
nn.MaxPool1d,
nn.AvgPool1d,
nn.AvgPool2d,
nn.MaxPool2d,
nn.MaxPool3d,
nn.AvgPool3d,
mmcv.cnn.bricks.MaxPool2d,
mmcv.cnn.bricks.MaxPool3d,
nn.AdaptiveMaxPool1d,
nn.AdaptiveAvgPool1d,
nn.AdaptiveMaxPool2d,
nn.AdaptiveAvgPool2d,
nn.AdaptiveMaxPool3d,
nn.AdaptiveAvgPool3d,
# normalizations
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d,
nn.GroupNorm,
nn.InstanceNorm1d,
nn.InstanceNorm2d,
nn.InstanceNorm3d,
nn.LayerNorm,
# FC
nn.Linear,
mmcv.cnn.bricks.Linear,
# Upscale
nn.Upsample,
nn.UpsamplingNearest2d,
nn.UpsamplingBilinear2d,
# Deconvolution
nn.ConvTranspose2d,
mmcv.cnn.bricks.ConvTranspose2d,
]
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import pytest
import torch
from mmcv.cnn.bricks import Conv2dAdaptivePadding
from torch import Tensor
from torch.nn import Conv2d, Module, Parameter

Expand Down Expand Up @@ -124,8 +125,17 @@ def test_estimate(self) -> None:
flops_count = results['flops']
params_count = results['params']

self.assertGreater(flops_count, 0)
self.assertGreater(params_count, 0)
self.assertEqual(flops_count, 44.158)
self.assertEqual(params_count, 0.001)

fool_conv2d = Conv2dAdaptivePadding(3, 32, 3)
results = estimator.estimate(
model=fool_conv2d, flops_params_cfg=flops_params_cfg)
flops_count = results['flops']
params_count = results['params']

self.assertEqual(flops_count, 44.958)
self.assertEqual(params_count, 0.001)

def test_register_module(self) -> None:
fool_add_constant = FoolConvModule()
Expand All @@ -151,6 +161,17 @@ def test_disable_sepc_counter(self) -> None:
self.assertLess(rest_flops_count, 45.158)
self.assertLess(rest_params_count, 0.701)

fool_conv2d = Conv2dAdaptivePadding(3, 32, 3)
flops_params_cfg = dict(
input_shape=(1, 3, 224, 224), disabled_counters=['Conv2dCounter'])
rest_results = estimator.estimate(
model=fool_conv2d, flops_params_cfg=flops_params_cfg)
rest_flops_count = rest_results['flops']
rest_params_count = rest_results['params']

self.assertEqual(rest_flops_count, 0)
self.assertEqual(rest_params_count, 0)

def test_estimate_spec_module(self) -> None:
fool_add_constant = FoolConvModule()
flops_params_cfg = dict(
Expand Down