Skip to content

Commit

Permalink
add prefix for EvalHook
Browse files Browse the repository at this point in the history
  • Loading branch information
dreamerlin committed Oct 21, 2020
1 parent 68a351d commit 09957f3
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 37 deletions.
4 changes: 2 additions & 2 deletions mmaction/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
build_optimizer)
from mmcv.runner.hooks import Fp16OptimizerHook

from ..core import DistEvalHook, EvalHook
from ..core import DistEpochEvalHook, EpochEvalHook
from ..datasets import build_dataloader, build_dataset
from ..utils import get_root_logger

Expand Down Expand Up @@ -102,7 +102,7 @@ def train_model(model,
dataloader_setting = dict(dataloader_setting,
**cfg.data.get('val_dataloader', {}))
val_dataloader = build_dataloader(val_dataset, **dataloader_setting)
eval_hook = DistEvalHook if distributed else EvalHook
eval_hook = DistEpochEvalHook if distributed else EpochEvalHook
runner.register_hook(eval_hook(val_dataloader, **eval_cfg))

if cfg.resume_from:
Expand Down
12 changes: 6 additions & 6 deletions mmaction/core/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
mean_average_precision, mean_class_accuracy,
pairwise_temporal_iou, softmax, top_k_accuracy)
from .eval_detection import ActivityNetDetection
from .eval_hooks import DistEvalHook, EvalHook
from .eval_hooks import DistEpochEvalHook, EpochEvalHook

__all__ = [
'DistEvalHook', 'EvalHook', 'top_k_accuracy', 'mean_class_accuracy',
'confusion_matrix', 'mean_average_precision', 'get_weighted_score',
'average_recall_at_avg_proposals', 'pairwise_temporal_iou',
'average_precision_at_temporal_iou', 'ActivityNetDetection', 'softmax',
'interpolated_precision_recall'
'DistEpochEvalHook', 'EpochEvalHook', 'top_k_accuracy',
'mean_class_accuracy', 'confusion_matrix', 'mean_average_precision',
'get_weighted_score', 'average_recall_at_avg_proposals',
'pairwise_temporal_iou', 'average_precision_at_temporal_iou',
'ActivityNetDetection', 'softmax', 'interpolated_precision_recall'
]
10 changes: 5 additions & 5 deletions mmaction/core/evaluation/eval_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from mmaction.utils import get_root_logger


class EvalHook(Hook):
"""Non-Distributed evaluation hook.
class EpochEvalHook(Hook):
"""Non-Distributed evaluation hook based on epochs.
Notes:
If new arguments are added for EvalHook, tools/test.py,
If new arguments are added for EpochEvalHook, tools/test.py,
tools/eval_metric.py may be effected.
This hook will regularly perform evaluation in a given interval when
Expand Down Expand Up @@ -176,8 +176,8 @@ def evaluate(self, runner, results):
return None


class DistEvalHook(EvalHook):
"""Distributed evaluation hook.
class DistEpochEvalHook(EpochEvalHook):
"""Distributed evaluation hook based on epochs.
This hook will regularly perform evaluation in a given interval when
performing in distributed environment.
Expand Down
51 changes: 27 additions & 24 deletions tests/test_eval_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mmcv.utils import get_logger
from torch.utils.data import DataLoader, Dataset

from mmaction.core import DistEvalHook, EvalHook
from mmaction.core import DistEpochEvalHook, EpochEvalHook


class ExampleDataset(Dataset):
Expand Down Expand Up @@ -99,7 +99,7 @@ def test_eval_hook():
sampler=None,
num_workers=0,
shuffle=False)
EvalHook(data_loader, save_best='True')
EpochEvalHook(data_loader, save_best='True')

with pytest.raises(TypeError):
# dataloader must be a pytorch DataLoader
Expand All @@ -112,7 +112,7 @@ def test_eval_hook():
num_worker=0,
shuffle=False)
]
EvalHook(data_loader)
EpochEvalHook(data_loader)

with pytest.raises(ValueError):
# when `save_best` is True, `key_indicator` should not be None
Expand All @@ -123,7 +123,7 @@ def test_eval_hook():
sampler=None,
num_workers=0,
shuffle=False)
EvalHook(data_loader, key_indicator=None)
EpochEvalHook(data_loader, key_indicator=None)

with pytest.raises(KeyError):
# rule must be in keys of rule_map
Expand All @@ -134,7 +134,7 @@ def test_eval_hook():
sampler=None,
num_workers=0,
shuffle=False)
EvalHook(data_loader, save_best=False, rule='unsupport')
EpochEvalHook(data_loader, save_best=False, rule='unsupport')

with pytest.raises(ValueError):
# key_indicator must be valid when rule_map is None
Expand All @@ -145,7 +145,7 @@ def test_eval_hook():
sampler=None,
num_workers=0,
shuffle=False)
EvalHook(data_loader, key_indicator='unsupport')
EpochEvalHook(data_loader, key_indicator='unsupport')

optimizer_cfg = dict(
type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0001)
Expand All @@ -156,7 +156,7 @@ def test_eval_hook():
optimizer = build_optimizer(model, optimizer_cfg)

