In [None]:
!pip install -qqq -U wandb --progress-bar off
import wandb
from huggingface_hub import login
from google.colab import userdata

login(userdata.get('HF_TOKEN'))

wb_token = userdata.get('wandb')
wandb.login(key=wb_token)

In [None]:
!pip install -q -U git+https://github.com/huggingface/transformers.git --progress-bar off
#!pip install -q -U git+https://github.com/huggingface/accelerate.git --progress-bar off
!pip install datasets evaluate --progress-bar off

In [None]:
from transformers import AutoTokenizer

base_model_id = "google/flan-t5-base"#"google/flan-t5-base"#"google-t5/t5-base"# "google/flan-t5-small"#"google-t5/t5-small"#"google/t5-v1_1-small"#"google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
tokenizer.add_special_tokens({'sep_token': "<s>"})

In [None]:
max_length = 512
stride = 128

In [None]:
def generate_input(_question, _context):
    return " ".join(["question:", _question.strip(), tokenizer.sep_token, "context:", _context.strip(), tokenizer.sep_token,  "answer:"])

def preprocess_mrqa_batch(examples):
        questions = examples["question"]
        contexts = examples["context"]
        answers = examples["answers"]

        inputs = [generate_input(question, context) for question, context in zip(questions, contexts)]
        targets = [answer['text'][0] if len(answer) > 0 else "" for answer in answers]
        return inputs, targets


# validation preprocessing
def preprocess_validation(examples):
    inputs, targets = preprocess_mrqa_batch(examples)

    model_inputs = tokenizer(inputs,
                             max_length=max_length,
                             stride=stride,
                             padding="max_length",
                             truncation=True,
                             return_overflowing_tokens=True,
                             return_offsets_mapping=True)
    labels = tokenizer(text_target=targets,
                       max_length=max_length,
                       stride=stride,
                       padding="max_length",
                       truncation=True)

    # Replace tokenizer.pad_token_id in the labels to ignore padding in the loss
    labels["input_ids"] = [
        [(l if l != tokenizer.pad_token_id else -100) for l in label] for label in labels["input_ids"]
    ]

    # examples with long context give us several features -> map feature to example
    sample_mapping = model_inputs.pop("overflow_to_sample_mapping")

    # convert predictions to substrings of the context for evaluation
    model_inputs["example_id"] = []
    # Augment the overflowing tokens to the labels
    labels_out = []
    for i in range(len(model_inputs["input_ids"])):
        # an example can give many spans -> take index of the example containing the span
        sample_index = sample_mapping[i]
        model_inputs["example_id"].append(examples["id"][sample_index])
        labels_out.append(labels["input_ids"][sample_index])

    model_inputs["labels"] = labels_out
    return model_inputs

In [None]:
import numpy as np

# source: https://github.com/huggingface/transformers/tree/main/examples/pytorch/question-answering
def postprocess_qa_predictions(examples, features, predictions):

    if isinstance(predictions, tuple):
        predictions = predictions[0]
    # Replace -100s used for padding as we can't decode them
    #predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id)
    predictions = [np.where(p != -100, p, tokenizer.pad_token_id) for p in predictions]
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)

    # Build a map example to its corresponding features.
    example_id_to_index = {k: i for i, k in enumerate(examples["id"])}
    feature_per_example = {example_id_to_index[feature["example_id"]]: i for i, feature in enumerate(features)}
    all_predictions = {}
    for example_index, example in enumerate(examples):
        # This is the index of the feature associated to the current example.
        feature_index = feature_per_example[example_index]
        all_predictions[example["id"]] = decoded_preds[feature_index]

    return all_predictions

In [None]:
from datasets import load_dataset
import evaluate

split ="test" #"validation" #
mrqa_eval = load_dataset("enriquesaou/mrqa-squadded-sample", split=split)

## preprocess eval dataset
eval_set = mrqa_eval.map(
    preprocess_validation,
    batched=True,
    remove_columns=mrqa_eval.column_names,
)

