Skip to content

Commit

Permalink
[feat] add sha1_store get function (#1027)
Browse files Browse the repository at this point in the history
Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Jul 14, 2022
1 parent 68af57d commit 073618d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
10 changes: 9 additions & 1 deletion fairscale/experimental/wgit/sha1_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,15 @@ def get(self, sha1: str) -> Union[Tensor, OrderedDict]:
(Tensor or OrderedDict):
In-memory object.
"""
raise NotImplementedError()
path = self._sha1_to_dir(sha1).joinpath(sha1)
if not path.exists():
# This is potentially valid case for the caller, we need to inform the
# the caller about it.
raise ValueError(f"Try to get SHA1 {sha1} but it is not found")
# Directly return the object after loading it. This could be throw an
# exception but that indicates some internal error since we should never
# have stored the (invalid) object in the first place with the add() API.
return torch.load(path)

def delete(self, sha1: str) -> None:
"""Delete a SHA1
Expand Down
28 changes: 28 additions & 0 deletions tests/experimental/wgit/test_sha1_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import torch
from torch import nn

from fair_dev.testing.testing import objects_are_equal
from fairscale.experimental.wgit.sha1_store import SHA1_Store
from fairscale.internal import torch_version

Expand Down Expand Up @@ -111,3 +112,30 @@ def test_sha1_add_tensor(sha1_store):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "71df4069a03a766eacf9f03eea50968e87eae9f8"
assert key in json_dict.keys() and json_dict[key] == 1, json_dict


def test_sha1_get(sha1_store):
os.chdir(PARENT_DIR)

# Add a file, a state dict and a tensor.
file = "test_get.pt"
torch.save(nn.Linear(100, 100).state_dict(), file)
state_dict = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 20)).state_dict()
tensor = torch.ones(20, 30)

# Check that we can get them back.
file_sha1 = sha1_store.add(file)
sd = sha1_store.get(file_sha1)
assert objects_are_equal(sd, torch.load(file))

sd_sha1 = sha1_store.add(state_dict)
sd = sha1_store.get(sd_sha1)
assert objects_are_equal(sd, state_dict)

tensor_sha1 = sha1_store.add(tensor)
tensor_got = sha1_store.get(tensor_sha1)
assert objects_are_equal(tensor_got, tensor)

# Make sure invalid sha1 cause exceptions.
with pytest.raises(ValueError):
sha1_store.get(tensor_sha1[:-1])

0 comments on commit 073618d

Please sign in to comment.