In [2]:
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# @noautodeps
# pyre-ignore-all-errors
import logging
import os
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.optim as optim

from monarch.actor import Actor, current_rank, endpoint
from monarch.job import SlurmJob
from monarch.utils import setup_env_for_distributed
from torch.nn.parallel import DistributedDataParallel as DDP


logging.basicConfig(
    level=logging.INFO,
    format="%(name)s %(asctime)s %(levelname)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,
)


logger: logging.Logger = logging.getLogger(__name__)


class ToyModel(nn.Module):
    """A simple toy model for demonstration purposes."""

    def __init__(self):
        super(ToyModel, self).__init__()
        self.net1 = nn.Linear(10, 10)
        self.relu = nn.ReLU()
        self.net2 = nn.Linear(10, 5)

    def forward(self, x):
        return self.net2(self.relu(self.net1(x)))


class DDPActor(Actor):
    """This Actor wraps the basic functionality from Torch's DDP example.

    Conveniently, all of the methods we need are already laid out for us,
    so we can just wrap them in the usual Actor endpoint semantic with some
    light modifications.

    Adapted from: https://docs.pytorch.org/tutorials/intermediate/ddp_tutorial.html#basic-use-case
    """

    def __init__(self):
        self.rank = current_rank().rank

    def _rprint(self, msg):
        """Helper method to print with rank information."""
        print(f"{self.rank=} {msg}")

    @endpoint
    async def setup(self):
        """Initialize the PyTorch distributed process group."""
        self._rprint("Initializing torch distributed")

        WORLD_SIZE = int(os.environ["WORLD_SIZE"])
        # initialize the process group
        dist.init_process_group("gloo", rank=self.rank, world_size=WORLD_SIZE)
        self._rprint("Finished initializing torch distributed")

    @endpoint
    async def cleanup(self):
        """Clean up the PyTorch distributed process group."""
        self._rprint("Cleaning up torch distributed")
        dist.destroy_process_group()

    @endpoint
    async def demo_basic(self):
        """Run a basic DDP training example."""
        self._rprint("Running basic DDP example")

        # create model and move it to GPU with id rank
        local_rank = int(os.environ["LOCAL_RANK"])
        self._rprint(f"{local_rank=}")
        model = ToyModel().to(local_rank)
        ddp_model = DDP(model, device_ids=[local_rank])

        loss_fn = nn.MSELoss()
        optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

        optimizer.zero_grad()
        outputs = ddp_model(torch.randn(20, 10))
        labels = torch.randn(20, 5).to(local_rank)
        loss_fn(outputs, labels).backward()
        optimizer.step()

        print(f"{self.rank=} Finished running basic DDP example")


async def main():
    num_nodes = 2
    gpus_per_node = 4
    mesh_name = "mesh0"
    
    # Create SLURM job
    slurm_job = SlurmJob(
        meshes={mesh_name: num_nodes},
        job_name="monarch_example",
        gpus_per_node=gpus_per_node,
        time_limit="06:00:00",
    )

    try:
        # Get job state and create process mesh
        job_state = slurm_job.state()
        proc_mesh = job_state.mesh0.spawn_procs({"gpus": gpus_per_node})

        # Spawn DDP actor
        ddp_actor = proc_mesh.spawn("ddp_actor", DDPActor)

        # Setup distributed environment
        await setup_env_for_distributed(proc_mesh)

        # Run DDP example
        await ddp_actor.setup.call()
        await ddp_actor.demo_basic.call()
        await ddp_actor.cleanup.call()

        print("DDP example completed successfully!")

    finally:
        # Cancel the SLURM job, releasing all reserved nodes back to the cluster
        slurm_job.kill()
        logger.info("Job terminated successfully")


if __name__ == "__main__":
    await main()

Found cached job at path: .monarch/job_state.pkl
Error checking job 7758 status: slurm_load_jobs error: Invalid job id specified

SLURM job 7758 not found in queue
Cached job cannot run this spec, removing cache
Cancelled SLURM job 7758
Applying current job
Submitting SLURM job with 2 nodes
SLURM job 9140 submitted. Logs will be written to: /home/mreso/monarch/examples/slurm_9140_monarch_example_2125347.out
Saving job to cache at .monarch/job_state.pkl
Job has started, connecting to current state
SLURM job 9140 is running on 2 nodes: ['slurm-compute-node-011', 'slurm-compute-node-012']


[36m>>> Aggregated Logs (2025-11-18 06:39:02) >>>[0m
[33m[1 similar log lines][0m [6] self.rank=6 Initializing torch distributed
[36m<<< Aggregated Logs (2025-11-18 06:39:05) <<<[0m

[36m>>> Aggregated Logs (2025-11-18 06:39:05) >>>[0m
[33m[7 similar log lines][0m [7] self.rank=7 Initializing torch distributed
[33m[8 similar log lines][0m [7] [Gloo] Rank 7 is connected to 7 peer ranks. Expected number of connected peer ranks is : 7
[33m[8 similar log lines][0m [7] self.rank=7 Finished initializing torch distributed
[33m[8 similar log lines][0m [7] self.rank=7 Running basic DDP example
[33m[8 similar log lines][0m [7] self.rank=7 local_rank=3
[36m<<< Aggregated Logs (2025-11-18 06:39:08) <<<[0m



Cancelled SLURM job 9140
__main__ 2025-11-18 06:39:11 INFO Job terminated successfully


DDP example completed successfully!


[36m>>> Aggregated Logs (2025-11-18 06:39:08) >>>[0m
[33m[8 similar log lines][0m [5] self.rank=5 Finished running basic DDP example
[33m[8 similar log lines][0m [1] self.rank=1 Cleaning up torch distributed
[36m<<< Aggregated Logs (2025-11-18 06:39:11) <<<[0m



[-]E1118 06:39:42.498521 2125347 hyperactor/src/channel/net.rs:876] error_msg:session tcp:172.27.60.61:22222.1936278980149669291: failed to deliver message within timeout
[-]E1118 06:39:42.498624 2125347 hyperactor/src/mailbox.rs:344] name:undelivered_message_attempt, sender:tcp:172.27.55.7:43767,mesh_root_client_proc,client[0], dest:tcp:172.27.60.61:22222,service,agent[0][15904432331645146645<hyperactor_mesh::resource::GetState<hyperactor_mesh::v1::host_mesh::mesh_agent::ProcState>>], error:broken link: failed to enqueue in MailboxClient when processing buffer: channel closed, return_handle:tcp:172.27.55.7:43767,mesh_root_client_proc,client[0]<hyperactor::mailbox::undeliverable::Undeliverable<hyperactor::mailbox::MessageEnvelope>>
[-]E1118 06:39:42.498634 2125347 hyperactor/src/mailbox.rs:344] name:undelivered_message_attempt, sender:tcp:172.27.55.7:43767,mesh_root_client_proc,client[0], dest:tcp:172.27.60.61:22222,service,agent[0][15904432331645146645<hyperactor_mesh::resource::GetSt