# Inference and evaluation

## Set-up

In [None]:
DATA_DIR = "data"
MODEL_DIR = "models"

In [None]:
try:
    import google.colab
    from google.colab import drive
    drive.mount('/content/drive')
    FULL_DATA_DIR = f'/content/drive/My Drive/mbr-reranking/{DATA_DIR}'
    FULL_MODEL_DIR = f'/content/drive/My Drive/mbr-reranking/{MODEL_DIR}'

    IN_COLAB = True
except:
    FULL_DATA_DIR = DATA_DIR
    FULL_MODEL_DIR = MODEL_DIR

    IN_COLAB = False

In [None]:
try:
    import sentencepiece
except:
    !pip install sentencepiece
    import sentencepiece

try:
    import evaluate
except:
    !pip install evaluate
    import evaluate

!pip install git+https://github.com/google-research/bleurt.git
!pip install unbabel-comet

In [None]:
import locale
def getpreferredencoding(do_setlocale = True):
    return "UTF-8"
locale.getpreferredencoding = getpreferredencoding

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import XLMRobertaTokenizer, XLMRobertaModel
from transformers import AdamW

from tqdm import tqdm

import evaluate
bleurt = evaluate.load('bleurt')
comet_metric = evaluate.load('comet')
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
# smoothie = SmoothingFunction().method1

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
import time

class Timer:
    def __enter__(self):
        self.start_time = time.time()
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.end_time = time.time()
        elapsed_time = self.end_time - self.start_time
        print(f"Elapsed time: {elapsed_time} seconds")

## Model

In [None]:
config = {
    "model_path": "model.pt",
    "model_has_original_first": True,

    "valid_batch_size": 128,
    "compute_mbr_consensus": True,
}

In [None]:
# Mean pooling
class MeanPooling(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, hidden_states, attention_mask):
        # Mean pooling
        input_mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
        sum_embeddings = torch.sum(hidden_states * input_mask_expanded, 1)
        sum_mask = input_mask_expanded.sum(1)
        mean_pooled = sum_embeddings / sum_mask
        return mean_pooled

# Model
class RegressionModel(nn.Module):
    def __init__(self, pretrained_model):
        super().__init__()
        self.pretrained_model = pretrained_model
        self.regression_head = torch.nn.Linear(pretrained_model.config.hidden_size, 1)
        self.pooling = MeanPooling()
        self.pretrained_frozen = False

    def forward(self, input_ids, attention_mask):
        if self.pretrained_frozen:
            with torch.no_grad():
                token_embeddings = self.pretrained_model(input_ids, attention_mask=attention_mask)
                pooled_embedding = self.pooling(token_embeddings.last_hidden_state, attention_mask)
        else:
            token_embeddings = self.pretrained_model(input_ids, attention_mask=attention_mask)
            pooled_embedding = self.pooling(token_embeddings.last_hidden_state, attention_mask)
        return self.regression_head(pooled_embedding)

    def freeze_pretrained(self):
        self.pretrained_frozen = True
        for param in self.pretrained_model.parameters():
            param.requires_grad = False

    def unfreeze_pretrained(self):
        self.pretrained_frozen = False
        for param in self.pretrained_model.parameters():
            param.requires_grad = True

In [None]:
# Load the pre-trained model
pretrained_model = XLMRobertaModel.from_pretrained('xlm-roberta-base')
tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')

model = RegressionModel(pretrained_model)
model.load_state_dict(torch.load(f'{FULL_MODEL_DIR}/{config["model_path"]}', map_location=torch.device('cpu')))

model.eval()
model = model.to(device)

## Reranker inference

In [None]:
from torch.utils.data import Sampler
import random
import math

