Skip to content

Commit

Permalink
[feat]: add size and names metadata to sha1 store (#1036)
Browse files Browse the repository at this point in the history
* additional metadata, step 1

* add gzip option to repo::add

* add repo:add's return value and some refactoring and todo

* added size metadata to sha1_store

* added names metadata to sha1_store

Co-authored-by: Min Xu <min.xu.public@gmail.com>
  • Loading branch information
min-xu-ai and flying-x committed Jul 21, 2022
1 parent 4d58a29 commit 2e544bd
Show file tree
Hide file tree
Showing 5 changed files with 282 additions and 72 deletions.
180 changes: 136 additions & 44 deletions fairscale/experimental/wgit/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,40 @@
SHA1_STORE_DIR_NAME = "sha1_store"


# These are on-disk keys. Don't modify for backward compatibility.
SHA1_KEY = "SHA1"
LAST_MODIFIED_TS_KEY = "last_modified_time_stamp"
REL_PATH_KEY = "file_path" # this will be removed from the json since it is redundant.


class RepoStatus(Enum):
"""Collections of Repo Statuses"""

CLEAN = 1
CHANGES_NOT_ADDED = 2
CHANGES_ADDED_NOT_COMMITED = 3


@dataclass
class SizeInfo:
"""Size info for a file or the repo in bytes.
Deduped size can't be disabled. So it will always be there.
Both sparsified and gzipped are optional. They are applied in the following
order if both are enabled:
sparsify -> gzip
Therefore, original >= deduped >= sparsified >= gzipped
"""

original: int
deduped: int
sparsified: int
gzipped: int


@dataclass
class _SHA1_Tensor:
"""Representing a tensor using sha1(s) from SHA1 store.
Expand Down Expand Up @@ -155,61 +189,130 @@ def _sanity_check(self) -> None:
sys.stderr.write("fatal: no wgit repo exists!\n")
sys.exit(1)

def add(self, in_file_path: str, per_tensor: bool = False) -> None:
def add(
self,
in_file_path: str,
per_tensor: bool = False,
gzip: bool = True,
sparsify: bool = False,
sparsify_policy: Any = None,
) -> Optional[Dict[Any, Any]]:
"""Add a file to the wgit repo.
This could a new file or a modified file. Adding an unmodified, existing file
is allowed but it is a noop.
Args:
in_file_path (str):
Path to the file to be added.
per_tensor (bool):
Add a file in a per-tensor fashion.
per_tensor (bool, optional):
Add a file in a per-tensor fashion. This enables more deduplication
due to tensors being identical. Deduplication cannot be disabled
completely because we use a content addressable SHA1_Store class.
Default: False
gzip (bool, optional):
Enable gzip based lossless compression on the object being added.
Default: True
sparsify (bool, optional):
Enable sparsify for the tensors, which is going to modify the values
for all or some tensors, i.e. lossy compression.
Default: False
sparsify_policy (Any):
TODO (Min): need to add a callback function to control which tensors
and how to sparsify.
Default: None
Returns:
(Dict, optional)
None if the content is added but not modified with lossy compression.
Otherwise, returns a state_dict that contains the modified Tensors to
be loaded back into the model, which means the tensors are dense, not
SST and DST tensors.
"""
self._sanity_check()

# create the corresponding metadata file
if sparsify and not per_tensor:
raise ValueError("Only support sparsity when per_tensor is true")

# Create the corresponding metadata file or load it if the file is
# not a newly added file.
file_path = Path(in_file_path)
rel_file_path = self._rel_file_path(file_path)
metadata_file = self._process_metadata_file(rel_file_path)

# add the file to the sha1_store
# Add the file to the sha1_store.
ret_state_dict = None
file_path_or_state_dict: Union[Path, Dict] = file_path
# TODO (Min): We don't add parent sha1 tracking to sha1 store due to
# de-duplication & dependency tracking can create cycles.
# We need to figure out a way to handle deletion.
sha1_dict = {}
# TODO (Min): We don't detect changes and compute delta on a modified file
# yet. Need to figure out a method for delta tracking.
if per_tensor:

def fn(element: Any) -> Any:
"""Callback on each leaf object for _recursive_apply_to_elements below."""
if isinstance(element, Tensor):
# TODO (Min): here we will optionally do SST/DST and add those
# tensors with sparsity.
sha1 = self._sha1_store.add(element, compress=True)
if sparsify:
# TODO (Min): here we will optionally do SST/DST and add those
# tensors with sparsity.
# Remember to update ret_state_dict
raise NotImplementedError()
sha1 = self._sha1_store.add(element, compress=gzip)
return _SHA1_Tensor(is_dense=True, dense_sha1=sha1)
else:
return element

state_dict = torch.load(file_path)
_recursive_apply_to_elements(state_dict, fn)
sha1_dict = {"__sha1_full__": self._sha1_store.add(state_dict)}
else:
sha1_dict = {"__sha1_full__": self._sha1_store.add(file_path)}
file_path_or_state_dict = state_dict

# Add this top-level object.
sha1 = self._sha1_store.add(file_path_or_state_dict, compress=gzip)

# write metadata to the metadata-file
self._write_metadata(metadata_file, file_path, sha1_dict)
self._write_metadata(metadata_file, file_path, sha1)
self._pygit.add() # add to the .wgit/.git repo

return ret_state_dict

def commit(self, message: str) -> None:
"""Commits staged changes to the repo.
Args:
message (str):
The commit message
The commit message to be added.
"""
self._sanity_check()

# TODO (Min): make commit message a json for better handling of metadata like step count,
# LR, sparsity level, etc.
self._pygit.commit(message)

def status(self) -> Dict:
def size_info(self, path: Optional[str] = None) -> SizeInfo:
"""Get size info for a file or the whole repo.
For the whole repo, just call size_info from sha1_store.
For a file, needs to open the metadata and find the sha1 and then
for per_tensor state_dict, collect size_info on all objects.
TODO (Min): not exactly clear it is easy to compute this with
delta encoding, deduplication between objects, this
is possible to compute precisely.
Args:
path (str, optional):
File path for the query. If None, return whole repo's info.
Default: None
Returns:
(SizeInfo):
The dataclass that contains the size info.
"""
raise NotImplementedError()

def status(self) -> Dict[str, RepoStatus]:
"""Show the state of the weigit working tree.
State can be
Expand All @@ -218,6 +321,8 @@ def status(self) -> Dict:
3. clean and tracking files after a change has been committed,
or clean with with an empty repo.
TODO (Min): this needs to return repo status and dirty files and untracked
files too.
Returns:
(dict):
A dict keyed with files and their status.
Expand Down Expand Up @@ -250,6 +355,8 @@ def log(self, file: str) -> None:
"""
self._sanity_check()

# TODO (Min): this should return a list of sha1 for the history as well as
# each commit's message, which could be a dict from json commit msg.
if file:
print(f"wgit log of the file: {file}")
else:
Expand All @@ -263,25 +370,22 @@ def checkout(self, sha1: str) -> None:
The sha1 hash of the file version to checkout.
"""
self._sanity_check()
raise NotImplementedError

def compression(self) -> None:
"""Not Implemented: Compression functionalities"""
self._sanity_check()
raise NotImplementedError
raise NotImplementedError()

def checkout_by_steps(self) -> None:
"""Not Implemented: Checkout by steps"""
"""Not Implemented: Checkout by step count of the train process"""
self._sanity_check()
raise NotImplementedError
raise NotImplementedError()

def _get_metdata_files(self) -> Dict:
def _get_metdata_files(self) -> Dict[str, bool]:
"""Walk the directories that contain the metadata files and check the
status of those files, whether they have been modified or not.
Dict[str, bool] is a path in string and whether the file is_modified.
"""
metadata_d = dict()
for file in self._dot_wgit_dir_path.iterdir(): # iterate over the .wgit directory
# exlude all the .wgit files and directory
# exclude all the .wgit files and directory
if file.name not in {"sha1_store", ".git", ".gitignore"}:
# perform a directory walk on the metadata_file directories to find the metadata files
for path in file.rglob("*"):
Expand All @@ -297,11 +401,7 @@ def _is_metadata_file(self, file: Path) -> bool:
try:
with open(file) as f:
metadata = json.load(f)
is_metadata = set(metadata.keys()) == {
"SHA1",
"file_path",
"last_modified_time_stamp",
} # TODO: Consider storing the keys as a class attribute, instead of hard coding.
is_metadata = set(metadata.keys()) == {SHA1_KEY, LAST_MODIFIED_TS_KEY, REL_PATH_KEY}
except json.JSONDecodeError:
return False # not a json file, so not valid metadata file
return is_metadata
Expand All @@ -315,8 +415,8 @@ def _is_file_modified(self, file: Path) -> bool:
# Get the last modified timestamp recorded by weigit and the current modified
# timestamp. If not the same, then file has been modified since last weigit
# updated metadata.
last_mod_timestamp = data["last_modified_time_stamp"]
curr_mod_timestamp = Path(data["file_path"]).stat().st_mtime
last_mod_timestamp = data[LAST_MODIFIED_TS_KEY]
curr_mod_timestamp = Path(data[REL_PATH_KEY]).stat().st_mtime
return not curr_mod_timestamp == last_mod_timestamp

def _process_metadata_file(self, metadata_fname: Path) -> Path:
Expand All @@ -334,13 +434,13 @@ def _process_metadata_file(self, metadata_fname: Path) -> Path:
ref_data = json.load(f)
return metadata_file

def _write_metadata(self, metadata_file: Path, file_path: Path, sha1_dict: Dict) -> None:
def _write_metadata(self, metadata_file: Path, file_path: Path, sha1: str) -> None:
"""Write metadata to the metadata file"""
change_time = Path(file_path).stat().st_mtime
metadata = {
"SHA1": sha1_dict,
"file_path": str(file_path),
"last_modified_time_stamp": change_time,
SHA1_KEY: sha1,
LAST_MODIFIED_TS_KEY: change_time,
REL_PATH_KEY: str(file_path),
}
with open(metadata_file, "w", encoding="utf-8") as f:
json.dump(metadata, f, ensure_ascii=False, indent=4)
Expand All @@ -356,11 +456,3 @@ def _rel_file_path(self, filepath: Path) -> Path:
pass
# return the relative part (path not common to cwd)
return Path(*filepath.parts[i:])


class RepoStatus(Enum):
"""Collections of Repo Statuses"""

CLEAN = 1
CHANGES_NOT_ADDED = 2
CHANGES_ADDED_NOT_COMMITED = 3

0 comments on commit 2e544bd

Please sign in to comment.