Skip to content

Commit

Permalink
[feat]: add per-tensor add to repo (#1033)
Browse files Browse the repository at this point in the history
* formatting change, no logical change

* formatting and name change, no logical change

* [refactor] sha1_store's path arg

- make sha1_store's path arg directly the path, not its parent
- this is because sha1_store is not like a .git or a .wgit dir, which is
  nested inside another "working" dir. It is simply a store, which
  is using a given dir.
- updated repo and tests as well.

* remove a test warning due to deprecated API from torch

* [refactor] change how dot_wgit_dir_path is used

- it should only be assigned in __init__.
- we use it in error checking in the rest APIs.

* simplify the init a bit

* refactor the sanity check

* moved some functions, no code change

* [feat] added per-tensor add to the repo

* enabled gzip compression on add

* fix a unit test

* add a note

* make sha1 store work on general dict

* handle general state_dict from a model, not just a module's one-level OrderedDict

* formatting

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Jul 19, 2022
1 parent d0ad08c commit 4d58a29
Show file tree
Hide file tree
Showing 7 changed files with 292 additions and 217 deletions.
9 changes: 6 additions & 3 deletions fair_dev/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
return all(objects_are_equal(x, y, raise_exception) for x, y in zip(a, b))
elif torch.is_tensor(a):
try:
# assert_allclose doesn't strictly test shape, dtype and device
# assert_close doesn't strictly test shape, dtype and device
shape_dtype_device_match = a.size() == b.size() and a.dtype == b.dtype and a.device == b.device
if not shape_dtype_device_match:
if raise_exception:
Expand All @@ -513,8 +513,11 @@ def objects_are_equal(a: Any, b: Any, raise_exception: bool = False, dict_key: O
raise AssertionError(msg)
else:
return False
# assert_allclose.
torch.testing.assert_allclose(a, b)
# assert_close.
if torch_version() < (1, 12, 0):
torch.testing.assert_allclose(a, b)
else:
torch.testing.assert_close(a, b)
return True
except (AssertionError, RuntimeError) as e:
if raise_exception:
Expand Down

0 comments on commit 4d58a29

Please sign in to comment.