diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index 48a5d516..228f7348 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -17,9 +17,10 @@ import threading import urllib.request from abc import ABC, abstractmethod +from contextlib import contextmanager from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Generic, List, Optional, TypeVar +from typing import Generator, Generic, List, Optional, TypeVar import torch @@ -87,6 +88,25 @@ def shutdown(self, wait: bool = True) -> None: """ +@contextmanager +def _timed_acquire( + lock: threading.Lock, timeout: timedelta +) -> Generator[None, None, None]: + """ + Acquire a lock with a timeout. + + Args: + lock: the lock to acquire + timeout: the timeout to acquire the lock + """ + if not lock.acquire(timeout=timeout.total_seconds()): + raise TimeoutError(f"timed out acquiring lock after {timeout}") + try: + yield + finally: + lock.release() + + class CheckpointServer(CheckpointTransport[T]): """ This is an HTTP server that can be used to transfer checkpoints @@ -106,6 +126,10 @@ def __init__(self, timeout: timedelta) -> None: self._timeout = timeout self._state_dict: Optional[T] = None + # We don't allow checkpoints until the first send_checkpoint to avoid + # serving the default step=-1 invalid checkpoint. + self.disallow_checkpoint() + ckpt_server = self class RequestHandler(BaseHTTPRequestHandler): @@ -117,7 +141,9 @@ def do_GET(self): # validate socket timeout is actually set assert self.connection.gettimeout() == self.timeout - with ckpt_server._checkpoint_lock: + with _timed_acquire( + ckpt_server._checkpoint_lock, ckpt_server._timeout + ): step = ckpt_server._step if self.path != f"/checkpoint/{step}": diff --git a/torchft/checkpointing_test.py b/torchft/checkpointing_test.py index e2a05e12..31658b43 100644 --- a/torchft/checkpointing_test.py +++ b/torchft/checkpointing_test.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import threading import urllib.error from datetime import timedelta from unittest import TestCase from unittest.mock import MagicMock -from torchft.checkpointing import CheckpointServer +from torchft.checkpointing import CheckpointServer, _timed_acquire class TestCheckpointing(TestCase): @@ -55,3 +56,50 @@ def test_checkpoint_server(self) -> None: ) server.shutdown() + + def test_checkpoint_server_locking(self) -> None: + server = CheckpointServer( + timeout=timedelta(seconds=10), + ) + + # server should start up in a disallowed state this will block incoming + # requests until allow_checkpoint is called + self.assertTrue(server._checkpoint_lock.locked()) + self.assertTrue(server._disallowed) + self.assertEqual(server._step, -1) + + # allow requests + server.allow_checkpoint(1) + + self.assertFalse(server._checkpoint_lock.locked()) + self.assertFalse(server._disallowed) + self.assertEqual(server._step, 1) + + # duplicate allow/disallow is fine + server.allow_checkpoint(2) + self.assertEqual(server._step, 2) + + server.disallow_checkpoint() + server.disallow_checkpoint() + self.assertTrue(server._checkpoint_lock.locked()) + self.assertTrue(server._disallowed) + + server.shutdown() + + def test_timed_acquire(self) -> None: + lock = threading.Lock() + + with _timed_acquire(lock, timedelta(seconds=10)): + self.assertTrue(lock.locked()) + + self.assertFalse(lock.locked()) + + lock.acquire() + + with self.assertRaisesRegex( + TimeoutError, r"timed out acquiring lock after 0.0" + ): + with _timed_acquire(lock, timedelta(seconds=0.0)): + pass + + self.assertTrue(lock.locked())