Skip to content
Permalink
Browse files

fix tests for tf

  • Loading branch information
nguigs authored and ninamiolane committed Mar 22, 2020
1 parent 3f1ec06 commit 3ea3e47934de399eaccecd8aa438ff3e113ec1f6
Showing with 7 additions and 3 deletions.
  1. +4 −0 geomstats/backend/tensorflow/linalg.py
  2. +3 −3 tests/test_backends.py
@@ -78,3 +78,7 @@ def qr_aux(x, mode):
dtype=(tf.float32, tf.float32))

return qr


def powerm(x, power):
return expm(power * logm(x))
@@ -161,7 +161,7 @@ def test_expm_and_logm_vectorization(self):

self.assertAllClose(result, expected)

@geomstats.tests.np_only
@geomstats.tests.np_and_tf_only
def test_powerm_diagonal(self):
power = .5
point = gs.array([[1., 0., 0.],
@@ -174,7 +174,7 @@ def test_powerm_diagonal(self):

self.assertAllClose(result, expected)

@geomstats.tests.np_only
@geomstats.tests.np_and_tf_only
def test_powerm(self):
power = 2.4
point = gs.array([[1., 0., 0.],
@@ -196,7 +196,7 @@ def test_powerm_vectorization(self):
[0., 2.5, 1.5],
[0., 1.5, 2.5]]])
result = gs.linalg.powerm(points, power)
result = gs.linalg.powerm(result, 1 / power)
result = gs.linalg.powerm(result, 1. / power)
expected = points

self.assertAllClose(result, expected)

0 comments on commit 3ea3e47

Please sign in to comment.
You can’t perform that action at this time.