diff --git a/proto/torchft.proto b/proto/torchft.proto index 67a42c03..7e248e54 100644 --- a/proto/torchft.proto +++ b/proto/torchft.proto @@ -72,30 +72,32 @@ service LighthouseService { message ManagerQuorumRequest { int64 rank = 1; int64 step = 2; - string checkpoint_server_addr = 3; + string checkpoint_metadata = 3; bool shrink_only = 4; } message ManagerQuorumResponse { int64 quorum_id = 1; - string address = 2; - string store_address = 3; + string recover_src_manager_address = 2; + optional int64 recover_src_rank = 3; + repeated int64 recover_dst_ranks = 4; + string store_address = 5; // These are information for the replicas which are at the max step. - int64 max_step = 4; - optional int64 max_rank = 5; - int64 max_world_size = 6; + int64 max_step = 6; + optional int64 max_rank = 7; + int64 max_world_size = 8; // These are information for all replicas including behind replicas. - int64 replica_rank = 7; - int64 replica_world_size = 8; - bool heal = 9; + int64 replica_rank = 9; + int64 replica_world_size = 10; + bool heal = 11; } -message CheckpointAddressRequest { +message CheckpointMetadataRequest { int64 rank = 1; } -message CheckpointAddressResponse { - string checkpoint_server_address = 1; +message CheckpointMetadataResponse { + string checkpoint_metadata = 1; } message ShouldCommitRequest { @@ -114,7 +116,7 @@ message KillResponse {} service ManagerService { rpc Quorum (ManagerQuorumRequest) returns (ManagerQuorumResponse); - rpc CheckpointAddress(CheckpointAddressRequest) returns (CheckpointAddressResponse); + rpc CheckpointMetadata(CheckpointMetadataRequest) returns (CheckpointMetadataResponse); rpc ShouldCommit(ShouldCommitRequest) returns (ShouldCommitResponse); rpc Kill(KillRequest) returns (KillResponse); } diff --git a/src/lib.rs b/src/lib.rs index d9a124b7..529532da 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -27,7 +27,7 @@ pub mod torchftpb { } use crate::torchftpb::manager_service_client::ManagerServiceClient; -use crate::torchftpb::{CheckpointAddressRequest, ManagerQuorumRequest, ShouldCommitRequest}; +use crate::torchftpb::{CheckpointMetadataRequest, ManagerQuorumRequest, ShouldCommitRequest}; use pyo3::prelude::*; #[pyclass] @@ -113,15 +113,15 @@ impl ManagerClient { py: Python<'_>, rank: i64, step: i64, - checkpoint_server_addr: String, + checkpoint_metadata: String, shrink_only: bool, timeout: Duration, - ) -> Result<(i64, i64, i64, String, String, i64, Option, i64, bool), StatusError> { + ) -> Result { py.allow_threads(move || { let mut request = tonic::Request::new(ManagerQuorumRequest { rank: rank, step: step, - checkpoint_server_addr: checkpoint_server_addr, + checkpoint_metadata: checkpoint_metadata, shrink_only: shrink_only, }); @@ -131,28 +131,30 @@ impl ManagerClient { let response = self.runtime.block_on(self.client.clone().quorum(request))?; let resp = response.into_inner(); - Ok(( - resp.quorum_id, - resp.replica_rank, - resp.replica_world_size, - resp.address, - resp.store_address, - resp.max_step, - resp.max_rank, - resp.max_world_size, - resp.heal, - )) + Ok(QuorumResult { + quorum_id: resp.quorum_id, + replica_rank: resp.replica_rank, + replica_world_size: resp.replica_world_size, + recover_src_manager_address: resp.recover_src_manager_address, + recover_src_rank: resp.recover_src_rank, + recover_dst_ranks: resp.recover_dst_ranks, + store_address: resp.store_address, + max_step: resp.max_step, + max_rank: resp.max_rank, + max_world_size: resp.max_world_size, + heal: resp.heal, + }) }) } - fn checkpoint_address( + fn checkpoint_metadata( &self, py: Python<'_>, rank: i64, timeout: Duration, ) -> Result { py.allow_threads(move || { - let mut request = tonic::Request::new(CheckpointAddressRequest { rank: rank }); + let mut request = tonic::Request::new(CheckpointMetadataRequest { rank: rank }); // This timeout is processed on the server side so we also enable // keep alives to detect server health. @@ -160,9 +162,9 @@ impl ManagerClient { let response = self .runtime - .block_on(self.client.clone().checkpoint_address(request))?; + .block_on(self.client.clone().checkpoint_metadata(request))?; let resp = response.into_inner(); - Ok(resp.checkpoint_server_address) + Ok(resp.checkpoint_metadata) }) } @@ -194,6 +196,41 @@ impl ManagerClient { } } +#[pyclass(get_all, set_all)] +struct QuorumResult { + quorum_id: i64, + replica_rank: i64, + replica_world_size: i64, + recover_src_manager_address: String, + recover_src_rank: Option, + recover_dst_ranks: Vec, + store_address: String, + max_step: i64, + max_rank: Option, + max_world_size: i64, + heal: bool, +} + +#[pymethods] +impl QuorumResult { + #[new] + fn new() -> Self { + Self { + quorum_id: 0, + replica_rank: 0, + replica_world_size: 1, + recover_src_manager_address: "".to_string(), + recover_src_rank: None, + recover_dst_ranks: Vec::new(), + store_address: "".to_string(), + max_step: 0, + max_rank: None, + max_world_size: 1, + heal: false, + } + } +} + fn reset_python_signals(py: Python<'_>) -> PyResult<()> { // clear python signal handlers // signal.signal(signal.SIGINT, signal.SIG_DFL) @@ -319,6 +356,7 @@ fn torchft(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; m.add_function(wrap_pyfunction!(lighthouse_main, m)?)?; Ok(()) diff --git a/src/manager.rs b/src/manager.rs index 982500a4..931e9954 100644 --- a/src/manager.rs +++ b/src/manager.rs @@ -26,7 +26,7 @@ use crate::torchftpb::lighthouse_service_client::LighthouseServiceClient; use crate::torchftpb::manager_service_client::ManagerServiceClient; use crate::torchftpb::{ manager_service_server::{ManagerService, ManagerServiceServer}, - CheckpointAddressRequest, CheckpointAddressResponse, KillRequest, KillResponse, + CheckpointMetadataRequest, CheckpointMetadataResponse, KillRequest, KillResponse, LighthouseHeartbeatRequest, LighthouseQuorumRequest, ManagerQuorumRequest, ManagerQuorumResponse, Quorum, QuorumMember, ShouldCommitRequest, ShouldCommitResponse, }; @@ -38,7 +38,7 @@ use log::{info, warn}; use std::{println as info, println as warn}; struct ManagerState { - checkpoint_servers: HashMap, + checkpoint_metadata: HashMap, channel: broadcast::Sender, participants: HashSet, @@ -104,7 +104,7 @@ impl Manager { world_size: world_size, heartbeat_interval: heartbeat_interval, state: Mutex::new(ManagerState { - checkpoint_servers: HashMap::new(), + checkpoint_metadata: HashMap::new(), channel: tx, participants: HashSet::new(), @@ -237,8 +237,8 @@ impl ManagerService for Arc { // save checkpoint server info for healing process // TODO: make separate call to set? state - .checkpoint_servers - .insert(req.rank, req.checkpoint_server_addr.clone()); + .checkpoint_metadata + .insert(req.rank, req.checkpoint_metadata.clone()); // TODO check step state.participants.insert(rank); @@ -266,81 +266,28 @@ impl ManagerService for Arc { .await .map_err(|e| Status::internal(e.to_string()))?; - let mut participants = quorum.participants.clone(); - participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); - - let replica_rank = participants.iter().enumerate().find_map(|(i, p)| { - if p.replica_id == self.replica_id { - Some(i) - } else { - None - } - }); - if replica_rank.is_none() { - return Err(Status::not_found(format!( - "replica {} not participating in returned quorum", - self.replica_id - ))); - } - - let max_step = participants.iter().map(|p| p.step).max().unwrap(); - let max_participants: Vec<&QuorumMember> = - participants.iter().filter(|p| p.step == max_step).collect(); - - let primary = max_participants[rank as usize % max_participants.len()]; - - let mut max_rank = None; - for (i, p) in max_participants.iter().enumerate() { - if p.replica_id == self.replica_id { - max_rank = Some(i as i64); - break; - } - } - - // Decide whether we should be healing: - // 1. if we're not at the max step - // 2. if everyone is at the first step and we're not the primary - let heal = max_step != req.step || max_step == 0 && primary.replica_id != self.replica_id; - if heal { - info!( - "healing is required step={}, max_step={}", - req.step, max_step - ); - } - - let reply = ManagerQuorumResponse { - quorum_id: quorum.quorum_id, - // address is used for looking up the checkpoint server address. - address: primary.address.clone(), - store_address: primary.store_address.clone(), - max_step: max_step, - max_rank: max_rank, - max_world_size: max_participants.len() as i64, - replica_rank: replica_rank.unwrap() as i64, - replica_world_size: participants.len() as i64, - heal: heal, - }; - info!("returning quorum for rank {}", rank); + let reply = compute_quorum_results(&self.replica_id, rank, &quorum)?; + Ok(Response::new(reply)) } - async fn checkpoint_address( + async fn checkpoint_metadata( &self, - request: Request, - ) -> Result, Status> { + request: Request, + ) -> Result, Status> { let state = self.state.lock().await; let req = request.into_inner(); - let address = state - .checkpoint_servers + let metadata = state + .checkpoint_metadata .get(&req.rank) .ok_or_else(|| Status::invalid_argument("rank not found"))?; - let reply = CheckpointAddressResponse { - checkpoint_server_address: address.clone(), + let reply = CheckpointMetadataResponse { + checkpoint_metadata: metadata.clone(), }; Ok(Response::new(reply)) } @@ -407,6 +354,131 @@ impl ManagerService for Arc { } } +fn compute_quorum_results( + replica_id: &str, + rank: i64, + quorum: &Quorum, +) -> Result { + let mut participants = quorum.participants.clone(); + participants.sort_by(|a, b| a.replica_id.cmp(&b.replica_id)); + + // Compute the rank of the replica in the returned quorum. + let replica_rank = participants + .iter() + .enumerate() + .find_map(|(i, p)| { + if p.replica_id == replica_id { + Some(i) + } else { + None + } + }) + .ok_or_else(|| { + Status::not_found(format!( + "replica {} not participating in returned quorum", + replica_id + )) + })?; + + let step = participants[replica_rank].step; + + // Compute the details for workers at max step. + let max_step = participants.iter().map(|p| p.step).max().unwrap(); + let max_participants: Vec<&QuorumMember> = + participants.iter().filter(|p| p.step == max_step).collect(); + let max_rank = max_participants.iter().enumerate().find_map(|(i, p)| { + if p.replica_id == replica_id { + Some(i as i64) + } else { + None + } + }); + + // The primary TCPStore to use for this rank. + let primary_rank = rank as usize % max_participants.len(); + let primary = max_participants[primary_rank]; + + // Compute recovery assignments + + // Nodes are recovering if: + // 1. not at the max step + // 2. max_step == 0 and not the primary replica + let all_recover_dst_ranks: Vec = participants + .iter() + .enumerate() + .filter_map(|(i, p)| { + if p.step != max_step || max_step == 0 && primary.replica_id != p.replica_id { + Some(i) + } else { + None + } + }) + .collect(); + let all_recover_dst_ranks_set = all_recover_dst_ranks.iter().collect::>(); + let up_to_date_ranks: Vec = participants + .iter() + .enumerate() + .filter_map(|(i, _p)| { + if !all_recover_dst_ranks_set.contains(&i) { + Some(i) + } else { + None + } + }) + .collect(); + + // This is a map of rank to the ranks that are recovering from that node. + let mut recovery_assignments: HashMap> = HashMap::new(); + // The rank of the node that this rank is recovering from. + let mut recover_src_rank: Option = None; + for (i, recovering_rank) in all_recover_dst_ranks.iter().enumerate() { + let up_to_date_idx = (i + rank as usize) % up_to_date_ranks.len(); + let recovering_recover_src_rank = up_to_date_ranks[up_to_date_idx]; + if !recovery_assignments.contains_key(&recovering_recover_src_rank) { + recovery_assignments.insert(recovering_recover_src_rank, Vec::new()); + } + recovery_assignments + .get_mut(&recovering_recover_src_rank) + .unwrap() + .push(*recovering_rank as i64); + if *recovering_rank == replica_rank { + recover_src_rank = Some(recovering_recover_src_rank as i64); + } + } + + let heal = recover_src_rank.is_some(); + if heal { + info!( + "healing is required step={}, max_step={}, recover_src_rank={}", + step, + max_step, + recover_src_rank.unwrap() + ); + } + + let recover_src_manager_address = match recover_src_rank { + Some(r) => participants[r as usize].address.clone(), + None => "".to_string(), + }; + + Ok(ManagerQuorumResponse { + quorum_id: quorum.quorum_id, + // address is used for looking up the checkpoint server address. + recover_src_manager_address: recover_src_manager_address, + recover_src_rank: recover_src_rank, + recover_dst_ranks: recovery_assignments + .get(&replica_rank) + .map_or_else(Vec::new, |v| v.clone()), + store_address: primary.store_address.clone(), + max_step: max_step, + max_rank: max_rank, + max_world_size: max_participants.len() as i64, + replica_rank: replica_rank as i64, + replica_world_size: participants.len() as i64, + heal: heal, + }) +} + #[cfg(test)] mod tests { use super::*; @@ -506,7 +578,7 @@ mod tests { let mut request = tonic::Request::new(ManagerQuorumRequest { rank: 0, step: 123, - checkpoint_server_addr: "addr".to_string(), + checkpoint_metadata: "addr".to_string(), shrink_only: false, }); request.set_timeout(Duration::from_secs(10)); @@ -516,7 +588,7 @@ mod tests { lighthouse_fut.abort(); assert_eq!(resp.quorum_id, 1); - assert_eq!(resp.address, manager.address()); + assert_eq!(resp.recover_src_manager_address, "".to_string()); assert_eq!(resp.store_address, "store_addr".to_string()); assert_eq!(resp.max_step, 123); assert_eq!(resp.max_rank, Some(0)); @@ -565,7 +637,7 @@ mod tests { let mut request = tonic::Request::new(ManagerQuorumRequest { rank: 0, step: 0, - checkpoint_server_addr: "addr".to_string(), + checkpoint_metadata: "addr".to_string(), shrink_only: false, }); request.set_timeout(Duration::from_secs(10)); @@ -597,4 +669,183 @@ mod tests { Ok(()) } + + #[tokio::test] + async fn test_checkpoint_metadata() -> Result<()> { + let lighthouse = Lighthouse::new(LighthouseOpt { + bind: "[::]:0".to_string(), + join_timeout_ms: 100, + min_replicas: 1, + quorum_tick_ms: 100, + heartbeat_timeout_ms: 5000, + }) + .await?; + let lighthouse_fut = tokio::spawn(lighthouse.clone().run()); + + let manager = Manager::new( + "rep_id".to_string(), + lighthouse.address(), + "localhost".to_string(), + "[::]:0".to_string(), + "store_addr".to_string(), + 1, // world size + Duration::from_millis(100), // heartbeat interval + Duration::from_secs(10), // connect timeout + ) + .await?; + let manager_fut = tokio::spawn(manager.clone().run()); + + let mut client = manager_client_new(manager.address(), Duration::from_secs(10)).await?; + + let request = tonic::Request::new(CheckpointMetadataRequest { rank: 0 }); + let resp = client.checkpoint_metadata(request).await; + assert!(resp.err().unwrap().to_string().contains("rank not found")); + + { + let mut state = manager.state.lock().await; + + state.checkpoint_metadata.insert(0, "addr".to_string()); + } + + let request = tonic::Request::new(CheckpointMetadataRequest { rank: 0 }); + let resp = client.checkpoint_metadata(request).await?.into_inner(); + assert_eq!(resp.checkpoint_metadata, "addr".to_string()); + + manager_fut.abort(); + lighthouse_fut.abort(); + + Ok(()) + } + + #[tokio::test] + async fn test_compute_quorum_results_first_step() -> Result<()> { + let quorum = Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "replica_0".to_string(), + address: "addr_0".to_string(), + store_address: "store_addr_0".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_1".to_string(), + address: "addr_1".to_string(), + store_address: "store_addr_1".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + ], + created: None, + }; + + // rank 0 + + let results = compute_quorum_results("replica_0", 0, &quorum)?; + assert!(!results.heal); + assert_eq!(results.replica_rank, 0); + assert_eq!(results.recover_src_rank, None); + assert_eq!(results.recover_dst_ranks, vec![1]); + + let results = compute_quorum_results("replica_1", 0, &quorum)?; + assert!(results.heal); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, Some(0)); + assert_eq!(results.recover_dst_ranks, Vec::::new()); + + // rank 1 assignments should be offset from rank 0 above and the primary + + let results = compute_quorum_results("replica_1", 1, &quorum)?; + assert!(!results.heal); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert_eq!(results.recover_dst_ranks, vec![0]); + + Ok(()) + } + + #[tokio::test] + async fn test_compute_quorum_results_recovery() -> Result<()> { + let quorum = Quorum { + quorum_id: 1, + participants: vec![ + QuorumMember { + replica_id: "replica_0".to_string(), + address: "addr_0".to_string(), + store_address: "store_addr_0".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_1".to_string(), + address: "addr_1".to_string(), + store_address: "store_addr_1".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_2".to_string(), + address: "addr_2".to_string(), + store_address: "store_addr_2".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_3".to_string(), + address: "addr_3".to_string(), + store_address: "store_addr_3".to_string(), + step: 1, + world_size: 1, + shrink_only: false, + }, + QuorumMember { + replica_id: "replica_4".to_string(), + address: "addr_4".to_string(), + store_address: "store_addr_4".to_string(), + step: 0, + world_size: 1, + shrink_only: false, + }, + ], + created: None, + }; + + // rank 0 + + let results = compute_quorum_results("replica_0", 0, &quorum)?; + assert!(results.heal); + assert_eq!(results.recover_src_manager_address, "addr_1".to_string()); + assert_eq!(results.replica_rank, 0); + assert_eq!(results.recover_src_rank, Some(1)); + assert!(results.recover_dst_ranks.is_empty()); + + let results = compute_quorum_results("replica_1", 0, &quorum)?; + assert!(!results.heal); + assert_eq!(results.recover_src_manager_address, "".to_string()); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert_eq!(results.recover_dst_ranks, vec![0, 4]); + + let results = compute_quorum_results("replica_3", 0, &quorum)?; + assert!(!results.heal); + assert_eq!(results.replica_rank, 3); + assert_eq!(results.recover_src_rank, None); + assert_eq!(results.recover_dst_ranks, vec![2]); + + // rank 1 assignments should be offset from rank 0 above + + let results = compute_quorum_results("replica_1", 1, &quorum)?; + assert!(!results.heal); + assert_eq!(results.replica_rank, 1); + assert_eq!(results.recover_src_rank, None); + assert_eq!(results.recover_dst_ranks, vec![2]); + + Ok(()) + } } diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index aaad8437..48a5d516 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -16,9 +16,10 @@ import socket import threading import urllib.request +from abc import ABC, abstractmethod from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Callable, Generic, TypeVar +from typing import Generic, List, Optional, TypeVar import torch @@ -29,7 +30,64 @@ T = TypeVar("T") -class CheckpointServer(Generic[T]): +class CheckpointTransport(Generic[T], ABC): + @abstractmethod + def metadata(self) -> str: + """ + Returns a string that will be used by the remote CheckpointTransport to fetch the checkpoint. + """ + ... + + @abstractmethod + def send_checkpoint( + self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta + ) -> None: + """ + Sends the checkpoint, only called when there is a rank that is behind. + + This may be async. + + Args: + dst_ranks: the ranks to send to + step: the step number to send + state_dict: the state dict to send + timeout: the timeout to wait for the checkpoint to be sent + """ + ... + + def disallow_checkpoint(self) -> None: + """ + Called after send_checkpoint to wait for the checkpoint to be sent. + + Once this returns, the state_dict may be mutated so no further data should be sent. + """ + ... + + @abstractmethod + def recv_checkpoint( + self, src_rank: int, metadata: str, step: int, timeout: timedelta + ) -> T: + """ + Receives the checkpoint from the given rank. + + Args: + src_rank: the rank to receive the checkpoint from + metadata: the metadata returned by the remote CheckpointTransport + step: the step number to receive + timeout: the timeout to wait for the checkpoint + """ + ... + + def shutdown(self, wait: bool = True) -> None: + """ + Called to shutdown the checkpoint transport. + + Args: + wait: whether to wait for the transport to shutdown + """ + + +class CheckpointServer(CheckpointTransport[T]): """ This is an HTTP server that can be used to transfer checkpoints between workers. @@ -41,11 +99,12 @@ class CheckpointServer(Generic[T]): state_dict: a callable that returns the state dict to be transferred """ - def __init__(self, state_dict: Callable[[], T], timeout: timedelta) -> None: + def __init__(self, timeout: timedelta) -> None: self._checkpoint_lock = threading.Lock() self._disallowed = False self._step = -1 self._timeout = timeout + self._state_dict: Optional[T] = None ckpt_server = self @@ -74,9 +133,9 @@ def do_GET(self): self.send_header("Content-type", "application/octet-stream") self.end_headers() - sd = state_dict() + state_dict = ckpt_server._state_dict - torch.save(sd, self.wfile) + torch.save(state_dict, self.wfile) except Exception as e: logger.exception( f"Exception in checkpoint server when handling {self.path=}: {e}", @@ -117,7 +176,7 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T: def address(self) -> str: """ - Returns the HTTP address to fetch a checkpoint from this server at the current step. + Returns the HTTP address to fetch a checkpoint from this server. Step must be appended to the end of the address. Format: http://host:port/checkpoint/1234 @@ -125,7 +184,7 @@ def address(self) -> str: an HTTP address """ port = self._server.socket.getsockname()[1] - return f"http://{socket.gethostname()}:{port}/checkpoint/{self._step}" + return f"http://{socket.gethostname()}:{port}/checkpoint/" def _serve(self) -> None: try: @@ -156,8 +215,28 @@ def allow_checkpoint(self, step: int) -> None: self._disallowed = False self._checkpoint_lock.release() - def shutdown(self) -> None: + def shutdown(self, wait: bool = True) -> None: """ Shutdown the server. """ - self._server.shutdown() + if not wait: + # hack for nonblocking shutdown of socketserver threads + # pyre-fixme[16]: no attribute `__shutdown_request`. + self._server.__shutdown_request = True + if wait: + self._server.shutdown() + self._thread.join() + + def metadata(self) -> str: + return self.address() + + def send_checkpoint( + self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta + ) -> None: + self._state_dict = state_dict + self.allow_checkpoint(step) + + def recv_checkpoint( + self, src_rank: int, metadata: str, step: int, timeout: timedelta + ) -> T: + return self.load_from_address(f"{metadata}{step}", timeout) diff --git a/torchft/checkpointing_test.py b/torchft/checkpointing_test.py index 983c429c..e2a05e12 100644 --- a/torchft/checkpointing_test.py +++ b/torchft/checkpointing_test.py @@ -18,26 +18,40 @@ def test_checkpoint_server(self) -> None: state_dict_fn = MagicMock() state_dict_fn.return_value = expected server = CheckpointServer( - state_dict=state_dict_fn, timeout=timedelta(seconds=10), ) - server.disallow_checkpoint() - server.allow_checkpoint(1234) + server.send_checkpoint( + dst_ranks=[], + step=1234, + state_dict=expected, + timeout=timedelta(seconds=10), + ) - addr = server.address() + metadata = server.metadata() - out = CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=10)) + out = server.recv_checkpoint( + src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10) + ) self.assertEqual(out, expected) # test timeout with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"): - CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=0.0)) + server.recv_checkpoint( + src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=0.0) + ) # test mismatch case - server.allow_checkpoint(2345) + server.send_checkpoint( + dst_ranks=[], + step=2345, + state_dict=expected, + timeout=timedelta(seconds=10), + ) with self.assertRaisesRegex(urllib.error.HTTPError, r"Error 400"): - CheckpointServer.load_from_address(addr, timeout=timedelta(seconds=10)) + server.recv_checkpoint( + src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10) + ) server.shutdown() diff --git a/torchft/manager.py b/torchft/manager.py index d9ff366d..dc5ab302 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -38,7 +38,7 @@ import torch from torch.distributed import ReduceOp, TCPStore -from torchft.checkpointing import CheckpointServer +from torchft.checkpointing import CheckpointServer, CheckpointTransport from torchft.futures import future_timeout from torchft.torchft import Manager as _Manager, ManagerClient @@ -104,6 +104,7 @@ def __init__( port: Optional[int] = None, hostname: str = socket.gethostname(), heartbeat_interval: timedelta = timedelta(milliseconds=100), + checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None, ) -> None: """ Args: @@ -139,6 +140,8 @@ def __init__( lighthouse_addr: if rank==0, the address of the lighthouse server replica_id: if rank==0, the replica_id for this group hostname: if rank==0, the hostname to advertise to the lighthouse server + checkpoint_transport: the checkpoint transport to use for + transfering checkpoints to recovering replicas """ self._load_state_dict = load_state_dict self._state_dict = state_dict @@ -156,15 +159,15 @@ def __init__( world_size = world_size or int(os.environ["WORLD_SIZE"]) self._min_replica_size = min_replica_size - def _manager_state_dict() -> Dict[str, T]: - return { - "user": state_dict(), - "torchft": cast(T, self.state_dict()), - } + self._user_state_dict = state_dict - self._ckpt_server = CheckpointServer[Dict[str, T]]( - _manager_state_dict, - timeout=timeout, + if checkpoint_transport is None: + checkpoint_transport = CheckpointServer[Dict[str, T]]( + timeout=timeout, + ) + + self._checkpoint_transport: CheckpointTransport[Dict[str, T]] = ( + checkpoint_transport ) self._executor = ThreadPoolExecutor( max_workers=1, thread_name_prefix="async_quorum" @@ -223,14 +226,14 @@ def _manager_state_dict() -> Dict[str, T]: self._participating_rank: Optional[int] = None self._participating_world_size: int = 0 - def shutdown(self) -> None: + def shutdown(self, wait: bool = True) -> None: """ Shutdown the manager and checkpoint server. """ - self._ckpt_server.shutdown() + self._checkpoint_transport.shutdown(wait=wait) if self._manager is not None: self._manager.shutdown() - self._executor.shutdown() + self._executor.shutdown(wait=wait) def allreduce(self, tensor: torch.Tensor) -> torch.futures.Future[torch.Tensor]: """ @@ -386,7 +389,6 @@ def start_quorum( self._errored = None self._healing = False - self._ckpt_server.allow_checkpoint(self._step) # TODO: we should really be wrapping this whole section in a try-except # block to allow gracefully recovering from issues in PG setup and quorum. @@ -422,24 +424,24 @@ def wait_quorum(self) -> None: def _async_quorum( self, allow_heal: bool, shrink_only: bool, quorum_timeout: timedelta ) -> None: - ( - quorum_id, - replica_rank, - replica_world_size, - address, - store_address, - max_step, - max_rank, - max_world_size, - heal, - ) = self._client.quorum( + quorum = self._client.quorum( rank=self._rank, step=self._step, - checkpoint_server_addr=self._ckpt_server.address(), + checkpoint_metadata=self._checkpoint_transport.metadata(), shrink_only=shrink_only, timeout=quorum_timeout, ) + quorum_id = quorum.quorum_id + replica_rank = quorum.replica_rank + replica_world_size = quorum.replica_world_size + recover_src_manager_address = quorum.recover_src_manager_address + store_address = quorum.store_address + max_step = quorum.max_step + max_rank = quorum.max_rank + max_world_size = quorum.max_world_size + heal = quorum.heal + # When using async quorum we need to take the recovered workers. # When not using async quorum we need to take the max world size as all # workers will be healthy. @@ -470,29 +472,54 @@ def _async_quorum( self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size) self._quorum_id = quorum_id - # See manager.rs for healing conditions - if heal and allow_heal: - self._healing = True - self._logger.info( - f"healing required, fetching checkpoint server address from {address=} {max_step=}" - ) - primary_client = ManagerClient( - address, connect_timeout=self._connect_timeout - ) - checkpoint_server_address = primary_client.checkpoint_address( - self._rank, timeout=self._timeout - ) + if allow_heal: + if quorum.recover_dst_ranks: + self._logger.info( + f"peers need recovery from us {quorum.recover_dst_ranks}" + ) + self._checkpoint_transport.send_checkpoint( + dst_ranks=quorum.recover_dst_ranks, + step=max_step, + state_dict=self._manager_state_dict(), + timeout=self._timeout, + ) - self._logger.info(f"fetching checkpoint from {checkpoint_server_address=}") - self._pending_state_dict = CheckpointServer.load_from_address( - checkpoint_server_address, timeout=self._timeout - ) - self.load_state_dict(self._pending_state_dict["torchft"]) - # we apply the user state dict only when safe from the main thread + # See manager.rs for healing conditions + if heal: + self._healing = True + self._logger.info( + f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}" + ) + primary_client = ManagerClient( + recover_src_manager_address, connect_timeout=self._connect_timeout + ) + checkpoint_metadata = primary_client.checkpoint_metadata( + self._rank, timeout=self._timeout + ) + recover_src_rank = quorum.recover_src_rank + assert ( + recover_src_rank is not None + ), "must have a recover rank when healing" - # This isn't strictly needed as loading the state_dict above should - # restore the correct step but it makes writing tests simpler. - self._step = max_step + self._logger.info( + f"fetching checkpoint from {recover_src_rank=} with {checkpoint_metadata=}" + ) + + # we apply the user state dict only when safe from the main thread + # save it for now + self._pending_state_dict = self._checkpoint_transport.recv_checkpoint( + src_rank=recover_src_rank, + metadata=checkpoint_metadata, + step=max_step, + timeout=self._timeout, + ) + + # pyre-fixme[6]: got object + self.load_state_dict(self._pending_state_dict["torchft"]) + + # This isn't strictly needed as loading the state_dict above should + # restore the correct step but it makes writing tests simpler. + self._step = max_step def _apply_pending_state_dict(self) -> None: assert self._healing, "must be in healing state" @@ -553,7 +580,7 @@ def should_commit(self, timeout: Optional[timedelta] = None) -> bool: f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}" ) - self._ckpt_server.disallow_checkpoint() + self._checkpoint_transport.disallow_checkpoint() # decide whether we're in a healthy state to increase the step count if should_commit: @@ -574,6 +601,12 @@ def load_state_dict(self, state_dict: Dict[str, int]) -> None: self._step = state_dict["step"] self._batches_committed = state_dict["batches_committed"] + def _manager_state_dict(self) -> Dict[str, object]: + return { + "user": self._user_state_dict(), + "torchft": self.state_dict(), + } + def state_dict(self) -> Dict[str, int]: """ Get the state dict for this manager. diff --git a/torchft/manager_integ_test.py b/torchft/manager_integ_test.py index d6e7bdee..0721b17e 100644 --- a/torchft/manager_integ_test.py +++ b/torchft/manager_integ_test.py @@ -159,7 +159,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-fixme[6]: Incompatible parameter type **runner.manager_args, ) - stack.callback(manager.shutdown) + stack.callback(lambda: manager.shutdown(wait=False)) m: nn.Module = DistributedDataParallel(manager, MyModel()) optimizer: optim.Optimizer = OptimizerWrapper( @@ -223,7 +223,7 @@ def state_dict() -> Dict[str, Dict[str, object]]: # pyre-fixme[6]: Incompatible parameter type **runner.manager_args, ) - stack.callback(manager.shutdown) + stack.callback(lambda: manager.shutdown(wait=False)) m: nn.Module = MyModel() optimizer: optim.Optimizer = optim.Adam(m.parameters()) @@ -460,7 +460,7 @@ def test_quorum_timeout(self) -> None: port=19530, use_async_quorum=False, ) - stack.callback(manager.shutdown) + stack.callback(lambda: manager.shutdown(wait=False)) with self.assertElapsedLessThan(1.0): with self.assertRaisesRegex( diff --git a/torchft/manager_test.py b/torchft/manager_test.py index 97b891cc..01adddf1 100644 --- a/torchft/manager_test.py +++ b/torchft/manager_test.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. from datetime import timedelta +from typing import Optional from unittest import TestCase from unittest.mock import MagicMock, create_autospec, patch @@ -13,7 +14,7 @@ from torchft.manager import MANAGER_ADDR_KEY, REPLICA_ID_KEY, Manager, WorldSizeMode from torchft.process_group import ProcessGroup, _DummyWork -from torchft.torchft import ManagerClient +from torchft.torchft import QuorumResult def mock_should_commit( @@ -25,13 +26,19 @@ def mock_should_commit( class TestManager(TestCase): store: TCPStore # pyre-fixme[13]: never initialized load_state_dict: MagicMock # pyre-fixme[13]: never initialized + manager: Optional[Manager] # pyre-fixme[13]: never initialized + + def tearDown(self) -> None: + manager = self.manager + if manager is not None: + manager.shutdown(wait=False) def _create_manager( self, use_async_quorum: bool = True, min_replica_size: int = 2, world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC, - timeout: timedelta = timedelta(seconds=60), + timeout: timedelta = timedelta(seconds=10), ) -> Manager: pg = create_autospec(ProcessGroup) self.store = TCPStore( @@ -58,6 +65,7 @@ def _create_manager( world_size_mode=world_size_mode, timeout=timeout, ) + self.manager = manager return manager @patch("torchft.manager.ManagerClient", autospec=True) @@ -92,17 +100,18 @@ def test_quorum_happy(self, client_mock: MagicMock) -> None: manager = self._create_manager() client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 1, # max_step - 1, # max_rank - 2, # max_world_size - False, # heal - ) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock().quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -127,21 +136,30 @@ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None: manager = self._create_manager(use_async_quorum=False) client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 20, # max_step - None, # max_rank - 2, # max_world_size - True, # heal - ) - # forcible increment checkpoint server to compute correct address - manager._ckpt_server.allow_checkpoint(manager.current_step()) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.recover_src_rank = 0 + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 20 + quorum.max_rank = None + quorum.max_world_size = 2 + quorum.heal = True - client_mock().checkpoint_address.return_value = manager._ckpt_server.address() + client_mock().quorum.return_value = quorum + + # forcible increment checkpoint server to compute correct address + manager._checkpoint_transport.send_checkpoint( + dst_ranks=[], + step=quorum.max_step, + state_dict=manager._manager_state_dict(), + timeout=timedelta(seconds=10), + ) + client_mock().checkpoint_metadata.return_value = ( + manager._checkpoint_transport.metadata() + ) self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -169,21 +187,30 @@ def test_quorum_heal_async_not_enough_participants( manager = self._create_manager(use_async_quorum=True, min_replica_size=2) client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 20, # max_step - None, # max_rank - 1, # max_world_size - True, # heal - ) - # forcible increment checkpoint server to compute correct address - manager._ckpt_server.allow_checkpoint(manager.current_step()) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.recover_src_rank = 0 + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 20 + quorum.max_rank = None + quorum.max_world_size = 1 + quorum.heal = True + + client_mock().quorum.return_value = quorum - client_mock().checkpoint_address.return_value = manager._ckpt_server.address() + # forcible increment checkpoint server to compute correct address + manager._checkpoint_transport.send_checkpoint( + dst_ranks=[], + step=quorum.max_step, + state_dict=manager._manager_state_dict(), + timeout=timedelta(seconds=10), + ) + client_mock().checkpoint_metadata.return_value = ( + manager._checkpoint_transport.metadata() + ) self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -212,6 +239,7 @@ def test_quorum_heal_async_not_enough_participants( self.assertEqual(self.load_state_dict.call_count, 1) # failed to commit so no step + quorum.heal = False manager.start_quorum() self.assertEqual(manager.current_step(), 20) self.assertEqual(manager.batches_committed(), 0) @@ -221,21 +249,30 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: manager = self._create_manager(use_async_quorum=True, min_replica_size=1) client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 20, # max_step - None, # max_rank - 1, # max_world_size - True, # heal - ) - # forceable increment checkpoint server to compute correct address - manager._ckpt_server.allow_checkpoint(manager.current_step()) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.recover_src_rank = 0 + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 20 + quorum.max_rank = None + quorum.max_world_size = 1 + quorum.heal = True - client_mock().checkpoint_address.return_value = manager._ckpt_server.address() + client_mock().quorum.return_value = quorum + + # forceable increment checkpoint server to compute correct address + manager._checkpoint_transport.send_checkpoint( + dst_ranks=[], + step=quorum.max_step, + state_dict=manager._manager_state_dict(), + timeout=timedelta(seconds=10), + ) + client_mock().checkpoint_metadata.return_value = ( + manager._checkpoint_transport.metadata() + ) self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -261,6 +298,8 @@ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None: self.assertEqual(self.load_state_dict.call_count, 1) + # healed + quorum.heal = False manager.start_quorum() self.assertEqual(manager.current_step(), 21) self.assertEqual(manager.batches_committed(), 1) @@ -270,17 +309,18 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: manager = self._create_manager() client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 1, # max_step - 1, # max_rank - 2, # max_world_size - False, # heal - ) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock().quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -308,17 +348,8 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: manager._pg.allreduce.side_effect = None # inject failure when worked waited - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 2, # max_step - 1, # max_rank - 2, # max_world_size - False, # heal - ) + quorum.max_step = 2 + manager.start_quorum() self.assertFalse(manager._errored) @@ -336,17 +367,7 @@ def test_allreduce_error(self, client_mock: MagicMock) -> None: manager._pg.allreduce.reset_mock(return_value=True) # recover on next step - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world_size - "manager address", - f"localhost:{self.store.port}", - 3, # max_step - 1, # max_rank - 2, # max_world_size - False, # heal - ) + quorum.max_step = 3 manager.start_quorum() manager.allreduce(torch.tensor([1.0])).wait() @@ -362,17 +383,18 @@ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None: ) client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - rank, # replica_rank - 3, # replica_world - "manager address", - f"localhost:{self.store.port}", - 1, # max_step - rank, # max_rank - 3, # max_world_size - False, # heal - ) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = rank + quorum.replica_world_size = 3 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = rank + quorum.max_world_size = 3 + quorum.heal = False + + client_mock().quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -395,17 +417,18 @@ def test_quorum_no_healing(self, client_mock: MagicMock) -> None: ) client_mock().should_commit = mock_should_commit - client_mock().quorum.return_value = ( - 123, # quorum_id - 0, # replica_rank - 3, # replica_world - "manager address", - f"localhost:{self.store.port}", - 1, # max_step - None, # max_rank - 2, # max_world_size - True, # heal - ) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 0 + quorum.replica_world_size = 3 + quorum.recover_src_manager_address = "manager address" + quorum.recover_src_rank = 1 + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = None + quorum.max_world_size = 2 + quorum.heal = True + client_mock().quorum.return_value = quorum self.assertEqual(manager._quorum_id, -1) self.assertEqual(manager.current_step(), 0) @@ -492,17 +515,18 @@ def test_manager_numerics(self, client_mock: MagicMock) -> None: def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None: manager = self._create_manager(use_async_quorum=False) - client_mock().quorum.return_value = ( - 123, # quorum_id - 1, # replica_rank - 2, # replica_world - "manager address", - f"localhost:{self.store.port}", - 1, # max_step - 1, # max_rank - 2, # max_world_size - False, # heal - ) + quorum = QuorumResult() + quorum.quorum_id = 123 + quorum.replica_rank = 1 + quorum.replica_world_size = 2 + quorum.recover_src_manager_address = "manager address" + quorum.store_address = f"localhost:{self.store.port}" + quorum.max_step = 1 + quorum.max_rank = 1 + quorum.max_world_size = 2 + quorum.heal = False + + client_mock().quorum.return_value = quorum manager.start_quorum(timeout=timedelta(seconds=12)) self.assertEqual( diff --git a/torchft/torchft.pyi b/torchft/torchft.pyi index fbd02937..2c8c6cd9 100644 --- a/torchft/torchft.pyi +++ b/torchft/torchft.pyi @@ -1,5 +1,5 @@ from datetime import timedelta -from typing import Optional, Tuple +from typing import List, Optional class ManagerClient: def __init__(self, addr: str, connect_timeout: timedelta) -> None: ... @@ -7,11 +7,11 @@ class ManagerClient: self, rank: int, step: int, - checkpoint_server_addr: str, + checkpoint_metadata: str, shrink_only: bool, timeout: timedelta, - ) -> Tuple[int, int, int, str, str, int, Optional[int], int, bool]: ... - def checkpoint_address(self, rank: int, timeout: timedelta) -> str: ... + ) -> QuorumResult: ... + def checkpoint_metadata(self, rank: int, timeout: timedelta) -> str: ... def should_commit( self, rank: int, @@ -20,6 +20,19 @@ class ManagerClient: timeout: timedelta, ) -> bool: ... +class QuorumResult: + quorum_id: int + replica_rank: int + replica_world_size: int + recover_src_manager_address: str + recover_src_rank: Optional[int] + recover_dst_ranks: List[int] + store_address: str + max_step: int + max_rank: Optional[int] + max_world_size: int + heal: bool + class Manager: def __init__( self,