Skip to content

Commit

Permalink
int64 mm tests
Browse files Browse the repository at this point in the history
  • Loading branch information
coquelin77 committed Dec 4, 2020
1 parent 85505c3 commit 48004d8
Showing 1 changed file with 78 additions and 0 deletions.
78 changes: 78 additions & 0 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,19 @@ def test_matmul(self):
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, None)

a = ht.ones((n, m), split=None, dtype=ht.int64)
b = ht.ones((j), split=None, dtype=ht.int64)
a[0] = ht.arange(1, m + 1, dtype=ht.int64)
a[:, -1] = ht.arange(1, n + 1, dtype=ht.int64)
ret00 = ht.matmul(a, b)

ret_comp = ht.array((a_torch @ b_torch), split=None)
self.assertTrue(ht.equal(ret00, ret_comp))
self.assertIsInstance(ret00, ht.DNDarray)
self.assertEqual(ret00.shape, (n,))
self.assertEqual(ret00.dtype, ht.int64)
self.assertEqual(ret00.split, None)

# splits 0 None
a = ht.ones((n, m), split=0)
b = ht.ones((j), split=None)
Expand All @@ -395,6 +408,19 @@ def test_matmul(self):
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, 0)

a = ht.ones((n, m), split=0, dtype=ht.int64)
b = ht.ones((j), split=None, dtype=ht.int64)
a[0] = ht.arange(1, m + 1, dtype=ht.int64)
a[:, -1] = ht.arange(1, n + 1, dtype=ht.int64)
ret00 = ht.matmul(a, b)

ret_comp = ht.array((a_torch @ b_torch), split=None)
self.assertTrue(ht.equal(ret00, ret_comp))
self.assertIsInstance(ret00, ht.DNDarray)
self.assertEqual(ret00.shape, (n,))
self.assertEqual(ret00.dtype, ht.int64)
self.assertEqual(ret00.split, 0)

# splits 1 None
a = ht.ones((n, m), split=1)
b = ht.ones((j), split=None)
Expand All @@ -409,6 +435,19 @@ def test_matmul(self):
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, 0)

a = ht.ones((n, m), split=1, dtype=ht.int64)
b = ht.ones((j), split=None, dtype=ht.int64)
a[0] = ht.arange(1, m + 1, dtype=ht.int64)
a[:, -1] = ht.arange(1, n + 1, dtype=ht.int64)
ret00 = ht.matmul(a, b)

ret_comp = ht.array((a_torch @ b_torch), split=None)
self.assertTrue(ht.equal(ret00, ret_comp))
self.assertIsInstance(ret00, ht.DNDarray)
self.assertEqual(ret00.shape, (n,))
self.assertEqual(ret00.dtype, ht.int64)
self.assertEqual(ret00.split, 0)

# splits None 0
a = ht.ones((n, m), split=None)
b = ht.ones((j), split=0)
Expand All @@ -423,6 +462,19 @@ def test_matmul(self):
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, 0)

a = ht.ones((n, m), split=None, dtype=ht.int64)
b = ht.ones((j), split=0, dtype=ht.int64)
a[0] = ht.arange(1, m + 1, dtype=ht.int64)
a[:, -1] = ht.arange(1, n + 1, dtype=ht.int64)
ret00 = ht.matmul(a, b)

ret_comp = ht.array((a_torch @ b_torch), split=None)
self.assertTrue(ht.equal(ret00, ret_comp))
self.assertIsInstance(ret00, ht.DNDarray)
self.assertEqual(ret00.shape, (n,))
self.assertEqual(ret00.dtype, ht.int64)
self.assertEqual(ret00.split, 0)

# splits 0 0
a = ht.ones((n, m), split=0)
b = ht.ones((j), split=0)
Expand All @@ -437,6 +489,19 @@ def test_matmul(self):
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, 0)

a = ht.ones((n, m), split=0, dtype=ht.int64)
b = ht.ones((j), split=0, dtype=ht.int64)
a[0] = ht.arange(1, m + 1, dtype=ht.int64)
a[:, -1] = ht.arange(1, n + 1, dtype=ht.int64)
ret00 = ht.matmul(a, b)

ret_comp = ht.array((a_torch @ b_torch), split=None)
self.assertTrue(ht.equal(ret00, ret_comp))
self.assertIsInstance(ret00, ht.DNDarray)
self.assertEqual(ret00.shape, (n,))
self.assertEqual(ret00.dtype, ht.int64)
self.assertEqual(ret00.split, 0)

# splits 1 0
a = ht.ones((n, m), split=1)
b = ht.ones((j), split=0)
Expand All @@ -451,6 +516,19 @@ def test_matmul(self):
self.assertEqual(ret00.dtype, ht.float)
self.assertEqual(ret00.split, 0)

a = ht.ones((n, m), split=1, dtype=ht.int64)
b = ht.ones((j), split=0, dtype=ht.int64)
a[0] = ht.arange(1, m + 1, dtype=ht.int64)
a[:, -1] = ht.arange(1, n + 1, dtype=ht.int64)
ret00 = ht.matmul(a, b)

ret_comp = ht.array((a_torch @ b_torch), split=None)
self.assertTrue(ht.equal(ret00, ret_comp))
self.assertIsInstance(ret00, ht.DNDarray)
self.assertEqual(ret00.shape, (n,))
self.assertEqual(ret00.dtype, ht.int64)
self.assertEqual(ret00.split, 0)

with self.assertRaises(NotImplementedError):
a = ht.zeros((3, 3, 3), split=2)
b = a.copy()
Expand Down

0 comments on commit 48004d8

Please sign in to comment.