Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions keras_nlp/src/models/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def presets(cls):
def from_preset(
cls,
preset,
load_task_extras=False,
**kwargs,
):
"""Instantiate a `keras_nlp.models.Preprocessor` from a model preset.
Expand All @@ -150,9 +149,6 @@ def from_preset(
Args:
preset: string. A built-in preset identifier, a Kaggle Models
handle, a Hugging Face handle, or a path to a local directory.
load_task_extras: bool. If `True`, load the saved task preprocessing
configuration from a `preprocessing.json`. You might use this to
restore the sequence length a model was fine-tuned with.

Examples:
```python
Expand All @@ -179,7 +175,7 @@ def from_preset(
# Detect the correct subclass if we need to.
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_preprocessor(cls, load_task_extras, **kwargs)
return loader.load_preprocessor(cls, **kwargs)

def save_to_preset(self, preset_dir):
"""Save preprocessor to a preset directory.
Expand Down
4 changes: 1 addition & 3 deletions keras_nlp/src/models/preprocessor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,5 @@ def test_save_to_preset(self, cls, preset_name):
self.assertEqual(set(tokenizer.file_assets), set(expected_assets))

# Check restore.
restored = cls.from_preset(save_dir, load_task_extras=True)
restored = cls.from_preset(save_dir)
self.assertEqual(preprocessor.get_config(), restored.get_config())
restored = cls.from_preset(save_dir, load_task_extras=False)
self.assertNotEqual(preprocessor.get_config(), restored.get_config())
12 changes: 6 additions & 6 deletions keras_nlp/src/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def from_preset(
cls,
preset,
load_weights=True,
load_task_extras=False,
**kwargs,
):
"""Instantiate a `keras_nlp.models.Task` from a model preset.
Expand Down Expand Up @@ -168,10 +167,6 @@ def from_preset(
load_weights: bool. If `True`, saved weights will be loaded into
the model architecture. If `False`, all weights will be
randomly initialized.
load_task_extras: bool. If `True`, load the saved task configuration
from a `task.json` and any task specific weights from
`task.weights`. You might use this to load a classification
head for a model that has been saved with it.

Examples:
```python
Expand Down Expand Up @@ -199,7 +194,12 @@ def from_preset(
# Detect the correct subclass if we need to.
if cls.backbone_cls != backbone_cls:
cls = find_subclass(preset, cls, backbone_cls)
return loader.load_task(cls, load_weights, load_task_extras, **kwargs)
# Specifically for classifiers, we never load task weights if
# num_classes is supplied. We handle this in the task base class because
# it is the same logic for classifiers regardless of modality (text,
# images, audio).
load_task_weights = "num_classes" not in kwargs
return loader.load_task(cls, load_weights, load_task_weights, **kwargs)

def load_task_weights(self, filepath):
"""Load only the tasks specific weights not in the backbone."""
Expand Down
10 changes: 3 additions & 7 deletions keras_nlp/src/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,20 +138,16 @@ def test_save_to_preset(self):
self.assertEqual(BertTextClassifier, check_config_class(task_config))

# Try loading the model from preset directory.
restored_task = TextClassifier.from_preset(
save_dir, load_task_extras=True
)
restored_task = TextClassifier.from_preset(save_dir)

# Check the model output.
data = ["the quick brown fox.", "the slow brown fox."]
ref_out = task.predict(data)
new_out = restored_task.predict(data)
self.assertAllClose(ref_out, new_out)

# Load without head weights.
restored_task = TextClassifier.from_preset(
save_dir, load_task_extras=False, num_classes=2
)
# Load classifier head with random weights.
restored_task = TextClassifier.from_preset(save_dir, num_classes=2)
data = ["the quick brown fox.", "the slow brown fox."]
# Full output unequal.
ref_out = task.predict(data)
Expand Down
6 changes: 6 additions & 0 deletions keras_nlp/src/models/text_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ class TextClassifier(Task):
All `TextClassifier` tasks include a `from_preset()` constructor which can be
used to load a pre-trained config and weights.

Some, but not all, classification presets include classification head
weights in a `task.weights.h5` file. For these presets, you can omit passing
`num_classes` to restore the saved classification head. For all presets, if
`num_classes` is passed as a kwarg to `from_preset()`, the classification
head will be randomly initialized.

Example:
```python
# Load a BERT classifier with pre-trained weights.
Expand Down
46 changes: 14 additions & 32 deletions keras_nlp/src/utils/preset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,7 +673,7 @@ def load_image_converter(self, cls, **kwargs):
"""Load an image converter layer from the preset."""
raise NotImplementedError

def load_task(self, cls, load_weights, load_task_extras, **kwargs):
def load_task(self, cls, load_weights, load_task_weights, **kwargs):
"""Load a task model from the preset.

