Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add BLEU Score #222

Merged
merged 18 commits into from
Jul 11, 2022
Merged

Add BLEU Score #222

merged 18 commits into from
Jul 11, 2022

Conversation

abheesht17
Copy link
Collaborator

@abheesht17 abheesht17 commented Jun 13, 2022

Resolves #65

Notebooks:

Edit (commit e90bef3):

We have decided to go ahead with a tf.py_function implementation. Yet to add UTs, but otherwise, it's done.

Notebooks:

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for all the work on this! Left a few questions and a few comments.

Feel free to start adding unit tests!

inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.

Args:
tokenizer: callable. A function that takes a string `tf.RaggedTensor`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What happens it you pass a tokenizer layer here, will that work? Say byte tokenizer for simplicity.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, it won't work with byte tokeniser because we use tensor_to_string_list in the code. Do you want me to change that?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should either support our tokenizers or not name this argument to something else.

Tokenizer means something specific in our library now, if we use that name but don't support our tokenizer class that is a bad look.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do support our tokenisers. I've added a unit test here:

def test_custom_tokenizer(self):

class Bleu(keras.metrics.Metric):
"""BLEU metric.

This class implements the BLEU metric. BLEU is generally used to evaluate
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably mention more prominently that this will replicate sacrebleu by default, but can be used with other tokenizers e.g. for other languages.

)

def _get_ngrams(segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

summary sentence for a docstring should always fit on a single line.

ngram_counts[ngram] += 1
return ngram_counts

def corpus_bleu(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this public? do we expect this to be called explicitly? If so can you show the use case?

keras_nlp/metrics/bleu.py Show resolved Hide resolved
considered. Defaults to 4.
smooth: bool. Whether to apply Lin et al. 2004 smoothing to the BLEU
score. Defaults to False.
variant: string. Either `"corpus_bleu"` or `"sentence_bleu"`. The former
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like corpus bleu is the better option here? I see that sacrebleu exposes methods for both of these, but does not seems to document the sentence one. Huggingface looks like it might not even have an option for this (is that true?).

I guess I'm wondering if it might make sense to not even expose this, and wait till someone asks for the sentence option.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After doing a survey, this is what I found:

Conclusion: I think the expectation is that, if users want to compute the Sentence BLEU score, they can do so by passing one sample at a time, and average over the returned scores.

Some additional notes
However, another point to note is that HF provides two options with all its metrics:

  • .compute() - user can pass one sample at a time, get the BLEU scores, and average over them for computing Sentence BLEU.
  • .add_batch() - will compute the Corpus BLEU score across all samples across batches.

We use Keras metrics similar to the add_batch() function. So, if the user wants to compute the Sentence BLEU score, he/she/they will have to re-initialise the metric for every sample. PyTorch Ignite metrics also work similar to the add_batch function, which is why they have provided an option for macro/micro-averaging. So, I am just wondering whether HF and NLTK do not provide explicit options to macro-average the BLEU scores because the user can average the BLEU scores. But with Ignite, the user can't do that without re-initialising before every sample, which is why an option has been provided,

max_order: int. The maximum n-gram order to use. For example, if
`max_order` is set to 3, unigrams, bigrams, and trigrams will be
considered. Defaults to 4.
smooth: bool. Whether to apply Lin et al. 2004 smoothing to the BLEU
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we describe this better? Lin et al. 2004 with a period in the middle of the docstring does not read very well. Also please add to reference section.

reference_length += min(len(r) for r in references)
translation_length += len(translation)

merged_ref_ngram_counts = collections.Counter()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we've confirmed at this point that the Counter approach is much more efficient correct? Can we open an issue to remove the py_function from bleu as a follow up?

Describe a bit of the slowdown we saw when trying to do that today. Probably not urgent, but maybe someday we could collaborate with tf.text to ship an efficient non-python version of this op.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! Opened an issue for this: #247 :)

f"Received: variant={variant}"
)

def default_tokenizer(inputs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a lot of nested function that are quite long all over. Not very readable.

Can we either pull them out of the class entirely, or make them class methods?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a few of them class methods.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make this one a private method called _tokenizer that does.

if self.tokenizer:
    return self.tokenizer(x)

for pattern, replacement...

That way saving self.tokenizer in the config would work as expected.

Copy link
Collaborator Author

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Replied to a few comments :)

considered. Defaults to 4.
smooth: bool. Whether to apply Lin et al. 2004 smoothing to the BLEU
score. Defaults to False.
variant: string. Either `"corpus_bleu"` or `"sentence_bleu"`. The former
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After doing a survey, this is what I found:

Conclusion: I think the expectation is that, if users want to compute the Sentence BLEU score, they can do so by passing one sample at a time, and average over the returned scores.

Some additional notes
However, another point to note is that HF provides two options with all its metrics:

  • .compute() - user can pass one sample at a time, get the BLEU scores, and average over them for computing Sentence BLEU.
  • .add_batch() - will compute the Corpus BLEU score across all samples across batches.

We use Keras metrics similar to the add_batch() function. So, if the user wants to compute the Sentence BLEU score, he/she/they will have to re-initialise the metric for every sample. PyTorch Ignite metrics also work similar to the add_batch function, which is why they have provided an option for macro/micro-averaging. So, I am just wondering whether HF and NLTK do not provide explicit options to macro-average the BLEU scores because the user can average the BLEU scores. But with Ignite, the user can't do that without re-initialising before every sample, which is why an option has been provided,

reference_length += min(len(r) for r in references)
translation_length += len(translation)

merged_ref_ngram_counts = collections.Counter()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! Opened an issue for this: #247 :)

Copy link
Collaborator Author

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed most comments. Thanks for the review, @mattdangerw!

inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.

Args:
tokenizer: callable. A function that takes a string `tf.RaggedTensor`
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, it won't work with byte tokeniser because we use tensor_to_string_list in the code. Do you want me to change that?

keras_nlp/metrics/bleu.py Show resolved Hide resolved
f"Received: variant={variant}"
)

def default_tokenizer(inputs):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Made a few of them class methods.

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you! Dropped some comments on style & interface design. Will take another pass on the implementation body.

"""BLEU metric.

