## Monarch + TorchTitan on SLURM
This example notebook demonstrates how you can easily run and iterate on a distributed training job with Monarch and TorchTitan.

#### Prerequisites
Please make sure your environment is setup for this notebook:
1. Install Monarch nightly: https://github.com/meta-pytorch/monarch/blob/main/scripts/install_nightly.py
2. Install Titan nightly: https://github.com/pytorch/torchtitan?tab=readme-ov-file#nightly-builds
3. Ensure you have a valid Titan model config in the script directory (i.e: https://github.com/pytorch/torchtitan/blob/main/torchtitan/models/llama3/train_configs/debug_model.toml)

### 1. Create your SLURM job
Configure parameters for your cluster:
- num_nodes: Number of nodes to allocate (default: 2)
- gpus_per_node: Number of GPUs per node (default: 8)
- mesh_name: Name for the mesh (default: "mesh0")
- time_limit: Maximum job duration (default: "06:00:00")

In [None]:
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

from slurm.utils import create_slurm_job, cleanup_job

num_nodes = 2  # assign for your system
gpus_per_node = 8  # adjust for your hardware
mesh_name = "mesh0"

# Create a SLURM job with N nodes
slurm_job = create_slurm_job(
    mesh_name,
    num_nodes,
    gpus_per_node,
    # time_limit="06:00:00",  # optional
)

### 2. Define your Titan and cluster parameters

In [None]:
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

from torchtitan.train import Trainer
from torchtitan.config import ConfigManager, JobConfig
from monarch.actor import Actor, current_rank, endpoint
from torchtitan.tools.logging import init_logger, logger
import torch
from dataclasses import dataclass
import os
from monarch.utils import setup_env_for_distributed


@dataclass
class RunParams:
    """
        Parameters for your cluster and training job, adjust as needed
    """
    training_steps: int = 50
    model_config = "debug_model.toml"
    dataset = "c4"
    num_nodes = num_nodes
    gpus_per_node = gpus_per_node


class TrainerActor(Actor):
    """
        A simple wrapper class with executes a TorchTitan trainer in a Monarch actor
    """
    def __init__(self, job_config: JobConfig) -> None:
        self.job_config = job_config
        rank = current_rank().rank
        self.uid = f"[trainer_{rank}]"

    @endpoint
    async def start_training(self) -> None:
        init_logger()
        trainer: Trainer | None = None

        try:
            trainer = Trainer(self.job_config)
            logger.info(f"{self.uid} initialized successfully and starting training")
            trainer.train()
        except Exception:
            if trainer:
                trainer.close()
            raise
        else:
            trainer.close()
        finally:
            torch.distributed.destroy_process_group()
            logger.info(f"{self.uid} trainer cleaned up")

def make_job_config() -> JobConfig:
    """
        Create a job config which is digested by TorchTitan, sourced from RunParams
    """
    data_parallel_shard_degree = RunParams.num_nodes * RunParams.gpus_per_node
    output_path = "./outputs"

    script_dir = globals()['_dh'][0]
    default_args = [
        "--job.config_file",
        os.path.join(script_dir, RunParams.model_config),
        "--model.hf_assets_path",
        os.path.join(script_dir, "tokenizer"),
        "--comm.trace_buf_size",
        "0",
        "--metrics.log_freq",
        "1",
        "--parallelism.data_parallel_shard_degree",
        str(data_parallel_shard_degree),
        "--activation_checkpoint.mode",
        "full",
        "--comm.train_timeout_seconds",
        "60",
        "--training.steps",
        str(RunParams.training_steps),
        "--training.dataset",
        RunParams.dataset,
        "--job.dump_folder",
        output_path,
        "--metrics.enable_tensorboard",
    ]

    config_manager = ConfigManager()
    job_config = config_manager.parse_args(default_args)

    return job_config

### 3. Execute your training job
You can make adjustments and run this on the existing SLURM allocations as many times as you would like!

In [3]:
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

async def main():
    job_config = make_job_config()

    try:
        # 1. Get job state and create process mesh
        job_state = slurm_job.state()
        proc_mesh = job_state.mesh0.spawn_procs({"gpus": RunParams.gpus_per_node})
        
        # 2. Configure remote logging behavior
        await proc_mesh.logging_option(
            stream_to_client=True,
            # aggregate_window_sec=None  # Uncomment to disable log batching
        )
        
        # 3. Setup environment for torch.distributed
        await setup_env_for_distributed(proc_mesh)
        
        # 4. Spawn TrainerActor on each GPU
        trainer = proc_mesh.spawn("trainer_actor", TrainerActor, job_config)
        
        # 5. Execute the training job
        await trainer.start_training.call()
        
        logger.info("Training completed successfully!")
        
    except Exception as e:
        logger.error(f"Training workflow failed: {e}")


if __name__ == "__main__":
    await main()

Found cached job at path: .monarch/job_state.pkl
SLURM job 7748 not found in queue
Cached job cannot run this spec, removing cache
Cancelled SLURM job 7748
Applying current job
Submitting SLURM job with 2 nodes
SLURM job 7749 submitted. Logs will be written to: /home/mreso/monarch/examples/slurm_7749_monarch_example_1780323.out
Saving job to cache at .monarch/job_state.pkl
Job has started, connecting to current state
SLURM job 7749 is running on 2 nodes: ['slurm-compute-node-074', 'slurm-compute-node-077']


[36m>>> Aggregated Logs (2025-11-15 00:42:28) >>>[0m
[33m[1 similar log lines][0m [6] [titan] 2025-11-15 00:42:38,346 - root - INFO - Starting job: Llama 3 debug training
[36m<<< Aggregated Logs (2025-11-15 00:42:38) <<<[0m



[36m>>> Aggregated Logs (2025-11-15 00:42:28) >>>[0m
[33m[8 similar log lines][0m [2]   sliced_mesh_layout = self._get_slice_mesh_layout(mesh_dim_names)
[36m<<< Aggregated Logs (2025-11-15 00:42:41) <<<[0m



[36m>>> Aggregated Logs (2025-11-15 00:42:38) >>>[0m
[33m[7 similar log lines][0m [4] [titan] 2025-11-15 00:42:38,347 - root - INFO - Starting job: Llama 3 debug training
[33m[8 similar log lines][0m [2] [titan] 2025-11-15 00:42:38,121 - root - INFO - Building 1-D device mesh with ['dp_shard'], [8]
[33m[8 similar log lines][0m [2] [titan] 2025-11-15 00:42:38,126 - root - INFO - [GC] Initial GC collection took 0.00 seconds
[36m<<< Aggregated Logs (2025-11-15 00:42:41) <<<[0m

[36m>>> Aggregated Logs (2025-11-15 00:42:41) >>>[0m
[33m[8 similar log lines][0m [5] [titan] 2025-11-15 00:42:43,362 - root - INFO - Loading tokenizer from tokenizer.json
[33m[8 similar log lines][0m [5] [titan] 2025-11-15 00:42:43,365 - root - INFO - Preparing c4 dataset from allenai/c4
[36m<<< Aggregated Logs (2025-11-15 00:42:44) <<<[0m

[36m>>> Aggregated Logs (2025-11-15 00:42:44) >>>[0m
[33m[8 similar log lines][0m [6] [titan] 2025-11-15 00:42:46,502 - root - INFO - Building llama3 debu

[36m>>> Aggregated Logs (2025-11-15 00:42:41) >>>[0m
[33m[8 similar log lines][0m [1]   sliced_mesh_layout = self._get_slice_mesh_layout(mesh_dim_names)
[36m<<< Aggregated Logs (2025-11-15 00:42:47) <<<[0m



[36m<<< Aggregated Logs (2025-11-15 00:42:47) <<<[0m



[36m>>> Aggregated Logs (2025-11-15 00:42:47) >>>[0m
[33m[1 similar log lines][0m [7]   sliced_mesh_layout = self._get_slice_mesh_layout(mesh_dim_names)
[36m<<< Aggregated Logs (2025-11-15 00:42:50) <<<[0m



[36m>>> Aggregated Logs (2025-11-15 00:42:47) >>>[0m
[33m[275 similar log lines][0m [5] [titan] 2025-11-15 00:42:50,809 - root - INFO - [31mstep:  1  [32mloss:  8.0601  [38;2;180;60;0mgrad_norm:  1.4225  [38;2;54;234;195mmemory:  0.68GiB(0.37%)  [34mtps: 3,867  [36mtflops: 0.28  [35mmfu: 0.01%[39m
[33m[8 similar log lines][0m [5] [titan] 2025-11-15 00:42:50,810 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:00
[36m<<< Aggregated Logs (2025-11-15 00:42:53) <<<[0m



[36m>>> Aggregated Logs (2025-11-15 00:42:50) >>>[0m
[33m[7 similar log lines][0m [5]   sliced_mesh_layout = self._get_slice_mesh_layout(mesh_dim_names)
[36m<<< Aggregated Logs (2025-11-15 00:42:53) <<<[0m



[36m>>> Aggregated Logs (2025-11-15 00:42:53) >>>[0m
[33m[5 similar log lines][0m [4] [titan] 2025-11-15 00:42:53,801 - root - INFO - [31mstep: 35  [32mloss:  2.8679  [38;2;180;60;0mgrad_norm:  0.2479  [38;2;54;234;195mmemory:  0.69GiB(0.37%)  [34mtps: 189,334  [36mtflops: 13.55  [35mmfu: 0.60%[39m
[36m<<< Aggregated Logs (2025-11-15 00:42:53) <<<[0m

[36m>>> Aggregated Logs (2025-11-15 00:42:53) >>>[0m
[33m[120 similar log lines][0m [1] [titan] 2025-11-15 00:42:52,697 - root - INFO - [31mstep: 36  [32mloss:  2.8734  [38;2;180;60;0mgrad_norm:  0.2411  [38;2;54;234;195mmemory:  0.69GiB(0.37%)  [34mtps: 192,839  [36mtflops: 13.80  [35mmfu: 0.61%[39m
[33m[8 similar log lines][0m [2] [titan] 2025-11-15 00:42:53,894 - root - INFO - [GC] Performing periodic GC collection took 0.04 seconds
[33m[7 similar log lines][0m [2] [titan] 2025-11-15 00:42:54,017 - root - INFO - Training completed
[33m[1 similar log lines][0m [0] [titan] 2025-11-15 00:42:54,018 - root - 

root 2025-11-15 00:42:57 INFO Training completed successfully!


[36m>>> Aggregated Logs (2025-11-15 00:42:56) >>>[0m
[33m[1 similar log lines][0m [0] [titan] 2025-11-15 00:42:56,018 - root - INFO - Training completed
[33m[1 similar log lines][0m [0] [titan] 2025-11-15 00:42:56,291 - root - INFO - [trainer_0] trainer cleaned up
[36m<<< Aggregated Logs (2025-11-15 00:42:59) <<<[0m



### 4. Cleanup the SLURM job
Once you're done experimenting, free up the allocation

In [4]:
# (c) Meta Platforms, Inc. and affiliates. Confidential and proprietary.

await cleanup_job(slurm_job)

Cancelled SLURM job 7749
slurm.utils 2025-11-15 00:42:59 INFO Job terminated successfully
