# Imports

In [None]:
from collections import defaultdict
import json


def open_file(path, mode):
  return open(path, mode)

# Load Input File

In [None]:
# A path and filename with "{}" in place of the language variety name (from
# LANGUAGES below).
input_format_str = '{}.jsonl' # @param {type: 'string'}

LANGUAGES = [
    'tajik',
    'farsi',
    'arabic_iraq',
    'arabic_jordan',
    'azerbaijani',
    'armenian',
    'arabic_egypt',
    'turkish',
    'hebrew',
    'arabic_algeria',
]

examples_by_language = defaultdict(list)
for language in LANGUAGES:
  input_filename = input_format_str.format(language)
  with open_file(input_filename, 'r') as f:
    for line in f:
      examples_by_language[language].append(json.loads(line))
  print(f'Loaded {len(examples_by_language[language])} inputs for {language}.')

# Define Metrics

In [None]:
def validate_predicted_answer_and_indices(
    predicted_answer: str,
    predicted_answer_byte_start_index: int | None,
    predicted_answer_byte_end_index: int | None,
    article_text_bytes: bytes,
) -> None:
  provided_indices = [
      predicted_answer_byte_start_index is not None,
      predicted_answer_byte_end_index is not None,
  ]
  assert all(provided_indices) or not any(
      provided_indices
  ), "Cannot provide only one predicted answer start/end index!"
  if all(provided_indices):
    assert (
        predicted_answer
        == article_text_bytes[
            predicted_answer_byte_start_index:predicted_answer_byte_end_index
        ]
    ), "predicted_answer must be calculated from provided start/end indices!"


def exact_match_example(
    predicted_answer: str,
    gold_answer: str,
    answer_type: str,
    article_text: str,
    predicted_answer_byte_start_index: int | None,
    predicted_answer_byte_end_index: int | None,
) -> bool:
  """Per-example Exact Match (EM) score.

  Credit is awarded if the predicted answer is exactly the same as the gold
  answer, with some considerations for non-minimal-span answers.
  """

  if answer_type == "no_answer":
    return predicted_answer.lower().strip() == "no answer"
  elif answer_type == "yes_no":
    return predicted_answer.lower().strip() == gold_answer.lower().strip()
  elif answer_type == "minimal_span":
    validate_predicted_answer_and_indices(
        predicted_answer,
        predicted_answer_byte_start_index,
        predicted_answer_byte_end_index,
        article_text.encode("utf-8"),
    )
    return predicted_answer.strip() == gold_answer.strip()
  else:
    raise NotImplementedError(f"Unknown answer type: {answer_type}")


def exact_match(
    inputs: list[dict],
    article_text_key: str,
    gold_answer_key: str,
    predicted_answer_key: str,
    predicted_answer_byte_start_index_key: str | None,
    predicted_answer_byte_end_index_key: str | None,
    answer_types_key: str,
) -> float:
  """Exact Match (EM) metric on eval set. See exact_match_example."""
  correct_count = 0
  all_count = 0

  for item in inputs:
    predicted_answer = item[predicted_answer_key].strip()
    answer_types = item[answer_types_key]
    gold_answers = item[gold_answer_key]
    article_text = item[article_text_key]
    predicted_answer_byte_start_index = (
        item[predicted_answer_byte_start_index_key]
        if predicted_answer_byte_start_index_key is not None
        else None
    )
    predicted_answer_byte_end_index = (
        item[predicted_answer_byte_end_index_key]
        if predicted_answer_byte_end_index_key is not None
        else None
    )
    if any([
        exact_match_example(
            predicted_answer,
            gold_answer,
            answer_type,
            article_text,
            predicted_answer_byte_start_index,
            predicted_answer_byte_end_index,
        )
        for gold_answer, answer_type in zip(gold_answers, answer_types)
    ]):
      correct_count += 1
    all_count += 1

  return correct_count / all_count


def f1_example(
    gold_answer: str,
    predicted_answer: str,
    predicted_answer_byte_start_index: int | None,
    predicted_answer_byte_end_index: int | None,
    article_text: str,
    answer_type: str,
    gold_answer_start_byte_index: int,
    gold_answer_end_byte_index: int,
) -> float:
  """Per-example F1 score."""
  article_text_bytes = article_text.encode("utf-8")
  predicted_answer_bytes = predicted_answer.strip().encode("utf-8")

  if answer_type == "no_answer":
    return 1.0 if predicted_answer.lower().strip() == "no answer" else 0.0
  elif answer_type == "yes_no":
    return (
        1.0 if predicted_answer.lower().strip() == gold_answer.lower() else 0.0
    )
  elif answer_type == "minimal_span":
    validate_predicted_answer_and_indices(
        predicted_answer,
        predicted_answer_byte_start_index,
        predicted_answer_byte_end_index,
        article_text_bytes,
    )
    predicted_start = (
        article_text_bytes.find(predicted_answer_bytes)
        if predicted_answer_byte_start_index is None
        else predicted_answer_byte_start_index
    )
    if predicted_start == -1:
      return 0.0
    else:
      predicted_end = (
          predicted_start + len(predicted_answer_bytes)
          if predicted_answer_byte_end_index is None
          else predicted_answer_byte_end_index
      )
      predicted_indices = set(range(predicted_start, predicted_end))
      gold_indices = set(
          range(gold_answer_start_byte_index, gold_answer_end_byte_index)
      )
      tp2 = 2 * len(
          predicted_indices.intersection(gold_indices)
      )  # True positives * 2
      fp = len(predicted_indices - gold_indices)  # False positives
      fn = len(gold_indices - predicted_indices)  # False negatives
      return tp2 / (tp2 + fp + fn)  # F1 = (2*TP) / ((2*TP) + FP + FN)
  else:
    raise NotImplementedError(f"Unknown answer type: {answer_type}")


