Skip to content

Commit

Permalink
[feat] add sha1_store delete function (#1028)
Browse files Browse the repository at this point in the history
Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Jul 14, 2022
1 parent 073618d commit c75d189
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 9 deletions.
60 changes: 55 additions & 5 deletions fairscale/experimental/wgit/sha1_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,20 @@
# for backward compatibility reasons.
SHA1_STORE_DIR_NAME = "sha1_store"

# Const string keys for json file. Do not change for backward compatibilities.
RF_KEY = "ref_count"


def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]:
"""Get a dict from a json entry.
This fills in any missing entries in case we load an older version
json file from the disk.
"""
if RF_KEY not in d.keys():
d[RF_KEY] = 0
return d


class SHA1_Store:
"""
Expand Down Expand Up @@ -181,6 +195,9 @@ def get(self, sha1: str) -> Union[Tensor, OrderedDict]:
Returns:
(Tensor or OrderedDict):
In-memory object.
Throws:
ValueError if sha1 is not found.
"""
path = self._sha1_to_dir(sha1).joinpath(sha1)
if not path.exists():
Expand All @@ -190,6 +207,9 @@ def get(self, sha1: str) -> Union[Tensor, OrderedDict]:
# Directly return the object after loading it. This could be throw an
# exception but that indicates some internal error since we should never
# have stored the (invalid) object in the first place with the add() API.
#
# TODO (Min): we could also keep a stats in the meta data on how many
# times the object is read. Will add if that's needed.
return torch.load(path)

def delete(self, sha1: str) -> None:
Expand All @@ -199,8 +219,34 @@ def delete(self, sha1: str) -> None:
sha1 (str):
SHA1 of the object to delete.
Throws:
ValueError if sha1 is not found.
"""
raise NotImplementedError()
path = self._sha1_to_dir(sha1).joinpath(sha1)
if not path.exists():
# This is potentially a valid case for the caller, we need to inform the
# the caller about it.
raise ValueError(f"Try to delete SHA1 {sha1} but it is not found")

self._load_json_dict()

assert sha1 in self._json_dict.keys(), "internal error: sha1 not found in json"
entry = _get_json_entry(self._json_dict[sha1])

assert entry[RF_KEY] > 0, f"ref count {entry[RF_KEY]} should be positive"
entry[RF_KEY] -= 1
if entry[RF_KEY] == 0:
# Now, since ref count is 0 now deleting the object.
path.unlink() # We may leave behind an empty dir, which is OK.
entry = {} # Below, we remove the entry because of this.

# Put the entry back and store it or delete it.
if entry:
self._json_dict[sha1] = entry
else:
# empty entry, it means this sha1 is deleted.
del self._json_dict[sha1]
self._store_json_dict()

def _get_sha1_hash(self, file_path: Union[str, Path]) -> str:
"""Return the sha1 hash of a file
Expand Down Expand Up @@ -257,12 +303,16 @@ def _add_ref(self, current_sha1_hash: str, inc: bool) -> int:

# Init the entry if needed.
if current_sha1_hash not in self._json_dict:
self._json_dict[current_sha1_hash] = 0
entry = {}
else:
entry = self._json_dict[current_sha1_hash]
entry = _get_json_entry(entry)

# Update the ref count.
self._json_dict[current_sha1_hash] += 1 if inc else -1
assert self._json_dict[current_sha1_hash] >= 0, "negative ref count"
entry[RF_KEY] += 1 if inc else -1
assert entry[RF_KEY] >= 0, "negative ref count"

self._json_dict[current_sha1_hash] = entry
self._store_json_dict()

return self._json_dict[current_sha1_hash]
return entry[RF_KEY]
32 changes: 28 additions & 4 deletions tests/experimental/wgit/test_sha1_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,9 @@ def test_sha1_add_file(sha1_store):
if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "da3e19590de8f77fcf7a09c888c526b0149863a0"
assert key in json_dict.keys() and json_dict[key] == 2, json_dict
assert key in json_dict.keys() and json_dict[key]["ref_count"] == 2, json_dict
del json_dict["created_on"]
assert sorted(json_dict.values()) == [1, 1, 1, 1, 1, 2], json_dict
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 1, 1, 2], json_dict


def test_sha1_add_state_dict(sha1_store):
Expand All @@ -100,7 +100,7 @@ def test_sha1_add_state_dict(sha1_store):
sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
del json_dict["created_on"]
assert sorted(json_dict.values()) == [1, 1, 1, 2, 2, 2], json_dict
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 2, 2, 2], json_dict


def test_sha1_add_tensor(sha1_store):
Expand All @@ -111,10 +111,11 @@ def test_sha1_add_tensor(sha1_store):
if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "71df4069a03a766eacf9f03eea50968e87eae9f8"
assert key in json_dict.keys() and json_dict[key] == 1, json_dict
assert key in json_dict.keys() and json_dict[key]["ref_count"] == 1, json_dict


def test_sha1_get(sha1_store):
"""Testing the get() API: normal and exception cases."""
os.chdir(PARENT_DIR)

# Add a file, a state dict and a tensor.
Expand All @@ -139,3 +140,26 @@ def test_sha1_get(sha1_store):
# Make sure invalid sha1 cause exceptions.
with pytest.raises(ValueError):
sha1_store.get(tensor_sha1[:-1])


def test_sha1_delete(sha1_store):
"""Testing the delete() API: with ref counting behavior."""
os.chdir(PARENT_DIR)

# Add once and delete, second delete should throw an exception.
tensor = torch.ones(30, 50)
sha1 = sha1_store.add(tensor)
sha1_store.delete(sha1)
with pytest.raises(ValueError):
sha1_store.delete(sha1)

# Add multiple times and delete should match that.
state_dict = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 20)).state_dict()
sha1 = sha1_store.add(state_dict)
for i in range(3):
new_sha1 = sha1_store.add(state_dict)
assert sha1 == new_sha1, f"{sha1} vs. {new_sha1}"
for i in range(4):
sha1_store.delete(sha1)
with pytest.raises(ValueError):
sha1_store.delete(sha1)

0 comments on commit c75d189

Please sign in to comment.