From 32fc430783d428e6460f17cd92602d719df6baf1 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Wed, 11 Sep 2024 19:15:13 -0700 Subject: [PATCH 1/2] Take two of #1812, simpler classifier head loading Let's get rid of `load_task_extras`, which is a bad and confusing name. Instead, we will adopt some behavior that is specific to classifiers, but a lot simpler. ```python # Random head. classifier = ImageClassifier.from_preset("resnet50", num_classes=2) # Pretrained head. classifier = ImageClassifier.from_preset("resnet50") # Error, must provide num_classes. classifier = TextClassifier.from_preset("bert_base_en") ``` --- keras_nlp/src/models/preprocessor.py | 6 +-- keras_nlp/src/models/preprocessor_test.py | 4 +- keras_nlp/src/models/task.py | 12 +++--- keras_nlp/src/models/task_test.py | 10 ++--- keras_nlp/src/models/text_classifier.py | 6 +++ keras_nlp/src/utils/preset_utils.py | 46 +++++++---------------- 6 files changed, 31 insertions(+), 53 deletions(-) diff --git a/keras_nlp/src/models/preprocessor.py b/keras_nlp/src/models/preprocessor.py index 6d674abbdc..20bec93b9a 100644 --- a/keras_nlp/src/models/preprocessor.py +++ b/keras_nlp/src/models/preprocessor.py @@ -126,7 +126,6 @@ def presets(cls): def from_preset( cls, preset, - load_task_extras=False, **kwargs, ): """Instantiate a `keras_nlp.models.Preprocessor` from a model preset. @@ -150,9 +149,6 @@ def from_preset( Args: preset: string. A built-in preset identifier, a Kaggle Models handle, a Hugging Face handle, or a path to a local directory. - load_task_extras: bool. If `True`, load the saved task preprocessing - configuration from a `preprocessing.json`. You might use this to - restore the sequence length a model was fine-tuned with. Examples: ```python @@ -179,7 +175,7 @@ def from_preset( # Detect the correct subclass if we need to. if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_preprocessor(cls, load_task_extras, **kwargs) + return loader.load_preprocessor(cls, **kwargs) def save_to_preset(self, preset_dir): """Save preprocessor to a preset directory. diff --git a/keras_nlp/src/models/preprocessor_test.py b/keras_nlp/src/models/preprocessor_test.py index b4e3af8efe..42de5c22b6 100644 --- a/keras_nlp/src/models/preprocessor_test.py +++ b/keras_nlp/src/models/preprocessor_test.py @@ -121,7 +121,5 @@ def test_save_to_preset(self, cls, preset_name): self.assertEqual(set(tokenizer.file_assets), set(expected_assets)) # Check restore. - restored = cls.from_preset(save_dir, load_task_extras=True) + restored = cls.from_preset(save_dir) self.assertEqual(preprocessor.get_config(), restored.get_config()) - restored = cls.from_preset(save_dir, load_task_extras=False) - self.assertNotEqual(preprocessor.get_config(), restored.get_config()) diff --git a/keras_nlp/src/models/task.py b/keras_nlp/src/models/task.py index 3748d0b67e..5fce9841e2 100644 --- a/keras_nlp/src/models/task.py +++ b/keras_nlp/src/models/task.py @@ -139,7 +139,6 @@ def from_preset( cls, preset, load_weights=True, - load_task_extras=False, **kwargs, ): """Instantiate a `keras_nlp.models.Task` from a model preset. @@ -168,10 +167,6 @@ def from_preset( load_weights: bool. If `True`, saved weights will be loaded into the model architecture. If `False`, all weights will be randomly initialized. - load_task_extras: bool. If `True`, load the saved task configuration - from a `task.json` and any task specific weights from - `task.weights`. You might use this to load a classification - head for a model that has been saved with it. Examples: ```python @@ -199,7 +194,12 @@ def from_preset( # Detect the correct subclass if we need to. if cls.backbone_cls != backbone_cls: cls = find_subclass(preset, cls, backbone_cls) - return loader.load_task(cls, load_weights, load_task_extras, **kwargs) + # Specifically for classifiers, we never load task weights if + # num_classes is supplied. We handle this in the task base class because + # it is the some logic for classifiers regardless of modality (text, + # images, audio). + load_task_weights = "num_classes" not in kwargs + return loader.load_task(cls, load_weights, load_task_weights, **kwargs) def load_task_weights(self, filepath): """Load only the tasks specific weights not in the backbone.""" diff --git a/keras_nlp/src/models/task_test.py b/keras_nlp/src/models/task_test.py index 64c3336e85..9564caa711 100644 --- a/keras_nlp/src/models/task_test.py +++ b/keras_nlp/src/models/task_test.py @@ -138,9 +138,7 @@ def test_save_to_preset(self): self.assertEqual(BertTextClassifier, check_config_class(task_config)) # Try loading the model from preset directory. - restored_task = TextClassifier.from_preset( - save_dir, load_task_extras=True - ) + restored_task = TextClassifier.from_preset(save_dir) # Check the model output. data = ["the quick brown fox.", "the slow brown fox."] @@ -148,10 +146,8 @@ def test_save_to_preset(self): new_out = restored_task.predict(data) self.assertAllClose(ref_out, new_out) - # Load without head weights. - restored_task = TextClassifier.from_preset( - save_dir, load_task_extras=False, num_classes=2 - ) + # Load without head different head weights. + restored_task = TextClassifier.from_preset(save_dir, num_classes=2) data = ["the quick brown fox.", "the slow brown fox."] # Full output unequal. ref_out = task.predict(data) diff --git a/keras_nlp/src/models/text_classifier.py b/keras_nlp/src/models/text_classifier.py index 79eedca52c..6e72bde807 100644 --- a/keras_nlp/src/models/text_classifier.py +++ b/keras_nlp/src/models/text_classifier.py @@ -32,6 +32,12 @@ class TextClassifier(Task): All `TextClassifier` tasks include a `from_preset()` constructor which can be used to load a pre-trained config and weights. + Some classification presets (but not all), include classification head + weights in a `task.weights.h5`. For these presets, you can omit passing + `num_classes` to re-create the save classification head. For all presets, if + `num_classes` is passed as a kwarg to `from_preset()`, the classification + head will be randomly initialized. + Example: ```python # Load a BERT classifier with pre-trained weights. diff --git a/keras_nlp/src/utils/preset_utils.py b/keras_nlp/src/utils/preset_utils.py index 3a3451ca32..34c96de777 100644 --- a/keras_nlp/src/utils/preset_utils.py +++ b/keras_nlp/src/utils/preset_utils.py @@ -673,7 +673,7 @@ def load_image_converter(self, cls, **kwargs): """Load an image converter layer from the preset.""" raise NotImplementedError - def load_task(self, cls, load_weights, load_task_extras, **kwargs): + def load_task(self, cls, load_weights, load_task_weights, **kwargs): """Load a task model from the preset. By default, we create a task from a backbone and preprocessor with @@ -689,11 +689,10 @@ def load_task(self, cls, load_weights, load_task_extras, **kwargs): if "preprocessor" not in kwargs: kwargs["preprocessor"] = self.load_preprocessor( cls.preprocessor_cls, - load_task_extras=load_task_extras, ) return cls(**kwargs) - def load_preprocessor(self, cls, load_task_extras, **kwargs): + def load_preprocessor(self, cls, **kwargs): """Load a prepocessor layer from the preset. By default, we create a preprocessor from a tokenizer with default @@ -738,33 +737,25 @@ def load_image_converter(self, cls, **kwargs): converter_config = load_json(self.preset, IMAGE_CONVERTER_CONFIG_FILE) return load_serialized_object(converter_config, **kwargs) - def load_task(self, cls, load_weights, load_task_extras, **kwargs): + def load_task(self, cls, load_weights, load_task_weights, **kwargs): # If there is no `task.json` or it's for the wrong class delegate to the # super class loader. - if not load_task_extras: - return super().load_task( - cls, load_weights, load_task_extras, **kwargs - ) if not check_file_exists(self.preset, TASK_CONFIG_FILE): - raise ValueError( - "Saved preset has no `task.json`, cannot load the task config " - "from a file. Call `from_preset()` with " - "`load_task_extras=False` to load the task from a backbone " - "with library defaults." + return super().load_task( + cls, load_weights, load_task_weights, **kwargs ) task_config = load_json(self.preset, TASK_CONFIG_FILE) if not issubclass(check_config_class(task_config), cls): - raise ValueError( - f"Saved `task.json`does not match calling cls {cls}. Call " - "`from_preset()` with `load_task_extras=False` to load the " - "task from a backbone with library defaults." + return super().load_task( + cls, load_weights, load_task_weights, **kwargs ) # We found a `task.json` with a complete config for our class. task = load_serialized_object(task_config, **kwargs) if task.preprocessor is not None: task.preprocessor.tokenizer.load_preset_assets(self.preset) if load_weights: - if check_file_exists(self.preset, TASK_WEIGHTS_FILE): + has_task_weights = check_file_exists(self.preset, TASK_WEIGHTS_FILE) + if has_task_weights and load_task_weights: jax_memory_cleanup(task) task_weights = get_file(self.preset, TASK_WEIGHTS_FILE) task.load_task_weights(task_weights) @@ -774,23 +765,14 @@ def load_task(self, cls, load_weights, load_task_extras, **kwargs): task.backbone.load_weights(backbone_weights) return task - def load_preprocessor(self, cls, load_task_extras, **kwargs): - if not load_task_extras: - return super().load_preprocessor(cls, load_task_extras, **kwargs) + def load_preprocessor(self, cls, **kwargs): + # If there is no `preprocessing.json` or it's for the wrong class, + # delegate to the super class loader. if not check_file_exists(self.preset, PREPROCESSOR_CONFIG_FILE): - raise ValueError( - "Saved preset has no `preprocessor.json`, cannot load the task " - "preprocessing config from a file. Call `from_preset()` with " - "`load_task_extras=False` to load the preprocessor with " - "library defaults." - ) + return super().load_preprocessor(cls, **kwargs) preprocessor_json = load_json(self.preset, PREPROCESSOR_CONFIG_FILE) if not issubclass(check_config_class(preprocessor_json), cls): - raise ValueError( - f"Saved `preprocessor.json`does not match calling cls {cls}. " - "Call `from_preset()` with `load_task_extras=False` to " - "load the the preprocessor with library defaults." - ) + return super().load_preprocessor(cls, **kwargs) # We found a `preprocessing.json` with a complete config for our class. preprocessor = load_serialized_object(preprocessor_json, **kwargs) preprocessor.tokenizer.load_preset_assets(self.preset) From 6d423863d08d722ffdd3a4397e9878a89da10558 Mon Sep 17 00:00:00 2001 From: Matt Watson Date: Thu, 12 Sep 2024 16:00:23 -0700 Subject: [PATCH 2/2] address review comments --- keras_nlp/src/models/task.py | 2 +- keras_nlp/src/models/task_test.py | 2 +- keras_nlp/src/models/text_classifier.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/keras_nlp/src/models/task.py b/keras_nlp/src/models/task.py index 5fce9841e2..703bee764c 100644 --- a/keras_nlp/src/models/task.py +++ b/keras_nlp/src/models/task.py @@ -196,7 +196,7 @@ def from_preset( cls = find_subclass(preset, cls, backbone_cls) # Specifically for classifiers, we never load task weights if # num_classes is supplied. We handle this in the task base class because - # it is the some logic for classifiers regardless of modality (text, + # it is the same logic for classifiers regardless of modality (text, # images, audio). load_task_weights = "num_classes" not in kwargs return loader.load_task(cls, load_weights, load_task_weights, **kwargs) diff --git a/keras_nlp/src/models/task_test.py b/keras_nlp/src/models/task_test.py index 9564caa711..2eba398452 100644 --- a/keras_nlp/src/models/task_test.py +++ b/keras_nlp/src/models/task_test.py @@ -146,7 +146,7 @@ def test_save_to_preset(self): new_out = restored_task.predict(data) self.assertAllClose(ref_out, new_out) - # Load without head different head weights. + # Load classifier head with random weights. restored_task = TextClassifier.from_preset(save_dir, num_classes=2) data = ["the quick brown fox.", "the slow brown fox."] # Full output unequal. diff --git a/keras_nlp/src/models/text_classifier.py b/keras_nlp/src/models/text_classifier.py index 6e72bde807..6f62f52210 100644 --- a/keras_nlp/src/models/text_classifier.py +++ b/keras_nlp/src/models/text_classifier.py @@ -32,9 +32,9 @@ class TextClassifier(Task): All `TextClassifier` tasks include a `from_preset()` constructor which can be used to load a pre-trained config and weights. - Some classification presets (but not all), include classification head - weights in a `task.weights.h5`. For these presets, you can omit passing - `num_classes` to re-create the save classification head. For all presets, if + Some, but not all, classification presets include classification head + weights in a `task.weights.h5` file. For these presets, you can omit passing + `num_classes` to restore the saved classification head. For all presets, if `num_classes` is passed as a kwarg to `from_preset()`, the classification head will be randomly initialized.