Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor evaluation metrics to support decoded generated text metrics like BLEU and ROUGE. #3539

Merged
merged 9 commits into from
Aug 25, 2023

Conversation

justinxzhao
Copy link
Collaborator

@justinxzhao justinxzhao commented Aug 17, 2023

Adding new decoded text metrics

New evaluation metrics for text output features can be added to the metrics registry under the RESPONSE keyword, like so:

@register_metric("bleu", [TEXT], MAXIMIZE, RESPONSE)
class BLEUScoreMetric(BLEUScore, LudwigMetric):
    def __init__(self, **kwargs):
        super().__init__()

New metrics added:

  • BLEU
  • ROUGE
  • WordErrorRate

Changes to TextOutputFeature.update_metrics()

Since we need to pass in additional decoded inputs to text features, we explicitly define update_metrics() for TextOutputFeature instead of relying on the base OutputFeature implementation. The decoded texts are optional arguments and if not provided, decoded text metrics are skipped.

def update_metrics(
        self,
        targets: Tensor,
        predictions: Dict[str, Tensor],
        decoded_targets: Optional[List[str]] = None,
        decoded_predictions: Optional[List[str]] = None,
    ) -> None:
        """Updates metrics with the given targets and predictions.

        If decoded_targets and decoded_predictions are provided, as through LLM model types, then additional
        response-based metrics like BLEU and ROUGE are also computed.

        Args:
            targets: Tensor with target values for this output feature.
            predictions: Dict of tensors returned by predictions().
        """
        for metric_name, metric_fn in self._metric_functions.items():
            prediction_key = get_metric_tensor_input(metric_name)
            if prediction_key != RESPONSE:
                # Non-RESPONSE metrics don't use decoded texts.
                metric_fn = metric_fn.to(predictions[prediction_key].device)
                metric_fn.update(predictions[prediction_key].detach(), targets)
                continue

            if decoded_targets is not None and decoded_predictions is not None:
                # RESPONSE metrics cannot be computed if decoded texts aren't provided.
                # We assume that RESPONSE metrics can only be calculated one example at a time.
                for i in range(len(decoded_predictions)):
                    metric_fn.update(decoded_predictions[i], decoded_targets[i])

Some torchmetrics text metrics return a dictionary of sub-metrics. These are individually unpacked in model. get_metrics().

Changes to the LLM class

The LLM class sets has two separate methods for evaluation metrics: 1) update_metrics for ZS/FS and 2) update_metrics_finetune_llm() for fine-tuning LLMs. These appear to have slightly different treatments for aligning tensors and handling pad tokens.

Once predictions and targets are computed, they are decoded and passed to TextOutputFeature.update_metrics().

Refactoring the MetricsPrintedTable

This PR also replaces the MetricsPrintedTable with a simpler, stateless method. Instead of:

printed_table = MetricsPrintedTable(output_features)
printed_table.add_metrics_to_printed_table(train_metrics_log, TRAIN)
printed_table.add_metrics_to_printed_table(validation_metrics_log, VALIDATION)
printed_table.add_metrics_to_printed_table(test_metrics_log, TEST)
printed_table.log_info()

We consolidate printed metrics into one call:

print_metrics_table(
    output_features,
    progress_tracker.train_metrics,
    progress_tracker.validation_metrics,
    progress_tracker.test_metrics,
)

I've also changed the metrics table to print out metrics in a transposed way with splits as columns. This looks substantially better with many additional metrics.

╒═══════════════════════╤════════════╤══════════════╤════════╕
│                       │      train │ validation   │ test   │
╞═══════════════════════╪════════════╪══════════════╪════════╡
│ bleu                  │     0.0000 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ char_error_rate       │     0.5089 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ loss                  │     2.3335 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ next_token_perplexity │ 30315.8379 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ perplexity            │ 37736.5156 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rouge1_fmeasure       │     0.5648 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rouge1_precision      │     0.4919 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rouge1_recall         │     0.6630 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rouge2_fmeasure       │     0.2991 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rouge2_precision      │     0.2602 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rouge2_recall         │     0.3516 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rougeL_fmeasure       │     0.4167 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rougeL_precision      │     0.3629 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rougeL_recall         │     0.4891 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rougeLsum_fmeasure    │     0.5093 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rougeLsum_precision   │     0.4435 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ rougeLsum_recall      │     0.5978 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ sequence_accuracy     │     0.0000 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ token_accuracy        │     0.2716 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ word_error_rate       │     1.1882 │              │        │
├───────────────────────┼────────────┼──────────────┼────────┤
│ combined_loss         │     2.3335 │              │        │
╘═══════════════════════╧════════════╧══════════════╧════════╛

