Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions keras_nlp/models/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import os

import tensorflow as tf
from tensorflow import keras

from keras_nlp.utils.keras_utils import print_msg
Expand All @@ -33,6 +34,53 @@ def __init__(self, *args, **kwargs):
self._preprocessor = None
super().__init__(*args, **kwargs)

def _check_for_loss_mismatch(self):
"""Check for a softmax/from_logits mismatch after compile.

We cannot handle this in the general case, but we can handle this for
the extremely common case of a single `SparseCategoricalCrossentropy`
loss, and a `None` or `"softmax"` activation.
"""
# Only handle a single loss.
if tf.nest.is_nested(self.loss):
return
# Only handle tasks with activation.
if not hasattr(self, "activation"):
return

loss = keras.losses.get(self.loss)
activation = keras.activations.get(self.activation)
if isinstance(loss, keras.losses.SparseCategoricalCrossentropy):
from_logits = loss.get_config()["from_logits"]
elif loss == keras.losses.sparse_categorical_crossentropy:
from_logits = False
else:
# Only handle sparse categorical crossentropy.
return

is_softmax = activation == keras.activations.softmax
is_linear = activation == keras.activations.linear
if is_softmax and from_logits:
raise ValueError(
"The `loss` passed to `compile()` expects logit output, but "
"the model is configured to output softmax probabilities "
"(`activation='softmax'`). This will not converge! Pass "
"`from_logits=False` to your loss, e.g. "
"`loss=keras.losses.SparseCategoricalCrossentropy(from_logits=False)`. "
)
if is_linear and not from_logits:
raise ValueError(
"The `loss` passed to `compile()` expects softmax probability "
"output, but the model is configured to output logits "
"(`activation=None`). This will not converge! Pass "
"`from_logits=True` to your loss, e.g. "
"`loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True)`. "
)

def compile(self, *args, **kwargs):
super().compile(*args, **kwargs)
self._check_for_loss_mismatch()

def preprocess_samples(self, x, y=None, sample_weight=None):
return self.preprocessor(x, y=y, sample_weight=sample_weight)

Expand Down
29 changes: 28 additions & 1 deletion keras_nlp/models/task_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.losses import SparseCategoricalCrossentropy

from keras_nlp.models.preprocessor import Preprocessor
from keras_nlp.models.task import Task
Expand All @@ -31,11 +32,12 @@ def __init__(self, **kwargs):


class SimpleTask(Task):
def __init__(self, preprocessor=None, **kwargs):
def __init__(self, preprocessor=None, activation=None, **kwargs):
inputs = keras.Input(shape=(5,))
outputs = keras.layers.Dense(5)(inputs)
super().__init__(inputs, outputs, **kwargs)
self.preprocessor = preprocessor
self.activation = keras.activations.get(activation)


class TestTask(tf.test.TestCase):
Expand All @@ -51,3 +53,28 @@ def test_summary_without_preprocessor(self):
summary = []
model.summary(print_fn=lambda x: summary.append(x))
self.assertNotRegex("\n".join(summary), "Preprocessor:")

def test_mismatched_loss(self):
# Logit output.
model = SimpleTask(activation=None)
model.compile(loss=SparseCategoricalCrossentropy(from_logits=True))
# Non-standard losses should not throw.
model.compile(loss="mean_squared_error")
with self.assertRaises(ValueError):
model.compile(loss="sparse_categorical_crossentropy")
with self.assertRaises(ValueError):
model.compile(loss=SparseCategoricalCrossentropy(from_logits=False))

# Probability output.
model = SimpleTask(activation="softmax")
model.compile(loss=SparseCategoricalCrossentropy(from_logits=False))
model.compile(loss="sparse_categorical_crossentropy")
# Non-standard losses should not throw.
model.compile(loss="mean_squared_error")
with self.assertRaises(ValueError):
model.compile(loss=SparseCategoricalCrossentropy(from_logits=True))

# Non-standard activations should not throw.
model = SimpleTask(activation="tanh")
model.compile(loss=SparseCategoricalCrossentropy(from_logits=True))
model.compile(loss=SparseCategoricalCrossentropy(from_logits=False))