-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features/37 minimum #324
Features/37 minimum #324
Changes from 30 commits
da15766
0611cb1
f4d885e
8324951
6a1867a
2b8d067
9740400
6d41122
f03f737
b031767
2f3b50f
ed358da
8266edd
b6833dd
2a20f1b
9c8c1a7
734e666
4edc3c3
856a90b
ee91a57
5e073d3
8aaf29e
8a1fb28
b089548
45f72d4
5c3522e
ef5a579
46c566f
4534838
a497b3a
c10bf8d
5ce196d
af5f449
fee286a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,6 +8,7 @@ | |
from . import operations | ||
from . import dndarray | ||
from . import types | ||
from . import stride_tricks | ||
|
||
|
||
__all__ = [ | ||
|
@@ -16,6 +17,7 @@ | |
'max', | ||
'mean', | ||
'min', | ||
'minimum', | ||
'std', | ||
'var' | ||
] | ||
|
@@ -524,6 +526,7 @@ def min(x, axis=None, out=None, keepdim=None): | |
[ 7.], | ||
[10.]]) | ||
""" | ||
|
||
def local_min(*args, **kwargs): | ||
result = torch.min(*args, **kwargs) | ||
if isinstance(result, tuple): | ||
|
@@ -533,6 +536,139 @@ def local_min(*args, **kwargs): | |
return operations.__reduce_op(x, local_min, MPI.MIN, axis=axis, out=out, keepdim=keepdim) | ||
|
||
|
||
def minimum(x1, x2, out=None, **kwargs): | ||
''' | ||
Compares two tensors and returns a new tensor containing the element-wise minima. | ||
If one of the elements being compared is a NaN, then that element is returned. TODO: Check this: If both elements are NaNs then the first is returned. | ||
The latter distinction is important for complex NaNs, which are defined as at least one of the real or imaginary parts being a NaN. The net effect is that NaNs are propagated. | ||
|
||
Parameters: | ||
----------- | ||
|
||
x1, x2 : ht.DNDarray | ||
The tensors containing the elements to be compared. They must have the same shape, or shapes that can be broadcast to a single shape. | ||
For broadcasting semantics, see: https://pytorch.org/docs/stable/notes/broadcasting.html | ||
|
||
out : ht.DNDarray or None, optional | ||
A location into which the result is stored. If provided, it must have a shape that the inputs broadcast to. | ||
If not provided or None, a freshly-allocated tensor is returned. | ||
|
||
Returns: | ||
-------- | ||
|
||
minimum: ht.DNDarray | ||
Element-wise minimum of the two input tensors. | ||
|
||
Examples: | ||
--------- | ||
>>> import heat as ht | ||
>>> import torch | ||
>>> torch.manual_seed(1) | ||
<torch._C.Generator object at 0x105c50b50> | ||
|
||
>>> a = ht.random.randn(3,4) | ||
>>> a | ||
tensor([[-0.1955, -0.9656, 0.4224, 0.2673], | ||
[-0.4212, -0.5107, -1.5727, -0.1232], | ||
[ 3.5870, -1.8313, 1.5987, -1.2770]]) | ||
|
||
>>> b = ht.random.randn(3,4) | ||
>>> b | ||
tensor([[ 0.8310, -0.2477, -0.8029, 0.2366], | ||
[ 0.2857, 0.6898, -0.6331, 0.8795], | ||
[-0.6842, 0.4533, 0.2912, -0.8317]]) | ||
|
||
>>> ht.minimum(a,b) | ||
tensor([[-0.1955, -0.9656, -0.8029, 0.2366], | ||
[-0.4212, -0.5107, -1.5727, -0.1232], | ||
[-0.6842, -1.8313, 0.2912, -1.2770]]) | ||
|
||
>>> c = ht.random.randn(1,4) | ||
>>> c | ||
tensor([[-1.6428, 0.9803, -0.0421, -0.8206]]) | ||
|
||
>>> ht.minimum(a,c) | ||
tensor([[-1.6428, -0.9656, -0.0421, -0.8206], | ||
[-1.6428, -0.5107, -1.5727, -0.8206], | ||
[-1.6428, -1.8313, -0.0421, -1.2770]]) | ||
|
||
>>> b.__setitem__((0,1), ht.nan) | ||
>>> b | ||
tensor([[ 0.8310, nan, -0.8029, 0.2366], | ||
[ 0.2857, 0.6898, -0.6331, 0.8795], | ||
[-0.6842, 0.4533, 0.2912, -0.8317]]) | ||
>>> ht.minimum(a,b) | ||
tensor([[-0.1955, nan, -0.8029, 0.2366], | ||
[-0.4212, -0.5107, -1.5727, -0.1232], | ||
[-0.6842, -1.8313, 0.2912, -1.2770]]) | ||
|
||
>>> d = ht.random.randn(3,4,5) | ||
>>> ht.minimum(a,d) | ||
ValueError: operands could not be broadcast, input shapes (3, 4) (3, 4, 5) | ||
''' | ||
# perform sanitation | ||
if not isinstance(x1, dndarray.DNDarray) or not isinstance(x2, dndarray.DNDarray): | ||
raise TypeError('expected x1 and x2 to be a ht.DNDarray, but were {}, {} '.format(type(x1), type(x2))) | ||
if out is not None and not isinstance(out, dndarray.DNDarray): | ||
raise TypeError('expected out to be None or an ht.DNDarray, but was {}'.format(type(out))) | ||
|
||
# apply split semantics | ||
if x1.split is not None or x2.split is not None: | ||
if x1.split == None: | ||
x1.resplit(x2.split) | ||
if x2.split == None: | ||
x2.resplit(x1.split) | ||
if x1.split != x2.split: | ||
if np.prod(x1.gshape) < np.prod(x2.gshape): | ||
x1.resplit(x2.split) | ||
if np.prod(x2.gshape) < np.prod(x1.gshape): | ||
x2.resplit(x1.split) | ||
else: | ||
if x1.split < x2.split: | ||
x2.resplit(x1.split) | ||
else: | ||
x1.resplit(x2.split) | ||
split = x1.split | ||
|
||
# locally: apply torch.min(x1, x2) | ||
output_lshape = stride_tricks.broadcast_shape(x1.lshape, x2.lshape) | ||
lresult = factories.empty(output_lshape) | ||
lresult._DNDarray__array = torch.min(x1._DNDarray__array, x2._DNDarray__array) | ||
lresult._DNDarray__dtype = types.promote_types(x1.dtype, x2.dtype) | ||
lresult._DNDarray__split = split | ||
if x1.split is not None or x2.split is not None: | ||
if x1.comm.is_distributed(): # assuming x1.comm = x2.comm | ||
output_gshape = stride_tricks.broadcast_shape(x1.gshape, x2.gshape) | ||
result = factories.empty(output_gshape) | ||
x1.comm.Allgather(lresult, result) | ||
# TODO: adopt Allgatherv() as soon as it is fixed, Issue #233 | ||
result._DNDarray__dtype = lresult._DNDarray__dtype | ||
result._DNDarray__split = split | ||
|
||
if out is not None: | ||
if out.shape != output_gshape: | ||
raise ValueError('Expecting output buffer of shape {}, got {}'.format(output_gshape, out.shape)) | ||
out._DNDarray__array = result._DNDarray__array | ||
out._DNDarray__dtype = result._DNDarray__dtype | ||
out._DNDarray__split = split | ||
out._DNDarray__device = x1.device | ||
out._DNDarray__comm = x1.comm | ||
|
||
return out | ||
return result | ||
|
||
if out is not None: | ||
if out.shape != output_lshape: | ||
raise ValueError('Expecting output buffer of shape {}, got {}'.format(output_lshape, out.shape)) | ||
out._DNDarray__array = lresult._DNDarray__array | ||
out._DNDarray__dtype = lresult._DNDarray__dtype | ||
out._DNDarray__split = split | ||
out._DNDarray__device = x1.device | ||
out._DNDarray__comm = x1.comm | ||
|
||
return lresult | ||
|
||
|
||
def mpi_argmax(a, b, _): | ||
lhs = torch.from_numpy(np.frombuffer(a, dtype=np.float64)) | ||
rhs = torch.from_numpy(np.frombuffer(b, dtype=np.float64)) | ||
|
@@ -554,7 +690,6 @@ def mpi_argmax(a, b, _): | |
def mpi_argmin(a, b, _): | ||
lhs = torch.from_numpy(np.frombuffer(a, dtype=np.float64)) | ||
rhs = torch.from_numpy(np.frombuffer(b, dtype=np.float64)) | ||
|
||
# extract the values and minimal indices from the buffers (first half are values, second are indices) | ||
values = torch.stack((lhs.chunk(2)[0], rhs.chunk(2)[0],), dim=1) | ||
indices = torch.stack((lhs.chunk(2)[1], rhs.chunk(2)[1],), dim=1) | ||
|
@@ -569,6 +704,15 @@ def mpi_argmin(a, b, _): | |
|
||
MPI_ARGMIN = MPI.Op.Create(mpi_argmin, commute=True) | ||
|
||
# # TODO: implement mpi_minimum | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Does this need to be here? this could be put in an issue There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry about this, it's a leftover from the attempted implementation through __reduce_op(), I removed it already. |
||
# def mpi_minimum(a, b, _): | ||
# lhs = torch.from_numpy(np.frombuffer(a, dtype=np.float64)) | ||
# rhs = torch.from_numpy(np.frombuffer(b, dtype=np.float64)) | ||
|
||
# # do something | ||
|
||
# rhs.copy_(result) | ||
|
||
|
||
def std(x, axis=None, bessel=True): | ||
""" | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is it possible to combine this
if out
and the other one?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure how, as one of them only writes out in the distributed case, the other one in the non-distributed case. If you can think of something, let me know.