Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion keras_nlp/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def presets(cls):
def from_preset(
cls,
preset,
load_task_extras=False,
**kwargs,
):
"""Instantiate a `keras_nlp.models.Preprocessor` from a model preset.
Expand All @@ -112,6 +113,9 @@ def from_preset(
Args:
preset: string. A built in preset identifier, a Kaggle Models
handle, a Hugging Face handle, or a path to a local directory.
load_task_extras: bool. If `True`, load the saved task preprocessing
configuration from a `preprocessing.json`. You might use this to
restore the sequence length a model was fine-tuned with.

Examples:
```python
Expand All @@ -138,7 +142,7 @@ def from_preset(
# Detect the correct subclass if we need to.
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_preprocessor(cls, **kwargs)
return loader.load_preprocessor(cls, load_task_extras, **kwargs)

def save_to_preset(self, preset_dir):
"""Save preprocessor to a preset directory.
Expand Down
53 changes: 25 additions & 28 deletions keras_nlp/src/models/preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import pathlib

import pytest
from absl.testing import parameterized
Expand All @@ -31,10 +32,11 @@
RobertaTextClassifierPreprocessor,
)
from keras_nlp.src.tests.test_case import TestCase
from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE
from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
from keras_nlp.src.tokenizers.sentence_piece_tokenizer import (
SentencePieceTokenizer,
)
from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR
from keras_nlp.src.utils.preset_utils import check_config_class
from keras_nlp.src.utils.preset_utils import load_json


class TestPreprocessor(TestCase):
Expand Down Expand Up @@ -80,45 +82,40 @@ def test_from_preset_errors(self):
# TODO: Add more tests when we added a model that has `preprocessor.json`.

@parameterized.parameters(
(
AlbertTextClassifierPreprocessor,
"albert_base_en_uncased",
"sentencepiece",
),
(RobertaTextClassifierPreprocessor, "roberta_base_en", "bytepair"),
(BertTextClassifierPreprocessor, "bert_tiny_en_uncased", "wordpiece"),
(AlbertTextClassifierPreprocessor, "albert_base_en_uncased"),
(RobertaTextClassifierPreprocessor, "roberta_base_en"),
(BertTextClassifierPreprocessor, "bert_tiny_en_uncased"),
)
@pytest.mark.large
def test_save_to_preset(self, cls, preset_name, tokenizer_type):
def test_save_to_preset(self, cls, preset_name):
save_dir = self.get_temp_dir()
preprocessor = cls.from_preset(preset_name)
preprocessor = cls.from_preset(preset_name, sequence_length=100)
tokenizer = preprocessor.tokenizer
preprocessor.save_to_preset(save_dir)
# Save a backbone so the preset is valid.
backbone = cls.backbone_cls.from_preset(preset_name, load_weights=False)
backbone.save_to_preset(save_dir)

if tokenizer_type == "bytepair":
if isinstance(tokenizer, BytePairTokenizer):
vocab_filename = "vocabulary.json"
expected_assets = [
"vocabulary.json",
"merges.txt",
]
elif tokenizer_type == "sentencepiece":
expected_assets = ["vocabulary.json", "merges.txt"]
elif isinstance(tokenizer, SentencePieceTokenizer):
vocab_filename = "vocabulary.spm"
expected_assets = ["vocabulary.spm"]
else:
vocab_filename = "vocabulary.txt"
expected_assets = ["vocabulary.txt"]

# Check existence of vocab file.
vocab_path = os.path.join(
save_dir, os.path.join(TOKENIZER_ASSET_DIR, vocab_filename)
)
path = pathlib.Path(save_dir)
vocab_path = path / TOKENIZER_ASSET_DIR / vocab_filename
self.assertTrue(os.path.exists(vocab_path))

# Check assets.
self.assertEqual(
set(preprocessor.tokenizer.file_assets),
set(expected_assets),
)
self.assertEqual(set(tokenizer.file_assets), set(expected_assets))

