diff --git a/keras_nlp/models/f_net/f_net_masked_lm.py b/keras_nlp/models/f_net/f_net_masked_lm.py index 65a4f43efa..67bc50f74d 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm.py +++ b/keras_nlp/models/f_net/f_net_masked_lm.py @@ -24,6 +24,7 @@ ) from keras_nlp.models.f_net.f_net_presets import backbone_presets from keras_nlp.models.task import Task +from keras_nlp.utils.keras_utils import is_xla_compatible from keras_nlp.utils.python_utils import classproperty @@ -135,6 +136,12 @@ def __init__( # All references to `self` below this line self.backbone = backbone self.preprocessor = preprocessor + self.compile( + loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True), + optimizer=keras.optimizers.Adam(5e-5), + weighted_metrics=keras.metrics.SparseCategoricalAccuracy(), + jit_compile=is_xla_compatible(self), + ) @classproperty def backbone_cls(cls): diff --git a/keras_nlp/models/f_net/f_net_masked_lm_test.py b/keras_nlp/models/f_net/f_net_masked_lm_test.py index dbf12202ec..3e7e7c9c99 100644 --- a/keras_nlp/models/f_net/f_net_masked_lm_test.py +++ b/keras_nlp/models/f_net/f_net_masked_lm_test.py @@ -87,6 +87,9 @@ def setUp(self): def test_valid_call_masked_lm(self): self.masked_lm(self.preprocessed_batch) + def test_fnet_masked_lm_fit_default_compile(self): + self.masked_lm.fit(self.raw_dataset) + @parameterized.named_parameters( ("jit_compile_false", False), ("jit_compile_true", True) )