Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions torchft/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,10 @@
from contextlib import contextmanager
from datetime import timedelta
from http.server import BaseHTTPRequestHandler
from typing import Generator, Generic, List, Optional, TypeVar

import torch
from typing import Generator, Generic, List, Optional, TypeVar, cast

from torchft.http import _IPv6HTTPServer
from torchft.serialization import streaming_load, streaming_save

logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -161,7 +160,7 @@ def do_GET(self):

state_dict = ckpt_server._state_dict

torch.save(state_dict, self.wfile)
streaming_save(state_dict, self.wfile)
except Exception as e:
logger.exception(
f"Exception in checkpoint server when handling {self.path=}: {e}",
Expand Down Expand Up @@ -198,9 +197,8 @@ def load_from_address(cls, address: str, timeout: timedelta) -> T:
data = f.read()

reader = io.BytesIO(data)
# We have to set weights_only to False as there are some non-tensor
# states like lr_scheduler.
return torch.load(reader, weights_only=False)
state_dict = streaming_load(reader)
return cast(T, state_dict)

def address(self) -> str:
"""
Expand Down
5 changes: 5 additions & 0 deletions torchft/checkpointing_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import threading
import unittest
import urllib.error
from datetime import timedelta
from unittest import TestCase
Expand Down Expand Up @@ -103,3 +104,7 @@ def test_timed_acquire(self) -> None:
pass

self.assertTrue(lock.locked())


if __name__ == "__main__":
unittest.main()
100 changes: 100 additions & 0 deletions torchft/serialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
import pickle
from dataclasses import dataclass
from io import BufferedIOBase
from typing import Dict, Tuple

import torch


@dataclass
class _Entry:
key: str
dtype: object
is_storage: bool
length: int


class _InMemoryStateDict:
def __init__(self) -> None:
self.records: Dict[str, Tuple[object, int]] = {}

def write_record(self, key: str, data: object, length: int) -> None:
self.records[key] = (data, length)

def write_to(self, f: BufferedIOBase) -> None:
entries = []
for key, (data, length) in self.records.items():
entries.append(
_Entry(
key=key,
is_storage=isinstance(data, torch.UntypedStorage),
dtype=type(data),
length=length,
)
)

pickle.dump(entries, f)

for key, (data, length) in self.records.items():
if isinstance(data, bytes):
f.write(data)
elif isinstance(data, str):
f.write(data.encode("utf-8"))
elif isinstance(data, torch.UntypedStorage):
data._write_file(f, False, False, 1)
else:
raise TypeError(f"unknown type: {type(data)}")

def read_from(self, f: BufferedIOBase) -> None:
entries = pickle.load(f)

for entry in entries:
data = f.read(entry.length)
if entry.is_storage:
storage = torch.frombuffer(
data,
dtype=torch.uint8,
).untyped_storage()

self.records[entry.key] = (
storage,
entry.length,
)
else:
self.records[entry.key] = (data, entry.length)

def has_record(self, key: str) -> bool:
return key in self.records

def get_record(self, key: str) -> object:
return self.records[key][0]

def get_storage_from_record(
self, key: str, _length: int, _type: int
) -> torch.Tensor:
return torch.tensor(self.records[key][0], dtype=torch.uint8)

def serialization_id(self) -> str:
return "torchft"


def streaming_save(obj: object, f: BufferedIOBase) -> None:
out = _InMemoryStateDict()
torch.serialization._save(
obj,
zip_file=out,
pickle_module=pickle,
pickle_protocol=2,
_disable_byteorder_record=False,
)
out.write_to(f)


def streaming_load(f: BufferedIOBase) -> object:
out = _InMemoryStateDict()
out.read_from(f)
return torch.serialization._load(
zip_file=out,
map_location=None,
pickle_module=pickle,
)
106 changes: 106 additions & 0 deletions torchft/serialization_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from io import BytesIO
from typing import cast
from unittest import TestCase

import torch
import torch.distributed as dist
from torch.distributed.tensor import DeviceMesh, DTensor, distribute_tensor

from torchft.serialization import streaming_load, streaming_save


class MyClass:
def __init__(self, a: int) -> None:
self.a = a

def __eq__(self, other: "MyClass") -> bool:
return self.a == other.a


class TestCheckpointingSerialization(TestCase):
def test_scalar_tensor(self) -> None:
tensor = torch.tensor(42, dtype=torch.int32)
state_dict = {"scalar": tensor}
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
torch.testing.assert_close(result, state_dict)

def test_strided_tensor(self) -> None:
base_tensor = torch.arange(16, dtype=torch.float32).reshape(4, 4)
strided_tensor = base_tensor[::2, ::2]
state_dict = {"strided": strided_tensor}
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
torch.testing.assert_close(result, state_dict)

def test_tensor_with_offset(self) -> None:
base_tensor = torch.arange(10, dtype=torch.float64)
offset_tensor = base_tensor[2:]
state_dict = {"offset": offset_tensor}
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
torch.testing.assert_close(result, state_dict)

def test_nested_tensors(self) -> None:
tensor1 = torch.tensor([1, 2, 3], dtype=torch.int32)
tensor2 = torch.tensor([[1.5, 2.5], [3.5, 4.5]], dtype=torch.float64)
state_dict = {"nested": {"tensor1": tensor1, "tensor2": tensor2}}
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
torch.testing.assert_close(result, state_dict)

def test_various_data_types(self) -> None:
tensor_float32 = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
tensor_int16 = torch.tensor([1, 2, 3], dtype=torch.int16)
tensor_bool = torch.tensor([True, False, True], dtype=torch.bool)
state_dict = {
"float32": tensor_float32,
"int16": tensor_int16,
"bool": tensor_bool,
}
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
torch.testing.assert_close(result, state_dict)

def test_dtensor(self) -> None:
dist.init_process_group(
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
)

device_mesh = DeviceMesh("cpu", 1)
tensor = torch.randn(4, 4, device="cuda")
dtensor = distribute_tensor(tensor, device_mesh, [])
state_dict = dtensor
file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = cast(DTensor, streaming_load(file))
torch.testing.assert_close(result.to_local(), state_dict.to_local())

def test_python_object(self) -> None:
state_dict = {
"obj": MyClass(42),
}

file = BytesIO()
streaming_save(state_dict, file)
file.seek(0)

result = streaming_load(file)
self.assertEqual(result, state_dict)
Loading