In [None]:
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

# @noautodeps
# pyre-ignore-all-errors
import json
import logging
import sys

import cloudpickle
from example_actors.compute_world_size_actor import ComputeWorldSizeActor
from slurm.utils import create_slurm_job, cleanup_job


logging.basicConfig(
    level=logging.INFO,
    format="%(name)s %(asctime)s %(levelname)s %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    force=True,
)


logger: logging.Logger = logging.getLogger(__name__)


async def main():
    num_nodes = 2
    gpus_per_node = 8
    mesh_name = "mesh0"
    
    # Create SLURM job
    slurm_job = create_slurm_job(mesh_name, num_nodes, gpus_per_node)

    try:
        # Get job state and create process mesh
        job_state = slurm_job.state()
        proc_mesh = job_state.mesh0.spawn_procs({"gpus": gpus_per_node})
        
        # Spawn actor
        actor = proc_mesh.spawn("compute_world_size_actor", ComputeWorldSizeActor)

        logger.info("computing world size...")
        # this is redundant but is here for example sake
        values = await actor.compute_world_size.call(
            master_addr=job_state.mesh0.hosts[0],
            master_port=29500,
        )

        values_by_rank = {f"rank_{p.rank}": v for p, v in list(values.flatten("rank"))}

        logger.info(
            f"""computed world_sizes:
    {'-'*40}
    {json.dumps(values_by_rank, indent=2)}
    {'-'*40}"""
        )
    finally:
        await cleanup_job(slurm_job)


if __name__ == "__main__":
    cloudpickle.register_pickle_by_value(sys.modules[ComputeWorldSizeActor.__module__])

    await main()