By default, we create a task from a backbone and preprocessor with
Expand All @@ -689,11 +689,10 @@ def load_task(self, cls, load_weights, load_task_extras, **kwargs):
if "preprocessor" not in kwargs:
kwargs["preprocessor"] = self.load_preprocessor(
cls.preprocessor_cls,
load_task_extras=load_task_extras,
)
return cls(**kwargs)

def load_preprocessor(self, cls, load_task_extras, **kwargs):
def load_preprocessor(self, cls, **kwargs):
"""Load a prepocessor layer from the preset.

By default, we create a preprocessor from a tokenizer with default
Expand Down Expand Up @@ -738,33 +737,25 @@ def load_image_converter(self, cls, **kwargs):
converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE)
return load_serialized_object(converter_config, **kwargs)

def load_task(self, cls, load_weights, load_task_extras, **kwargs):
def load_task(self, cls, load_weights, load_task_weights, **kwargs):
# If there is no `task.json` or it's for the wrong class delegate to the
# super class loader.
if not load_task_extras:
return super().load_task(
cls, load_weights, load_task_extras, **kwargs
)
if not check_file_exists(self.preset, TASK_CONFIG_FILE):
raise ValueError(
"Saved preset has no `task.json`, cannot load the task config "
"from a file. Call `from_preset()` with "
"`load_task_extras=False` to load the task from a backbone "
"with library defaults."
return super().load_task(
cls, load_weights, load_task_weights, **kwargs
)
task_config = load_json(self.preset, TASK_CONFIG_FILE)
if not issubclass(check_config_class(task_config), cls):
raise ValueError(
f"Saved `task.json`does not match calling cls {cls}. Call "
"`from_preset()` with `load_task_extras=False` to load the "
"task from a backbone with library defaults."
return super().load_task(
cls, load_weights, load_task_weights, **kwargs
)
# We found a `task.json` with a complete config for our class.
task = load_serialized_object(task_config, **kwargs)
if task.preprocessor is not None:
task.preprocessor.tokenizer.load_preset_assets(self.preset)
if load_weights:
if check_file_exists(self.preset, TASK_WEIGHTS_FILE):
has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE)
if has_task_weights and load_task_weights:
jax_memory_cleanup(task)
task_weights = get_file(self.preset, TASK_WEIGHTS_FILE)
task.load_task_weights(task_weights)
Expand All @@ -774,23 +765,14 @@ def load_task(self, cls, load_weights, load_task_extras, **kwargs):
task.backbone.load_weights(backbone_weights)
return task

def load_preprocessor(self, cls, load_task_extras, **kwargs):
if not load_task_extras:
return super().load_preprocessor(cls, load_task_extras, **kwargs)
def load_preprocessor(self, cls, **kwargs):
# If there is no `preprocessing.json` or it's for the wrong class,
# delegate to the super class loader.
if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE):
raise ValueError(
"Saved preset has no `preprocessor.json`, cannot load the task "
"preprocessing config from a file. Call `from_preset()` with "
"`load_task_extras=False` to load the preprocessor with "
"library defaults."
)
return super().load_preprocessor(cls, **kwargs)
preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE)
if not issubclass(check_config_class(preprocessor_json), cls):
raise ValueError(
f"Saved `preprocessor.json`does not match calling cls {cls}. "
"Call `from_preset()` with `load_task_extras=False` to "
"load the the preprocessor with library defaults."
)
return super().load_preprocessor(cls, **kwargs)
# We found a `preprocessing.json` with a complete config for our class.
preprocessor = load_serialized_object(preprocessor_json, **kwargs)
preprocessor.tokenizer.load_preset_assets(self.preset)
Expand Down
Loading