# Check config class.
preprocessor_config = load_json(save_dir, PREPROCESSOR_CONFIG_FILE)
self.assertEqual(cls, check_config_class(preprocessor_config))
# Check restore.
restored = cls.from_preset(save_dir, load_task_extras=True)
self.assertEqual(preprocessor.get_config(), restored.get_config())
restored = cls.from_preset(save_dir, load_task_extras=False)
self.assertNotEqual(preprocessor.get_config(), restored.get_config())
16 changes: 11 additions & 5 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def from_preset(
cls,
preset,
load_weights=True,
load_task_extras=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was wondering what's the reason behind setting the default to False here! 🤔 Since it's loading a task, loading the head weights by default might not be a bad idea!

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My biggest reason for wanting to flip the default is to have parallel quickstarts for vision and text models:

classifier = TextClassifier.from_preset(
    "bert_base_en",
    num_classes=2,
)
classifier.fit(text_dataset)

classifier = ImageClassifier.from_preset(
    "res_net_50",
    num_classes=2,
)
classifier.fit(image_dataset)

I think those snippets are important, they will be front and center. We should avoid needing to introduce more concepts there.

This also flips the arg to be more explicit (explicit is good I think). Passing true will now error if a task.json does not exist, or if it is for the wrong class. But we definitely cannot keep this strict behavior if we flip the default. It would break our current quickstart!

**kwargs,
):
"""Instantiate a `keras_nlp.models.Task` from a model preset.
Expand All @@ -171,9 +172,13 @@ def from_preset(
Args:
preset: string. A built in preset identifier, a Kaggle Models
handle, a Hugging Face handle, or a path to a local directory.
load_weights: bool. If `True`, the weights will be loaded into the
model architecture. If `False`, the weights will be randomly
initialized.
load_weights: bool. If `True`, saved weights will be loaded into
the model architecture. If `False`, all weights will be
randomly initialized.
load_task_extras: bool. If `True`, load the saved task configuration
from a `task.json` and any task specific weights from
`task.weights`. You might use this to load a classification
head for a model that has been saved with it.

Examples:
```python
Expand Down Expand Up @@ -201,13 +206,14 @@ def from_preset(
# Detect the correct subclass if we need to.
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_task(cls, load_weights, **kwargs)
return loader.load_task(cls, load_weights, load_task_extras, **kwargs)

def load_task_weights(self, filepath):
"""Load only the tasks specific weights not in the backbone."""
if not str(filepath).endswith(".weights.h5"):
raise ValueError(
"The filename must end in `.weights.h5`. Received: filepath={filepath}"
"The filename must end in `.weights.h5`. "
f"Received: filepath={filepath}"
)
backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers())
keras.saving.load_weights(
Expand Down
49 changes: 30 additions & 19 deletions keras_nlp/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.

import os
import pathlib

import keras
import pytest
Expand Down Expand Up @@ -109,23 +110,16 @@ def test_summary_without_preprocessor(self):
@pytest.mark.large
def test_save_to_preset(self):
save_dir = self.get_temp_dir()
model = TextClassifier.from_preset(
"bert_tiny_en_uncased", num_classes=2
)
model.save_to_preset(save_dir)
task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2)
task.save_to_preset(save_dir)

# Check existence of files.
self.assertTrue(os.path.exists(os.path.join(save_dir, CONFIG_FILE)))
self.assertTrue(
os.path.exists(os.path.join(save_dir, MODEL_WEIGHTS_FILE))
)
self.assertTrue(os.path.exists(os.path.join(save_dir, METADATA_FILE)))
self.assertTrue(
os.path.exists(os.path.join(save_dir, TASK_CONFIG_FILE))
)
self.assertTrue(
os.path.exists(os.path.join(save_dir, TASK_WEIGHTS_FILE))
)
path = pathlib.Path(save_dir)
self.assertTrue(os.path.exists(path / CONFIG_FILE))
self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE))
self.assertTrue(os.path.exists(path / METADATA_FILE))
self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE))
self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE))

# Check the task config (`task.json`).
task_config = load_json(save_dir, TASK_CONFIG_FILE)
Expand All @@ -138,13 +132,30 @@ def test_save_to_preset(self):
self.assertEqual(BertTextClassifier, check_config_class(task_config))

# Try loading the model from preset directory.
restored_model = TextClassifier.from_preset(save_dir)
restored_task = TextClassifier.from_preset(
save_dir, load_task_extras=True
)

# Check the model output.
data = ["the quick brown fox.", "the slow brown fox."]
ref_out = model.predict(data)
new_out = restored_model.predict(data)
self.assertAllEqual(ref_out, new_out)
ref_out = task.predict(data)
new_out = restored_task.predict(data)
self.assertAllClose(ref_out, new_out)

# Load without head weights.
restored_task = TextClassifier.from_preset(
save_dir, load_task_extras=False, num_classes=2
)
data = ["the quick brown fox.", "the slow brown fox."]
# Full output unequal.
ref_out = task.predict(data)
new_out = restored_task.predict(data)
self.assertNotAllClose(ref_out, new_out)
# Backbone output equal.
data = task.preprocessor(data)
ref_out = task.backbone.predict(data)
new_out = restored_task.backbone.predict(data)
self.assertAllClose(ref_out, new_out)

@pytest.mark.large
def test_none_preprocessor(self):
Expand Down
49 changes: 37 additions & 12 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -656,7 +656,7 @@ def load_tokenizer(self, cls, **kwargs):
"""Load a tokenizer layer from the preset."""
raise NotImplementedError

def load_task(self, cls, load_weights, **kwargs):
def load_task(self, cls, load_weights, load_task_extras, **kwargs):
"""Load a task model from the preset.

By default, we create a task from a backbone and preprocessor with
Expand All @@ -671,11 +671,12 @@ def load_task(self, cls, load_weights, **kwargs):
)
if "preprocessor" not in kwargs:
kwargs["preprocessor"] = self.load_preprocessor(
cls.preprocessor_cls
cls.preprocessor_cls,
load_task_extras=load_task_extras,
)
return cls(**kwargs)

def load_preprocessor(self, cls, **kwargs):
def load_preprocessor(self, cls, load_task_extras, **kwargs):
"""Load a prepocessor layer from the preset.

By default, we create a preprocessor from a tokenizer with default
Expand Down Expand Up @@ -704,35 +705,59 @@ def load_tokenizer(self, cls, **kwargs):
tokenizer.load_preset_assets(self.preset)
return tokenizer

def load_task(self, cls, load_weights, **kwargs):
def load_task(self, cls, load_weights, load_task_extras, **kwargs):
# If there is no `task.json` or it's for the wrong class delegate to the
# super class loader.
if not load_task_extras:
return super().load_task(
cls, load_weights, load_task_extras, **kwargs
)
if not check_file_exists(self.preset, TASK_CONFIG_FILE):
return super().load_task(cls, load_weights, **kwargs)
raise ValueError(
"Saved preset has no `task.json`, cannot load the task config "
"from a file. Call `from_preset()` with "
"`load_task_extras=False` to load the task from a backbone "
"with library defaults."
)
task_config = load_json(self.preset, TASK_CONFIG_FILE)
if not issubclass(check_config_class(task_config), cls):
return super().load_task(cls, load_weights, **kwargs)
raise ValueError(
f"Saved `task.json`does not match calling cls {cls}. Call "
"`from_preset()` with `load_task_extras=False` to load the "
"task from a backbone with library defaults."
)
# We found a `task.json` with a complete config for our class.
task = load_serialized_object(task_config, **kwargs)
if task.preprocessor is not None:
task.preprocessor.tokenizer.load_preset_assets(self.preset)
if load_weights:
jax_memory_cleanup(task)
if check_file_exists(self.preset, TASK_WEIGHTS_FILE):
jax_memory_cleanup(task)
task_weights = get_file(self.preset, TASK_WEIGHTS_FILE)
task.load_task_weights(task_weights)
else:
jax_memory_cleanup(task.backbone)
backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE)
task.backbone.load_weights(backbone_weights)
return task

def load_preprocessor(self, cls, **kwargs):
# If there is no `preprocessing.json` or it's for the wrong class,
# delegate to the super class loader.
def load_preprocessor(self, cls, load_task_extras, **kwargs):
if not load_task_extras:
return super().load_preprocessor(cls, load_task_extras, **kwargs)
if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE):
return super().load_preprocessor(cls, **kwargs)
raise ValueError(
"Saved preset has no `preprocessor.json`, cannot load the task "
"preprocessing config from a file. Call `from_preset()` with "
"`load_task_extras=False` to load the preprocessor with "
"library defaults."
)
preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE)
if not issubclass(check_config_class(preprocessor_json), cls):
return super().load_preprocessor(cls, **kwargs)
raise ValueError(
f"Saved `preprocessor.json`does not match calling cls {cls}. "
"Call `from_preset()` with `load_task_extras=False` to "
"load the the preprocessor with library defaults."
)
# We found a `preprocessing.json` with a complete config for our class.
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
preprocessor.tokenizer.load_preset_assets(self.preset)
Expand Down
Loading