Skip to content

Commit

Permalink
docs: doc string improvements and dead code removal
Browse files Browse the repository at this point in the history
  • Loading branch information
JuanPedroGHM committed Jun 18, 2024
1 parent 48939e1 commit c083678
Show file tree
Hide file tree
Showing 5 changed files with 37 additions and 60 deletions.
34 changes: 14 additions & 20 deletions heat/core/communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,12 +250,7 @@ def mpi_type_of(cls, dtype: torch.dtype) -> MPI.Datatype:
Parameters
----------
dtype : torch.dtype
_description_
Returns
-------
MPI.Datatype
_description_
PyTorch data type
"""
return cls.__mpi_type_mappings[dtype]

Expand Down Expand Up @@ -1465,14 +1460,15 @@ def Alltoallw(
recvbuf: Union[DNDarray, torch.Tensor, Any],
):
"""
Generalized All-to-All communication allowing different counts, displacements and datatypes for each partner.
Generalized All-to-All communication allowing different counts, displacements and datatypes for each partner. See MPI standard for more information.
Parameters
----------
sendbuf: Union[DNDarray, torch.Tensor, Any]
Buffer address of the send message
Buffer address of the send message. The buffer is expected to be a tuple of the form (buffer, (counts, displacements), subarray_params_list), where subarray_params_list is a list of tuples of the form (lshape, subsizes, substarts).
recvbuf: Union[DNDarray, torch.Tensor, Any]
Buffer address where to store the result
Buffer address where to store the result. The buffer is expected to be a tuple of the form (buffer, (counts, displacements), subarray_params_list), where subarray_params_list is a list of tuples of the form (lshape, subsizes, substarts).
"""
# Unpack sendbuffer information
sendbuf_tensor, (send_counts, send_displs), subarray_params_list = sendbuf
Expand Down Expand Up @@ -1530,7 +1526,6 @@ def Alltoallw(
target_subarray_types.append(MPI.INT)

# Perform the Alltoallw operation
# print("Sendbuf: ", stride, (send_counts, send_displs), source_subarray_types)
self.handle.Alltoallw(
[sendbuf_ptr, (send_counts, send_displs), source_subarray_types],
[recvbuf_ptr, (recv_counts, recv_displs), target_subarray_types],
Expand All @@ -1557,7 +1552,11 @@ def Alltoallw(
Alltoallw.__doc__ = MPI.Comm.Alltoallw.__doc__

def _create_recursive_vectortype(
self, datatype: MPI.Datatype, tensor_stride, subarray_sizes, start
self,
datatype: MPI.Datatype,
tensor_stride: Tuple[int],
subarray_sizes: List[int],
start: List[int],
):
"""
Create a recursive vector to handle non-contiguous tensor data. The created datatype will be a recursively defined vector datatype that will enable the collection of non-contiguous tensor data in the specified subarray sizes.
Expand All @@ -1566,21 +1565,16 @@ def _create_recursive_vectortype(
----------
datatype : MPI.Datatype
The base datatype to create the recursive vector datatype from.
tensor_stride : list
tensor_stride : Tuple[int]
A list of tensor strides for each dimension.
subarray_sizes : list
subarray_sizes : List[int]
A list of subarray sizes for each dimension.
start: list
start: List[int]
Index of the first element of the subarray in the original array.
Returns
-------
MPI.Datatype
The created recursive vector datatype.
Notes
-----
This function creates a recursive vector datatype by defining vectors out of the previous datatype with specified strides and sizes. The extent of the new datatype is set to the extent of the basic datatype to allow interweaving of data.
This function creates a recursive vector datatype by defining vectors out of the previous datatype with specified strides and sizes. The extent (size of the data type in bytes) of the new datatype is set to the extent of the basic datatype to allow interweaving of data.
Examples
--------
Expand Down
2 changes: 1 addition & 1 deletion heat/core/dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1483,7 +1483,7 @@ def resplit_(self, axis: int = None):
)

self._axis2axisResplit(
self.comm, self.split, self.larray, arr_tiles, axis, recv_buffer, new_tiles
self.comm, self.larray, self.split, arr_tiles, recv_buffer, axis, new_tiles
)

self.__array = recv_buffer
Expand Down
45 changes: 19 additions & 26 deletions heat/core/manipulations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3544,7 +3544,7 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:
new_tiles = tiling.SplitTiles(new_arr)

new_arr.larray = _axis2axisResplit(
arr.comm, arr.split, arr.larray, arr_tiles, axis, new_arr.larray, new_tiles
arr.comm, arr.larray, arr.split, arr_tiles, new_arr.larray, axis, new_tiles
)

return new_arr
Expand All @@ -3558,61 +3558,54 @@ def resplit(arr: DNDarray, axis: int = None) -> DNDarray:

