Skip to content

Commit

Permalink
Introduce translation evaluation recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
cbalioglu committed Jul 8, 2024
1 parent 3969a96 commit be51660
Show file tree
Hide file tree
Showing 5 changed files with 472 additions and 0 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
"psutil~=5.9",
"pyyaml~=6.0",
"rich~=13.7",
"sacrebleu~=2.4",
"tiktoken~=0.7",
"torcheval~=0.0.6",
"tqdm~=4.62",
Expand Down
82 changes: 82 additions & 0 deletions src/fairseq2/metrics/bleu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

from typing import Iterable, Optional, Sequence, final

import torch
from sacrebleu import corpus_bleu
from sacrebleu.metrics.bleu import BLEU, MAX_NGRAM_ORDER
from torch import Tensor
from torcheval.metrics import Metric
from typing_extensions import Self

from fairseq2.typing import Device, override


@final
class BleuMetric(Metric[Tensor]):
"""Computes the BLEU score."""

sys_len: Tensor
ref_len: Tensor
valid_ngrams: Tensor
total_ngrams: Tensor

def __init__(self, *, device: Optional[Device] = None) -> None:
super().__init__(device=device)

self._add_state("sys_len", torch.zeros((), device=device, dtype=torch.float64))
self._add_state("ref_len", torch.zeros((), device=device, dtype=torch.float64))

self._add_state("valid_ngrams", torch.zeros((MAX_NGRAM_ORDER,), device=device, dtype=torch.float64)) # fmt: skip
self._add_state("total_ngrams", torch.zeros((MAX_NGRAM_ORDER,), device=device, dtype=torch.float64)) # fmt: skip

@override
@torch.inference_mode()
def update(self, refs: Sequence[str], hyps: Sequence[str]) -> Self:
"""
:param refs:
The reference strings.
:param hyps:
The hypothesis strings.
"""
device = self.sys_len.device

bleu = corpus_bleu(hyps, [refs])

self.sys_len += bleu.sys_len
self.ref_len += bleu.ref_len

self.valid_ngrams += torch.tensor(bleu.counts, device=device)
self.total_ngrams += torch.tensor(bleu.totals, device=device)

return self

@override
@torch.inference_mode()
def compute(self) -> Tensor:
valid_ngrams = [int(v) for v in self.valid_ngrams.tolist()]
total_ngrams = [int(t) for t in self.total_ngrams.tolist()]

bleu = BLEU.compute_bleu(
valid_ngrams, total_ngrams, int(self.sys_len), int(self.ref_len)
)

return torch.tensor(bleu.score, device=self.sys_len.device)

@override
@torch.inference_mode()
def merge_state(self, metrics: Iterable[BleuMetric]) -> Self:
for metric in metrics:
self.sys_len += metric.sys_len.to(self.device)
self.ref_len += metric.ref_len.to(self.device)

self.valid_ngrams += metric.valid_ngrams.to(self.device)
self.total_ngrams += metric.total_ngrams.to(self.device)

return self
1 change: 1 addition & 0 deletions src/fairseq2/metrics/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class _MetricFormatter:
# fmt: off
"ctc_loss": _MetricFormatter("CTC Loss", 100, format_as_float),
"nll_loss": _MetricFormatter("NLL Loss", 100, format_as_float),
"bleu": _MetricFormatter("BLEU", 200, format_as_float),
"uer": _MetricFormatter("Unit Error Rate (UER)", 200, format_as_float),
"wer": _MetricFormatter("Word Error Rate (WER)", 200, format_as_float),
"gradient_norm": _MetricFormatter("Gradient Norm", 300, format_as_float),
Expand Down
17 changes: 17 additions & 0 deletions src/fairseq2/recipes/transformer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
# LICENSE file in the root directory of this source tree.

from fairseq2.recipes.cli import Cli, RecipeCommandHandler
from fairseq2.recipes.transformer.eval import (
load_transformer_evaluator,
transformer_eval_presets,
)
from fairseq2.recipes.transformer.translate import (
load_text_translator,
text_translate_presets,
Expand All @@ -16,6 +20,19 @@ def _setup_transformer_cli(cli: Cli) -> None:
"transformer", help="Transformer-based machine translation recipes"
)

# Eval
eval_handler = RecipeCommandHandler(
loader=load_transformer_evaluator,
preset_configs=transformer_eval_presets,
default_preset="nllb_dense_600m",
)

group.add_command(
name="eval",
handler=eval_handler,
help="evaluate a machine translation model",
)

# Translate
text_translate_handler = RecipeCommandHandler(
loader=load_text_translator,
Expand Down
Loading

0 comments on commit be51660

Please sign in to comment.