From a8771194ab3795be53a7011dcf7ade85bedbfc02 Mon Sep 17 00:00:00 2001 From: Tristan Rice Date: Mon, 10 Feb 2025 18:12:59 -0800 Subject: [PATCH] checkpointing/HTTPTransport: added streaming serialization and parallel transfer support --- torchft/checkpointing/_rwlock.py | 132 ++++++++++++ torchft/checkpointing/_serialization.py | 33 +++ torchft/checkpointing/http_transport.py | 203 +++++++++++++----- torchft/checkpointing/http_transport_bench.py | 55 +++++ torchft/checkpointing/http_transport_test.py | 78 ++++--- torchft/checkpointing/rwlock_test.py | 52 +++++ torchft/manager.py | 1 + 7 files changed, 469 insertions(+), 85 deletions(-) create mode 100644 torchft/checkpointing/_rwlock.py create mode 100644 torchft/checkpointing/_serialization.py create mode 100644 torchft/checkpointing/http_transport_bench.py create mode 100644 torchft/checkpointing/rwlock_test.py diff --git a/torchft/checkpointing/_rwlock.py b/torchft/checkpointing/_rwlock.py new file mode 100644 index 00000000..db8c370d --- /dev/null +++ b/torchft/checkpointing/_rwlock.py @@ -0,0 +1,132 @@ +# -*- coding: utf-8 -*- +""" rwlock.py + + Adapted from: https://github.com/tylerneylon/rwlock/blob/main/rwlock.py + + A class to implement read-write locks on top of the standard threading + library. + + This is implemented with two mutexes (threading.Lock instances) as per this + wikipedia pseudocode: + + https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock#Using_two_mutexes + + __________________________ + License info (MIT): + + ******* + + Copyright 2023 Tyler Neylon and contributors + + Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit + persons to whom the Software is furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + ******* +""" + + +from contextlib import contextmanager +from threading import Lock +from typing import Generator + + +class RWLock(object): + """RWLock class; this is meant to allow an object to be read from by + multiple threads, but only written to by a single thread at a time. See: + https://en.wikipedia.org/wiki/Readers%E2%80%93writer_lock + + All operations are timed and will throw TimeoutError if the timeout is + exceeded. + + Usage: + + from rwlock import RWLock + + my_obj_rwlock = RWLock(timeout=60.0) + + # When reading from my_obj: + with my_obj_rwlock.r_lock(): + do_read_only_things_with(my_obj) + + # When writing to my_obj: + with my_obj_rwlock.w_lock(): + mutate(my_obj) + """ + + def __init__(self, timeout: float = -1) -> None: + self.timeout = timeout + + self._w_lock = Lock() + self._num_r_lock = Lock() + self._num_r = 0 + + # ___________________________________________________________________ + # Reading methods. + + def r_acquire(self) -> None: + if not self._num_r_lock.acquire(timeout=self.timeout): + raise TimeoutError( + f"Timed out waiting for rlock after {self.timeout} seconds" + ) + + self._num_r += 1 + if self._num_r == 1: + if not self._w_lock.acquire(timeout=self.timeout): + self._num_r -= 1 + self._num_r_lock.release() + raise TimeoutError( + f"Timed out waiting for wlock after {self.timeout} seconds" + ) + + self._num_r_lock.release() + + def r_release(self) -> None: + assert self._num_r > 0 + self._num_r_lock.acquire() + self._num_r -= 1 + if self._num_r == 0: + self._w_lock.release() + self._num_r_lock.release() + + @contextmanager + def r_lock(self) -> Generator[None, None, None]: + """This method is designed to be used via the `with` statement.""" + self.r_acquire() + try: + yield + finally: + self.r_release() + + # ___________________________________________________________________ + # Writing methods. + + def w_acquire(self) -> None: + if not self._w_lock.acquire(timeout=self.timeout): + raise TimeoutError( + f"Timed out waiting for wlock after {self.timeout} seconds" + ) + + def w_release(self) -> None: + self._w_lock.release() + + @contextmanager + def w_lock(self) -> Generator[None, None, None]: + """This method is designed to be used via the `with` statement.""" + self.w_acquire() + try: + yield + finally: + self.w_release() + + def w_locked(self) -> bool: + """Returns True if the lock is currently locked for reading.""" + return self._w_lock.locked() diff --git a/torchft/checkpointing/_serialization.py b/torchft/checkpointing/_serialization.py new file mode 100644 index 00000000..6f087431 --- /dev/null +++ b/torchft/checkpointing/_serialization.py @@ -0,0 +1,33 @@ +import io +import warnings +from typing import IO + +import torch + + +def _fallback_save(obj: object, f: IO[bytes]) -> None: + warnings.warn( + "using slow fallback torch.save implementation, please upgrade to PT 2.7+ for fast streaming saves" + ) + + torch.save(obj, f) + + +def _fallback_load(f: IO[bytes], weights_only: bool = True) -> object: + warnings.warn( + "using slow fallback torch.load implementation, please upgrade to PT 2.7+ for fast streaming loads" + ) + + # torch.load requires a seekable file object + buf = f.read() + reader = io.BytesIO(buf) + + return torch.load(reader, weights_only=weights_only) + + +try: + # pyre-fixme[21]: upgrade to PT 2.7 once released + from torch.distributed._serialization import _streaming_load, _streaming_save +except ImportError: + _streaming_load = _fallback_load + _streaming_save = _fallback_save diff --git a/torchft/checkpointing/http_transport.py b/torchft/checkpointing/http_transport.py index dce91805..826f23d5 100644 --- a/torchft/checkpointing/http_transport.py +++ b/torchft/checkpointing/http_transport.py @@ -4,19 +4,22 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import io import logging import socket import threading +import time import urllib.request -from abc import ABC, abstractmethod -from contextlib import contextmanager +from concurrent.futures import ThreadPoolExecutor +from contextlib import contextmanager, nullcontext from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Generator, Generic, List, Optional, TypeVar +from typing import Generator, List, Optional, TypeVar, cast import torch +from torch.utils._pytree import TreeSpec, tree_flatten, tree_unflatten +from torchft.checkpointing._rwlock import RWLock +from torchft.checkpointing._serialization import _streaming_load, _streaming_save from torchft.checkpointing.transport import CheckpointTransport from torchft.http import _IPv6HTTPServer @@ -26,22 +29,11 @@ @contextmanager -def _timed_acquire( - lock: threading.Lock, timeout: timedelta -) -> Generator[None, None, None]: - """ - Acquire a lock with a timeout. - - Args: - lock: the lock to acquire - timeout: the timeout to acquire the lock - """ - if not lock.acquire(timeout=timeout.total_seconds()): - raise TimeoutError(f"timed out acquiring lock after {timeout}") - try: - yield - finally: - lock.release() +def _time(desc: str) -> Generator[None, None, None]: + start = time.perf_counter() + yield + end = time.perf_counter() + logger.info(f"{desc} took {end - start}s") class HTTPTransport(CheckpointTransport[T]): @@ -53,15 +45,24 @@ class HTTPTransport(CheckpointTransport[T]): from an existing worker. Args: - state_dict: a callable that returns the state dict to be transferred + timeout: the timeout for HTTP requests + num_chunks: the number of chunks to split the checkpoint into (0 for no chunking) """ - def __init__(self, timeout: timedelta) -> None: - self._checkpoint_lock = threading.Lock() + def __init__(self, timeout: timedelta, num_chunks: int) -> None: + self._checkpoint_lock = RWLock(timeout=timeout.total_seconds()) self._disallowed = False self._step = -1 self._timeout = timeout self._state_dict: Optional[T] = None + self._num_chunks = num_chunks + self._stream: Optional[torch.cuda.Stream] = ( + torch.cuda.Stream() if torch.cuda.is_available() else None + ) + + # staged checkpoint information + self._spec: Optional[TreeSpec] = None + self._chunks: Optional[List[List[object]]] = None # We don't allow checkpoints until the first send_checkpoint to avoid # serving the default step=-1 invalid checkpoint. @@ -78,37 +79,56 @@ def do_GET(self): # validate socket timeout is actually set assert self.connection.gettimeout() == self.timeout - with _timed_acquire( - ckpt_server._checkpoint_lock, ckpt_server._timeout - ): + with ckpt_server._checkpoint_lock.r_lock(): step = ckpt_server._step - if self.path != f"/checkpoint/{step}": - self.send_response(400) - self.send_header("Content-type", "text/plain") - self.end_headers() - self.err( - f"invalid checkpoint requested, serving {step} but got {self.path}" + parts = self.path.split("/") + assert len(parts) == 4 + if parts[1] != "checkpoint": + self.send_error( + 400, + f"invalid url format, expected /checkpoint/step/key but got {self.path}", + ) + return + + step = int(parts[2]) + if step != ckpt_server._step: + self.send_error( + 400, + f"invalid checkpoint requested, serving {ckpt_server._step} but got {step=}", ) return - self.send_response(200) - self.send_header("Content-type", "application/octet-stream") - self.end_headers() + key = parts[3] + if key == "full": + self.send_response(200) + self.send_header("Content-type", "application/octet-stream") + self.end_headers() + + state_dict = ckpt_server._state_dict + + _streaming_save(state_dict, self.wfile) + return + + if key == "metadata": + self.send_response(200) + self.send_header("Content-type", "application/octet-stream") + self.end_headers() + + _streaming_save(ckpt_server._spec, self.wfile) + else: + chunk = ckpt_server._chunks[int(key)] - state_dict = ckpt_server._state_dict + self.send_response(200) + self.send_header("Content-type", "application/octet-stream") + self.end_headers() - torch.save(state_dict, self.wfile) + _streaming_save(chunk, self.wfile) except Exception as e: logger.exception( f"Exception in checkpoint server when handling {self.path=}: {e}", ) - self.send_response(500, str(e)) - self.end_headers() - - def err(self, msg: str) -> None: - logger.error(msg) - self.wfile.write(msg.encode()) + self.send_error(500, str(e)) server_address = ("", 0) self._server = _IPv6HTTPServer(server_address, RequestHandler) @@ -122,22 +142,23 @@ def err(self, msg: str) -> None: self._thread.start() @classmethod - def load_from_address(cls, address: str, timeout: timedelta) -> T: + def _load_from_address(cls, address: str, timeout: timedelta) -> object: """ Loads a checkpoint from the given address. Args: address: the HTTP address to load the checkpoint from """ - logger.info(f"fetching checkpoint from {address}") - - with urllib.request.urlopen(address, timeout=timeout.total_seconds()) as f: - data = f.read() + msg = f"fetching checkpoint from {address}" + logger.info(msg) - reader = io.BytesIO(data) - # We have to set weights_only to False as there are some non-tensor - # states like lr_scheduler. - return torch.load(reader, weights_only=False) + with _time(msg), urllib.request.urlopen( + address, timeout=timeout.total_seconds() + ) as f: + # We have to set weights_only to False as there are some non-tensor + # states like lr_scheduler. + # pyre-fixme[16]: needs torch>=2.7 + return cast(T, _streaming_load(f, weights_only=False)) def address(self) -> str: """ @@ -165,7 +186,7 @@ def disallow_checkpoint(self) -> None: """ if not self._disallowed: self._disallowed = True - self._checkpoint_lock.acquire() + self._checkpoint_lock.w_acquire() def allow_checkpoint(self, step: int) -> None: """ @@ -178,7 +199,7 @@ def allow_checkpoint(self, step: int) -> None: if self._disallowed: self._disallowed = False - self._checkpoint_lock.release() + self._checkpoint_lock.w_release() def shutdown(self, wait: bool = True) -> None: """ @@ -198,10 +219,80 @@ def metadata(self) -> str: def send_checkpoint( self, dst_ranks: List[int], step: int, state_dict: T, timeout: timedelta ) -> None: - self._state_dict = state_dict + values, spec = tree_flatten(state_dict) + + with ( + torch.cuda.stream(self._stream) + if self._stream is not None + else nullcontext() + ): + with _time("transferring state_dict to CPU"): + values = _to_cpu(values, pin_memory=False) + if self._stream is not None: + self._stream.synchronize() + + # Unflatten so non-chunked transfer uses CPU tensors + self._state_dict = tree_unflatten(values, spec) + + # Save spec for chunked + self._spec = spec + self._chunks = _split_chunks(values, self._num_chunks) + self.allow_checkpoint(step) def recv_checkpoint( self, src_rank: int, metadata: str, step: int, timeout: timedelta ) -> T: - return self.load_from_address(f"{metadata}{step}", timeout) + base_url = f"{metadata}{step}" + if self._num_chunks == 0: + return cast(T, self._load_from_address(f"{base_url}/full", timeout)) + else: + urls = [f"{base_url}/metadata"] + [ + f"{base_url}/{i}" for i in range(self._num_chunks) + ] + + with ThreadPoolExecutor(max_workers=len(urls)) as executor: + futures = [ + executor.submit(self._load_from_address, url, timeout) + for url in urls + ] + + spec, *chunks = [future.result() for future in futures] + spec = cast(TreeSpec, spec) + chunks = cast(List[List[object]], chunks) + + values = _merge_chunks(chunks, self._num_chunks) + + return tree_unflatten(values, spec) + + +def _to_cpu(values: List[T], pin_memory: bool) -> List[T]: + out = [] + for v in values: + if isinstance(v, torch.Tensor): + if v.device.type == "cuda": + if pin_memory: + cpu = torch.empty(*tuple(v.size()), dtype=v.dtype, pin_memory=True) + cpu.copy_(v, non_blocking=True) + out.append(cpu) + else: + out.append(v.cpu()) + else: + out.append(v) + else: + out.append(v) + return out + + +def _split_chunks(values: List[T], num_chunks: int) -> List[List[T]]: + return [values[i::num_chunks] for i in range(num_chunks)] + + +def _merge_chunks(chunks: List[List[T]], num_chunks: int) -> List[T]: + max_len = max(len(lst) for lst in chunks) + output_list = [] + for i in range(max_len): + for lst in chunks: + if i < len(lst): + output_list.append(lst[i]) + return output_list diff --git a/torchft/checkpointing/http_transport_bench.py b/torchft/checkpointing/http_transport_bench.py new file mode 100644 index 00000000..4e52193c --- /dev/null +++ b/torchft/checkpointing/http_transport_bench.py @@ -0,0 +1,55 @@ +import logging +import sys +from datetime import timedelta +from typing import List + +import torch + +from torchft.checkpointing.http_transport import HTTPTransport, _time + +logger: logging.Logger = logging.getLogger(__name__) + + +def main(argv: List[str]) -> None: + import argparse + + logging.basicConfig(level=logging.INFO) + + parser = argparse.ArgumentParser() + parser.add_argument("--num-chunks", type=int, default=0) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--chunk-size", type=int, default=3_000_000) # 3MB + parser.add_argument("--total-size", type=int, default=12_000_000_000) # 12GB + args = parser.parse_args(argv) + + device = torch.device(args.device) + num_chunks: int = args.num_chunks + CHUNK_SIZE = args.chunk_size + TOTAL_SIZE = args.total_size + + transport = HTTPTransport(timedelta(seconds=60), num_chunks=num_chunks) + metadata = transport.metadata() + + logger.info(f"creating state_dict... {CHUNK_SIZE=} {TOTAL_SIZE=}") + + with _time("create state_dict"): + state_dict = {} + for i in range(0, TOTAL_SIZE, CHUNK_SIZE): + state_dict[f"chunk/{i}"] = torch.zeros( + CHUNK_SIZE // 4, dtype=torch.float32, device=device + ) + + logger.info(f"fetching from {metadata=} {device=} {num_chunks=} {len(state_dict)=}") + + transport.send_checkpoint( + dst_ranks=[0], step=1, state_dict=state_dict, timeout=timedelta(seconds=60) + ) + + with _time("fetching checkpoint"): + transport.recv_checkpoint( + src_rank=1, metadata=metadata, step=1, timeout=timedelta(seconds=60) + ) + + +if __name__ == "__main__": + main(sys.argv[1:]) diff --git a/torchft/checkpointing/http_transport_test.py b/torchft/checkpointing/http_transport_test.py index 26ac26f9..6c297730 100644 --- a/torchft/checkpointing/http_transport_test.py +++ b/torchft/checkpointing/http_transport_test.py @@ -4,22 +4,47 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import threading import urllib.error from datetime import timedelta +from typing import Any, Dict from unittest import TestCase from unittest.mock import MagicMock -from torchft.checkpointing.http_transport import HTTPTransport, _timed_acquire - - -class TestCheckpointing(TestCase): - def test_checkpoint_server(self) -> None: - expected = {"state": "dict"} +import torch +from parameterized import parameterized + +from torchft.checkpointing.http_transport import HTTPTransport +from torchft.checkpointing.http_transport_bench import main as bench_main + + +class TestHTTPTransport(TestCase): + def assertStateDictEqual(self, a: Dict[str, object], b: Dict[str, object]) -> None: + for k, v1 in a.items(): + v2 = b[k] + if isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor): + torch.testing.assert_close(v1.cpu(), v2.cpu()) + else: + self.assertEqual(v1, v2) + + @parameterized.expand( + [ + ("no chunks", 0), + ("chunked", 3), + ] + ) + def test_checkpoint_server(self, name: str, num_chunks: int) -> None: + expected: Dict[str, object] = { + "state": "dict", + "tensor": torch.rand(5, 2), + "cuda": torch.rand( + 2, 3, device="cuda" if torch.cuda.is_available() else "cpu" + ), + } state_dict_fn = MagicMock() state_dict_fn.return_value = expected server = HTTPTransport( timeout=timedelta(seconds=10), + num_chunks=num_chunks, ) server.send_checkpoint( @@ -34,7 +59,7 @@ def test_checkpoint_server(self) -> None: out = server.recv_checkpoint( src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10) ) - self.assertEqual(out, expected) + self.assertStateDictEqual(out, expected) # test timeout with self.assertRaisesRegex(urllib.error.URLError, r"urlopen error"): @@ -50,7 +75,9 @@ def test_checkpoint_server(self) -> None: timeout=timedelta(seconds=10), ) - with self.assertRaisesRegex(urllib.error.HTTPError, r"Error 400"): + with self.assertRaisesRegex( + urllib.error.HTTPError, r"Error 400.*serving 2345 but got step=1234" + ): server.recv_checkpoint( src_rank=0, metadata=metadata, step=1234, timeout=timedelta(seconds=10) ) @@ -60,18 +87,19 @@ def test_checkpoint_server(self) -> None: def test_checkpoint_server_locking(self) -> None: server = HTTPTransport( timeout=timedelta(seconds=10), + num_chunks=0, ) # server should start up in a disallowed state this will block incoming # requests until allow_checkpoint is called - self.assertTrue(server._checkpoint_lock.locked()) + self.assertTrue(server._checkpoint_lock.w_locked()) self.assertTrue(server._disallowed) self.assertEqual(server._step, -1) # allow requests server.allow_checkpoint(1) - self.assertFalse(server._checkpoint_lock.locked()) + self.assertFalse(server._checkpoint_lock.w_locked()) self.assertFalse(server._disallowed) self.assertEqual(server._step, 1) @@ -81,25 +109,17 @@ def test_checkpoint_server_locking(self) -> None: server.disallow_checkpoint() server.disallow_checkpoint() - self.assertTrue(server._checkpoint_lock.locked()) + self.assertTrue(server._checkpoint_lock.w_locked()) self.assertTrue(server._disallowed) server.shutdown() - def test_timed_acquire(self) -> None: - lock = threading.Lock() - - with _timed_acquire(lock, timedelta(seconds=10)): - self.assertTrue(lock.locked()) - - self.assertFalse(lock.locked()) - - lock.acquire() - - with self.assertRaisesRegex( - TimeoutError, r"timed out acquiring lock after 0.0" - ): - with _timed_acquire(lock, timedelta(seconds=0.0)): - pass - - self.assertTrue(lock.locked()) + def test_benchmark(self) -> None: + bench_main( + [ + "--chunk-size=10", + "--num-chunks=0", + "--total-size=100", + "--device=cpu", + ] + ) diff --git a/torchft/checkpointing/rwlock_test.py b/torchft/checkpointing/rwlock_test.py new file mode 100644 index 00000000..49af8f6f --- /dev/null +++ b/torchft/checkpointing/rwlock_test.py @@ -0,0 +1,52 @@ +import pytest + +from torchft.checkpointing._rwlock import RWLock + + +def test_w_locked() -> None: + lock = RWLock() + + with lock.w_lock(): + assert lock.w_locked() + assert not lock.w_locked() + + +def test_w_lock_timeout() -> None: + lock = RWLock(timeout=0.01) + + lock.r_acquire() + lock.r_acquire() + + with pytest.raises(TimeoutError): + lock.w_acquire() + + with pytest.raises(TimeoutError): + with lock.w_lock(): + pass + + lock.r_release() + with pytest.raises(TimeoutError): + lock.w_acquire() + + lock.r_release() + with lock.w_lock(): + pass + lock.w_acquire() + + +def test_r_lock_timeout() -> None: + lock = RWLock(timeout=0.01) + + lock.w_acquire() + + with pytest.raises(TimeoutError): + lock.r_acquire() + + with pytest.raises(TimeoutError): + with lock.r_lock(): + pass + + lock.w_release() + with lock.r_lock(): + pass + lock.r_acquire() diff --git a/torchft/manager.py b/torchft/manager.py index 5113f0ee..a3556897 100644 --- a/torchft/manager.py +++ b/torchft/manager.py @@ -162,6 +162,7 @@ def __init__( if checkpoint_transport is None: checkpoint_transport = HTTPTransport[Dict[str, T]]( timeout=timeout, + num_chunks=0, ) self._checkpoint_transport: CheckpointTransport[Dict[str, T]] = (