Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
62324be
Add rough class for RougeL
abheesht17 Apr 16, 2022
3bc476a
Fix typos
abheesht17 Apr 16, 2022
b09302d
Correct logic
abheesht17 Apr 16, 2022
cadbd01
Add examples
abheesht17 Apr 16, 2022
3e767ff
Small doc-string changes
abheesht17 Apr 16, 2022
b622cfe
Add alpha example
abheesht17 Apr 16, 2022
38a809f
Small doc-string change
abheesht17 Apr 16, 2022
e3bf503
Fix doc-string
abheesht17 Apr 16, 2022
d25403b
Fix f-string
abheesht17 Apr 17, 2022
2f9a35c
Minor doc-string edit
abheesht17 Apr 17, 2022
9b4c1f1
Minor doc-string edit - 2
abheesht17 Apr 17, 2022
c59aa74
Address review comments - I
abheesht17 Apr 20, 2022
d166ab7
Minor change in examples
abheesht17 Apr 20, 2022
a669d21
Merge branch 'keras-team:master' into rouge-l
abheesht17 May 22, 2022
632df5d
Use the rouge_score package
abheesht17 May 23, 2022
7586e00
Fix rouge_score import
abheesht17 May 23, 2022
893aab9
Add rouge-score to test deps list
abheesht17 May 24, 2022
ccf33d4
Address review comments - II
abheesht17 May 28, 2022
748df81
Address review comments - III
abheesht17 Jun 3, 2022
a793d3d
Fix model.compile error in doc-string
abheesht17 Jun 3, 2022
b8dae75
Rename rouge.py to rouge_base.py
abheesht17 Jun 5, 2022
8050086
Address review comments - IV
abheesht17 Jun 7, 2022
da44d22
Address review comments - IV
abheesht17 Jun 7, 2022
f8c05aa
Return dict from ROUGE
abheesht17 Jun 10, 2022
f4df42b
Fix doc-strings
abheesht17 Jun 10, 2022
723d8e7
Truncate doc-string example output
abheesht17 Jun 10, 2022
b0fe8bc
Remove ROUGE-LSum from doc-string
abheesht17 Jun 16, 2022
7250617
Small doc-string changes
abheesht17 Jun 16, 2022
3c5b3dc
Add TODO comment for dict return bug
abheesht17 Jun 17, 2022
4fa518a
Address review comments - V
abheesht17 Jun 17, 2022
14e851f
Fix doc-string
abheesht17 Jun 17, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions keras_nlp/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
# limitations under the License.

from keras_nlp.metrics.perplexity import Perplexity
from keras_nlp.metrics.rouge_l import RougeL
from keras_nlp.metrics.rouge_n import RougeN
223 changes: 223 additions & 0 deletions keras_nlp/metrics/rouge_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ROUGE metric implementation based on `keras.metrics.Metric`."""


import types

import tensorflow as tf
from tensorflow import keras

from keras_nlp.utils.tensor_utils import tensor_to_string_list

try:
import rouge_score
from rouge_score import rouge_scorer
except ImportError:
rouge_score = None


class RougeBase(keras.metrics.Metric):
"""ROUGE metric.

This class implements two variants of the ROUGE metric - ROUGE-N,
and ROUGE-L.

Note on input shapes:
For `y_true` and `y_pred`, this class supports scalar values and batch
inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.

