Skip to content

Commit

Permalink
[chore] update nightly version (#1064)
Browse files Browse the repository at this point in the history
* update nightly version

* update wgit to use numpy for load/store

- this is introduced with new nightly torch version, which made torch.save() not
  producing deterministic bytes
- this make tensor<->numpy conversion and then do the save/load to avoid that issues.

* fixed tests

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Aug 25, 2022
1 parent e982b43 commit 15d4cf1
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 34 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ install_dep_pytorch_nightly: &install_dep_pytorch_nightly
# check if we have restored venv cache (/home/circleci/venv) correctly, if so, just skip
if [ -f /home/circleci/venv/check_version.py ]; then python /home/circleci/venv/check_version.py torch eq 1.12 && exit 0; fi
# start installing
pip install --pre torch==1.13.0.dev20220625+cu113 torchvision==0.14.0.dev20220625+cu113 --extra-index-url https://download.pytorch.org/whl/nightly/cu113
pip install --pre torch==1.13.0.dev20220825+cu113 torchvision==0.14.0.dev20220825+cu113 --extra-index-url https://download.pytorch.org/whl/nightly/cu113
pip install --progress-bar off -r requirements-dev.txt
pip install --progress-bar off -r requirements-benchmarks.txt
python -c 'import torch; print("Torch version:", torch.__version__)'
Expand Down
45 changes: 31 additions & 14 deletions fairscale/experimental/wgit/sha1_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import logging
import os
from pathlib import Path
import pickle
import shutil
import sys
import tempfile
Expand All @@ -19,6 +20,8 @@
import torch
from torch import Tensor

from fairscale.internal.containers import from_np, to_np

from .utils import ExitCode

#
Expand Down Expand Up @@ -76,7 +79,7 @@ def _copy_compressed(src: Path, dest: Path, thread: Optional[int], blocksize: in
break
destf.write(buf)
orig, comp = Path(src).stat().st_size, Path(dest).stat().st_size
assert orig >= comp, f"Compressed size {comp} > original {orig}"
assert orig >= comp or comp < 1 * 1024 * 1024, f"Compressed size {comp} > original {orig} for large data"
return orig, comp


Expand Down Expand Up @@ -127,7 +130,7 @@ class SHA1_Store:
To make things easier for the callers, this class accept input data
as files, state_dict or tensors. This class always returns in-memory
data, not on-disk files. This class doesn't really care or know the actually
data types. It uses torch.save() and torch.load() to do serialization.
data types.
A key issue is dealing with content deletion. We use a reference counting
algorithm, which means the caller must have symmetrical add/remove calls
Expand All @@ -140,10 +143,8 @@ class SHA1_Store:
addressibility and dependency graphs do not mix well.
We support multicore compression for the data to be store on per-object basis.
The ``torch.save()`` API uses zip format to store the data, but it appears to
be uncompressed. Even if it can be made compressed, it is likely a single
threaded compression. Therefore, we use pgzip to do parallel
compression/decompression on top of it to use all the cores.
We use pgzip to do parallel compression/decompression on top of it to use all
the cores.
Args:
path (Path):
Expand Down Expand Up @@ -229,10 +230,14 @@ def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = True, nam
If compress is True, the stored file is also compressed, which is useful for tensors
with a lot of zeros.
We use pickle and numpy for saving, loading because it is more deterministic
in terms of serialized bytes. They do lose info on device and dtype of
tensors. Will handle those later.
Args:
file_or_obj (str or tensor or Dict):
Path to the file to be added to the store or an in-memory object
that can be handled by torch.save. Note, OrderedDict is used when
that can be handled by pickle. Note, OrderedDict is used when
you call `state_dict()` on a nn.Module, and it is an instance
of a Dict too. A model's state_dict can be a simple dict because
it may contain both model state_dict and other non-tensor info.
Expand All @@ -244,18 +249,30 @@ def add(self, file_or_obj: Union[Path, Tensor, Dict], compress: bool = True, nam
Default: None
"""
start = time.time()
is_pickle_file = None

# Use `isinstance` not type() == Path since pathlib returns OS specific
# Use `isinstance` not `type() == Path` since pathlib returns OS specific
# Path types, which inherit from the Path class.
if isinstance(file_or_obj, (Path, str)):
# Make sure it is a valid file.
torch.load(cast(Union[Path, str], file_or_obj))
try:
pickle.load(open(file_or_obj, "rb"))
is_pickle_file = True
except Exception as e:
is_pickle_file = False
pass
file_path = Path(file_or_obj)
remove_tmp = False
elif isinstance(file_or_obj, (Tensor, Dict)):

if is_pickle_file is False:
# Continue to support torch.save()'ed files too by loading it
# in memory and the next if condition will pickle it.
file_or_obj = torch.load(cast(Union[Path, str], file_or_obj))

if isinstance(file_or_obj, (Tensor, Dict)):
# Serialize the object into a tmp file.
file_path = self._get_tmp_file_path()
torch.save(cast(Union[Tensor, Dict], file_or_obj), file_path)
pickle.dump(to_np(file_or_obj), open(file_path, "wb"))
remove_tmp = True
else:
assert False, f"incorrect input {type(file_or_obj)}"
Expand Down Expand Up @@ -361,12 +378,12 @@ def get(self, sha1: str) -> Union[Tensor, Dict]:
# uncompress into a temp file and return it.
tmp = self._get_tmp_file_path()
_copy_uncompressed(path, tmp, self._pgzip_threads, self._pgzip_block_size)
obj = torch.load(tmp)
obj = pickle.load(open(tmp, "rb"))
tmp.unlink()
return obj
else:
# Uncompressed.
return torch.load(path)
obj = pickle.load(open(path, "rb"))
return from_np(obj)

def delete(self, sha1: str) -> None:
"""Delete a SHA1
Expand Down
26 changes: 22 additions & 4 deletions fairscale/internal/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@
from collections import OrderedDict
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

import numpy as np
import torch
from torch.nn.utils.rnn import PackedSequence

"""Useful functions to deal with tensor types with other python container types."""


def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
def apply_to_type(
type_fn: Callable, fn: Callable, container: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set]
) -> Any:
"""Recursively apply to all objects in different kinds of container types that matches a type function."""

def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
if torch.is_tensor(x):
def _apply(x: Union[torch.Tensor, np.ndarray, Dict, List, Tuple, Set]) -> Any:
if type_fn(x):
return fn(x)
elif isinstance(x, OrderedDict):
od = x.__class__()
Expand All @@ -40,6 +43,21 @@ def _apply(x: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
return _apply(container)


def apply_to_tensors(fn: Callable, container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
"""Recursively apply to all tensor in different kinds of container types."""
return apply_to_type(torch.is_tensor, fn, container)


def to_np(tensor_or_container: Union[torch.Tensor, Dict, List, Tuple, Set]) -> Any:
"""Convert a tensor or a container to numpy."""
return apply_to_type(torch.is_tensor, lambda x: x.cpu().numpy(), tensor_or_container)


def from_np(ndarray_or_container: Union[np.ndarray, Dict, List, Tuple, Set]) -> Any:
"""Convert a ndarray or a container to tensor."""
return apply_to_type(lambda x: isinstance(x, np.ndarray), lambda x: torch.from_numpy(x), ndarray_or_container)


def pack_kwargs(*args: Any, **kwargs: Any) -> Tuple[Tuple[str, ...], Tuple[Any, ...]]:
"""
Turn argument list into separate key list and value list (unpack_kwargs does the opposite)
Expand Down
3 changes: 2 additions & 1 deletion tests/experimental/wgit/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ def test_api_add(capsys, repo, per_tensor, gzip):
json_data = json.load(f)

sha1_dir_0 = f"{sha1_hash[:2]}/" + f"{sha1_hash[2:]}"
assert json_data["SHA1"] == sha1_hash
# The sha1 are different because add internally use a different pickle method.
assert json_data["SHA1"] != sha1_hash


def test_api_commit(capsys, repo):
Expand Down
3 changes: 2 additions & 1 deletion tests/experimental/wgit/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ def test_cli_add(create_test_dir, capsys):
json_data = json.load(f)

sha1_dir_0 = f"{sha1_hash[:2]}/" + f"{sha1_hash[2:]}"
assert json_data["SHA1"] == sha1_hash
# The sha1 are different because add internally use a different pickle method.
assert json_data["SHA1"] != sha1_hash


def test_cli_commit(capsys):
Expand Down
21 changes: 8 additions & 13 deletions tests/experimental/wgit/test_sha1_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from fair_dev.testing.testing import objects_are_equal
from fairscale.experimental.wgit.sha1_store import SHA1_Store
from fairscale.internal import torch_version

# Get the absolute path of the parent at the beginning before any os.chdir(),
# so that we can proper clean it up at any CWD.
Expand Down Expand Up @@ -65,7 +64,7 @@ def test_sha1_add_file(sha1_store, compress):
]

for file, size in zip(chkpts, size_list):
torch.save(nn.Linear(1, int(size)), file)
torch.save(nn.Linear(1, int(size)).state_dict(), file)

# Add those 5 random files.
for c in chkpts:
Expand All @@ -82,10 +81,8 @@ def test_sha1_add_file(sha1_store, compress):
# Assert the ref counts are 1,1,1,1,1 and 2
with sha1_store._readonly_json_ctx:
json_dict = sha1_store._json_dict
if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "da3e19590de8f77fcf7a09c888c526b0149863a0"
assert key in json_dict.keys() and json_dict[key]["ref_count"] == 2, json_dict
key = "3c06179202606573a4982d91c2829a1a675362b3"
assert key in json_dict.keys() and json_dict[key]["ref_count"] == 2, json_dict
json_dict = dict(filter(lambda item: len(item[0]) == SHA1_KEY_STR_LEN, json_dict.items()))
assert sorted(map(lambda x: x["ref_count"], json_dict.values())) == [1, 1, 1, 1, 1, 2], json_dict

Expand All @@ -95,10 +92,10 @@ def test_sha1_add_state_dict(sha1_store, compress):
os.chdir(TESTING_STORE_DIR)
# add once
for i in range(3):
sha1_store.add(nn.Linear(10, 10).state_dict(), compress)
sha1_store.add(nn.Linear(100, 100).state_dict(), compress)
# add twice
for i in range(3):
sd = nn.Linear(8, 8).state_dict()
sd = nn.Linear(80, 80).state_dict()
sha1_store.add(sd, compress)
sha1_store.add(sd, compress)

Expand All @@ -111,13 +108,11 @@ def test_sha1_add_state_dict(sha1_store, compress):
@pytest.mark.parametrize("compress", [True, False])
def test_sha1_add_tensor(sha1_store, compress):
os.chdir(TESTING_STORE_DIR)
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]), compress)
sha1_store.add(torch.Tensor([1.0, 5.5, 3.4]).repeat(100), compress)
with sha1_store._readonly_json_ctx:
json_dict = sha1_store._json_dict
if torch_version() >= (1, 9, 0):
# torch 1.8 LTS doesn't produce deterministic checkpoint file from fixed tensors/state_dict.
key = "71df4069a03a766eacf9f03eea50968e87eae9f8"
assert key in json_dict.keys() and json_dict[key]["ref_count"] == 1, json_dict
key = "81cb2a3f823cfb78da8dd390e29e685720974cc7"
assert key in json_dict.keys() and json_dict[key]["ref_count"] == 1, json_dict


@pytest.mark.parametrize("compress", [True, False])
Expand Down

0 comments on commit 15d4cf1

Please sign in to comment.