Combined loss is also merged into this table as the last column instead of a separate table, which has frequently confused users.

╒═══════════════╤═════════╤══════════════╤════════╕
│               │   train │   validation │   test │
╞═══════════════╪═════════╪══════════════╪════════╡
│ accuracy      │  0.8157 │       0.6966 │ 0.8090 │
├───────────────┼─────────┼──────────────┼────────┤
│ loss          │  0.4619 │       0.5039 │ 0.4488 │
├───────────────┼─────────┼──────────────┼────────┤
│ precision     │  0.8274 │       0.6250 │ 0.7818 │
├───────────────┼─────────┼──────────────┼────────┤
│ recall        │  0.6680 │       0.4545 │ 0.6615 │
├───────────────┼─────────┼──────────────┼────────┤
│ roc_auc       │  0.8471 │       0.7706 │ 0.8592 │
├───────────────┼─────────┼──────────────┼────────┤
│ specificity   │  0.9105 │       0.8393 │ 0.8938 │
├───────────────┼─────────┼──────────────┼────────┤
│ combined_loss │  0.4619 │       0.5039 │ 0.4488 │
╘═══════════════╧═════════╧══════════════╧════════╛

@github-actions
Copy link

github-actions bot commented Aug 17, 2023

Unit Test Results

  6 files  ±0    6 suites  ±0   1h 12m 34s ⏱️ - 8m 2s
34 tests ±0  29 ✔️ ±0    5 💤 ±0  0 ±0 
88 runs  ±0  72 ✔️ ±0  16 💤 ±0  0 ±0 

Results for commit 95d419c. ± Comparison against base commit 8d4c96b.

♻️ This comment has been updated with latest results.

requirements.txt Outdated Show resolved Hide resolved
Copy link
Contributor

@arnavgarg1 arnavgarg1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, nice work @justinxzhao ! Just left some comments to add more comments to clarify some pieces of the code, but this is a good change overall.

_predictions,
_decoded_predictions,
) = realign_target_and_prediction_tensors_for_inference(targets, predictions, of_name, self.tokenizer)
of_obj.update_metrics(_targets[of_name], _predictions[of_name], _decoded_targets, _decoded_predictions)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we streictly need to pass all this stuff? plus isn't there a more elegant way that having an if textfeature and having update_metrics to have a different signature only in that case? It's not amazing abstraction, and requires us to check if textoutput everywhere we call update metrics (it may be happening only here, but still, it's a red flag that comething is wrong with the abstraction). One option could be to have kwargs in the superclass update metrics and put the edditional things as parameters passed through kwargs. it would still need an if for determining the inputs, but at least we would not have textoutputs need a different signature for update_metrics

Copy link
Collaborator Author

@justinxzhao justinxzhao Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having TextOutputFeature own its tokenizer (for decoding predictions) would enable the original update_metrics() signature to work, though TextOutputFeature would still need to override the base class's implementation of update_metrics() to decode texts with the tokenizer and use the decoded texts to compute RESPONSE metrics.

However, constructing TextOutputFeature with the tokenizer would require a more invasive change.

  • Output feature construction code paths are still fully unified under BaseModel.build_outputs().
  • We would need to change the build_outputs() signature to take in an additional training_set_metadata argument (which contains tokenizer information), so that it can be passed through the output feature constructors. TextOutputFeature is the only output feature who would need training_set_metadata in this way.
  • Output feature construction call sites would need to be updated as well -- there are 3 (ecd.py, gbm.py, and llm.py).

Copy link
Collaborator Author

@justinxzhao justinxzhao Aug 25, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This refactoring seems worthwhile to do if other output features need training set metadata. For now, keeping the ownership of the tokenizer at the model level and relying on the model to decode predictions and provide them for TextOutputFeature's metrics calculations seems like a cleaner, less invasive change -- WDYT?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made one simplification, which is to do decoding within update_metrics() instead of outside at the model level. This should reduce some duplication.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good to me!

ludwig/models/llm.py Outdated Show resolved Hide resolved
@justinxzhao justinxzhao merged commit e5513c3 into master Aug 25, 2023
13 of 16 checks passed
@justinxzhao justinxzhao deleted the bleu_rouge branch August 25, 2023 20:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants