Skip to content

[Bug]: Broadcasting after the first time does not send buffer correctly. #1830

@Berkant03

Description

@Berkant03

What happened?

When using an Bcast from MPICommunication in the following code example, it does work the first and the buffers are correct, but broadcasting the second time and beyond does not send the buffer correctly. This happens in the _shuffle Method of the DistributedSampler. When using the normal mpi4py Bcast of COMM_WORLD it works all the time.

Code snippet triggering the error

from functools import reduce
from typing import Iterator, List
import mpi4py
import numpy as np
import torch
import torch.distributed as dist
from torch.utils.data import Sampler, Dataset, DataLoader, RandomSampler
import os
import torch.multiprocessing as mp
import random
import heat as ht
from heat import DNDarray
import torch.utils.data as data

from heat.core.communication import MPICommunication


class DistributedDataset(Dataset):
    """
    A DistributedDataset for usage in PyTorch. Saves the dndarray and the larray tensor. Uses the larray tensor
    for the distribution and getting the items.
    """
    def __init__(self, dndarray: DNDarray):
        self.dndarray = dndarray

    def __len__(self) -> int:
        return len(self.dndarray)

    def __getitem__(self, index):
        return self.dndarray.larray[index]

    def __getitems__(self, indices):
        return tuple(self.dndarray.larray[index] for index in indices)


class DistributedSampler(Sampler):
    """
    A DistributedSampler for usage in PyTorch with Heat Arrays. Uses the nature of the Heat DNDArray
    to give the locally stored data on the larray. Shuffling is done by shuffling the indices.
    The given Indices corrospond to the index of the larray tensor.
    """
    def __init__(self, dataset: DistributedDataset, shuffle: bool = False, seed: int = None) -> None:
        """
        Parameters
        ----------
        dataset : DistributedDataset
            Dataset to be shuffled
        shuffle : bool, optional
            If the underlying DNDarray should be shuffled, by default False
        seed : int, optional
            seed for shuffling, by default None
        """
        self.dataset = dataset
        self.dndarray = dataset.dndarray
        self.shuffle = shuffle
        self.set_seed(seed)

    @staticmethod
    def _in_slice(idx: int, a_slice: slice) -> bool:
        """Check if the given index is inside the given slice

        Parameters
        ----------
        idx : int
            Index to check
        a_slice : slice
            Slice to check

        Returns
        -------
        bool
            Wether index is in slice
        """
        if idx < a_slice.start or idx >= a_slice.stop:
            return False
        step = a_slice.step if a_slice.step else 1
        if (idx - a_slice.start) % step == 0:
            return True
        else:
            return False

    def _shuffle(self) -> None:
        """Shuffles the given dndarray at creation across processes."""
        dtype = self.dndarray.dtype.torch_type()
        comm: MPICommunication = ht.MPI_WORLD  #self.dndarray.comm
        rank: int = comm.rank
        world_size: int = comm.size
        N: int = self.dndarray.gshape[0]

        if rank == 0:
            indices = torch.randperm(N, dtype=torch.int32)
        else:
            indices = torch.empty(N, dtype=torch.int32)
        # mpi4py.MPI.COMM_WORLD.Bcast(indices, root=0)
        comm.Bcast(indices)
        print(rank, indices)

        indice_buffers: List[List[int]] = [list() for _ in range(world_size)]
        rank_slices: List[slice] = [
            comm.chunk((N,), split=0, rank=i)[-1][0] for i in range(world_size)
        ]
        local_slice: slice = rank_slices[rank]

        # Now figure out which rank needs to send what to each rank and what this rank will receive
        for i, idx in enumerate(indices):
            idx = idx.item()
            for data_send_rank, tslice in enumerate(rank_slices):
                if not self._in_slice(idx, tslice):
                    continue
                break
            for data_recv_rank, tslice in enumerate(rank_slices):
                if not self._in_slice(i, tslice):
                    continue
                break
            if data_recv_rank == rank:
                indice_buffers[rank].append(idx)
            elif data_send_rank == rank:
                indice_buffers[data_recv_rank].append(idx)

        torch_send_buffers: List[torch.Tensor] = list()
        row_length: int = reduce(lambda a, b: a * b, self.dndarray.gshape[1:], 1)
        local_recv_buffer: torch.Tensor = torch.empty(
            len(indice_buffers[rank]) * row_length, dtype=dtype
        )

        for current_rank in range(world_size):
            if current_rank == rank:
                send_indice = [
                    idx for idx in indice_buffers[current_rank] if self._in_slice(idx, local_slice)
                ]
            else:
                send_indice = indice_buffers[current_rank]

            if len(send_indice) == 1:
                send_indice = tuple(send_indice)  # issue#1816

            buf = self.dndarray[send_indice].larray
            torch_send_buffers.append(buf)

        send_elems = [torch.flatten(elem) for elem in torch_send_buffers]
        send_counts = torch.tensor([len(elem) for elem in send_elems])
        send_displs = torch.zeros(world_size)
        send_displs[1:] = torch.cumsum(send_counts[:-1], dim=0)

        recv_counts = torch.zeros(world_size)
        for idx in indice_buffers[rank]:
            for i, tslice in enumerate(rank_slices):
                if not self._in_slice(idx, tslice):
                    continue
                recv_counts[i] += 1
                break
        recv_counts *= row_length

        recv_displs = torch.zeros(world_size)
        recv_displs[1:] = torch.cumsum(recv_counts[:-1], dim=0)

        comm.Alltoallv(
            (torch.cat(send_elems).contiguous(), send_counts, send_displs),
            (local_recv_buffer, recv_counts, recv_displs),
        )

        arr = local_recv_buffer.reshape(-1, *self.dndarray.gshape[1:])
        self.dndarray.larray = arr

    def set_seed(self, value: int) -> None:
        """Sets the seed for the torch.randperm

        Parameters
        ----------
        value : int
            seed to set
        """
        self._seed = value
        torch.manual_seed(value)

    def __iter__(self) -> Iterator[int]:
        if self.shuffle:
            self._shuffle()
        self.indices = list(range(len(self.dndarray.larray)))
        return iter(self.indices)

    def __len__(self) -> int:
        return len(self.dndarray.larray)


