Skip to content

Commit

Permalink
Add datatype for MPI Bcast.
Browse files Browse the repository at this point in the history
  • Loading branch information
sidward committed Dec 17, 2021
1 parent 815e17d commit 85739ee
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions sigpy/backend.py
Expand Up @@ -265,17 +265,21 @@ def reduce(self, input, root=0):
else:
self.mpi_comm.Reduce(cpu_input, None, root=root)

def bcast(self, input, root=0):
def bcast(self, input, root=0, datatype=None):
"""Broadcast from root to other nodes.
Args:
input (array): input array.
root (int): root node rank.
datatype (int): MPI datatype for broadcasting.
"""
if config.mpi4py_enabled:
datatype = MPI.COMPLEX

if self.size > 1:
cpu_input = to_device(input, cpu_device)
self.mpi_comm.Bcast(cpu_input, root=root)
self.mpi_comm.Bcast((cpu_input, datatype), root=root)
copyto(input, cpu_input)

def gatherv(self, input, root=0):
Expand Down

0 comments on commit 85739ee

Please sign in to comment.