Skip to content

Commit

Permalink
Address review comments-II, make shape changes
Browse files Browse the repository at this point in the history
  • Loading branch information
abheesht17 committed Jul 7, 2022
1 parent 0217b71 commit 0b6ebfa
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 9 deletions.
27 changes: 18 additions & 9 deletions keras_nlp/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class Bleu(keras.metrics.Metric):
"""BLEU metric.
This class implements the BLEU metric. BLEU is generally used to evaluate
machine translation systems. by default, this implementation replicates
machine translation systems. By default, this implementation replicates
SacreBLEU, but user-defined tokenizers can be passed to deal with other
languages.
Expand All @@ -63,13 +63,15 @@ class Bleu(keras.metrics.Metric):
n-grams so as to not give a high score to a (reference, prediction) pair
with redundant, repeated tokens. Secondly, BLEU score tends to reward
shorter predictions more, which is why a brevity penalty is applied to
penalise short predictions.
penalise short predictions. For more details, see the following article:
https://cloud.google.com/translate/automl/docs/evaluate#bleu.
Note on input shapes:
For `y_true` and `y_pred`, this class supports the following shapes:
If `y_pred` is a scalar value, `y_true` has to be a 1D dense tensor.
For batched inputs, if `y_pred` is a 1D dense tensor, `y_true` has to be
a dense/ragged tensor with shape `(batch_size, None)`.
`y_pred` can be a scalar (of shape `()`), or a dense tensor of shape
`(batch_size,)` or `(batch_size, 1)`. `y_true` can either be a dense tensor
of shape `(num_references,)`, or a ragged tensor of shapes
`(batch_size, None)` or `(batch_size, None, 1)`. This is because every
sample can have multiple references.
Args:
tokenizer: callable. A function that takes a string `tf.RaggedTensor`
Expand Down Expand Up @@ -171,7 +173,7 @@ def _get_ngrams(self, segment, max_order):
https://github.com/tensorflow/nmt/blob/master/nmt/scripts/bleu.py.
Args:
segment: string. Text segment from which n-grams will be
segment: list. Text segment from which n-grams will be
extracted.
max_order: int. Maximum length in tokens of the n-grams returned
by this methods.
Expand Down Expand Up @@ -286,10 +288,17 @@ def validate_and_fix_rank(inputs, tensor_name, base_rank=0):
return inputs[tf.newaxis]
elif inputs.shape.rank == base_rank + 1:
return inputs
elif inputs.shape.rank == base_rank + 2:
if tf.shape(inputs)[-1] != 1:
raise ValueError(
f"{tensor_name} is of rank {input.shape.rank}. The "
f"last dimension must be of size 1."
)
return tf.squeeze(inputs, axis=-1)
else:
raise ValueError(
f"{tensor_name} must be of rank {base_rank} or {base_rank+1}. "
f"Found rank: {inputs.shape.rank}"
f"{tensor_name} must be of rank {base_rank}, {base_rank+1} "
f"or {base_rank+2}. Found rank: {inputs.shape.rank}"
)

def calculate_bleu_score(references, translation):
Expand Down
32 changes: 32 additions & 0 deletions keras_nlp/metrics/bleu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ def test_1d_list_input(self):
bleu_val = bleu(y_true, y_pred)
self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3)

def test_2d_list_input(self):
bleu = Bleu()
y_true = [
[["He eats a sweet apple."]],
[["Silicon Valley is one of my favourite shows!"]],
]
y_pred = [
["He He He eats sweet apple which is a fruit."],
["I love Silicon Valley, it's one of my favourite shows."],
]

bleu_val = bleu(y_true, y_pred)
self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3)

def test_1d_tensor_input(self):
bleu = Bleu()
y_true = tf.ragged.constant(
Expand All @@ -70,6 +84,24 @@ def test_1d_tensor_input(self):
bleu_val = bleu(y_true, y_pred)
self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3)

def test_2d_tensor_input(self):
bleu = Bleu()
y_true = tf.constant(
[
[["He eats a sweet apple."]],
[["Silicon Valley is one of my favourite shows!"]],
]
)
y_pred = tf.constant(
[
["He He He eats sweet apple which is a fruit."],
["I love Silicon Valley, it's one of my favourite shows."],
]
)

bleu_val = bleu(y_true, y_pred)
self.assertAlmostEqual(bleu_val.numpy(), 0.243, delta=1e-3)

def test_custom_tokenizer(self):
byte_tokenizer = ByteTokenizer()
bleu = Bleu(tokenizer=byte_tokenizer)
Expand Down

0 comments on commit 0b6ebfa

Please sign in to comment.