Skip to content

Commit

Permalink
[feat] add compression and tests to sha1 store (#1032)
Browse files Browse the repository at this point in the history
Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Jul 18, 2022
1 parent c8327e1 commit d0ad08c
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 28 deletions.
15 changes: 15 additions & 0 deletions fairscale/experimental/wgit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,23 @@
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.

import sys
from typing import List

# Check for user requirements before we import our code.
try:
import pygit2
except ImportError:
print("Error: please pip install pygit2 module to use wgit")
sys.exit(1)

try:
import pgzip
except ImportError:
print("Error: please pip install pgzip module to use wgit")
sys.exit(1)


from .repo import Repo
from .signal_sparsity import Algo, SignalSparsity
from .version import __version_tuple__
Expand Down
81 changes: 72 additions & 9 deletions fairscale/experimental/wgit/sha1_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
import sys
import tempfile
import time
from typing import Any, Dict, Union, cast
from typing import Any, Dict, Optional, Union, cast

import pgzip
import torch
from torch import Tensor

Expand All @@ -25,6 +26,7 @@

# Const string keys for json file. Do not change for backward compatibilities.
RF_KEY = "ref_count"
COMP_KEY = "compressed"


def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -38,6 +40,28 @@ def _get_json_entry(d: Dict[str, Any]) -> Dict[str, Any]:
return d


def _copy_compressed(src: Path, dest: Path, thread: Optional[int], blocksize: int) -> None:
"""Helper to copy a file and compress it at the same time."""
with open(str(src), "rb") as srcf:
with pgzip.open(str(dest), "wb", compresslevel=5, thread=thread, blocksize=blocksize) as destf:
while True:
buf = srcf.read(blocksize)
if len(buf) == 0:
break
destf.write(buf)


def _copy_uncompressed(src: Path, dest: Path, thread: Optional[int], blocksize: int) -> None:
"""Helper to copy a file and uncompress it at the same time."""
with open(str(dest), "wb") as destf:
with pgzip.open(str(src), "rb", thread=thread, blocksize=blocksize) as srcf:
while True:
buf = srcf.read(blocksize)
if len(buf) == 0:
break
destf.write(buf)


class SHA1_Store:
"""
This class represents a SHA1 checksum based storage dir for state_dict
Expand All @@ -61,6 +85,12 @@ class SHA1_Store:
to delete in a version tracking graph. The lesson here is that content
addressibility and dependency graphs do not mix well.
We support multicore compression for the data to be store on per-object basis.
The ``torch.save()`` API uses zip format to store the data, but it appears to
be uncompressed. Even if it can be made compressed, it is likely a single
threaded compression. Therefore, we use pgzip to do parallel
compression/decompression on top of it to use all the cores.
Args:
parent_path (Path):
The parent path in which a SHA1_Store will be created.
Expand All @@ -75,16 +105,29 @@ class SHA1_Store:
sha1_buf_size (int):
Buffer size used for checksumming. Default: 100MB.
tmp_dir (str):
Dir for temporary files if input is an in-memory object.
Dir for temporary files if input is an in-memory object or output data needs
to be decompressed first.
pgzip_threads (int, optional):
Number of threads (cores) used in compression. Default: None to use all cores.
pgzip_block_size (int):
Per-thread block size for compression. Default: 10MB.
"""

def __init__(
self, parent_path: Path, init: bool = False, sha1_buf_size: int = 100 * 1024 * 1024, tmp_dir: str = ""
self,
parent_path: Path,
init: bool = False,
sha1_buf_size: int = 100 * 1024 * 1024,
tmp_dir: str = "",
pgzip_threads: Optional[int] = None,
pgzip_block_size: int = 10 * 1024 * 1024,
) -> None:
"""Create or wrap (if already exists) a sha1_store."""
self._path = parent_path.joinpath(SHA1_STORE_DIR_NAME)
self._ref_file_path = self._path.joinpath("ref_count.json")
self._sha1_buf_size = sha1_buf_size
self._pgzip_threads = pgzip_threads
self._pgzip_block_size = pgzip_block_size
self._json_dict: Dict[str, Any] = {"created_on": time.ctime()}

# Initialize the sha1_store if not exist and init==True.
Expand Down Expand Up @@ -121,7 +164,7 @@ def _store_json_dict(self) -> None:
with open(self._ref_file_path, "w", encoding="utf-8") as f:
json.dump(self._json_dict, f, ensure_ascii=False, indent=4)

def add(self, file_or_obj: Union[Path, Tensor, OrderedDict]) -> str:
def add(self, file_or_obj: Union[Path, Tensor, OrderedDict], compress: bool = False) -> str:
"""
Adds a file/object to the internal sha1_store and the sha1 references
accordingly.
Expand All @@ -130,6 +173,9 @@ def add(self, file_or_obj: Union[Path, Tensor, OrderedDict]) -> str:
in <file_or_obj> is moved within the sha1_store and the reference file is updated.
If the input is an object, it will be store in the self._tmp_dir and then moved.
If compress is True, the stored file is also compressed, which is useful for tensors
with a lot of zeros.
Args:
file_or_obj (str or tensor or OrderedDict):
Path to the file to be added to the sha1_store or an in-memory object
Expand All @@ -155,7 +201,7 @@ def add(self, file_or_obj: Union[Path, Tensor, OrderedDict]) -> str:
sha1_hash = self._get_sha1_hash(file_path)

# Add reference.
ref_count = self._add_ref(sha1_hash, True)
ref_count = self._add_ref(sha1_hash, True, compress)

if ref_count == 1:
# First time adding
Expand All @@ -172,12 +218,15 @@ def add(self, file_or_obj: Union[Path, Tensor, OrderedDict]) -> str:
# Transfer the file to the internal sha1_store
repo_fpath = repo_fdir.joinpath(sha1_hash)
try:
shutil.copy2(file_path, repo_fpath)
if compress:
_copy_compressed(file_path, repo_fpath, self._pgzip_threads, self._pgzip_block_size)
else:
shutil.copy2(file_path, repo_fpath)
except BaseException as error:
# Something went wrong, perhaps out of space, or race condition due to lack of locking.
# TODO (Min): proper handle the error and recover when we learn more here.
sys.stderr.write(f"An exception occured: {repr(error)}\n")
ref_count = self._add_ref(sha1_hash, False)
ref_count = self._add_ref(sha1_hash, False, compress)

# Clean up if needed.
if remove_tmp:
Expand Down Expand Up @@ -210,7 +259,18 @@ def get(self, sha1: str) -> Union[Tensor, OrderedDict]:
#
# TODO (Min): we could also keep a stats in the meta data on how many
# times the object is read. Will add if that's needed.
return torch.load(path)
self._load_json_dict()
if self._json_dict[sha1][COMP_KEY]:
# Compressed. Because pgzip doesn't support tell() yet, we need to
# uncompress into a temp file and return it.
tmp = self._get_tmp_file_path()
_copy_uncompressed(path, tmp, self._pgzip_threads, self._pgzip_block_size)
obj = torch.load(tmp)
tmp.unlink()
return obj
else:
# Uncompressed.
return torch.load(path)

def delete(self, sha1: str) -> None:
"""Delete a SHA1
Expand Down Expand Up @@ -282,7 +342,7 @@ def _sha1_to_dir(self, sha1: str) -> Path:
part1, part2 = sha1[:2], sha1[2:4]
return self._path.joinpath(part1, part2)

def _add_ref(self, current_sha1_hash: str, inc: bool) -> int:
def _add_ref(self, current_sha1_hash: str, inc: bool, compressed: bool) -> int:
"""
Update the reference count.
Expand Down Expand Up @@ -312,6 +372,9 @@ def _add_ref(self, current_sha1_hash: str, inc: bool) -> int:
entry[RF_KEY] += 1 if inc else -1
assert entry[RF_KEY] >= 0, "negative ref count"

# Update compressed flag.
entry[COMP_KEY] = compressed

self._json_dict[current_sha1_hash] = entry
self._store_json_dict()

Expand Down
5 changes: 4 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,8 @@ numpy == 1.22.0
# For layerwise gradient scaler
sklearn >= 0.0

# For weigit
# For weigit. These are actually user requirements, not developer requirements.
# However, due to the experimental nature of weigit, we don't expose to the
# general users of fairscale yet. We check for them in weigit's init code.
pygit2==1.9.2
pgzip==0.3.1
41 changes: 23 additions & 18 deletions tests/experimental/wgit/test_sha1_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ def teardown():
return sha1_store


def test_sha1_add_file(sha1_store):
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_file(sha1_store, compress):
os.chdir(PARENT_DIR)

# Create random checkpoints
Expand All @@ -65,15 +66,15 @@ def test_sha1_add_file(sha1_store):

# Add those 5 random files.
for c in chkpts:
sha1_store.add(c)
sha1_store.add(c, compress)

# Add a fixed data twice.
module = nn.Linear(100, 100, bias=False)
module.weight.data = torch.zeros(100, 100)
zeros_file = "zeros.pt"
torch.save(module.state_dict(), zeros_file)
sha1_store.add(zeros_file)
sha1_store.add(zeros_file)
sha1_store.add(zeros_file, compress)
sha1_store.add(zeros_file, compress)

# Assert the ref counts are 1,1,1,1,1 and 2
sha1_store._load_json_dict()
Expand All @@ -86,26 +87,28 @@ def test_sha1_add_file(sha1_store):
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 1, 1, 2], json_dict


def test_sha1_add_state_dict(sha1_store):
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_state_dict(sha1_store, compress):
os.chdir(PARENT_DIR)
# add once
for i in range(3):
sha1_store.add(nn.Linear(10, 10).state_dict())
sha1_store.add(nn.Linear(10, 10).state_dict(), compress)
# add twice
for i in range(3):
sd = nn.Linear(8, 8).state_dict()
sha1_store.add(sd)
sha1_store.add(sd)
sha1_store.add(sd, compress)
sha1_store.add(sd, compress)

sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
del json_dict["created_on"]
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 2, 2, 2], json_dict


def test_sha1_add_tensor(sha1_store):
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_tensor(sha1_store, compress):
os.chdir(PARENT_DIR)
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]))
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress)
sha1_store._load_json_dict()
json_dict = sha1_store._json_dict
if torch_version() >= (1, 9, 0):
Expand All @@ -114,7 +117,8 @@ def test_sha1_add_tensor(sha1_store):
assert key in json_dict.keys() and json_dict[key]["ref_count"] == 1, json_dict


def test_sha1_get(sha1_store):
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_get(sha1_store, compress):
"""Testing the get() API: normal and exception cases."""
os.chdir(PARENT_DIR)

Expand All @@ -125,15 +129,15 @@ def test_sha1_get(sha1_store):
tensor = torch.ones(20, 30)

# Check that we can get them back.
file_sha1 = sha1_store.add(file)
file_sha1 = sha1_store.add(file, compress)
sd = sha1_store.get(file_sha1)
assert objects_are_equal(sd, torch.load(file))

sd_sha1 = sha1_store.add(state_dict)
sd_sha1 = sha1_store.add(state_dict, compress)
sd = sha1_store.get(sd_sha1)
assert objects_are_equal(sd, state_dict)

tensor_sha1 = sha1_store.add(tensor)
tensor_sha1 = sha1_store.add(tensor, compress)
tensor_got = sha1_store.get(tensor_sha1)
assert objects_are_equal(tensor_got, tensor)

Expand All @@ -142,22 +146,23 @@ def test_sha1_get(sha1_store):
sha1_store.get(tensor_sha1[:-1])


def test_sha1_delete(sha1_store):
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_delete(sha1_store, compress):
"""Testing the delete() API: with ref counting behavior."""
os.chdir(PARENT_DIR)

# Add once and delete, second delete should throw an exception.
tensor = torch.ones(30, 50)
sha1 = sha1_store.add(tensor)
sha1 = sha1_store.add(tensor, compress)
sha1_store.delete(sha1)
with pytest.raises(ValueError):
sha1_store.delete(sha1)

# Add multiple times and delete should match that.
state_dict = nn.Sequential(nn.Linear(10, 10), nn.Linear(10, 20)).state_dict()
sha1 = sha1_store.add(state_dict)
sha1 = sha1_store.add(state_dict, compress)
for i in range(3):
new_sha1 = sha1_store.add(state_dict)
new_sha1 = sha1_store.add(state_dict, compress)
assert sha1 == new_sha1, f"{sha1} vs. {new_sha1}"
for i in range(4):
sha1_store.delete(sha1)
Expand Down

0 comments on commit d0ad08c

Please sign in to comment.