Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Improvement] Remove dependency package warnings #583

Merged
merged 7 commits into from
Feb 1, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions demo/demo_spatiotemporal_det.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import os
import os.path as osp
import shutil
import warnings

import cv2
import mmcv
Expand All @@ -13,11 +12,20 @@
from tqdm import tqdm

from mmaction.models import build_detector
from mmaction.utils import import_module_error_func

try:
from mmdet.apis import inference_detector, init_detector
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use demo_spatiotempoal_det')

@import_module_error_func('mmdet')
def inference_detector(*args, **kwargs):
pass

@import_module_error_func('mmdet')
def init_detector(*args, **kwargs):
pass


try:
import moviepy.editor as mpy
Expand Down
2 changes: 1 addition & 1 deletion docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

- Refactor EvalHook ([#395](https://github.com/open-mmlab/mmaction2/pull/395))
- Refactor AVA hook ([#567](https://github.com/open-mmlab/mmaction2/pull/567))
- Refactor unit test structure ([#433](https://github.com/open-mmlab/mmaction2/pull/433))
- Add repo citation ([#545](https://github.com/open-mmlab/mmaction2/pull/545))
- Add dataset size of Kinetics400 ([#503](https://github.com/open-mmlab/mmaction2/pull/503))
- Add lazy operation docs ([#504](https://github.com/open-mmlab/mmaction2/pull/504))
Expand All @@ -31,6 +30,7 @@
- Add config tag in dataset README ([#540](https://github.com/open-mmlab/mmaction2/pull/540))
- Add solution for markdownlint installation issue ([#497](https://github.com/open-mmlab/mmaction2/pull/497))
- Add dataset overview in readthedocs ([#548](https://github.com/open-mmlab/mmaction2/pull/548))
- Modify the trigger mode of the warnings of missing mmdet ([583](https://github.com/open-mmlab/mmaction2/pull/583))
- Refactor config structure ([#488](https://github.com/open-mmlab/mmaction2/pull/488), [#572](https://github.com/open-mmlab/mmaction2/pull/572))
- Refactor unittest structure ([#433](https://github.com/open-mmlab/mmaction2/pull/433))

Expand Down
7 changes: 3 additions & 4 deletions mmaction/core/bbox/assigners/max_iou_assigner_ava.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
import warnings

import torch

from mmaction.utils import import_module_error_class

try:
from mmdet.core.bbox import AssignResult, MaxIoUAssigner
from mmdet.core.bbox.builder import BBOX_ASSIGNERS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use AssignResult, MaxIoUAssigner '
'and BBOX_ASSIGNERS')
mmdet_imported = False

if mmdet_imported:
Expand Down Expand Up @@ -135,5 +133,6 @@ def assign_wrt_overlaps(self, overlaps, gt_labels=None):

else:
# define an empty class, so that can be imported
@import_module_error_class('mmdet')
class MaxIoUAssignerAVA:
pass
9 changes: 6 additions & 3 deletions mmaction/core/evaluation/recall.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import warnings

import numpy as np
import torch

from mmaction.utils import import_module_error_func

try:
from mmdet.core import bbox_overlaps
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use bbox_overlaps')

@import_module_error_func('mmdet')
def bbox_overlaps(*args, **kwargs):
pass


def _recalls(all_ious, proposal_nums, thrs):
Expand Down
3 changes: 0 additions & 3 deletions mmaction/models/backbones/resnet3d.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, NonLocal3d, build_activation_layer,
Expand All @@ -15,7 +13,6 @@
from mmdet.models.builder import SHARED_HEADS as MMDET_SHARED_HEADS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use MMDET_SHARED_HEADS')
mmdet_imported = False


Expand Down
3 changes: 0 additions & 3 deletions mmaction/models/backbones/resnet3d_slowfast.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, kaiming_init
Expand All @@ -14,7 +12,6 @@
from mmdet.models import BACKBONES as MMDET_BACKBONES
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use MMDET_BACKBONES')
mmdet_imported = False


Expand Down
3 changes: 0 additions & 3 deletions mmaction/models/backbones/resnet3d_slowonly.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import warnings

from ..registry import BACKBONES
from .resnet3d_slowfast import ResNet3dPathway

try:
from mmdet.models.builder import BACKBONES as MMDET_BACKBONES
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use MMDET_BACKBONES')
mmdet_imported = False


Expand Down
8 changes: 5 additions & 3 deletions mmaction/models/backbones/resnet_tin.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import warnings

import torch
import torch.nn as nn

from mmaction.utils import import_module_error_func
from ..registry import BACKBONES
from .resnet_tsm import ResNetTSM

try:
from mmcv.ops import tin_shift
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmcv-full to support "tin_shift"')

@import_module_error_func('mmcv-full')
def tin_shift(*args, **kwargs):
pass


def linear_sampler(data, offset):
Expand Down
6 changes: 2 additions & 4 deletions mmaction/models/builder.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import warnings

import torch.nn as nn
from mmcv.utils import Registry, build_from_cfg

from mmaction.utils import import_module_error_func
from .registry import BACKBONES, HEADS, LOCALIZERS, LOSSES, NECKS, RECOGNIZERS

try:
from mmdet.models.builder import DETECTORS, build_detector
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use DETECTORS, build_detector')

# Define an empty registry and building func, so that can import
DETECTORS = Registry('detector')

@import_module_error_func('mmdet')
def build_detector(cfg, train_cfg, test_cfg):
pass

Expand Down
3 changes: 0 additions & 3 deletions mmaction/models/heads/bbox_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import warnings

import torch
import torch.nn as nn
import torch.nn.functional as F
Expand All @@ -10,7 +8,6 @@
from mmdet.models.builder import HEADS as MMDET_HEADS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use MMDET_HEADS')
mmdet_imported = False


Expand Down
6 changes: 2 additions & 4 deletions mmaction/models/heads/roi_head.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import warnings

import numpy as np

from mmaction.core.bbox import bbox2result
from mmaction.utils import import_module_error_class

try:
from mmdet.core.bbox import bbox2roi
from mmdet.models import HEADS as MMDET_HEADS
from mmdet.models.roi_heads import StandardRoIHead
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use bbox2roi, MMDET_HEADS '
'and StandardRoIHead')
mmdet_imported = False

if mmdet_imported:
Expand Down Expand Up @@ -86,5 +83,6 @@ def simple_test_bboxes(self,
return det_bboxes, det_labels
else:
# Just define an empty class, so that __init__ can import it.
@import_module_error_class('mmdet')
class AVARoIHead:
pass
15 changes: 11 additions & 4 deletions mmaction/models/roi_extractors/single_straight3d.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,25 @@
import warnings

import torch
import torch.nn as nn

from mmaction.utils import import_module_error_class

try:
from mmcv.ops import RoIAlign, RoIPool
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmcv-full to use RoIAlign and RoIPool')

@import_module_error_class('mmcv-full')
class RoIAlign(nn.Module):
pass

@import_module_error_class('mmcv-full')
class RoIPool(nn.Module):
pass


try:
from mmdet.models import ROI_EXTRACTORS
mmdet_imported = True
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use ROI_EXTRACTORS')
mmdet_imported = False


Expand Down
4 changes: 3 additions & 1 deletion mmaction/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from .collect_env import collect_env
from .decorators import import_module_error_class, import_module_error_func
from .gradcam_utils import GradCAM
from .logger import get_root_logger
from .misc import get_random_string, get_shm_dir, get_thread_id
from .precise_bn import PreciseBNHook

__all__ = [
'get_root_logger', 'collect_env', 'get_random_string', 'get_thread_id',
'get_shm_dir', 'GradCAM', 'PreciseBNHook'
'get_shm_dir', 'GradCAM', 'PreciseBNHook', 'import_module_error_class',
'import_module_error_func'
]
33 changes: 33 additions & 0 deletions mmaction/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from types import MethodType


def import_module_error_func(module_name):
"""When a function is imported incorrectly due to a missing module, raise
an import error when the function is called."""

def decorate(func):

def new_func(*args, **kwargs):
raise ImportError(
f'Please install {module_name} to use {func.__name__}.')
return func(*args, **kwargs)

return new_func

return decorate


def import_module_error_class(module_name):
"""When a class is imported incorrectly due to a missing module, raise an
import error when the class is instantiated."""

def decorate(cls):

def import_error_init(*args, **kwargs):
raise ImportError(
f'Please install {module_name} to use {cls.__name__}.')

cls.__init__ = MethodType(import_error_init, cls)
return cls

return decorate
11 changes: 9 additions & 2 deletions tests/test_utils/test_bbox.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
import os.path as osp
import warnings
from abc import abstractproperty

import numpy as np
import torch

from mmaction.core.bbox import bbox2result, bbox_target
from mmaction.datasets import AVADataset
from mmaction.utils import import_module_error_func

try:
from mmdet.core.bbox import build_assigner, build_sampler
except (ImportError, ModuleNotFoundError):
warnings.warn('Please install mmdet to use build_assigner, build_sampler')

@import_module_error_func('mmdet')
def build_assigner(*args, **kwargs):
pass

@import_module_error_func('mmdet')
def build_sampler(*args, **kwargs):
pass


def test_assigner_sampler():
Expand Down
32 changes: 32 additions & 0 deletions tests/test_utils/test_decorator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import pytest

from mmaction.utils import import_module_error_class, import_module_error_func


def test_import_module_error_class():

@import_module_error_class('mmdet')
class ExampleClass:
pass

with pytest.raises(ImportError):
ExampleClass()

@import_module_error_class('mmdet')
class ExampleClass:

def __init__(self, a, b=3):
self.c = a + b

with pytest.raises(ImportError):
ExampleClass(4)


def test_import_module_error_func():

@import_module_error_func('_add')
def ExampleFunc(a, b):
return a + b

with pytest.raises(ImportError):
ExampleFunc(3, 4)