From bf6651115e517f0be4c27eff2d755c42917cebc7 Mon Sep 17 00:00:00 2001 From: Jas Date: Tue, 22 Sep 2020 10:40:22 +0800 Subject: [PATCH] fix assertion (#142) Co-authored-by: lizz --- mmpose/core/post_processing/post_transforms.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/mmpose/core/post_processing/post_transforms.py b/mmpose/core/post_processing/post_transforms.py index 5dad39e410..1eae3d38b6 100644 --- a/mmpose/core/post_processing/post_transforms.py +++ b/mmpose/core/post_processing/post_transforms.py @@ -95,6 +95,7 @@ def transform_preds(coords, center, scale, output_size): Args: coords (np.ndarray[K, ndims]): if ndims=2, corrds are predicted keypoint location. + if ndims=4, corrds are composed of (x, y, tags, scores) if ndims=5, corrds are composed of (x, y, tags, flipped_tags, scores) center (np.ndarray[2, ]): Center of the bounding box (x, y). @@ -105,7 +106,7 @@ def transform_preds(coords, center, scale, output_size): Returns: np.ndarray: Predicted coordinates in the images. """ - assert coords.shape[1] == 2 or coords.shape[1] == 5 + assert coords.shape[1] in (2, 4, 5) assert len(center) == 2 assert len(scale) == 2 assert len(output_size) == 2