Skip to content

Commit

Permalink
Merge pull request #877 from helmholtz-analytics/feature/337-determinant
Browse files Browse the repository at this point in the history
Feature/337 determinant
  • Loading branch information
coquelin77 committed Jan 24, 2022
2 parents 293d873 + 593ad7f commit 2867fe9
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 1 deletion.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`
- [#850](https://github.com/helmholtz-analytics/heat/pull/850) New Feature `cross`
- [#877](https://github.com/helmholtz-analytics/heat/pull/877) New feature `det`

### Logical
- [#862](https://github.com/helmholtz-analytics/heat/pull/862) New feature `signbit`
### Manipulations
Expand Down
2 changes: 1 addition & 1 deletion heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def __init__(
array: torch.Tensor,
gshape: Tuple[int, ...],
dtype: datatype,
split: int,
split: Union[int, None],
device: Device,
comm: Communication,
balanced: bool,
Expand Down
87 changes: 87 additions & 0 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

__all__ = [
"cross",
"det",
"dot",
"matmul",
"matrix_norm",
Expand Down Expand Up @@ -155,6 +156,92 @@ def cross(
return ret


def det(a: DNDarray) -> DNDarray:
"""
Returns the determinant of a square matrix.
Parameters
----------
a : DNDarray
A square matrix or a stack of matrices. Shape = (...,M,M)
Raises
------
RuntimeError
If the dtype of 'a' is not floating-point.
RuntimeError
If `a.ndim < 2` or if the length of the last two dimensions is not the same.
Examples
--------
>>> a = ht.array([[-2,-1,2],[2,1,4],[-3,3,-1]])
>>> ht.linalg.det(a)
DNDarray(54., dtype=ht.float64, device=cpu:0, split=None)
"""
sanitation.sanitize_in(a) # pragma: no cover

if a.ndim < 2:
raise RuntimeError("DNDarray must be at least two-dimensional.")

m, n = a.shape[-2:]
if m != n:
raise RuntimeError("Last two dimensions of the DNDarray must be square.")

if types.heat_type_is_exact(a.dtype):
raise RuntimeError("dtype of DNDarray must be floating-point.")

# no split in the square matrices
if not a.is_distributed() or a.split < a.ndim - 2:
data = torch.linalg.det(a.larray)
sp = None if not a.is_distributed() else a.split
return DNDarray(
data,
a.shape[:-2],
types.heat_type_of(data),
split=sp,
device=a.device,
comm=a.comm,
balanced=a.balanced,
)

acopy = a.copy()
acopy = manipulations.reshape(acopy, (-1, m, m), new_split=a.split - a.ndim + 3)
adet = factories.ones(acopy.shape[0], dtype=a.dtype, device=a.device)

for k in range(adet.shape[0]):
m = 0
for i in range(n):
# partial pivoting
if np.isclose(acopy[k, i, i].item(), 0):
abord = True
for j in range(i + 1, n):
if not np.isclose(acopy[k, j, i].item(), 0):
if a.split == a.ndim - 2: # split=0 on square matrix
acopy[k, i, :], acopy[k, j, :] = acopy[k, j, :], acopy[k, i, :].copy()
else: # split=1
acopy.larray[k, i, :], acopy.larray[k, j, :] = (
acopy.larray[k, j, :],
acopy.larray[k, i, :].clone(),
)
abord = False
m += 1
break
if abord:
adet[k] = 0
break

adet[k] *= acopy[k, i, i]
z = acopy[k, i + 1 :, i, None].larray / acopy[k, i, i].item()
acopy[k, i + 1 :, :].larray -= z * acopy[k, i, :].larray

if m % 2 != 0:
adet[k] = -adet[k]

adet = manipulations.reshape(adet, a.shape[:-2])

return adet


def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDarray, float]:
"""
Returns the dot product of two ``DNDarrays``.
Expand Down
78 changes: 78 additions & 0 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,84 @@ def test_cross(self):
with self.assertRaises(ValueError):
ht.cross(ht.eye(3, split=0), ht.eye(3, split=0), axis=0)

def test_det(self):
# (3,3) with pivoting
ares = ht.array(54.0)
a = ht.array([[-2.0, -1, 2], [2, 1, 4], [-3, 3, -1]], split=0, dtype=ht.double)
adet = ht.linalg.det(a)

self.assertTupleEqual(adet.shape, ares.shape)
self.assertIsNone(adet.split)
self.assertEqual(adet.dtype, a.dtype)
self.assertEqual(adet.device, a.device)
self.assertTrue(ht.equal(adet, ares))

a = ht.array([[-2.0, -1, 2], [2, 1, 4], [-3, 3, -1]], split=1, dtype=ht.double)
adet = ht.linalg.det(a)

self.assertTupleEqual(adet.shape, ares.shape)
self.assertIsNone(adet.split)
self.assertEqual(adet.dtype, a.dtype)
self.assertEqual(adet.device, a.device)
self.assertTrue(ht.equal(adet, ares))

# det==0
ares = ht.array(0.0)
a = ht.array([[0, 0, 0], [2, 1, 4], [-3, 3, -1]], dtype=ht.float64, split=0)
adet = ht.linalg.det(a)

self.assertTupleEqual(adet.shape, ares.shape)
self.assertIsNone(adet.split)
self.assertEqual(adet.dtype, a.dtype)
self.assertEqual(adet.device, a.device)
self.assertTrue(ht.equal(adet, ares))

# (3,2,2)
ares = ht.array([-2.0, -3.0, -8.0])

a = ht.array([[[1.0, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]])
adet = ht.linalg.det(a)

self.assertTupleEqual(adet.shape, ares.shape)
self.assertIsNone(adet.split)
self.assertEqual(adet.dtype, a.dtype)
self.assertEqual(adet.device, a.device)
self.assertTrue(ht.allclose(adet, ares))

a = ht.array([[[1.0, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]], split=0)
adet = ht.linalg.det(a)

self.assertTupleEqual(adet.shape, ares.shape)
self.assertEqual(adet.split, a.split if a.is_distributed() else None)
self.assertEqual(adet.dtype, a.dtype)
self.assertEqual(adet.device, a.device)
self.assertTrue(ht.allclose(adet, ares))

a = ht.array([[[1.0, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]], split=1)
adet = ht.linalg.det(a)

self.assertTupleEqual(adet.shape, ares.shape)
self.assertIsNone(adet.split)
self.assertEqual(adet.dtype, a.dtype)
self.assertEqual(adet.device, a.device)
self.assertTrue(ht.allclose(adet, ares))

a = ht.array([[[1.0, 2], [3, 4]], [[1, 2], [2, 1]], [[1, 3], [3, 1]]], split=2)
adet = ht.linalg.det(a)

self.assertTupleEqual(adet.shape, ares.shape)
self.assertIsNone(adet.split)
self.assertEqual(adet.dtype, a.dtype)
self.assertEqual(adet.device, a.device)
self.assertTrue(ht.allclose(adet, ares))

with self.assertRaises(RuntimeError):
ht.linalg.det(ht.array([1, 2, 3], split=0))
with self.assertRaises(RuntimeError):
ht.linalg.det(ht.zeros((2, 2, 3), split=1))
with self.assertRaises(RuntimeError):
ht.linalg.det(ht.zeros((2, 2), dtype=ht.int, split=0))

def test_dot(self):
# ONLY TESTING CORRECTNESS! ALL CALLS IN DOT ARE PREVIOUSLY TESTED
# cases to test:
Expand Down

0 comments on commit 2867fe9

Please sign in to comment.