Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use MMCV registry #220

Merged
merged 4 commits into from Jul 26, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
7 changes: 3 additions & 4 deletions mmtrack/models/__init__.py
@@ -1,8 +1,8 @@
from .aggregators import * # noqa: F401,F403
from .backbones import * # noqa: F401,F403
from .builder import (AGGREGATORS, MODELS, MOTION, REID, TRACKERS,
build_aggregator, build_detector, build_model,
build_motion, build_reid, build_tracker)
build_aggregator, build_model, build_motion, build_reid,
build_tracker)
from .losses import * # noqa: F401,F403
from .mot import * # noqa: F401,F403
from .motion import * # noqa: F401,F403
Expand All @@ -14,6 +14,5 @@

__all__ = [
'AGGREGATORS', 'MODELS', 'TRACKERS', 'MOTION', 'REID', 'build_model',
'build_tracker', 'build_motion', 'build_aggregator', 'build_reid',
'build_detector'
'build_tracker', 'build_motion', 'build_aggregator', 'build_reid'
]
62 changes: 14 additions & 48 deletions mmtrack/models/builder.py
@@ -1,71 +1,37 @@
import torch.nn as nn
from mmcv.utils import Registry, build_from_cfg
from mmdet.models import DETECTORS
from mmcv.cnn import MODELS as MMCV_MODELS
from mmcv.utils import Registry

MODELS = Registry('model')
TRACKERS = Registry('tracker')
MOTION = Registry('motion')
REID = Registry('reid')
AGGREGATORS = Registry('aggregator')


def build(cfg, registry, default_args=None):
"""Build a module.

Args:
cfg (dict, list[dict]): The config of modules, is is either a dict
or a list of configs.
registry (:obj:`Registry`): A registry the module belongs to.
default_args (dict, optional): Default arguments to build the module.
Defaults to None.

Returns:
nn.Module: A built nn module.
"""
if isinstance(cfg, list):
modules = [
build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
]
try:
return nn.Sequential(*modules)
except: # noqa: E722
return modules
else:
return build_from_cfg(cfg, registry, default_args)
MODELS = Registry('models', parent=MMCV_MODELS)
TRACKERS = MODELS
MOTION = MODELS
REID = MODELS
AGGREGATORS = MODELS


def build_tracker(cfg):
"""Build tracker."""
return build(cfg, TRACKERS)
return TRACKERS.build(cfg)


def build_motion(cfg):
"""Build motion model."""
return build(cfg, MOTION)
return MOTION.build(cfg)


def build_reid(cfg):
"""Build motion model."""
return build(cfg, REID)
return REID.build(cfg)


def build_aggregator(cfg):
"""Build aggregator model."""
return build(cfg, AGGREGATORS)


def build_detector(cfg, train_cfg=None, test_cfg=None):
"""Build detector."""
if train_cfg is None and test_cfg is None:
return build(cfg, DETECTORS)
else:
return build(cfg, DETECTORS,
dict(train_cfg=train_cfg, test_cfg=test_cfg))
return AGGREGATORS.build(cfg)


def build_model(cfg, train_cfg=None, test_cfg=None):
"""Build model."""
if train_cfg is None and test_cfg is None:
return build(cfg, MODELS)
return MODELS.build(cfg)
else:
return build(cfg, MODELS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
return MODELS.build(cfg, MODELS,
dict(train_cfg=train_cfg, test_cfg=test_cfg))
4 changes: 2 additions & 2 deletions mmtrack/models/mot/deep_sort.py
@@ -1,8 +1,8 @@
from mmdet.core import bbox2result
from mmdet.models import build_detector

from mmtrack.core import track2result
from ..builder import (MODELS, build_detector, build_motion, build_reid,
build_tracker)
from ..builder import MODELS, build_motion, build_reid, build_tracker
from .base import BaseMultiObjectTracker


Expand Down
4 changes: 2 additions & 2 deletions mmtrack/models/mot/tracktor.py
@@ -1,8 +1,8 @@
from mmdet.core import bbox2result
from mmdet.models import build_detector

from mmtrack.core import track2result
from ..builder import (MODELS, build_detector, build_motion, build_reid,
build_tracker)
from ..builder import MODELS, build_motion, build_reid, build_tracker
from ..motion import CameraMotionCompensation, LinearMotion
from .base import BaseMultiObjectTracker

Expand Down
3 changes: 2 additions & 1 deletion mmtrack/models/vid/dff.py
@@ -1,9 +1,10 @@
import torch
from addict import Dict
from mmdet.core import bbox2result
from mmdet.models import build_detector

from mmtrack.core.motion import flow_warp_feats
from ..builder import MODELS, build_detector, build_motion
from ..builder import MODELS, build_motion
from .base import BaseVideoDetector


Expand Down
3 changes: 2 additions & 1 deletion mmtrack/models/vid/fgfa.py
@@ -1,9 +1,10 @@
import torch
from addict import Dict
from mmdet.core import bbox2result
from mmdet.models import build_detector

from mmtrack.core import flow_warp_feats
from ..builder import MODELS, build_aggregator, build_detector, build_motion
from ..builder import MODELS, build_aggregator, build_motion
from .base import BaseVideoDetector


Expand Down
3 changes: 2 additions & 1 deletion mmtrack/models/vid/selsa.py
@@ -1,7 +1,8 @@
import torch
from addict import Dict
from mmdet.models import build_detector

from ..builder import MODELS, build_detector
from ..builder import MODELS
from .base import BaseVideoDetector


Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Expand Up @@ -7,3 +7,4 @@ numpy
seaborn
six
terminaltables
tqdm
2 changes: 1 addition & 1 deletion tools/train.py
Expand Up @@ -68,7 +68,7 @@ def main():

if cfg.get('USE_MMDET', False):
from mmdet.apis import train_detector as train_model
from mmtrack.models import build_detector as build_model
from mmdet.models import build_detector as build_model
if 'detector' in cfg.model:
cfg.model = cfg.model.detector
elif cfg.get('USE_MMCLS', False):
Expand Down