In [None]:
# Install required packages
!pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113
!pip install fairseq==0.12.2 sacremoses sacrebleu>=1.4.12 unbabel-comet

## implementation of the objective function
To use it, merge it with rl_criterion.py from https://github.com/afeena/fairseq_easy_extend/

In [None]:
# extra imports
from sacrebleu import sentence_bleu, sentence_chrf
import torch.nn.functional as F

from comet import download_model, load_from_checkpoint

In [None]:
@register_criterion("rl_loss", dataclass=RLCriterionConfig)
class RLCriterion(FairseqCriterion):
    def __init__(self, task, sentence_level_metric):
        super().__init__(task)
        self.metric = sentence_level_metric.lower()
        if self.metric == "bleu":
            self.metric_func = sentence_bleu
        elif self.metric == "chrf":
            self.metric_func = sentence_chrf
        elif self.metric == "comet":
            model_path = download_model("Unbabel/wmt22-comet-da")
            model = load_from_checkpoint(model_path)
            self.metric_func = model.predict
        else:
            raise Exception("RL metric not yet implemented")
        self.tokenizer = encoders.build_tokenizer(Namespace(tokenizer="moses"))
        self.tgt_dict = task.target_dictionary
        self.src_dict = task.source_dictionary

    def forward(self, model, sample, reduce=True):
        """Compute the loss for the given sample.
        Returns a tuple with three elements:
        1) the loss
        2) the sample size, which is used as the denominator for the gradient
        3) logging outputs to display while training
        """
        nsentences, ntokens = sample["nsentences"], sample["ntokens"]
        # B x T
        src_tokens, src_lengths = (
            sample["net_input"]["src_tokens"],
            sample["net_input"]["src_lengths"],
        )
        tgt_tokens, prev_output_tokens = sample["target"], sample["prev_target"]
        outputs = model(src_tokens, src_lengths, prev_output_tokens, tgt_tokens)
        # get loss only on tokens, not on lengths
        outs = outputs["word_ins"].get("out", None)
        # masks = outputs["word_ins"].get("mask", None)

        loss, reward = self._compute_loss(
            outs,
            tgt_tokens,
            src_tokens,
            #   masks
        )

        # NOTE:
        # we don't need to use sample_size as denominator for the gradient
        # here sample_size is just used for logging
        sample_size = 1
        logging_output = {
            "loss": loss.detach(),
            "nll_loss": loss.detach(),
            "ntokens": ntokens,
            "nsentences": nsentences,
            "sample_size": sample_size,
            "reward": reward.detach(),
        }

        return loss, sample_size, logging_output

    def decode(self, toks, escape_unk=False, dict="tgt"):
        with torch.no_grad():
            if dict == "tgt":
                s = self.tgt_dict.string(
                    toks.int().cpu(),
                    "@@ ",
                    # The default unknown string in fairseq is `<unk>`, but
                    # this is tokenized by sacrebleu as `< unk >`, inflating
                    # BLEU scores. Instead, we use a somewhat more verbose
                    # alternative that is unlikely to appear in the real
                    # reference, but doesn't get split into multiple tokens.
                    unk_string=(
                        "UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"
                    ),
                )
            else:
                s = self.src_dict.string(
                    toks.int().cpu(),
                    "@@ ",
                    unk_string=(
                        "UNKNOWNTOKENINREF" if escape_unk else "UNKNOWNTOKENINHYP"
                    ),
                )
            s = self.tokenizer.decode(s)
        return s

    def compute_reward(self, sample_idx, targets, src_tokens):
        """ """
        with torch.no_grad():
            if self.metric == "comet":
                batch = [
                    {
                        "src": self.decode(src_tokens_sent, dict="src"),
                        "mt": self.decode(sample_sent),
                        "ref": self.decode(target),
                    }
                    for sample_sent, target, src_tokens_sent in zip(
                        sample_idx, targets, src_tokens
                    )
                ]
                reward = torch.tensor(
                    self.metric_func(batch, batch_size=64, progress_bar=False).scores
                )
            else:
                sampled_sentences_strings = [
                    self.decode(sample_idx_sent) for sample_idx_sent in sample_idx
                ]

                targets_strings = [
                    self.decode(
                        target_sent,
                    )
                    for target_sent in targets
                ]
                reward = torch.tensor(
                    [
                        self.metric_func(sampled_sentence, [target]).score
                        for sampled_sentence, target in zip(
                            sampled_sentences_strings, targets_strings
                        )
                    ]
                )
            return reward

    def _compute_loss(self, outputs, targets, src_tokens, masks=None):
        """
        outputs: batch x len x d_model
        targets: batch x len
        masks:   batch x len
        """
        # Locate possible padding tokens for masking later
        masks = targets.ne(self.tgt_dict.pad())
        bsz, seq_len, vocab_size = outputs.size()
        # Flatten for sampling
        probs = F.softmax(outputs, dim=-1).view(-1, vocab_size)
        # Bring back to sentence view after sampling
        sample_idx = torch.multinomial(probs, 1, replacement=True).view(bsz, seq_len)

        with torch.no_grad():
            ####HERE calculate metric###
            reward = self.compute_reward(sample_idx, targets, src_tokens)

        # expand it to make it of a shape BxT - each token gets the same reward value (e.g. bleu is 20, so each token gets reward of 20 [20,20,20,20,20])
        reward = reward.unsqueeze(1).repeat(1, seq_len)
        # now you need to apply mask on both outputs and reward
        if masks is not None:
            outputs, targets = outputs[masks], targets[masks]
            reward, sample_idx = reward[masks], sample_idx[masks]
        # numerically more stable than log on probs
        log_probs = F.log_softmax(outputs, dim=-1)
        # select the log probs for the sampled indices
        log_probs_of_samples = log_probs.gather(1, sample_idx.unsqueeze(1)).squeeze()
        # compute loss
        loss = -log_probs_of_samples * reward.to(log_probs_of_samples.device)
        return loss.mean(), reward.mean()