Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BLEU Score #222

Merged
merged 18 commits into from
Jul 11, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions keras_nlp/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from keras_nlp.metrics.bleu import Bleu
from keras_nlp.metrics.edit_distance import EditDistance
from keras_nlp.metrics.perplexity import Perplexity
from keras_nlp.metrics.rouge_l import RougeL
Expand Down
385 changes: 385 additions & 0 deletions keras_nlp/metrics/bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,385 @@
# Copyright 2022 The KerasNLP Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BLEU metric implementation."""

import collections
import math

import tensorflow as tf
from tensorflow import keras

from keras_nlp.utils.tensor_utils import tensor_to_list
from keras_nlp.utils.tensor_utils import tensor_to_string_list

REPLACE_SUBSTRINGS = [
("<skipped>", ""),
("-\n", ""),
("\n", " "),
("&quot;", '"'),
("&amp;", "&"),
("&lt;", "<"),
("&gt;", ">"),
]


REGEX_PATTERNS = [
# language-dependent part (assuming Western languages)
(r"([\{-\~\[-\` -\&\(-\+\:-\@\/])", r" \1 "),
# tokenize period and comma unless preceded by a digit
(r"([^0-9])([\.,])", r"\1 \2 "),
# tokenize period and comma unless followed by a digit
(r"([\.,])([^0-9])", r" \1 \2"),
# tokenize dash when preceded by a digit
(r"([0-9])(-)", r"\1 \2 "),
# If last character is "." or ",", add space.
(r"[\.,]$", r" \0 \1"),
# one space only between words
(r"\s+", r" "),
]


class Bleu(keras.metrics.Metric):
"""BLEU metric.

This class implements the BLEU metric. BLEU is generally used to evaluate
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably mention more prominently that this will replicate sacrebleu by default, but can be used with other tokenizers e.g. for other languages.

machine translation systems. by default, this implementation replicates
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: capital case By default

SacreBLEU, but user-defined tokenizers can be passed to deal with other
languages.

For BLEU score, we count the number of matching n-grams in the candidate
translation and the reference text. We find the "clipped count" of matching
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
n-grams so as to not give a high score to a (reference, prediction) pair
with redundant, repeated tokens. Secondly, BLEU score tends to reward
shorter predictions more, which is why a brevity penalty is applied to
penalise short predictions.

Note on input shapes:
For `y_true` and `y_pred`, this class supports the following shapes:
If `y_pred` is a scalar value, `y_true` has to be a 1D dense tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requirement on shape match is a little weird. Suppose both y_pred and y_true are scalar, can we just convert y_true for them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, the idea right now is that one sample can have multiple references. So, basically, a translation can have multiple reference sentences. That's why the rank of y_true = rank of y_pred + 1.

For batched inputs, if `y_pred` is a 1D dense tensor, `y_true` has to be
a dense/ragged tensor with shape `(batch_size, None)`.

Args:
tokenizer: callable. A function that takes a string `tf.RaggedTensor`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens it you pass a tokenizer layer here, will that work? Say byte tokenizer for simplicity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, it won't work with byte tokeniser because we use tensor_to_string_list in the code. Do you want me to change that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should either support our tokenizers or not name this argument to something else.

Tokenizer means something specific in our library now, if we use that name but don't support our tokenizer class that is a bad look.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do support our tokenisers. I've added a unit test here:

def test_custom_tokenizer(self):

(of any shape), and tokenizes the strings in the tensor. This
function should use TensorFlow graph ops. If the tokenizer is not
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary? If people are not interested in using model.evaluate(), can they just run it in pure eager mode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but we call the tokeniser after converting the inputs to tensors. So, we have to use TF ops here such as tf.strings.regex_replace().

specified, the default tokenizer is used. The default tokenizer
replicates the behaviour of SacreBLEU's `"tokenizer_13a"` tokenizer
(https://github.com/mjpost/sacrebleu/blob/v2.1.0/sacrebleu/tokenizers/tokenizer_13a.py).
max_order: int. The maximum n-gram order to use. For example, if
`max_order` is set to 3, unigrams, bigrams, and trigrams will be
considered. Defaults to 4.
smooth: bool. Whether to apply Lin et al. 2004 smoothing to the BLEU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we describe this better? Lin et al. 2004 with a period in the middle of the docstring does not read very well. Also please add to reference section.

score. Adds 1 to the matched n-gram count (i.e., numerator) and 1
to the total n-gram count (i.e., denominator) for every order while
calculating precision. Defaults to False.
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
not specified, it defaults to tf.float32.
name: string. Name of the metric instance.
**kwargs: Other keyword arguments.

References:
- [Papineni et al., 2002](https://aclanthology.org/P02-1040/)
- [SacreBLEU](https://github.com/mjpost/sacrebleu)
- [Lin et al., 2004](https://aclanthology.org/P04-1077/)
"""

