Skip to content

Commit

Permalink
Merge pull request #1509 from helmholtz-analytics/features/allow-ones…
Browse files Browse the repository at this point in the history
…ided-halo

Support one-sided halo for DNDarrays
  • Loading branch information
FOsterfeld committed Jun 5, 2024
2 parents 5cdd986 + 63aa686 commit 9b059f8
Showing 1 changed file with 23 additions and 15 deletions.
38 changes: 23 additions & 15 deletions heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,14 +384,18 @@ def __prephalo(self, start, end) -> torch.Tensor:

return self.__array[ix].clone().contiguous()

def get_halo(self, halo_size: int) -> torch.Tensor:
def get_halo(self, halo_size: int, prev: bool = True, next: bool = True) -> torch.Tensor:
"""
Fetch halos of size ``halo_size`` from neighboring ranks and save them in ``self.halo_next/self.halo_prev``.
Parameters
----------
halo_size : int
Size of the halo.
prev : bool, optional
If True, fetch the halo from the previous rank. Default: True.
next : bool, optional
If True, fetch the halo from the next rank. Default: True.
"""
if not isinstance(halo_size, int):
raise TypeError(
Expand Down Expand Up @@ -433,25 +437,29 @@ def get_halo(self, halo_size: int) -> torch.Tensor:
req_list = []

# exchange data with next populated process
if rank != last_rank:
self.comm.Isend(a_next, next_rank)
res_prev = torch.zeros(
a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_prev, source=next_rank))
if prev:
if rank != last_rank:
self.comm.Isend(a_next, next_rank)
if rank != first_rank:
res_prev = torch.zeros(
a_prev.size(), dtype=a_prev.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_prev, source=prev_rank))

if rank != first_rank:
self.comm.Isend(a_prev, prev_rank)
res_next = torch.zeros(
a_next.size(), dtype=a_next.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_next, source=prev_rank))
if next:
if rank != first_rank:
self.comm.Isend(a_prev, prev_rank)
if rank != last_rank:
res_next = torch.zeros(
a_next.size(), dtype=a_next.dtype, device=self.device.torch_device
)
req_list.append(self.comm.Irecv(res_next, source=next_rank))

for req in req_list:
req.Wait()

self.__halo_next = res_prev
self.__halo_prev = res_next
self.__halo_next = res_next
self.__halo_prev = res_prev
self.__ishalo = True

def __cat_halo(self) -> torch.Tensor:
Expand Down

0 comments on commit 9b059f8

Please sign in to comment.