# F1 metric on eval set.
def f1(
    inputs: list[dict],
    gold_answer_key: str,
    gold_answer_start_byte_indices_key: str,
    gold_answer_end_byte_indices_key: str,
    predicted_answer_key: str,
    predicted_answer_byte_start_index_key: str | None,
    predicted_answer_byte_end_index_key: str | None,
    article_text_key: str,
    answer_types_key: str,
) -> float:
  scores = []
  for item in inputs:
    gold_answers = item[gold_answer_key]
    article_text = item[article_text_key]
    predicted_answer = item[predicted_answer_key]
    predicted_answer_byte_start_index = (
        item[predicted_answer_byte_start_index_key]
        if predicted_answer_byte_start_index_key is not None
        else None
    )
    predicted_answer_byte_end_index = (
        item[predicted_answer_byte_end_index_key]
        if predicted_answer_byte_end_index_key is not None
        else None
    )
    answer_types = item[answer_types_key]
    gold_start_byte_indices = item[gold_answer_start_byte_indices_key]
    gold_end_byte_indices = item[gold_answer_end_byte_indices_key]
    best_f1 = 0.0
    for i, gold_answer in enumerate(gold_answers):
      this_f1 = f1_example(
          gold_answer,
          predicted_answer,
          predicted_answer_byte_start_index,
          predicted_answer_byte_end_index,
          article_text,
          answer_types[i],
          gold_start_byte_indices[i],
          gold_end_byte_indices[i],
      )
      if this_f1 > best_f1:
        best_f1 = this_f1
    scores.append(best_f1)

  return sum(scores) / len(scores)

# Evaluate the Inputs

In [None]:
# If your model predictions already contain byte start/end indices, set this to
# False. Otherwise, the first occurence of the predicted string in the article
# will be used to infer the predicted indices.
INFER_PREDICTED_INDICES = True # @param {type: "boolean"}

gold_answer_key = "answer_texts"  # @param ["answer_texts"] {allow-input: true}
gold_answer_start_byte_indices_key = "answer_start_byte_indices" # @param ["answer_start_byte_indices"] {allow-input: true}
gold_answer_end_byte_indices_key = "answer_end_byte_indices" # @param ["answer_end_byte_indices"] {allow-input: true}
predicted_answer_key = "generated_answer" # @param ["generated_answer"] {allow-input: true}
predicted_answer_byte_start_index_key = "generated_answer_byte_start_index" # @param ["generated_answer_byte_start_index"] {allow-input: true}
predicted_answer_byte_end_index_key = "generated_answer_byte_end_index" # @param ["generated_answer_byte_end_index"] {allow-input: true}
article_text_key = "article_plaintext_trimmed" # @param ["article_plaintext_trimmed", "article_plaintext"] {allow-input: true}
answer_types_key = "answer_types"  # @param ["answer_types"] {allow-input: true}
metrics_by_language = defaultdict(dict)
for language in LANGUAGES:
  print(f"{language}:")
  # Exact Match
  exact_match_metric = exact_match(
      examples_by_language[language],
      article_text_key,
      gold_answer_key,
      predicted_answer_key,
      None if INFER_PREDICTED_INDICES else predicted_answer_byte_start_index_key,
      None if INFER_PREDICTED_INDICES else predicted_answer_byte_end_index_key,
      answer_types_key,
  )
  metrics_by_language[language]["exact_match"] = exact_match_metric

  # F1
  f1_metric = f1(
      examples_by_language[language],
      gold_answer_key,
      gold_answer_start_byte_indices_key,
      gold_answer_end_byte_indices_key,
      predicted_answer_key,
      None if INFER_PREDICTED_INDICES else predicted_answer_byte_start_index_key,
      None if INFER_PREDICTED_INDICES else predicted_answer_byte_end_index_key,
      article_text_key,
      answer_types_key,
  )
  metrics_by_language[language]["f1"] = f1_metric

  print(metrics_by_language[language])
  print()