Skip to content

Commit

Permalink
Consolidate test_forward and test_forward2
Browse files Browse the repository at this point in the history
  • Loading branch information
Erotemic committed Nov 18, 2019
1 parent c5d248b commit 7aa143d
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 213 deletions.
105 changes: 105 additions & 0 deletions tests/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,111 @@ def test_retina_ghm_forward():
batch_results.append(result)


def test_cascade_empty_forward():
try:
from torchvision import _C as C # NOQA
except ImportError:
import pytest
raise pytest.skip('requires torchvision on cpu')

model, train_cfg, test_cfg = _get_detector_cfg(
'cascade_rcnn_r50_fpn_1x.py')
model['pretrained'] = None
# torchvision roi align supports CPU
model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True

from mmdet.models import build_detector
detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)

input_shape = (1, 3, 256, 256)

# Test forward train with an empty truth batch
mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_bboxes = mm_inputs['gt_bboxes']
gt_labels = mm_inputs['gt_labels']
losses = detector.forward(
imgs,
img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
return_loss=True)
assert isinstance(losses, dict)
from mmdet.apis.train import parse_losses
total_loss = float(parse_losses(losses)[0].item())
assert total_loss > 0

# Test forward train with a non-empty truth batch
mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_bboxes = mm_inputs['gt_bboxes']
gt_labels = mm_inputs['gt_labels']
losses = detector.forward(
imgs,
img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
return_loss=True)
assert isinstance(losses, dict)
from mmdet.apis.train import parse_losses
total_loss = float(parse_losses(losses)[0].item())
assert total_loss > 0


def test_faster_rcnn_empty_forward():
try:
from torchvision import _C as C # NOQA
except ImportError:
import pytest
raise pytest.skip('requires torchvision on cpu')

model, train_cfg, test_cfg = _get_detector_cfg('faster_rcnn_r50_fpn_1x.py')
model['pretrained'] = None
# torchvision roi align supports CPU
model['bbox_roi_extractor']['roi_layer']['use_torchvision'] = True

from mmdet.models import build_detector
detector = build_detector(model, train_cfg=train_cfg, test_cfg=test_cfg)

input_shape = (1, 3, 256, 256)

# Test forward train with an empty truth batch
mm_inputs = _demo_mm_inputs(input_shape, num_items=[0])
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_bboxes = mm_inputs['gt_bboxes']
gt_labels = mm_inputs['gt_labels']
losses = detector.forward(
imgs,
img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
return_loss=True)
assert isinstance(losses, dict)
from mmdet.apis.train import parse_losses
total_loss = float(parse_losses(losses)[0].item())
assert total_loss > 0

# Test forward train with a non-empty truth batch
mm_inputs = _demo_mm_inputs(input_shape, num_items=[10])
imgs = mm_inputs.pop('imgs')
img_metas = mm_inputs.pop('img_metas')
gt_bboxes = mm_inputs['gt_bboxes']
gt_labels = mm_inputs['gt_labels']
losses = detector.forward(
imgs,
img_metas,
gt_bboxes=gt_bboxes,
gt_labels=gt_labels,
return_loss=True)
assert isinstance(losses, dict)
from mmdet.apis.train import parse_losses
total_loss = float(parse_losses(losses)[0].item())
assert total_loss > 0


def _demo_mm_inputs(
input_shape=(1, 3, 300, 300), num_items=None, num_classes=10):
"""
Expand Down
213 changes: 0 additions & 213 deletions tests/test_forward2.py

This file was deleted.

0 comments on commit 7aa143d

Please sign in to comment.