Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/880 binop ben bou #902

Merged
merged 62 commits into from
Jan 31, 2022
Merged
Show file tree
Hide file tree
Changes from 61 commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
346822a
Support different lshape maps in binary ops
ClaudiaComito Nov 2, 2021
169f062
Adapt documentation
ClaudiaComito Nov 5, 2021
207e843
Typos
ClaudiaComito Nov 5, 2021
cb053e6
Update changelog
ClaudiaComito Nov 5, 2021
ff03641
Debugging tests
ClaudiaComito Nov 5, 2021
1958470
Allow unbalanced and unequally balanced dndarrays for binops
Nov 7, 2021
a5cc6ff
determine promoted_type in the beginning
ben-bou Nov 7, 2021
86172ba
Refine distribution logic for non-distributed operands
ClaudiaComito Nov 8, 2021
bd33252
Change tests order
ClaudiaComito Nov 8, 2021
e41dc58
Typo
ClaudiaComito Nov 8, 2021
e1fb638
typo
ben-bou Nov 8, 2021
871bb42
No broadcasting in split-dimension
ben-bou Nov 8, 2021
55d46bf
Merge branch 'features/880-bin_op_different_distribution' into origin…
ben-bou Nov 8, 2021
c415539
Remove redundant condition
ClaudiaComito Nov 8, 2021
45dc427
explicit broadcasting in split-dimension
ben-bou Nov 8, 2021
363f304
Merge branch 'features/880-bin_op_different_distribution' into origin…
ben-bou Nov 8, 2021
2eaf0c5
debugging
Nov 8, 2021
70ce2bc
improve tests
Nov 15, 2021
2075684
fix empty process bug
Nov 15, 2021
00c31bd
remove debugging output
Nov 15, 2021
ed753a6
empty process handling; run tests
Nov 16, 2021
c341248
typo
Nov 16, 2021
85d376e
move the equalization of distribution to sanitation
Dec 6, 2021
50d9406
beautify
Dec 6, 2021
b6ba31b
wrong condition
Dec 6, 2021
6b3f136
use None-indexing, fix broadcast_shapes
Dec 7, 2021
572c432
Fix broadcast_shapes
Dec 7, 2021
f5fa517
old expand_dims
Dec 7, 2021
c77d086
debug
Dec 20, 2021
7c7160d
allow broadcasting in split dimension
Dec 20, 2021
6cbe0cb
torch.Tensor check is needed because torch.equal is wrongly directed …
Dec 20, 2021
eb7bf04
allow different lshape-map in sanitize-distribution
Dec 20, 2021
f9fce70
docs
Dec 20, 2021
203b9fc
directly use tuple
Jan 18, 2022
39e626f
Fix error in stack
Jan 18, 2022
d573bdc
Fix heat.equal error
Jan 18, 2022
04161e9
Merge branch 'master' into features/880-binop_ben-bou
ClaudiaComito Jan 21, 2022
5d198e4
Modify `stride_tricks.broadcast_shape` to return `torch.broadcast_sha…
ClaudiaComito Jan 21, 2022
120054b
Backward compatibility for `stride_tricks.broadcast_shape`
ClaudiaComito Jan 21, 2022
f85d18e
Debug GPU tests
ClaudiaComito Jan 21, 2022
24045b6
Debugging
ClaudiaComito Jan 21, 2022
b329f32
Debugging GPU tests
ClaudiaComito Jan 21, 2022
7e5ee42
Fix bug in debugging statement
ClaudiaComito Jan 21, 2022
ee504fc
Debugging
ClaudiaComito Jan 21, 2022
ed32233
Fix `broadcast_shape` when one dimension is 0 for torch < 1.8
ClaudiaComito Jan 21, 2022
39527f1
Remove print statements
ClaudiaComito Jan 21, 2022
6627c08
Merge pull request #905 from helmholtz-analytics/features/880-binop_b…
ben-bou Jan 21, 2022
3f8d722
add tests
Jan 22, 2022
82036e8
Merge branch 'master' into features/880-binop_ben-bou
ClaudiaComito Jan 24, 2022
dca94bf
check if balanced when using diff_map. TODO: debug is_balanced; curre…
Jan 24, 2022
a556754
found typo, no need for force check
Jan 24, 2022
4863cd3
code review + docs
Jan 24, 2022
07dd506
fix error
Jan 24, 2022
80549be
docs
Jan 24, 2022
eac41b8
changing Tensor returned by create_lshape_map should not change the l…
Jan 24, 2022
415ee52
Merge branch 'master' into features/880-binop_ben-bou
ben-bou Jan 24, 2022
e70aa39
add tests
Jan 24, 2022
ca23d4c
Merge branch 'features/880-binop_ben-bou' of https://github.com/helmh…
Jan 24, 2022
f05d31f
add tests
Jan 24, 2022
664b27e
Edit error messages for `sanitize_out`
ClaudiaComito Jan 25, 2022
c6a7d9e
Merge branch 'features/880-binop_ben-bou' of github.com:helmholtz-ana…
ClaudiaComito Jan 25, 2022
3ff5a41
adressed Review; mainly docs
Jan 28, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ These tools only profile the memory used by each process, not the entire functio
- with `split=None` and `split not None`

Python has an embedded profiler: https://docs.python.org/3.9/library/profile.html
Again, this will only provile the performance on each process. Printing the results with many processes
Again, this will only profile the performance on each process. Printing the results with many processes
my be illegible. It may be easiest to save the output of each to a file.
--->

Expand Down
6 changes: 4 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) Fixed an issue in `__binary_op` where data was falsely distributed if a DNDarray has single element.

## Feature Additions
### Linear Algebra
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`

