Skip to content

Commit

Permalink
Features/880 binop ben bou (#902)
Browse files Browse the repository at this point in the history
* Support different lshape maps in binary ops

* Adapt documentation

* Typos

* Update changelog

* Debugging tests

* Allow unbalanced and unequally balanced dndarrays for binops

* determine promoted_type in the beginning

* Refine distribution logic for non-distributed operands

* Change tests order

* Typo

* typo

* No broadcasting in split-dimension

* Remove redundant condition

* explicit broadcasting in split-dimension

* debugging

* improve tests

* fix empty process bug

* remove debugging output

* empty process handling; run tests

* typo

* move the equalization of distribution to sanitation

* beautify

* wrong condition

* use None-indexing, fix broadcast_shapes

* Fix broadcast_shapes

* old expand_dims

* debug

* allow broadcasting in split dimension

* torch.Tensor check is needed because torch.equal is wrongly directed to binop

* allow different lshape-map in sanitize-distribution

* docs

* directly use tuple

* Fix error in stack

* Fix heat.equal error

* Modify `stride_tricks.broadcast_shape` to return `torch.broadcast_shapes(shape_a, shape_b)`

* Backward compatibility for `stride_tricks.broadcast_shape`

* Debug GPU tests

* Debugging

* Debugging GPU tests

* Fix bug in debugging statement

* Debugging

* Fix `broadcast_shape` when one dimension is 0 for torch < 1.8

* Remove print statements

* add tests

* check if balanced when using diff_map. TODO: debug is_balanced; currently only works with force_check

* found typo, no need for force check

* code review + docs

* fix error

* docs

* changing Tensor returned by create_lshape_map should not change the lshape_map attribute -> return copy of the lshape_map

* add tests

* add tests

* Edit error messages for `sanitize_out`

* adressed Review; mainly docs

Co-authored-by: Claudia Comito <c.comito@fz-juelich.de>
Co-authored-by: Ben Bourgart <b.bourgart@fz-juelich.de>
  • Loading branch information
3 people committed Jan 31, 2022
1 parent 35f39d1 commit dbb8300
Show file tree
Hide file tree
Showing 16 changed files with 578 additions and 259 deletions.
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ These tools only profile the memory used by each process, not the entire functio
- with `split=None` and `split not None`
Python has an embedded profiler: https://docs.python.org/3.9/library/profile.html
Again, this will only provile the performance on each process. Printing the results with many processes
Again, this will only profile the performance on each process. Printing the results with many processes
my be illegible. It may be easiest to save the output of each to a file.
--->

Expand Down
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) Fixed an issue in `__binary_op` where data was falsely distributed if a DNDarray has single element.

