diff --git a/keras_nlp/src/models/preprocessor.py b/keras_nlp/src/models/preprocessor.py index 686d010bf8..4ab8523918 100644 --- a/keras_nlp/src/models/preprocessor.py +++ b/keras_nlp/src/models/preprocessor.py @@ -89,6 +89,7 @@ def presets(cls): def from_preset( cls, preset, + load_task_extras=False, **kwargs, ): """Instantiate a `keras_nlp.models.Preprocessor` from a model preset. @@ -112,6 +113,9 @@ def from_preset( Args: preset: string. A built in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory. + load_task_extras: bool. If `True`, load the saved task preprocessing + configuration from a `preprocessing.json`. You might use this to + restore the sequence length a model was fine-tuned with. Examples: ```python @@ -138,7 +142,7 @@ def from_preset( # Detect the correct subclass if we need to. if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_preprocessor(cls, **kwargs) + return loader.load_preprocessor(cls, load_task_extras, **kwargs) def save_to_preset(self, preset_dir): """Save preprocessor to a preset directory. diff --git a/keras_nlp/src/models/preprocessor_test.py b/keras_nlp/src/models/preprocessor_test.py index 2baa891842..ef89ca3f56 100644 --- a/keras_nlp/src/models/preprocessor_test.py +++ b/keras_nlp/src/models/preprocessor_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +import pathlib import pytest from absl.testing import parameterized @@ -31,10 +32,11 @@ RobertaTextClassifierPreprocessor, ) from keras_nlp.src.tests.test_case import TestCase -from keras_nlp.src.utils.preset_utils import PREPROCESSOR_CONFIG_FILE +from keras_nlp.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer +from keras_nlp.src.tokenizers.sentence_piece_tokenizer import ( + SentencePieceTokenizer, +) from keras_nlp.src.utils.preset_utils import TOKENIZER_ASSET_DIR -from keras_nlp.src.utils.preset_utils import check_config_class -from keras_nlp.src.utils.preset_utils import load_json class TestPreprocessor(TestCase): @@ -80,27 +82,24 @@ def test_from_preset_errors(self): # TODO: Add more tests when we added a model that has `preprocessor.json`. @parameterized.parameters( - ( - AlbertTextClassifierPreprocessor, - "albert_base_en_uncased", - "sentencepiece", - ), - (RobertaTextClassifierPreprocessor, "roberta_base_en", "bytepair"), - (BertTextClassifierPreprocessor, "bert_tiny_en_uncased", "wordpiece"), + (AlbertTextClassifierPreprocessor, "albert_base_en_uncased"), + (RobertaTextClassifierPreprocessor, "roberta_base_en"), + (BertTextClassifierPreprocessor, "bert_tiny_en_uncased"), ) @pytest.mark.large - def test_save_to_preset(self, cls, preset_name, tokenizer_type): + def test_save_to_preset(self, cls, preset_name): save_dir = self.get_temp_dir() - preprocessor = cls.from_preset(preset_name) + preprocessor = cls.from_preset(preset_name, sequence_length=100) + tokenizer = preprocessor.tokenizer preprocessor.save_to_preset(save_dir) + # Save a backbone so the preset is valid. + backbone = cls.backbone_cls.from_preset(preset_name, load_weights=False) + backbone.save_to_preset(save_dir) - if tokenizer_type == "bytepair": + if isinstance(tokenizer, BytePairTokenizer): vocab_filename = "vocabulary.json" - expected_assets = [ - "vocabulary.json", - "merges.txt", - ] - elif tokenizer_type == "sentencepiece": + expected_assets = ["vocabulary.json", "merges.txt"] + elif isinstance(tokenizer, SentencePieceTokenizer): vocab_filename = "vocabulary.spm" expected_assets = ["vocabulary.spm"] else: @@ -108,17 +107,15 @@ def test_save_to_preset(self, cls, preset_name, tokenizer_type): expected_assets = ["vocabulary.txt"] # Check existence of vocab file. - vocab_path = os.path.join( - save_dir, os.path.join(TOKENIZER_ASSET_DIR, vocab_filename) - ) + path = pathlib.Path(save_dir) + vocab_path = path / TOKENIZER_ASSET_DIR / vocab_filename self.assertTrue(os.path.exists(vocab_path)) # Check assets. - self.assertEqual( - set(preprocessor.tokenizer.file_assets), - set(expected_assets), - ) + self.assertEqual(set(tokenizer.file_assets), set(expected_assets)) - # Check config class. - preprocessor_config = load_json(save_dir, PREPROCESSOR_CONFIG_FILE) - self.assertEqual(cls, check_config_class(preprocessor_config)) + # Check restore. + restored = cls.from_preset(save_dir, load_task_extras=True) + self.assertEqual(preprocessor.get_config(), restored.get_config()) + restored = cls.from_preset(save_dir, load_task_extras=False) + self.assertNotEqual(preprocessor.get_config(), restored.get_config()) diff --git a/keras_nlp/src/models/task.py b/keras_nlp/src/models/task.py index d5aa8eb0b8..f42cc3c3d0 100644 --- a/keras_nlp/src/models/task.py +++ b/keras_nlp/src/models/task.py @@ -146,6 +146,7 @@ def from_preset( cls, preset, load_weights=True, + load_task_extras=False, **kwargs, ): """Instantiate a `keras_nlp.models.Task` from a model preset. @@ -171,9 +172,13 @@ def from_preset( Args: preset: string. A built in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory. - load_weights: bool. If `True`, the weights will be loaded into the - model architecture. If `False`, the weights will be randomly - initialized. + load_weights: bool. If `True`, saved weights will be loaded into + the model architecture. If `False`, all weights will be + randomly initialized. + load_task_extras: bool. If `True`, load the saved task configuration + from a `task.json` and any task specific weights from + `task.weights`. You might use this to load a classification + head for a model that has been saved with it. Examples: ```python @@ -201,13 +206,14 @@ def from_preset( # Detect the correct subclass if we need to. if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_task(cls, load_weights, **kwargs) + return loader.load_task(cls, load_weights, load_task_extras, **kwargs) def load_task_weights(self, filepath): """Load only the tasks specific weights not in the backbone.""" if not str(filepath).endswith(".weights.h5"): raise ValueError( - "The filename must end in `.weights.h5`. Received: filepath={filepath}" + "The filename must end in `.weights.h5`. " + f"Received: filepath={filepath}" ) backbone_layer_ids = set(id(w) for w in self.backbone._flatten_layers()) keras.saving.load_weights( diff --git a/keras_nlp/src/models/task_test.py b/keras_nlp/src/models/task_test.py index 7622dc5070..6b7b5e25cf 100644 --- a/keras_nlp/src/models/task_test.py +++ b/keras_nlp/src/models/task_test.py @@ -13,6 +13,7 @@ # limitations under the License. import os +import pathlib import keras import pytest @@ -109,23 +110,16 @@ def test_summary_without_preprocessor(self): @pytest.mark.large def test_save_to_preset(self): save_dir = self.get_temp_dir() - model = TextClassifier.from_preset( - "bert_tiny_en_uncased", num_classes=2 - ) - model.save_to_preset(save_dir) + task = TextClassifier.from_preset("bert_tiny_en_uncased", num_classes=2) + task.save_to_preset(save_dir) # Check existence of files. - self.assertTrue(os.path.exists(os.path.join(save_dir, CONFIG_FILE))) - self.assertTrue( - os.path.exists(os.path.join(save_dir, MODEL_WEIGHTS_FILE)) - ) - self.assertTrue(os.path.exists(os.path.join(save_dir, METADATA_FILE))) - self.assertTrue( - os.path.exists(os.path.join(save_dir, TASK_CONFIG_FILE)) - ) - self.assertTrue( - os.path.exists(os.path.join(save_dir, TASK_WEIGHTS_FILE)) - ) + path = pathlib.Path(save_dir) + self.assertTrue(os.path.exists(path / CONFIG_FILE)) + self.assertTrue(os.path.exists(path / MODEL_WEIGHTS_FILE)) + self.assertTrue(os.path.exists(path / METADATA_FILE)) + self.assertTrue(os.path.exists(path / TASK_CONFIG_FILE)) + self.assertTrue(os.path.exists(path / TASK_WEIGHTS_FILE)) # Check the task config (`task.json`). task_config = load_json(save_dir, TASK_CONFIG_FILE) @@ -138,13 +132,30 @@ def test_save_to_preset(self): self.assertEqual(BertTextClassifier, check_config_class(task_config)) # Try loading the model from preset directory. - restored_model = TextClassifier.from_preset(save_dir) + restored_task = TextClassifier.from_preset( + save_dir, load_task_extras=True + ) # Check the model output. data = ["the quick brown fox.", "the slow brown fox."] - ref_out = model.predict(data) - new_out = restored_model.predict(data) - self.assertAllEqual(ref_out, new_out) + ref_out = task.predict(data) + new_out = restored_task.predict(data) + self.assertAllClose(ref_out, new_out) + + # Load without head weights. + restored_task = TextClassifier.from_preset( + save_dir, load_task_extras=False, num_classes=2 + ) + data = ["the quick brown fox.", "the slow brown fox."] + # Full output unequal. + ref_out = task.predict(data) + new_out = restored_task.predict(data) + self.assertNotAllClose(ref_out, new_out) + # Backbone output equal. + data = task.preprocessor(data) + ref_out = task.backbone.predict(data) + new_out = restored_task.backbone.predict(data) + self.assertAllClose(ref_out, new_out) @pytest.mark.large def test_none_preprocessor(self): diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index f5a6dc62ce..65cdd7802e 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -656,7 +656,7 @@ def load_tokenizer(self, cls, **kwargs): """Load a tokenizer layer from the preset.""" raise NotImplementedError - def load_task(self, cls, load_weights, **kwargs): + def load_task(self, cls, load_weights, load_task_extras, **kwargs): """Load a task model from the preset. By default, we create a task from a backbone and preprocessor with @@ -671,11 +671,12 @@ def load_task(self, cls, load_weights, **kwargs): ) if "preprocessor" not in kwargs: kwargs["preprocessor"] = self.load_preprocessor( - cls.preprocessor_cls + cls.preprocessor_cls, + load_task_extras=load_task_extras, ) return cls(**kwargs) - def load_preprocessor(self, cls, **kwargs): + def load_preprocessor(self, cls, load_task_extras, **kwargs): """Load a prepocessor layer from the preset. By default, we create a preprocessor from a tokenizer with default @@ -704,35 +705,59 @@ def load_tokenizer(self, cls, **kwargs): tokenizer.load_preset_assets(self.preset) return tokenizer - def load_task(self, cls, load_weights, **kwargs): + def load_task(self, cls, load_weights, load_task_extras, **kwargs): # If there is no `task.json` or it's for the wrong class delegate to the # super class loader. + if not load_task_extras: + return super().load_task( + cls, load_weights, load_task_extras, **kwargs + ) if not check_file_exists(self.preset, TASK_CONFIG_FILE): - return super().load_task(cls, load_weights, **kwargs) + raise ValueError( + "Saved preset has no `task.json`, cannot load the task config " + "from a file. Call `from_preset()` with " + "`load_task_extras=False` to load the task from a backbone " + "with library defaults." + ) task_config = load_json(self.preset, TASK_CONFIG_FILE) if not issubclass(check_config_class(task_config), cls): - return super().load_task(cls, load_weights, **kwargs) + raise ValueError( + f"Saved `task.json`does not match calling cls {cls}. Call " + "`from_preset()` with `load_task_extras=False` to load the " + "task from a backbone with library defaults." + ) # We found a `task.json` with a complete config for our class. task = load_serialized_object(task_config, **kwargs) if task.preprocessor is not None: task.preprocessor.tokenizer.load_preset_assets(self.preset) if load_weights: - jax_memory_cleanup(task) if check_file_exists(self.preset, TASK_WEIGHTS_FILE): + jax_memory_cleanup(task) task_weights = get_file(self.preset, TASK_WEIGHTS_FILE) task.load_task_weights(task_weights) + else: + jax_memory_cleanup(task.backbone) backbone_weights = get_file(self.preset, MODEL_WEIGHTS_FILE) task.backbone.load_weights(backbone_weights) return task - def load_preprocessor(self, cls, **kwargs): - # If there is no `preprocessing.json` or it's for the wrong class, - # delegate to the super class loader. + def load_preprocessor(self, cls, load_task_extras, **kwargs): + if not load_task_extras: + return super().load_preprocessor(cls, load_task_extras, **kwargs) if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE): - return super().load_preprocessor(cls, **kwargs) + raise ValueError( + "Saved preset has no `preprocessor.json`, cannot load the task " + "preprocessing config from a file. Call `from_preset()` with " + "`load_task_extras=False` to load the preprocessor with " + "library defaults." + ) preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE) if not issubclass(check_config_class(preprocessor_json), cls): - return super().load_preprocessor(cls, **kwargs) + raise ValueError( + f"Saved `preprocessor.json`does not match calling cls {cls}. " + "Call `from_preset()` with `load_task_extras=False` to " + "load the the preprocessor with library defaults." + ) # We found a `preprocessing.json` with a complete config for our class. preprocessor = load_serialized_object(preprocessor_json, **kwargs) preprocessor.tokenizer.load_preset_assets(self.preset)