diff --git a/CHANGELOG.md b/CHANGELOG.md index 69fcbd0312..f373a718d5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,12 +1,14 @@ # Pending additions - ## Bug Fixes - [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension +- [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `_reduce_op` when axis and keepdim were set. +- [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `min`, `max` where DNDarrays with empty processes can't be computed. ## Feature Additions ### Linear Algebra - [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()` +- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm` ### Manipulations - [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll` - [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes` diff --git a/heat/core/_operations.py b/heat/core/_operations.py index 7d155505d4..8a8e8ea835 100644 --- a/heat/core/_operations.py +++ b/heat/core/_operations.py @@ -419,7 +419,10 @@ def __reduce_op( else: output_shape = x.gshape for dim in axis: - partial = partial_op(partial, dim=dim, keepdim=True) + if not ( + partial.shape.numel() == 0 and partial_op.__name__ in ("local_max", "local_min") + ): # no neutral element for max/min + partial = partial_op(partial, dim=dim, keepdim=True) output_shape = output_shape[:dim] + (1,) + output_shape[dim + 1 :] if not keepdim and not len(partial.shape) == 1: gshape_losedim = tuple(x.gshape[dim] for dim in range(len(x.gshape)) if dim not in axis) @@ -439,7 +442,7 @@ def __reduce_op( balanced = True if x.comm.is_distributed(): x.comm.Allreduce(MPI.IN_PLACE, partial, reduction_op) - elif axis is not None: + elif axis is not None and not keepdim: down_dims = len(tuple(dim for dim in axis if dim < x.split)) split -= down_dims balanced = x.balanced diff --git a/heat/core/linalg/basics.py b/heat/core/linalg/basics.py index a031ec2c95..5c6225d7a0 100644 --- a/heat/core/linalg/basics.py +++ b/heat/core/linalg/basics.py @@ -6,20 +6,27 @@ import torch import warnings -from typing import List, Callable, Union, Optional +from typing import List, Callable, Union, Optional, Tuple, Literal + +from torch._C import Value from ..communication import MPI from .. import arithmetics +from .. import complex_math +from .. import constants from .. import exponential from ..dndarray import DNDarray from .. import factories from .. import manipulations +from .. import rounding from .. import sanitation +from .. import statistics from .. import types __all__ = [ "dot", "matmul", + "matrix_norm", "norm", "outer", "projection", @@ -28,6 +35,7 @@ "tril", "triu", "vecdot", + "vector_norm", ] @@ -79,9 +87,9 @@ def dot(a: DNDarray, b: DNDarray, out: Optional[DNDarray] = None) -> Union[DNDar a.comm.Allreduce(MPI.IN_PLACE, ret, MPI.SUM) if out is not None: - out = ret.item() + out = DNDarray(ret, (), types.heat_type_of(ret), None, a.device, a.comm, True) return out - return ret.item() + return DNDarray(ret, (), types.heat_type_of(ret), None, a.device, a.comm, True) elif a.ndim <= 2 and b.ndim <= 2: # 2. If both a and b are 2-D arrays, it is matrix multiplication, but using matmul or a @ b is preferred. ret = matmul(a, b) @@ -768,24 +776,277 @@ def matmul(a: DNDarray, b: DNDarray, allow_resplit: bool = False) -> DNDarray: DNDarray.__matmul__ = lambda self, other: matmul(self, other) -def norm(a: DNDarray) -> float: +def matrix_norm( + x: DNDarray, + axis: Optional[Tuple[int, int]] = None, + keepdims: bool = False, + ord: Optional[Union[int, Literal["fro", "nuc"]]] = None, +) -> DNDarray: """ - Return the vector norm (Frobenius norm) of vector ``a``. + Computes the matrix norm of an array. Parameters ---------- - a : DNDarray + x : DNDarray + Input array + axis : tuple, optional + Both axes of the matrix. If `None` 'x' must be a matrix. Default: `None` + keepdims : bool, optional + Retains the reduced dimension when `True`. Default: `False` + ord : int, 'fro', 'nuc', optional + The matrix norm order to compute. If `None` the Frobenius norm (`'fro'`) is used. Default: `None` + + See Also + -------- + norm + Computes the vector or matrix norm of an array. + vector_norm + Computes the vector norm of an array. + + Notes + ----- + The following norms are supported: + + ===== ============================ + ord norm for matrices + ===== ============================ + None Frobenius norm + 'fro' Frobenius norm + 'nuc' nuclear norm + inf max(sum(abs(x), axis=1)) + -inf min(sum(abs(x), axis=1)) + 1 max(sum(abs(x), axis=0)) + -1 min(sum(abs(x), axis=0)) + ===== ============================ + + The following matrix norms are currently **not** supported: + + ===== ============================ + ord norm for matrices + ===== ============================ + 2 largest singular value + -2 smallest singular value + ===== ============================ + + Raises + ------ + TypeError + If axis is not a 2-tuple + ValueError + If an invalid matrix norm is given or 'x' is a vector. + + Examples + -------- + >>> ht.matrix_norm(ht.array([[1,2],[3,4]])) + DNDarray([[5.4772]], dtype=ht.float64, device=cpu:0, split=None) + >>> ht.matrix_norm(ht.array([[1,2],[3,4]]), keepdims=True, ord=-1) + DNDarray([[4.]], dtype=ht.float64, device=cpu:0, split=None) + """ + sanitation.sanitize_in(x) + + if x.ndim < 2: + raise ValueError("Cannot compute a matrix norm of a vector.") + + if axis is None: + if x.ndim > 2: + raise ValueError("Cannot infer axis on arrays with more than two dimensions.") + else: + axis = (0, 1) + + if (not isinstance(axis, tuple)) or len(axis) != 2: + raise TypeError("'axis' must be a 2-tuple.") + + row_axis, col_axis = axis + + if ord == 1: + if col_axis > row_axis and not keepdims: + col_axis -= 1 + return statistics.max( + arithmetics.sum(rounding.abs(x), axis=row_axis, keepdim=keepdims), + axis=col_axis, + keepdim=keepdims, + ) + elif ord == -1: + if col_axis > row_axis and not keepdims: + col_axis -= 1 + return statistics.min( + arithmetics.sum(rounding.abs(x), axis=row_axis, keepdim=keepdims), + axis=col_axis, + keepdim=keepdims, + ) + elif ord == 2: + raise NotImplementedError("The largest singular value can't be computed yet.") + elif ord == -2: + raise NotImplementedError("The smallest singular value can't be computed yet.") + elif ord == constants.inf: + if row_axis > col_axis and not keepdims: + row_axis -= 1 + return statistics.max( + arithmetics.sum(rounding.abs(x), axis=col_axis, keepdim=keepdims), + axis=row_axis, + keepdim=keepdims, + ) + elif ord == -constants.inf: + if row_axis > col_axis and not keepdims: + row_axis -= 1 + return statistics.min( + arithmetics.sum(rounding.abs(x), axis=col_axis, keepdim=keepdims), + axis=row_axis, + keepdim=keepdims, + ) + elif ord in [None, "fro"]: + return exponential.sqrt( + arithmetics.sum((complex_math.conj(x) * x).real, axis=axis, keepdim=keepdims) + ) + elif ord == "nuc": + raise NotImplementedError("The nuclear norm can't be computed yet.") + else: + raise ValueError("Invalid norm order for matrices.") + + +def norm( + x: DNDarray, + axis: Optional[Union[int, Tuple[int, int]]] = None, + keepdims: bool = False, + ord: Optional[Union[int, float, Literal["fro", "nuc"]]] = None, +) -> DNDarray: + """ + Return the vector or matrix norm of an array. + + Parameters + ---------- + x : DNDarray Input vector - """ # noqa: D402 - if not isinstance(a, DNDarray): - raise TypeError("a must be of type ht.DNDarray, but was {}".format(type(a))) + axis : int, tuple, optional + Axes along which to compute the norm. If an integer, vector norm is used. If a 2-tuple, matrix norm is used. + If `None`, it is inferred from the dimension of the array. Default: `None` + keepdims : bool, optional + Retains the reduced dimension when `True`. Default: `False` + ord : int, float, inf, -inf, 'fro', 'nuc' + The norm order to compute. See Notes + + See Also + -------- + vector_norm + Computes the vector norm of an array. + matrix_norm + Computes the matrix norm of an array. + + Notes + ----- + The following norms are supported: + + ===== ============================ ========================== + ord norm for matrices norm for vectors + ===== ============================ ========================== + None Frobenius norm L2-norm (Euclidean) + 'fro' Frobenius norm -- + 'nuc' nuclear norm -- + inf max(sum(abs(x), axis=1)) max(abs(x)) + -inf min(sum(abs(x), axis=1)) min(abs(x)) + 0 -- sum(x != 0) + 1 max(sum(abs(x), axis=0)) L1-norm (Manhattan) + -1 min(sum(abs(x), axis=0)) 1./sum(1./abs(a)) + 2 -- L2-norm (Euclidean) + -2 -- 1./sqrt(sum(1./abs(a)**2)) + other -- sum(abs(x)**ord)**(1./ord) + ===== ============================ ========================== + + The following matrix norms are currently **not** supported: + + ===== ============================ + ord norm for matrices + ===== ============================ + 2 largest singular value + -2 smallest singular value + ===== ============================ - d = a ** 2 + Raises + ------ + ValueError + If 'axis' has more than 2 elements - for i in range(len(a.shape) - 1, -1, -1): - d = arithmetics.sum(d, axis=i) + Examples + -------- + >>> from heat import linalg as LA + >>> a = ht.arange(9, dtype=ht.float) - 4 + >>> a + DNDarray([-4., -3., -2., -1., 0., 1., 2., 3., 4.], dtype=ht.float32, device=cpu:0, split=None) + >>> b = a.reshape((3, 3)) + >>> b + DNDarray([[-4., -3., -2.], + [-1., 0., 1.], + [ 2., 3., 4.]], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(a) + DNDarray(7.7460, dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(b) + DNDarray(7.7460, dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(b, ord='fro') + DNDarray(7.7460, dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(a, float('inf')) + DNDarray([4.], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(b, ht.inf) + DNDarray([9.], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(a, -ht.inf)) + DNDarray([0.], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(b, -ht.inf) + DNDarray([2.], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(a, 1) + DNDarray([20.], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(b, 1) + DNDarray([7.], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(a, -1) + DNDarray([0.], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(b, -1) + DNDarray([6.], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(a, 2) + DNDarray(7.7460, dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(a, -2) + DNDarray([0.], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(a, 3) + DNDarray([5.8480], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(a, -3) + DNDarray([0.], dtype=ht.float32, device=cpu:0, split=None) + c = ht.array([[ 1, 2, 3], + [-1, 1, 4]]) + >>> LA.norm(c, axis=0) + DNDarray([1.4142, 2.2361, 5.0000], dtype=ht.float64, device=cpu:0, split=None) + >>> LA.norm(c, axis=1) + DNDarray([3.7417, 4.2426], dtype=ht.float64, device=cpu:0, split=None) + >>> LA.norm(c, axis=1, ord=1) + DNDarray([6., 6.], dtype=ht.float64, device=cpu:0, split=None) + >>> m = ht.arange(8).reshape(2,2,2) + >>> LA.norm(m, axis=(1,2)) + DNDarray([ 3.7417, 11.2250], dtype=ht.float32, device=cpu:0, split=None) + >>> LA.norm(m[0, :, :]), LA.norm(m[1, :, :]) + (DNDarray(3.7417, dtype=ht.float32, device=cpu:0, split=None), DNDarray(11.2250, dtype=ht.float32, device=cpu:0, split=None)) + """ + sanitation.sanitize_in(x) - return exponential.sqrt(d).item() + ndim = x.ndim + + if axis is None: + if ord is None or (ord == 2 and ndim == 1) or (ord == "fro" and ndim == 2): + x = x.flatten() + if types.issubdtype(x.dtype, types.complex): + sqnorm = dot(x.real, x.real) + dot(x.imag, x.imag) + else: + sqnorm = dot(x, x) + ret = exponential.sqrt(sqnorm) + if keepdims: + ret = ret.reshape(ndim * [1]) + return ret + elif ndim == 2: + return matrix_norm(x, axis, keepdims, ord) + else: + return vector_norm(x, axis, keepdims, ord) + + if isinstance(axis, int) or len(axis) == 1: + return vector_norm(x, axis, keepdims, ord) + elif len(axis) == 2: + return matrix_norm(x, axis, keepdims, ord) + else: + raise ValueError("Improper number of dimensions to norm.") DNDarray.norm: Callable[[DNDarray], float] = lambda self: norm(self) @@ -1691,3 +1952,95 @@ def vecdot( axis = m.ndim - 1 return arithmetics.sum(m, axis=axis, keepdim=keepdim) + + +def vector_norm( + x: DNDarray, + axis: Optional[Union[int, Tuple[int]]] = None, + keepdims=False, + ord: Optional[Union[int, float]] = None, +) -> DNDarray: + """ + Computes the vector norm of an array. + + Parameters + ---------- + x : DNDarray + Input array + axis : int, tuple, optional + Axis along which to compute the vector norm. If `None` 'x' must be a vector. Default: `None` + keepdims : bool, optional + Retains the reduced dimension when `True`. Default: `False` + ord : int, float, optional + The norm order to compute. If `None` the euclidean norm (`2`) is used. Default: `None` + + See Also + -------- + norm + Computes the vector norm or matrix norm of an array. + matrix_norm + Computes the matrix norm of an array. + + Notes + ----- + The following norms are suported: + + ===== ========================== + ord norm for vectors + ===== ========================== + None L2-norm (Euclidean) + inf max(abs(x)) + -inf min(abs(x)) + 0 sum(x != 0) + 1 L1-norm (Manhattan) + -1 1./sum(1./abs(a)) + 2 L2-norm (Euclidean) + -2 1./sqrt(sum(1./abs(a)**2)) + other sum(abs(x)**ord)**(1./ord) + ===== ========================== + + Raises + ------ + TypeError + If axis is not an integer or a 1-tuple + ValueError + If an invalid vector norm is given. + + Examples + -------- + >>> ht.vector_norm(ht.array([1,2,3,4])) + DNDarray([5.4772], dtype=ht.float64, device=cpu:0, split=None) + >>> ht.vector_norm(ht.array([[1,2],[3,4]]), axis=0, ord=1) + DNDarray([[4., 6.]], dtype=ht.float64, device=cpu:0, split=None) + """ + sanitation.sanitize_in(x) + + if axis is None: + pass + elif isinstance(axis, tuple): + if len(axis) > 1: + raise TypeError("'axis' must be an integer or 1-tuple for vectors.") + else: + try: + axis = int(axis) + except Exception: + raise TypeError("'axis' must be an integer or 1-tuple for vectors.") + + if ord == constants.INF: + return statistics.max(rounding.abs(x), axis=axis, keepdim=keepdims) + elif ord == -constants.INF: + return statistics.min(rounding.abs(x), axis=axis, keepdim=keepdims) + elif ord == 0: + return arithmetics.sum(x != 0, axis=axis, keepdim=keepdims).astype(types.float) + elif ord == 1: + return arithmetics.sum(rounding.abs(x), axis=axis, keepdim=keepdims) + elif ord is None or ord == 2: + s = (complex_math.conj(x) * x).real + return exponential.sqrt(arithmetics.sum(s, axis=axis, keepdim=keepdims)) + elif isinstance(ord, str): + raise ValueError("Norm order {} is invalid for vectors".format(ord)) + else: + ret = arithmetics.pow(rounding.abs(x), ord) + ret = arithmetics.sum(ret, axis=axis, keepdim=keepdims) + ret = arithmetics.pow(ret, 1.0 / ord) + return ret diff --git a/heat/core/linalg/solver.py b/heat/core/linalg/solver.py index 9fd23fd343..e35ab25656 100644 --- a/heat/core/linalg/solver.py +++ b/heat/core/linalg/solver.py @@ -128,7 +128,7 @@ def lanczos( V[:, 0] = v0 for i in range(1, int(m)): beta = ht.norm(w) - if abs(beta) < 1e-10: + if ht.abs(beta) < 1e-10: # print("Lanczos breakdown in iteration {}".format(i)) # Lanczos Breakdown, pick a random vector to continue vr = ht.random.rand(n, dtype=A.dtype, split=V.split) diff --git a/heat/core/linalg/tests/test_basics.py b/heat/core/linalg/tests/test_basics.py index 466798cd4d..93a5948f36 100644 --- a/heat/core/linalg/tests/test_basics.py +++ b/heat/core/linalg/tests/test_basics.py @@ -534,28 +534,122 @@ def test_matmul(self): b = a.copy() a @ b - def test_norm(self): - a = ht.arange(9, dtype=ht.float32, split=0) - 4 - self.assertTrue( - ht.allclose(ht.linalg.norm(a), ht.float32(np.linalg.norm(a.numpy())).item(), atol=1e-5) - ) - a.resplit_(axis=None) - self.assertTrue( - ht.allclose(ht.linalg.norm(a), ht.float32(np.linalg.norm(a.numpy())).item(), atol=1e-5) - ) - - b = ht.array([[-4.0, -3.0, -2.0], [-1.0, 0.0, 1.0], [2.0, 3.0, 4.0]], split=0) - self.assertTrue( - ht.allclose(ht.linalg.norm(b), ht.float32(np.linalg.norm(b.numpy())).item(), atol=1e-5) - ) - b.resplit_(axis=1) - self.assertTrue( - ht.allclose(ht.linalg.norm(b), ht.float32(np.linalg.norm(b.numpy())).item(), atol=1e-5) - ) - + def test_matrix_norm(self): + a = ht.arange(9, dtype=ht.float) - 4 + b = a.reshape((3, 3)) + b0 = a.reshape((3, 3), new_split=0) + b1 = a.reshape((3, 3), new_split=1) + + # different ord + mn = ht.linalg.matrix_norm(b, ord="fro") + self.assertEqual(mn.split, b.split) + self.assertEqual(mn.dtype, b.dtype) + self.assertEqual(mn.device, b.device) + self.assertTrue(ht.allclose(mn, ht.array(7.745966692414834))) + + mn = ht.linalg.matrix_norm(b0, ord=1) + self.assertEqual(mn.split, b.split) + self.assertEqual(mn.dtype, b.dtype) + self.assertEqual(mn.device, b.device) + self.assertEqual(mn.item(), 7.0) + + mn = ht.linalg.matrix_norm(b0, ord=-1) + self.assertEqual(mn.split, b.split) + self.assertEqual(mn.dtype, b.dtype) + self.assertEqual(mn.device, b.device) + self.assertEqual(mn.item(), 6.0) + + mn = ht.linalg.matrix_norm(b1) + self.assertEqual(mn.split, b.split) + self.assertEqual(mn.dtype, b.dtype) + self.assertEqual(mn.device, b.device) + self.assertTrue(ht.allclose(mn, ht.array(7.745966692414834))) + + # higher dimension + different dtype + m = ht.arange(8).reshape(2, 2, 2) + mn = ht.linalg.matrix_norm(m, axis=(2, 1), ord=ht.inf) + self.assertEqual(mn.split, m.split) + self.assertEqual(mn.dtype, ht.float) + self.assertEqual(mn.device, m.device) + self.assertTrue(ht.equal(mn, ht.array([4.0, 12.0]))) + + mn = ht.linalg.matrix_norm(m, axis=(2, 1), ord=-ht.inf) + self.assertEqual(mn.split, m.split) + self.assertEqual(mn.dtype, ht.float) + self.assertEqual(mn.device, m.device) + self.assertTrue(ht.equal(mn, ht.array([2.0, 10.0]))) + + # too many axis to infer + with self.assertRaises(ValueError): + ht.linalg.matrix_norm(ht.ones((2, 2, 2))) + # bad axis with self.assertRaises(TypeError): - c = np.arange(9) - 4 - ht.linalg.norm(c) + ht.linalg.matrix_norm(ht.ones((2, 2)), axis=1) + with self.assertRaises(TypeError): + ht.linalg.matrix_norm(ht.ones(2, 2), axis=(1, 2, 3)) + # bad array + with self.assertRaises(ValueError): + ht.linalg.matrix_norm(ht.array([1, 2, 3])) + # bad ord + with self.assertRaises(ValueError): + ht.linalg.matrix_norm(ht.ones((2, 2)), ord=3) + # Not implemented yet; SVD needed + with self.assertRaises(NotImplementedError): + ht.linalg.matrix_norm(ht.ones((2, 2)), ord=2) + with self.assertRaises(NotImplementedError): + ht.linalg.matrix_norm(ht.ones((2, 2)), ord=-2) + with self.assertRaises(NotImplementedError): + ht.linalg.matrix_norm(ht.ones((2, 2)), ord="nuc") + + def test_norm(self): + a = ht.arange(9, dtype=ht.float) - 4 + a0 = ht.array([1 + 1j, 2 - 2j, 0 + 1j, 2 + 1j], dtype=ht.complex64, split=0) + b = a.reshape((3, 3)) + b0 = a.reshape((3, 3), new_split=0) + b1 = a.reshape((3, 3), new_split=1) + + # vectors + gn = ht.linalg.norm(a, axis=0, ord=1) + self.assertEqual(gn.split, a.split) + self.assertEqual(gn.dtype, a.dtype) + self.assertEqual(gn.device, a.device) + self.assertEqual(gn.item(), 20.0) + + # complex type + gn = ht.linalg.norm(a0, keepdims=True) + self.assertEqual(gn.split, None) + self.assertEqual(gn.dtype, ht.float) + self.assertEqual(gn.device, a0.device) + self.assertEqual(gn.item(), 4.0) + + # matrices + gn = ht.linalg.norm(b, ord="fro") + self.assertEqual(gn.split, None) + self.assertEqual(gn.dtype, b.dtype) + self.assertEqual(gn.device, b.device) + self.assertTrue(ht.allclose(gn, ht.array(7.745966692414834))) + + gn = ht.linalg.norm(b0, ord=ht.inf) + self.assertEqual(gn.split, None) + self.assertEqual(gn.dtype, b0.dtype) + self.assertEqual(gn.device, b0.device) + self.assertEqual(gn.item(), 9.0) + + gn = ht.linalg.norm(b1, axis=(0,), ord=-ht.inf, keepdims=True) + self.assertEqual(gn.split, b1.split) + self.assertEqual(gn.dtype, b1.dtype) + self.assertEqual(gn.device, b1.device) + self.assertTrue(ht.equal(gn, ht.array([[1.0, 0.0, 1.0]]))) + + # higher dimension + different dtype + gn = ht.linalg.norm(ht.ones((3, 3, 3), dtype=ht.int), axis=(-2, -1)) + self.assertEqual(gn.split, None) + self.assertEqual(gn.dtype, ht.float) + self.assertTrue(ht.equal(gn, ht.array([3.0, 3.0, 3.0]))) + + # bad axis + with self.assertRaises(ValueError): + ht.linalg.norm(ht.ones(2), axis=(0, 1, 2)) def test_outer(self): # test outer, a and b local, different dtypes @@ -1690,3 +1784,81 @@ def test_vecdot(self): self.assertEqual(c.dtype, ht.float32) self.assertEqual(c.device, a.device) self.assertTrue(ht.equal(c, ht.array([[8, 8, 8, 8]]))) + + def test_vector_norm(self): + a = ht.arange(9, dtype=ht.float) - 4 + a_split = ht.arange(9, dtype=ht.float, split=0) - 4 + b = a.reshape((3, 3)) + b0 = ht.reshape(a, (3, 3), new_split=0) + b1 = ht.reshape(a, (3, 3), new_split=1) + + # vector infintity norm + vn = ht.vector_norm(a, ord=ht.inf) + self.assertEqual(vn.split, a.split) + self.assertEqual(vn.dtype, a.dtype) + self.assertEqual(vn.device, a.device) + self.assertEqual(vn.item(), 4.0) + + # vector 0 norm + vn = ht.vector_norm(a, ord=0) + self.assertEqual(vn.split, a.split) + self.assertEqual(vn.dtype, a.dtype) + self.assertEqual(vn.device, a.device) + self.assertEqual(vn.item(), 8.0) + + # split vector -infinity + vn = ht.vector_norm(a_split, ord=-ht.inf) + self.assertEqual(vn.split, a.split) + self.assertEqual(vn.dtype, a.dtype) + self.assertEqual(vn.device, a.device) + self.assertEqual(vn.item(), 0.0) + + # matrix 1 norm no axis + vn = ht.vector_norm(b, ord=1) + self.assertEqual(vn.split, b.split) + self.assertEqual(vn.dtype, b.dtype) + self.assertEqual(vn.device, b.device) + self.assertEqual(vn.item(), 20.0) + + # split matrix axis l2-norm + vn = ht.vector_norm(b0, axis=1, ord=2) + self.assertEqual(vn.split, 0) + self.assertEqual(vn.dtype, b0.dtype) + self.assertEqual(vn.device, b0.device) + self.assertTrue(ht.allclose(vn, ht.array([5.38516481, 1.41421356, 5.38516481], split=0))) + + # split matrix axis keepdim norm 3 + vn = ht.vector_norm(b1, axis=1, keepdims=True, ord=3) + self.assertEqual(vn.split, None) + self.assertEqual(vn.dtype, b1.dtype) + self.assertEqual(vn.device, b1.device) + self.assertTrue( + ht.allclose(vn, ht.array([[4.62606501], [1.25992105], [4.62606501]], split=None)) + ) + + # different dtype + vn = ht.linalg.vector_norm(ht.full((4, 4, 4), 1 + 1j, dtype=ht.int), axis=0, ord=4) + self.assertEqual(vn.split, None) + self.assertEqual(vn.dtype, ht.float) + self.assertTrue( + ht.equal( + vn, + ht.array( + [ + [2.0, 2.0, 2.0, 2.0], + [2.0, 2.0, 2.0, 2.0], + [2.0, 2.0, 2.0, 2.0], + [2.0, 2.0, 2.0, 2.0], + ] + ), + ) + ) + + # bad ord + with self.assertRaises(ValueError): + ht.vector_norm(ht.array([1, 2, 3]), ord="fro") + # bad axis + with self.assertRaises(TypeError): + ht.vector_norm(ht.array([1, 2, 3]), axis=(1, 2)) + with self.assertRaises(TypeError): + ht.vector_norm(ht.array([1, 2, 3]), axis="r")