diff --git a/sigpy/backend.py b/sigpy/backend.py index 7dcd6cad..7d5a2324 100644 --- a/sigpy/backend.py +++ b/sigpy/backend.py @@ -265,17 +265,21 @@ def reduce(self, input, root=0): else: self.mpi_comm.Reduce(cpu_input, None, root=root) - def bcast(self, input, root=0): + def bcast(self, input, root=0, datatype=None): """Broadcast from root to other nodes. Args: input (array): input array. root (int): root node rank. + datatype (int): MPI datatype for broadcasting. """ + if config.mpi4py_enabled: + datatype = MPI.COMPLEX + if self.size > 1: cpu_input = to_device(input, cpu_device) - self.mpi_comm.Bcast(cpu_input, root=root) + self.mpi_comm.Bcast((cpu_input, datatype), root=root) copyto(input, cpu_input) def gatherv(self, input, root=0):