Skip to content

Commit

Permalink
Merge pull request #507 from helmholtz-analytics/Bug/492-scalar_split
Browse files Browse the repository at this point in the history
bug distributed scalar
  • Loading branch information
Markus-Goetz committed Apr 3, 2020
2 parents 23d4be7 + 99233e9 commit 95fda46
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
- [#496](https://github.com/helmholtz-analytics/heat/pull/496) New feature: flipud()
- [#498](https://github.com/helmholtz-analytics/heat/pull/498) Feature: flip()
- [#499](https://github.com/helmholtz-analytics/heat/pull/499) Bugfix: MPI datatype mapping: `torch.int16` now maps to `MPI.SHORT` instead of `MPI.SHORT_INT`
- [#507](https://github.com/helmholtz-analytics/heat/pull/507) Bugfix: sanitize_axis changes axis of 0-dim scalars to None
- [#515](https://github.com/helmholtz-analytics/heat/pull/515) ht.var() now returns the unadjusted sample variance by default, Bessel's correction can be applied by setting ddof=1.
- [#519](https://github.com/helmholtz-analytics/heat/pull/519) Bugfix: distributed slicing with empty list or scalar as input; distributed nonzero() of empty (local) tensor.
- [#521](https://github.com/helmholtz-analytics/heat/pull/521) Add documentation for the generic reduce_op in Heat's core
Expand Down
4 changes: 4 additions & 0 deletions heat/core/stride_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ def sanitize_axis(shape, axis):
>>> sanitize_axis((5, 4), 1.0)
TypeError
"""
# scalars are handled like unsplit matrices
if len(shape) == 0:
axis = None

if axis is not None:
if not isinstance(axis, int) and not isinstance(axis, tuple):
raise TypeError("axis must be None or int or tuple, but was {}".format(type(axis)))
Expand Down
1 change: 1 addition & 0 deletions heat/core/tests/test_stride_tricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def test_sanitize_axis(self):
self.assertEqual(ht.core.stride_tricks.sanitize_axis((5, 4, 4), (-2, -3)), (1, 0))
self.assertEqual(ht.core.stride_tricks.sanitize_axis((5, 4), 0), 0)
self.assertEqual(ht.core.stride_tricks.sanitize_axis((5, 4), None), None)
self.assertEqual(ht.core.stride_tricks.sanitize_axis(tuple(), 0), None)

# invalid types
with self.assertRaises(TypeError):
Expand Down

0 comments on commit 95fda46

Please sign in to comment.