def _axis2axisResplit(
comm: Communication,
source_axis: int,
source_array: torch.Tensor,
source_larray: torch.Tensor,
source_split: int,
source_tiles: tiling.SplitTiles,
target_axis: int,
target_array: torch.Tensor,
target_larray: torch.Tensor,
target_split: int,
target_tiles: tiling.SplitTiles,
) -> torch.Tensor:
"""
Resplits the input array along a new axis and performs data exchange using MPI_Alltoallw.
Resplits the input array along a new axis and performs data exchange using MPI_Alltoallw. Returns target_larray object with the data after the exchange.
Parameters
----------
comm : Communication
The communication object for MPI communication.
source_axis : int
The axis along which the source array is split.
source_array : torch.Tensor
source_larray : torch.Tensor
The source array to be resplit.
source_split : int
The axis along which the source array is split.
source_tiles : tiling.SplitTiles
The tiling object containing the subarray parameters for the source array.
target_axis : int
The axis along which the target array is split.
target_array : torch.Tensor
target_larray : torch.Tensor
The target array to store the resplit data.
target_split : int
The axis along which the target array is split.
target_tiles : tiling.SplitTiles
The tiling object containing the subarray parameters for the target array.
Returns
-------
torch.Tensor
The resplit target array.
"""
# Create subarray types for original local shapes split along the new axis
source_subarray_params = source_tiles.get_subarray_params(source_axis, target_axis)
source_subarray_params = source_tiles.get_subarray_params(source_split, target_split)

# Create subarray types for resplit local array along the old axis
target_subarray_params = target_tiles.get_subarray_params(target_axis, source_axis)
target_subarray_params = target_tiles.get_subarray_params(target_split, source_split)

world_size = comm.Get_size()
counts = [1] * world_size
displs = [0] * world_size

# Perform the data exchange using MPI_Alltoallw
comm.Alltoallw(
(source_array, (counts.copy(), displs.copy()), source_subarray_params),
(target_array, (counts.copy(), displs.copy()), target_subarray_params),
(source_larray, (counts.copy(), displs.copy()), source_subarray_params),
(target_larray, (counts.copy(), displs.copy()), target_subarray_params),
)

# print(f"Source axis: {source_axis}, Source array: {source_array}")
# print(f"Target axis: {target_axis}, Target array: {target_array}")
return target_array
return target_larray


DNDarray._axis2axisResplit = lambda self, comm, source_axis, source_array, source_tiles, target_axis, target_array, target_tile: _axis2axisResplit(
comm, source_axis, source_array, source_tiles, target_axis, target_array, target_tile
DNDarray._axis2axisResplit = lambda self, comm, source_larray, source_split, source_tiles, target_larray, target_split, target_tile: _axis2axisResplit(
comm, source_larray, source_split, source_tiles, target_larray, target_split, target_tile
)
DNDarray._axis2axisResplit.__doc__ = _axis2axisResplit.__doc__

Expand Down
7 changes: 0 additions & 7 deletions heat/core/tests/test_dndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,14 +1108,9 @@ def test_resplit(self):

# 3D non-contiguous resplit testing (Column mayor ordering)
torch_array = torch.arange(100, device=self.device.torch_device).reshape((10, 5, 2))
print(torch_array.device)
heat_array = ht.array(torch_array, split=2, order="F")
print(heat_array.device)
heat_array.resplit_(axis=1)
print(heat_array.device)
res = np.arange(100).reshape(10, 5, 2)
print(heat_array.device)
print(ht.array(res).device)
self.assertTrue(ht.array(res).device == heat_array.device)
self.assertTrue(ht.all(heat_array == ht.array(res)))
self.assertEqual(heat_array.split, 1)
Expand All @@ -1127,8 +1122,6 @@ def test_resplit(self):
res = torch_array.cpu().numpy().transpose((3, 1, 2, 0))
heat_array = ht.array(torch_array, split=2).transpose((3, 1, 2, 0))
heat_array.resplit_(axis=1)
print(heat_array.device)
print(ht.array(res).device)
self.assertTrue(ht.array(res).device == heat_array.device)
self.assertTrue(ht.all(heat_array == ht.array(res)))
self.assertEqual(heat_array.split, 1)
Expand Down
9 changes: 3 additions & 6 deletions heat/core/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,19 +331,16 @@ def __setitem__(
def get_subarray_params(
self, from_axis: int, to_axis: int
) -> List[Tuple[List[int], List[int], List[int]]]:
"""Create subarray types of the local array along a new split axis. For use with alltoallw.
"""Create subarray types of the local array along a new split axis. For use with Alltoallw.
Return type is a list of tuples, each tuple containing the shape of the local array, the shape of the subarray, and the start index of the subarray.
Parameters
----------
from_axis : int
Current split axis of global array.
to_axis : int
New split axis of of subarrays array.
Returns
-------
List[Tuple[List[int], List[int], List[int]]]
List of subarray parameters for all processes. For use with Create_subarray.
"""
arr = self.__DNDarray
world_size = arr.comm.Get_size()
Expand Down

0 comments on commit c083678

Please sign in to comment.