def __init__(
self,
tokenizer=None,
max_order=4,
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
smooth=False,
dtype=None,
name="bleu",
**kwargs,
):
super().__init__(name=name, dtype=dtype, **kwargs)

if not tf.as_dtype(self.dtype).is_floating:
raise ValueError(
"`dtype` must be a floating point type. "
f"Received: dtype={dtype}"
)

def default_tokenizer(inputs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of nested function that are quite long all over. Not very readable.

Can we either pull them out of the class entirely, or make them class methods?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a few of them class methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make this one a private method called _tokenizer that does.

if self.tokenizer:
    return self.tokenizer(x)

for pattern, replacement...

That way saving self.tokenizer in the config would work as expected.

"""
Default tokenizer. Replicates the behaviour of SacreBLEU's
default tokenizer, namely, `tokenizer_13a`.
"""
for pattern, replacement in REPLACE_SUBSTRINGS + REGEX_PATTERNS:
inputs = tf.strings.regex_replace(
input=inputs,
pattern=pattern,
rewrite=replacement,
replace_global=True,
name=None,
)
inputs = tf.strings.split(inputs)
return inputs

if tokenizer is None:
self.tokenizer = default_tokenizer
else:
self.tokenizer = tokenizer
self.max_order = max_order
self.smooth = smooth

self._matches = self.add_weight(
shape=(self.max_order,),
name="bleu_matches",
initializer="zeros",
dtype=self.dtype,
)
self._possible_matches = self.add_weight(
shape=(self.max_order,),
name="bleu_possible_matches",
initializer="zeros",
dtype=self.dtype,
)
self._translation_length = self.add_weight(
name="bleu_translation_length",
initializer="zeros",
dtype=self.dtype,
)
self._reference_length = self.add_weight(
name="bleu_reference_length",
initializer="zeros",
dtype=self.dtype,
)
self._bleu = self.add_weight(
name="bleu",
initializer="zeros",
dtype=self.dtype,
)

def _get_ngrams(self, segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: up to

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remember to fix this!


Uses Python ops. Inspired from
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.

Args:
segment: string. Text segment from which n-grams will be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a string or tensor of split tokens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this should be list.

extracted.
max_order: int. Maximum length in tokens of the n-grams returned
by this methods.
"""
ngram_counts = collections.Counter()
for order in range(1, max_order + 1):
for i in range(0, len(segment) - order + 1):
ngram = tuple(segment[i : i + order])
ngram_counts[ngram] += 1
return ngram_counts

def _corpus_bleu(
self,
reference_corpus,
translation_corpus,
matches_by_order,
possible_matches_by_order,
translation_length,
reference_length,
max_order=4,
smooth=False,
):
"""Corpus BLEU implementation using Python ops.
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved

Computes BLEU score of translated segments against one or more
references. Inspired from
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.

Args:
reference_corpus: list of lists of references for each
translation. Each reference should be tokenized into a list
of tokens.
translation_corpus: list of translations to score. Each
translation should be tokenized into a list of tokens.
matches_by_order: list of floats containing the initial number
of matches for each order.
possible_matches_by_order: list of floats containing the initial
number of possible matches for each order.
translation_length: float. Initial number of tokens in all the
translations.
reference_length: float. Initial number of tokens in all the
references.
max_order: int. Maximum n-gram order to use when computing
BLEU score.
smooth: boolean. Whether or not to apply Lin et al. 2004
smoothing.
"""
for (references, translation) in zip(
reference_corpus, translation_corpus
):
reference_length += min(len(r) for r in references)
translation_length += len(translation)

merged_ref_ngram_counts = collections.Counter()
for reference in references:
merged_ref_ngram_counts |= self._get_ngrams(
reference, max_order
)
translation_ngram_counts = self._get_ngrams(translation, max_order)
overlap = translation_ngram_counts & merged_ref_ngram_counts
for ngram in overlap:
matches_by_order[len(ngram) - 1] += overlap[ngram]
for order in range(1, max_order + 1):
possible_matches = len(translation) - order + 1
if possible_matches > 0:
possible_matches_by_order[order - 1] += possible_matches

precisions = [0] * max_order
for i in range(0, max_order):
if smooth:
precisions[i] = (matches_by_order[i] + 1.0) / (
possible_matches_by_order[i] + 1.0
)
else:
if possible_matches_by_order[i] > 0:
precisions[i] = (
float(matches_by_order[i])
/ possible_matches_by_order[i]
)
else:
precisions[i] = 0.0

if min(precisions) > 0:
p_log_sum = sum((1.0 / max_order) * math.log(p) for p in precisions)
geo_mean = math.exp(p_log_sum)
else:
geo_mean = 0

ratio = float(translation_length) / reference_length

if ratio > 1.0:
bp = 1.0
else:
bp = math.exp(1 - 1.0 / ratio)

bleu = geo_mean * bp

return (
bleu,
matches_by_order,
possible_matches_by_order,
translation_length,
reference_length,
)

def update_state(self, y_true, y_pred, sample_weight=None):
def validate_and_fix_rank(inputs, tensor_name, base_rank=0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just split this into a separate private method on the layer, code can stay unchanged

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove calculate_bleu_score from here. Let's keep validate_and_fix_rank inside this function for homogeneity (since we have done the same for ROUGE and Edit Distance)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!

if not isinstance(inputs, (tf.Tensor, tf.RaggedTensor)):
inputs = tf.convert_to_tensor(inputs)

if inputs.shape.rank == base_rank:
return inputs[tf.newaxis]
elif inputs.shape.rank == base_rank + 1:
return inputs
else:
raise ValueError(
f"{tensor_name} must be of rank {base_rank} or {base_rank+1}. "
f"Found rank: {inputs.shape.rank}"
)

def calculate_bleu_score(references, translation):
if references.dtype == tf.string:
references = tensor_to_string_list(references)
translation = tensor_to_string_list(translation)
else:
references = tensor_to_list(references)
translation = tensor_to_list(translation)

matches = self._matches.numpy().tolist()
possible_matches = self._possible_matches.numpy().tolist()
translation_length = self._translation_length.numpy()
reference_length = self._reference_length.numpy()

(
bleu_score,
matches,
possible_matches,
translation_length,
reference_length,
) = self._corpus_bleu(
reference_corpus=references,
translation_corpus=translation,
matches_by_order=matches,
possible_matches_by_order=possible_matches,
translation_length=translation_length,
reference_length=reference_length,
max_order=self.max_order,
smooth=self.smooth,
)
return (
tf.constant(bleu_score, dtype=self.dtype),
tf.constant(matches, dtype=self.dtype),
tf.constant(possible_matches, dtype=self.dtype),
tf.constant(translation_length, dtype=self.dtype),
tf.constant(reference_length, dtype=self.dtype),
)

y_true = validate_and_fix_rank(y_true, "y_true", 1)
y_pred = validate_and_fix_rank(y_pred, "y_pred", 0)

# Tokenize the inputs.
y_true = self.tokenizer(y_true)
y_pred = self.tokenizer(y_pred)

(
bleu_score,
matches,
possible_matches,
translation_length,
reference_length,
) = tf.py_function(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that you are converting everything to numpy, would this do better as a numpy_function call?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I think it won't work because we have ragged tensors? Numpy has no support for ragged matrices, I think.

Copy link
Collaborator Author

@abheesht17 abheesht17 Jul 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it fails to convert ragged input tensors to numpy arrays:

File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 224, in _corpus_bleu
    for (references, translation) in zip(
TypeError: in user code:

    File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 349, in update_state  *
        (
    File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 302, in _calculate_bleu_score
        ) = self._corpus_bleu(
    File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 224, in _corpus_bleu
        for (references, translation) in zip(

    TypeError: iteration over a 0-d array

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, let's stick with py_function for now

func=calculate_bleu_score,
inp=[y_true, y_pred],
Tout=[
self.dtype,
self.dtype,
self.dtype,
self.dtype,
self.dtype,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you remove the trailing comma will this format to one line?

],
)

self._matches.assign(matches)
self._possible_matches.assign(possible_matches)
self._translation_length.assign(translation_length)
self._reference_length.assign(reference_length)
self._bleu.assign(bleu_score)

def result(self):
return self._bleu

def reset_state(self):
self._matches.assign(
tf.zeros(shape=(self.max_order,), dtype=self.dtype)
)
self._possible_matches.assign(
tf.zeros(shape=(self.max_order,), dtype=self.dtype)
)
self._translation_length.assign(0.0)
self._reference_length.assign(0.0)
self._bleu.assign(0.0)

def get_config(self):
config = super().get_config()
config.update(
{
"max_order": self.max_order,
mattdangerw marked this conversation as resolved.
Show resolved Hide resolved
"smooth": self.smooth,
}
)
return config
Loading