class RegressionDataset(Dataset):
    def __init__(self, original, generated):
        mult = len(generated) // len(original)
        if config["model_has_original_first"]:
            texts = [original[i // mult] + tokenizer.sep_token + generated[i] for i in range(len(generated))]
        else:
            texts = [generated[i] + tokenizer.sep_token + original[i // mult] for i in range(len(generated))]
        self.encodings = tokenizer(texts, truncation=True, padding=True)

        self.sorted_indices = sorted(range(len(texts)), key=lambda i: sum(self.encodings["attention_mask"][i]))

    def __getitem__(self, index):
        idx = self.sorted_indices[index]
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['index'] = idx
        return item

    def __len__(self):
        return len(self.sorted_indices)

In [None]:
def truncate_batch(input_ids, attention_mask):
    # Find the maximum sequence length in this batch
    max_len = attention_mask.sum(dim=1).max().item()

    # Truncate input_ids and attention_mask to max_len
    truncated_input_ids = input_ids[:, :max_len]
    truncated_attention_mask = attention_mask[:, :max_len]

    return truncated_input_ids, truncated_attention_mask

In [None]:
def compute_reranker_scores(original, generated):

    # Create datasets and dataloaders
    dataset = RegressionDataset(original, generated)
    dataloader = DataLoader(dataset, batch_size=config["valid_batch_size"])

    scores_flattened = [-np.inf for i in range(len(generated))]
    with torch.no_grad(), torch.autocast(device_type=device, enabled=True):
        pbar = tqdm(dataloader, total=len(dataloader))
        for step, batch in enumerate(pbar):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            indices = batch['index']

            outputs = model(input_ids, attention_mask=attention_mask).squeeze()

            for idx, score in zip(indices, outputs):
                scores_flattened[idx] = score.cpu().item()

    return scores_flattened

## Data

In [None]:
class ScoredSentences:
    def __init__(self, sentences, originals, references, max_sentences=None):
        self.samples_per_sentence = len(sentences) // len(originals)
        self.references_per_sentence = len(references) // len(originals)
        self.mbr_consensus_samples_per_sentence = self.samples_per_sentence

        if max_sentences is not None:
            sentences = sentences[:(max_sentences*self.samples_per_sentence)]
            references = references[:(max_sentences*self.references_per_sentence)]
            originals = originals[:max_sentences]

        self.sentences_flat = sentences
        self.originals = originals
        self.references = references
        self.bleurt = np.array([])
        self.comet = np.array([])
        self.reranker_score = np.array([])
        self.bleu = np.array([])
        self.mbr_expected_utility = np.zeros((len(self.originals), self.mbr_consensus_samples_per_sentence))

    def compute_scores(self):
        references = [self.references[i // self.samples_per_sentence] for i in range(len(self.sentences_flat))]
        sources = [self.originals[i // self.samples_per_sentence] for i in range(len(self.sentences_flat))]

        # BLEURT
        print("Computing BLEURT")
        with Timer():
            bleurt_flat = bleurt.compute(predictions=self.sentences_flat, references=references)["scores"]
            self.bleurt = np.array(bleurt_flat).reshape((len(self.originals), self.samples_per_sentence))

        # COMET
        print("Computing COMET")
        with Timer():
            comet_flat = comet_metric.compute(predictions=self.sentences_flat, references=references, sources=sources)["scores"]
            self.comet = np.array(comet_flat).reshape((len(self.originals), self.samples_per_sentence))

        # BLEU
        print("Computing BLEU")
        with Timer():
            bleu_flat = []
            for reference, sentence in zip(references, self.sentences_flat):
                bleu_flat.append(sentence_bleu(references=[reference], hypothesis=sentence,
                        # smoothing_function=smoothie
                  ))
            self.bleu = np.array(bleu_flat).reshape((len(self.originals), self.samples_per_sentence))

        # Rerank score
        print("Computing reranker score")
        with Timer():
            score_flat = compute_reranker_scores(self.originals, self.sentences_flat)
            self.reranker_score = np.array(score_flat).reshape((len(self.originals), self.samples_per_sentence))

    def compute_mbr_expected_utility(self):
        self.mbr_expected_utility = np.zeros((len(self.originals), self.mbr_consensus_samples_per_sentence))
        for i in tqdm(range(len(self.originals))):
            refs = []
            hyps = []
            for j in range(self.mbr_consensus_samples_per_sentence):
                hyp_sent = self.sentences_flat[i*self.samples_per_sentence+j]
                for k in range(self.mbr_consensus_samples_per_sentence):
                    ref_sent = self.sentences_flat[i*self.samples_per_sentence+k]
                    refs.append(ref_sent)
                    hyps.append(hyp_sent)

            bleurt_i = bleurt.compute(predictions=hyps, references=refs)["scores"]
            bleurt_i = np.array(bleurt_i).reshape((self.mbr_consensus_samples_per_sentence, self.mbr_consensus_samples_per_sentence))
            self.mbr_expected_utility[i] = bleurt_i.mean(axis=-1)


    def print_reranker_metrics(self):
        print("Reranker MSE   ", np.mean((self.reranker_score.reshape((-1,)) - self.bleurt.reshape((-1,)))**2))
        print("Reranker 1-corr", 1-np.corrcoef((self.reranker_score.reshape((-1,)), self.bleurt.reshape((-1,))))[0,1])
        if len(self.mbr_expected_utility) > 0:
          print("MBR.E.U. MSE   ", np.mean((self.mbr_expected_utility.reshape((-1,)) - self.bleurt.reshape((-1,)))**2))
          print("MBR.E.U. 1-corr", 1-np.corrcoef((self.mbr_expected_utility.reshape((-1,)), self.bleurt.reshape((-1,))))[0,1])
        print()

def read_from_file(filepath):
    result = []
    with open(f"{FULL_DATA_DIR}/{filepath}", 'r', encoding='utf-8') as fp:
        for line in fp:
            result.append(line.strip())
    return result

# Load the data

split = "test"
assert split in ["dev", "test"]

originals = read_from_file(f"{split}.deu")
references = read_from_file(f"{split}.eng")

if config["compute_mbr_consensus"]:
    max_sentences = len(originals) // 10
else:
    max_sentences = None

generated = ScoredSentences(read_from_file(f"sampled/{split}.eng"), originals, references, max_sentences=max_sentences)
generated_cold = ScoredSentences(read_from_file(f"sampled/{split}-cold.eng"), originals, references, max_sentences=max_sentences)
beamsearch = ScoredSentences(read_from_file(f"beams/{split}.eng"), originals, references, max_sentences=max_sentences)

mult = generated.samples_per_sentence

## Score computation

In [None]:
generated.compute_scores()
generated_cold.compute_scores()
beamsearch.compute_scores()

In [None]:
if config["compute_mbr_consensus"]:
    generated.compute_mbr_expected_utility()
    generated_cold.compute_mbr_expected_utility()

In [None]:
generated.print_reranker_metrics()
generated_cold.print_reranker_metrics()
beamsearch.print_reranker_metrics()

## Examples of selected translations

In [None]:
def print_samples(samples):
    for i, (oracle_idx, best_idx, worst_idx) in enumerate(zip(np.argmax(samples.bleurt, axis=-1), np.argmax(samples.reranker_score, axis=-1), np.argmin(samples.reranker_score, axis=-1))):
        print("Original: ", originals[i])
        print("Reference:", references[i])
        print("Beam srch:", beamsearch.sentences_flat[i])
        print()
        print("Oracle:   ", samples.sentences_flat[i*mult+oracle_idx])
        print("Best:     ", samples.sentences_flat[i*mult+best_idx])
        print("Worst:    ", samples.sentences_flat[i*mult+worst_idx])
        print()
        beam_score = beamsearch.bleurt[i,0]
        oracle_score = samples.bleurt[i,oracle_idx]
        best_score = samples.bleurt[i,best_idx]
        print(f"beam: {beam_score:.3f}, best: {best_score:.3f} <= {oracle_score:.3f} (predicted {samples.reranker_score[i][best_idx]:.3f} >= {samples.reranker_score[i][oracle_idx]:.3f})")
        print()
        print(30*"-")
        print()

In [None]:
print_samples(generated)

In [None]:
print_samples(generated_cold)

## Main evaluation

In [None]:
def print_metrics(samples, display_mbr_consensus=config["compute_mbr_consensus"]):

    best_e_oracle = 0
    best_e_mbr = 0
    mbr_e_oracle = 0

    for i, (oracle_idx, best_idx, mbr_idx) in enumerate(zip(np.argmax(samples.bleurt, axis=-1), np.argmax(samples.reranker_score, axis=-1), np.argmax(samples.mbr_expected_utility, axis=-1))):
        if oracle_idx == best_idx:
            best_e_oracle += 1
        if mbr_idx == best_idx:
            best_e_mbr += 1
        if mbr_idx == oracle_idx:
            mbr_e_oracle += 1

    print(f"Our selection is the MBR oracle   on {100*best_e_oracle/len(samples.originals):.02f}% of cases")
    if display_mbr_consensus:
        print(f"Our selection is the MBR cons.    on {100*best_e_mbr/len(samples.originals):.02f}% of cases")
        print(f"The MBR cons. is the MBR oracle   on {100*mbr_e_oracle/len(samples.originals):.02f}% of cases")
    print()

    scores = [("bleurt", samples.bleurt, beamsearch.bleurt), ("comet", samples.comet, beamsearch.comet), ("bleu", samples.bleu, beamsearch.bleu)]

    for (score_name, samples_score, beamsearch_score) in scores:

        oracle_ge_beam = 0
        best_ge_beam = 0
        mbr_ge_beam = 0
        best_ge_mbr = 0

        beam_sum = 0
        oracle_sum = 0
        best_sum = 0
        mbr_sum = 0


        for i, (oracle_idx, best_idx, mbr_idx) in enumerate(zip(np.argmax(samples.bleurt, axis=-1), np.argmax(samples.reranker_score, axis=-1), np.argmax(samples.mbr_expected_utility, axis=-1))):
            beam_score = beamsearch_score[i,0]
            oracle_score = samples_score[i,oracle_idx]
            best_score = samples_score[i,best_idx]
            mbr_score = samples_score[i,mbr_idx]

            if oracle_score >= beam_score:
                oracle_ge_beam += 1
            if best_score >= beam_score:
                best_ge_beam += 1
            if mbr_score >= beam_score:
                mbr_ge_beam += 1
            if best_score >= mbr_score:
                best_ge_mbr += 1

            beam_sum += beam_score
            oracle_sum += oracle_score
            best_sum += best_score
            mbr_sum += mbr_score

        print(f"===[{score_name}]===")
        print()
        print(f"The MBR oracle is better than beam on {100*oracle_ge_beam/len(samples.originals):.02f}% of cases")
        print(f"Our selection  is better than beam on {100*best_ge_beam/len(samples.originals):.02f}% of cases")
        if display_mbr_consensus:
            print(f"The MBR cons.  is better than beam on {100*mbr_ge_beam/len(samples.originals):.02f}% of cases")
            print(f"Our selection  is better than MBR  on {100*best_ge_mbr/len(samples.originals):.02f}% of cases")
        print(f"Beam       average score = {100*beam_sum/len(samples.originals):.02f}")
        print(f"MBR oracle average score = {100*oracle_sum/len(samples.originals):.02f}")
        print(f"Our selec. average score = {100*best_sum/len(samples.originals):.02f}")
        if display_mbr_consensus:
            print(f"MBR cons.  average score = {100*mbr_sum/len(samples.originals):.02f}")
        print()
        print()

In [None]:
print("\n\n\n temp = 1\n\n\n")
print_metrics(generated)
print("\n\n\n temp = 0.7\n\n\n")
print_metrics(generated_cold)

## Evaluation with different hypothesis sample sizes

We evaluate how our translations (but not the standard MBR ones) compare for different number of samples

In [None]:
def subsample(samples, num_samples):
    num_sentences = len(samples.originals)
    num_samples = min(num_samples, samples.samples_per_sentence)
    sentences_not_flat = [[samples.sentences_flat[i*samples.samples_per_sentence+j] for j in range(samples.samples_per_sentence)] for i in range(num_sentences)]
    sentences_flat = sum([list(x[:num_samples]) for x in sentences_not_flat], [])
    result = ScoredSentences(sentences_flat, samples.originals, samples.references)
    result.bleurt = samples.bleurt[:, :num_samples]
    result.bleu = samples.bleu[:, :num_samples]
    result.comet = samples.comet[:, :num_samples]
    result.reranker_score = samples.reranker_score[:, :num_samples]
    result.mbr_expected_utility = samples.mbr_expected_utility[:, :num_samples]
    return result

In [None]:
for S in [1, 5, 10]:
    print(f"\n\n\n\nS = {S}\n\n\n\n")

    print("\n\n\n temp = 1\n\n\n")
    g = subsample(generated, S)
    g.compute_scores()
    if config["compute_mbr_consensus"]:
        g.compute_mbr_expected_utility()
    print_metrics(g)

    print("\n\n\n temp = 0.7\n\n\n")
    g = subsample(generated_cold, S)
    g.compute_scores()
    if config["compute_mbr_consensus"]:
        g.compute_mbr_expected_utility()
    print_metrics(g)

## End

If in Google Colab, kill the session:

In [None]:
if IN_COLAB:
    import time
    time.sleep(15)

    from google.colab import runtime
    runtime.unassign()