-
Notifications
You must be signed in to change notification settings - Fork 301
Add ROUGE Metric #122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add ROUGE Metric #122
Changes from all commits
62324be
3bc476a
b09302d
cadbd01
3e767ff
b622cfe
38a809f
e3bf503
d25403b
2f9a35c
9b4c1f1
c59aa74
d166ab7
a669d21
632df5d
7586e00
893aab9
ccf33d4
748df81
a793d3d
b8dae75
8050086
da44d22
f8c05aa
f4df42b
723d8e7
b0fe8bc
7250617
3c5b3dc
4fa518a
14e851f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,223 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""ROUGE metric implementation based on `keras.metrics.Metric`.""" | ||
|
||
|
||
import types | ||
|
||
import tensorflow as tf | ||
from tensorflow import keras | ||
|
||
from keras_nlp.utils.tensor_utils import tensor_to_string_list | ||
|
||
try: | ||
import rouge_score | ||
from rouge_score import rouge_scorer | ||
except ImportError: | ||
rouge_score = None | ||
|
||
|
||
class RougeBase(keras.metrics.Metric): | ||
"""ROUGE metric. | ||
|
||
This class implements two variants of the ROUGE metric - ROUGE-N, | ||
and ROUGE-L. | ||
|
||
Note on input shapes: | ||
For `y_true` and `y_pred`, this class supports scalar values and batch | ||
inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`. | ||
|
||
Args: | ||
variant: string. One of "rougeN", "rougeL". Defaults to | ||
"rouge2". For "rougeN", N lies in the range [1, 9]. | ||
use_stemmer: bool. Whether Porter Stemmer should be used to strip word | ||
suffixes to improve matching. Defaults to False. | ||
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If | ||
not specified, it defaults to tf.float32. | ||
name: string. Name of the metric instance. | ||
**kwargs: Other keyword arguments. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
variant="rouge2", | ||
use_stemmer=False, | ||
dtype=None, | ||
name="rouge", | ||
**kwargs, | ||
): | ||
super().__init__(name=name, dtype=dtype, **kwargs) | ||
|
||
if rouge_score is None: | ||
raise ImportError( | ||
"ROUGE metric requires the `rouge_score` package. " | ||
"Please install it with `pip install rouge-score`." | ||
) | ||
|
||
if not tf.as_dtype(self.dtype).is_floating: | ||
raise ValueError( | ||
"`dtype` must be a floating point type. " | ||
f"Received: dtype={dtype}" | ||
) | ||
|
||
if variant not in tuple( | ||
("rouge" + str(order) for order in range(1, 10)) | ||
) + ("rougeL",): | ||
raise ValueError( | ||
"Invalid variant of ROUGE. Should be one of: rougeN, rougeL, " | ||
"with N ranging from 1 to 9. Received: " | ||
f"variant={variant}" | ||
) | ||
|
||
self.variant = variant | ||
self.use_stemmer = use_stemmer | ||
|
||
# To-do: Add split_summaries and tokenizer options after the maintainers | ||
# of rouge_scorer have released a new version. | ||
self._rouge_scorer = rouge_scorer.RougeScorer( | ||
rouge_types=[self.variant], | ||
use_stemmer=use_stemmer, | ||
) | ||
|
||
self._rouge_precision = self.add_weight( | ||
name="rouge_precision", | ||
initializer="zeros", | ||
dtype=self.dtype, | ||
) | ||
self._rouge_recall = self.add_weight( | ||
name="rouge_recall", | ||
initializer="zeros", | ||
dtype=self.dtype, | ||
) | ||
self._rouge_f1_score = self.add_weight( | ||
name="rouge_f1_score", | ||
initializer="zeros", | ||
dtype=self.dtype, | ||
) | ||
|
||
self._number_of_samples = self.add_weight( | ||
name="number_of_samples", initializer="zeros", dtype=self.dtype | ||
) | ||
|
||
def __new__(cls, *args, **kwargs): | ||
# Temporary workaround for Keras bug with dictionary return types. | ||
# Wraps `result()` with a python dictionary that also supports variable | ||
# assignment. We have to do this with __new__ because the base metric | ||
# class wraps the `results()` method. | ||
# TODO: Remove this snippet of code once the Keras bug is fixed. | ||
obj = super().__new__(cls) | ||
|
||
class MetricDict(dict): | ||
"""A dictionary that supports variable assignment.""" | ||
|
||
pass | ||
|
||
def wrap_result(result_fn): | ||
return tf.__internal__.decorator.make_decorator( | ||
result_fn, lambda obj, *args: MetricDict(result_fn(*args)) | ||
) | ||
|
||
obj.result = types.MethodType(wrap_result(obj.result), obj) | ||
return obj | ||
|
||
def update_state(self, y_true, y_pred, sample_weight=None): | ||
# Three possible shapes for y_true and y_pred: Python string, | ||
# [batch_size] and [batch_size, 1]. In the latter two cases, we have | ||
# strings in the tensor/list. | ||
|
||
def validate_and_fix_rank(inputs, tensor_name): | ||
if not isinstance(inputs, tf.Tensor): | ||
inputs = tf.convert_to_tensor(inputs) | ||
|
||
if inputs.shape.rank == 0: | ||
return inputs[tf.newaxis] | ||
elif inputs.shape.rank == 1: | ||
return inputs | ||
elif inputs.shape.rank == 2: | ||
if inputs.shape[1] != 1: | ||
raise ValueError( | ||
f"{tensor_name} must be of shape `[batch_size, 1]`. " | ||
f"Found shape: {inputs.shape}" | ||
) | ||
else: | ||
return tf.squeeze(inputs, axis=1) | ||
else: | ||
raise ValueError( | ||
f"{tensor_name} must be of rank 0 (scalar input), 1 or 2. " | ||
f"Found rank: {inputs.shape.rank}" | ||
) | ||
|
||
y_true = validate_and_fix_rank(y_true, "y_true") | ||
y_pred = validate_and_fix_rank(y_pred, "y_pred") | ||
|
||
batch_size = tf.shape(y_true)[0] | ||
|
||
def calculate_rouge_score(reference, hypothesis): | ||
reference = tensor_to_string_list(reference) | ||
hypothesis = tensor_to_string_list(hypothesis) | ||
score = self._rouge_scorer.score(reference, hypothesis)[ | ||
self.variant | ||
] | ||
return tf.cast( | ||
tf.constant([score.precision, score.recall, score.fmeasure]), | ||
dtype=self.dtype, | ||
) | ||
|
||
for batch_idx in range(batch_size): | ||
score = tf.py_function( | ||
func=calculate_rouge_score, | ||
inp=[y_true[batch_idx], y_pred[batch_idx]], | ||
Tout=self.dtype, | ||
) | ||
self._rouge_precision.assign_add(score[0]) | ||
self._rouge_recall.assign_add(score[1]) | ||
self._rouge_f1_score.assign_add(score[2]) | ||
|
||
self._number_of_samples.assign_add( | ||
tf.cast(batch_size, dtype=self.dtype) | ||
) | ||
|
||
def result(self): | ||
if self._number_of_samples == 0: | ||
return { | ||
"precision": 0.0, | ||
"recall": 0.0, | ||
"f1_score": 0.0, | ||
} | ||
|
||
rouge_precision = self._rouge_precision / self._number_of_samples | ||
rouge_recall = self._rouge_recall / self._number_of_samples | ||
rouge_f1_score = self._rouge_f1_score / self._number_of_samples | ||
return { | ||
"precision": rouge_precision, | ||
"recall": rouge_recall, | ||
"f1_score": rouge_f1_score, | ||
} | ||
|
||
def reset_state(self): | ||
self._rouge_precision.assign(0.0) | ||
self._rouge_recall.assign(0.0) | ||
self._rouge_f1_score.assign(0.0) | ||
self._number_of_samples.assign(0.0) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
config.update( | ||
{ | ||
"variant": self.variant, | ||
"use_stemmer": self.use_stemmer, | ||
} | ||
) | ||
return config |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,129 @@ | ||
# Copyright 2022 The KerasNLP Authors | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# https://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""ROUGE-L metric implementation based on `keras.metrics.Metric`.""" | ||
|
||
|
||
from keras_nlp.metrics.rouge_base import RougeBase | ||
|
||
|
||
class RougeL(RougeBase): | ||
"""ROUGE-L metric. | ||
|
||
This class implements the ROUGE-L variant of the ROUGE metric. The ROUGE-L | ||
metric is traditionally used for evaluating summarisation systems. | ||
Succinctly put, ROUGE-L is a score based on the length of the longest | ||
common subsequence present in the reference text and the hypothesis text. | ||
|
||
Note on input shapes: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would just comment on the shapes here (not the types). So just say supports scalar and batch inputs of shape |
||
For `y_true` and `y_pred`, this class supports scalar values and batch | ||
inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`. | ||
|
||
Args: | ||
use_stemmer: bool. Whether Porter Stemmer should be used to strip word | ||
suffixes to improve matching. Defaults to False. | ||
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If | ||
not specified, it defaults to tf.float32. | ||
name: string. Name of the metric instance. | ||
**kwargs: Other keyword arguments. | ||
|
||
Examples: | ||
|
||
1. Various Input Types. | ||
1.1. Python string. | ||
>>> rouge_l = keras_nlp.metrics.RougeL() | ||
>>> y_true = "the tiny little cat was found under the big funny bed" | ||
>>> y_pred = "the cat was under the bed" | ||
>>> rouge_l(y_true, y_pred)["f1_score"] | ||
<tf.Tensor: shape=(), dtype=float32, numpy=0.7058824> | ||
|
||
1.2. rank 1 inputs. | ||
a. Python list. | ||
>>> rouge_l = keras_nlp.metrics.RougeL() | ||
>>> y_true = [ | ||
... "the tiny little cat was found under the big funny bed", | ||
... "i really love contributing to KerasNLP", | ||
... ] | ||
>>> y_pred = [ | ||
... "the cat was under the bed", | ||
... "i love contributing to KerasNLP", | ||
... ] | ||
>>> rouge_l(y_true, y_pred)["f1_score"] | ||
<tf.Tensor: shape=(), dtype=float32, numpy=0.80748665> | ||
|
||
b. Tensor | ||
>>> rouge_l = keras_nlp.metrics.RougeL() | ||
>>> y_true = tf.constant( | ||
... [ | ||
... "the tiny little cat was found under the big funny bed", | ||
... "i really love contributing to KerasNLP", | ||
... ] | ||
... ) | ||
>>> y_pred = tf.constant( | ||
... [ | ||
... "the cat was under the bed", | ||
... "i love contributing to KerasNLP", | ||
... ] | ||
... ) | ||
>>> rouge_l(y_true, y_pred)["f1_score"] | ||
<tf.Tensor: shape=(), dtype=float32, numpy=0.80748665> | ||
|
||
1.3. rank 2 inputs. | ||
>>> rouge_l = keras_nlp.metrics.RougeL() | ||
>>> y_true = tf.constant( | ||
... [ | ||
... ["the tiny little cat was found under the big funny bed"], | ||
... ["i really love contributing to KerasNLP"], | ||
... ] | ||
... ) | ||
>>> y_pred = tf.constant( | ||
... [ | ||
... ["the cat was under the bed"], | ||
... ["i love contributing to KerasNLP"], | ||
... ] | ||
... ) | ||
>>> rouge_l(y_true, y_pred)["f1_score"] | ||
<tf.Tensor: shape=(), dtype=float32, numpy=0.80748665> | ||
|
||
3. Pass the metric to `model.compile()`. | ||
>>> inputs = keras.Input(shape=(), dtype='string') | ||
>>> outputs = tf.strings.lower(inputs) | ||
>>> model = keras.Model(inputs, outputs) | ||
>>> model.compile(metrics=[keras_nlp.metrics.RougeL()]) | ||
>>> x = tf.constant(["HELLO THIS IS FUN"]) | ||
>>> y = tf.constant(["hello this is awesome"]) | ||
>>> metric_dict = model.evaluate(x, y, return_dict=True) | ||
>>> metric_dict["f1_score"] | ||
0.75 | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add some docstring examples! I think the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, I forgot to do this. I've added examples in the new commit! |
||
|
||
def __init__( | ||
self, | ||
use_stemmer=False, | ||
dtype=None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In docstring it defaults to float32, which mismatches the default value |
||
name="rouge-l", | ||
**kwargs, | ||
): | ||
super().__init__( | ||
variant="rougeL", | ||
use_stemmer=use_stemmer, | ||
dtype=dtype, | ||
name=name, | ||
**kwargs, | ||
) | ||
|
||
def get_config(self): | ||
config = super().get_config() | ||
del config["variant"] | ||
return config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this class necessary? It seems this is juts an alias to
dict
.Also let's create a TODO here for future cleanup, this code is hard to maintain.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually, the reason for defining a class is so that we can do
object.var_name
type assignments.If we use a dictionary, this error crops up:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, definitely the plan would be to remove this code after 2.10 is out!