This class implements the BLEU metric. BLEU is generally used to evaluate
machine translation systems. by default, this implementation replicates
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: capital case By default

keras_nlp/metrics/bleu.py Show resolved Hide resolved

Note on input shapes:
For `y_true` and `y_pred`, this class supports the following shapes:
If `y_pred` is a scalar value, `y_true` has to be a 1D dense tensor.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This requirement on shape match is a little weird. Suppose both y_pred and y_true are scalar, can we just convert y_true for them?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, the idea right now is that one sample can have multiple references. So, basically, a translation can have multiple reference sentences. That's why the rank of y_true = rank of y_pred + 1.

Args:
tokenizer: callable. A function that takes a string `tf.RaggedTensor`
(of any shape), and tokenizes the strings in the tensor. This
function should use TensorFlow graph ops. If the tokenizer is not
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary? If people are not interested in using model.evaluate(), can they just run it in pure eager mode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but we call the tokeniser after converting the inputs to tensors. So, we have to use TF ops here such as tf.strings.regex_replace().

)

def _get_ngrams(self, segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: up to

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remember to fix this!

https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.

Args:
segment: string. Text segment from which n-grams will be
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a string or tensor of split tokens?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this should be list.

keras_nlp/metrics/bleu.py Show resolved Hide resolved
Copy link
Collaborator Author

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you, @chenmoneygithub, for the review comments. I've addressed them!


Note on input shapes:
For `y_true` and `y_pred`, this class supports the following shapes:
If `y_pred` is a scalar value, `y_true` has to be a 1D dense tensor.
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, the idea right now is that one sample can have multiple references. So, basically, a translation can have multiple reference sentences. That's why the rank of y_true = rank of y_pred + 1.

Args:
tokenizer: callable. A function that takes a string `tf.RaggedTensor`
(of any shape), and tokenizes the strings in the tensor. This
function should use TensorFlow graph ops. If the tokenizer is not
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but we call the tokeniser after converting the inputs to tensors. So, we have to use TF ops here such as tf.strings.regex_replace().

https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.

Args:
segment: string. Text segment from which n-grams will be
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this should be list.

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!


Args:
inputs: Input tensor, or dict/list/tuple of input tensors.
*args: Additional positional arguments.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove *args and **kwargs

from keras_nlp.tokenizers import ByteTokenizer


class BleuTest(tf.test.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we add a test with model.compile to make sure we are ok in function compilation? I think we did that for rouge.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad - I forgot to add it.

)

def update_state(self, y_true, y_pred, sample_weight=None):
def validate_and_fix_rank(inputs, tensor_name, base_rank=0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's just split this into a separate private method on the layer, code can stay unchanged

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove calculate_bleu_score from here. Let's keep validate_and_fix_rank inside this function for homogeneity (since we have done the same for ROUGE and Edit Distance)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!

possible_matches,
translation_length,
reference_length,
) = tf.py_function(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

given that you are converting everything to numpy, would this do better as a numpy_function call?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I think it won't work because we have ragged tensors? Numpy has no support for ragged matrices, I think.

Copy link
Collaborator Author

@abheesht17 abheesht17 Jul 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it fails to convert ragged input tensors to numpy arrays:

File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 224, in _corpus_bleu
    for (references, translation) in zip(
TypeError: in user code:

    File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 349, in update_state  *
        (
    File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 302, in _calculate_bleu_score
        ) = self._corpus_bleu(
    File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 224, in _corpus_bleu
        for (references, translation) in zip(

    TypeError: iteration over a 0-d array

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, let's stick with py_function for now

self.dtype,
self.dtype,
self.dtype,
self.dtype,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you remove the trailing comma will this format to one line?

keras_nlp/metrics/bleu.py Show resolved Hide resolved
inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.

Args:
tokenizer: callable. A function that takes a string `tf.RaggedTensor`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should either support our tokenizers or not name this argument to something else.

Tokenizer means something specific in our library now, if we use that name but don't support our tokenizer class that is a bad look.

keras_nlp/metrics/bleu.py Show resolved Hide resolved
f"Received: variant={variant}"
)

def default_tokenizer(inputs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe make this one a private method called _tokenizer that does.

if self.tokenizer:
    return self.tokenizer(x)

for pattern, replacement...

That way saving self.tokenizer in the config would work as expected.

https://cloud.google.com/translate/automl/docs/evaluate#bleu.

Note on input shapes:
`y_pred` can be a scalar (of shape `()`), or a dense tensor of shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How common is it to call bleu with only a single reference translation (is that unusual, or the normal case?). If it's usual, could we consider supporting the case that y_pred and y_true have the same shape? That might lead to a simpler overall usage in that case.

Open question, I'm not sure if that is true.

Copy link
Collaborator Author

@abheesht17 abheesht17 Jul 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it is called often with one reference...but all implementations take input the way I've mentioned in the doc-string, whether it be SacreBLEU, or HF, or even NLTK.

Copy link
Collaborator Author

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, @mattdangerw, for the comments!

https://cloud.google.com/translate/automl/docs/evaluate#bleu.

Note on input shapes:
`y_pred` can be a scalar (of shape `()`), or a dense tensor of shape
Copy link
Collaborator Author

@abheesht17 abheesht17 Jul 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, it is called often with one reference...but all implementations take input the way I've mentioned in the doc-string, whether it be SacreBLEU, or HF, or even NLTK.

inputs of shapes `()`, `(batch_size,)` and `(batch_size, 1)`.

Args:
tokenizer: callable. A function that takes a string `tf.RaggedTensor`
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do support our tokenisers. I've added a unit test here:

def test_custom_tokenizer(self):

from keras_nlp.tokenizers import ByteTokenizer


class BleuTest(tf.test.TestCase):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad - I forgot to add it.

)

def update_state(self, y_true, y_pred, sample_weight=None):
def validate_and_fix_rank(inputs, tensor_name, base_rank=0):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll remove calculate_bleu_score from here. Let's keep validate_and_fix_rank inside this function for homogeneity (since we have done the same for ROUGE and Edit Distance)?

possible_matches,
translation_length,
reference_length,
) = tf.py_function(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, I think it won't work because we have ragged tensors? Numpy has no support for ragged matrices, I think.

possible_matches,
translation_length,
reference_length,
) = tf.py_function(
Copy link
Collaborator Author

@abheesht17 abheesht17 Jul 8, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it fails to convert ragged input tensors to numpy arrays:

File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 224, in _corpus_bleu
    for (references, translation) in zip(
TypeError: in user code:

    File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 349, in update_state  *
        (
    File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 302, in _calculate_bleu_score
        ) = self._corpus_bleu(
    File "/home/abheesht/repos/keras-nlp/keras_nlp/metrics/bleu.py", line 224, in _corpus_bleu
        for (references, translation) in zip(

    TypeError: iteration over a 0-d array

keras_nlp/metrics/bleu.py Show resolved Hide resolved
Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great! Definitely a complex metric, so I expect we will have to keep iterating, but I think we are in a good spot to land. Last few comments.

`y_true` should be a tensor of shape `(num_references,)`. For batched
inputs, `y_pred` should be a tensor of shape `(batch_size,)`,
and `y_true` should be a tensor of shape `(batch_size, num_references)`. In
case of batched inputs, `y_true` can also be of shape `(batch_size, None)`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a ragged tensor with shape (batch_size, None)

possible_matches,
translation_length,
reference_length,
) = tf.py_function(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg, let's stick with py_function for now

)

def update_state(self, y_true, y_pred, sample_weight=None):
def validate_and_fix_rank(inputs, tensor_name, base_rank=0):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sg!

)

def _get_ngrams(self, segment, max_order):
"""Extracts all n-grams upto a given maximum order from an input segment.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remember to fix this!

@mattdangerw
Copy link
Member

LGTM from me! @chenmoneygithub feel free to merge whenever you are satisfied here!

@chenmoneygithub chenmoneygithub merged commit 13a9f2d into keras-team:master Jul 11, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add a BLEU metric
3 participants