Skip to content

Conversation

abheesht17
Copy link
Collaborator

@abheesht17 abheesht17 commented Apr 16, 2022

Resolves #67

@mattdangerw, @chenmoneygithub, this PR is now ready for review :)

@abheesht17 abheesht17 mentioned this pull request Apr 16, 2022
2 tasks
@abheesht17 abheesht17 changed the title [WIP] Add Rouge-L Metric Add Rouge-L Metric Apr 16, 2022
@mattdangerw
Copy link
Member

Thank you! This looks awesome!

Before diving into review could you share a colab showing the expected end to end use case for this with some string translations and references?

You should be able to add these lines to the top of the colab.

!git clone --banch your-branch-name your-remote-url
!cd keras-nlp && pip install . -q

It looks like we would need people to tokenize themselves before calling into this metric, which I think is ok (and maybe even preferable). But we should be clear what our expected end to end flow is. See as reference...

https://colab.sandbox.google.com/github/huggingface/notebooks/blob/master/course/videos/rouge_metric.ipynb

Maybe use TextVectorization for now to show how this could be done at a word level?

It would also be helpful to compare this argument list and docstring to the ones we expect for the Rouge-N variant. Could you show what the full arg list and usage for Rouge-N could look like? I know we are still figure out how to implement, so that part could be a markdown codeblock, not actually runnable.

@abheesht17
Copy link
Collaborator Author

Sure thing! Will share a notebook ASAP :)

@abheesht17
Copy link
Collaborator Author

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally looks good! left some comments

