diff --git a/CHANGELOG.md b/CHANGELOG.md index 37ea979835..ebe3a2d5b5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -27,6 +27,8 @@ - [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()` - [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm` - [#850](https://github.com/helmholtz-analytics/heat/pull/850) New Feature `cross` +- [#877](https://github.com/helmholtz-analytics/heat/pull/877) New feature `det` + ### Logical - [#862](https://github.com/helmholtz-analytics/heat/pull/862) New feature `signbit` ### Manipulations diff --git a/heat/core/dndarray.py b/heat/core/dndarray.py index d968ce2fe6..2580d09090 100644 --- a/heat/core/dndarray.py +++ b/heat/core/dndarray.py @@ -65,7 +65,7 @@ def __init__( array: torch.Tensor, gshape: Tuple[int, ...], dtype: datatype, - split: int, + split: Union[int, None], device: Device, comm: Communication, balanced: bool, diff --git a/heat/core/linalg/basics.py b/heat/core/linalg/basics.py index 39589f6967..a1d339af35 100644 --- a/heat/core/linalg/basics.py +++ b/heat/core/linalg/basics.py @@ -26,6 +26,7 @@ __all__ = [ "cross", + "det", "dot", "matmul", "matrix_norm", @@ -155,6 +156,92 @@ def cross( return ret +def det(a: DNDarray) -> DNDarray: + """ + Returns the determinant of a square matrix. + + Parameters + ---------- + a : DNDarray + A square matrix or a stack of matrices. Shape = (...,M,M) + + Raises + ------ + RuntimeError + If the dtype of 'a' is not floating-point. + RuntimeError + If `a.ndim < 2` or if the length of the last two dimensions is not the same. + + Examples + -------- + >>> a = ht.array([[-2,-1,2],[2,1,4],[-3,3,-1]]) + >>> ht.linalg.det(a) + DNDarray(54., dtype=ht.float64, device=cpu:0, split=None) + """ + sanitation.sanitize_in(a) # pragma: no cover + + if a.ndim < 2: + raise RuntimeError("DNDarray must be at least two-dimensional.") + + m, n = a.shape[-2:] + if m != n: + raise RuntimeError("Last two dimensions of the DNDarray must be square.") + + if types.heat_type_is_exact(a.dtype): + raise RuntimeError("dtype of DNDarray must be floating-point.") + + # no split in the square matrices + if not a.is_distributed() or a.split < a.ndim - 2: + data = torch.linalg.det(a.larray) + sp = None if not a.is_distributed() else a.split + return DNDarray( + data, + a.shape[:-2], + types.heat_type_of(data), + split=sp, + device=a.device, + comm=a.comm, + balanced=a.balanced, + ) + + acopy = a.copy() + acopy = manipulations.reshape(acopy, (-1, m, m), new_split=a.split - a.ndim + 3) + adet = factories.ones(acopy.shape[0], dtype=a.dtype, device=a.device) + + for k in range(adet.shape[0]): + m = 0 + for i in range(n): + # partial pivoting + if np.isclose(acopy[k, i, i].item(), 0): + abord = True + for j in range(i + 1, n): + if not np.isclose(acopy[k, j, i].item(), 0): + if a.split == a.ndim - 2: # split=0 on square matrix + acopy[k, i, :], acopy[k, j, :] = acopy[k, j, :], acopy[k, i, :].copy() + else: # split=1 + acopy.larray[k, i, :], acopy.larray[k, j, :] = ( + acopy.larray[k, j, :], + acopy.larray[k, i, :].clone(), + ) + abord = False + m += 1 + break + if abord: + adet[k] = 0 + break + + adet[k] *= acopy[k, i, i] + z = acopy[k, i + 1 :, i, None].larray / acopy[k, i, i].item() + acopy[k, i + 1 :, :].larray -= z * acopy[k, i, :].larray + + if m % 2 != 0: + adet[k] = -adet[k] + + adet = manipulations.reshape(adet, a.shape[:-2]) + + return adet + + def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDarray, float]: """ Returns the dot product of two ``DNDarrays``. diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index ff034da3af..fd32f99502 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -90,6 +90,84 @@ def test_cross(self): with self.assertRaises(ValueError): ht.cross(ht.eye(3, split=0), ht.eye(3, split=0), axis=0) + def test_det(self): + # (3,3) with pivoting + ares = ht.array(54.0) + a = ht.array([[-2.0, -1, 2], [2, 1, 4], [-3, 3, -1]], split=0, dtype=ht.double) + adet = ht.linalg.det(a) + + self.assertTupleEqual(adet.shape, ares.shape) + self.assertIsNone(adet.split) + self.assertEqual(adet.dtype, a.dtype) + self.assertEqual(adet.device, a.device) + self.assertTrue(ht.equal(adet, ares)) + + a = ht.array([[-2.0, -1, 2], [2, 1, 4], [-3, 3, -1]], split=1, dtype=ht.double) + adet = ht.linalg.det(a) + + self.assertTupleEqual(adet.shape, ares.shape) + self.assertIsNone(adet.split) + self.assertEqual(adet.dtype, a.dtype) + self.assertEqual(adet.device, a.device) + self.assertTrue(ht.equal(adet, ares)) + + # det==0 + ares = ht.array(0.0) + a = ht.array([[0, 0, 0], [2, 1, 4], [-3, 3, -1]], dtype=ht.float64, split=0) + adet = ht.linalg.det(a) + + self.assertTupleEqual(adet.shape, ares.shape) + self.assertIsNone(adet.split) + self.assertEqual(adet.dtype, a.dtype) + self.assertEqual(adet.device, a.device) + self.assertTrue(ht.equal(adet, ares)) + + # (3,2,2) + ares = ht.array([-2.0, -3.0, -8.0]) + + a = ht.array([[[1.0, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]]) + adet = ht.linalg.det(a) + + self.assertTupleEqual(adet.shape, ares.shape) + self.assertIsNone(adet.split) + self.assertEqual(adet.dtype, a.dtype) + self.assertEqual(adet.device, a.device) + self.assertTrue(ht.allclose(adet, ares)) + + a = ht.array([[[1.0, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]], split=0) + adet = ht.linalg.det(a) + + self.assertTupleEqual(adet.shape, ares.shape) + self.assertEqual(adet.split, a.split if a.is_distributed() else None) + self.assertEqual(adet.dtype, a.dtype) + self.assertEqual(adet.device, a.device) + self.assertTrue(ht.allclose(adet, ares)) + + a = ht.array([[[1.0, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]], split=1) + adet = ht.linalg.det(a) + + self.assertTupleEqual(adet.shape, ares.shape) + self.assertIsNone(adet.split) + self.assertEqual(adet.dtype, a.dtype) + self.assertEqual(adet.device, a.device) + self.assertTrue(ht.allclose(adet, ares)) + + a = ht.array([[[1.0, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]], split=2) + adet = ht.linalg.det(a) + + self.assertTupleEqual(adet.shape, ares.shape) + self.assertIsNone(adet.split) + self.assertEqual(adet.dtype, a.dtype) + self.assertEqual(adet.device, a.device) + self.assertTrue(ht.allclose(adet, ares)) + + with self.assertRaises(RuntimeError): + ht.linalg.det(ht.array([1, 2, 3], split=0)) + with self.assertRaises(RuntimeError): + ht.linalg.det(ht.zeros((2, 2, 3), split=1)) + with self.assertRaises(RuntimeError): + ht.linalg.det(ht.zeros((2, 2), dtype=ht.int, split=0)) + def test_dot(self): # ONLY TESTING CORRECTNESS! ALL CALLS IN DOT ARE PREVIOUSLY TESTED # cases to test: