Skip to content

Commit

Permalink
Merge branch 'master' into bug/866-op-nested
Browse files Browse the repository at this point in the history
  • Loading branch information
coquelin77 committed Oct 8, 2021
2 parents 9d5cf2e + c570047 commit 18a8297
Show file tree
Hide file tree
Showing 12 changed files with 350 additions and 20 deletions.
22 changes: 21 additions & 1 deletion .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,28 @@ i.e.
- Documentation update
--->

## Due Diligence
## Memory requirements
<!--- Compare memory requirements to previous implementation / relevant torch operations if applicable:
- in distributed and non-distributed mode
- with `split=None` and `split not None`
This can be done using https://github.com/pythonprofilers/memory_profiler for CPU memory measurements,
GPU measuremens can be done with https://pytorch.org/docs/master/generated/torch.cuda.max_memory_allocated.html.
These tools only profile the memory used by each process, not the entire function.
--->

## Performance
<!--- Compare performance to previous implementation / relevant torch operations if applicable:
- in distributed and non-distributed mode
- with `split=None` and `split not None`
Python has an embedded profiler: https://docs.python.org/3.9/library/profile.html
Again, this will only provile the performance on each process. Printing the results with many processes
my be illegible. It may be easiest to save the output of each to a file.
--->


## Due Diligence
- [ ] All split configurations tested
- [ ] Multiple dtypes tested in relevant functions
- [ ] Documentation updated (if needed)
Expand Down
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Pending additions