1.1. `mask_token_ids` not provided.
>>> tf.random.set_seed(42)
>>> rouge_l = keras_nlp.metrics.RougeL(name="rouge_l")
>>> references = tf.random.uniform(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Prefer use a well-defined example rather than random data so that users can manually calculate the F1 score.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

alpha=0.5,
metric_type="f1_score",
mask_token_ids=None,
dtype=None,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In docstring it defaults to float32, which mismatches the default value None. Please fix it.


def update_state(self, y_true, y_pred, sample_weight=None):
# Both y_true and y_pred have shape: [batch_size, seq_len]. Note that
# they can also be ragged tensors with shape [num_samples, (seq_len)].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename num_samples to batch_size, we should be consistent with the naming.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. Changed!

return config


def rouge_l(y_true, y_pred, alpha=0.5):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do not actually need to make this a standalone util. Just write

f1_scores, precisions, recalls = tf_text.metrics.rouge_l(
        y_pred, y_true, alpha=alpha
    )

in the RougeL class.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following the format of metrics given in Keras. For example, check this out: https://github.com/keras-team/keras/blob/master/keras/metrics/metrics.py#L220, https://github.com/keras-team/keras/blob/master/keras/metrics/metrics.py#L3331.

The class helps in aggregating the ROUGE score (the user can iterate over the dataset, and the class will return the avg. ROUGE score).

The function allows string inputs, I think.

I've changed it to this for the time being:

f1_scores, precisions, recalls = tf_text.metrics.rouge_l(
        y_pred, y_true, alpha=alpha
    )

but let me know which one is more appropriate and I'll change it to that. Thanks!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! My guess is there was some need to directly call the metrics computing function, which I am not sure if still applies. Let's keep it simple for now, and we can add the util if we find it is necessary.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great! Let me know if further changes are required



class RougeLTest(tf.test.TestCase):
def test_vars_after_initializing_class(self):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit verbose, we can rename to test_initialization().

Copy link
Collaborator Author

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@chenmoneygithub, thank you for the review comments! I've addressed them

return config


def rouge_l(y_true, y_pred, alpha=0.5):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was following the format of metrics given in Keras. For example, check this out: https://github.com/keras-team/keras/blob/master/keras/metrics/metrics.py#L220, https://github.com/keras-team/keras/blob/master/keras/metrics/metrics.py#L3331.

The class helps in aggregating the ROUGE score (the user can iterate over the dataset, and the class will return the avg. ROUGE score).

The function allows string inputs, I think.

I've changed it to this for the time being:

f1_scores, precisions, recalls = tf_text.metrics.rouge_l(
        y_pred, y_true, alpha=alpha
    )

but let me know which one is more appropriate and I'll change it to that. Thanks!


def update_state(self, y_true, y_pred, sample_weight=None):
# Both y_true and y_pred have shape: [batch_size, seq_len]. Note that
# they can also be ragged tensors with shape [num_samples, (seq_len)].
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad. Changed!

1.1. `mask_token_ids` not provided.
>>> tf.random.set_seed(42)
>>> rouge_l = keras_nlp.metrics.RougeL(name="rouge_l")
>>> references = tf.random.uniform(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@mattdangerw
Copy link
Member

@abheesht17 thanks for the colab! That is super helpful. Thoughts...

I am still somewhat torn on whether we might want to somehow accept strings. The hugging face implementation is certainly a little more usable on the face of it, as is the rouge-score package it is based on. Though it seems like they will need to bake in a lot of assumptions (including language stemming!).

This brings up a lot of questions. Curious your thoughts here.

  • If you wanted to report rouge score, say for a paper, would you need to recreate the tokenization and stemming exactly like the package I linked, for comparability with other models?
  • How could you do that from our package?
  • Do people regularly report rouge with other tokenizers?
  • How does people handle other languages with rogue generally (especially ones without whitespace splitting for tokens)?

Second, more minor point, why do you need to expand_dims on the first metric, could we remove that requirement or is that baked into keras.metrics somehow?

@chenmoneygithub chenmoneygithub self-requested a review April 25, 2022 22:01
@chenmoneygithub chenmoneygithub dismissed their stale review April 25, 2022 22:01

changes requested

@chenmoneygithub
Copy link
Contributor

@abheesht17 Hi, we had some discussions around this, and also reached out to Google research team for their insights. Briefly, we will have ROUGE metric working at string inputs, and by default provide a standard tokenizer so that different works would report ROUGE score based on the same calculation mechanism.

To make our work compatible with existing ROUGE scores reporting, we can depend this metric on this rouge_score package. We still want to deliver Keras metrics so that ROUGE can be easily integrated into training flow. To do that, one possible way is to use tf.py_function() to wrap the rouge function call from existing rouge_score package.

We can leave tokenizer customization a TODO until we know how to make our tokenizers compatible with rouge_score package.

Sorry about the inconvenience!

@abheesht17
Copy link
Collaborator Author

Hello, @chenmoneygithub! Thank you! Sorry for not replying earlier to @mattdangerw's questions.

So, essentially, the gist is:

  1. We will still keep ROUGE-L as a subclass of keras.metrics.Metric.
  2. We will use Google Research's ROUGE package.
  3. We will take string inputs.
  4. For future work, we can figure out how to make our tokenisers compatible with package's tokenisers and give an option to the user to pass an arg for that.

I had a brief look at the tokenisers provided by the package. Their default tokeniser merely returns a list of string tokens:

>>> df = DefaultTokenizer()
>>> df.tokenize("hello, this is fun")
['hello', 'this', 'is', 'fun']

If you see this file, we need to do two things to implement a custom tokeniser:

  1. The tokeniser should be a subclass of the Tokenizer class present in the ROUGE package.
  2. It should have a tokenize() method which returns a list of string tokens.

I don't think (1) is very important; after all, in the internal ROUGE package implementation, they will just call the tokenize method. See here: https://github.com/google-research/google-research/blob/master/rouge/rouge_scorer.py#L73 and https://github.com/google-research/google-research/blob/master/rouge/rouge_scorer.py#L125.

Currently, our tokenisers return a RaggedTensor on passing a string:

String output.
>>> vocab = ["[UNK]", "the", "qu", "##ick", "br", "##own", "fox", "."]
>>> inputs = "The quick brown fox."
>>> tokenizer = keras_nlp.tokenizers.WordPieceTokenizer(
...     vocabulary=vocab, dtype="string")
>>> tokenizer(inputs)
<tf.RaggedTensor [[b'the', b'qu', b'##ick', b'br', b'##own', b'fox', b'.']]>

So, I guess in the tokenizer class in KerasNLP, we can simply add an option, return_list, and if this is True, we can convert the RaggedTensor output to a list. I'll try this out.

@abheesht17
Copy link
Collaborator Author

abheesht17 commented Apr 26, 2022

Small update:

The functionality for split_summaries and providing a custom tokeniser was only recently added to the rouge-score package. The authors have not released a version on PyPi with the above two functionalities. So, I'm just going for a basic implementation right now.

See these two commits:

  1. google-research/google-research@61ce9f0 (custom tokeniser)
  2. google-research/google-research@ed3e2bc (split summaries)

@abheesht17
Copy link
Collaborator Author

@mattdangerw, @chenmoneygithub, how do I convert a string Tensor to a Python string (graph ops)? The rouge-score package does not accept string tensors (I tried passing a tensor with a single string).

>>> scorer.score(tf.constant("hello"), tf.constant("bye"))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\abheesht\anaconda3\envs\keras_nlp\lib\site-packages\rouge_score\rouge_scorer.py", line 88, in score
    target_tokens = tokenize.tokenize(target, self._stemmer)
  File "C:\Users\abheesht\anaconda3\envs\keras_nlp\lib\site-packages\rouge_score\tokenize.py", line 42, in tokenize
    text = text.lower()
  File "C:\Users\abheesht\anaconda3\envs\keras_nlp\lib\site-packages\tensorflow\python\framework\ops.py", line 513, in __getattr__
    self.__getattribute__(name)
AttributeError: 'tensorflow.python.framework.ops.EagerTensor' object has no attribute 'lower'

I tried a bunch of things, none of them seem to work. str(tensor_var) works in Eager mode, but won't work in graph mode. tf.strings.as_string did not work either.

>>> tf.strings.as_string(tf.constant("weifwf"))
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\abheesht\anaconda3\envs\keras_nlp\lib\site-packages\tensorflow\python\ops\gen_string_ops.py", line 74, in as_string
    _ops.raise_from_not_ok_status(e, name)
  File "C:\Users\abheesht\anaconda3\envs\keras_nlp\lib\site-packages\tensorflow\python\framework\ops.py", line 7186, in raise_from_not_ok_status     
    raise core._status_to_exception(e) from None  # pylint: disable=protected-access
tensorflow.python.framework.errors_impl.InvalidArgumentError: Value for attr 'T' of string is not in the list of allowed values: float, double, int32, uint8, int16, int8, int64, bfloat16, uint16, half, uint32, uint64, complex64, complex128, bool, variant
        ; NodeDef: {{node AsString}}; Op<name=AsString; signature=input:T -> output:string; attr=T:type,allowed=[DT_FLOAT, DT_DOUBLE, DT_INT32, DT_UINT8, DT_INT16, DT_INT8, DT_INT64, DT_BFLOAT16, DT_UINT16, DT_HALF, DT_UINT32, DT_UINT64, DT_COMPLEX64, DT_COMPLEX128, DT_BOOL, DT_VARIANT]; attr=precision:int,default=-1; attr=scientific:bool,default=false; attr=shortest:bool,default=false; attr=width:int,default=-1; attr=fill:string,default=""> [Op:AsString]

@abheesht17 abheesht17 changed the title Add Rouge-L Metric Add Rouge Metric May 23, 2022
@abheesht17 abheesht17 changed the title Add Rouge Metric Add ROUGE Metric May 23, 2022
@abheesht17
Copy link
Collaborator Author

abheesht17 commented May 23, 2022

@mattdangerw, @chenmoneygithub, I've made the required changes. Apologies once again for the delay!

Still confused about the graph ops stuff though. Since we are using tf.py_function, we can't really pass an object of this class to model.compile. I tried it out with @tf.function and it fails, which means it won't work with model.compile.

P.S. Haven't added examples in the doc-string yet.
P.P.S It's nice to be back! :)

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks good to me.

Made a colab to play around, looks like there is at least one shape issue. This metric seems to only supper shape (batch_size) inputs, but we should also support (batch_size, 1).

https://colab.sandbox.google.com/gist/mattdangerw/104626168c0bce36f12679b2dd38ce23/rouge-test.ipynb

We should also discuss whether we want to make this two metrics (maybe with a common base class, unsure).


if rouge_score is None:
raise ImportError(
"ROUGE metric requires the `rouge_score` package."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

space after period

if rouge_score is None:
raise ImportError(
"ROUGE metric requires the `rouge_score` package."
"Please install it with `pip install rouge_score`."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rouge_score -> rouge-score

score = score.recall
else:
score = score.fmeasure
return score
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a keras metric can just return a dict of scalars I believe, should we just return a dict here and from the metric overall?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I was not aware of this. Will return a dictionary

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw, looks like this doesn't work; an error is thrown when I return a dictionary from result(). Have a look at this snippet of code:

>>> import keras_nlp
>>> y_true = "hey, this is great fun"
>>> y_pred = "great fun indeed"
>>> rouge = keras_nlp.metrics.RougeN(order=2)
>>> rouge(y_true, y_pred)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/metrics/base_metric.py", line 200, in __call__
    return distributed_training_utils.call_replica_local_fn(
  File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/distribute/distributed_training_utils.py", line 60, in call_replica_local_fn
    return fn(*args, **kwargs)
  File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/keras/metrics/base_metric.py", line 196, in replica_local_fn
    result_t._metric_obj = self  # pylint: disable=protected-access
AttributeError: 'dict' object has no attribute '_metric_obj'

However, this works:

>>> import keras_nlp
>>> y_true = "hey, this is great fun"
>>> y_pred = "great fun indeed"
>>> rouge = keras_nlp.metrics.RougeN(order=2)
>>> rouge.update_state(y_true, y_pred)
>>> rouge.result()
{'rouge_n_precision': <tf.Tensor: shape=(), dtype=float32, numpy=0.5>, 'rouge_n_recall': <tf.Tensor: shape=(), dtype=float32, numpy=0.25>, 'rouge_n_f1_score': <tf.Tensor: shape=(), dtype=float32, numpy=0.33333334>}

Reverting back to the metric_type implementation

def result(self):
if self._number_of_samples == 0:
return 0.0
rouge_l_score = self._rouge_score / self._number_of_samples
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this variable called rouge_l when it could be rouge_l or rouge_n? Seems like a confusing name.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, this is a typo. Corrected in the latest commit!

ROUGE-L and ROUGE-LSum.

Args:
variant: string. One of "rougeN", "rougeL", "rougeLsum". Defaults to
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels too tricky, particularly with the order of RougeN being a hidden parameter of the string passed to this argument.

What about making a separate RougeL and RougeN metric class?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline with Matt. We are going ahead with separate classes for ROUGE-N and ROUGE-L!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looking good! Left some comment. Most of the comments on RougeL will also apply to RougeN.

Making a private base class for common logic will make refactoring simpler.

Still trying to get answers on the best return type for this metric.

# strings in the tensor/list.

# Check if input is a raw string/list.
if isinstance(y_true, str):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just

if not isinstance(y_true, tf.Tensor):
    inputs = tf.convert_to_tensor(y_true)
if not isinstance(y_pred, tf.Tensor):
    inputs = tf.convert_to_tensor(y_pred)

I don't think we should do the rank coercion in this test case. That would seem to then support scalar inputs only if you had not tensor inputs, but not support scalar inputs if tensors, which is weird behavior.

Convert to tensor first, then fix rank.


def update_state(self, y_true, y_pred, sample_weight=None):
# Three possible shapes for y_true and y_pred: Python string,
# [batch_size] and [batch_size, 1]. In the latter two cases, we have
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems like we also support scalar inputs, is that true?

We should probably move some of this discussion on supported shape into the docstring.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I've written "Python string" on line number 96.

Sure, will move it to the doc-string!


# If the shape of y_true and y_pred is [batch_size, 1], squeeze it to
# [batch_size].
if y_true.shape.rank == 2:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would fail if shape is [batch_size, 2] right now in a not very helpful way. The most friendly thing here might be to do a check if we have a supported shape (rank 0, rank 1 or rank 2 with shape[-1] == 1), and if not give a friendly error message.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! Changes made 👍🏼

name="rouge_l_test",
)

config = rouge.get_config()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels fragile, we are essentially testing what the base metric class has in it's config. Currently it is only dtype and name, but that could change.

Maybe just assert the contents of the config you expect, metric_type and use_stemmer.

from keras_nlp.metrics import RougeL


class RougeLTest(tf.test.TestCase):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test this with a model (which maybe just passes through inputs), and a batched tf.data.Dataset.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test passing this to model.compile()'s metrics arg.

tf.cast(batch_size, dtype=self.dtype)
)

def result(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Open question, how are these rouge scores usually aggregated when reported in a paper? We may want to do a little bit of a dive into this, to understand what the critical user journeys are when reporting an aggregate score.

Looking at the aggregation code in the package, they have a lot more there...
https://github.com/google-research/google-research/blob/master/rouge/scoring.py#L61

OK if we don't need all that, but we should make sure we understand what people will want when using this metric.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, good point. I went through some examples online. In particular, I went through PyTorch Ignite's ROUGE metric. Have a look: https://pytorch.org/ignite/_modules/ignite/metrics/nlp/rouge.html#Rouge (_BaseRouge class).

They take the average:

    def compute(self) -> Mapping:
        if self._num_examples == 0:
            raise NotComputableError("Rouge metric must have at least one example before be computed")

        return {
            f"{self._metric_name()}-P": float(self._precision / self._num_examples),
            f"{self._metric_name()}-R": float(self._recall / self._num_examples),
            f"{self._metric_name()}-F": float(self._fmeasure / self._num_examples),
        }

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good to me! We can always see if people open up issues. Thanks for checking!

rouge_score = None


class RougeL(keras.metrics.Metric):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would still be a good idea to do some code sharing between these two metrics.

What if we do this...

  • Move everything back into rouge.py and rouge_test.py.
  • Add a base class RougeBase that contains most of the logic.
  • In init.py, only export RougeL and RougeN

That would similar to how core Keras handle Conv2D and Conv3D, for example.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Copy link
Collaborator Author

@abheesht17 abheesht17 Jun 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! However, I have kept two separate files for unit tests - rouge_n_test.py and rouge_l_test.py since rougeN and rougeL are what will eventually be exposed to the user. Let me know if you want only one test script (for RougeBase).

not specified, it defaults to tf.float32.
name: string. Name of the metric instance.
**kwargs: Other keyword arguments.
"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add some docstring examples! I think the >>> style with actual output, would be useful in this case.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I forgot to do this. I've added examples in the new commit!

Copy link
Contributor

@chenmoneygithub chenmoneygithub left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

rouge_score = None


class RougeL(keras.metrics.Metric):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

from keras_nlp.metrics import RougeL


class RougeLTest(tf.test.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should test passing this to model.compile()'s metrics arg.

between the reference text and the hypothesis text.

Args:
order: The order of n-grams which are to be matched. It should lie in
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just curious - is this [1, 9] a requirement from rouge-score package?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

>>> from rouge_score import rouge_scorer
>>> rg = rouge_scorer.RougeScorer(rouge_types=["rouge10"])
>>> rg.score("hey", "hey, hello")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/rouge_score/rouge_scorer.py", line 119, in score
    raise ValueError("Invalid rouge type: %s" % rouge_type)
ValueError: Invalid rouge type: rouge10

rouge_score = None


class RougeN(keras.metrics.Metric):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading through both class implementation, most code can be shared between, so it looks doable to me to have a RougeBase class.

Copy link
Collaborator Author

@abheesht17 abheesht17 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mattdangerw, @chenmoneygithub, thanks for the comments! I've addressed them.

rouge_score = None


class RougeL(keras.metrics.Metric):
Copy link
Collaborator Author

@abheesht17 abheesht17 Jun 3, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done! However, I have kept two separate files for unit tests - rouge_n_test.py and rouge_l_test.py since rougeN and rougeL are what will eventually be exposed to the user. Let me know if you want only one test script (for RougeBase).


def update_state(self, y_true, y_pred, sample_weight=None):
# Three possible shapes for y_true and y_pred: Python string,
# [batch_size] and [batch_size, 1]. In the latter two cases, we have
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, I've written "Python string" on line number 96.

Sure, will move it to the doc-string!


# If the shape of y_true and y_pred is [batch_size, 1], squeeze it to
# [batch_size].
if y_true.shape.rank == 2:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right! Changes made 👍🏼

between the reference text and the hypothesis text.

Args:
order: The order of n-grams which are to be matched. It should lie in
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep!

>>> from rouge_score import rouge_scorer
>>> rg = rouge_scorer.RougeScorer(rouge_types=["rouge10"])
>>> rg.score("hey", "hey, hello")
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/abheesht/python_envs/keras_nlp/lib/python3.8/site-packages/rouge_score/rouge_scorer.py", line 119, in score
    raise ValueError("Invalid rouge type: %s" % rouge_type)
ValueError: Invalid rouge type: rouge10

not specified, it defaults to tf.float32.
name: string. Name of the metric instance.
**kwargs: Other keyword arguments.
"""
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I forgot to do this. I've added examples in the new commit!

tf.cast(batch_size, dtype=self.dtype)
)

def result(self):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmmm, good point. I went through some examples online. In particular, I went through PyTorch Ignite's ROUGE metric. Have a look: https://pytorch.org/ignite/_modules/ignite/metrics/nlp/rouge.html#Rouge (_BaseRouge class).

They take the average:

    def compute(self) -> Mapping:
        if self._num_examples == 0:
            raise NotComputableError("Rouge metric must have at least one example before be computed")

        return {
            f"{self._metric_name()}-P": float(self._precision / self._num_examples),
            f"{self._metric_name()}-R": float(self._recall / self._num_examples),
            f"{self._metric_name()}-F": float(self._fmeasure / self._num_examples),
        }

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good! Left some minor comments. The main thing we need to dig into now is what is going wrong when returning a dict. This will probably require some looking into the base metric class in core Keras.

Thanks!


class RougeBase(keras.metrics.Metric):
"""ROUGE metric.
This class implements all the variants of the ROUGE metric - ROUGE-N,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

white space before and after this paragraph

use_stemmer: bool. Whether Porter Stemmer should be used to strip word
suffixes to improve matching. Defaults to False.
dtype: string or tf.dtypes.Dtype. Precision of metric computation. If
not specified, it defaults to tf.float32.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fix alignment of this line

def __init__(
self,
variant="rouge2",
metric_type="f1_score",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We still need to figure out why returning a dict is not working, and see if there is a bug that needs to be fixed in core keras or elsewhere.

We should not ship a API signature we don't want because of a bug we need to fix!

("rouge" + str(order) for order in range(1, 10))
) + (
"rougeL",
"rougeLsum",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rougeLsum we aren't supporting right now correct?

# [batch_size] and [batch_size, 1]. In the latter two cases, we have
# strings in the tensor/list.

def validate_and_fix_rank(input_, tensor_name):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

generally trailing underscore is not a naming pattern we follow

Just call this inputs?

Succinctly put, ROUGE-L is a score based on the length of the longest
common subsequence present in the reference text and the hypothesis text.

Note on input shapes:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would just comment on the shapes here (not the types). So just say supports scalar and batch inputs of shape (), (batch_size,) and (batch_size, 1).

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! This looks good to me pending the dict discussion.

# class wraps the `results()` method.
obj = super().__new__(cls)

class MetricDict(dict):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this class necessary? It seems this is juts an alias to dict.

Also let's create a TODO here for future cleanup, this code is hard to maintain.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the reason for defining a class is so that we can do object.var_name type assignments.
If we use a dictionary, this error crops up:

    def replica_local_fn(*args, **kwargs):
      """Updates the state of the metric in a replica-local context."""
      if any(
          isinstance(arg, keras_tensor.KerasTensor)
          for arg in tf.nest.flatten((args, kwargs))):
        update_op = None
      else:
        update_op = self.update_state(*args, **kwargs)  # pylint: disable=not-callable
      update_ops = []
      if update_op is not None:
        update_ops.append(update_op)
      with tf.control_dependencies(update_ops):
        result_t = self.result()  # pylint: disable=not-callable
    
        # We are adding the metric object as metadata on the result tensor.
        # This is required when we want to use a metric with `add_metric` API on
        # a Model/Layer in graph mode. This metric instance will later be used
        # to reset variable state after each epoch of training.
        # Example:
        #   model = Model()
        #   mean = Mean()
        #   model.add_metric(mean(values), name='mean')
>       result_t._metric_obj = self  # pylint: disable=protected-access
E       AttributeError: 'dict' object has no attribute '_metric_obj'

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, definitely the plan would be to remove this code after 2.10 is out!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving! Thanks!

Just a few more minor comments.

for metric_type, expected_val in zip(
self.metric_types, [1, 0.689, 0.807]
):
self.assertAlmostEqual(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we assert a whole dict structure in here? If so I would fine that a lot more readable actually than the way it's done here. Here and elsewhere

assertAlmostEqual(rouge_output, {
    "rouge-l_precision": x,
    "rouge-l_recall": y,
    "rouge-l_f1_score": z,
})

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right - I can make a custom function for this!

rouge_recall = self._rouge_recall / self._number_of_samples
rouge_f1_score = self._rouge_f1_score / self._number_of_samples
return {
f"{self.name}_precision": rouge_precision,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick note. I think we should remove f"{self.name}_" part. We would like to make a change to core keras to actually join the metric name when reporting metrics in a dict.

So if they metric is called "rouge-2", we would join metric when returning the metric dict to "rouge-2/recall" or something like that.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should I remove it now, or later when the fix for the bug has been released?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed it for now. Let me know if you want to revert it back to f"{self.name}_"

@mattdangerw mattdangerw merged commit 0e3d12e into keras-team:master Jun 17, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add ROUGE metrics
3 participants