Skip to content

Commit

Permalink
[Improvement] Remove dependency package warnings (#583)
Browse files Browse the repository at this point in the history
* fix mmdet warnings

* add unit test

* remove mmcv-full uninstall warnings

* add changelog

* remove rebundent changelog of refactor unittest

* fix conflict

* fix warnings of mmcv-full uninstalled
  • Loading branch information
congee524 committed Feb 1, 2021
1 parent 5c1093e commit 509118c
Show file tree
Hide file tree
Showing 16 changed files with 117 additions and 40 deletions.
12 changes: 10 additions & 2 deletions demo/demo_spatiotemporal_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import os.path as osp
import shutil
import warnings

import cv2
import mmcv
Expand All @@ -13,11 +12,20 @@
from tqdm import tqdm

from mmaction.models import build_detector
from mmaction.utils import import_module_error_func

try:
from mmdet.apis import inference_detector, init_detector
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use demo_spatiotempoal_det')

@import_module_error_func('mmdet')
def inference_detector(*args, **kwargs):
pass

@import_module_error_func('mmdet')
def init_detector(*args, **kwargs):
pass


try:
import moviepy.editor as mpy
Expand Down
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

- Refactor EvalHook ([#395](https://github.com/open-mmlab/mmaction2/pull/395))
- Refactor AVA hook ([#567](https://github.com/open-mmlab/mmaction2/pull/567))
- Refactor unit test structure ([#433](https://github.com/open-mmlab/mmaction2/pull/433))
- Add repo citation ([#545](https://github.com/open-mmlab/mmaction2/pull/545))
- Add dataset size of Kinetics400 ([#503](https://github.com/open-mmlab/mmaction2/pull/503))
- Add lazy operation docs ([#504](https://github.com/open-mmlab/mmaction2/pull/504))
Expand All @@ -31,6 +30,7 @@
- Add config tag in dataset README ([#540](https://github.com/open-mmlab/mmaction2/pull/540))
- Add solution for markdownlint installation issue ([#497](https://github.com/open-mmlab/mmaction2/pull/497))
- Add dataset overview in readthedocs ([#548](https://github.com/open-mmlab/mmaction2/pull/548))
- Modify the trigger mode of the warnings of missing mmdet ([583](https://github.com/open-mmlab/mmaction2/pull/583))
- Refactor config structure ([#488](https://github.com/open-mmlab/mmaction2/pull/488), [#572](https://github.com/open-mmlab/mmaction2/pull/572))
- Refactor unittest structure ([#433](https://github.com/open-mmlab/mmaction2/pull/433))

Expand Down
7 changes: 3 additions & 4 deletions mmaction/core/bbox/assigners/max_iou_assigner_ava.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import warnings

import torch

from mmaction.utils import import_module_error_class

try:
from mmdet.core.bbox import AssignResult, MaxIoUAssigner
from mmdet.core.bbox.builder import BBOX_ASSIGNERS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use AssignResult, MaxIoUAssigner '
'and BBOX_ASSIGNERS')
mmdet_imported = False

if mmdet_imported:
Expand Down Expand Up @@ -135,5 +133,6 @@ def assign_wrt_overlaps(self, overlaps, gt_labels=None):

else:
# define an empty class, so that can be imported
@import_module_error_class('mmdet')
class MaxIoUAssignerAVA:
pass
9 changes: 6 additions & 3 deletions mmaction/core/evaluation/recall.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import warnings

import numpy as np
import torch

from mmaction.utils import import_module_error_func

try:
from mmdet.core import bbox_overlaps
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use bbox_overlaps')

@import_module_error_func('mmdet')
def bbox_overlaps(*args, **kwargs):
pass


def _recalls(all_ious, proposal_nums, thrs):
Expand Down
3 changes: 0 additions & 3 deletions mmaction/models/backbones/resnet3d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, NonLocal3d, build_activation_layer,
Expand All @@ -15,7 +13,6 @@
from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use MMDET_SHARED_HEADS')
mmdet_imported = False


Expand Down
3 changes: 0 additions & 3 deletions mmaction/models/backbones/resnet3d_slowfast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, kaiming_init
Expand All @@ -14,7 +12,6 @@
from mmdet.models import BACKBONES as MMDET_BACKBONES
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use MMDET_BACKBONES')
mmdet_imported = False


Expand Down
3 changes: 0 additions & 3 deletions mmaction/models/backbones/resnet3d_slowonly.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import warnings

from ..registry import BACKBONES
from .resnet3d_slowfast import ResNet3dPathway

try:
from mmdet.models.builder import BACKBONES as MMDET_BACKBONES
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use MMDET_BACKBONES')
mmdet_imported = False


Expand Down
8 changes: 5 additions & 3 deletions mmaction/models/backbones/resnet_tin.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import warnings

import torch
import torch.nn as nn

from mmaction.utils import import_module_error_func
from ..registry import BACKBONES
from .resnet_tsm import ResNetTSM

try:
from mmcv.ops import tin_shift
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmcv-full to support "tin_shift"')

@import_module_error_func('mmcv-full')
def tin_shift(*args, **kwargs):
pass


def linear_sampler(data, offset):
Expand Down
6 changes: 2 additions & 4 deletions mmaction/models/builder.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import warnings

import torch.nn as nn
from mmcv.utils import Registry, build_from_cfg

from mmaction.utils import import_module_error_func
from .registry import BACKBONES, HEADS, LOCALIZERS, LOSSES, NECKS, RECOGNIZERS

try:
from mmdet.models.builder import DETECTORS, build_detector
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use DETECTORS, build_detector')

# Define an empty registry and building func, so that can import
DETECTORS = Registry('detector')

@import_module_error_func('mmdet')
def build_detector(cfg, train_cfg, test_cfg):
pass

Expand Down
3 changes: 0 additions & 3 deletions mmaction/models/heads/bbox_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -10,7 +8,6 @@
from mmdet.models.builder import HEADS as MMDET_HEADS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use MMDET_HEADS')
mmdet_imported = False


Expand Down
6 changes: 2 additions & 4 deletions mmaction/models/heads/roi_head.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import warnings

import numpy as np

from mmaction.core.bbox import bbox2result
from mmaction.utils import import_module_error_class

try:
from mmdet.core.bbox import bbox2roi
from mmdet.models import HEADS as MMDET_HEADS
from mmdet.models.roi_heads import StandardRoIHead
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use bbox2roi, MMDET_HEADS '
'and StandardRoIHead')
mmdet_imported = False

if mmdet_imported:
Expand Down Expand Up @@ -86,5 +83,6 @@ def simple_test_bboxes(self,
return det_bboxes, det_labels
else:
# Just define an empty class, so that __init__ can import it.
@import_module_error_class('mmdet')
class AVARoIHead:
pass
15 changes: 11 additions & 4 deletions mmaction/models/roi_extractors/single_straight3d.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import warnings

import torch
import torch.nn as nn

from mmaction.utils import import_module_error_class

try:
from mmcv.ops import RoIAlign, RoIPool
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmcv-full to use RoIAlign and RoIPool')

@import_module_error_class('mmcv-full')
class RoIAlign(nn.Module):
pass

@import_module_error_class('mmcv-full')
class RoIPool(nn.Module):
pass


try:
from mmdet.models import ROI_EXTRACTORS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use ROI_EXTRACTORS')
mmdet_imported = False


Expand Down
4 changes: 3 additions & 1 deletion mmaction/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .collect_env import collect_env
from .decorators import import_module_error_class, import_module_error_func
from .gradcam_utils import GradCAM
from .logger import get_root_logger
from .misc import get_random_string, get_shm_dir, get_thread_id
from .precise_bn import PreciseBNHook

__all__ = [
'get_root_logger', 'collect_env', 'get_random_string', 'get_thread_id',
'get_shm_dir', 'GradCAM', 'PreciseBNHook'
'get_shm_dir', 'GradCAM', 'PreciseBNHook', 'import_module_error_class',
'import_module_error_func'
]
33 changes: 33 additions & 0 deletions mmaction/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from types import MethodType


def import_module_error_func(module_name):
"""When a function is imported incorrectly due to a missing module, raise
an import error when the function is called."""

def decorate(func):

def new_func(*args, **kwargs):
raise ImportError(
f'Please install {module_name} to use {func.__name__}.')
return func(*args, **kwargs)

return new_func

return decorate


def import_module_error_class(module_name):
"""When a class is imported incorrectly due to a missing module, raise an
import error when the class is instantiated."""

def decorate(cls):

def import_error_init(*args, **kwargs):
raise ImportError(
f'Please install {module_name} to use {cls.__name__}.')

cls.__init__ = MethodType(import_error_init, cls)
return cls

return decorate
11 changes: 9 additions & 2 deletions tests/test_utils/test_bbox.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import os.path as osp
import warnings
from abc import abstractproperty

import numpy as np
import torch

from mmaction.core.bbox import bbox2result, bbox_target
from mmaction.datasets import AVADataset
from mmaction.utils import import_module_error_func

try:
from mmdet.core.bbox import build_assigner, build_sampler
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use build_assigner, build_sampler')

@import_module_error_func('mmdet')
def build_assigner(*args, **kwargs):
pass

@import_module_error_func('mmdet')
def build_sampler(*args, **kwargs):
pass


def test_assigner_sampler():
Expand Down
32 changes: 32 additions & 0 deletions tests/test_utils/test_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from mmaction.utils import import_module_error_class, import_module_error_func


def test_import_module_error_class():

@import_module_error_class('mmdet')
class ExampleClass:
pass

with pytest.raises(ImportError):
ExampleClass()

@import_module_error_class('mmdet')
class ExampleClass:

def __init__(self, a, b=3):
self.c = a + b

with pytest.raises(ImportError):
ExampleClass(4)


def test_import_module_error_func():

@import_module_error_func('_add')
def ExampleFunc(a, b):
return a + b

with pytest.raises(ImportError):
ExampleFunc(3, 4)

0 comments on commit 509118c

Please sign in to comment.