In [1]:
from mpi4py import MPI
import torch


In [2]:
# Initialize MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()  # Get process rank (ID)
size = comm.Get_size()  # Get the total number of processes

In [4]:
print(f"Hello from rank {rank} of {size}")
print(f"PyTorch version: {torch.__version__}")
print(f'size of the mpi comm = {size}')

Hello from rank 0 of 1
PyTorch version: 2.3.1
size of the mpi comm = 1


In [5]:



# Define the size of the tensor and scalar for computation
tensor_size = 16
scalar = 2.0

if rank == 0:
    # Rank 0 will initialize the full tensor and scatter it
    full_tensor = torch.arange(1, tensor_size + 1, dtype=torch.float32)  # Create a tensor from 1 to tensor_size
else:
    full_tensor = None

# Each process gets a chunk of the tensor
chunk_size = tensor_size // size
local_chunk = torch.zeros(chunk_size, dtype=torch.float32)

# Scatter the tensor from rank 0 to all processes
comm.Scatter(full_tensor, local_chunk, root=0)

print(f"Rank {rank} received chunk: {local_chunk}")

# Each process performs a computation on its local chunk (e.g., multiply by a scalar)
local_chunk *= scalar

# Gather the results back at rank 0
gathered_tensor = None
if rank == 0:
    gathered_tensor = torch.zeros(tensor_size, dtype=torch.float32)

comm.Gather(local_chunk, gathered_tensor, root=0)

if rank == 0:
    print(f"Gathered tensor after multiplication: {gathered_tensor}")


Rank 0 received chunk: tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12., 13., 14.,
        15., 16.])
Gathered tensor after multiplication: tensor([ 2.,  4.,  6.,  8., 10., 12., 14., 16., 18., 20., 22., 24., 26., 28.,
        30., 32.])
