Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion keras/src/layers/preprocessing/index_lookup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from keras.src import backend
from keras.src.layers.layer import Layer
from keras.src.saving import serialization_lib
from keras.src.utils import argument_validation
from keras.src.utils import numerical_utils
from keras.src.utils import tf_utils
Expand Down Expand Up @@ -178,7 +179,12 @@ def __init__(
self.vocabulary_dtype = tf.as_dtype(vocabulary_dtype).name
self._frozen_vocab_size = kwargs.pop("vocabulary_size", None)

self.input_vocabulary = vocabulary
# Remember original `vocabulary` as `input_vocabulary` for serialization
# via `get_config`. However, if `vocabulary` is a file path or a URL, we
# serialize the vocabulary as an asset and clear the original path/URL.
self.input_vocabulary = (
vocabulary if not isinstance(vocabulary, str) else None
)
self.input_idf_weights = idf_weights

# We set this hidden attr to
Expand Down Expand Up @@ -382,6 +388,18 @@ def set_vocabulary(self, vocabulary, idf_weights=None):
)

if isinstance(vocabulary, str):
if serialization_lib.in_safe_mode():
raise ValueError(
"Requested the loading of a vocabulary file outside of the "
"model archive. This carries a potential risk of loading "
"arbitrary and sensitive files and thus it is disallowed "
"by default. If you trust the source of the artifact, you "
"can override this error by passing `safe_mode=False` to "
"the loading function, or calling "
"`keras.config.enable_unsafe_deserialization(). "
f"Vocabulary file: '{vocabulary}'"
)

if not tf.io.gfile.exists(vocabulary):
raise ValueError(
f"Vocabulary file {vocabulary} does not exist."
Expand Down
38 changes: 38 additions & 0 deletions keras/src/layers/preprocessing/string_lookup_test.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
import os

import numpy as np
import pytest
from tensorflow import data as tf_data

from keras.src import backend
from keras.src import layers
from keras.src import models
from keras.src import saving
from keras.src import testing
from keras.src.ops import convert_to_tensor

Expand All @@ -19,6 +23,40 @@ def test_config(self):
mask_token="[MASK]",
)
self.run_class_serialization_test(layer)
self.assertEqual(layer.get_config()["vocabulary"], ["a", "b", "c"])

def test_vocabulary_file(self):
temp_dir = self.get_temp_dir()
vocab_path = os.path.join(temp_dir, "vocab.txt")
with open(vocab_path, "w") as file:
file.write("a\nb\nc\n")

layer = layers.StringLookup(
output_mode="int",
vocabulary=vocab_path,
oov_token="[OOV]",
mask_token="[MASK]",
name="index",
)
self.assertEqual(
[str(v) for v in layer.get_vocabulary()],
["[MASK]", "[OOV]", "a", "b", "c"],
)
self.assertIsNone(layer.get_config().get("vocabulary", None))

# Make sure vocabulary comes from the archive, not the original file.
os.remove(vocab_path)

model = models.Sequential([layer])
model_path = os.path.join(temp_dir, "test_model.keras")
model.save(model_path)

reloaded_model = saving.load_model(model_path)
reloaded_layer = reloaded_model.get_layer("index")
self.assertEqual(
[str(v) for v in reloaded_layer.get_vocabulary()],
["[MASK]", "[OOV]", "a", "b", "c"],
)

def test_adapt_flow(self):
layer = layers.StringLookup(
Expand Down