data_loader = DataLoader(test_dataset, batch_size=1)
eval_hook = EvalHook(data_loader, save_best=False)
eval_hook = EpochEvalHook(data_loader, save_best=False)
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(
Expand All @@ -176,7 +176,7 @@ def test_eval_hook():
loader = DataLoader(EvalDataset(), batch_size=1)
model = ExampleModel()
data_loader = DataLoader(EvalDataset(), batch_size=1)
eval_hook = EvalHook(
eval_hook = EpochEvalHook(
data_loader, interval=1, save_best=True, key_indicator='acc')

with tempfile.TemporaryDirectory() as tmpdir:
Expand All @@ -200,7 +200,7 @@ def test_eval_hook():
assert best_json['key_indicator'] == 'acc'

data_loader = DataLoader(EvalDataset(), batch_size=1)
eval_hook = EvalHook(
eval_hook = EpochEvalHook(
data_loader,
interval=1,
save_best=True,
Expand All @@ -227,7 +227,7 @@ def test_eval_hook():
assert best_json['key_indicator'] == 'score'

data_loader = DataLoader(EvalDataset(), batch_size=1)
eval_hook = EvalHook(data_loader, rule='less', key_indicator='acc')
eval_hook = EpochEvalHook(data_loader, rule='less', key_indicator='acc')
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(
Expand All @@ -249,7 +249,7 @@ def test_eval_hook():
assert best_json['key_indicator'] == 'acc'

data_loader = DataLoader(EvalDataset(), batch_size=1)
eval_hook = EvalHook(data_loader, key_indicator='acc')
eval_hook = EpochEvalHook(data_loader, key_indicator='acc')
with tempfile.TemporaryDirectory() as tmpdir:
logger = get_logger('test_eval')
runner = EpochBasedRunner(
Expand All @@ -272,7 +272,7 @@ def test_eval_hook():

resume_from = osp.join(tmpdir, 'latest.pth')
loader = DataLoader(ExampleDataset(), batch_size=1)
eval_hook = EvalHook(data_loader, key_indicator='acc')
eval_hook = EpochEvalHook(data_loader, key_indicator='acc')
runner = EpochBasedRunner(
model=model,
batch_processor=None,
Expand All @@ -295,46 +295,49 @@ def test_eval_hook():

@patch('mmaction.apis.single_gpu_test', MagicMock)
@patch('mmaction.apis.multi_gpu_test', MagicMock)
@pytest.mark.parametrize('EvalHookParam', (EvalHook, DistEvalHook))
def test_start_param(EvalHookParam):
@pytest.mark.parametrize('EpochEvalHookParam',
(EpochEvalHook, DistEpochEvalHook))
def test_start_param(EpochEvalHookParam):
# create dummy data
dataloader = DataLoader(torch.ones((5, 2)))

# 0.1. dataloader is not a DataLoader object
with pytest.raises(TypeError):
EvalHookParam(dataloader=MagicMock(), interval=-1)
EpochEvalHookParam(dataloader=MagicMock(), interval=-1)

# 0.2. negative interval
with pytest.raises(ValueError):
EvalHookParam(dataloader, interval=-1)
EpochEvalHookParam(dataloader, interval=-1)

# 1. start=None, interval=1: perform evaluation after each epoch.
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, interval=1, save_best=False)
evalhook = EpochEvalHookParam(dataloader, interval=1, save_best=False)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 2)
assert evalhook.evaluate.call_count == 2 # after epoch 1 & 2

# 2. start=1, interval=1: perform evaluation after each epoch.
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, start=1, interval=1, save_best=False)
evalhook = EpochEvalHookParam(
dataloader, start=1, interval=1, save_best=False)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 2)
assert evalhook.evaluate.call_count == 2 # after epoch 1 & 2

# 3. start=None, interval=2: perform evaluation after epoch 2, 4, 6, etc
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, interval=2, save_best=False)
evalhook = EpochEvalHookParam(dataloader, interval=2, save_best=False)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 2)
assert evalhook.evaluate.call_count == 1 # after epoch 2

# 4. start=1, interval=2: perform evaluation after epoch 1, 3, 5, etc
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, start=1, interval=2, save_best=False)
evalhook = EpochEvalHookParam(
dataloader, start=1, interval=2, save_best=False)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 3)
Expand All @@ -343,15 +346,15 @@ def test_start_param(EvalHookParam):
# 5. start=0/negative, interval=1: perform evaluation after each epoch and
# before epoch 1.
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, start=0, save_best=False)
evalhook = EpochEvalHookParam(dataloader, start=0, save_best=False)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 2)
assert evalhook.evaluate.call_count == 3 # before epoch1 and after e1 & e2

runner = _build_demo_runner()
with pytest.warns(UserWarning):
evalhook = EvalHookParam(dataloader, start=-2, save_best=False)
evalhook = EpochEvalHookParam(dataloader, start=-2, save_best=False)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner.run([dataloader], [('train', 1)], 2)
Expand All @@ -360,7 +363,7 @@ def test_start_param(EvalHookParam):
# 6. resuming from epoch i, start = x (x<=i), interval =1: perform
# evaluation after each epoch and before the first epoch.
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, start=1, save_best=False)
evalhook = EpochEvalHookParam(dataloader, start=1, save_best=False)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner._epoch = 2
Expand All @@ -370,7 +373,7 @@ def test_start_param(EvalHookParam):
# 7. resuming from epoch i, start = i+1/None, interval =1: perform
# evaluation after each epoch.
runner = _build_demo_runner()
evalhook = EvalHookParam(dataloader, start=2, save_best=False)
evalhook = EpochEvalHookParam(dataloader, start=2, save_best=False)
evalhook.evaluate = MagicMock()
runner.register_hook(evalhook)
runner._epoch = 1
Expand Down

0 comments on commit 09957f3

Please sign in to comment.