From 4058a18a3a082233514c4433d6f6270e66dd8b2e Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Wed, 6 Dec 2023 00:09:29 +0000 Subject: [PATCH 1/6] Add initial commit for sharded H5 store --- keras/saving/saving_lib.py | 62 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 62 insertions(+) diff --git a/keras/saving/saving_lib.py b/keras/saving/saving_lib.py index 7342bcc8f22..0f0dac48c04 100644 --- a/keras/saving/saving_lib.py +++ b/keras/saving/saving_lib.py @@ -557,6 +557,68 @@ def close(self): self.io_file.close() +class ShardedH5IOStore: + def __init__(self, root_path, max_size="10GB", archive=None, mode="r"): + self.shard_list = [] + self.root_path = root_path + self.mode = mode + self.archive = archive + self.io_file = None + self.max_size = convert_str_bytes_to_int(max_size) + self.h5_file = self._create_new_file() + + def _create_new_file(self, root_path): + if self.h5_file in self.shard_list: + self.root_path = str(root_path).replace(".weights.h5") + if self.archive: + if self.mode == "w": + self.io_file = io.BytesIO() + else: + self.io_file = self.archive.open(self.root_path, "r") + return h5py.File(self.io_file, mode=self.mode) + else: + return h5py.File(self.root_path, mode=self.mode) + + def make(self, path): + if self.current_shard_size > self.max_size: + self.close() + self.shard_list.append(self.h5_file) + self.h5_file = self._create_new_file() + if not path: + return self.h5_file.create_group("vars") + return self.h5_file.create_group(path).create_group("vars") + + def close(self): + self.h5_file.close() + if self.mode == "w" and self.archive: + self.archive.writestr(self.root_path, self.io_file.getvalue()) + if self.io_file: + self.io_file.close() + + + +def convert_str_bytes_to_int(size): + if size.upper().endswith("GB"): + return int(size[:-2]) * (10**9) + if size.upper().endswith("MB"): + return int(size[:-2]) * (10**6) + if size.upper().endswith("KB"): + return int(size[:-2]) * (10**3) + raise ValueError( + "Invalid format for `size`. Use an integer followed by the unit " + "(GB, MB, or KB). For example, '5GB' or '15MB'." + ) + + +def dtype_to_bytes(dtype): + if "bool" in str(dtype): + return 1 / 8 + bits = re.search(r"[^\d](\d+)$", str(dtype)) + if bits is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + return int(bits.groups()[0]) // 8 # Bit size in bytes + + class NpzIOStore: def __init__(self, root_path, archive=None, mode="r"): """Numerical variable store backed by NumPy.savez/load. From 51ea84890b7dab33b33c94ad91524b2e7751cf34 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Fri, 8 Mar 2024 21:37:55 +0000 Subject: [PATCH 2/6] Finish ShardedH5IOStore initial implementation --- keras/saving/saving_lib.py | 66 ++++++++++++++++++++++++++++++++------ 1 file changed, 56 insertions(+), 10 deletions(-) diff --git a/keras/saving/saving_lib.py b/keras/saving/saving_lib.py index 0f0dac48c04..6957269017f 100644 --- a/keras/saving/saving_lib.py +++ b/keras/saving/saving_lib.py @@ -6,6 +6,8 @@ import tempfile import warnings import zipfile +import re +import os import numpy as np @@ -560,33 +562,67 @@ def close(self): class ShardedH5IOStore: def __init__(self, root_path, max_size="10GB", archive=None, mode="r"): self.shard_list = [] + self.var_shard_map_filename = str(root_path).replace(".weights.h5", ".weights.json") + if self.mode == "w" and not os.path.exists(self.var_shard_map_filename): + self.var_shard_map = {} + else: + with open(self.var_shard_map_filename) as map_json: + self.var_shard_map = json.load(map_json) self.root_path = root_path self.mode = mode self.archive = archive self.io_file = None self.max_size = convert_str_bytes_to_int(max_size) - self.h5_file = self._create_new_file() + self.h5_file = self._create_new_file(root_path) - def _create_new_file(self, root_path): - if self.h5_file in self.shard_list: - self.root_path = str(root_path).replace(".weights.h5") + def _create_new_file(self, path): + if path in self.shard_list: + path = resolve_duplicate_filename(str(path), self.shard_list) + self.root_path = path if self.archive: if self.mode == "w": self.io_file = io.BytesIO() else: - self.io_file = self.archive.open(self.root_path, "r") + self.io_file = self.archive.open(path "r") + return h5py.File(self.io_file, mode=self.mode) + else: + return h5py.File(path, mode=self.mode) + + def _change_access_file(self, filename): # Read-only + self.close() + if self.archive: + self.io_file = self.archive.open(filename, "r") return h5py.File(self.io_file, mode=self.mode) else: - return h5py.File(self.root_path, mode=self.mode) - + return h5py.File(path, mode=self.mode) + + def make(self, path): if self.current_shard_size > self.max_size: self.close() - self.shard_list.append(self.h5_file) + self.shard_list.append(self.h5_file.filename) self.h5_file = self._create_new_file() if not path: - return self.h5_file.create_group("vars") - return self.h5_file.create_group(path).create_group("vars") + group = self.h5_file.create_group("vars") + else: + group = self.h5_file.create_group(path).create_group("vars") + self.var_shard_map[group.name] = self.root_path + return group + + def get(self, path): + if not path: + return self.h5_file["vars"] + if path in self.h5_file and "vars" in self.h5_file[path]: + return self.h5_file[path]["vars"] + + # If not found, check shard map and switch files + filename = self.var_shard_map.get(path) + if filename is not None and self.h5_file.name != filename: + new_file = self._change_access_file(filename) + if "vars" in new_file[path]: + self.h5_file = new_file + return self.h5_file[path]["vars"] + return {} def close(self): self.h5_file.close() @@ -610,6 +646,16 @@ def convert_str_bytes_to_int(size): ) +def resolve_duplicate_filename(path, path_list): + pattern = re.compile("_\d\.weights\.h5") + pre_duplicate = pattern.split(path)[0] # Check for pre-existing duplicate + if not pre_duplicate.endswith(".weights.h5"): + match_list = filter(lambda x: x.startswith(pre_duplicate), path_list) + if len(match_list) > 1: + return pre_duplicate + "_" + str(len(match_list)) + ".weights.h5" + return path.replace(".weights.h5", "_1.weights.h5") + + def dtype_to_bytes(dtype): if "bool" in str(dtype): return 1 / 8 From a7d29aacb650611a95ff77b61113f65c0ffbb38e Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Fri, 8 Mar 2024 22:31:10 +0000 Subject: [PATCH 3/6] Add sharding to saving and loading API logic with associated errors --- keras/saving/saving_lib.py | 79 ++++++++++++++++++++++++++++++-------- 1 file changed, 63 insertions(+), 16 deletions(-) diff --git a/keras/saving/saving_lib.py b/keras/saving/saving_lib.py index 6957269017f..3806da8afd4 100644 --- a/keras/saving/saving_lib.py +++ b/keras/saving/saving_lib.py @@ -35,7 +35,7 @@ _ASSETS_DIRNAME = "assets" -def save_model(model, filepath, weights_format="h5"): +def save_model(model, filepath, weights_format="h5", sharded=False, max_size="10GB"): """Save a zip-archive representing a Keras model to the given filepath. The zip-based archive contains the following structure: @@ -67,6 +67,12 @@ def save_model(model, filepath, weights_format="h5"): ) if weights_format == "h5" and h5py is None: raise ImportError("h5py must be installed in order to save a model.") + if weights_format != "h5" and sharded: + raise NotImplementedError( + "Sharding is only currently supported in the H5 weights format. " + "Please pass `sharded=False` or switch to `weights_format=h5`. " + f"Received: weights_format={weights_format}, sharded={sharded}." + ) if not model.built: warnings.warn( @@ -99,7 +105,15 @@ def save_model(model, filepath, weights_format="h5"): f.write(config_json.encode()) if weights_format == "h5": - weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="w") + if sharded: + weights_store = ShardedH5IOStore( + _VARS_FNAME + ".h5", + archive=zf, + mode="w", + max_size=max_size, + ) + else: + weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="w") elif weights_format == "npz": weights_store = NpzIOStore( _VARS_FNAME + ".npz", archive=zf, mode="w" @@ -158,7 +172,14 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): all_filenames = zf.namelist() if _VARS_FNAME + ".h5" in all_filenames: - weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="r") + if _VARS_FNAME + ".json" in all_filenames: + weights_store = ShardedH5IOStore( + _VARS_FNAME + ".h5", + archive=zf, + mode="r", + ) + else: + weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="r") elif _VARS_FNAME + ".npz" in all_filenames: weights_store = NpzIOStore( _VARS_FNAME + ".npz", archive=zf, mode="r" @@ -186,7 +207,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): return model -def save_weights_only(model, filepath): +def save_weights_only(model, filepath, sharded=False, max_size="10GB"): """Save only the weights of a model to a target filepath (.weights.h5). Note: only supports h5 for now. @@ -199,7 +220,10 @@ def save_weights_only(model, filepath): "Invalid `filepath` argument: expected a `.weights.h5` extension. " f"Received: filepath={filepath}" ) - weights_store = H5IOStore(filepath, mode="w") + if sharded: + weights_store = ShardedH5IOStore(filepath, mode="w", max_size=max_size) + else: + weights_store = H5IOStore(filepath, mode="w") _save_state( model, weights_store=weights_store, @@ -210,7 +234,7 @@ def save_weights_only(model, filepath): weights_store.close() -def load_weights_only(model, filepath, skip_mismatch=False): +def load_weights_only(model, filepath, sharded=False, skip_mismatch=False): """Load the weights of a model from a filepath (.keras or .weights.h5). Note: only supports h5 for now. @@ -220,12 +244,23 @@ def load_weights_only(model, filepath, skip_mismatch=False): filepath = str(filepath) if filepath.endswith(".weights.h5"): # TODO: download file if h5 filepath is remote - weights_store = H5IOStore(filepath, mode="r") + if sharded: + weights_store = ShardedH5IOStore(filepath, mode="r") + else: + weights_store = H5IOStore(filepath, mode="r") elif filepath.endswith(".keras"): archive = zipfile.ZipFile(filepath, "r") - weights_store = H5IOStore( - _VARS_FNAME + ".h5", archive=archive, mode="r" - ) + all_filenames = archive.namelist() + if _VARS_FNAME + ".json" in all_filenames: + weights_store = ShardedH5IOStore( + _VARS_FNAME + ".h5", + archive=archive, + mode="r", + ) + else: + weights_store = H5IOStore( + _VARS_FNAME + ".h5", archive=archive, mode="r" + ) _load_state( model, @@ -563,11 +598,20 @@ class ShardedH5IOStore: def __init__(self, root_path, max_size="10GB", archive=None, mode="r"): self.shard_list = [] self.var_shard_map_filename = str(root_path).replace(".weights.h5", ".weights.json") - if self.mode == "w" and not os.path.exists(self.var_shard_map_filename): - self.var_shard_map = {} + if not os.path.exists(self.var_shard_map_filename): + if self.mode == "w": + self.var_shard_map = {} + if self.mode =="r": + raise FileNotFoundError( + f"Loading a sharded `.weights.h5` file requires " + "its corresponding sharding map JSON file " + f"{self.var_shard_map_filename} in the same directory. " + "Please ensure all weights files and the sharding map JSON file " + "are in the same directory when loading a sharded weights file." + ) else: - with open(self.var_shard_map_filename) as map_json: - self.var_shard_map = json.load(map_json) + with open(self.var_shard_map_filename, "r") as map_file: + self.var_shard_map = json.load(map_file) self.root_path = root_path self.mode = mode self.archive = archive @@ -626,8 +670,11 @@ def get(self, path): def close(self): self.h5_file.close() - if self.mode == "w" and self.archive: - self.archive.writestr(self.root_path, self.io_file.getvalue()) + if self.mode == "w": + with open(self.var_shard_map_filename, "w") as map_file: + map_file.write(json.dumps(self.var_shard_map)) + if self.archive: + self.archive.writestr(self.root_path, self.io_file.getvalue()) if self.io_file: self.io_file.close() From de397ed2773fd79996139e725aa5ffe38bf17d03 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Sat, 9 Mar 2024 02:01:26 +0000 Subject: [PATCH 4/6] Fix sharding API and add size check --- keras/models/model.py | 4 +++- keras/saving/saving_api.py | 9 ++++++--- keras/saving/saving_lib.py | 28 +++++++++++++++++++--------- keras/saving/saving_lib_test.py | 11 +++++++++++ 4 files changed, 39 insertions(+), 13 deletions(-) diff --git a/keras/models/model.py b/keras/models/model.py index ee89bb42493..04c0483b03f 100644 --- a/keras/models/model.py +++ b/keras/models/model.py @@ -305,6 +305,8 @@ def save(self, filepath, overwrite=True, **kwargs): """ include_optimizer = kwargs.pop("include_optimizer", True) save_format = kwargs.pop("save_format", None) + sharded = kwargs.pop("sharded", False) + shard_size = kwargs.pop("shard_size", None) if kwargs: raise ValueError( "The following argument(s) are not supported: " @@ -336,7 +338,7 @@ def save(self, filepath, overwrite=True, **kwargs): if not proceed: return if str(filepath).endswith(".keras"): - saving_lib.save_model(self, filepath) + saving_lib.save_model(self, filepath, sharded, shard_size) elif str(filepath).endswith((".h5", ".hdf5")): # Deprecation warnings warnings.warn( diff --git a/keras/saving/saving_api.py b/keras/saving/saving_api.py index d0fb5697e3e..9c46edf6382 100644 --- a/keras/saving/saving_api.py +++ b/keras/saving/saving_api.py @@ -52,6 +52,8 @@ def save_model(model, filepath, overwrite=True, **kwargs): """ include_optimizer = kwargs.pop("include_optimizer", True) save_format = kwargs.pop("save_format", False) + sharded = kwargs.pop("sharded", False) + shard_size = kwargs.pop("shard_size", None) if save_format: if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith( ".keras" @@ -94,7 +96,7 @@ def save_model(model, filepath, overwrite=True, **kwargs): proceed = io_utils.ask_to_proceed_with_overwrite(filepath) if not proceed: return - saving_lib.save_model(model, filepath) + saving_lib.save_model(model, filepath, sharded, shard_size) elif str(filepath).endswith((".h5", ".hdf5")): legacy_h5_format.save_model_to_hdf5( model, filepath, overwrite, include_optimizer @@ -201,17 +203,18 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): def load_weights(model, filepath, skip_mismatch=False, **kwargs): + sharded = kwargs.pop("sharded", False) if str(filepath).endswith(".keras"): if kwargs: raise ValueError(f"Invalid keyword arguments: {kwargs}") saving_lib.load_weights_only( - model, filepath, skip_mismatch=skip_mismatch + model, filepath, sharded=sharded, skip_mismatch=skip_mismatch ) elif str(filepath).endswith(".weights.h5"): if kwargs: raise ValueError(f"Invalid keyword arguments: {kwargs}") saving_lib.load_weights_only( - model, filepath, skip_mismatch=skip_mismatch + model, filepath, sharded=sharded, skip_mismatch=skip_mismatch ) elif str(filepath).endswith(".h5") or str(filepath).endswith(".hdf5"): by_name = kwargs.pop("by_name", False) diff --git a/keras/saving/saving_lib.py b/keras/saving/saving_lib.py index 3806da8afd4..7572b014f10 100644 --- a/keras/saving/saving_lib.py +++ b/keras/saving/saving_lib.py @@ -35,7 +35,7 @@ _ASSETS_DIRNAME = "assets" -def save_model(model, filepath, weights_format="h5", sharded=False, max_size="10GB"): +def save_model(model, filepath, weights_format="h5", sharded=False, shard_size=None): """Save a zip-archive representing a Keras model to the given filepath. The zip-based archive contains the following structure: @@ -106,6 +106,7 @@ def save_model(model, filepath, weights_format="h5", sharded=False, max_size="10 if weights_format == "h5": if sharded: + max_size = shard_size if shard_size is not None else "10GB" weights_store = ShardedH5IOStore( _VARS_FNAME + ".h5", archive=zf, @@ -207,7 +208,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): return model -def save_weights_only(model, filepath, sharded=False, max_size="10GB"): +def save_weights_only(model, filepath, sharded=False, shard_size=None): """Save only the weights of a model to a target filepath (.weights.h5). Note: only supports h5 for now. @@ -221,6 +222,7 @@ def save_weights_only(model, filepath, sharded=False, max_size="10GB"): f"Received: filepath={filepath}" ) if sharded: + max_size = shard_size if shard_size is not None else "10GB" weights_store = ShardedH5IOStore(filepath, mode="w", max_size=max_size) else: weights_store = H5IOStore(filepath, mode="w") @@ -597,6 +599,13 @@ def close(self): class ShardedH5IOStore: def __init__(self, root_path, max_size="10GB", archive=None, mode="r"): self.shard_list = [] + self.root_path = root_path + self.mode = mode + self.archive = archive + self.io_file = None + self.max_size = convert_str_bytes_to_int(max_size) + self.current_shard_size = 0 + self.var_shard_map_filename = str(root_path).replace(".weights.h5", ".weights.json") if not os.path.exists(self.var_shard_map_filename): if self.mode == "w": @@ -612,11 +621,7 @@ def __init__(self, root_path, max_size="10GB", archive=None, mode="r"): else: with open(self.var_shard_map_filename, "r") as map_file: self.var_shard_map = json.load(map_file) - self.root_path = root_path - self.mode = mode - self.archive = archive - self.io_file = None - self.max_size = convert_str_bytes_to_int(max_size) + self.h5_file = self._create_new_file(root_path) def _create_new_file(self, path): @@ -627,7 +632,7 @@ def _create_new_file(self, path): if self.mode == "w": self.io_file = io.BytesIO() else: - self.io_file = self.archive.open(path "r") + self.io_file = self.archive.open(path, "r") return h5py.File(self.io_file, mode=self.mode) else: return h5py.File(path, mode=self.mode) @@ -640,8 +645,13 @@ def _change_access_file(self, filename): # Read-only else: return h5py.File(path, mode=self.mode) - def make(self, path): + def _get_size(key): + if isinstance(self.h5_file[key], h5py.Dataset): + self.current_shard_size += self.h5_file[key].nbytes + + self.current_shard_size = 0 + self.h5_file.visit(_get_size) if self.current_shard_size > self.max_size: self.close() self.shard_list.append(self.h5_file.filename) diff --git a/keras/saving/saving_lib_test.py b/keras/saving/saving_lib_test.py index 20871df6aa6..3a95ed947be 100644 --- a/keras/saving/saving_lib_test.py +++ b/keras/saving/saving_lib_test.py @@ -520,6 +520,17 @@ def test_partial_load(self): np.array(new_model.layers[2].kernel), new_layer_kernel_value ) + def test_model_sharding(self): + model = _get_basic_functional_model() + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.weights.h5") + ref_input = np.random.random((2, 4)) + ref_output = model.predict(ref_input) + saving_lib.save_weights_only(model, temp_filepath, sharded=True, shard_size="50") + + model = _get_basic_functional_model() + model.load_weights(temp_filepath, sharded=True) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + @pytest.mark.requires_trainable_backend class SavingAPITest(testing.TestCase): From 933c387284bba0da4abc59c89317707336d58bc0 Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Mon, 11 Mar 2024 18:57:54 +0000 Subject: [PATCH 5/6] Add large model test and debug sharding algorithm --- keras/saving/saving_lib.py | 10 +++---- keras/saving/saving_lib_test.py | 53 ++++++++++++++++++++++++++------- 2 files changed, 47 insertions(+), 16 deletions(-) diff --git a/keras/saving/saving_lib.py b/keras/saving/saving_lib.py index 7572b014f10..4ab0febeffc 100644 --- a/keras/saving/saving_lib.py +++ b/keras/saving/saving_lib.py @@ -643,7 +643,7 @@ def _change_access_file(self, filename): # Read-only self.io_file = self.archive.open(filename, "r") return h5py.File(self.io_file, mode=self.mode) else: - return h5py.File(path, mode=self.mode) + return h5py.File(filename, mode=self.mode) def make(self, path): def _get_size(key): @@ -653,9 +653,9 @@ def _get_size(key): self.current_shard_size = 0 self.h5_file.visit(_get_size) if self.current_shard_size > self.max_size: - self.close() self.shard_list.append(self.h5_file.filename) - self.h5_file = self._create_new_file() + self.close() + self.h5_file = self._create_new_file(self.root_path) if not path: group = self.h5_file.create_group("vars") else: @@ -670,7 +670,7 @@ def get(self, path): return self.h5_file[path]["vars"] # If not found, check shard map and switch files - filename = self.var_shard_map.get(path) + filename = self.var_shard_map.get(path) or self.var_shard_map.get("/" + path +"/vars") if filename is not None and self.h5_file.name != filename: new_file = self._change_access_file(filename) if "vars" in new_file[path]: @@ -707,7 +707,7 @@ def resolve_duplicate_filename(path, path_list): pattern = re.compile("_\d\.weights\.h5") pre_duplicate = pattern.split(path)[0] # Check for pre-existing duplicate if not pre_duplicate.endswith(".weights.h5"): - match_list = filter(lambda x: x.startswith(pre_duplicate), path_list) + match_list = list(filter(lambda x: x.startswith(pre_duplicate), path_list)) if len(match_list) > 1: return pre_duplicate + "_" + str(len(match_list)) + ".weights.h5" return path.replace(".weights.h5", "_1.weights.h5") diff --git a/keras/saving/saving_lib_test.py b/keras/saving/saving_lib_test.py index 3a95ed947be..c5b6c4310aa 100644 --- a/keras/saving/saving_lib_test.py +++ b/keras/saving/saving_lib_test.py @@ -520,17 +520,6 @@ def test_partial_load(self): np.array(new_model.layers[2].kernel), new_layer_kernel_value ) - def test_model_sharding(self): - model = _get_basic_functional_model() - temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.weights.h5") - ref_input = np.random.random((2, 4)) - ref_output = model.predict(ref_input) - saving_lib.save_weights_only(model, temp_filepath, sharded=True, shard_size="50") - - model = _get_basic_functional_model() - model.load_weights(temp_filepath, sharded=True) - self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) - @pytest.mark.requires_trainable_backend class SavingAPITest(testing.TestCase): @@ -699,6 +688,48 @@ def call(self, inputs): return self.first_layer(self.second_layer(inputs)) +def _get_large_model(): + model = keras.Sequential( + [ + keras.layers.Input(shape=[28, 28, 1], dtype="float32"), + keras.layers.Conv2D(filters=12, kernel_size=3, padding='same', name="conv1", use_bias=False), # no bias necessary before batch norm + keras.layers.BatchNormalization(scale=False, center=True), # no batch norm scaling necessary before "relu" + keras.layers.Activation('relu'), # activation after batch norm + + keras.layers.Conv2D(filters=24, kernel_size=6, padding='same', name="conv2", use_bias=False, strides=2), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation('relu'), + + keras.layers.Conv2D(filters=32, kernel_size=6, padding='same', name="conv3", use_bias=False, strides=2), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation('relu'), + + keras.layers.Flatten(), + keras.layers.Dense(200, name="dense1", use_bias=False), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation('relu'), + keras.layers.Dropout(0.4), # Dropout on dense layer only + + keras.layers.Dense(10, name="dense2", activation='softmax') + ] + ) + model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]) + return model + + +class LargeModelTest(testing.TestCase): + def test_model_sharding(self): + model = _get_large_model() + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.weights.h5") + ref_input = np.random.random((1, 28, 28, 1)) + ref_output = model.predict(ref_input) + saving_lib.save_weights_only(model, temp_filepath, sharded=True, shard_size="10KB") + + model = _get_large_model() + model.load_weights(temp_filepath, sharded=True) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + + class SavingBattleTest(testing.TestCase): def test_custom_object_without_from_config(self): temp_filepath = os.path.join( From cfbb761e9538352848afc82dd27eec5a77769dde Mon Sep 17 00:00:00 2001 From: Neel Kovelamudi Date: Mon, 11 Mar 2024 20:56:14 +0000 Subject: [PATCH 6/6] Fix formatting --- keras/saving/saving_api.py | 7 +++- keras/saving/saving_lib.py | 38 +++++++++++------ keras/saving/saving_lib_test.py | 74 ++++++++++++++++++++++----------- 3 files changed, 80 insertions(+), 39 deletions(-) diff --git a/keras/saving/saving_api.py b/keras/saving/saving_api.py index 5122c88adc4..c8c2a85b156 100644 --- a/keras/saving/saving_api.py +++ b/keras/saving/saving_api.py @@ -99,7 +99,12 @@ def save_model(model, filepath, overwrite=True, **kwargs): return if str(filepath).endswith(".keras"): - saving_lib.save_model(model, filepath) + saving_lib.save_model( + model, + filepath, + sharded=sharded, + shard_size=shard_size, + ) elif str(filepath).endswith((".h5", ".hdf5")): legacy_h5_format.save_model_to_hdf5( model, filepath, overwrite, include_optimizer diff --git a/keras/saving/saving_lib.py b/keras/saving/saving_lib.py index 0c8073f0bfe..fdccd32bc9f 100644 --- a/keras/saving/saving_lib.py +++ b/keras/saving/saving_lib.py @@ -3,11 +3,11 @@ import datetime import io import json +import os +import re import tempfile import warnings import zipfile -import re -import os import ml_dtypes import numpy as np @@ -37,7 +37,9 @@ _ASSETS_DIRNAME = "assets" -def save_model(model, filepath, weights_format="h5", sharded=False, shard_size=None): +def save_model( + model, filepath, weights_format="h5", sharded=False, shard_size=None +): """Save a zip-archive representing a Keras model to the given filepath. The zip-based archive contains the following structure: @@ -116,7 +118,9 @@ def save_model(model, filepath, weights_format="h5", sharded=False, shard_size=N max_size=max_size, ) else: - weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="w") + weights_store = H5IOStore( + _VARS_FNAME + ".h5", archive=zf, mode="w" + ) elif weights_format == "npz": weights_store = NpzIOStore( _VARS_FNAME + ".npz", archive=zf, mode="w" @@ -182,7 +186,9 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): mode="r", ) else: - weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="r") + weights_store = H5IOStore( + _VARS_FNAME + ".h5", archive=zf, mode="r" + ) elif _VARS_FNAME + ".npz" in all_filenames: weights_store = NpzIOStore( _VARS_FNAME + ".npz", archive=zf, mode="r" @@ -666,17 +672,20 @@ def __init__(self, root_path, max_size="10GB", archive=None, mode="r"): self.max_size = convert_str_bytes_to_int(max_size) self.current_shard_size = 0 - self.var_shard_map_filename = str(root_path).replace(".weights.h5", ".weights.json") + self.var_shard_map_filename = str(root_path).replace( + ".weights.h5", ".weights.json" + ) if not os.path.exists(self.var_shard_map_filename): if self.mode == "w": self.var_shard_map = {} - if self.mode =="r": + if self.mode == "r": raise FileNotFoundError( f"Loading a sharded `.weights.h5` file requires " "its corresponding sharding map JSON file " f"{self.var_shard_map_filename} in the same directory. " - "Please ensure all weights files and the sharding map JSON file " - "are in the same directory when loading a sharded weights file." + "Please ensure all weights files and the sharding map " + "JSON file are in the same directory when loading a " + "sharded weights file." ) else: with open(self.var_shard_map_filename, "r") as map_file: @@ -730,7 +739,9 @@ def get(self, path): return self.h5_file[path]["vars"] # If not found, check shard map and switch files - filename = self.var_shard_map.get(path) or self.var_shard_map.get("/" + path +"/vars") + filename = self.var_shard_map.get(path) or self.var_shard_map.get( + "/" + path + "/vars" + ) if filename is not None and self.h5_file.name != filename: new_file = self._change_access_file(filename) if "vars" in new_file[path]: @@ -749,7 +760,6 @@ def close(self): self.io_file.close() - def convert_str_bytes_to_int(size): if size.upper().endswith("GB"): return int(size[:-2]) * (10**9) @@ -764,10 +774,12 @@ def convert_str_bytes_to_int(size): def resolve_duplicate_filename(path, path_list): - pattern = re.compile("_\d\.weights\.h5") + pattern = re.compile(r"_\d\.weights\.h5") pre_duplicate = pattern.split(path)[0] # Check for pre-existing duplicate if not pre_duplicate.endswith(".weights.h5"): - match_list = list(filter(lambda x: x.startswith(pre_duplicate), path_list)) + match_list = list( + filter(lambda x: x.startswith(pre_duplicate), path_list) + ) if len(match_list) > 1: return pre_duplicate + "_" + str(len(match_list)) + ".weights.h5" return path.replace(".weights.h5", "_1.weights.h5") diff --git a/keras/saving/saving_lib_test.py b/keras/saving/saving_lib_test.py index e78d364a5c1..9cd9c14c0b8 100644 --- a/keras/saving/saving_lib_test.py +++ b/keras/saving/saving_lib_test.py @@ -756,30 +756,50 @@ def call(self, inputs): def _get_large_model(): model = keras.Sequential( - [ - keras.layers.Input(shape=[28, 28, 1], dtype="float32"), - keras.layers.Conv2D(filters=12, kernel_size=3, padding='same', name="conv1", use_bias=False), # no bias necessary before batch norm - keras.layers.BatchNormalization(scale=False, center=True), # no batch norm scaling necessary before "relu" - keras.layers.Activation('relu'), # activation after batch norm - - keras.layers.Conv2D(filters=24, kernel_size=6, padding='same', name="conv2", use_bias=False, strides=2), - keras.layers.BatchNormalization(scale=False, center=True), - keras.layers.Activation('relu'), - - keras.layers.Conv2D(filters=32, kernel_size=6, padding='same', name="conv3", use_bias=False, strides=2), - keras.layers.BatchNormalization(scale=False, center=True), - keras.layers.Activation('relu'), - - keras.layers.Flatten(), - keras.layers.Dense(200, name="dense1", use_bias=False), - keras.layers.BatchNormalization(scale=False, center=True), - keras.layers.Activation('relu'), - keras.layers.Dropout(0.4), # Dropout on dense layer only - - keras.layers.Dense(10, name="dense2", activation='softmax') - ] + [ + keras.layers.Input(shape=[28, 28, 1], dtype="float32"), + keras.layers.Conv2D( + filters=12, + kernel_size=3, + padding="same", + name="conv1", + use_bias=False, + ), # no bias necessary before batch norm + keras.layers.BatchNormalization( + scale=False, center=True + ), # no batch norm scaling necessary before "relu" + keras.layers.Activation("relu"), # activation after batch norm + keras.layers.Conv2D( + filters=24, + kernel_size=6, + padding="same", + name="conv2", + use_bias=False, + strides=2, + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Conv2D( + filters=32, + kernel_size=6, + padding="same", + name="conv3", + use_bias=False, + strides=2, + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Flatten(), + keras.layers.Dense(200, name="dense1", use_bias=False), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Dropout(0.4), # Dropout on dense layer only + keras.layers.Dense(10, name="dense2", activation="softmax"), + ] + ) + model.compile( + optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"] ) - model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"]) return model @@ -789,7 +809,9 @@ def test_model_sharding(self): temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.weights.h5") ref_input = np.random.random((1, 28, 28, 1)) ref_output = model.predict(ref_input) - saving_lib.save_weights_only(model, temp_filepath, sharded=True, shard_size="10KB") + saving_lib.save_weights_only( + model, temp_filepath, sharded=True, shard_size="1MB" + ) model = _get_large_model() model.load_weights(temp_filepath, sharded=True) @@ -851,7 +873,9 @@ def dense(self): def call(self, x): return self.dense(x) - temp_filepath = "normal_model.weights.h5" + temp_filepath = os.path.join( + self.get_temp_dir(), "normal_model.weights.h5" + ) model_a = NormalModel() model_a(np.random.random((2, 2))) model_a.save_weights(temp_filepath)