## Feature Additions
### Linear Algebra
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`

### Arithmetics
- - [#887](https://github.com/helmholtz-analytics/heat/pull/887) Binary operations now support operands of equal shapes, equal `split` axes, but different distribution maps.

## Feature additions
### Communication
Expand All @@ -25,6 +26,7 @@
# Feature additions
### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`
- [#850](https://github.com/helmholtz-analytics/heat/pull/850) New Feature `cross`
- [#877](https://github.com/helmholtz-analytics/heat/pull/877) New feature `det`
Expand Down
2 changes: 1 addition & 1 deletion heat/cluster/tests/test_kmedoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_spherical_dataset(
cluster4 = ht.stack((x - 2 * offset, y - 2 * offset, z - 2 * offset), axis=1)

data = ht.concatenate((cluster1, cluster2, cluster3, cluster4), axis=0)
# Note: enhance when shuffel is available
# Note: enhance when shuffle is available
return data

def test_clusterer(self):
Expand Down
233 changes: 110 additions & 123 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,147 +51,134 @@ def __binary_op(
-------
result: ht.DNDarray
A DNDarray containing the results of element-wise operation.
Warning
-------
If both operands are distributed, they must be distributed along the same dimension, i.e. `t1.split = t2.split`.
MPI communication is necessary when both operands are distributed along the same dimension, but the distribution maps do not match. E.g.:
```
a = ht.ones(10000, split=0)
b = ht.zeros(10000, split=0)
c = a[:-1] + b[1:]
```
In such cases, one of the operands is redistributed OUT-OF-PLACE to match the distribution map of the other operand.
The operand determining the resulting distribution is chosen as follows:
1) split is preferred to no split
2) no (shape)-broadcasting in the split dimension if not necessary
3) t1 is preferred to t2
"""
# Check inputs
if not np.isscalar(t1) and not isinstance(t1, DNDarray):
raise TypeError(
"Only DNDarrays and numeric scalars are supported, but input was {}".format(type(t1))
)
if not np.isscalar(t2) and not isinstance(t2, DNDarray):
raise TypeError(
"Only DNDarrays and numeric scalars are supported, but input was {}".format(type(t2))
)
promoted_type = types.result_type(t1, t2).torch_type()

if np.isscalar(t1):
# Make inputs Dndarrays
if np.isscalar(t1) and np.isscalar(t2):
try:
t1 = factories.array(t1, device=t2.device if isinstance(t2, DNDarray) else None)
t1 = factories.array(t1)
t2 = factories.array(t2)
except (ValueError, TypeError):
raise TypeError("Data type not supported, input was {}".format(type(t1)))

if np.isscalar(t2):
try:
t2 = factories.array(t2)
except (ValueError, TypeError):
raise TypeError(
"Only numeric scalars are supported, but input was {}".format(type(t2))
)
output_shape = (1,)
output_split = None
output_device = t2.device
output_comm = MPI_WORLD
elif isinstance(t2, DNDarray):
output_shape = t2.shape
output_split = t2.split
output_device = t2.device
output_comm = t2.comm
else:
raise TypeError(
"Only tensors and numeric scalars are supported, but input was {}".format(type(t2))
)

if t1.dtype != t2.dtype:
t1 = t1.astype(t2.dtype)

elif isinstance(t1, DNDarray):
if np.isscalar(t2):
try:
t2 = factories.array(t2, device=t1.device)
output_shape = t1.shape
output_split = t1.split
output_device = t1.device
output_comm = t1.comm
except (ValueError, TypeError):
raise TypeError("Data type not supported, input was {}".format(type(t2)))

elif isinstance(t2, DNDarray):
if t1.split is None:
t1 = factories.array(
t1, split=t2.split, copy=False, comm=t1.comm, device=t1.device, ndmin=-t2.ndim
)
elif t2.split is None:
t2 = factories.array(
t2, split=t1.split, copy=False, comm=t2.comm, device=t2.device, ndmin=-t1.ndim
)
elif t1.split != t2.split:
# It is NOT possible to perform binary operations on tensors with different splits, e.g. split=0
# and split=1
raise NotImplementedError("Not implemented for other splittings")

output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape)
output_split = t1.split
output_device = t1.device
output_comm = t1.comm

if t1.split is not None:
if t1.shape[t1.split] == 1 and t1.comm.is_distributed():
# warnings.warn(
# "Broadcasting requires transferring data of first operator between MPI ranks!"
# )
color = 0 if t1.comm.rank < t2.shape[t1.split] else 1
newcomm = t1.comm.Split(color, t1.comm.rank)
if t1.comm.rank > 0 and color == 0:
t1.larray = torch.zeros(
t1.shape, dtype=t1.dtype.torch_type(), device=t1.device.torch_device
)
newcomm.Bcast(t1)
newcomm.Free()

if t2.split is not None:
if t2.shape[t2.split] == 1 and t2.comm.is_distributed():
# warnings.warn(
# "Broadcasting requires transferring data of second operator between MPI ranks!"
# )
color = 0 if t2.comm.rank < t1.shape[t2.split] else 1
newcomm = t2.comm.Split(color, t2.comm.rank)
if t2.comm.rank > 0 and color == 0:
t2.larray = torch.zeros(
t2.shape, dtype=t2.dtype.torch_type(), device=t2.device.torch_device
)
newcomm.Bcast(t2)
newcomm.Free()

else:
raise TypeError(
"Only tensors and numeric scalars are supported, but input was {}".format(type(t2))
)
else:
raise NotImplementedError("Not implemented for non scalar")

# sanitize output
if out is not None:
sanitation.sanitize_out(out, output_shape, output_split, output_device)

# promoted_type = types.promote_types(t1.dtype, t2.dtype).torch_type()
if t1.split is not None:
if len(t1.lshape) > t1.split and t1.lshape[t1.split] == 0:
result = t1.larray.type(promoted_type)
else:
result = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
"Data type not supported, inputs were {} and {}".format(type(t1), type(t2))
)
elif np.isscalar(t1) and isinstance(t2, DNDarray):
try:
t1 = factories.array(t1, device=t2.device, comm=t2.comm)
except (ValueError, TypeError):
raise TypeError("Data type not supported, input was {}".format(type(t1)))
elif isinstance(t1, DNDarray) and np.isscalar(t2):
try:
t2 = factories.array(t2, device=t1.device, comm=t1.comm)
except (ValueError, TypeError):
raise TypeError("Data type not supported, input was {}".format(type(t2)))

# Make inputs have the same dimensionality
output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape)
# Broadcasting allows additional empty dimensions on the left side
# TODO simplify this once newaxis-indexing is supported to get rid of the loops
while len(t1.shape) < len(output_shape):
t1 = t1.expand_dims(axis=0)
while len(t2.shape) < len(output_shape):
t2 = t2.expand_dims(axis=0)
# t1 = t1[tuple([None] * (len(output_shape) - t1.ndim))]
# t2 = t2[tuple([None] * (len(output_shape) - t2.ndim))]
# print(t1.lshape, t2.lshape)

def __get_out_params(target, other=None, map=None):
"""
Getter for the output parameters of a binary operation with target distribution.
If `other` is provided, its distribution will be matched to `target` or, if provided,
redistributed according to `map`.
Parameters
----------
target : DNDarray
DNDarray determining the parameters
other : DNDarray
DNDarray to be adapted
map : Tensor
lshape_map `other` should be matched to. Defaults to `target.lshape_map`
Returns
-------
Tuple
split, device, comm, balanced, [other]
"""
if other is not None:
if out is None:
other = sanitation.sanitize_distribution(other, target=target, diff_map=map)
return target.split, target.device, target.comm, target.balanced, other
return target.split, target.device, target.comm, target.balanced

if t1.split is not None and t1.shape[t1.split] == output_shape[t1.split]: # t1 is "dominant"
output_split, output_device, output_comm, output_balanced, t2 = __get_out_params(t1, t2)
elif t2.split is not None and t2.shape[t2.split] == output_shape[t2.split]: # t2 is "dominant"
output_split, output_device, output_comm, output_balanced, t1 = __get_out_params(t2, t1)
elif t1.split is not None:
# t1 is split but broadcast -> only on one rank; manipulate lshape_map s.t. this rank has 'full' data
lmap = t1.lshape_map
idx = lmap[:, t1.split].nonzero(as_tuple=True)[0]
lmap[idx.item(), t1.split] = output_shape[t1.split]
output_split, output_device, output_comm, output_balanced, t2 = __get_out_params(
t1, t2, map=lmap
)
elif t2.split is not None:

if len(t2.lshape) > t2.split and t2.lshape[t2.split] == 0:
result = t2.larray.type(promoted_type)
else:
result = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
)
else:
result = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
# t2 is split but broadcast -> only on one rank; manipulate lshape_map s.t. this rank has 'full' data
lmap = t2.lshape_map
idx = lmap[:, t2.split].nonzero(as_tuple=True)[0]
lmap[idx.item(), t2.split] = output_shape[t2.split]
output_split, output_device, output_comm, output_balanced, t1 = __get_out_params(
t2, other=t1, map=lmap
)

if not isinstance(result, torch.Tensor):
result = torch.tensor(result, device=output_device.torch_device)
else: # both are not split
output_split, output_device, output_comm, output_balanced = __get_out_params(t1)

if out is not None:
out_dtype = out.dtype
out.larray = result
out._DNDarray__comm = output_comm
out = out.astype(out_dtype)
sanitation.sanitize_out(out, output_shape, output_split, output_device, output_comm)
t1, t2 = sanitation.sanitize_distribution(t1, t2, target=out)
out.larray[:] = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
)
return out
# print(t1.lshape, t2.lshape)

result = operation(t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs)

return DNDarray(
result,
output_shape,
types.heat_type_of(result),
output_split,
output_device,
output_comm,
balanced=None,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)


Expand Down
Loading

0 comments on commit dbb8300

Please sign in to comment.