### Arithmetics
- - [#887](https://github.com/helmholtz-analytics/heat/pull/887) Binary operations now support operands of equal shapes, equal `split` axes, but different distribution maps.

## Feature additions
### Communication
Expand All @@ -25,6 +26,7 @@
# Feature additions
### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`
- [#850](https://github.com/helmholtz-analytics/heat/pull/850) New Feature `cross`
- [#877](https://github.com/helmholtz-analytics/heat/pull/877) New feature `det`
Expand Down
2 changes: 1 addition & 1 deletion heat/cluster/tests/test_kmedoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def create_spherical_dataset(
cluster4 = ht.stack((x - 2 * offset, y - 2 * offset, z - 2 * offset), axis=1)

data = ht.concatenate((cluster1, cluster2, cluster3, cluster4), axis=0)
# Note: enhance when shuffel is available
# Note: enhance when shuffle is available
return data

def test_clusterer(self):
Expand Down
233 changes: 110 additions & 123 deletions heat/core/_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,147 +51,134 @@ def __binary_op(
-------
result: ht.DNDarray
A DNDarray containing the results of element-wise operation.

Warning
-------
If both operands are distributed, they must be distributed along the same dimension, i.e. `t1.split = t2.split`.

MPI communication is necessary when both operands are distributed along the same dimension, but the distribution maps do not match. E.g.:
```
a = ht.ones(10000, split=0)
b = ht.zeros(10000, split=0)
c = a[:-1] + b[1:]
```
In such cases, one of the operands is redistributed OUT-OF-PLACE to match the distribution map of the other operand.
The operand determining the resulting distribution is chosen as follows:
1) split is preferred to no split
2) no (shape)-broadcasting in the split dimension if not necessary
3) t1 is preferred to t2
"""
# Check inputs
if not np.isscalar(t1) and not isinstance(t1, DNDarray):
raise TypeError(
"Only DNDarrays and numeric scalars are supported, but input was {}".format(type(t1))
)
if not np.isscalar(t2) and not isinstance(t2, DNDarray):
raise TypeError(
"Only DNDarrays and numeric scalars are supported, but input was {}".format(type(t2))
)
promoted_type = types.result_type(t1, t2).torch_type()

if np.isscalar(t1):
# Make inputs Dndarrays
if np.isscalar(t1) and np.isscalar(t2):
try:
t1 = factories.array(t1, device=t2.device if isinstance(t2, DNDarray) else None)
t1 = factories.array(t1)
t2 = factories.array(t2)
except (ValueError, TypeError):
raise TypeError("Data type not supported, input was {}".format(type(t1)))

if np.isscalar(t2):
try:
t2 = factories.array(t2)
except (ValueError, TypeError):
raise TypeError(
"Only numeric scalars are supported, but input was {}".format(type(t2))
)
output_shape = (1,)
output_split = None
output_device = t2.device
output_comm = MPI_WORLD
elif isinstance(t2, DNDarray):
output_shape = t2.shape
output_split = t2.split
output_device = t2.device
output_comm = t2.comm
else:
raise TypeError(
"Only tensors and numeric scalars are supported, but input was {}".format(type(t2))
)

if t1.dtype != t2.dtype:
t1 = t1.astype(t2.dtype)

elif isinstance(t1, DNDarray):
if np.isscalar(t2):
try:
t2 = factories.array(t2, device=t1.device)
output_shape = t1.shape
output_split = t1.split
output_device = t1.device
output_comm = t1.comm
except (ValueError, TypeError):
raise TypeError("Data type not supported, input was {}".format(type(t2)))

elif isinstance(t2, DNDarray):
if t1.split is None:
t1 = factories.array(
t1, split=t2.split, copy=False, comm=t1.comm, device=t1.device, ndmin=-t2.ndim
)
elif t2.split is None:
t2 = factories.array(
t2, split=t1.split, copy=False, comm=t2.comm, device=t2.device, ndmin=-t1.ndim
)
elif t1.split != t2.split:
# It is NOT possible to perform binary operations on tensors with different splits, e.g. split=0
# and split=1
raise NotImplementedError("Not implemented for other splittings")

output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape)
output_split = t1.split
output_device = t1.device
output_comm = t1.comm

if t1.split is not None:
if t1.shape[t1.split] == 1 and t1.comm.is_distributed():
# warnings.warn(
# "Broadcasting requires transferring data of first operator between MPI ranks!"
# )
color = 0 if t1.comm.rank < t2.shape[t1.split] else 1
newcomm = t1.comm.Split(color, t1.comm.rank)
if t1.comm.rank > 0 and color == 0:
t1.larray = torch.zeros(
t1.shape, dtype=t1.dtype.torch_type(), device=t1.device.torch_device
)
newcomm.Bcast(t1)
newcomm.Free()

if t2.split is not None:
if t2.shape[t2.split] == 1 and t2.comm.is_distributed():
# warnings.warn(
# "Broadcasting requires transferring data of second operator between MPI ranks!"
# )
color = 0 if t2.comm.rank < t1.shape[t2.split] else 1
newcomm = t2.comm.Split(color, t2.comm.rank)
if t2.comm.rank > 0 and color == 0:
t2.larray = torch.zeros(
t2.shape, dtype=t2.dtype.torch_type(), device=t2.device.torch_device
)
newcomm.Bcast(t2)
newcomm.Free()

else:
raise TypeError(
"Only tensors and numeric scalars are supported, but input was {}".format(type(t2))
)
else:
raise NotImplementedError("Not implemented for non scalar")

# sanitize output
if out is not None:
sanitation.sanitize_out(out, output_shape, output_split, output_device)

# promoted_type = types.promote_types(t1.dtype, t2.dtype).torch_type()
if t1.split is not None:
if len(t1.lshape) > t1.split and t1.lshape[t1.split] == 0:
result = t1.larray.type(promoted_type)
else:
result = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
"Data type not supported, inputs were {} and {}".format(type(t1), type(t2))
)
elif np.isscalar(t1) and isinstance(t2, DNDarray):
try:
t1 = factories.array(t1, device=t2.device, comm=t2.comm)
except (ValueError, TypeError):
raise TypeError("Data type not supported, input was {}".format(type(t1)))
elif isinstance(t1, DNDarray) and np.isscalar(t2):
try:
t2 = factories.array(t2, device=t1.device, comm=t1.comm)
except (ValueError, TypeError):
raise TypeError("Data type not supported, input was {}".format(type(t2)))

# Make inputs have the same dimensionality
output_shape = stride_tricks.broadcast_shape(t1.shape, t2.shape)
# Broadcasting allows additional empty dimensions on the left side
# TODO simplify this once newaxis-indexing is supported to get rid of the loops
while len(t1.shape) < len(output_shape):
t1 = t1.expand_dims(axis=0)
while len(t2.shape) < len(output_shape):
t2 = t2.expand_dims(axis=0)
# t1 = t1[tuple([None] * (len(output_shape) - t1.ndim))]
# t2 = t2[tuple([None] * (len(output_shape) - t2.ndim))]
# print(t1.lshape, t2.lshape)

def __get_out_params(target, other=None, map=None):
"""
Getter for the output parameters of a binop with target.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"binop with target" -> "binary operation with target distribution"

If other is provided, it's distribution will be matched to target or, if provided,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

other -> other
it's -> its
target -> target

redistributed according to map.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

map -> map


Parameters
----------
target : DNDarray
DNDarray determining the parameters
other : DNDarray
DNDarray to be adapted
map : Tensor
Lshape-Map other should be matched to. Defaults to target's lshape_map
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lshape-Map -> lshape_map
other -> other
target's lshape_map -> target.lshape_map


Returns
-------
Tuple
split, device, comm, balanced, [other]
"""
if other is not None:
ClaudiaComito marked this conversation as resolved.
Show resolved Hide resolved
if out is None:
other = sanitation.sanitize_distribution(other, target=target, diff_map=map)
return target.split, target.device, target.comm, target.balanced, other
return target.split, target.device, target.comm, target.balanced

if t1.split is not None and t1.shape[t1.split] == output_shape[t1.split]: # t1 is "dominant"
output_split, output_device, output_comm, output_balanced, t2 = __get_out_params(t1, t2)
elif t2.split is not None and t2.shape[t2.split] == output_shape[t2.split]: # t2 is "dominant"
output_split, output_device, output_comm, output_balanced, t1 = __get_out_params(t2, t1)
elif t1.split is not None:
# t1 is split but broadcast -> only on one rank; manipulate lshape_map s.t. this rank has 'full' data
lmap = t1.lshape_map
idx = lmap[:, t1.split].nonzero(as_tuple=True)[0]
lmap[idx.item(), t1.split] = output_shape[t1.split]
output_split, output_device, output_comm, output_balanced, t2 = __get_out_params(
t1, t2, map=lmap
)
elif t2.split is not None:

if len(t2.lshape) > t2.split and t2.lshape[t2.split] == 0:
result = t2.larray.type(promoted_type)
else:
result = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
)
else:
result = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
# t2 is split but broadcast -> only on one rank; manipulate lshape_map s.t. this rank has 'full' data
lmap = t2.lshape_map
idx = lmap[:, t2.split].nonzero(as_tuple=True)[0]
lmap[idx.item(), t2.split] = output_shape[t2.split]
output_split, output_device, output_comm, output_balanced, t1 = __get_out_params(
t2, other=t1, map=lmap
)

if not isinstance(result, torch.Tensor):
result = torch.tensor(result, device=output_device.torch_device)
else: # both are not split
output_split, output_device, output_comm, output_balanced = __get_out_params(t1)

if out is not None:
out_dtype = out.dtype
out.larray = result
out._DNDarray__comm = output_comm
out = out.astype(out_dtype)
sanitation.sanitize_out(out, output_shape, output_split, output_device, output_comm)
t1, t2 = sanitation.sanitize_distribution(t1, t2, target=out)
out.larray[:] = operation(
t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs
)
return out
# print(t1.lshape, t2.lshape)

result = operation(t1.larray.type(promoted_type), t2.larray.type(promoted_type), **fn_kwargs)

return DNDarray(
result,
output_shape,
types.heat_type_of(result),
output_split,
output_device,
output_comm,
balanced=None,
device=output_device,
comm=output_comm,
balanced=output_balanced,
)


Expand Down
Loading