diff --git a/keras_nlp/metrics/__init__.py b/keras_nlp/metrics/__init__.py index 7152a97032..55ade6dc8a 100644 --- a/keras_nlp/metrics/__init__.py +++ b/keras_nlp/metrics/__init__.py @@ -13,3 +13,5 @@ # limitations under the License. from keras_nlp.metrics.perplexity import Perplexity +from keras_nlp.metrics.rouge_l import RougeL +from keras_nlp.metrics.rouge_n import RougeN diff --git a/keras_nlp/metrics/rouge_base.py b/keras_nlp/metrics/rouge_base.py new file mode 100644 index 0000000000..22d4adf3b8 --- /dev/null +++ b/keras_nlp/metrics/rouge_base.py @@ -0,0 +1,223 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ROUGE metric implementation based on `keras.metrics.Metric`.""" + + +import types + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.utils.tensor_utils import tensor_to_string_list + +try: + import rouge_score + from rouge_score import rouge_scorer +except ImportError: + rouge_score = None + + +class RougeBase(keras.metrics.Metric): + """ROUGE metric. + + This class implements two variants of the ROUGE metric - ROUGE-N, + and ROUGE-L. + + Note on input shapes: + For `y_true` and `y_pred`, this class supports scalar values and batch + inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`. + + Args: + variant: string. One of "rougeN", "rougeL". Defaults to + "rouge2". For "rougeN", N lies in the range [1, 9]. + use_stemmer: bool. Whether Porter Stemmer should be used to strip word + suffixes to improve matching. Defaults to False. + dtype: string or tf.dtypes.Dtype. Precision of metric computation. If + not specified, it defaults to tf.float32. + name: string. Name of the metric instance. + **kwargs: Other keyword arguments. + """ + + def __init__( + self, + variant="rouge2", + use_stemmer=False, + dtype=None, + name="rouge", + **kwargs, + ): + super().__init__(name=name, dtype=dtype, **kwargs) + + if rouge_score is None: + raise ImportError( + "ROUGE metric requires the `rouge_score` package. " + "Please install it with `pip install rouge-score`." + ) + + if not tf.as_dtype(self.dtype).is_floating: + raise ValueError( + "`dtype` must be a floating point type. " + f"Received: dtype={dtype}" + ) + + if variant not in tuple( + ("rouge" + str(order) for order in range(1, 10)) + ) + ("rougeL",): + raise ValueError( + "Invalid variant of ROUGE. Should be one of: rougeN, rougeL, " + "with N ranging from 1 to 9. Received: " + f"variant={variant}" + ) + + self.variant = variant + self.use_stemmer = use_stemmer + + # To-do: Add split_summaries and tokenizer options after the maintainers + # of rouge_scorer have released a new version. + self._rouge_scorer = rouge_scorer.RougeScorer( + rouge_types=[self.variant], + use_stemmer=use_stemmer, + ) + + self._rouge_precision = self.add_weight( + name="rouge_precision", + initializer="zeros", + dtype=self.dtype, + ) + self._rouge_recall = self.add_weight( + name="rouge_recall", + initializer="zeros", + dtype=self.dtype, + ) + self._rouge_f1_score = self.add_weight( + name="rouge_f1_score", + initializer="zeros", + dtype=self.dtype, + ) + + self._number_of_samples = self.add_weight( + name="number_of_samples", initializer="zeros", dtype=self.dtype + ) + + def __new__(cls, *args, **kwargs): + # Temporary workaround for Keras bug with dictionary return types. + # Wraps `result()` with a python dictionary that also supports variable + # assignment. We have to do this with __new__ because the base metric + # class wraps the `results()` method. + # TODO: Remove this snippet of code once the Keras bug is fixed. + obj = super().__new__(cls) + + class MetricDict(dict): + """A dictionary that supports variable assignment.""" + + pass + + def wrap_result(result_fn): + return tf.__internal__.decorator.make_decorator( + result_fn, lambda obj, *args: MetricDict(result_fn(*args)) + ) + + obj.result = types.MethodType(wrap_result(obj.result), obj) + return obj + + def update_state(self, y_true, y_pred, sample_weight=None): + # Three possible shapes for y_true and y_pred: Python string, + # [batch_size] and [batch_size, 1]. In the latter two cases, we have + # strings in the tensor/list. + + def validate_and_fix_rank(inputs, tensor_name): + if not isinstance(inputs, tf.Tensor): + inputs = tf.convert_to_tensor(inputs) + + if inputs.shape.rank == 0: + return inputs[tf.newaxis] + elif inputs.shape.rank == 1: + return inputs + elif inputs.shape.rank == 2: + if inputs.shape[1] != 1: + raise ValueError( + f"{tensor_name} must be of shape `[batch_size, 1]`. " + f"Found shape: {inputs.shape}" + ) + else: + return tf.squeeze(inputs, axis=1) + else: + raise ValueError( + f"{tensor_name} must be of rank 0 (scalar input), 1 or 2. " + f"Found rank: {inputs.shape.rank}" + ) + + y_true = validate_and_fix_rank(y_true, "y_true") + y_pred = validate_and_fix_rank(y_pred, "y_pred") + + batch_size = tf.shape(y_true)[0] + + def calculate_rouge_score(reference, hypothesis): + reference = tensor_to_string_list(reference) + hypothesis = tensor_to_string_list(hypothesis) + score = self._rouge_scorer.score(reference, hypothesis)[ + self.variant + ] + return tf.cast( + tf.constant([score.precision, score.recall, score.fmeasure]), + dtype=self.dtype, + ) + + for batch_idx in range(batch_size): + score = tf.py_function( + func=calculate_rouge_score, + inp=[y_true[batch_idx], y_pred[batch_idx]], + Tout=self.dtype, + ) + self._rouge_precision.assign_add(score[0]) + self._rouge_recall.assign_add(score[1]) + self._rouge_f1_score.assign_add(score[2]) + + self._number_of_samples.assign_add( + tf.cast(batch_size, dtype=self.dtype) + ) + + def result(self): + if self._number_of_samples == 0: + return { + "precision": 0.0, + "recall": 0.0, + "f1_score": 0.0, + } + + rouge_precision = self._rouge_precision / self._number_of_samples + rouge_recall = self._rouge_recall / self._number_of_samples + rouge_f1_score = self._rouge_f1_score / self._number_of_samples + return { + "precision": rouge_precision, + "recall": rouge_recall, + "f1_score": rouge_f1_score, + } + + def reset_state(self): + self._rouge_precision.assign(0.0) + self._rouge_recall.assign(0.0) + self._rouge_f1_score.assign(0.0) + self._number_of_samples.assign(0.0) + + def get_config(self): + config = super().get_config() + config.update( + { + "variant": self.variant, + "use_stemmer": self.use_stemmer, + } + ) + return config diff --git a/keras_nlp/metrics/rouge_l.py b/keras_nlp/metrics/rouge_l.py new file mode 100644 index 0000000000..f6969a85f6 --- /dev/null +++ b/keras_nlp/metrics/rouge_l.py @@ -0,0 +1,129 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ROUGE-L metric implementation based on `keras.metrics.Metric`.""" + + +from keras_nlp.metrics.rouge_base import RougeBase + + +class RougeL(RougeBase): + """ROUGE-L metric. + + This class implements the ROUGE-L variant of the ROUGE metric. The ROUGE-L + metric is traditionally used for evaluating summarisation systems. + Succinctly put, ROUGE-L is a score based on the length of the longest + common subsequence present in the reference text and the hypothesis text. + + Note on input shapes: + For `y_true` and `y_pred`, this class supports scalar values and batch + inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`. + + Args: + use_stemmer: bool. Whether Porter Stemmer should be used to strip word + suffixes to improve matching. Defaults to False. + dtype: string or tf.dtypes.Dtype. Precision of metric computation. If + not specified, it defaults to tf.float32. + name: string. Name of the metric instance. + **kwargs: Other keyword arguments. + + Examples: + + 1. Various Input Types. + 1.1. Python string. + >>> rouge_l = keras_nlp.metrics.RougeL() + >>> y_true = "the tiny little cat was found under the big funny bed" + >>> y_pred = "the cat was under the bed" + >>> rouge_l(y_true, y_pred)["f1_score"] + + + 1.2. rank 1 inputs. + a. Python list. + >>> rouge_l = keras_nlp.metrics.RougeL() + >>> y_true = [ + ... "the tiny little cat was found under the big funny bed", + ... "i really love contributing to KerasNLP", + ... ] + >>> y_pred = [ + ... "the cat was under the bed", + ... "i love contributing to KerasNLP", + ... ] + >>> rouge_l(y_true, y_pred)["f1_score"] + + + b. Tensor + >>> rouge_l = keras_nlp.metrics.RougeL() + >>> y_true = tf.constant( + ... [ + ... "the tiny little cat was found under the big funny bed", + ... "i really love contributing to KerasNLP", + ... ] + ... ) + >>> y_pred = tf.constant( + ... [ + ... "the cat was under the bed", + ... "i love contributing to KerasNLP", + ... ] + ... ) + >>> rouge_l(y_true, y_pred)["f1_score"] + + + 1.3. rank 2 inputs. + >>> rouge_l = keras_nlp.metrics.RougeL() + >>> y_true = tf.constant( + ... [ + ... ["the tiny little cat was found under the big funny bed"], + ... ["i really love contributing to KerasNLP"], + ... ] + ... ) + >>> y_pred = tf.constant( + ... [ + ... ["the cat was under the bed"], + ... ["i love contributing to KerasNLP"], + ... ] + ... ) + >>> rouge_l(y_true, y_pred)["f1_score"] + + + 3. Pass the metric to `model.compile()`. + >>> inputs = keras.Input(shape=(), dtype='string') + >>> outputs = tf.strings.lower(inputs) + >>> model = keras.Model(inputs, outputs) + >>> model.compile(metrics=[keras_nlp.metrics.RougeL()]) + >>> x = tf.constant(["HELLO THIS IS FUN"]) + >>> y = tf.constant(["hello this is awesome"]) + >>> metric_dict = model.evaluate(x, y, return_dict=True) + >>> metric_dict["f1_score"] + 0.75 + """ + + def __init__( + self, + use_stemmer=False, + dtype=None, + name="rouge-l", + **kwargs, + ): + super().__init__( + variant="rougeL", + use_stemmer=use_stemmer, + dtype=dtype, + name=name, + **kwargs, + ) + + def get_config(self): + config = super().get_config() + del config["variant"] + return config diff --git a/keras_nlp/metrics/rouge_l_test.py b/keras_nlp/metrics/rouge_l_test.py new file mode 100644 index 0000000000..d130e12190 --- /dev/null +++ b/keras_nlp/metrics/rouge_l_test.py @@ -0,0 +1,228 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RougeL.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.metrics import RougeL + + +class RougeLTest(tf.test.TestCase): + def setUp(self): + super().setUp() + + def assertDictAlmostEqual(d1, d2, delta=1e-3, typecast_to_numpy=True): + for key, val in d1.items(): + if typecast_to_numpy: + val = val.numpy() + self.assertAlmostEqual(val, d2[key], delta=delta) + + def assertDictAllValuesNotEqual(d1, d2): + for key, val in d1.items(): + self.assertNotEqual(val, d2[key]) + + self.assertDictAlmostEqual = assertDictAlmostEqual + self.assertDictAllValuesNotEqual = assertDictAllValuesNotEqual + + def test_initialization(self): + rouge = RougeL() + result = rouge.result() + + self.assertDictEqual( + result, {"precision": 0.0, "recall": 0.0, "f1_score": 0.0} + ) + + def test_string_input(self): + rouge = RougeL(use_stemmer=False) + y_true = "the tiny little cat was found under the big funny bed" + y_pred = "the cat was under the bed" + + rouge_val = rouge(y_true, y_pred) + self.assertDictAlmostEqual( + rouge_val, {"precision": 1.0, "recall": 0.545, "f1_score": 0.706} + ) + + def test_string_list_input(self): + rouge = RougeL(use_stemmer=False) + y_true = [ + "the tiny little cat was found under the big funny bed", + "i really love contributing to KerasNLP", + ] + y_pred = [ + "the cat was under the bed", + "i love contributing to KerasNLP", + ] + + rouge_val = rouge(y_true, y_pred) + self.assertDictAlmostEqual( + rouge_val, {"precision": 1.0, "recall": 0.689, "f1_score": 0.807} + ) + + def test_tensor_input(self): + rouge = RougeL(use_stemmer=False) + y_true = tf.constant( + [ + "the tiny little cat was found under the big funny bed", + "i really love contributing to KerasNLP", + ] + ) + y_pred = tf.constant( + ["the cat was under the bed", "i love contributing to KerasNLP"] + ) + + rouge_val = rouge(y_true, y_pred) + self.assertDictAlmostEqual( + rouge_val, {"precision": 1.0, "recall": 0.689, "f1_score": 0.807} + ) + + def test_rank_2_input(self): + rouge = RougeL(use_stemmer=False) + y_true = tf.constant( + [ + ["the tiny little cat was found under the big funny bed"], + ["i really love contributing to KerasNLP"], + ] + ) + y_pred = tf.constant( + [["the cat was under the bed"], ["i love contributing to KerasNLP"]] + ) + + rouge_val = rouge(y_true, y_pred) + self.assertDictAlmostEqual( + rouge_val, {"precision": 1.0, "recall": 0.689, "f1_score": 0.807} + ) + + def test_model_compile(self): + inputs = keras.Input(shape=(), dtype="string") + outputs = tf.strings.lower(inputs) + model = keras.Model(inputs, outputs) + + model.compile(metrics=[RougeL()]) + + x = tf.constant(["HELLO THIS IS FUN"]) + y = tf.constant(["hello this is awesome"]) + + output = model.evaluate(x, y, return_dict=True) + del output["loss"] + self.assertDictAlmostEqual( + output, + {"precision": 0.75, "recall": 0.75, "f1_score": 0.75}, + typecast_to_numpy=False, + ) + + def test_reset_state(self): + rouge = RougeL() + y_true = tf.constant( + ["hey, this is great fun", "i love contributing to KerasNLP"] + ) + y_pred = tf.constant( + [ + "great fun indeed", + "KerasNLP is awesome, i love contributing to it", + ] + ) + + rouge.update_state(y_true, y_pred) + rouge_val = rouge.result() + self.assertDictAllValuesNotEqual( + rouge_val, {"precision": 0.0, "recall": 0.0, "f1_score": 0.0} + ) + + rouge.reset_state() + rouge_val = rouge.result() + self.assertDictEqual( + rouge_val, {"precision": 0.0, "recall": 0.0, "f1_score": 0.0} + ) + + def test_update_state(self): + rouge = RougeL() + y_true_1 = tf.constant( + [ + "the tiny little cat was found under the big funny bed", + "i really love contributing to KerasNLP", + ] + ) + y_pred_1 = tf.constant( + ["the cat was under the bed", "i love contributing to KerasNLP"] + ) + + rouge.update_state(y_true_1, y_pred_1) + rouge_val = rouge.result() + self.assertDictAlmostEqual( + rouge_val, {"precision": 1.0, "recall": 0.689, "f1_score": 0.807} + ) + + y_true_2 = tf.constant(["what is your favourite show"]) + y_pred_2 = tf.constant(["my favourite show is silicon valley"]) + + rouge.update_state(y_true_2, y_pred_2) + rouge_val = rouge.result() + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.778, "recall": 0.593, "f1_score": 0.66} + ) + + def test_merge_state(self): + rouge_1 = RougeL() + rouge_2 = RougeL() + + y_true_1 = tf.constant( + [ + "the tiny little cat was found under the big funny bed", + "i really love contributing to KerasNLP", + ] + ) + y_pred_1 = tf.constant( + ["the cat was under the bed", "i love contributing to KerasNLP"] + ) + + y_true_2 = tf.constant(["what is your favourite show"]) + y_pred_2 = tf.constant(["my favourite show is silicon valley"]) + + y_true_3 = tf.constant(["lorem ipsum dolor sit amet"]) + y_pred_3 = tf.constant(["lorem ipsum is simply dummy text"]) + + rouge_1.update_state(y_true_1, y_pred_1) + rouge_1.update_state(y_true_2, y_pred_2) + rouge_val = rouge_1.result() + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.778, "recall": 0.593, "f1_score": 0.66} + ) + + rouge_2.update_state(y_true_3, y_pred_3) + rouge_val = rouge_2.result() + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.333, "recall": 0.4, "f1_score": 0.364} + ) + + merged_rouge = RougeL() + merged_rouge.merge_state([rouge_1, rouge_2]) + rouge_val = merged_rouge.result() + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.667, "recall": 0.545, "f1_score": 0.586} + ) + + def test_get_config(self): + rouge = RougeL( + use_stemmer=True, + dtype=tf.float32, + name="rouge_l_test", + ) + + config = rouge.get_config() + expected_config_subset = { + "use_stemmer": True, + } + self.assertEqual(config, {**config, **expected_config_subset}) diff --git a/keras_nlp/metrics/rouge_n.py b/keras_nlp/metrics/rouge_n.py new file mode 100644 index 0000000000..4bfe532ee2 --- /dev/null +++ b/keras_nlp/metrics/rouge_n.py @@ -0,0 +1,163 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""ROUGE-N metric implementation based on `keras.metrics.Metric`.""" + + +from keras_nlp.metrics.rouge_base import RougeBase + + +class RougeN(RougeBase): + """ROUGE-N metric. + + This class implements the ROUGE-N variant of the ROUGE metric. The ROUGE-N + metric is traditionally used for evaluating summarisation systems. + Succinctly put, ROUGE-N is a score based on the number of matching n-grams + between the reference text and the hypothesis text. + + Note on input shapes: + For `y_true` and `y_pred`, this class supports scalar values and batch + inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`. + + Args: + order: The order of n-grams which are to be matched. It should lie in + range [1, 9]. Defaults to 2. + use_stemmer: bool. Whether Porter Stemmer should be used to strip word + suffixes to improve matching. Defaults to False. + dtype: string or tf.dtypes.Dtype. Precision of metric computation. If + not specified, it defaults to tf.float32. + name: string. Name of the metric instance. + **kwargs: Other keyword arguments. + + Examples: + + 1. Various Input Types. + 1.1. Python string. + >>> rouge_n = keras_nlp.metrics.RougeN(order=2) + >>> y_true = "the tiny little cat was found under the big funny bed" + >>> y_pred = "the cat was under the bed" + >>> rouge_n(y_true, y_pred)["f1_score"] + + + 1.2. rank 1 inputs. + a. Python list. + >>> rouge_n = keras_nlp.metrics.RougeN(order=2) + >>> y_true = [ + ... "the tiny little cat was found under the big funny bed", + ... "i really love contributing to KerasNLP", + ... ] + >>> y_pred = [ + ... "the cat was under the bed", + ... "i love contributing to KerasNLP", + ... ] + >>> rouge_n(y_true, y_pred)["f1_score"] + + + b. Tensor. + >>> rouge_n = keras_nlp.metrics.RougeN(order=2) + >>> y_true = tf.constant( + ... [ + ... "the tiny little cat was found under the big funny bed", + ... "i really love contributing to KerasNLP", + ... ] + ... ) + >>> y_pred = tf.constant( + ... [ + ... "the cat was under the bed", + ... "i love contributing to KerasNLP", + ... ] + ... ) + >>> rouge_n(y_true, y_pred)["f1_score"] + + + 1.3. rank 2 inputs. + >>> rouge_n = keras_nlp.metrics.RougeN(order=2) + >>> y_true = tf.constant( + ... [ + ... ["the tiny little cat was found under the big funny bed"], + ... ["i really love contributing to KerasNLP"], + ... ] + ... ) + >>> y_pred = tf.constant( + ... [ + ... ["the cat was under the bed"], + ... ["i love contributing to KerasNLP"], + ... ] + ... ) + >>> rouge_n(y_true, y_pred)["f1_score"] + + + 2. Consider trigrams for calculating ROUGE-N. + >>> rouge_n = keras_nlp.metrics.RougeN(order=3) + >>> y_true = tf.constant( + ... [ + ... "the tiny little cat was found under the big funny bed", + ... "i really love contributing to KerasNLP", + ... ] + ... ) + >>> y_pred = tf.constant( + ... [ + ... "the cat was under the bed", + ... "i love contributing to KerasNLP", + ... ] + ... ) + >>> rouge_n(y_true, y_pred)["f1_score"] + + + 3. Pass the metric to `model.compile()`. + >>> inputs = keras.Input(shape=(), dtype='string') + >>> outputs = tf.strings.lower(inputs) + >>> model = keras.Model(inputs, outputs) + >>> model.compile(metrics=[keras_nlp.metrics.RougeN()]) + >>> x = tf.constant(["HELLO THIS IS FUN"]) + >>> y = tf.constant(["hello this is awesome"]) + >>> metric_dict = model.evaluate(x, y, return_dict=True) + >>> metric_dict["f1_score"] + 0.6666666865348816 + """ + + def __init__( + self, + order=2, + use_stemmer=False, + dtype=None, + name="rouge-n", + **kwargs, + ): + if order not in range(1, 10): + raise ValueError( + "Invalid `order` value. Should lie in the range [1, 9]." + f"Received order={order}" + ) + + super().__init__( + variant=f"rouge{order}", + use_stemmer=use_stemmer, + dtype=dtype, + name=name, + **kwargs, + ) + + self.order = order + + def get_config(self): + config = super().get_config() + del config["variant"] + + config.update( + { + "order": self.order, + } + ) + return config diff --git a/keras_nlp/metrics/rouge_n_test.py b/keras_nlp/metrics/rouge_n_test.py new file mode 100644 index 0000000000..2183afe3fe --- /dev/null +++ b/keras_nlp/metrics/rouge_n_test.py @@ -0,0 +1,254 @@ +# Copyright 2022 The KerasNLP Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Tests for RougeN.""" + +import tensorflow as tf +from tensorflow import keras + +from keras_nlp.metrics import RougeN + + +class RougeNTest(tf.test.TestCase): + def setUp(self): + super().setUp() + + def assertDictAlmostEqual(d1, d2, delta=1e-3, typecast_to_numpy=True): + for key, val in d1.items(): + if typecast_to_numpy: + val = val.numpy() + self.assertAlmostEqual(val, d2[key], delta=delta) + + def assertDictAllValuesNotEqual(d1, d2): + for key, val in d1.items(): + self.assertNotEqual(val, d2[key]) + + self.assertDictAlmostEqual = assertDictAlmostEqual + self.assertDictAllValuesNotEqual = assertDictAllValuesNotEqual + + def test_initialization(self): + rouge = RougeN() + result = rouge.result() + + self.assertDictEqual( + result, {"precision": 0.0, "recall": 0.0, "f1_score": 0.0} + ) + + def test_string_input(self): + rouge = RougeN(order=2, use_stemmer=False) + y_true = "the tiny little cat was found under the big funny bed" + y_pred = "the cat was under the bed" + + rouge_val = rouge(y_true, y_pred) + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.4, "recall": 0.2, "f1_score": 0.267} + ) + + def test_string_list_input(self): + rouge = RougeN(order=2, use_stemmer=False) + y_true = [ + "the tiny little cat was found under the big funny bed", + "i really love contributing to KerasNLP", + ] + y_pred = [ + "the cat was under the bed", + "i love contributing to KerasNLP", + ] + + rouge_val = rouge(y_true, y_pred) + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.575, "recall": 0.4, "f1_score": 0.467} + ) + + def test_tensor_input(self): + rouge = RougeN(order=2, use_stemmer=False) + y_true = tf.constant( + [ + "the tiny little cat was found under the big funny bed", + "i really love contributing to KerasNLP", + ] + ) + y_pred = tf.constant( + ["the cat was under the bed", "i love contributing to KerasNLP"] + ) + + rouge_val = rouge(y_true, y_pred) + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.575, "recall": 0.4, "f1_score": 0.467} + ) + + def test_rank_2_input(self): + rouge = RougeN(order=2, use_stemmer=False) + y_true = tf.constant( + [ + ["the tiny little cat was found under the big funny bed"], + ["i really love contributing to KerasNLP"], + ] + ) + y_pred = tf.constant( + [["the cat was under the bed"], ["i love contributing to KerasNLP"]] + ) + + rouge_val = rouge(y_true, y_pred) + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.575, "recall": 0.4, "f1_score": 0.467} + ) + + def test_model_compile(self): + inputs = keras.Input(shape=(), dtype="string") + outputs = tf.strings.lower(inputs) + model = keras.Model(inputs, outputs) + + model.compile(metrics=[RougeN()]) + + x = tf.constant(["HELLO THIS IS FUN"]) + y = tf.constant(["hello this is awesome"]) + + output = model.evaluate(x, y, return_dict=True) + del output["loss"] + self.assertDictAlmostEqual( + output, + {"precision": 0.667, "recall": 0.667, "f1_score": 0.667}, + typecast_to_numpy=False, + ) + + def test_incorrect_order(self): + with self.assertRaises(ValueError): + _ = RougeN(order=10) + + def test_different_order(self): + rouge = RougeN(order=3, use_stemmer=False) + y_true = tf.constant( + [ + "the tiny little cat was found under the big funny bed", + "i really love contributing to KerasNLP", + ] + ) + y_pred = tf.constant( + ["the cat was under the bed", "i love contributing to KerasNLP"] + ) + + rouge_val = rouge(y_true, y_pred) + self.assertDictAlmostEqual( + rouge_val, + {"precision": 0.333, "recall": 0.25, "f1_score": 0.286}, + typecast_to_numpy=False, + ) + + def test_reset_state(self): + rouge = RougeN() + y_true = tf.constant( + ["hey, this is great fun", "i love contributing to KerasNLP"] + ) + y_pred = tf.constant( + [ + "great fun indeed", + "KerasNLP is awesome, i love contributing to it", + ] + ) + + rouge.update_state(y_true, y_pred) + rouge_val = rouge.result() + self.assertDictAllValuesNotEqual( + rouge_val, {"precision": 0.0, "recall": 0.0, "f1_score": 0.0} + ) + + rouge.reset_state() + rouge_val = rouge.result() + self.assertDictEqual( + rouge_val, {"precision": 0.0, "recall": 0.0, "f1_score": 0.0} + ) + + def test_update_state(self): + rouge = RougeN() + y_true_1 = tf.constant( + [ + "the tiny little cat was found under the big funny bed", + "i really love contributing to KerasNLP", + ] + ) + y_pred_1 = tf.constant( + ["the cat was under the bed", "i love contributing to KerasNLP"] + ) + + rouge.update_state(y_true_1, y_pred_1) + rouge_val = rouge.result() + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.575, "recall": 0.4, "f1_score": 0.467} + ) + + y_true_2 = tf.constant(["what is your favourite show"]) + y_pred_2 = tf.constant(["my favourite show is silicon valley"]) + + rouge.update_state(y_true_2, y_pred_2) + rouge_val = rouge.result() + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.45, "recall": 0.35, "f1_score": 0.385} + ) + + def test_merge_state(self): + rouge_1 = RougeN() + rouge_2 = RougeN() + + y_true_1 = tf.constant( + [ + "the tiny little cat was found under the big funny bed", + "i really love contributing to KerasNLP", + ] + ) + y_pred_1 = tf.constant( + ["the cat was under the bed", "i love contributing to KerasNLP"] + ) + + y_true_2 = tf.constant(["what is your favourite show"]) + y_pred_2 = tf.constant(["my favourite show is silicon valley"]) + + y_true_3 = tf.constant(["lorem ipsum dolor sit amet"]) + y_pred_3 = tf.constant(["lorem ipsum is simply dummy text"]) + + rouge_1.update_state(y_true_1, y_pred_1) + rouge_1.update_state(y_true_2, y_pred_2) + rouge_val = rouge_1.result() + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.45, "recall": 0.35, "f1_score": 0.385} + ) + + rouge_2.update_state(y_true_3, y_pred_3) + rouge_val = rouge_2.result() + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.2, "recall": 0.25, "f1_score": 0.222} + ) + + merged_rouge = RougeN() + merged_rouge.merge_state([rouge_1, rouge_2]) + rouge_val = merged_rouge.result() + self.assertDictAlmostEqual( + rouge_val, {"precision": 0.388, "recall": 0.325, "f1_score": 0.344} + ) + + def test_get_config(self): + rouge = RougeN( + order=5, + use_stemmer=True, + dtype=tf.float32, + name="rouge_n_test", + ) + + config = rouge.get_config() + expected_config_subset = { + "order": 5, + "use_stemmer": True, + } + + self.assertEqual(config, {**config, **expected_config_subset}) diff --git a/setup.py b/setup.py index 371a3e08f5..fac287ae7c 100644 --- a/setup.py +++ b/setup.py @@ -48,6 +48,7 @@ "isort", "pytest", "pytest-cov", + "rouge-score", ], "examples": [ "datasets", # For GLUE in BERT example.