diff --git a/keras_nlp/models/__init__.py b/keras_nlp/models/__init__.py index 1abfc0dc8..033a9dc87 100644 --- a/keras_nlp/models/__init__.py +++ b/keras_nlp/models/__init__.py @@ -20,6 +20,7 @@ ) from keras_nlp.models.albert.albert_preprocessor import AlbertPreprocessor from keras_nlp.models.albert.albert_tokenizer import AlbertTokenizer +from keras_nlp.models.backbone import Backbone from keras_nlp.models.bart.bart_backbone import BartBackbone from keras_nlp.models.bart.bart_preprocessor import BartPreprocessor from keras_nlp.models.bart.bart_seq_2_seq_lm import BartSeq2SeqLM @@ -130,6 +131,7 @@ from keras_nlp.models.roberta.roberta_tokenizer import RobertaTokenizer from keras_nlp.models.t5.t5_backbone import T5Backbone from keras_nlp.models.t5.t5_tokenizer import T5Tokenizer +from keras_nlp.models.task import Task from keras_nlp.models.whisper.whisper_audio_feature_extractor import ( WhisperAudioFeatureExtractor, ) diff --git a/keras_nlp/models/backbone.py b/keras_nlp/models/backbone.py index 867616da6..bfdc8207a 100644 --- a/keras_nlp/models/backbone.py +++ b/keras_nlp/models/backbone.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.utils.preset_utils import check_preset_class @@ -20,7 +21,7 @@ from keras_nlp.utils.python_utils import format_docstring -@keras.saving.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.models.Backbone") class Backbone(keras.Model): def __init__(self, *args, dtype=None, **kwargs): super().__init__(*args, **kwargs) diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 0656d2194..9957f6546 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -16,6 +16,7 @@ from rich import markup from rich import table as rich_table +from keras_nlp.api_export import keras_nlp_export from keras_nlp.backend import config from keras_nlp.backend import keras from keras_nlp.utils.keras_utils import print_msg @@ -26,7 +27,7 @@ from keras_nlp.utils.python_utils import format_docstring -@keras.saving.register_keras_serializable(package="keras_nlp") +@keras_nlp_export("keras_nlp.models.Task") class Task(PipelineModel): """Base class for Task models."""