Args:
variant: string. One of "rougeN", "rougeL". Defaults to
"rouge2". For "rougeN", N lies in the range [1, 9].
use_stemmer: bool. Whether Porter Stemmer should be used to strip word
suffixes to improve matching. Defaults to False.
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
not specified, it defaults to tf.float32.
name: string. Name of the metric instance.
**kwargs: Other keyword arguments.
"""

def __init__(
self,
variant="rouge2",
use_stemmer=False,
dtype=None,
name="rouge",
**kwargs,
):
super().__init__(name=name, dtype=dtype, **kwargs)

if rouge_score is None:
raise ImportError(
"ROUGE metric requires the `rouge_score` package. "
"Please install it with `pip install rouge-score`."
)

if not tf.as_dtype(self.dtype).is_floating:
raise ValueError(
"`dtype` must be a floating point type. "
f"Received: dtype={dtype}"
)

if variant not in tuple(
("rouge" + str(order) for order in range(1, 10))
) + ("rougeL",):
raise ValueError(
"Invalid variant of ROUGE. Should be one of: rougeN, rougeL, "
"with N ranging from 1 to 9. Received: "
f"variant={variant}"
)

self.variant = variant
self.use_stemmer = use_stemmer

# To-do: Add split_summaries and tokenizer options after the maintainers
# of rouge_scorer have released a new version.
self._rouge_scorer = rouge_scorer.RougeScorer(
rouge_types=[self.variant],
use_stemmer=use_stemmer,
)

self._rouge_precision = self.add_weight(
name="rouge_precision",
initializer="zeros",
dtype=self.dtype,
)
self._rouge_recall = self.add_weight(
name="rouge_recall",
initializer="zeros",
dtype=self.dtype,
)
self._rouge_f1_score = self.add_weight(
name="rouge_f1_score",
initializer="zeros",
dtype=self.dtype,
)

self._number_of_samples = self.add_weight(
name="number_of_samples", initializer="zeros", dtype=self.dtype
)

def __new__(cls, *args, **kwargs):
# Temporary workaround for Keras bug with dictionary return types.
# Wraps `result()` with a python dictionary that also supports variable
# assignment. We have to do this with __new__ because the base metric
# class wraps the `results()` method.
# TODO: Remove this snippet of code once the Keras bug is fixed.
obj = super().__new__(cls)

class MetricDict(dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this class necessary? It seems this is juts an alias to dict.

Also let's create a TODO here for future cleanup, this code is hard to maintain.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the reason for defining a class is so that we can do object.var_name type assignments.
If we use a dictionary, this error crops up:

    def replica_local_fn(*args, **kwargs):
      """Updates the state of the metric in a replica-local context."""
      if any(
          isinstance(arg, keras_tensor.KerasTensor)
          for arg in tf.nest.flatten((args, kwargs))):
        update_op = None
      else:
        update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
      update_ops = []
      if update_op is not None:
        update_ops.append(update_op)
      with tf.control_dependencies(update_ops):
        result_t = self.result()  # pylint: disable=not-callable
    
        # We are adding the metric object as metadata on the result tensor.
        # This is required when we want to use a metric with `add_metric` API on
        # a Model/Layer in graph mode. This metric instance will later be used
        # to reset variable state after each epoch of training.
        # Example:
        #   model = Model()
        #   mean = Mean()
        #   model.add_metric(mean(values), name='mean')
>       result_t._metric_obj = self  # pylint: disable=protected-access
E       AttributeError: 'dict' object has no attribute '_metric_obj'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, definitely the plan would be to remove this code after 2.10 is out!

"""A dictionary that supports variable assignment."""

pass

def wrap_result(result_fn):
return tf.__internal__.decorator.make_decorator(
result_fn, lambda obj, *args: MetricDict(result_fn(*args))
)

obj.result = types.MethodType(wrap_result(obj.result), obj)
return obj

def update_state(self, y_true, y_pred, sample_weight=None):
# Three possible shapes for y_true and y_pred: Python string,
# [batch_size] and [batch_size, 1]. In the latter two cases, we have
# strings in the tensor/list.

def validate_and_fix_rank(inputs, tensor_name):
if not isinstance(inputs, tf.Tensor):
inputs = tf.convert_to_tensor(inputs)

if inputs.shape.rank == 0:
return inputs[tf.newaxis]
elif inputs.shape.rank == 1:
return inputs
elif inputs.shape.rank == 2:
if inputs.shape[1] != 1:
raise ValueError(
f"{tensor_name} must be of shape `[batch_size, 1]`. "
f"Found shape: {inputs.shape}"
)
else:
return tf.squeeze(inputs, axis=1)
else:
raise ValueError(
f"{tensor_name} must be of rank 0 (scalar input), 1 or 2. "
f"Found rank: {inputs.shape.rank}"
)

y_true = validate_and_fix_rank(y_true, "y_true")
y_pred = validate_and_fix_rank(y_pred, "y_pred")

batch_size = tf.shape(y_true)[0]

def calculate_rouge_score(reference, hypothesis):
reference = tensor_to_string_list(reference)
hypothesis = tensor_to_string_list(hypothesis)
score = self._rouge_scorer.score(reference, hypothesis)[
self.variant
]
return tf.cast(
tf.constant([score.precision, score.recall, score.fmeasure]),
dtype=self.dtype,
)

for batch_idx in range(batch_size):
score = tf.py_function(
func=calculate_rouge_score,
inp=[y_true[batch_idx], y_pred[batch_idx]],
Tout=self.dtype,
)
self._rouge_precision.assign_add(score[0])
self._rouge_recall.assign_add(score[1])
self._rouge_f1_score.assign_add(score[2])

self._number_of_samples.assign_add(
tf.cast(batch_size, dtype=self.dtype)
)

def result(self):
if self._number_of_samples == 0:
return {
"precision": 0.0,
"recall": 0.0,
"f1_score": 0.0,
}

rouge_precision = self._rouge_precision / self._number_of_samples
rouge_recall = self._rouge_recall / self._number_of_samples
rouge_f1_score = self._rouge_f1_score / self._number_of_samples
return {
"precision": rouge_precision,
"recall": rouge_recall,
"f1_score": rouge_f1_score,
}

def reset_state(self):
self._rouge_precision.assign(0.0)
self._rouge_recall.assign(0.0)
self._rouge_f1_score.assign(0.0)
self._number_of_samples.assign(0.0)

def get_config(self):
config = super().get_config()
config.update(
{
"variant": self.variant,
"use_stemmer": self.use_stemmer,
}
)
return config
129 changes: 129 additions & 0 deletions keras_nlp/metrics/rouge_l.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""ROUGE-L metric implementation based on `keras.metrics.Metric`."""


