Skip to content

Commit

Permalink
fix 2d kpts
Browse files Browse the repository at this point in the history
  • Loading branch information
xiexinch committed Sep 20, 2023
1 parent b296a2a commit 693f087
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 3 deletions.
8 changes: 5 additions & 3 deletions mmpose/datasets/transforms/converting.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def transform(self, results: dict) -> dict:
keypoints = np.zeros((num_instances, self.num_keypoints, 3))
keypoints_visible = np.zeros((num_instances, self.num_keypoints))
key = 'keypoints_3d' if 'keypoints_3d' in results else 'keypoints'
c = results[key].shape[-1]

flip_indices = results.get('flip_indices', None)

Expand All @@ -103,7 +104,7 @@ def transform(self, results: dict) -> dict:

# Interpolate keypoints if pairs of source indexes provided
if self.interpolation:
keypoints[:, self.target_index] = 0.5 * (
keypoints[:, self.target_index, :c] = 0.5 * (
results[key][:, self.source_index] +
results[key][:, self.source_index2])
keypoints_visible[:, self.target_index] = results[
Expand All @@ -118,8 +119,9 @@ def transform(self, results: dict) -> dict:
flip_indices = flip_indices[:len(self.source_index)]
# Otherwise just copy from the source index
else:
keypoints[:, self.target_index] = results[key][:,
self.source_index]
keypoints[:,
self.target_index, :c] = results[key][:,
self.source_index]
keypoints_visible[:, self.target_index] = results[
'keypoints_visible'][:, self.source_index]

Expand Down
27 changes: 27 additions & 0 deletions tests/test_datasets/test_transforms/test_converting.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,33 @@ def test_transform(self):
self.data_info['keypoints_visible'][:,
source_index]).all())

# check 3d keypoint
self.data_info['keypoints_3d'] = np.random.random((4, 17, 3))
self.data_info['target_idx'] = [-1]
mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
transform = KeypointConverter(num_keypoints=5, mapping=mapping)
results = transform(self.data_info.copy())

# check shape
self.assertEqual(results['keypoints_3d'].shape[0],
self.data_info['keypoints_3d'].shape[0])
self.assertEqual(results['keypoints_3d'].shape[1], 5)
self.assertEqual(results['keypoints_3d'].shape[2], 3)
self.assertEqual(results['keypoints_visible'].shape[0],
self.data_info['keypoints_visible'].shape[0])
self.assertEqual(results['keypoints_visible'].shape[1], 5)

# check value
for source_index, target_index in mapping:
self.assertTrue(
(results['keypoints_3d'][:, target_index] ==
self.data_info['keypoints_3d'][:, source_index]).all())
self.assertEqual(results['keypoints_visible'].ndim, 3)
self.assertEqual(results['keypoints_visible'].shape[2], 2)
self.assertTrue(
(results['keypoints_visible'][:, target_index, 0] ==
self.data_info['keypoints_visible'][:, source_index]).all())

def test_transform_sigmas(self):

mapping = [(3, 0), (6, 1), (16, 2), (5, 3)]
Expand Down

0 comments on commit 693f087

Please sign in to comment.