-
Notifications
You must be signed in to change notification settings - Fork 61
Closed
Labels
bugSomething isn't workingSomething isn't working
Milestone
Description
What happened?
When using an Bcast from MPICommunication in the following code example, it does work the first and the buffers are correct, but broadcasting the second time and beyond does not send the buffer correctly. This happens in the _shuffle Method of the DistributedSampler. When using the normal mpi4py Bcast of COMM_WORLD it works all the time.
Code snippet triggering the error
from functools import reduce
from typing import Iterator, List
import mpi4py
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Sampler, Dataset, DataLoader, RandomSampler
import os
import torch.multiprocessing as mp
import random
import heat as ht
from heat import DNDarray
import torch.utils.data as data
from heat.core.communication import MPICommunication
class DistributedDataset(Dataset):
"""
A DistributedDataset for usage in PyTorch. Saves the dndarray and the larray tensor. Uses the larray tensor
for the distribution and getting the items.
"""
def __init__(self, dndarray: DNDarray):
self.dndarray = dndarray
def __len__(self) -> int:
return len(self.dndarray)
def __getitem__(self, index):
return self.dndarray.larray[index]
def __getitems__(self, indices):
return tuple(self.dndarray.larray[index] for index in indices)
class DistributedSampler(Sampler):
"""
A DistributedSampler for usage in PyTorch with Heat Arrays. Uses the nature of the Heat DNDArray
to give the locally stored data on the larray. Shuffling is done by shuffling the indices.
The given Indices corrospond to the index of the larray tensor.
"""
def __init__(self, dataset: DistributedDataset, shuffle: bool = False, seed: int = None) -> None:
"""
Parameters
----------
dataset : DistributedDataset
Dataset to be shuffled
shuffle : bool, optional
If the underlying DNDarray should be shuffled, by default False
seed : int, optional
seed for shuffling, by default None
"""
self.dataset = dataset
self.dndarray = dataset.dndarray
self.shuffle = shuffle
self.set_seed(seed)
@staticmethod
def _in_slice(idx: int, a_slice: slice) -> bool:
"""Check if the given index is inside the given slice
Parameters
----------
idx : int
Index to check
a_slice : slice
Slice to check
Returns
-------
bool
Wether index is in slice
"""
if idx < a_slice.start or idx >= a_slice.stop:
return False
step = a_slice.step if a_slice.step else 1
if (idx - a_slice.start) % step == 0:
return True
else:
return False
def _shuffle(self) -> None:
"""Shuffles the given dndarray at creation across processes."""
dtype = self.dndarray.dtype.torch_type()
comm: MPICommunication = ht.MPI_WORLD #self.dndarray.comm
rank: int = comm.rank
world_size: int = comm.size
N: int = self.dndarray.gshape[0]
if rank == 0:
indices = torch.randperm(N, dtype=torch.int32)
else:
indices = torch.empty(N, dtype=torch.int32)
# mpi4py.MPI.COMM_WORLD.Bcast(indices, root=0)
comm.Bcast(indices)
print(rank, indices)
indice_buffers: List[List[int]] = [list() for _ in range(world_size)]
rank_slices: List[slice] = [
comm.chunk((N,), split=0, rank=i)[-1][0] for i in range(world_size)
]
local_slice: slice = rank_slices[rank]
# Now figure out which rank needs to send what to each rank and what this rank will receive
for i, idx in enumerate(indices):
idx = idx.item()
for data_send_rank, tslice in enumerate(rank_slices):
if not self._in_slice(idx, tslice):
continue
break
for data_recv_rank, tslice in enumerate(rank_slices):
if not self._in_slice(i, tslice):
continue
break
if data_recv_rank == rank:
indice_buffers[rank].append(idx)
elif data_send_rank == rank:
indice_buffers[data_recv_rank].append(idx)
torch_send_buffers: List[torch.Tensor] = list()
row_length: int = reduce(lambda a, b: a * b, self.dndarray.gshape[1:], 1)
local_recv_buffer: torch.Tensor = torch.empty(
len(indice_buffers[rank]) * row_length, dtype=dtype
)
for current_rank in range(world_size):
if current_rank == rank:
send_indice = [
idx for idx in indice_buffers[current_rank] if self._in_slice(idx, local_slice)
]
else:
send_indice = indice_buffers[current_rank]
if len(send_indice) == 1:
send_indice = tuple(send_indice) # issue#1816
buf = self.dndarray[send_indice].larray
torch_send_buffers.append(buf)
send_elems = [torch.flatten(elem) for elem in torch_send_buffers]
send_counts = torch.tensor([len(elem) for elem in send_elems])
send_displs = torch.zeros(world_size)
send_displs[1:] = torch.cumsum(send_counts[:-1], dim=0)
recv_counts = torch.zeros(world_size)
for idx in indice_buffers[rank]:
for i, tslice in enumerate(rank_slices):
if not self._in_slice(idx, tslice):
continue
recv_counts[i] += 1
break
recv_counts *= row_length
recv_displs = torch.zeros(world_size)
recv_displs[1:] = torch.cumsum(recv_counts[:-1], dim=0)
comm.Alltoallv(
(torch.cat(send_elems).contiguous(), send_counts, send_displs),
(local_recv_buffer, recv_counts, recv_displs),
)
arr = local_recv_buffer.reshape(-1, *self.dndarray.gshape[1:])
self.dndarray.larray = arr
def set_seed(self, value: int) -> None:
"""Sets the seed for the torch.randperm
Parameters
----------
value : int
seed to set
"""
self._seed = value
torch.manual_seed(value)
def __iter__(self) -> Iterator[int]:
if self.shuffle:
self._shuffle()
self.indices = list(range(len(self.dndarray.larray)))
return iter(self.indices)
def __len__(self) -> int:
return len(self.dndarray.larray)
def setup():
backend = "nccl" if torch.cuda.is_available() else "gloo"
rank = int(os.environ["OMPI_COMM_WORLD_RANK"])
world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
master_port = os.environ.get("MASTER_PORT", "29500")
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = master_addr
os.environ["MASTER_PORT"] = master_port
dist.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method="env://")
def cleanup():
dist.destroy_process_group()
def main():
dataset = DistributedDataset(ht.arange(25, split=0).reshape(5, 5).resplit(0))
rank = ht.MPI_WORLD.rank
sampler = DistributedSampler(dataset, shuffle=True, seed=42)
dataloader = DataLoader(dataset, batch_size=1, sampler=sampler)
# print(rank, dataset.dndarray.larray)
sampler._shuffle()
print(dataset.dndarray)
# print(rank, dataset.dndarray.larray)
sampler._shuffle()
print(dataset.dndarray)
sampler._shuffle()
print(dataset.dndarray)
# print(rank, dataset.dndarray)
print(f"Process {rank} gets indices: {list(iter(sampler))}")
for batch in dataloader:
print(f"Process {rank}, Batch: {batch}")
if __name__ == "__main__":
setup()
main()
cleanup()Error message or erroneous outcome
❯ mpirun -np 2 python test.py
0 tensor([2, 4, 3, 0, 1], dtype=torch.int32)
1 tensor([2, 4, 3, 0, 1], dtype=torch.int32)
...
1 tensor([455, 0, 0, 0, 32], dtype=torch.int32)
0 tensor([1, 4, 3, 2, 0], dtype=torch.int32)Version
Branch of PR #1807
Python version
3.11
PyTorch version
2.5
MPI version
mpirun (Open MPI) 4.1.6Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
bugSomething isn't workingSomething isn't working
Type
Projects
Status
Done