Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

norm implementation #846

Merged
merged 15 commits into from
Aug 20, 2021
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
# Pending additions

## Bug Fixes
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `_reduce_op` when axis and keepdim were set.
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) Fixed an issue in `min`, `max` where DNDarrays with empty processes can't be computed.

## Feature Additions

### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`
### Manipulations
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`
- [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes`
Expand Down
7 changes: 5 additions & 2 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,7 +419,10 @@ def __reduce_op(
else:
output_shape = x.gshape
for dim in axis:
partial = partial_op(partial, dim=dim, keepdim=True)
if not (
partial.shape.numel() == 0 and partial_op.__name__ in ("local_max", "local_min")
): # no neutral element for max/min
partial = partial_op(partial, dim=dim, keepdim=True)
output_shape = output_shape[:dim] + (1,) + output_shape[dim + 1 :]
if not keepdim and not len(partial.shape) == 1:
gshape_losedim = tuple(x.gshape[dim] for dim in range(len(x.gshape)) if dim not in axis)
Expand All @@ -439,7 +442,7 @@ def __reduce_op(
balanced = True
if x.comm.is_distributed():
x.comm.Allreduce(MPI.IN_PLACE, partial, reduction_op)
elif axis is not None:
elif axis is not None and not keepdim:
coquelin77 marked this conversation as resolved.
Show resolved Hide resolved
down_dims = len(tuple(dim for dim in axis if dim < x.split))
split -= down_dims
balanced = x.balanced
Expand Down