diff --git a/keras_nlp/models/task.py b/keras_nlp/models/task.py index 9df4d972ad..1b20181318 100644 --- a/keras_nlp/models/task.py +++ b/keras_nlp/models/task.py @@ -15,6 +15,7 @@ import os +import tensorflow as tf from tensorflow import keras from keras_nlp.utils.keras_utils import print_msg @@ -33,6 +34,53 @@ def __init__(self, *args, **kwargs): self._preprocessor = None super().__init__(*args, **kwargs) + def _check_for_loss_mismatch(self): + """Check for a softmax/from_logits mismatch after compile. + + We cannot handle this in the general case, but we can handle this for + the extremely common case of a single `SparseCategoricalCrossentropy` + loss, and a `None` or `"softmax"` activation. + """ + # Only handle a single loss. + if tf.nest.is_nested(self.loss): + return + # Only handle tasks with activation. + if not hasattr(self, "activation"): + return + + loss = keras.losses.get(self.loss) + activation = keras.activations.get(self.activation) + if isinstance(loss, keras.losses.SparseCategoricalCrossentropy): + from_logits = loss.get_config()["from_logits"] + elif loss == keras.losses.sparse_categorical_crossentropy: + from_logits = False + else: + # Only handle sparse categorical crossentropy. + return + + is_softmax = activation == keras.activations.softmax + is_linear = activation == keras.activations.linear + if is_softmax and from_logits: + raise ValueError( + "The `loss` passed to `compile()` expects logit output, but " + "the model is configured to output softmax probabilities " + "(`activation='softmax'`). This will not converge! Pass " + "`from_logits=False` to your loss, e.g. " + "`loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False)`. " + ) + if is_linear and not from_logits: + raise ValueError( + "The `loss` passed to `compile()` expects softmax probability " + "output, but the model is configured to output logits " + "(`activation=None`). This will not converge! Pass " + "`from_logits=True` to your loss, e.g. " + "`loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True)`. " + ) + + def compile(self, *args, **kwargs): + super().compile(*args, **kwargs) + self._check_for_loss_mismatch() + def preprocess_samples(self, x, y=None, sample_weight=None): return self.preprocessor(x, y=y, sample_weight=sample_weight) diff --git a/keras_nlp/models/task_test.py b/keras_nlp/models/task_test.py index 72831a2c84..81d55bd090 100644 --- a/keras_nlp/models/task_test.py +++ b/keras_nlp/models/task_test.py @@ -13,6 +13,7 @@ # limitations under the License. import tensorflow as tf from tensorflow import keras +from tensorflow.keras.losses import SparseCategoricalCrossentropy from keras_nlp.models.preprocessor import Preprocessor from keras_nlp.models.task import Task @@ -31,11 +32,12 @@ def __init__(self, **kwargs): class SimpleTask(Task): - def __init__(self, preprocessor=None, **kwargs): + def __init__(self, preprocessor=None, activation=None, **kwargs): inputs = keras.Input(shape=(5,)) outputs = keras.layers.Dense(5)(inputs) super().__init__(inputs, outputs, **kwargs) self.preprocessor = preprocessor + self.activation = keras.activations.get(activation) class TestTask(tf.test.TestCase): @@ -51,3 +53,28 @@ def test_summary_without_preprocessor(self): summary = [] model.summary(print_fn=lambda x: summary.append(x)) self.assertNotRegex("\n".join(summary), "Preprocessor:") + + def test_mismatched_loss(self): + # Logit output. + model = SimpleTask(activation=None) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True)) + # Non-standard losses should not throw. + model.compile(loss="mean_squared_error") + with self.assertRaises(ValueError): + model.compile(loss="sparse_categorical_crossentropy") + with self.assertRaises(ValueError): + model.compile(loss=SparseCategoricalCrossentropy(from_logits=False)) + + # Probability output. + model = SimpleTask(activation="softmax") + model.compile(loss=SparseCategoricalCrossentropy(from_logits=False)) + model.compile(loss="sparse_categorical_crossentropy") + # Non-standard losses should not throw. + model.compile(loss="mean_squared_error") + with self.assertRaises(ValueError): + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True)) + + # Non-standard activations should not throw. + model = SimpleTask(activation="tanh") + model.compile(loss=SparseCategoricalCrossentropy(from_logits=True)) + model.compile(loss=SparseCategoricalCrossentropy(from_logits=False))