Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Features/481 halos #541

Merged
merged 22 commits into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 112 additions & 0 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def __init__(self, array, gshape, dtype, split, device, comm):
self.__split = split
self.__device = device
self.__comm = comm
self.__ishalo = False
self.__halo_next = None
self.__halo_prev = None

# handle inconsistencies between torch and heat devices
if (
Expand All @@ -81,6 +84,14 @@ def __init__(self, array, gshape, dtype, split, device, comm):
):
self.__array = self.__array.to(devices.sanitize_device(self.__device).torch_device)

@property
def halo_next(self):
return self.__halo_next

@property
def halo_prev(self):
return self.__halo_prev

@property
def comm(self):
return self.__comm
Expand Down Expand Up @@ -248,6 +259,107 @@ def strides(self):
def T(self):
return linalg.transpose(self, axes=None)

@property
def array_with_halos(self):
return self.__cat_halo()

def __prephalo(self, start, end):
"""
Extracts the halo indexed by start, end from self.array in the direction of self.split

Parameters
----------
start : int
start index of the halo extracted from self.array
end : int
end index of the halo extracted from self.array

Returns
Markus-Goetz marked this conversation as resolved.
Show resolved Hide resolved
-------
halo : torch.Tensor
The halo extracted from self.array
"""
ix = [slice(None, None, None)] * len(self.shape)
try:
ix[self.split] = slice(start, end)
except IndexError:
print("Indices out of bound")

return self.__array[ix].clone().contiguous()

def get_halo(self, halo_size):
"""
Fetch halos of size 'halo_size' from neighboring ranks and save them in self.halo_next/self.halo_prev
in case they are not already stored. If 'halo_size' differs from the size of already stored halos,
the are overwritten.

Parameters
----------
halo_size : int
Size of the halo.
"""
if not isinstance(halo_size, int):
raise TypeError(
"halo_size needs to be of Python type integer, {} given)".format(type(halo_size))
)
if halo_size < 0:
raise ValueError(
"halo_size needs to be a positive Python integer, {} given)".format(type(halo_size))
)

if self.comm.is_distributed() and self.split is not None:
min_chunksize = self.shape[self.split] // self.comm.size
if halo_size > min_chunksize:
raise ValueError(
"halo_size {} needs to smaller than chunck-size {} )".format(
halo_size, min_chunksize
)
)

a_prev = self.__prephalo(0, halo_size)
a_next = self.__prephalo(-halo_size, None)

res_prev = None
res_next = None

req_list = list()

if self.comm.rank != self.comm.size - 1:
self.comm.Isend(a_next, self.comm.rank + 1)
res_prev = torch.zeros(a_prev.size(), dtype=a_prev.dtype)
req_list.append(self.comm.Irecv(res_prev, source=self.comm.rank + 1))

if self.comm.rank != 0:
self.comm.Isend(a_prev, self.comm.rank - 1)
res_next = torch.zeros(a_next.size(), dtype=a_next.dtype)
req_list.append(self.comm.Irecv(res_next, source=self.comm.rank - 1))

for req in req_list:
req.wait()

Markus-Goetz marked this conversation as resolved.
Show resolved Hide resolved
self.__halo_next = res_prev
self.__halo_prev = res_next
self.__ishalo = True

def __cat_halo(self):
"""
Fetch halos of size 'halo_size' from neighboring ranks and save them in self.halo_next/self.halo_prev
in case they are not already stored. If 'halo_size' differs from the size of already stored halos,
the are overwritten.

Parameters
----------
None

Returns
-------
array + halos: pytorch tensors
"""
return torch.cat(
[_ for _ in (self.__halo_prev, self.__array, self.__halo_next) if _ is not None],
self.split,
)

