From 85739eee0a2db4c14074b670e63e36de0cc5a8f2 Mon Sep 17 00:00:00 2001 From: Siddharth Srinivasan Date: Thu, 16 Dec 2021 16:08:02 -0800 Subject: [PATCH] Add datatype for MPI Bcast. --- sigpy/backend.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/sigpy/backend.py b/sigpy/backend.py index 7dcd6cad..7d5a2324 100644 --- a/sigpy/backend.py +++ b/sigpy/backend.py @@ -265,17 +265,21 @@ def reduce(self, input, root=0): else: self.mpi_comm.Reduce(cpu_input, None, root=root) - def bcast(self, input, root=0): + def bcast(self, input, root=0, datatype=None): """Broadcast from root to other nodes. Args: input (array): input array. root (int): root node rank. + datatype (int): MPI datatype for broadcasting. """ + if config.mpi4py_enabled: + datatype = MPI.COMPLEX + if self.size > 1: cpu_input = to_device(input, cpu_device) - self.mpi_comm.Bcast(cpu_input, root=root) + self.mpi_comm.Bcast((cpu_input, datatype), root=root) copyto(input, cpu_input) def gatherv(self, input, root=0):