- [#867](https://github.com/helmholtz-analytics/heat/pull/867) Upgraded to support torch 1.9.0
- [#876](https://github.com/helmholtz-analytics/heat/pull/876) Make examples work (Lasso and kNN)

## Bug Fixes
- [#826](https://github.com/helmholtz-analytics/heat/pull/826) Fixed `__setitem__` handling of distributed `DNDarray` values which have a different shape in the split dimension
Expand All @@ -9,6 +10,8 @@
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) Fixed an issue in `__binary_op` where data was falsely distributed if a DNDarray has single element.

## Feature Additions
### Linear Algebra
- [#842](https://github.com/helmholtz-analytics/heat/pull/842) New feature: `vdot`

### Communication
- [#868](https://github.com/helmholtz-analytics/heat/pull/868) New `MPICommunication` method `Split`
Expand All @@ -27,6 +30,8 @@
- [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis`
### Random
- [#858](https://github.com/helmholtz-analytics/heat/pull/858) New Feature: `standard_normal`, `normal`
### Rounding
- [#827](https://github.com/helmholtz-analytics/heat/pull/827) New feature: `sign`, `sgn`

# v1.1.1
- [#864](https://github.com/helmholtz-analytics/heat/pull/864) Dependencies: constrain `torchvision` version range to match supported `pytorch` version range.
Expand Down
2 changes: 1 addition & 1 deletion examples/classification/demo_knn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from heat.classification.kneighborsclassifier import KNeighborsClassifier

# Load dataset from hdf5 file
X = ht.load_hdf5("../../heat/datasets/data/iris.h5", dataset="data", split=0)
X = ht.load_hdf5("../../heat/datasets/iris.h5", dataset="data", split=0)

# Generate keys for the iris.h5 dataset
keys = []
Expand Down
9 changes: 6 additions & 3 deletions examples/lasso/demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import numpy as np
import torch
import sys
import os

sys.path.append("../../")
# Fix python path if run from terminal
curdir = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, os.path.abspath(os.path.join(curdir, "../../")))

import heat as ht
from matplotlib import pyplot as plt
Expand All @@ -14,8 +17,8 @@
diabetes = datasets.load_diabetes()

# load diabetes dataset from hdf5 file
X = ht.load_hdf5("../../heat/datasets/data/diabetes.h5", dataset="x", split=0)
y = ht.load_hdf5("../../heat/datasets/data/diabetes.h5", dataset="y", split=0)
X = ht.load_hdf5("../../heat/datasets/diabetes.h5", dataset="x", split=0)
y = ht.load_hdf5("../../heat/datasets/diabetes.h5", dataset="y", split=0)

# normalize dataset #DoTO this goes into the lasso fit routine soon as issue #106 is solved
X = X / ht.sqrt((ht.mean(X ** 2, axis=0)))
Expand Down
37 changes: 37 additions & 0 deletions heat/core/linalg/basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
"transpose",
"tril",
"triu",
"vdot",
"vecdot",
"vector_norm",
]
Expand Down Expand Up @@ -1917,6 +1918,42 @@ def triu(m: DNDarray, k: int = 0) -> DNDarray:
DNDarray.triu.__doc__ = triu.__doc__


def vdot(x1: DNDarray, x2: DNDarray) -> DNDarray:
"""
Computes the dot product of two vectors. Higher-dimensional arrays will be flattened.
Parameters
----------
x1 : DNDarray
first input array. If it's complex, it's complex conjugate will be used.
x2 : DNDarray
second input array.
Raises
------
ValueError
If the number of elements is inconsistent.
See Also
--------
dot
Return the dot product without using the complex conjugate.
Examples
--------
>>> a = ht.array([1+1j, 2+2j])
>>> b = ht.array([1+2j, 3+4j])
>>> ht.vdot(a,b)
DNDarray([(17+3j)], dtype=ht.complex64, device=cpu:0, split=None)
>>> ht.vdot(b,a)
DNDarray([(17-3j)], dtype=ht.complex64, device=cpu:0, split=None)
"""
x1 = manipulations.flatten(x1)
x2 = manipulations.flatten(x2)

return arithmetics.sum(arithmetics.multiply(complex_math.conjugate(x1), x2))


def vecdot(
x1: DNDarray, x2: DNDarray, axis: Optional[int] = None, keepdim: Optional[bool] = None
) -> DNDarray:
Expand Down
15 changes: 15 additions & 0 deletions heat/core/linalg/tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1767,6 +1767,21 @@ def test_triu(self):
if result.comm.rank == result.shape[0] - 1:
self.assertTrue(result.larray[0, -1] == 1)

def test_vdot(self):
a = ht.array([[1 + 1j, 2 + 2j], [3 + 3j, 4 + 4j]], split=0)
b = ht.array([[1 + 2j, 3 + 4j], [5 + 6j, 7 + 8j]], split=0)

vdot = ht.vdot(a, b)
self.assertEqual(vdot.dtype, a.dtype)
self.assertEqual(vdot.split, None)
self.assertTrue(ht.equal(vdot, ht.array([110 + 10j])))

vdot = ht.vdot(b, a)
self.assertTrue(ht.equal(vdot, ht.array([110 - 10j])))

with self.assertRaises(ValueError):
ht.vdot(ht.array([1, 2, 3]), ht.array([[1, 2], [3, 4]]))

def test_vecdot(self):
a = ht.array([1, 1, 1])
b = ht.array([1, 2, 3])
Expand Down
2 changes: 2 additions & 0 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,8 @@ def flatten(a: DNDarray) -> DNDarray:
>>> ht.flatten(a)
DNDarray([1, 2, 3, 4, 5, 6, 7, 8], dtype=ht.int64, device=cpu:0, split=None)
"""
sanitation.sanitize_in(a)

if a.split is None:
return factories.array(
torch.flatten(a.larray), dtype=a.dtype, is_split=None, device=a.device, comm=a.comm
Expand Down
98 changes: 97 additions & 1 deletion heat/core/rounding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,19 @@
from . import sanitation
from . import types

__all__ = ["abs", "absolute", "ceil", "clip", "fabs", "floor", "modf", "round", "trunc"]
__all__ = [
"abs",
"absolute",
"ceil",
"clip",
"fabs",
"floor",
"modf",
"round",
"sgn",
"sign",
"trunc",
]


def abs(
Expand Down Expand Up @@ -328,6 +340,90 @@ def round(
DNDarray.round.__doc__ = round.__doc__


def sgn(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
"""
Returns an indication of the sign of a number, element-wise. The definition for complex values is equivalent to :math:`x / |x|`.
Parameters
----------
x : DNDarray
Input array
out : DNDarray, optional
A location in which to store the results.
See Also
--------
:func:`sign`
Equivalent function on non-complex arrays. The definition for complex values is equivalent to :math:`x / \\sqrt{x \\cdot x}`
Examples
--------
>>> a = ht.array([-1, -0.5, 0, 0.5, 1])
>>> ht.sign(a)
DNDarray([-1., -1., 0., 1., 1.], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.sgn(ht.array([5-2j, 3+4j]))
DNDarray([(0.9284766912460327-0.3713906705379486j), (0.6000000238418579+0.800000011920929j)], dtype=ht.complex64, device=cpu:0, split=None)
"""
return _operations.__local_op(torch.sgn, x, out)


def sign(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
"""
Returns an indication of the sign of a number, element-wise. The definition for complex values is equivalent to :math:`x / \\sqrt{x \\cdot x}`.
Parameters
----------
x : DNDarray
Input array
out : DNDarray, optional
A location in which to store the results.
See Also
--------
:func:`sgn`
Equivalent function on non-complex arrays. The definition for complex values is equivalent to :math:`x / |x|`.
Examples
--------
>>> a = ht.array([-1, -0.5, 0, 0.5, 1])
>>> ht.sign(a)
DNDarray([-1., -1., 0., 1., 1.], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.sign(ht.array([5-2j, 3+4j]))
DNDarray([(1+0j), (1+0j)], dtype=ht.complex64, device=cpu:0, split=None)
"""
# special case for complex values
if types.heat_type_is_complexfloating(x.dtype):
sanitation.sanitize_in(x)
if out is not None:
sanitation.sanitize_out(out, x.shape, x.split, x.device)
out.larray.copy_(x.larray)
data = out.larray
else:
data = torch.clone(x.larray)
# NOTE remove when min version >= 1.9
if "1.7" in torch.__version__ or "1.8" in torch.__version__:
pos = data != 0
else: # pragma: no cover
indices = torch.nonzero(data)
pos = torch.split(indices, 1, 1)
data[pos] = x.larray[pos] / torch.sqrt(torch.square(x.larray[pos]))

if out is not None:
out.__dtype = types.heat_type_of(data)
return out
return DNDarray(
data,
gshape=x.shape,
dtype=types.heat_type_of(data),
split=x.split,
device=x.device,
comm=x.comm,
balanced=x.balanced,
)

return _operations.__local_op(torch.sign, x, out)


def trunc(x: DNDarray, out: Optional[DNDarray] = None) -> DNDarray:
"""
Return the trunc of the input, element-wise.
Expand Down
Loading

0 comments on commit 18a8297

Please sign in to comment.