-
Notifications
You must be signed in to change notification settings - Fork 223
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Fix] Fix counter mapping bug #331
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,10 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
import sys | ||
import warnings | ||
from functools import partial | ||
from typing import Dict | ||
from typing import Dict, List | ||
|
||
import mmcv | ||
import torch | ||
import torch.nn as nn | ||
|
||
|
@@ -409,6 +411,12 @@ def add_flops_params_counter_hook_function(module): | |
|
||
else: | ||
counter_type = get_counter_type(module) | ||
if counter_type not in TASK_UTILS._module_dict.keys(): | ||
old_counter_type = counter_type | ||
counter_type = \ | ||
module.__class__.__base__.__name__ + 'Counter' | ||
warnings.warn(f'`{old_counter_type}` not in ' | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suggest using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
f'op_counters. Using `{counter_type}`') | ||
if (disabled_counters is None | ||
or counter_type not in disabled_counters): | ||
counter = TASK_UTILS.build( | ||
|
@@ -503,9 +511,13 @@ def get_counter_type(module): | |
|
||
|
||
def is_supported_instance(module): | ||
"""Judge whether the module is in TASK_UTILS registry or not.""" | ||
"""Judge whether the module can be countered or not.""" | ||
if get_counter_type(module) in TASK_UTILS._module_dict.keys(): | ||
return True | ||
else: | ||
for op in get_modules_list(): | ||
if issubclass(module.__class__.__base__, op): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. multiple inheritances should also be handled. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Use |
||
return True | ||
return False | ||
|
||
|
||
|
@@ -518,3 +530,54 @@ def remove_flops_params_counter_hook_function(module): | |
del module.__flops__ | ||
if hasattr(module, '__params__'): | ||
del module.__params__ | ||
|
||
|
||
def get_modules_list() -> List: | ||
return [ | ||
# convolutions | ||
nn.Conv1d, | ||
nn.Conv2d, | ||
nn.Conv3d, | ||
mmcv.cnn.bricks.Conv2d, | ||
mmcv.cnn.bricks.Conv3d, | ||
# activations | ||
nn.ReLU, | ||
nn.PReLU, | ||
nn.ELU, | ||
nn.LeakyReLU, | ||
nn.ReLU6, | ||
# poolings | ||
nn.MaxPool1d, | ||
nn.AvgPool1d, | ||
nn.AvgPool2d, | ||
nn.MaxPool2d, | ||
nn.MaxPool3d, | ||
nn.AvgPool3d, | ||
mmcv.cnn.bricks.MaxPool2d, | ||
mmcv.cnn.bricks.MaxPool3d, | ||
nn.AdaptiveMaxPool1d, | ||
nn.AdaptiveAvgPool1d, | ||
nn.AdaptiveMaxPool2d, | ||
nn.AdaptiveAvgPool2d, | ||
nn.AdaptiveMaxPool3d, | ||
nn.AdaptiveAvgPool3d, | ||
# normalizations | ||
nn.BatchNorm1d, | ||
nn.BatchNorm2d, | ||
nn.BatchNorm3d, | ||
nn.GroupNorm, | ||
nn.InstanceNorm1d, | ||
nn.InstanceNorm2d, | ||
nn.InstanceNorm3d, | ||
nn.LayerNorm, | ||
# FC | ||
nn.Linear, | ||
mmcv.cnn.bricks.Linear, | ||
# Upscale | ||
nn.Upsample, | ||
nn.UpsamplingNearest2d, | ||
nn.UpsamplingBilinear2d, | ||
# Deconvolution | ||
nn.ConvTranspose2d, | ||
mmcv.cnn.bricks.ConvTranspose2d, | ||
] |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,7 @@ | |
|
||
import pytest | ||
import torch | ||
from mmcv.cnn.bricks import Conv2dAdaptivePadding | ||
from torch import Tensor | ||
from torch.nn import Conv2d, Module, Parameter | ||
|
||
|
@@ -127,6 +128,15 @@ def test_estimate(self) -> None: | |
self.assertGreater(flops_count, 0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add more strict check constraints for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done. |
||
self.assertGreater(params_count, 0) | ||
|
||
fool_conv2d = Conv2dAdaptivePadding(3, 32, 3) | ||
results = estimator.estimate( | ||
model=fool_conv2d, flops_params_cfg=flops_params_cfg) | ||
flops_count = results['flops'] | ||
params_count = results['params'] | ||
|
||
self.assertGreater(flops_count, 0) | ||
self.assertGreater(params_count, 0) | ||
|
||
def test_register_module(self) -> None: | ||
fool_add_constant = FoolConvModule() | ||
flops_params_cfg = dict(input_shape=(1, 3, 224, 224)) | ||
|
@@ -151,6 +161,17 @@ def test_disable_sepc_counter(self) -> None: | |
self.assertLess(rest_flops_count, 45.158) | ||
self.assertLess(rest_params_count, 0.701) | ||
|
||
fool_conv2d = Conv2dAdaptivePadding(3, 32, 3) | ||
flops_params_cfg = dict( | ||
input_shape=(1, 3, 224, 224), disabled_counters=['Conv2dCounter']) | ||
rest_results = estimator.estimate( | ||
model=fool_conv2d, flops_params_cfg=flops_params_cfg) | ||
rest_flops_count = rest_results['flops'] | ||
rest_params_count = rest_results['params'] | ||
|
||
self.assertEqual(rest_flops_count, 0) | ||
self.assertEqual(rest_params_count, 0) | ||
|
||
def test_estimate_spec_module(self) -> None: | ||
fool_add_constant = FoolConvModule() | ||
flops_params_cfg = dict( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Move the logic into
get_counter_type
and then you can also refactoris_supported_instance
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done.