Skip to content

Commit

Permalink
Use var.path rather than var.name in legacy h5 saving.
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Apr 12, 2024
1 parent 5c60112 commit 2c30d86
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
12 changes: 6 additions & 6 deletions keras/legacy/saving/legacy_h5_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def save_subset_weights_to_hdf5_group(f, weights):
weights: List of weight variables.
"""
weight_values = [backend.convert_to_numpy(w) for w in weights]
weight_names = [w.name.encode("utf8") for w in weights]
weight_names = [str(w.path).encode("utf8") for w in weights]
save_attributes_to_hdf5_group(f, "weight_names", weight_names)
for name, val in zip(weight_names, weight_values):
param_dset = f.create_dataset(name, val.shape, dtype=val.dtype)
Expand All @@ -255,7 +255,7 @@ def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):
symbolic_weights = getattr(optimizer, "weights")
if symbolic_weights:
weights_group = hdf5_group.create_group("optimizer_weights")
weight_names = [str(w.name).encode("utf8") for w in symbolic_weights]
weight_names = [str(w.path).encode("utf8") for w in symbolic_weights]
save_attributes_to_hdf5_group(
weights_group, "weight_names", weight_names
)
Expand Down Expand Up @@ -464,7 +464,7 @@ def load_weights_from_hdf5_group_by_name(f, model, skip_mismatch=False):
warnings.warn(
f"Skipping loading weights for layer #{k} (named "
f"{layer.name}) due to mismatch in shape for "
f"weight {symbolic_weights[i].name}. "
f"weight {symbolic_weights[i].path}. "
f"Weight expects shape {expected_shape}. "
"Received saved weight "
f"with shape {received_shape}",
Expand All @@ -473,7 +473,7 @@ def load_weights_from_hdf5_group_by_name(f, model, skip_mismatch=False):
continue
raise ValueError(
f"Shape mismatch in layer #{k} (named {layer.name}) "
f"for weight {symbolic_weights[i].name}. "
f"for weight {symbolic_weights[i].path}. "
f"Weight expects shape {expected_shape}. "
"Received saved weight "
f"with shape {received_shape}"
Expand Down Expand Up @@ -513,7 +513,7 @@ def load_weights_from_hdf5_group_by_name(f, model, skip_mismatch=False):
warnings.warn(
"Skipping loading top-level weight for model due "
"to mismatch in shape for "
f"weight {symbolic_weights[i].name}. "
f"weight {symbolic_weights[i].path}. "
f"Weight expects shape {expected_shape}. "
"Received saved weight "
f"with shape {received_shape}",
Expand All @@ -522,7 +522,7 @@ def load_weights_from_hdf5_group_by_name(f, model, skip_mismatch=False):
else:
raise ValueError(
"Shape mismatch in model for top-level weight "
f"{symbolic_weights[i].name}. "
f"{symbolic_weights[i].path}. "
f"Weight expects shape {expected_shape}. "
"Received saved weight "
f"with shape {received_shape}"
Expand Down
16 changes: 16 additions & 0 deletions keras/saving/saving_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,3 +845,19 @@ def call(self, x):
self.assertAllClose(
model_a.dense.kernel.numpy(), model_b.dense.kernel.numpy()
)

def test_legacy_h5_format(self):
temp_filepath = os.path.join(self.get_temp_dir(), "custom_model.h5")

inputs = keras.Input((32,))
x = MyDense(2)(inputs)
outputs = CustomModelX()(x)
model = keras.Model(inputs, outputs)

x = np.random.random((1, 32))
ref_out = model(x)

model.save(temp_filepath)
new_model = keras.saving.load_model(temp_filepath)
out = new_model(x)
self.assertAllClose(ref_out, out, atol=1e-6)

0 comments on commit 2c30d86

Please sign in to comment.