Skip to content

Commit

Permalink
[fix] original size computation (#1037)
Browse files Browse the repository at this point in the history
* flip per_tensor's default

* fixed original size computation

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Jul 22, 2022
1 parent 2e544bd commit 16fba4c
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 11 deletions.
7 changes: 6 additions & 1 deletion fairscale/experimental/wgit/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,11 @@ def main(argv: List[str] = None) -> None:
metavar="FILE_PATH",
help="add a file to the staged changeset (default: none)",
)
add_parser.add_argument(
"--no_per_tensor",
action="store_true",
help="Disable per-tensor adding of a file",
)

commit_parser = subparsers.add_parser("commit", description="Commits the staged changes")
commit_parser.add_argument("commit", action="store_true", help="Commit the staged changes")
Expand Down Expand Up @@ -76,7 +81,7 @@ def main(argv: List[str] = None) -> None:

if args.command == "add":
repo = Repo(Path.cwd())
repo.add(args.add)
repo.add(args.add, per_tensor=not args.no_per_tensor)

if args.command == "status":
repo = Repo(Path.cwd())
Expand Down
4 changes: 2 additions & 2 deletions fairscale/experimental/wgit/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def _sanity_check(self) -> None:
def add(
self,
in_file_path: str,
per_tensor: bool = False,
per_tensor: bool = True,
gzip: bool = True,
sparsify: bool = False,
sparsify_policy: Any = None,
Expand All @@ -209,7 +209,7 @@ def add(
Add a file in a per-tensor fashion. This enables more deduplication
due to tensors being identical. Deduplication cannot be disabled
completely because we use a content addressable SHA1_Store class.
Default: False
Default: True
gzip (bool, optional):
Enable gzip based lossless compression on the object being added.
Default: True
Expand Down
5 changes: 4 additions & 1 deletion fairscale/experimental/wgit/sha1_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,10 @@ def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = True, nam

# Update the sizes for this entry.
entry = _get_json_entry(self._json_dict[sha1_hash])
o_diff = orig_size if ref_count == 1 else entry[ENTRY_OS_KEY]
assert (
ref_count == 1 or entry[ENTRY_OS_KEY] % (ref_count - 1) == 0
), f"incorrect size: {entry[ENTRY_OS_KEY]} and {ref_count}"
o_diff = orig_size if ref_count == 1 else (entry[ENTRY_OS_KEY] // (ref_count - 1))
d_diff = orig_size if ref_count == 1 else 0
c_diff = comp_size if ref_count == 1 else 0
entry[ENTRY_OS_KEY] += o_diff
Expand Down
11 changes: 6 additions & 5 deletions tests/experimental/wgit/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,8 @@ def test_api_commit(capsys, repo):
assert line[0].rstrip().split()[-1] == commit_msg


def test_api_status(capsys, repo):
@pytest.mark.parametrize("per_tensor", [True, False])
def test_api_status(capsys, repo, per_tensor):
# delete the repo and initialize a new one:
shutil.rmtree(".wgit")
repo = Repo(Path.cwd(), init=True)
Expand All @@ -96,7 +97,7 @@ def test_api_status(capsys, repo):

# check status before after a file is added but not committed
chkpt0 = f"checkpoint_{random.randint(0, 1)}.pt"
repo.add(chkpt0)
repo.add(chkpt0, per_tensor=per_tensor)
out = repo.status()
key_list = list(repo._get_metdata_files().keys())
assert out == {key_list[0]: RepoStatus.CHANGES_ADDED_NOT_COMMITED}
Expand All @@ -107,18 +108,18 @@ def test_api_status(capsys, repo):
assert out == {key_list[0]: RepoStatus.CLEAN}

# check status after a new change has been made to the file
torch.save(nn.Linear(1, int(15e5)), chkpt0)
torch.save(nn.Linear(1, int(15e5)).state_dict(), chkpt0)
out = repo.status()
assert out == {key_list[0]: RepoStatus.CHANGES_NOT_ADDED}

# add the new changes made to weigit
repo.add(chkpt0)
repo.add(chkpt0, per_tensor=per_tensor)
out = repo.status()
assert out == {key_list[0]: RepoStatus.CHANGES_ADDED_NOT_COMMITED}

# check status after a new different file is added to be tracked by weigit
chkpt3 = "checkpoint_3.pt"
repo.add(chkpt3)
repo.add(chkpt3, per_tensor=per_tensor)
key_list = list(repo._get_metdata_files().keys())
out = repo.status()
assert out == {
Expand Down
4 changes: 2 additions & 2 deletions tests/experimental/wgit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def create_test_dir():
# create random checkpoints
size_list = [30e5, 35e5, 40e5]
for i, size in enumerate(size_list):
torch.save(nn.Linear(1, int(size)), f"checkpoint_{i}.pt")
torch.save(nn.Linear(1, int(size)).state_dict(), f"checkpoint_{i}.pt")

# Test init.
cli.main(["init"])
Expand All @@ -53,7 +53,7 @@ def test_cli_init(create_test_dir, capsys):

def test_cli_add(create_test_dir, capsys):
chkpt0 = "checkpoint_0.pt"
cli.main(["add", chkpt0])
cli.main(["add", "--no_per_tensor", chkpt0])

sha1_store = SHA1_Store(
Path.cwd().joinpath(".wgit", "sha1_store"),
Expand Down

0 comments on commit 16fba4c

Please sign in to comment.