def setup():
    backend = "nccl" if torch.cuda.is_available() else "gloo"

    rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) 
    world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"])

    master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
    master_port = os.environ.get("MASTER_PORT", "29500")

    os.environ["RANK"] = str(rank)
    os.environ["WORLD_SIZE"] = str(world_size)
    os.environ["MASTER_ADDR"] = master_addr
    os.environ["MASTER_PORT"] = master_port

    dist.init_process_group(backend=backend, rank=rank, world_size=world_size, init_method="env://")


def cleanup():
    dist.destroy_process_group()


def main():
    dataset = DistributedDataset(ht.arange(25, split=0).reshape(5, 5).resplit(0))
    
    rank = ht.MPI_WORLD.rank
    sampler = DistributedSampler(dataset, shuffle=True, seed=42)
    dataloader = DataLoader(dataset, batch_size=1, sampler=sampler)

    # print(rank, dataset.dndarray.larray)
    sampler._shuffle()
    print(dataset.dndarray)
    # print(rank, dataset.dndarray.larray)
    sampler._shuffle()
    print(dataset.dndarray)
    sampler._shuffle()
    print(dataset.dndarray)
    # print(rank, dataset.dndarray)

    print(f"Process {rank} gets indices: {list(iter(sampler))}")

    for batch in dataloader:
        print(f"Process {rank}, Batch: {batch}")

if __name__ == "__main__":
    setup()
    main()
    cleanup()

Error message or erroneous outcome

❯ mpirun -np 2 python test.py
0 tensor([2, 4, 3, 0, 1], dtype=torch.int32)
1 tensor([2, 4, 3, 0, 1], dtype=torch.int32)
...

1 tensor([455,   0,   0,   0,  32], dtype=torch.int32)
0 tensor([1, 4, 3, 2, 0], dtype=torch.int32)

Version

Branch of PR #1807

Python version

3.11

PyTorch version

2.5

MPI version

mpirun (Open MPI) 4.1.6

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    Projects

    Status

    Done

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions