Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/monarch/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ You can also override the resource configuration manually:
- TrainingActor: Individual trainer processes
- ReplicaActor: Manages groups of trainers
- OrchestrationManager: Top-level orchestration and failure recovery
- FailureController: Optional, periodically injects random failures into trainer processes

##### FAILURE RECOVERY
- Automatic retry with configurable delays (PER_ATTEMPT_DELAY)
Expand Down
Empty file added examples/monarch/__init__.py
Empty file.
158 changes: 100 additions & 58 deletions examples/monarch/train_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,54 +24,61 @@
from torchtitan.config import ConfigManager, JobConfig
from torchtitan.tools.logging import init_logger, logger
from torchtitan.train import Trainer
from utils.failure import Failure, FailureActor, FailureController


# ==== Allocation boilerplate - much of this will be upstreamed into Monarch ====
class MonarchSlurm:
# Cluster Configuration - update these values for your specific cluster
machine: str = "aws_g5.12xlarge"
machine_memory: int = 186777
machine: str = "gpu.xlarge"
machine_memory: int = 2062607
job_name_prefix: str = "monarch-torchft"

job_handles: Dict[str, str] = {}
def __init__(self):
self.job_handles: Dict[str, str] = {}
atexit.register(self.kill_jobs)

@classmethod
def get_config(cls, mesh_name: str, nodes_per_mesh: int) -> Config:
def get_config(self, mesh_name: str, nodes_per_mesh: int) -> Config:
mesh = [f"{mesh_name}:{nodes_per_mesh}:{MonarchSlurm.machine}"]
appdef = hyperactor.host_mesh(meshes=mesh)
# to enable relative import of utils on actors
current_dir = os.path.dirname(os.path.abspath(__file__))
env = {"PYTHONPATH": current_dir}

appdef = hyperactor.host_mesh(meshes=mesh, env=env)

for role in appdef.roles:
role.resource.memMB = MonarchSlurm.machine_memory

return Config(scheduler="slurm", appdef=appdef)

@classmethod
async def get_or_create_job(cls, mesh_name: str, nodes_per_mesh: int = 1) -> None:
config = cls.get_config(mesh_name, nodes_per_mesh)
async def get_or_create_job(self, mesh_name: str, nodes_per_mesh: int = 1) -> None:
config = self.get_config(mesh_name, nodes_per_mesh)
job_name = f"{MonarchSlurm.job_name_prefix}-{mesh_name}"
server_spec = await commands.get_or_create(job_name, config, force_restart=True)
cls.job_handles[mesh_name] = server_spec.name
self.job_handles[mesh_name] = server_spec.name

@classmethod
def kill_jobs(cls):
for mesh_name, job_handle in cls.job_handles.items():
try:
logger.info(f"Destroying job for mesh {mesh_name}")
commands.kill(f"slurm:///{job_handle}")
except Exception as e:
logger.warning(f"Failed to destroy job for {mesh_name}: {e}")
def kill_jobs(self):
for mesh_name in self.job_handles.keys():
self.kill_job(mesh_name)

def kill_job(self, mesh_name: str):
try:
job_handle = self.job_handles[mesh_name]
logger.info(f"Destroying job for mesh {mesh_name}")
commands.kill(f"slurm:///{job_handle}")
except Exception as e:
logger.warning(f"Failed to destroy job for {mesh_name}: {e}")

@classmethod
def proc_mesh(
cls,
self,
mesh_name: str,
num_hosts: int = 1,
num_gpus: int = 8,
) -> ProcMesh:
allocator = RemoteAllocator(
world_id=MonarchSlurm.job_name_prefix,
initializer=TorchXRemoteAllocInitializer(
f"slurm:///{cls.job_handles[mesh_name]}"
f"slurm:///{self.job_handles[mesh_name]}"
),
)
alloc = allocator.allocate(
Expand All @@ -94,7 +101,7 @@ def start_lighthouse(self) -> str:
from torchft.coordination import LighthouseServer

self.lighthouse = LighthouseServer(
bind="[::]:0", min_replicas=1, join_timeout_ms=10000
bind="[::]:0", min_replicas=1, join_timeout_ms=60000
)
return self.lighthouse.address()

Expand Down Expand Up @@ -140,6 +147,7 @@ class JobSpec:
replica_count: int
hosts_per_replica: int
gpus_per_node: int
with_failures: bool
lighthouse_address: str = ""


Expand All @@ -154,16 +162,15 @@ class Replica:
# This does not currently benefit from being an actor, but will once
# Monarch supervision APIs are fleshed out.
class ReplicaActor(Actor):
def __init__(
self,
spec: JobSpec,
replica_id: int,
) -> None:
def __init__(self, spec: JobSpec, replica_id: int, scheduler: MonarchSlurm) -> None:
self.spec = deepcopy(spec)
self.replica_id = replica_id

self.uid = f"[replica_{replica_id}]"
self.spec.job_config.fault_tolerance.replica_id = self.replica_id
self.scheduler = scheduler

self.failure_actors: FailureActor | None = None

@endpoint
async def start_replica(self) -> None:
Expand All @@ -172,14 +179,12 @@ async def start_replica(self) -> None:

trainers_proc_mesh: ProcMesh | None = None
try:
trainers_proc_mesh = MonarchSlurm.proc_mesh(
trainers_proc_mesh = self.scheduler.proc_mesh(
f"replica_{self.replica_id}",
self.spec.hosts_per_replica,
self.spec.gpus_per_node,
)
await trainers_proc_mesh.logging_option(
stream_to_client=True, aggregate_window_sec=None
)
await trainers_proc_mesh.logging_option(stream_to_client=True)
await setup_env_for_distributed(trainers_proc_mesh)

training_actors = trainers_proc_mesh.spawn(
Expand All @@ -189,6 +194,10 @@ async def start_replica(self) -> None:
self.replica_id,
)

self.failure_actors = trainers_proc_mesh.spawn(
"failure_actors", FailureActor
)

logger.info(f"{self.uid} Starting trainers")
await training_actors.start_training.call(self.spec.lighthouse_address)
await trainers_proc_mesh.stop()
Expand All @@ -197,13 +206,29 @@ async def start_replica(self) -> None:
await trainers_proc_mesh.stop()
raise e

@endpoint
async def inject_failure(self, failure_type: Failure):
if self.failure_actors:
try:
logger.info(
f"{self.uid} Injecting failure ({failure_type}) into random trainer"
)

await self.failure_actors.fail.choose(failure_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.choose picks an arbitrary training?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.choose picks an arbitrary training?

yup, choose will send it to one random trainer in the replica mesh

except Exception as e:
error_msg = f"{self.uid} Injected failure: {e}"
logger.error(error_msg)
else:
error_msg = f"{self.uid} No failure actors available"
logger.error(error_msg)


# delay before re-creating proc mesh on existing job. change as needed.
PROC_ATTEMPT_DELAY = 10
PROC_ATTEMPT_DELAY = 0
# proc attempts before getting a new scheduler allocation. change as needed.
PROC_ATTEMPTS = 2
PROC_ATTEMPTS = 4
# attempts before failing training on replica. change as needed.
MAX_ATTEMPT = PROC_ATTEMPTS * 2
MAX_ATTEMPT = PROC_ATTEMPTS * 4


class OrchestrationManager:
Expand All @@ -213,32 +238,41 @@ def __init__(self, spec: JobSpec) -> None:
self.lighthouse_actor: LighthouseActor | None = None
self.lighthouse_mesh: ProcMesh | None = None

self.scheduler = MonarchSlurm()

async def start_training(self) -> None:
logger.info(
f"[Controller] Creating training system with {self.spec.replica_count} replicas"
)

for replica_id in range(self.spec.replica_count):
await MonarchSlurm.get_or_create_job(
await self.scheduler.get_or_create_job(
f"replica_{replica_id}", self.spec.hosts_per_replica
)

mesh_futures = {}
for i in range(self.spec.replica_count):
mesh_futures[i] = asyncio.create_task(self._run_replica(i, 0))

failure_future = None
if self.spec.with_failures:
failure_future = asyncio.create_task(
FailureController.execute_failures(self.replicas, self.scheduler)
)

await asyncio.gather(*mesh_futures.values(), return_exceptions=True)

if failure_future:
failure_future.cancel()

async def start_lighthouse(self) -> None:
if self.spec.remote_lighthouse:
await MonarchSlurm.get_or_create_job("lighthouse")
self.lighthouse_mesh = MonarchSlurm.proc_mesh("lighthouse", num_gpus=1)
await self.scheduler.get_or_create_job("lighthouse")
self.lighthouse_mesh = self.scheduler.proc_mesh("lighthouse", num_gpus=1)
else:
self.lighthouse_mesh = this_host().spawn_procs({"gpus": 1})

await self.lighthouse_mesh.logging_option(
stream_to_client=True, aggregate_window_sec=None
)
await self.lighthouse_mesh.logging_option(stream_to_client=True)
self.lighthouse_actor = self.lighthouse_mesh.spawn(
"lighthouse_actor", LighthouseActor
)
Expand Down Expand Up @@ -274,7 +308,8 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
logger.info(
f"[Controller] Replica {replica_id} has failed {attempt_number} times. Getting new allocation."
)
await MonarchSlurm.get_or_create_job(
self.scheduler.kill_job(f"replica_{replica_id}")
await self.scheduler.get_or_create_job(
f"replica_{replica_id}", self.spec.hosts_per_replica
)
delay = 0 if not attempt_number else PROC_ATTEMPT_DELAY
Expand All @@ -287,10 +322,7 @@ async def _spin_up_replica(self, replica_id: int, attempt_number: int = 0) -> No
await replica_proc_mesh.logging_option(aggregate_window_sec=None)

replica_actor = replica_proc_mesh.spawn(
"replica_actor",
ReplicaActor,
self.spec,
replica_id,
"replica_actor", ReplicaActor, self.spec, replica_id, self.scheduler
)

replica = Replica(replica_id, replica_proc_mesh, replica_actor, attempt_number)
Expand All @@ -301,8 +333,8 @@ async def _teardown(self, replica_id: int) -> None:
try:
replica = self.replicas[replica_id]
await replica.proc_mesh.stop()
del replica.proc_mesh
del self.replicas[replica_id]
del replica.proc_mesh
except Exception as e:
logger.error(f"[Controller] Failed to _teardown replica {replica_id}: {e}")

Expand Down Expand Up @@ -339,20 +371,25 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--model-config",
type=str,
default=os.path.join(script_dir, "debug_model.toml"),
help=f"Path to model configuration file (default: {os.path.join(script_dir, 'debug_model.toml')})",
default="debug_model.toml",
help=f"Relative path to model configuration file (default: {os.path.join(script_dir, 'debug_model.toml')})",
)
parser.add_argument(
"--dataset-path",
type=str,
default=os.path.join(script_dir, "c4_test"),
help=f"Path to training dataset (default: {os.path.join(script_dir, 'c4_test')})",
default="c4_test",
help=f"Relative path to training dataset (default: {os.path.join(script_dir, 'c4_test')})",
)
parser.add_argument(
"--tokenizer-path",
type=str,
default=os.path.join(script_dir, "tokenizer"),
help=f"Path to tokenizer (default: {os.path.join(script_dir, 'tokenizer')})",
default="debug_tokenizer",
help=f"Relative path to tokenizer (default: {os.path.join(script_dir, 'debug_tokenizer')})",
)
parser.add_argument(
"--with-failures",
action="store_true",
help="Enable the failure injector utility (default: False)",
)

return parser.parse_args()
Expand All @@ -362,32 +399,37 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
data_parallel_shard_degree = args.gpu_per_node * args.host_per_replica

output_path = "./outputs"
training_dataset = "c4_test"
training_dataset = args.dataset_path.split("/")[-1]

script_dir = os.path.dirname(os.path.abspath(__file__))
default_args = [
"--job.config_file",
args.model_config,
os.path.join(script_dir, args.model_config),
"--model.tokenizer_path",
args.tokenizer_path,
os.path.join(script_dir, args.tokenizer_path),
"--comm.trace_buf_size",
"0",
"--metrics.log_freq",
"1",
"--fault_tolerance.enable",
"--fault_tolerance.group_size",
str(args.replica_count),
"--fault_tolerance.process_group",
"nccl",
"--fault_tolerance.process_group_timeout_ms",
"60000",
"--parallelism.data_parallel_shard_degree",
str(data_parallel_shard_degree),
"--activation_checkpoint.mode",
"full",
"--comm.train_timeout_seconds",
"60",
"300",
"--training.steps",
str(args.training_steps),
"--training.dataset",
training_dataset,
"--training.dataset_path",
args.dataset_path,
os.path.join(script_dir, args.dataset_path),
"--job.dump_folder",
output_path,
"--metrics.enable_tensorboard",
Expand All @@ -402,6 +444,7 @@ def make_job_spec(args: argparse.Namespace) -> JobSpec:
replica_count=args.replica_count,
hosts_per_replica=args.host_per_replica,
gpus_per_node=args.gpu_per_node,
with_failures=args.with_failures,
)


Expand All @@ -414,7 +457,6 @@ async def main() -> None:
args = parse_args()
job_spec = make_job_spec(args)

atexit.register(MonarchSlurm.kill_jobs)
orchestrator = OrchestrationManager(job_spec)
try:
await orchestrator.start_lighthouse()
Expand Down
Empty file.
Loading