Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions keras_nlp/layers/modeling/reversible_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np

from keras_nlp.api_export import keras_nlp_export
from keras_nlp.backend import keras
from keras_nlp.backend import ops
Expand Down Expand Up @@ -140,18 +138,24 @@ def get_config(self):
)
return config

def save_own_variables(self, store):
if not self.built:
return
super().save_own_variables(store)
# Before Keras 3.2, the reverse weight is saved in the super() call.
# After Keras 3.2, the reverse weight must be saved manually.
if len(store.keys()) < len(self.weights):
# Store the reverse embedding as the last weight.
store[str(len(store.keys()))] = self.reverse_embeddings

def load_own_variables(self, store):
if not self.built:
self.build()
self.embeddings.assign(store["0"])
super().load_own_variables(store)
if not self.tie_weights:
# Handle the case where saved weights are tied, but the layer
# weights untied. We can simply assign the embedding weights to both
# variables in this case.
if len(store.keys()) == 1:
self.reverse_embeddings.assign(np.transpose(store["0"]))
else:
self.reverse_embeddings.assign(store["1"])
# Last weight in the store is the reverse embedding weights.
key = str(len(store.keys()) - 1)
self.reverse_embeddings.assign(store[key])

def compute_output_spec(self, inputs, reverse=False):
output_shape = list(inputs.shape)
Expand Down
41 changes: 22 additions & 19 deletions keras_nlp/layers/modeling/reversible_embedding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,28 @@ def test_layer_behaviors_tied(self, tie_weights):
expected_num_trainable_weights=1 if tie_weights else 2,
)

@parameterized.named_parameters(
("tie_weights", True),
("untie_weights", False),
)
def test_saving(self, tie_weights):
input_data = random.randint(minval=0, maxval=100, shape=(4, 10))
model = keras.Sequential(
[
ReversibleEmbedding(
input_dim=100,
output_dim=32,
tie_weights=tie_weights,
)
]
)
path = os.path.join(self.get_temp_dir(), "model.keras")
model_output = model(input_data)
model.save(path, save_format="keras_v3")
restored_model = keras.models.load_model(path)
restored_output = restored_model(input_data)
self.assertAllClose(model_output, restored_output)

def test_correctness(self):
layer = ReversibleEmbedding(input_dim=3, output_dim=2)
layer.build()
Expand All @@ -57,25 +79,6 @@ def test_correctness(self):
out = layer(np.array(([[1.0, 1.0]])), reverse=True)
self.assertAllClose(out, np.array([[0.0, 4.0, 6.0]]))

def test_tied_checkpoint_untied_weights(self):
embedding = ReversibleEmbedding(100, 16, tie_weights=True)
inputs = keras.Input(shape=(10,), dtype="int32")
hidden_states = embedding(inputs)
outputs = embedding(hidden_states, reverse=True)
tied_model = keras.Model(inputs, outputs)
path = os.path.join(self.get_temp_dir(), "checkpoint.weights.h5")
tied_model.save_weights(path)

embedding = ReversibleEmbedding(100, 16, tie_weights=False)
inputs = keras.Input(shape=(10,), dtype="int32")
hidden_states = embedding(inputs)
outputs = embedding(hidden_states, reverse=True)
untied_model = keras.Model(inputs, outputs)
untied_model.load_weights(path)

input_data = ops.ones(shape=(4, 10), dtype="int32")
self.assertAllClose(untied_model(input_data), tied_model(input_data))

def test_reverse_dtype(self):
embedding = ReversibleEmbedding(100, 16, reverse_dtype="float32")
input_data = ops.ones(shape=(4, 10, 16))
Expand Down