eval_set_for_model = eval_set.remove_columns(["example_id", "offset_mapping"])
eval_set_for_model.set_format("torch")

In [None]:
eval_set_for_model

In [None]:
# source: https://github.com/mrqa/MRQA-Shared-Task-2019/blob/master/mrqa_official_eval.py

import string
import re
import json
import gzip
from collections import Counter

def normalize_answer(s):
    """Lower text and remove punctuation, articles and extra whitespace."""
    def remove_articles(text):
        return re.sub(r'\b(a|an|the)\b', ' ', text)

    def white_space_fix(text):
        return ' '.join(text.split())

    def remove_punc(text):
        exclude = set(string.punctuation)
        return ''.join(ch for ch in text if ch not in exclude)

    def lower(text):
        return text.lower()

    return white_space_fix(remove_articles(remove_punc(lower(s))))


def f1_score(prediction, ground_truth):
    prediction_tokens = normalize_answer(prediction).split()
    ground_truth_tokens = normalize_answer(ground_truth).split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def exact_match_score(prediction, ground_truth):
    return (normalize_answer(prediction) == normalize_answer(ground_truth))


def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
    scores_for_ground_truths = []
    for ground_truth in ground_truths:
        score = metric_fn(prediction, ground_truth)
        scores_for_ground_truths.append(score)
    return max(scores_for_ground_truths)


def read_predictions(prediction_file):
    with open(prediction_file) as f:
        predictions = json.load(f)
    return predictions


def read_answers(gold_file):
    answers = {}
    with gzip.open(gold_file, 'rb') as f:
        for i, line in enumerate(f):
            example = json.loads(line)
            if i == 0 and 'header' in example:
                continue
            for qa in example['qas']:
                answers[qa['id']] = qa['answers']
    return answers


def evaluate_predictions(answers, predictions, skip_no_answer=False):
    f1 = exact_match = total = 0

    for qid, ground_truths in answers.items():
        if qid not in predictions:
            if not skip_no_answer:
                message = 'Unanswered question %s will receive score 0.' % qid
                print(message)
                total += 1
            continue
        total += 1
        prediction = predictions[qid]
        exact_match += metric_max_over_ground_truths(
            exact_match_score, prediction, ground_truths)
        f1 += metric_max_over_ground_truths(
            f1_score, prediction, ground_truths)

    exact_match = 100.0 * exact_match / total
    f1 = 100.0 * f1 / total

    return {'exact_match': exact_match, 'f1': f1}

In [None]:
models_to_evaluate = ["enriquesaou/flan-t5-base-mrqa-16"]

In [None]:
import torch
from transformers import T5ForConditionalGeneration
from tqdm.auto import tqdm

# use cuda for faster computation
device = torch.device("cuda")

for model_id in models_to_evaluate:
    # load model and evaluate
    model_for_eval = T5ForConditionalGeneration.from_pretrained(model_id).to(device)

    dataloader = torch.utils.data.DataLoader(eval_set_for_model, batch_size=32)
    outputs = []
    for batch in tqdm(dataloader):
      outs = model_for_eval.generate(input_ids=batch['input_ids'].to(device),
                                     attention_mask=batch['attention_mask'].to(device),
                                     max_new_tokens=16)
      outputs.extend(outs)

    outputs = [o.to('cpu') for o in outputs]

    # postprocess the predictions
    all_predictions = postprocess_qa_predictions(
        examples=mrqa_eval,
        features=eval_set,
        predictions=outputs)

    # compute metrics
    answers = mrqa_eval.to_dict()
    answers = {id: aws['text'] for id, aws in zip(answers['id'], answers['answers'])}
    metrics = evaluate_predictions(answers, predictions=all_predictions)

    print(model_id, split, json.dumps(metrics))

    """
    for k in answers.keys():
        print(all_predictions[k], answers[k])
        metrics = evaluate_predictions({k: answers[k]}, {k: all_predictions[k]})
        print(json.dumps(metrics))
    """
