Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/torch proxy #856

Merged
merged 6 commits into from
Aug 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

## Feature Additions

### DNDarray
- [#856](https://github.com/helmholtz-analytics/heat/pull/856) New `DNDarray` method `__torch_proxy__`

### Linear Algebra
- [#840](https://github.com/helmholtz-analytics/heat/pull/840) New feature: `vecdot()`
- [#846](https://github.com/helmholtz-analytics/heat/pull/846) New features `norm`, `vector_norm`, `matrix_norm`
Expand Down
17 changes: 14 additions & 3 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,9 @@ def size(self) -> int:
"""
Number of total elements of the ``DNDarray``
"""
return torch.prod(torch.tensor(self.gshape, device=self.device.torch_device)).item()
return torch.prod(
torch.tensor(self.gshape, dtype=torch.int, device=self.device.torch_device)
).item()

@property
def gnbytes(self) -> int:
Expand Down Expand Up @@ -731,7 +733,7 @@ def __getitem__(self, key: Union[int, Tuple[int, ...], List[int, ...]]) -> DNDar
key = tuple(key)

# assess final global shape
self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim)
self_proxy = self.__torch_proxy__()
gout_full = list(self_proxy[key].shape)

# ellipsis
Expand Down Expand Up @@ -1430,7 +1432,7 @@ def __setitem__(
chunk_start = chunk_slice[self.split].start
chunk_end = chunk_slice[self.split].stop

self_proxy = torch.ones((1,)).as_strided(self.gshape, [0] * self.ndim)
self_proxy = self.__torch_proxy__()

# if the value is a DNDarray, the divisions need to be balanced:
# this means that we need to know how much data is where for both DNDarrays
Expand Down Expand Up @@ -1612,6 +1614,15 @@ def tolist(self, keepsplit: bool = False) -> List:

return self.__array.tolist()

def __torch_proxy__(self) -> torch.Tensor:
"""
Return a 1-element `torch.Tensor` strided as the global `self` shape.
Used internally for sanitation purposes.
"""
return torch.ones((1,), dtype=torch.int8, device=self.larray.device).as_strided(
self.gshape, [0] * self.ndim
)

@staticmethod
def __xitem_get_key_start_stop(
rank: int,
Expand Down
4 changes: 2 additions & 2 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3656,7 +3656,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray:
except AttributeError:
x = factories.array(x).reshape(1)

x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim)
x_proxy = x.__torch_proxy__()

# torch-proof args/kwargs:
# torch `reps`: int or sequence of ints; numpy `reps`: can be array-like
Expand Down Expand Up @@ -3722,7 +3722,7 @@ def tile(x: DNDarray, reps: Sequence[int, ...]) -> DNDarray:
trans_axes[0], trans_axes[x.split] = x.split, 0
reps[0], reps[x.split] = reps[x.split], reps[0]
x = linalg.transpose(x, trans_axes)
x_proxy = torch.ones((1,)).as_strided(x.gshape, [0] * x.ndim)
x_proxy = x.__torch_proxy__()
out_gshape = tuple(x_proxy.repeat(reps).shape)

local_x = x.larray
Expand Down
16 changes: 16 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1460,6 +1460,22 @@ def test_tolist(self):
]
self.assertListEqual(a.tolist(keepsplit=True), res)

def test_torch_proxy(self):
scalar_array = ht.array(1)
scalar_proxy = scalar_array.__torch_proxy__()
self.assertTrue(scalar_proxy.ndim == 0)
scalar_proxy_nbytes = scalar_proxy.storage().size() * scalar_proxy.storage().element_size()
self.assertTrue(scalar_proxy_nbytes == 1)

dndarray = ht.zeros((4, 7, 6), split=1)
dndarray_proxy = dndarray.__torch_proxy__()
self.assertTrue(dndarray_proxy.ndim == dndarray.ndim)
self.assertTrue(tuple(dndarray_proxy.shape) == dndarray.gshape)
dndarray_proxy_nbytes = (
dndarray_proxy.storage().size() * dndarray_proxy.storage().element_size()
)
self.assertTrue(dndarray_proxy_nbytes == 1)

def test_xor(self):
int16_tensor = ht.array([[1, 1], [2, 2]], dtype=ht.int16)
int16_vector = ht.array([[3, 4]], dtype=ht.int16)
Expand Down
10 changes: 9 additions & 1 deletion heat/core/tests/test_manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3276,14 +3276,22 @@ def test_tile(self):
self.assertTrue((np_tiled == ht_tiled.numpy()).all())
self.assertTrue(ht_tiled.dtype is x.dtype)

# test scalar x
# test scalar DNDarray x
x = ht.array(9.0)
reps = (2, 1)
ht_tiled = ht.tile(x, reps)
np_tiled = np.tile(x.numpy(), reps)
self.assertTrue((np_tiled == ht_tiled.numpy()).all())
self.assertTrue(ht_tiled.dtype is x.dtype)

# test scalar x
x = 10
reps = (2, 1)
ht_tiled = ht.tile(x, reps)
np_tiled = np.tile(x, reps)
self.assertTrue((np_tiled == ht_tiled.numpy()).all())
self.assertTrue(ht_tiled.dtype is ht.int64)

# test distributed tile along split axis
# len(reps) > x.ndim
split = 1
Expand Down