diff --git a/tests/test_points.py b/tests/test_points.py index a9fb7c5..6d9bf38 100644 --- a/tests/test_points.py +++ b/tests/test_points.py @@ -33,6 +33,10 @@ def test_create_points(vtk_points, np_points): assert points == vtk_points assert points == np_points +def test_compare_points(points, np_points): + points2 = Points(np_points) + assert points == points2 + def test_xyz_points(points, np_points): assert np.array_equal(points.x, np_points[:, 0]) assert np.array_equal(points.y, np_points[:, 1]) @@ -127,4 +131,4 @@ def test_div(points, np_points): assert points == np_points assert points.GetPoint(0) == (.0, .1, .2) assert points.GetPoint(1) == (.1, .2, .3) - assert points.GetPoint(2) == (.2, .3, .4) \ No newline at end of file + assert points.GetPoint(2) == (.2, .3, .4) diff --git a/vtky/Points.py b/vtky/Points.py index 7567d30..b4ba73f 100644 --- a/vtky/Points.py +++ b/vtky/Points.py @@ -35,6 +35,8 @@ def __init__(self, array, array_name='Points'): def __eq__(self, other): if isinstance(other, vtk.vtkPoints): return self._base_array == other.GetData() + if isinstance(other, Points): + return self._base_array == other.xyz return self._base_array == other def _do_operation(self, other, operation):