def abs(self, out=None, dtype=None):
"""
Calculate the absolute value element-wise.
Expand Down
95 changes: 95 additions & 0 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,101 @@ def test_and(self):
ht.equal(int16_tensor & int16_vector, ht.bitwise_and(int16_tensor, int16_vector))
)

def test_gethalo(self):
data_np = np.array([[1, 2, 3, 4, 5, 6], [7, 8, 9, 10, 11, 12]])
data = ht.array(data_np, split=1)

if data.comm.size == 2:

halo_next = torch.tensor(np.array([[4, 5], [10, 11]]))
halo_prev = torch.tensor(np.array([[2, 3], [8, 9]]))

data.get_halo(2)

data_with_halos = data.array_with_halos
self.assertEqual(data_with_halos.shape, (2, 5))

if data.comm.rank == 0:
self.assertTrue(torch.equal(data.halo_next, halo_next))
self.assertEqual(data.halo_prev, None)
if data.comm.rank == 1:
self.assertTrue(torch.equal(data.halo_prev, halo_prev))
self.assertEqual(data.halo_next, None)

self.assertEqual(data.array_with_halos.shape, (2, 5))
# exception on wrong argument type in get_halo
with self.assertRaises(TypeError):
data.get_halo("wrong_type")
# exception on wrong argument in get_halo
with self.assertRaises(ValueError):
data.get_halo(-99)
# exception for too large halos
with self.assertRaises(ValueError):
data.get_halo(4)

data_np = np.array([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], [7.0, 8.0, 9.0, 10.0, 11.0, 12.0]])
data = ht.array(data_np, split=1)

halo_next = torch.tensor(np.array([[4.0, 5.0], [10.0, 11.0]]))
halo_prev = torch.tensor(np.array([[2.0, 3.0], [8.0, 9.0]]))

data.get_halo(2)

if data.comm.rank == 0:
self.assertTrue(np.isclose(((data.halo_next - halo_next) ** 2).mean().item(), 0.0))
self.assertEqual(data.halo_prev, None)
if data.comm.rank == 1:
self.assertTrue(np.isclose(((data.halo_prev - halo_prev) ** 2).mean().item(), 0.0))
self.assertEqual(data.halo_next, None)

data = ht.ones((10, 2), split=0)

halo_next = torch.tensor(np.array([[1.0, 1.0], [1.0, 1.0]]))
halo_prev = torch.tensor(np.array([[1.0, 1.0], [1.0, 1.0]]))

data.get_halo(2)

if data.comm.rank == 0:
self.assertTrue(np.isclose(((data.halo_next - halo_next) ** 2).mean().item(), 0.0))
self.assertEqual(data.halo_prev, None)
if data.comm.rank == 1:
self.assertTrue(np.isclose(((data.halo_prev - halo_prev) ** 2).mean().item(), 0.0))
self.assertEqual(data.halo_next, None)

if data.comm.size == 3:

halo_1 = torch.tensor(np.array([[2], [8]]))
halo_2 = torch.tensor(np.array([[3], [9]]))
halo_3 = torch.tensor(np.array([[4], [10]]))
halo_4 = torch.tensor(np.array([[5], [11]]))

data.get_halo(1)

data_with_halos = data.array_with_halos

if data.comm.rank == 0:
self.assertTrue(torch.equal(data.halo_next, halo_2))
self.assertEqual(data.halo_prev, None)
self.assertEqual(data_with_halos.shape, (2, 3))
if data.comm.rank == 1:
self.assertTrue(torch.equal(data.halo_prev, halo_1))
self.assertTrue(torch.equal(data.halo_next, halo_4))
self.assertEqual(data_with_halos.shape, (2, 4))
if data.comm.rank == 2:
self.assertEqual(data.halo_next, None)
self.assertTrue(torch.equal(data.halo_prev, halo_3))
self.assertEqual(data_with_halos.shape, (2, 3))

# exception on wrong argument type in get_halo
with self.assertRaises(TypeError):
data.get_halo("wrong_type")
# exception on wrong argument in get_halo
with self.assertRaises(ValueError):
data.get_halo(-99)
# exception for too large halos
with self.assertRaises(ValueError):
data.get_halo(4)

def test_astype(self):
data = ht.float32([[1, 2, 3], [4, 5, 6]], device=ht_device)

Expand Down
9 changes: 6 additions & 3 deletions heat/core/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,12 @@ def __init__(self, arr, tiles_per_proc=2):
# then the local data needs to be redistributed to fit the full diagonal on as many
# processes as possible
# if any(lshape_map[..., arr.split] == 1):
last_diag_pr, col_per_proc_list, col_inds, tile_columns = self.__adjust_lshape_sp0_1tile(
arr, col_inds, lshape_map, tiles_per_proc
)
(
last_diag_pr,
col_per_proc_list,
col_inds,
tile_columns,
) = self.__adjust_lshape_sp0_1tile(arr, col_inds, lshape_map, tiles_per_proc)
# re-test for empty processes and remove empty rows
empties = torch.where(lshape_map[..., 0] == 0)[0]
if empties.numel() > 0:
Expand Down