Skip to content

Commit

Permalink
fix test_homography_warper_gradcheck
Browse files Browse the repository at this point in the history
  • Loading branch information
edgarriba committed Sep 5, 2018
1 parent ae8d9e1 commit c716a2d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
7 changes: 7 additions & 0 deletions test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def test_convert_points_to_homogeneous_gradcheck(self):
# evaluate function gradient
res = gradcheck(dgm.convert_points_to_homogeneous, (points,),
raise_exception=True)
self.assertTrue(res)

def test_convert_points_from_homogeneous(self):
# generate input data
Expand All @@ -49,6 +50,7 @@ def test_convert_points_from_homogeneous_gradcheck(self):
# evaluate function gradient
res = gradcheck(dgm.convert_points_from_homogeneous, (points,),
raise_exception=True)
self.assertTrue(res)

def test_inverse(self):
# generate input data
Expand All @@ -72,6 +74,7 @@ def test_inverse_gradcheck(self):

# evaluate function gradient
res = gradcheck(dgm.inverse, (homographies,), raise_exception=True)
self.assertTrue(res)

def test_transform_points(self):
# generate input data
Expand Down Expand Up @@ -107,6 +110,7 @@ def test_transform_points_gradcheck(self):
# evaluate function gradient
res = gradcheck(dgm.transform_points, (dst_homo_src, points_src,),
raise_exception=True)
self.assertTrue(res)

def test_pi(self):
self.assertAlmostEqual(dgm.pi.item(), 3.141592, places=4)
Expand All @@ -130,6 +134,7 @@ def test_rad2deg_gradcheck(self):
# evaluate function gradient
res = gradcheck(dgm.rad2deg, (utils.tensor_to_gradcheck_var(x_rad),),
raise_exception=True)
self.assertTrue(res)

def test_deg2rad(self):
# generate input data
Expand All @@ -150,6 +155,7 @@ def test_deg2rad_gradcheck(self):
# evaluate function gradient
res = gradcheck(dgm.deg2rad, (utils.tensor_to_gradcheck_var(x_deg),),
raise_exception=True)
self.assertTrue(res)

@unittest.skip("")
def test_inverse_pose(self):
Expand Down Expand Up @@ -179,6 +185,7 @@ def test_inverse_pose_gradcheck(self):
# evaluate function gradient
res = gradcheck(dgm.inverse_pose, (dst_pose_src,),
raise_exception=True)
self.assertTrue(res)

if __name__ == '__main__':
unittest.main()
10 changes: 6 additions & 4 deletions test/test_homography_warper.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def test_homography_warper(self):
def test_homography_warper_gradcheck(self):
# generate input data
batch_size = 1
height, width = 128, 128
height, width = 16, 16 # small patch, otherwise the test takes forever
eye_size = 3 # identity 3x3

# create checkerboard
Expand All @@ -60,14 +60,16 @@ def test_homography_warper_gradcheck(self):

# create base homography
dst_homo_src = utils.create_eye_batch(batch_size, eye_size)
dst_homo_src = utils.tensor_to_gradcheck_var(dst_homo_src) # to var
dst_homo_src = utils.tensor_to_gradcheck_var(
dst_homo_src, requires_grad=False) # to var

# instantiate warper
warper = dgm.HomographyWarper(width, height)

# evaluate function gradient
res = gradcheck(warper, (patch_src, dst_homo_src,), eps=1e-2,
atol=1e-2, raise_exception=True)
res = gradcheck(warper, (patch_src, dst_homo_src,),
raise_exception=True)
self.assertTrue(res)


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,12 @@ def create_random_homography(batch_size, eye_size, std_val=1e-3):
return eye + std.uniform_(-std_val, std_val)


def tensor_to_gradcheck_var(tensor):
def tensor_to_gradcheck_var(tensor, dtype=torch.float64, requires_grad=True):
"""Converts the input tensor to a valid variable to check the gradient.
`gradcheck` needs 64-bit floating point and requires gradient.
"""
assert torch.is_tensor(tensor), type(tensor)
return tensor.requires_grad_(True).type(torch.DoubleTensor)
return tensor.requires_grad_(requires_grad).type(dtype)


def compute_mse(x, y):
Expand Down

0 comments on commit c716a2d

Please sign in to comment.