Skip to content

Commit

Permalink
[Improvement] Decouple dependency (#254)
Browse files Browse the repository at this point in the history
* add dependency placeholder

* update requirements

* decouple mmcls dependencies

Co-authored-by: huangpengsheng <huangpengsheng@sensetime.com>
  • Loading branch information
fpshuang and huangpengsheng committed Aug 30, 2022
1 parent e3390ce commit f45e2bd
Show file tree
Hide file tree
Showing 15 changed files with 76 additions and 19 deletions.
10 changes: 8 additions & 2 deletions mmrazor/models/architectures/backbones/searchable_mobilenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,21 @@
import copy
from typing import Dict, List, Optional, Sequence, Tuple, Union

from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.utils import make_divisible
from mmcv.cnn import ConvModule
from mmengine.model import Sequential
from torch import Tensor
from torch.nn.modules.batchnorm import _BatchNorm

from mmrazor.registry import MODELS

try:
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.utils import make_divisible
except ImportError:
from mmrazor.utils import get_placeholder
BaseBackbone = get_placeholder('mmcls')
make_divisible = get_placeholder('mmcls')


@MODELS.register_module()
class SearchableMobileNet(BaseBackbone):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Dict, List, Optional, Sequence, Tuple, Union

import torch.nn as nn
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcv.cnn import ConvModule
from mmengine.model import ModuleList, Sequential
from mmengine.model.weight_init import constant_init, normal_init
Expand All @@ -12,6 +11,12 @@

from mmrazor.registry import MODELS

try:
from mmcls.models.backbones.base_backbone import BaseBackbone
except ImportError:
from mmrazor.utils import get_placeholder
BaseBackbone = get_placeholder('mmcls')


@MODELS.register_module()
class SearchableShuffleNetV2(BaseBackbone):
Expand Down
13 changes: 10 additions & 3 deletions mmrazor/models/architectures/heads/darts_subnet_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,21 @@
from typing import List, Tuple

import torch
from mmcls.evaluation import Accuracy
from mmcls.models.heads import LinearClsHead
from mmcls.structures import ClsDataSample
from torch import nn

from mmrazor.models.utils import add_prefix
from mmrazor.registry import MODELS

try:
from mmcls.evaluation import Accuracy
from mmcls.models.heads import LinearClsHead
from mmcls.structures import ClsDataSample
except ImportError:
from mmrazor.utils import get_placeholder
Accuracy = get_placeholder('mmcls')
LinearClsHead = get_placeholder('mmcls')
ClsDataSample = get_placeholder('mmcls')


@MODELS.register_module()
class DartsSubnetClsHead(LinearClsHead):
Expand Down
7 changes: 6 additions & 1 deletion mmrazor/models/losses/kd_soft_ce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,15 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcls.models.losses.cross_entropy_loss import soft_cross_entropy

from mmrazor.registry import MODELS

try:
from mmcls.models.losses.cross_entropy_loss import soft_cross_entropy
except ImportError:
from mmrazor.utils import get_placeholder
soft_cross_entropy = get_placeholder('mmcls')


@MODELS.register_module()
class KDSoftCELoss(nn.Module):
Expand Down
7 changes: 6 additions & 1 deletion mmrazor/models/ops/efficientnet_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
from typing import Dict, Optional

import torch.nn as nn
from mmcls.models.utils import SELayer
from mmcv.cnn import ConvModule

from mmrazor.registry import MODELS
from .base import BaseOP

try:
from mmcls.models.utils import SELayer
except ImportError:
from mmrazor.utils import get_placeholder
SELayer = get_placeholder('mmcls')


@MODELS.register_module()
class ConvBnAct(BaseOP):
Expand Down
7 changes: 6 additions & 1 deletion mmrazor/models/ops/mobilenet_series.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,18 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcls.models.utils import SELayer
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import DropPath

from mmrazor.registry import MODELS
from .base import BaseOP

try:
from mmcls.models.utils import SELayer
except ImportError:
from mmrazor.utils import get_placeholder
SELayer = get_placeholder('mmcls')


@MODELS.register_module()
class MBBlock(BaseOP):
Expand Down
7 changes: 6 additions & 1 deletion mmrazor/models/ops/shufflenet_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcls.models.utils import channel_shuffle
from mmcv.cnn import ConvModule, DepthwiseSeparableConvModule

from mmrazor.registry import MODELS
from .base import BaseOP

try:
from mmcls.models.utils import channel_shuffle
except ImportError:
from mmrazor.utils import get_placeholder
channel_shuffle = get_placeholder('mmcls')


@MODELS.register_module()
class ShuffleBlock(BaseOP):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcls.models import ImageClassifier

from mmrazor.registry import TASK_UTILS

try:
from mmcls.models import ImageClassifier
except ImportError:
from mmrazor.utils import get_placeholder
ImageClassifier = get_placeholder('mmcls')


@TASK_UTILS.register_module()
class ImageClassifierPseudoLoss:
Expand Down
3 changes: 2 additions & 1 deletion mmrazor/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .misc import find_latest_checkpoint
from .placeholder import get_placeholder
from .setup_env import register_all_modules, setup_multi_processes
from .typing import (FixMutable, MultiMutatorsRandomSubnet,
SingleMutatorRandomSubnet, SupportRandomSubnet,
Expand All @@ -8,5 +9,5 @@
__all__ = [
'find_latest_checkpoint', 'setup_multi_processes', 'register_all_modules',
'FixMutable', 'ValidFixMutable', 'SingleMutatorRandomSubnet',
'MultiMutatorsRandomSubnet', 'SupportRandomSubnet'
'MultiMutatorsRandomSubnet', 'SupportRandomSubnet', 'get_placeholder'
]
20 changes: 20 additions & 0 deletions mmrazor/utils/placeholder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright (c) OpenMMLab. All rights reserved.
def get_placeholder(string: str) -> object:
"""Get placeholder instance which can avoid raising errors when down-stream
dependency is not installed properly.
Args:
string (str): the dependency's name, i.e. `mmcls`
Raises:
ImportError: raise it when the dependency is not installed properly.
Returns:
object: PlaceHolder instance.
"""

class PlaceHolder:

def __init__(self, *args, **kwargs) -> None:
raise ImportError(
f'`{string}` is not installed properly, plz check.')

return PlaceHolder
1 change: 0 additions & 1 deletion requirements/mminstall.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1 @@
mmcls
mmcv-full>=2.0.0rc0
2 changes: 0 additions & 2 deletions requirements/optional.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,2 @@
albumentations>=0.3.2
mmdet
mmsegmentation
timm
1 change: 0 additions & 1 deletion requirements/readthedocs.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
mmcls
mmcv>=1.3.8
ordered_set
torch
Expand Down
1 change: 0 additions & 1 deletion requirements/runtime.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
mmcls
ordered_set
typing_extensions;python_version<"3.8"
2 changes: 0 additions & 2 deletions requirements/tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ codecov
flake8
interrogate
isort==4.3.21
mmdet
mmsegmentation
pytest
xdoctest >= 0.10.0
yapf

0 comments on commit f45e2bd

Please sign in to comment.