diff --git a/keras/saving/saving_api.py b/keras/saving/saving_api.py index 888219cc3a1..c8c2a85b156 100644 --- a/keras/saving/saving_api.py +++ b/keras/saving/saving_api.py @@ -52,6 +52,8 @@ def save_model(model, filepath, overwrite=True, **kwargs): """ include_optimizer = kwargs.pop("include_optimizer", True) save_format = kwargs.pop("save_format", False) + sharded = kwargs.pop("sharded", False) + shard_size = kwargs.pop("shard_size", None) if save_format: if str(filepath).endswith((".h5", ".hdf5")) or str(filepath).endswith( ".keras" @@ -97,7 +99,12 @@ def save_model(model, filepath, overwrite=True, **kwargs): return if str(filepath).endswith(".keras"): - saving_lib.save_model(model, filepath) + saving_lib.save_model( + model, + filepath, + sharded=sharded, + shard_size=shard_size, + ) elif str(filepath).endswith((".h5", ".hdf5")): legacy_h5_format.save_model_to_hdf5( model, filepath, overwrite, include_optimizer @@ -204,17 +211,18 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): def load_weights(model, filepath, skip_mismatch=False, **kwargs): + sharded = kwargs.pop("sharded", False) if str(filepath).endswith(".keras"): if kwargs: raise ValueError(f"Invalid keyword arguments: {kwargs}") saving_lib.load_weights_only( - model, filepath, skip_mismatch=skip_mismatch + model, filepath, sharded=sharded, skip_mismatch=skip_mismatch ) elif str(filepath).endswith(".weights.h5"): if kwargs: raise ValueError(f"Invalid keyword arguments: {kwargs}") saving_lib.load_weights_only( - model, filepath, skip_mismatch=skip_mismatch + model, filepath, sharded=sharded, skip_mismatch=skip_mismatch ) elif str(filepath).endswith(".h5") or str(filepath).endswith(".hdf5"): by_name = kwargs.pop("by_name", False) diff --git a/keras/saving/saving_lib.py b/keras/saving/saving_lib.py index 8e0f29a0cc8..fdccd32bc9f 100644 --- a/keras/saving/saving_lib.py +++ b/keras/saving/saving_lib.py @@ -3,6 +3,8 @@ import datetime import io import json +import os +import re import tempfile import warnings import zipfile @@ -35,7 +37,9 @@ _ASSETS_DIRNAME = "assets" -def save_model(model, filepath, weights_format="h5"): +def save_model( + model, filepath, weights_format="h5", sharded=False, shard_size=None +): """Save a zip-archive representing a Keras model to the given filepath. The zip-based archive contains the following structure: @@ -67,6 +71,12 @@ def save_model(model, filepath, weights_format="h5"): ) if weights_format == "h5" and h5py is None: raise ImportError("h5py must be installed in order to save a model.") + if weights_format != "h5" and sharded: + raise NotImplementedError( + "Sharding is only currently supported in the H5 weights format. " + "Please pass `sharded=False` or switch to `weights_format=h5`. " + f"Received: weights_format={weights_format}, sharded={sharded}." + ) if not model.built: warnings.warn( @@ -99,7 +109,18 @@ def save_model(model, filepath, weights_format="h5"): f.write(config_json.encode()) if weights_format == "h5": - weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="w") + if sharded: + max_size = shard_size if shard_size is not None else "10GB" + weights_store = ShardedH5IOStore( + _VARS_FNAME + ".h5", + archive=zf, + mode="w", + max_size=max_size, + ) + else: + weights_store = H5IOStore( + _VARS_FNAME + ".h5", archive=zf, mode="w" + ) elif weights_format == "npz": weights_store = NpzIOStore( _VARS_FNAME + ".npz", archive=zf, mode="w" @@ -158,7 +179,16 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): all_filenames = zf.namelist() if _VARS_FNAME + ".h5" in all_filenames: - weights_store = H5IOStore(_VARS_FNAME + ".h5", archive=zf, mode="r") + if _VARS_FNAME + ".json" in all_filenames: + weights_store = ShardedH5IOStore( + _VARS_FNAME + ".h5", + archive=zf, + mode="r", + ) + else: + weights_store = H5IOStore( + _VARS_FNAME + ".h5", archive=zf, mode="r" + ) elif _VARS_FNAME + ".npz" in all_filenames: weights_store = NpzIOStore( _VARS_FNAME + ".npz", archive=zf, mode="r" @@ -193,7 +223,7 @@ def load_model(filepath, custom_objects=None, compile=True, safe_mode=True): return model -def save_weights_only(model, filepath): +def save_weights_only(model, filepath, sharded=False, shard_size=None): """Save only the weights of a model to a target filepath (.weights.h5). Note: only supports h5 for now. @@ -206,7 +236,11 @@ def save_weights_only(model, filepath): "Invalid `filepath` argument: expected a `.weights.h5` extension. " f"Received: filepath={filepath}" ) - weights_store = H5IOStore(filepath, mode="w") + if sharded: + max_size = shard_size if shard_size is not None else "10GB" + weights_store = ShardedH5IOStore(filepath, mode="w", max_size=max_size) + else: + weights_store = H5IOStore(filepath, mode="w") _save_state( model, weights_store=weights_store, @@ -217,7 +251,7 @@ def save_weights_only(model, filepath): weights_store.close() -def load_weights_only(model, filepath, skip_mismatch=False): +def load_weights_only(model, filepath, sharded=False, skip_mismatch=False): """Load the weights of a model from a filepath (.keras or .weights.h5). Note: only supports h5 for now. @@ -227,12 +261,23 @@ def load_weights_only(model, filepath, skip_mismatch=False): filepath = str(filepath) if filepath.endswith(".weights.h5"): # TODO: download file if h5 filepath is remote - weights_store = H5IOStore(filepath, mode="r") + if sharded: + weights_store = ShardedH5IOStore(filepath, mode="r") + else: + weights_store = H5IOStore(filepath, mode="r") elif filepath.endswith(".keras"): archive = zipfile.ZipFile(filepath, "r") - weights_store = H5IOStore( - _VARS_FNAME + ".h5", archive=archive, mode="r" - ) + all_filenames = archive.namelist() + if _VARS_FNAME + ".json" in all_filenames: + weights_store = ShardedH5IOStore( + _VARS_FNAME + ".h5", + archive=archive, + mode="r", + ) + else: + weights_store = H5IOStore( + _VARS_FNAME + ".h5", archive=archive, mode="r" + ) failed_trackables = set() error_msgs = {} @@ -617,6 +662,138 @@ def close(self): self.io_file.close() +class ShardedH5IOStore: + def __init__(self, root_path, max_size="10GB", archive=None, mode="r"): + self.shard_list = [] + self.root_path = root_path + self.mode = mode + self.archive = archive + self.io_file = None + self.max_size = convert_str_bytes_to_int(max_size) + self.current_shard_size = 0 + + self.var_shard_map_filename = str(root_path).replace( + ".weights.h5", ".weights.json" + ) + if not os.path.exists(self.var_shard_map_filename): + if self.mode == "w": + self.var_shard_map = {} + if self.mode == "r": + raise FileNotFoundError( + f"Loading a sharded `.weights.h5` file requires " + "its corresponding sharding map JSON file " + f"{self.var_shard_map_filename} in the same directory. " + "Please ensure all weights files and the sharding map " + "JSON file are in the same directory when loading a " + "sharded weights file." + ) + else: + with open(self.var_shard_map_filename, "r") as map_file: + self.var_shard_map = json.load(map_file) + + self.h5_file = self._create_new_file(root_path) + + def _create_new_file(self, path): + if path in self.shard_list: + path = resolve_duplicate_filename(str(path), self.shard_list) + self.root_path = path + if self.archive: + if self.mode == "w": + self.io_file = io.BytesIO() + else: + self.io_file = self.archive.open(path, "r") + return h5py.File(self.io_file, mode=self.mode) + else: + return h5py.File(path, mode=self.mode) + + def _change_access_file(self, filename): # Read-only + self.close() + if self.archive: + self.io_file = self.archive.open(filename, "r") + return h5py.File(self.io_file, mode=self.mode) + else: + return h5py.File(filename, mode=self.mode) + + def make(self, path): + def _get_size(key): + if isinstance(self.h5_file[key], h5py.Dataset): + self.current_shard_size += self.h5_file[key].nbytes + + self.current_shard_size = 0 + self.h5_file.visit(_get_size) + if self.current_shard_size > self.max_size: + self.shard_list.append(self.h5_file.filename) + self.close() + self.h5_file = self._create_new_file(self.root_path) + if not path: + group = self.h5_file.create_group("vars") + else: + group = self.h5_file.create_group(path).create_group("vars") + self.var_shard_map[group.name] = self.root_path + return group + + def get(self, path): + if not path: + return self.h5_file["vars"] + if path in self.h5_file and "vars" in self.h5_file[path]: + return self.h5_file[path]["vars"] + + # If not found, check shard map and switch files + filename = self.var_shard_map.get(path) or self.var_shard_map.get( + "/" + path + "/vars" + ) + if filename is not None and self.h5_file.name != filename: + new_file = self._change_access_file(filename) + if "vars" in new_file[path]: + self.h5_file = new_file + return self.h5_file[path]["vars"] + return {} + + def close(self): + self.h5_file.close() + if self.mode == "w": + with open(self.var_shard_map_filename, "w") as map_file: + map_file.write(json.dumps(self.var_shard_map)) + if self.archive: + self.archive.writestr(self.root_path, self.io_file.getvalue()) + if self.io_file: + self.io_file.close() + + +def convert_str_bytes_to_int(size): + if size.upper().endswith("GB"): + return int(size[:-2]) * (10**9) + if size.upper().endswith("MB"): + return int(size[:-2]) * (10**6) + if size.upper().endswith("KB"): + return int(size[:-2]) * (10**3) + raise ValueError( + "Invalid format for `size`. Use an integer followed by the unit " + "(GB, MB, or KB). For example, '5GB' or '15MB'." + ) + + +def resolve_duplicate_filename(path, path_list): + pattern = re.compile(r"_\d\.weights\.h5") + pre_duplicate = pattern.split(path)[0] # Check for pre-existing duplicate + if not pre_duplicate.endswith(".weights.h5"): + match_list = list( + filter(lambda x: x.startswith(pre_duplicate), path_list) + ) + if len(match_list) > 1: + return pre_duplicate + "_" + str(len(match_list)) + ".weights.h5" + return path.replace(".weights.h5", "_1.weights.h5") + + +def dtype_to_bytes(dtype): + if "bool" in str(dtype): + return 1 / 8 + bits = re.search(r"[^\d](\d+)$", str(dtype)) + if bits is None: + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") + return int(bits.groups()[0]) // 8 # Bit size in bytes + + class H5Entry: """Leaf entry in a H5IOStore.""" diff --git a/keras/saving/saving_lib_test.py b/keras/saving/saving_lib_test.py index c77c8e28615..9cd9c14c0b8 100644 --- a/keras/saving/saving_lib_test.py +++ b/keras/saving/saving_lib_test.py @@ -754,6 +754,70 @@ def call(self, inputs): return self.first_layer(self.second_layer(inputs)) +def _get_large_model(): + model = keras.Sequential( + [ + keras.layers.Input(shape=[28, 28, 1], dtype="float32"), + keras.layers.Conv2D( + filters=12, + kernel_size=3, + padding="same", + name="conv1", + use_bias=False, + ), # no bias necessary before batch norm + keras.layers.BatchNormalization( + scale=False, center=True + ), # no batch norm scaling necessary before "relu" + keras.layers.Activation("relu"), # activation after batch norm + keras.layers.Conv2D( + filters=24, + kernel_size=6, + padding="same", + name="conv2", + use_bias=False, + strides=2, + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Conv2D( + filters=32, + kernel_size=6, + padding="same", + name="conv3", + use_bias=False, + strides=2, + ), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Flatten(), + keras.layers.Dense(200, name="dense1", use_bias=False), + keras.layers.BatchNormalization(scale=False, center=True), + keras.layers.Activation("relu"), + keras.layers.Dropout(0.4), # Dropout on dense layer only + keras.layers.Dense(10, name="dense2", activation="softmax"), + ] + ) + model.compile( + optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"] + ) + return model + + +class LargeModelTest(testing.TestCase): + def test_model_sharding(self): + model = _get_large_model() + temp_filepath = os.path.join(self.get_temp_dir(), "mymodel.weights.h5") + ref_input = np.random.random((1, 28, 28, 1)) + ref_output = model.predict(ref_input) + saving_lib.save_weights_only( + model, temp_filepath, sharded=True, shard_size="1MB" + ) + + model = _get_large_model() + model.load_weights(temp_filepath, sharded=True) + self.assertAllClose(model.predict(ref_input), ref_output, atol=1e-6) + + class SavingBattleTest(testing.TestCase): def test_custom_object_without_from_config(self): temp_filepath = os.path.join( @@ -809,7 +873,9 @@ def dense(self): def call(self, x): return self.dense(x) - temp_filepath = "normal_model.weights.h5" + temp_filepath = os.path.join( + self.get_temp_dir(), "normal_model.weights.h5" + ) model_a = NormalModel() model_a(np.random.random((2, 2))) model_a.save_weights(temp_filepath)