diff --git a/torchft/checkpointing.py b/torchft/checkpointing.py index c3168b20..a225e3b1 100644 --- a/torchft/checkpointing.py +++ b/torchft/checkpointing.py @@ -20,11 +20,10 @@ from contextlib import contextmanager from datetime import timedelta from http.server import BaseHTTPRequestHandler -from typing import Generator, Generic, List, Optional, TypeVar - -import torch +from typing import Generator, Generic, List, Optional, TypeVar, cast from torchft.http import _IPv6HTTPServer +from torchft.serialization import streaming_load, streaming_save logger: logging.Logger = logging.getLogger(__name__) @@ -161,7 +160,7 @@ def do_GET(self): state_dict = ckpt_server._state_dict - torch.save(state_dict, self.wfile) + streaming_save(state_dict, self.wfile) except Exception as e: logger.exception( f"Exception in checkpoint server when handling {self.path=}: {e}", @@ -198,9 +197,8 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T: data = f.read() reader = io.BytesIO(data) - # We have to set weights_only to False as there are some non-tensor - # states like lr_scheduler. - return torch.load(reader, weights_only=False) + state_dict = streaming_load(reader) + return cast(T, state_dict) def address(self) -> str: """ diff --git a/torchft/checkpointing_test.py b/torchft/checkpointing_test.py index 31658b43..cd904288 100644 --- a/torchft/checkpointing_test.py +++ b/torchft/checkpointing_test.py @@ -5,6 +5,7 @@ # LICENSE file in the root directory of this source tree. import threading +import unittest import urllib.error from datetime import timedelta from unittest import TestCase @@ -103,3 +104,7 @@ def test_timed_acquire(self) -> None: pass self.assertTrue(lock.locked()) + + +if __name__ == "__main__": + unittest.main() diff --git a/torchft/serialization.py b/torchft/serialization.py new file mode 100644 index 00000000..4ec9bc89 --- /dev/null +++ b/torchft/serialization.py @@ -0,0 +1,100 @@ +import pickle +from dataclasses import dataclass +from io import BufferedIOBase +from typing import Dict, Tuple + +import torch + + +@dataclass +class _Entry: + key: str + dtype: object + is_storage: bool + length: int + + +class _InMemoryStateDict: + def __init__(self) -> None: + self.records: Dict[str, Tuple[object, int]] = {} + + def write_record(self, key: str, data: object, length: int) -> None: + self.records[key] = (data, length) + + def write_to(self, f: BufferedIOBase) -> None: + entries = [] + for key, (data, length) in self.records.items(): + entries.append( + _Entry( + key=key, + is_storage=isinstance(data, torch.UntypedStorage), + dtype=type(data), + length=length, + ) + ) + + pickle.dump(entries, f) + + for key, (data, length) in self.records.items(): + if isinstance(data, bytes): + f.write(data) + elif isinstance(data, str): + f.write(data.encode("utf-8")) + elif isinstance(data, torch.UntypedStorage): + data._write_file(f, False, False, 1) + else: + raise TypeError(f"unknown type: {type(data)}") + + def read_from(self, f: BufferedIOBase) -> None: + entries = pickle.load(f) + + for entry in entries: + data = f.read(entry.length) + if entry.is_storage: + storage = torch.frombuffer( + data, + dtype=torch.uint8, + ).untyped_storage() + + self.records[entry.key] = ( + storage, + entry.length, + ) + else: + self.records[entry.key] = (data, entry.length) + + def has_record(self, key: str) -> bool: + return key in self.records + + def get_record(self, key: str) -> object: + return self.records[key][0] + + def get_storage_from_record( + self, key: str, _length: int, _type: int + ) -> torch.Tensor: + return torch.tensor(self.records[key][0], dtype=torch.uint8) + + def serialization_id(self) -> str: + return "torchft" + + +def streaming_save(obj: object, f: BufferedIOBase) -> None: + out = _InMemoryStateDict() + torch.serialization._save( + obj, + zip_file=out, + pickle_module=pickle, + pickle_protocol=2, + _disable_byteorder_record=False, + ) + out.write_to(f) + + +def streaming_load(f: BufferedIOBase) -> object: + out = _InMemoryStateDict() + out.read_from(f) + return torch.serialization._load( + zip_file=out, + map_location=None, + pickle_module=pickle, + ) diff --git a/torchft/serialization_test.py b/torchft/serialization_test.py new file mode 100644 index 00000000..440ecad6 --- /dev/null +++ b/torchft/serialization_test.py @@ -0,0 +1,106 @@ +from io import BytesIO +from typing import cast +from unittest import TestCase + +import torch +import torch.distributed as dist +from torch.distributed.tensor import DeviceMesh, DTensor, distribute_tensor + +from torchft.serialization import streaming_load, streaming_save + + +class MyClass: + def __init__(self, a: int) -> None: + self.a = a + + def __eq__(self, other: "MyClass") -> bool: + return self.a == other.a + + +class TestCheckpointingSerialization(TestCase): + def test_scalar_tensor(self) -> None: + tensor = torch.tensor(42, dtype=torch.int32) + state_dict = {"scalar": tensor} + file = BytesIO() + streaming_save(state_dict, file) + file.seek(0) + + result = streaming_load(file) + torch.testing.assert_close(result, state_dict) + + def test_strided_tensor(self) -> None: + base_tensor = torch.arange(16, dtype=torch.float32).reshape(4, 4) + strided_tensor = base_tensor[::2, ::2] + state_dict = {"strided": strided_tensor} + file = BytesIO() + streaming_save(state_dict, file) + file.seek(0) + + result = streaming_load(file) + torch.testing.assert_close(result, state_dict) + + def test_tensor_with_offset(self) -> None: + base_tensor = torch.arange(10, dtype=torch.float64) + offset_tensor = base_tensor[2:] + state_dict = {"offset": offset_tensor} + file = BytesIO() + streaming_save(state_dict, file) + file.seek(0) + + result = streaming_load(file) + torch.testing.assert_close(result, state_dict) + + def test_nested_tensors(self) -> None: + tensor1 = torch.tensor([1, 2, 3], dtype=torch.int32) + tensor2 = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float64) + state_dict = {"nested": {"tensor1": tensor1, "tensor2": tensor2}} + file = BytesIO() + streaming_save(state_dict, file) + file.seek(0) + + result = streaming_load(file) + torch.testing.assert_close(result, state_dict) + + def test_various_data_types(self) -> None: + tensor_float32 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + tensor_int16 = torch.tensor([1, 2, 3], dtype=torch.int16) + tensor_bool = torch.tensor([True, False, True], dtype=torch.bool) + state_dict = { + "float32": tensor_float32, + "int16": tensor_int16, + "bool": tensor_bool, + } + file = BytesIO() + streaming_save(state_dict, file) + file.seek(0) + + result = streaming_load(file) + torch.testing.assert_close(result, state_dict) + + def test_dtensor(self) -> None: + dist.init_process_group( + backend="gloo", rank=0, world_size=1, store=dist.HashStore() + ) + + device_mesh = DeviceMesh("cpu", 1) + tensor = torch.randn(4, 4, device="cuda") + dtensor = distribute_tensor(tensor, device_mesh, []) + state_dict = dtensor + file = BytesIO() + streaming_save(state_dict, file) + file.seek(0) + + result = cast(DTensor, streaming_load(file)) + torch.testing.assert_close(result.to_local(), state_dict.to_local()) + + def test_python_object(self) -> None: + state_dict = { + "obj": MyClass(42), + } + + file = BytesIO() + streaming_save(state_dict, file) + file.seek(0) + + result = streaming_load(file) + self.assertEqual(result, state_dict)