diff --git a/keras_nlp/models/albert/albert_masked_lm.py b/keras_nlp/models/albert/albert_masked_lm.py index ca7f21dd2e..140b23ee54 100644 --- a/keras_nlp/models/albert/albert_masked_lm.py +++ b/keras_nlp/models/albert/albert_masked_lm.py @@ -27,6 +27,7 @@ ) from keras_nlp.models.albert.albert_presets import backbone_presets from keras_nlp.models.task import Task +from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -142,6 +143,13 @@ def __init__(self, backbone, preprocessor=None, **kwargs): self.backbone = backbone self.preprocessor = preprocessor + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + weighted_metrics=keras.metrics.SparseCategoricalAccuracy(), + jit_compile=is_xla_compatible(self), + ) + @classproperty def backbone_cls(cls): return AlbertBackbone diff --git a/keras_nlp/models/albert/albert_masked_lm_test.py b/keras_nlp/models/albert/albert_masked_lm_test.py index 557b1bd411..528ae0c5f3 100644 --- a/keras_nlp/models/albert/albert_masked_lm_test.py +++ b/keras_nlp/models/albert/albert_masked_lm_test.py @@ -114,6 +114,9 @@ def test_albert_masked_lm_predict_no_preprocessing(self, jit_compile): self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile) self.masked_lm_no_preprocessing.predict(self.preprocessed_batch) + def test_albert_masked_lm_fit_default_compile(self): + self.masked_lm.fit(self.raw_dataset) + @parameterized.named_parameters( ("jit_compile_false", False), ("jit_compile_true", True) ) diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py index 79aca77fac..37f3edc15c 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm.py @@ -28,6 +28,7 @@ ) from keras_nlp.models.distil_bert.distil_bert_presets import backbone_presets from keras_nlp.models.task import Task +from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -143,6 +144,13 @@ def __init__( self.backbone = backbone self.preprocessor = preprocessor + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + weighted_metrics=keras.metrics.SparseCategoricalAccuracy(), + jit_compile=is_xla_compatible(self), + ) + @classproperty def backbone_cls(cls): return DistilBertBackbone diff --git a/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py b/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py index 76a5774731..995a993123 100644 --- a/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py +++ b/keras_nlp/models/distil_bert/distil_bert_masked_lm_test.py @@ -89,6 +89,9 @@ def test_distilbert_masked_lm_predict_no_preprocessing(self, jit_compile): self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile) self.masked_lm_no_preprocessing.predict(self.preprocessed_batch) + def test_distil_bert_masked_lm_fit_default_compile(self): + self.masked_lm.fit(self.raw_dataset) + @parameterized.named_parameters( ("jit_compile_false", False), ("jit_compile_true", True) ) diff --git a/keras_nlp/models/roberta/roberta_masked_lm.py b/keras_nlp/models/roberta/roberta_masked_lm.py index dba87ea8ab..43447be4c1 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm.py +++ b/keras_nlp/models/roberta/roberta_masked_lm.py @@ -26,6 +26,7 @@ ) from keras_nlp.models.roberta.roberta_presets import backbone_presets from keras_nlp.models.task import Task +from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -141,6 +142,13 @@ def __init__( self.backbone = backbone self.preprocessor = preprocessor + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + weighted_metrics=keras.metrics.SparseCategoricalAccuracy(), + jit_compile=is_xla_compatible(self), + ) + @classproperty def backbone_cls(cls): return RobertaBackbone diff --git a/keras_nlp/models/roberta/roberta_masked_lm_test.py b/keras_nlp/models/roberta/roberta_masked_lm_test.py index 4026dc9af2..a80c83086b 100644 --- a/keras_nlp/models/roberta/roberta_masked_lm_test.py +++ b/keras_nlp/models/roberta/roberta_masked_lm_test.py @@ -103,6 +103,9 @@ def test_roberta_masked_lm_predict_no_preprocessing(self, jit_compile): self.masked_lm_no_preprocessing.compile(jit_compile=jit_compile) self.masked_lm_no_preprocessing.predict(self.preprocessed_batch) + def test_roberta_masked_lm_fit_default_compile(self): + self.masked_lm.fit(self.raw_dataset) + @parameterized.named_parameters( ("jit_compile_false", False), ("jit_compile_true", True) )