from keras_nlp.metrics.rouge_base import RougeBase


class RougeL(RougeBase):
"""ROUGE-L metric.

This class implements the ROUGE-L variant of the ROUGE metric. The ROUGE-L
metric is traditionally used for evaluating summarisation systems.
Succinctly put, ROUGE-L is a score based on the length of the longest
common subsequence present in the reference text and the hypothesis text.

Note on input shapes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just comment on the shapes here (not the types). So just say supports scalar and batch inputs of shape (), (batch_size,) and (batch_size, 1).

For `y_true` and `y_pred`, this class supports scalar values and batch
inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.

Args:
use_stemmer: bool. Whether Porter Stemmer should be used to strip word
suffixes to improve matching. Defaults to False.
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
not specified, it defaults to tf.float32.
name: string. Name of the metric instance.
**kwargs: Other keyword arguments.

Examples:

1. Various Input Types.
1.1. Python string.
>>> rouge_l = keras_nlp.metrics.RougeL()
>>> y_true = "the tiny little cat was found under the big funny bed"
>>> y_pred = "the cat was under the bed"
>>> rouge_l(y_true, y_pred)["f1_score"]
<tf.Tensor: shape=(), dtype=float32, numpy=0.7058824>

1.2. rank 1 inputs.
a. Python list.
>>> rouge_l = keras_nlp.metrics.RougeL()
>>> y_true = [
... "the tiny little cat was found under the big funny bed",
... "i really love contributing to KerasNLP",
... ]
>>> y_pred = [
... "the cat was under the bed",
... "i love contributing to KerasNLP",
... ]
>>> rouge_l(y_true, y_pred)["f1_score"]
<tf.Tensor: shape=(), dtype=float32, numpy=0.80748665>

b. Tensor
>>> rouge_l = keras_nlp.metrics.RougeL()
>>> y_true = tf.constant(
... [
... "the tiny little cat was found under the big funny bed",
... "i really love contributing to KerasNLP",
... ]
... )
>>> y_pred = tf.constant(
... [
... "the cat was under the bed",
... "i love contributing to KerasNLP",
... ]
... )
>>> rouge_l(y_true, y_pred)["f1_score"]
<tf.Tensor: shape=(), dtype=float32, numpy=0.80748665>

1.3. rank 2 inputs.
>>> rouge_l = keras_nlp.metrics.RougeL()
>>> y_true = tf.constant(
... [
... ["the tiny little cat was found under the big funny bed"],
... ["i really love contributing to KerasNLP"],
... ]
... )
>>> y_pred = tf.constant(
... [
... ["the cat was under the bed"],
... ["i love contributing to KerasNLP"],
... ]
... )
>>> rouge_l(y_true, y_pred)["f1_score"]
<tf.Tensor: shape=(), dtype=float32, numpy=0.80748665>

3. Pass the metric to `model.compile()`.
>>> inputs = keras.Input(shape=(), dtype='string')
>>> outputs = tf.strings.lower(inputs)
>>> model = keras.Model(inputs, outputs)
>>> model.compile(metrics=[keras_nlp.metrics.RougeL()])
>>> x = tf.constant(["HELLO THIS IS FUN"])
>>> y = tf.constant(["hello this is awesome"])
>>> metric_dict = model.evaluate(x, y, return_dict=True)
>>> metric_dict["f1_score"]
0.75
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some docstring examples! I think the >>> style with actual output, would be useful in this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I forgot to do this. I've added examples in the new commit!


def __init__(
self,
use_stemmer=False,
dtype=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In docstring it defaults to float32, which mismatches the default value None. Please fix it.

name="rouge-l",
**kwargs,
):
super().__init__(
variant="rougeL",
use_stemmer=use_stemmer,
dtype=dtype,
name=name,
**kwargs,
)

def get_config(self):
config = super().get_config()
del config["variant"]
return config
Loading