<a href="https://colab.research.google.com/github/halnegheimish/ForcedInvalidation/blob/main/notebooks/naqanet_FI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Code based on https://github.com/allenai/allennlp-reading-comprehension/blob/master/allennlp_rc/models/naqanet.py, with changes to account for Forced Invalidation

In [None]:
!pip install allennlp==2.1.0 allennlp-models==2.1.0

In [None]:
!pip install spacy-transformers==1.2.1

In [None]:
!gdown --fuzzy --folder https://drive.google.com/drive/folders/1m-Q7M-yvwHSI11peuMkwssphQ6J4q6-N?usp=share_link

In [None]:
!apt install subversion

In [None]:
!svn checkout https://github.com/halnegheimish/ForcedInvalidation/trunk/data/eval/drop/


In [None]:
import os
seed = 42


In [None]:
serialization_dir= f"naqanet_invalid_seed_{seed}"
os.makedirs(serialization_dir, exist_ok=True)

In [None]:
augmented_train_data_path= 'drop/drop_aug_train_ngrams.json'
augmented_dev_data_path=   'drop/drop_aug_val_ngrams.json'

#code 

In [None]:
#naqanet imports
from typing import Any, Dict, List, Optional
import logging

import torch

from allennlp.data import Vocabulary
from allennlp.models.model import Model
from allennlp.modules import Highway
from allennlp.nn.activations import Activation
from allennlp.modules.feedforward import FeedForward
from allennlp.modules import Seq2SeqEncoder, TextFieldEmbedder
from allennlp.modules.matrix_attention.matrix_attention import MatrixAttention
from allennlp.nn import util, InitializerApplicator, RegularizerApplicator
from allennlp.nn.util import masked_softmax

from allennlp_models.rc.models.utils import (
    get_best_span,
    replace_masked_values_with_big_negative_number,
)
from allennlp_models.rc.metrics.drop_em_and_f1 import DropEmAndF1


logger = logging.getLogger(__name__)

#train util imports
import numpy as np
import random


#set random seeds
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
random.seed(seed)

#based on original drop reader https://github.com/allenai/allennlp-models/blob/main/allennlp_models/rc/dataset_readers/drop.py
import itertools
import json
import logging
import string
from collections import defaultdict
from typing import Dict, List, Union, Tuple, Any

from overrides import overrides
from word2number.w2n import word_to_num

from allennlp.common.file_utils import cached_path
from allennlp.data.fields import (
    Field,
    TextField,
    MetadataField,
    LabelField,
    ListField,
    SequenceLabelField,
    SpanField,
    IndexField,
)
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import SingleIdTokenIndexer, TokenIndexer
from allennlp.data.tokenizers import Token, Tokenizer, SpacyTokenizer

from allennlp_models.rc.dataset_readers.utils import (
    IGNORED_TOKENS,
    STRIPPED_CHARACTERS,
    make_reading_comprehension_instance,
    split_tokens_by_hyphen,
)

logger = logging.getLogger(__name__)


WORD_NUMBER_MAP = {
    "zero": 0,
    "one": 1,
    "two": 2,
    "three": 3,
    "four": 4,
    "five": 5,
    "six": 6,
    "seven": 7,
    "eight": 8,
    "nine": 9,
    "ten": 10,
    "eleven": 11,
    "twelve": 12,
    "thirteen": 13,
    "fourteen": 14,
    "fifteen": 15,
    "sixteen": 16,
    "seventeen": 17,
    "eighteen": 18,
    "nineteen": 19,
}

class DropShuffReaderInvalid(DatasetReader):
    """
    Reads a JSON-formatted DROP dataset file and returns instances in a few different possible
    formats.  The input format is complicated; see the test fixture for an example of what it looks
    like.  The output formats all contain a question ``TextField``, a passage ``TextField``, and
    some kind of answer representation.  Because DROP has instances with several different kinds of
    answers, this dataset reader allows you to filter out questions that do not have answers of a
    particular type (e.g., remove questions that have numbers as answers, if you model can only
    give passage spans as answers).  We typically return all possible ways of arriving at a given
    answer string, and expect models to marginalize over these possibilities.
    # Parameters
    tokenizer : `Tokenizer`, optional (default=`SpacyTokenizer()`)
        We use this `Tokenizer` for both the question and the passage.  See :class:`Tokenizer`.
        Default is `SpacyTokenizer()`.
    token_indexers : `Dict[str, TokenIndexer]`, optional
        We similarly use this for both the question and the passage.  See :class:`TokenIndexer`.
        Default is `{"tokens": SingleIdTokenIndexer()}`.
    passage_length_limit : `int`, optional (default=`None`)
        If specified, we will cut the passage if the length of passage exceeds this limit.
    question_length_limit : `int`, optional (default=`None`)
        If specified, we will cut the question if the length of passage exceeds this limit.
    skip_when_all_empty: `List[str]`, optional (default=`None`)
        In some cases such as preparing for training examples, you may want to skip some examples
        when there are no gold labels. You can specify on what condition should the examples be
        skipped. Currently, you can put "passage_span", "question_span", "addition_subtraction",
        or "counting" in this list, to tell the reader skip when there are no such label found.
        If not specified, we will keep all the examples.
    instance_format: `str`, optional (default=`"drop"`)
        We try to be generous in providing a few different formats for the instances in DROP,
        in terms of the `Fields` that we return for each `Instance`, to allow for several
        different kinds of models.  "drop" format will do processing to detect numbers and
        various ways those numbers can be arrived at from the passage, and return `Fields`
        related to that.  "bert" format only allows passage spans as answers, and provides a
        "question_and_passage" field with the two pieces of text joined as BERT expects.
        "squad" format provides the same fields that our BiDAF and other SQuAD models expect.
    relaxed_span_match_for_finding_labels : `bool`, optional (default=`True`)
        DROP dataset contains multi-span answers, and the date-type answers are usually hard to
        find exact span matches for, also.  In order to use as many examples as possible
        to train the model, we may not want a strict match for such cases when finding the gold
        span labels. If this argument is true, we will treat every span in the multi-span
        answers as correct, and every token in the date answer as correct, too.  Because models
        trained on DROP typically marginalize over all possible answer positions, this is just
        being a little more generous in what is being marginalized.  Note that this will not
        affect evaluation.
    """

    def __init__(
        self,
        tokenizer: Tokenizer = None,
        token_indexers: Dict[str, TokenIndexer] = None,
        passage_length_limit: int = None,
        question_length_limit: int = None,
        skip_when_all_empty: List[str] = None,
        instance_format: str = "drop",
        relaxed_span_match_for_finding_labels: bool = True,
        **kwargs,
    ) -> None:
        super().__init__(**kwargs)
        self._tokenizer = tokenizer or SpacyTokenizer()
        self._token_indexers = token_indexers or {"tokens": SingleIdTokenIndexer()}
        self.passage_length_limit = passage_length_limit
        self.question_length_limit = question_length_limit
        self.skip_when_all_empty = skip_when_all_empty if skip_when_all_empty is not None else []
        for item in self.skip_when_all_empty:
            assert item in [
                "passage_span",
                "question_span",
                "addition_subtraction",
                "counting",
                "invalid",
            ], f"Unsupported skip type: {item}"
        self.instance_format = instance_format
        self.relaxed_span_match_for_finding_labels = relaxed_span_match_for_finding_labels

    @overrides
    def _read(self, file_path: str):
        # if `file_path` is a URL, redirect to the cache
        file_path = cached_path(file_path, extract_archive=True)
        logger.info("Reading file at %s", file_path)
        with open(file_path) as dataset_file:
            dataset = json.load(dataset_file)
        logger.info("Reading the dataset")
        kept_count, skip_count = 0, 0
        for passage_id, passage_info in dataset.items():
            passage_text = passage_info["passage"]
            passage_tokens = self._tokenizer.tokenize(passage_text)
            passage_tokens = split_tokens_by_hyphen(passage_tokens)
            for question_answer in passage_info["qa_pairs"]:
                question_id = question_answer["query_id"]
                question_text = question_answer["question"].strip()
                

                answer_annotations = []
                if "answer" in question_answer:
                    answer_annotations.append(question_answer["answer"])
                if "validated_answers" in question_answer:
                    answer_annotations += question_answer["validated_answers"]

                instance = self.text_to_instance(
                    question_text,
                    passage_text,
                    question_id,
                    passage_id,
                    answer_annotations,
                    passage_tokens,
                )
                if instance is not None:
                    kept_count += 1
                    yield instance
                else:
                    skip_count += 1
        print(f"Skipped {skip_count} questions, kept {kept_count} questions.")
        logger.info(f"Skipped {skip_count} questions, kept {kept_count} questions.")

    @overrides
    def text_to_instance(
        self,  # type: ignore
        question_text: str,
        passage_text: str,
        question_id: str = None,
        passage_id: str = None,
        answer_annotations: List[Dict] = None,
        passage_tokens: List[Token] = None,
    ) -> Union[Instance, None]:

        if not passage_tokens:
            passage_tokens = self._tokenizer.tokenize(passage_text)
            passage_tokens = split_tokens_by_hyphen(passage_tokens)
        question_tokens = self._tokenizer.tokenize(question_text)
        question_tokens = split_tokens_by_hyphen(question_tokens)

        if self.passage_length_limit is not None:
            passage_tokens = passage_tokens[: self.passage_length_limit]
        if self.question_length_limit is not None:
            question_tokens = question_tokens[: self.question_length_limit]

        answer_type: str = None
        answer_texts: List[str] = []
        if answer_annotations:

            answer_type, answer_texts = self.extract_answer_info_from_annotation(
                answer_annotations[0]
            )

        # Tokenize the answer text in order to find the matched span based on token
        tokenized_answer_texts = []
        for answer_text in answer_texts:
            answer_tokens = self._tokenizer.tokenize(answer_text)
            answer_tokens = split_tokens_by_hyphen(answer_tokens)
            tokenized_answer_texts.append(" ".join(token.text for token in answer_tokens))

        if self.instance_format == "squad":
            valid_passage_spans = (
                self.find_valid_spans(passage_tokens, tokenized_answer_texts)
                if tokenized_answer_texts
                else []
            )
            if not valid_passage_spans:
                if "passage_span" in self.skip_when_all_empty:
                    return None
                else:
                    valid_passage_spans.append((len(passage_tokens) - 1, len(passage_tokens) - 1))
            return make_reading_comprehension_instance(
                question_tokens,
                passage_tokens,
                self._token_indexers,
                passage_text,
                valid_passage_spans,
                answer_texts,
                additional_metadata={
                    "original_passage": passage_text,
                    "original_question": question_text,
                    "passage_id": passage_id,
                    "question_id": question_id,
                    "valid_passage_spans": valid_passage_spans,
                    "answer_annotations": answer_annotations,
                },
            )
        elif self.instance_format == "bert":
            question_concat_passage_tokens = question_tokens + [Token("[SEP]")] + passage_tokens
            valid_passage_spans = []
            for span in self.find_valid_spans(passage_tokens, tokenized_answer_texts):
                # This span is for `question + [SEP] + passage`.
                valid_passage_spans.append(
                    (span[0] + len(question_tokens) + 1, span[1] + len(question_tokens) + 1)
                )
            if not valid_passage_spans:
                if "passage_span" in self.skip_when_all_empty:
                    return None
                else:
                    valid_passage_spans.append(
                        (
                            len(question_concat_passage_tokens) - 1,
                            len(question_concat_passage_tokens) - 1,
                        )
                    )
            answer_info = {
                "answer_texts": answer_texts,  
                "answer_passage_spans": valid_passage_spans,
            }
            return self.make_bert_drop_instance(
                question_tokens,
                passage_tokens,
                question_concat_passage_tokens,
                self._token_indexers,
                passage_text,
                answer_info,
                additional_metadata={
                    "original_passage": passage_text,
                    "original_question": question_text,
                    "passage_id": passage_id,
                    "question_id": question_id,
                    "answer_annotations": answer_annotations,
                },
            )

        elif self.instance_format == "drop":
            numbers_in_passage = []
            number_indices = []
            for token_index, token in enumerate(passage_tokens):
                number = self.convert_word_to_number(token.text)
                if number is not None:
                    numbers_in_passage.append(number)
                    number_indices.append(token_index)
            # hack to guarantee minimal length of padded number
            numbers_in_passage.append(0)
            number_indices.append(-1)
            numbers_as_tokens = [Token(str(number)) for number in numbers_in_passage]

            valid_passage_spans = (
                self.find_valid_spans(passage_tokens, tokenized_answer_texts)
                if tokenized_answer_texts
                else []
            )
            valid_question_spans = (
                self.find_valid_spans(question_tokens, tokenized_answer_texts)
                if tokenized_answer_texts
                else []
            )


            target_numbers = []
            for answer_text in answer_texts:
                number = self.convert_word_to_number(answer_text)
                if number is not None:
                    target_numbers.append(number)
            valid_signs_for_add_sub_expressions: List[List[int]] = []
            valid_counts: List[int] = []
            invalid_answer: List[int] = []
            if answer_type in ["number", "date"]:
                valid_signs_for_add_sub_expressions = self.find_valid_add_sub_expressions(
                    numbers_in_passage, target_numbers
                )
            if answer_type in ["number"]:
                # Currently we only support count number 0 ~ 9
                numbers_for_count = list(range(10))
                valid_counts = self.find_valid_counts(numbers_for_count, target_numbers)

            if answer_type in ['invalid']:
                invalid_answer= target_numbers
            
            type_to_answer_map = {
                "passage_span": valid_passage_spans,
                "question_span": valid_question_spans,
                "addition_subtraction": valid_signs_for_add_sub_expressions,
                "counting": valid_counts,
                "invalid": invalid_answer,
            }

            if self.skip_when_all_empty and not any(
                type_to_answer_map[skip_type] for skip_type in self.skip_when_all_empty
            ):
                return None

            

            answer_info = {
                "answer_texts": answer_texts,  
                "answer_passage_spans": valid_passage_spans,
                "answer_question_spans": valid_question_spans,
                "signs_for_add_sub_expressions": valid_signs_for_add_sub_expressions,
                "counts": valid_counts,
                "invalid":invalid_answer,
            }

            return self.make_marginal_drop_instance(
                question_tokens,
                passage_tokens,
                numbers_as_tokens,
                number_indices,
                self._token_indexers,
                passage_text,
                answer_info,
                additional_metadata={
                    "original_passage": passage_text,
                    "original_question": question_text,
                    "original_numbers": numbers_in_passage,
                    "passage_id": passage_id,
                    "question_id": question_id,
                    "answer_info": answer_info,
                    "answer_annotations": answer_annotations,
                },
            )
        else:
            raise ValueError(
                f'Expect the instance format to be "drop", "squad" or "bert", '
                f"but got {self.instance_format}"
            )

   

    @staticmethod
    def make_marginal_drop_instance(
        question_tokens: List[Token],
        passage_tokens: List[Token],
        number_tokens: List[Token],
        number_indices: List[int],
        token_indexers: Dict[str, TokenIndexer],
        passage_text: str,
        answer_info: Dict[str, Any] = None,
        additional_metadata: Dict[str, Any] = None,
    ) -> Instance:
        additional_metadata = additional_metadata or {}
        fields: Dict[str, Field] = {}
        passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens]
        question_offsets = [(token.idx, token.idx + len(token.text)) for token in question_tokens]

        passage_field = TextField(passage_tokens, token_indexers)
        question_field = TextField(question_tokens, token_indexers)

        fields["passage"] = passage_field
        fields["question"] = question_field

        number_index_fields: List[Field] = [
            IndexField(index, passage_field) for index in number_indices
        ]
        fields["number_indices"] = ListField(number_index_fields)
      
        numbers_in_passage_field = TextField(number_tokens, token_indexers)
        metadata = {
            "original_passage": passage_text,
            "passage_token_offsets": passage_offsets,
            "question_token_offsets": question_offsets,
            "question_tokens": [token.text for token in question_tokens],
            "passage_tokens": [token.text for token in passage_tokens],
            "number_tokens": [token.text for token in number_tokens],
            "number_indices": number_indices,
        }
        if answer_info:
            metadata["answer_texts"] = answer_info["answer_texts"]

            passage_span_fields: List[Field] = [
                SpanField(span[0], span[1], passage_field)
                for span in answer_info["answer_passage_spans"]
            ]
            if not passage_span_fields:
                passage_span_fields.append(SpanField(-1, -1, passage_field))
            fields["answer_as_passage_spans"] = ListField(passage_span_fields)

            question_span_fields: List[Field] = [
                SpanField(span[0], span[1], question_field)
                for span in answer_info["answer_question_spans"]
            ]
            if not question_span_fields:
                question_span_fields.append(SpanField(-1, -1, question_field))
            fields["answer_as_question_spans"] = ListField(question_span_fields)

            add_sub_signs_field: List[Field] = []
            for signs_for_one_add_sub_expression in answer_info["signs_for_add_sub_expressions"]:
                add_sub_signs_field.append(
                    SequenceLabelField(signs_for_one_add_sub_expression, numbers_in_passage_field)
                )
            if not add_sub_signs_field:
                add_sub_signs_field.append(
                    SequenceLabelField([0] * len(number_tokens), numbers_in_passage_field)
                )
            fields["answer_as_add_sub_expressions"] = ListField(add_sub_signs_field)

            count_fields: List[Field] = [
                LabelField(count_label, skip_indexing=True) for count_label in answer_info["counts"]
            ]
            if not count_fields:
                count_fields.append(LabelField(-1, skip_indexing=True))
            fields["answer_as_counts"] = ListField(count_fields)

            
            invalid_fields: List[Field] = [
                LabelField(invalid_label, skip_indexing=True) for invalid_label in answer_info["invalid"]
            ]
            if not invalid_fields:
                invalid_fields.append(LabelField(-1, skip_indexing=True))
            fields["answer_invalid"] = ListField(invalid_fields)

        metadata.update(additional_metadata)
        fields["metadata"] = MetadataField(metadata)
        return Instance(fields)

    @staticmethod
    def make_bert_drop_instance(
        question_tokens: List[Token],
        passage_tokens: List[Token],
        question_concat_passage_tokens: List[Token],
        token_indexers: Dict[str, TokenIndexer],
        passage_text: str,
        answer_info: Dict[str, Any] = None,
        additional_metadata: Dict[str, Any] = None,
    ) -> Instance:
        additional_metadata = additional_metadata or {}
        fields: Dict[str, Field] = {}
        passage_offsets = [(token.idx, token.idx + len(token.text)) for token in passage_tokens]

        # This is separate so we can reference it later with a known type.
        passage_field = TextField(passage_tokens, token_indexers)
        question_field = TextField(question_tokens, token_indexers)
        fields["passage"] = passage_field
        fields["question"] = question_field
        question_and_passage_field = TextField(question_concat_passage_tokens, token_indexers)
        fields["question_and_passage"] = question_and_passage_field

        metadata = {
            "original_passage": passage_text,
            "passage_token_offsets": passage_offsets,
            "question_tokens": [token.text for token in question_tokens],
            "passage_tokens": [token.text for token in passage_tokens],
        }

        if answer_info:
            metadata["answer_texts"] = answer_info["answer_texts"]

            passage_span_fields: List[Field] = [
                SpanField(span[0], span[1], question_and_passage_field)
                for span in answer_info["answer_passage_spans"]
            ]
            if not passage_span_fields:
                passage_span_fields.append(SpanField(-1, -1, question_and_passage_field))
            fields["answer_as_passage_spans"] = ListField(passage_span_fields)

        metadata.update(additional_metadata)
        fields["metadata"] = MetadataField(metadata)
        return Instance(fields)

    @staticmethod
    def extract_answer_info_from_annotation(
        answer_annotation: Dict[str, Any]
    ) -> Tuple[str, List[str]]:
        answer_type = None
        if answer_annotation["spans"]:
            answer_type = "spans"
        elif answer_annotation["number"]:
            answer_type = "number"
        elif any(answer_annotation["date"].values()):
            answer_type = "date"
        elif answer_annotation["invalid"]:
            answer_type= "invalid"

        answer_content = answer_annotation[answer_type] if answer_type is not None else None

        answer_texts: List[str] = []
        if answer_type is None:  # No answer
            pass
        elif answer_type == "spans":
            # answer_content is a list of string in this case
            answer_texts = answer_content
        elif answer_type == "date":
            # answer_content is a dict with "month", "day", "year" as the keys
            date_tokens = [
                answer_content[key]
                for key in ["month", "day", "year"]
                if key in answer_content and answer_content[key]
            ]
            answer_texts = date_tokens
        elif answer_type == "number":
            # answer_content is a string of number
            answer_texts = [answer_content]

        elif answer_type == "invalid":
            answer_texts = [answer_content]

        return answer_type, answer_texts

    @staticmethod
    def convert_word_to_number(word: str, try_to_include_more_numbers=False):
        """
        Currently we only support limited types of conversion.
        """
        if try_to_include_more_numbers:
            # strip all punctuations from the sides of the word, except for the negative sign
            punctruations = string.punctuation.replace("-", "")
            word = word.strip(punctruations)
            # some words may contain the comma as deliminator
            word = word.replace(",", "")
            # word2num will convert hundred, thousand ... to number, but we skip it.
            if word in ["hundred", "thousand", "million", "billion", "trillion"]:
                return None
            try:
                number = word_to_num(word)
            except ValueError:
                try:
                    number = int(word)
                except ValueError:
                    try:
                        number = float(word)
                    except ValueError:
                        number = None
            return number
        else:
            no_comma_word = word.replace(",", "")
            if no_comma_word in WORD_NUMBER_MAP:
                number = WORD_NUMBER_MAP[no_comma_word]
            else:
                try:
                    number = int(no_comma_word)
                except ValueError:
                    number = None
            return number

    @staticmethod
    def find_valid_spans(
        passage_tokens: List[Token], answer_texts: List[str]
    ) -> List[Tuple[int, int]]:
        normalized_tokens = [
            token.text.lower().strip(STRIPPED_CHARACTERS) for token in passage_tokens
        ]
        word_positions: Dict[str, List[int]] = defaultdict(list)
        for i, token in enumerate(normalized_tokens):
            word_positions[token].append(i)
        spans = []
        for answer_text in answer_texts:
            answer_tokens = answer_text.lower().strip(STRIPPED_CHARACTERS).split()
            num_answer_tokens = len(answer_tokens)
            if answer_tokens[0] not in word_positions:
                continue
            for span_start in word_positions[answer_tokens[0]]:
                span_end = span_start  # span_end is _inclusive_
                answer_index = 1
                while answer_index < num_answer_tokens and span_end + 1 < len(normalized_tokens):
                    token = normalized_tokens[span_end + 1]
                    if answer_tokens[answer_index].strip(STRIPPED_CHARACTERS) == token:
                        answer_index += 1
                        span_end += 1
                    elif token in IGNORED_TOKENS:
                        span_end += 1
                    else:
                        break
                if num_answer_tokens == answer_index:
                    spans.append((span_start, span_end))
        return spans

    @staticmethod
    def find_valid_add_sub_expressions(
        numbers: List[int], targets: List[int], max_number_of_numbers_to_consider: int = 2
    ) -> List[List[int]]:
        valid_signs_for_add_sub_expressions = []
        for number_of_numbers_to_consider in range(2, max_number_of_numbers_to_consider + 1):
            possible_signs = list(itertools.product((-1, 1), repeat=number_of_numbers_to_consider))
            for number_combination in itertools.combinations(
                enumerate(numbers), number_of_numbers_to_consider
            ):
                indices = [it[0] for it in number_combination]
                values = [it[1] for it in number_combination]
                for signs in possible_signs:
                    eval_value = sum(sign * value for sign, value in zip(signs, values))
                    if eval_value in targets:
                        labels_for_numbers = [0] * len(numbers)  
                        for index, sign in zip(indices, signs):
                            labels_for_numbers[index] = (
                                1 if sign == 1 else 2
                            )  # 1 for positive, 2 for negative
                        valid_signs_for_add_sub_expressions.append(labels_for_numbers)
        return valid_signs_for_add_sub_expressions

    @staticmethod
    def find_valid_counts(count_numbers: List[int], targets: List[int]) -> List[int]:
        valid_indices = []
        for index, number in enumerate(count_numbers):
            if number in targets:
                valid_indices.append(index)
        return valid_indices

from allennlp.nn.util import dist_reduce_sum
from allennlp_models.rc.tools.squad import metric_max_over_ground_truths
from allennlp_models.rc.tools.drop import (
    get_metrics as drop_em_and_f1
)

def custom_answer_json_to_strings(answer: Dict[str, Any]) -> Tuple[Tuple[str, ...], str]:
    """
    Takes an answer JSON blob from the DROP data release and converts it into strings used for
    evaluation.
    """
    if "number" in answer and answer["number"]:
        return tuple([str(answer["number"])]), "number"
    elif "spans" in answer and answer["spans"]:
        return tuple(answer["spans"]), "span" if len(answer["spans"]) == 1 else "spans"
    elif "invalid" in answer and answer["invalid"]:
        return tuple([str(answer["invalid"])]), 'invalid'
    elif "date" in answer:
        return (
            tuple(
                [
                    "{0} {1} {2}".format(
                        answer["date"]["day"], answer["date"]["month"], answer["date"]["year"]
                    )
                ]
            ),
            "date",
        )
    else:
        raise ValueError(
            f"Answer type not found, should be one of number, spans or date at: {json.dumps(answer)}"
        )
        
class CustomDropEmAndF1(DropEmAndF1):


  def __call__(self, prediction: Union[str, List], ground_truths: List):  # type: ignore
        """
        Parameters
        ----------
        prediction: ``Union[str, List]``
            The predicted answer from the model evaluated. This could be a string, or a list of string
            when multiple spans are predicted as answer.
        ground_truths: ``List``
            All the ground truth answer annotations.
        """
        # If you wanted to split this out by answer type, you could look at [1] here and group by
        # that, instead of only keeping [0].
        ground_truth_answer_strings = [
            custom_answer_json_to_strings(annotation)[0] for annotation in ground_truths
        ]
        exact_match, f1_score = metric_max_over_ground_truths(
            drop_em_and_f1, prediction, ground_truth_answer_strings
        )

        # Converting to int here, since we want to count the number of exact matches.
        self._total_em += dist_reduce_sum(int(exact_match))
        self._total_f1 += dist_reduce_sum(f1_score)
        self._count += dist_reduce_sum(1)

class FINumericallyAugmentedQaNet(Model):
    """
    This class augments the QANet model with some rudimentary numerical reasoning abilities, as
    published in the original DROP paper.
    The main idea here is that instead of just predicting a passage span after doing all of the
    QANet modeling stuff, we add several different "answer abilities": predicting a span from the
    question, predicting a count, or predicting an arithmetic expression, in addition to predicting
    whether the input contains invalid sequences.  Near the end of the QANet model, we have a variable 
    that predicts what kind of answer type we need, and each branch has separate modeling logic to 
    predict that answer type.  We then marginalize over all possible ways of getting to the right 
    answer through each of these answer types.
    """

    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        num_highway_layers: int,
        phrase_layer: Seq2SeqEncoder,
        matrix_attention_layer: MatrixAttention,
        modeling_layer: Seq2SeqEncoder,
        dropout_prob: float = 0.1,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
        answering_abilities: List[str] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        if answering_abilities is None:
            self.answering_abilities = [
                "passage_span_extraction",
                "question_span_extraction",
                "addition_subtraction",
                "counting",
                "invalid",
            ]
        else:
            self.answering_abilities = answering_abilities

        text_embed_dim = text_field_embedder.get_output_dim()
        encoding_in_dim = phrase_layer.get_input_dim()
        encoding_out_dim = phrase_layer.get_output_dim()
        modeling_in_dim = modeling_layer.get_input_dim()
        modeling_out_dim = modeling_layer.get_output_dim()

        self._text_field_embedder = text_field_embedder

        self._embedding_proj_layer = torch.nn.Linear(text_embed_dim, encoding_in_dim)
        self._highway_layer = Highway(encoding_in_dim, num_highway_layers)

        self._encoding_proj_layer = torch.nn.Linear(encoding_in_dim, encoding_in_dim, bias=True)
        self._phrase_layer = phrase_layer

        self._matrix_attention = matrix_attention_layer

        self._modeling_proj_layer = torch.nn.Linear(
            encoding_out_dim * 4, modeling_in_dim, bias=True
        )
        self._modeling_layer = modeling_layer

        self._passage_weights_predictor = torch.nn.Linear(modeling_out_dim, 1)
        self._question_weights_predictor = torch.nn.Linear(encoding_out_dim, 1)

        if len(self.answering_abilities) > 1:
            self._answer_ability_predictor = FeedForward(
                modeling_out_dim + encoding_out_dim,
                activations=[Activation.by_name("relu")(), Activation.by_name("linear")()],
                hidden_dims=[modeling_out_dim, len(self.answering_abilities)],
                num_layers=2,
                dropout=dropout_prob,
            )

        if "passage_span_extraction" in self.answering_abilities:
            self._passage_span_extraction_index = self.answering_abilities.index(
                "passage_span_extraction"
            )
            self._passage_span_start_predictor = FeedForward(
                modeling_out_dim * 2,
                activations=[Activation.by_name("relu")(), Activation.by_name("linear")()],
                hidden_dims=[modeling_out_dim, 1],
                num_layers=2,
            )
            self._passage_span_end_predictor = FeedForward(
                modeling_out_dim * 2,
                activations=[Activation.by_name("relu")(), Activation.by_name("linear")()],
                hidden_dims=[modeling_out_dim, 1],
                num_layers=2,
            )

        if "question_span_extraction" in self.answering_abilities:
            self._question_span_extraction_index = self.answering_abilities.index(
                "question_span_extraction"
            )
            self._question_span_start_predictor = FeedForward(
                modeling_out_dim * 2,
                activations=[Activation.by_name("relu")(), Activation.by_name("linear")()],
                hidden_dims=[modeling_out_dim, 1],
                num_layers=2,
            )
            self._question_span_end_predictor = FeedForward(
                modeling_out_dim * 2,
                activations=[Activation.by_name("relu")(), Activation.by_name("linear")()],
                hidden_dims=[modeling_out_dim, 1],
                num_layers=2,
            )

        if "addition_subtraction" in self.answering_abilities:
            self._addition_subtraction_index = self.answering_abilities.index(
                "addition_subtraction"
            )
            self._number_sign_predictor = FeedForward(
                modeling_out_dim * 3,
                activations=[Activation.by_name("relu")(), Activation.by_name("linear")()],
                hidden_dims=[modeling_out_dim, 3],
                num_layers=2,
            )

        if "counting" in self.answering_abilities:
            self._counting_index = self.answering_abilities.index("counting")
            self._count_number_predictor = FeedForward(
                modeling_out_dim,
                activations=[Activation.by_name("relu")(), Activation.by_name("linear")()],
                hidden_dims=[modeling_out_dim, 10],
                num_layers=2,
            )

        if "invalid" in self.answering_abilities:
            self._invalid_index = self.answering_abilities.index("invalid")
            
            self._invalid_predictor = FeedForward(
                modeling_out_dim + encoding_out_dim, 
                activations=[Activation.by_name("relu")(), Activation.by_name("linear")()],
                hidden_dims=[modeling_out_dim, 2],
                num_layers=2,
            )

        self._drop_metrics = CustomDropEmAndF1()
        self._dropout = torch.nn.Dropout(p=dropout_prob)

        initializer(self)


    def forward( 
        self,
        question: Dict[str, torch.LongTensor],
        passage: Dict[str, torch.LongTensor],
        number_indices: torch.LongTensor,
        answer_as_passage_spans: torch.LongTensor = None,
        answer_as_question_spans: torch.LongTensor = None,
        answer_as_add_sub_expressions: torch.LongTensor = None,
        answer_as_counts: torch.LongTensor = None,
        answer_invalid: torch.LongTensor = None, 
        metadata: List[Dict[str, Any]] = None,
    ) -> Dict[str, torch.Tensor]:

        question_mask = util.get_text_field_mask(question)
        passage_mask = util.get_text_field_mask(passage)
        embedded_question = self._dropout(self._text_field_embedder(question))
        embedded_passage = self._dropout(self._text_field_embedder(passage))
        embedded_question = self._highway_layer(self._embedding_proj_layer(embedded_question))
        embedded_passage = self._highway_layer(self._embedding_proj_layer(embedded_passage))

        batch_size = embedded_question.size(0)

        projected_embedded_question = self._encoding_proj_layer(embedded_question)
        projected_embedded_passage = self._encoding_proj_layer(embedded_passage)

        encoded_question = self._dropout(
            self._phrase_layer(projected_embedded_question, question_mask)
        )
        encoded_passage = self._dropout(
            self._phrase_layer(projected_embedded_passage, passage_mask)
        )

        # Shape: (batch_size, passage_length, question_length)
        passage_question_similarity = self._matrix_attention(encoded_passage, encoded_question)
        # Shape: (batch_size, passage_length, question_length)
        passage_question_attention = masked_softmax(
            passage_question_similarity, question_mask, memory_efficient=True
        )
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_question_vectors = util.weighted_sum(encoded_question, passage_question_attention)

        # Shape: (batch_size, question_length, passage_length)
        question_passage_attention = masked_softmax(
            passage_question_similarity.transpose(1, 2), passage_mask, memory_efficient=True
        )

        # Shape: (batch_size, passage_length, passage_length)
        passsage_attention_over_attention = torch.bmm(
            passage_question_attention, question_passage_attention
        )
        # Shape: (batch_size, passage_length, encoding_dim)
        passage_passage_vectors = util.weighted_sum(
            encoded_passage, passsage_attention_over_attention
        )

        # Shape: (batch_size, passage_length, encoding_dim * 4)
        merged_passage_attention_vectors = self._dropout(
            torch.cat(
                [
                    encoded_passage,
                    passage_question_vectors,
                    encoded_passage * passage_question_vectors,
                    encoded_passage * passage_passage_vectors,
                ],
                dim=-1,
            )
        )

        # The recurrent modeling layers. Since these layers share the same parameters,
        # we don't construct them conditioned on answering abilities.
        modeled_passage_list = [self._modeling_proj_layer(merged_passage_attention_vectors)]
        for _ in range(4):
            modeled_passage = self._dropout(
                self._modeling_layer(modeled_passage_list[-1], passage_mask)
            )
            modeled_passage_list.append(modeled_passage)
        # Pop the first one, which is input
        modeled_passage_list.pop(0)

        # The first modeling layer is used to calculate the vector representation of passage
        passage_weights = self._passage_weights_predictor(modeled_passage_list[0]).squeeze(-1)
        passage_weights = masked_softmax(passage_weights, passage_mask)
        passage_vector = util.weighted_sum(modeled_passage_list[0], passage_weights)
        # The vector representation of question is calculated based on the unmatched encoding,
        # because we may want to infer the answer ability only based on the question words.
        question_weights = self._question_weights_predictor(encoded_question).squeeze(-1)
        question_weights = masked_softmax(question_weights, question_mask)
        question_vector = util.weighted_sum(encoded_question, question_weights)

        if len(self.answering_abilities) > 1:
            # Shape: (batch_size, number_of_abilities)
            answer_ability_logits = self._answer_ability_predictor(
                torch.cat([passage_vector, question_vector], -1)
            )
            answer_ability_log_probs = torch.nn.functional.log_softmax(answer_ability_logits, -1)
            best_answer_ability = torch.argmax(answer_ability_log_probs, 1)

        if "invalid" in self.answering_abilities:
            # Shape: (batch_size, number_of_abilities)
            invalid_logits = self._invalid_predictor(
                torch.cat([passage_vector, question_vector], -1)
            )
            invalid_log_probs = torch.nn.functional.log_softmax(invalid_logits, -1)

            best_is_invalid_pred = torch.argmax(invalid_log_probs, -1)
            best_invalid_log_prob = torch.gather(
                invalid_log_probs, 1, best_is_invalid_pred.unsqueeze(-1)
            ).squeeze(-1)
            if len(self.answering_abilities) > 1:
                best_invalid_log_prob += answer_ability_log_probs[:, self._invalid_index]


        if "counting" in self.answering_abilities:
            # Shape: (batch_size, 10)
            count_number_logits = self._count_number_predictor(passage_vector)
            count_number_log_probs = torch.nn.functional.log_softmax(count_number_logits, -1)
            # Info about the best count number prediction
            # Shape: (batch_size,)
            best_count_number = torch.argmax(count_number_log_probs, -1)
            best_count_log_prob = torch.gather(
                count_number_log_probs, 1, best_count_number.unsqueeze(-1)
            ).squeeze(-1)
            if len(self.answering_abilities) > 1:
                best_count_log_prob += answer_ability_log_probs[:, self._counting_index]

        if "passage_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, passage_length, modeling_dim * 2))
            passage_for_span_start = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[1]], dim=-1
            )
            # Shape: (batch_size, passage_length)
            passage_span_start_logits = self._passage_span_start_predictor(
                passage_for_span_start
            ).squeeze(-1)
            # Shape: (batch_size, passage_length, modeling_dim * 2)
            passage_for_span_end = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[2]], dim=-1
            )
            # Shape: (batch_size, passage_length)
            passage_span_end_logits = self._passage_span_end_predictor(
                passage_for_span_end
            ).squeeze(-1)
            # Shape: (batch_size, passage_length)
            passage_span_start_log_probs = util.masked_log_softmax(
                passage_span_start_logits, passage_mask
            )
            passage_span_end_log_probs = util.masked_log_softmax(
                passage_span_end_logits, passage_mask
            )

            # Info about the best passage span prediction
            passage_span_start_logits = replace_masked_values_with_big_negative_number(
                passage_span_start_logits, passage_mask
            )
            passage_span_end_logits = replace_masked_values_with_big_negative_number(
                passage_span_end_logits, passage_mask
            )
            # Shape: (batch_size, 2)
            best_passage_span = get_best_span(passage_span_start_logits, passage_span_end_logits)
            # Shape: (batch_size, 2)
            best_passage_start_log_probs = torch.gather(
                passage_span_start_log_probs, 1, best_passage_span[:, 0].unsqueeze(-1)
            ).squeeze(-1)
            best_passage_end_log_probs = torch.gather(
                passage_span_end_log_probs, 1, best_passage_span[:, 1].unsqueeze(-1)
            ).squeeze(-1)
            # Shape: (batch_size,)
            best_passage_span_log_prob = best_passage_start_log_probs + best_passage_end_log_probs
            if len(self.answering_abilities) > 1:
                best_passage_span_log_prob += answer_ability_log_probs[
                    :, self._passage_span_extraction_index
                ]

        if "question_span_extraction" in self.answering_abilities:
            # Shape: (batch_size, question_length)
            encoded_question_for_span_prediction = torch.cat(
                [
                    encoded_question,
                    passage_vector.unsqueeze(1).repeat(1, encoded_question.size(1), 1),
                ],
                -1,
            )
            question_span_start_logits = self._question_span_start_predictor(
                encoded_question_for_span_prediction
            ).squeeze(-1)
            # Shape: (batch_size, question_length)
            question_span_end_logits = self._question_span_end_predictor(
                encoded_question_for_span_prediction
            ).squeeze(-1)
            question_span_start_log_probs = util.masked_log_softmax(
                question_span_start_logits, question_mask
            )
            question_span_end_log_probs = util.masked_log_softmax(
                question_span_end_logits, question_mask
            )

            # Info about the best question span prediction
            question_span_start_logits = replace_masked_values_with_big_negative_number(
                question_span_start_logits, question_mask
            )
            question_span_end_logits = replace_masked_values_with_big_negative_number(
                question_span_end_logits, question_mask
            )
            # Shape: (batch_size, 2)
            best_question_span = get_best_span(question_span_start_logits, question_span_end_logits)
            # Shape: (batch_size, 2)
            best_question_start_log_probs = torch.gather(
                question_span_start_log_probs, 1, best_question_span[:, 0].unsqueeze(-1)
            ).squeeze(-1)
            best_question_end_log_probs = torch.gather(
                question_span_end_log_probs, 1, best_question_span[:, 1].unsqueeze(-1)
            ).squeeze(-1)
            # Shape: (batch_size,)
            best_question_span_log_prob = (
                best_question_start_log_probs + best_question_end_log_probs
            )
            if len(self.answering_abilities) > 1:
                best_question_span_log_prob += answer_ability_log_probs[
                    :, self._question_span_extraction_index
                ]

        if "addition_subtraction" in self.answering_abilities:
            # Shape: (batch_size, # of numbers in the passage)
            number_indices = number_indices.squeeze(-1)
            number_mask = number_indices != -1
            clamped_number_indices = util.replace_masked_values(number_indices, number_mask, 0)
            encoded_passage_for_numbers = torch.cat(
                [modeled_passage_list[0], modeled_passage_list[3]], dim=-1
            )
            # Shape: (batch_size, # of numbers in the passage, encoding_dim)
            encoded_numbers = torch.gather(
                encoded_passage_for_numbers,
                1,
                clamped_number_indices.unsqueeze(-1).expand(
                    -1, -1, encoded_passage_for_numbers.size(-1)
                ),
            )
            # Shape: (batch_size, # of numbers in the passage)
            encoded_numbers = torch.cat(
                [
                    encoded_numbers,
                    passage_vector.unsqueeze(1).repeat(1, encoded_numbers.size(1), 1),
                ],
                -1,
            )

            # Shape: (batch_size, # of numbers in the passage, 3)
            number_sign_logits = self._number_sign_predictor(encoded_numbers)
            number_sign_log_probs = torch.nn.functional.log_softmax(number_sign_logits, -1)

            # Shape: (batch_size, # of numbers in passage).
            best_signs_for_numbers = torch.argmax(number_sign_log_probs, -1)
            # For padding numbers, the best sign masked as 0 (not included).
            best_signs_for_numbers = util.replace_masked_values(
                best_signs_for_numbers, number_mask, 0
            )
            # Shape: (batch_size, # of numbers in passage)
            best_signs_log_probs = torch.gather(
                number_sign_log_probs, 2, best_signs_for_numbers.unsqueeze(-1)
            ).squeeze(-1)
            # the probs of the masked positions should be 1 so that it will not affect the joint probability
            # TODO: this is not quite right, since if there are many numbers in the passage,
            # TODO: the joint probability would be very small.
            best_signs_log_probs = util.replace_masked_values(best_signs_log_probs, number_mask, 0)
            # Shape: (batch_size,)
            best_combination_log_prob = best_signs_log_probs.sum(-1)
            if len(self.answering_abilities) > 1:
                best_combination_log_prob += answer_ability_log_probs[
                    :, self._addition_subtraction_index
                ]

        output_dict = {}

        # If answer is given, compute the loss.
        if (
            answer_as_passage_spans is not None
            or answer_as_question_spans is not None
            or answer_as_add_sub_expressions is not None
            or answer_as_counts is not None
            or answer_invalid is not None
        ):

            log_marginal_likelihood_list = []

            for answering_ability in self.answering_abilities:
                if answering_ability == "passage_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_passage_span_starts = answer_as_passage_spans[:, :, 0]
                    gold_passage_span_ends = answer_as_passage_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_passage_span_mask = gold_passage_span_starts != -1
                    clamped_gold_passage_span_starts = util.replace_masked_values(
                        gold_passage_span_starts, gold_passage_span_mask, 0
                    )
                    clamped_gold_passage_span_ends = util.replace_masked_values(
                        gold_passage_span_ends, gold_passage_span_mask, 0
                    )
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_span_starts = torch.gather(
                        passage_span_start_log_probs, 1, clamped_gold_passage_span_starts
                    )
                    log_likelihood_for_passage_span_ends = torch.gather(
                        passage_span_end_log_probs, 1, clamped_gold_passage_span_ends
                    )
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_passage_spans = (
                        log_likelihood_for_passage_span_starts
                        + log_likelihood_for_passage_span_ends
                    )
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_passage_spans = (
                        replace_masked_values_with_big_negative_number(
                            log_likelihood_for_passage_spans,
                            gold_passage_span_mask,
                        )
                    )
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_passage_span = util.logsumexp(
                        log_likelihood_for_passage_spans
                    )
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_passage_span)

                elif answering_ability == "question_span_extraction":
                    # Shape: (batch_size, # of answer spans)
                    gold_question_span_starts = answer_as_question_spans[:, :, 0]
                    gold_question_span_ends = answer_as_question_spans[:, :, 1]
                    # Some spans are padded with index -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    gold_question_span_mask = gold_question_span_starts != -1
                    clamped_gold_question_span_starts = util.replace_masked_values(
                        gold_question_span_starts, gold_question_span_mask, 0
                    )
                    clamped_gold_question_span_ends = util.replace_masked_values(
                        gold_question_span_ends, gold_question_span_mask, 0
                    )
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_span_starts = torch.gather(
                        question_span_start_log_probs, 1, clamped_gold_question_span_starts
                    )
                    log_likelihood_for_question_span_ends = torch.gather(
                        question_span_end_log_probs, 1, clamped_gold_question_span_ends
                    )
                    # Shape: (batch_size, # of answer spans)
                    log_likelihood_for_question_spans = (
                        log_likelihood_for_question_span_starts
                        + log_likelihood_for_question_span_ends
                    )
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_question_spans = (
                        replace_masked_values_with_big_negative_number(
                            log_likelihood_for_question_spans,
                            gold_question_span_mask,
                        )
                    )
                    # Shape: (batch_size, )

                    log_marginal_likelihood_for_question_span = util.logsumexp(
                        log_likelihood_for_question_spans
                    )
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_question_span)

                elif answering_ability == "addition_subtraction":
                    # The padded add-sub combinations use 0 as the signs for all numbers, and we mask them here.
                    # Shape: (batch_size, # of combinations)
                    gold_add_sub_mask = answer_as_add_sub_expressions.sum(-1) > 0
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    gold_add_sub_signs = answer_as_add_sub_expressions.transpose(1, 2)
                    # Shape: (batch_size, # of numbers in the passage, # of combinations)
                    log_likelihood_for_number_signs = torch.gather(
                        number_sign_log_probs, 2, gold_add_sub_signs
                    )
                    # the log likelihood of the masked positions should be 0
                    # so that it will not affect the joint probability
                    log_likelihood_for_number_signs = util.replace_masked_values(
                        log_likelihood_for_number_signs, number_mask.unsqueeze(-1), 0
                    )
                    # Shape: (batch_size, # of combinations)
                    log_likelihood_for_add_subs = log_likelihood_for_number_signs.sum(1)
                    # For those padded combinations, we set their log probabilities to be very small negative value
                    log_likelihood_for_add_subs = replace_masked_values_with_big_negative_number(
                        log_likelihood_for_add_subs, gold_add_sub_mask
                    )
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_add_sub = util.logsumexp(
                        log_likelihood_for_add_subs
                    )
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_add_sub)

                elif answering_ability == "counting":
                    # Count answers are padded with label -1,
                    # so we clamp those paddings to 0 and then mask after `torch.gather()`.
                    # Shape: (batch_size, # of count answers)
                    gold_count_mask = answer_as_counts != -1
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_counts = util.replace_masked_values(
                        answer_as_counts, gold_count_mask, 0
                    )
                    log_likelihood_for_counts = torch.gather(
                        count_number_log_probs, 1, clamped_gold_counts
                    )
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_counts = replace_masked_values_with_big_negative_number(
                        log_likelihood_for_counts, gold_count_mask
                    )
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_count = util.logsumexp(log_likelihood_for_counts)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_count)
                    
                elif answering_ability == "invalid":
                    gold_invalid_mask = answer_invalid != -1
                    # Shape: (batch_size, # of count answers)
                    clamped_gold_invalid = util.replace_masked_values(
                        answer_invalid, gold_invalid_mask, 0
                    ) 
                    log_likelihood_for_invalid = torch.gather(
                        invalid_log_probs, 1, clamped_gold_invalid
                    )
                    # For those padded spans, we set their log probabilities to be very small negative value
                    log_likelihood_for_invalid = replace_masked_values_with_big_negative_number(
                        log_likelihood_for_invalid, gold_invalid_mask
                    )
                    # Shape: (batch_size, )
                    log_marginal_likelihood_for_invalid = util.logsumexp(log_likelihood_for_invalid)
                    log_marginal_likelihood_list.append(log_marginal_likelihood_for_invalid)


                else:
                    raise ValueError(f"Unsupported answering ability: {answering_ability}")

            if len(self.answering_abilities) > 1:
                # Add the ability probabilities if there are more than one abilities
                all_log_marginal_likelihoods = torch.stack(log_marginal_likelihood_list, dim=-1)
                all_log_marginal_likelihoods = (
                    all_log_marginal_likelihoods + answer_ability_log_probs
                )
                marginal_log_likelihood = util.logsumexp(all_log_marginal_likelihoods)
            else:
                marginal_log_likelihood = log_marginal_likelihood_list[0]

            output_dict["loss"] = -marginal_log_likelihood.mean()

        # Compute the metrics and add the tokenized input to the output.
        if metadata is not None:
            output_dict["question_id"] = []
            output_dict["answer"] = []
            question_tokens = []
            passage_tokens = []
            for i in range(batch_size):
                question_tokens.append(metadata[i]["question_tokens"])
                passage_tokens.append(metadata[i]["passage_tokens"])

                if len(self.answering_abilities) > 1:
                    predicted_ability_str = self.answering_abilities[
                        best_answer_ability[i].detach().cpu().numpy()
                    ]
                else:
                    predicted_ability_str = self.answering_abilities[0]

                answer_json: Dict[str, Any] = {}

                # We did not consider multi-mention answers here
                if predicted_ability_str == "passage_span_extraction":
                    answer_json["answer_type"] = "passage_span"
                    passage_str = metadata[i]["original_passage"]
                    offsets = metadata[i]["passage_token_offsets"]
                    predicted_span = tuple(best_passage_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    predicted_answer = passage_str[start_offset:end_offset]
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = [(start_offset, end_offset)]
                elif predicted_ability_str == "question_span_extraction":
                    answer_json["answer_type"] = "question_span"
                    question_str = metadata[i]["original_question"]
                    offsets = metadata[i]["question_token_offsets"]
                    predicted_span = tuple(best_question_span[i].detach().cpu().numpy())
                    start_offset = offsets[predicted_span[0]][0]
                    end_offset = offsets[predicted_span[1]][1]
                    predicted_answer = question_str[start_offset:end_offset]
                    answer_json["value"] = predicted_answer
                    answer_json["spans"] = [(start_offset, end_offset)]
                elif (
                    predicted_ability_str == "addition_subtraction"
                ):  # plus_minus combination answer
                    answer_json["answer_type"] = "arithmetic"
                    original_numbers = metadata[i]["original_numbers"]
                    sign_remap = {0: 0, 1: 1, 2: -1}
                    predicted_signs = [
                        sign_remap[it] for it in best_signs_for_numbers[i].detach().cpu().numpy()
                    ]
                    result = sum(
                        [sign * number for sign, number in zip(predicted_signs, original_numbers)]
                    )
                    predicted_answer = str(result)
                    offsets = metadata[i]["passage_token_offsets"]
                    number_indices = metadata[i]["number_indices"]
                    number_positions = [offsets[index] for index in number_indices]
                    answer_json["numbers"] = []
                    for offset, value, sign in zip(
                        number_positions, original_numbers, predicted_signs
                    ):
                        answer_json["numbers"].append(
                            {"span": offset, "value": value, "sign": sign}
                        )
                    if number_indices[-1] == -1:
                        # There is a dummy 0 number at position -1 added in some cases; we are
                        # removing that here.
                        answer_json["numbers"].pop()
                    answer_json["value"] = result
                elif predicted_ability_str == "counting":
                    answer_json["answer_type"] = "count"
                    predicted_count = best_count_number[i].detach().cpu().numpy()
                    predicted_answer = str(predicted_count)
                    answer_json["count"] = predicted_count
                elif predicted_ability_str == "invalid":
                    answer_json["answer_type"] = "invalid"
                    predicted_is_invalid = best_is_invalid_pred[i].detach().cpu().numpy()
                    predicted_answer = str(predicted_is_invalid)
                    answer_json["invalid"] = predicted_is_invalid
                else:
                    raise ValueError(f"Unsupported answer ability: {predicted_ability_str}")

                output_dict["question_id"].append(metadata[i]["question_id"])
                output_dict["answer"].append(answer_json)
                answer_annotations = metadata[i].get("answer_annotations", [])
                if answer_annotations:
                    self._drop_metrics(predicted_answer, answer_annotations)
            # This is used for the demo.
            output_dict["passage_question_attention"] = passage_question_attention
            output_dict["question_tokens"] = question_tokens
            output_dict["passage_tokens"] = passage_tokens
        return output_dict

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        exact_match, f1_score = self._drop_metrics.get_metric(reset)
        return {"em": exact_match, "f1": f1_score}

    default_predictor = "reading_comprehension"

# main

In [None]:
#from training util
from allennlp.data import Instance, Vocabulary, Batch, DataLoader
from allennlp.data.dataset_readers import DatasetReader

#me
from allennlp_models.rc.dataset_readers.drop import DropReader
from allennlp.data.token_indexers import TokenCharactersIndexer, SingleIdTokenIndexer

from typing import Any, Dict, Iterable, Optional, Union, Tuple, Set, List



##data reader
dataset_reader= DropShuffReaderInvalid(token_indexers= {'tokens':SingleIdTokenIndexer(lowercase_tokens=True), 'token_characters': TokenCharactersIndexer(min_padding_length=5)}, \
                           passage_length_limit= 400, question_length_limit=50, \
                           skip_when_all_empty=["passage_span", "question_span", "addition_subtraction", "counting", 'invalid'], \
                           instance_format="drop")

validation_dataset_reader= DropShuffReaderInvalid(token_indexers= {'tokens':SingleIdTokenIndexer(lowercase_tokens=True), 'token_characters': TokenCharactersIndexer(min_padding_length=5)}, \
                           passage_length_limit= 1000, question_length_limit=100, \
                           skip_when_all_empty=[], \
                           instance_format="drop")

import random
from allennlp.data.samplers.bucket_batch_sampler import BucketBatchSampler
from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader

dataloaders={}
#changed batch size from 16 to 8

logger.info("Reading original training data from %s", augmented_train_data_path)
dataloaders['train_o']=MultiProcessDataLoader(reader= dataset_reader, \
                                              data_path=augmented_train_data_path, batch_sampler= BucketBatchSampler(batch_size=16), #8
                                              cuda_device=torch.cuda.current_device())


logger.info("Reading original dev data from %s", augmented_dev_data_path)
dataloaders['dev_o']=MultiProcessDataLoader(reader= validation_dataset_reader, \
                                            data_path=augmented_dev_data_path, batch_sampler= BucketBatchSampler(batch_size=16), #8
                                            cuda_device=torch.cuda.current_device())


##voca
vocab_dir = os.path.join(serialization_dir, "vocabulary")

if os.path.isdir(vocab_dir) and os.listdir(vocab_dir) is not None:
  raise ConfigurationError(
      "The 'vocabulary' directory in the provided serialization directory is non-empty"
  )

datasets_for_vocab_creation=None

instances: Iterable[Instance] = (
        instance
        for key, data_loader in dataloaders.items()
        if datasets_for_vocab_creation is None or key in datasets_for_vocab_creation
        for instance in data_loader.iter_instances()
    )
vocab = Vocabulary.from_instances(min_count= {"token_characters": 200}, \
                  pretrained_files= {"tokens": "https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.lower.converted.zip"}, \
                  only_include_pretrained_words= True,
                  instances=instances)

logger.info(f"writing the vocabulary to {vocab_dir}.")
vocab.save_to_files(vocab_dir)
logger.info("done creating vocab")

##
from allennlp.nn.regularizers.regularizer_applicator import RegularizerApplicator
from allennlp.nn.regularizers.regularizers import L2Regularizer
from allennlp_models.rc.modules.seq2seq_encoders.qanet_encoder import QaNetEncoder
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
from allennlp.modules.matrix_attention.linear_matrix_attention import LinearMatrixAttention
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding, TokenCharactersEncoder

logger.info("indexing dataloaders with vocab")

dataloaders['train_o'].index_with(vocab)
dataloaders['dev_o'].index_with(vocab)

##model creation
model=FINumericallyAugmentedQaNet(
        vocab=vocab,
        text_field_embedder=  BasicTextFieldEmbedder(
        {"tokens": Embedding(embedding_dim=300, trainable=False, \
          pretrained_file="https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.lower.converted.zip",
           vocab=vocab), 
         "token_characters": TokenCharactersEncoder(embedding=Embedding(embedding_dim=64, 
                                                                         vocab=vocab), 
                        encoder=CnnEncoder(embedding_dim= 64,
                        num_filters= 200,
                        ngram_filter_sizes=[5]))}),
        num_highway_layers=2, 
        phrase_layer= QaNetEncoder(input_dim=128,
            hidden_dim= 128,
            attention_projection_dim=128,
            feedforward_hidden_dim=128,
            num_blocks= 1,
            num_convs_per_block= 4,
            conv_kernel_size= 7,
            num_attention_heads= 8,
            dropout_prob= 0.1,
            layer_dropout_undecayed_prob= 0.1,
            attention_dropout_prob= 0) ,
        matrix_attention_layer= LinearMatrixAttention(tensor_1_dim= 128,
            tensor_2_dim= 128,
            combination= "x,y,x*y"),
        modeling_layer= QaNetEncoder(input_dim=128,
            hidden_dim= 128,
            attention_projection_dim=128,
            feedforward_hidden_dim=128,
            num_blocks= 6,
            num_convs_per_block= 2,
            conv_kernel_size= 5,
            num_attention_heads= 8,
            dropout_prob= 0.1,
            layer_dropout_undecayed_prob= 0.1,
            attention_dropout_prob= 0) ,
        dropout_prob = 0.1,
        regularizer= RegularizerApplicator(regexes=[(".*", L2Regularizer(alpha= 1e-07))]) ,
        answering_abilities= [
            "passage_span_extraction",
            "question_span_extraction",
            "addition_subtraction",
            "counting",
            "invalid"]
    )

from allennlp.training.trainer import Trainer, GradientDescentTrainer
from allennlp.training.optimizers import AdamOptimizer, SgdOptimizer
from allennlp.training.moving_average import ExponentialMovingAverage

adam=AdamOptimizer(model_parameters=model.named_parameters(), lr=5e-4, 
              betas=[0.8, 0.999], eps= 1e-07)

model.cuda()

trainer= GradientDescentTrainer(model=model, optimizer=adam, data_loader= dataloaders['train_o'], 
                                patience=10, validation_metric="+f1", validation_data_loader=dataloaders['dev_o'], 
                                num_epochs=50, serialization_dir=serialization_dir, grad_norm=5,
                                moving_average=ExponentialMovingAverage(model.named_parameters(), decay=0.9999)
                              )

trainer.train()



# Eval

In [None]:
path= serialization_dir

In [None]:
from allennlp.data import Instance, Vocabulary, Batch, DataLoader
from allennlp.data.dataset_readers import DatasetReader
from allennlp_models.rc.dataset_readers.drop import DropReader
from allennlp.data.token_indexers import TokenCharactersIndexer, SingleIdTokenIndexer

from typing import Any, Dict, Iterable, Optional, Union, Tuple, Set, List
import os
from allennlp.nn.regularizers.regularizer_applicator import RegularizerApplicator
from allennlp.nn.regularizers.regularizers import L2Regularizer
from allennlp_models.rc.modules.seq2seq_encoders.qanet_encoder import QaNetEncoder
from allennlp.modules.seq2vec_encoders.cnn_encoder import CnnEncoder
from allennlp.modules.matrix_attention.linear_matrix_attention import LinearMatrixAttention
from allennlp.modules.text_field_embedders.basic_text_field_embedder import BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding, TokenCharactersEncoder
import random
from allennlp.data.samplers.bucket_batch_sampler import BucketBatchSampler
from allennlp.data.data_loaders.multiprocess_data_loader import MultiProcessDataLoader


In [None]:
vocab= Vocabulary.from_files(path+'/vocabulary')

In [None]:
pretrained_dict=torch.load(path+'/best.th')


In [None]:
validation_dataset_reader= DropShuffReaderInvalid(token_indexers= {'tokens':SingleIdTokenIndexer(lowercase_tokens=True), 'token_characters': TokenCharactersIndexer(min_padding_length=5)}, \
                           passage_length_limit= 1000, question_length_limit=100, \
                           skip_when_all_empty=[], \
                           instance_format="drop")

In [None]:
dataset_reader=validation_dataset_reader

In [None]:
model=FINumericallyAugmentedQaNet(
        vocab=vocab,
        text_field_embedder=  BasicTextFieldEmbedder(
        {"tokens": Embedding(embedding_dim=300, trainable=False, \
          pretrained_file="https://allennlp.s3.amazonaws.com/datasets/glove/glove.840B.300d.lower.converted.zip",
           vocab=vocab), #num_embeddings=vocab.get_vocab_size("tokens")),
         "token_characters": TokenCharactersEncoder(embedding=Embedding(embedding_dim=64, 
                                                                         vocab=vocab), 
                        encoder=CnnEncoder(embedding_dim= 64,
                        num_filters= 200,
                        ngram_filter_sizes=[5]))}),
        num_highway_layers=2, 
        phrase_layer= QaNetEncoder(input_dim=128,
            hidden_dim= 128,
            attention_projection_dim=128,
            feedforward_hidden_dim=128,
            num_blocks= 1,
            num_convs_per_block= 4,
            conv_kernel_size= 7,
            num_attention_heads= 8,
            dropout_prob= 0.1,
            layer_dropout_undecayed_prob= 0.1,
            attention_dropout_prob= 0) ,
        matrix_attention_layer= LinearMatrixAttention(tensor_1_dim= 128,
            tensor_2_dim= 128,
            combination= "x,y,x*y"),
        modeling_layer= QaNetEncoder(input_dim=128,
            hidden_dim= 128,
            attention_projection_dim=128,
            feedforward_hidden_dim=128,
            num_blocks= 6,
            num_convs_per_block= 2,
            conv_kernel_size= 5,
            num_attention_heads= 8,
            dropout_prob= 0.1,
            layer_dropout_undecayed_prob= 0.1,
            attention_dropout_prob= 0) ,
        dropout_prob = 0.1,
        regularizer= RegularizerApplicator(regexes=[(".*", L2Regularizer(alpha= 1e-07))]) ,
        answering_abilities= [
            "passage_span_extraction",
            "question_span_extraction",
            "addition_subtraction",
            "counting",
            "invalid"]
    )

In [None]:
model.load_state_dict(pretrained_dict)

<All keys matched successfully>

In [None]:
model.cuda()

SiameseNumericallyAugmentedQaNet(
  (_text_field_embedder): BasicTextFieldEmbedder(
    (token_embedder_tokens): Embedding()
    (token_embedder_token_characters): TokenCharactersEncoder(
      (_embedding): TimeDistributed(
        (_module): Embedding()
      )
      (_encoder): TimeDistributed(
        (_module): CnnEncoder(
          (_activation): ReLU()
          (conv_layer_0): Conv1d(64, 200, kernel_size=(5,), stride=(1,))
        )
      )
    )
  )
  (_embedding_proj_layer): Linear(in_features=500, out_features=128, bias=True)
  (_highway_layer): Highway(
    (_layers): ModuleList(
      (0): Linear(in_features=128, out_features=256, bias=True)
      (1): Linear(in_features=128, out_features=256, bias=True)
    )
  )
  (_encoding_proj_layer): Linear(in_features=128, out_features=128, bias=True)
  (_phrase_layer): QaNetEncoder(
    (_encoder_blocks): ModuleList(
      (0): QaNetEncoderBlock(
        (_conv_norm_layers): ModuleList(
          (0): LayerNorm((128,), eps=1e-05, e

In [None]:
from allennlp.training.util import evaluate

In [None]:
dataset_reader= validation_dataset_reader

In [None]:
prediction_dir='predictions/'
def evaluate_dataset(model, data_path, model_name=None):
  logger.info("Reading data from %s", data_path)
  file_name=os.path.basename(data_path)[:-5]+"_preds.json"
  predictions_path=os.path.join(prediction_dir, model_name, file_name)
  dataloader=MultiProcessDataLoader(reader= dataset_reader, \
                                              data_path=data_path, batch_sampler= BucketBatchSampler(batch_size=8), 
                                              cuda_device=torch.cuda.current_device())
  dataloader.index_with(vocab)
  return evaluate(model, dataloader, cuda_device=0, \
                  predictions_output_file=predictions_path)

In [None]:
original_num_data_path='drop/drop_dataset_num.json'
shuffled_1g_dev_data_path='drop/drop_dataset_num_sh_q_1gram.json'
shuffled_2g_dev_data_path='drop/drop_dataset_num_sh_q_2gram.json'
shuffled_3g_dev_data_path='drop/drop_dataset_num_sh_q_3gram.json'
sh_p_3g= 'drop/drop_dataset_num_sh_p_3g.json'
sh_p_2g= 'drop/drop_dataset_num_sh_p_2g.json'
sh_p_1g= 'drop/drop_dataset_num_sh_p_1g.json'
dev= 'drop/drop_dataset_dev.json'
contrast_sets= 'drop/drop_contrast_sets_test.json'

In [None]:
result= evaluate_dataset(model, original_num_data_path,  f'naqanet_ngrams_final_seed_{seed}' )
print('results of numset original')
print(result)

loading instances: 6849it [00:50, 135.80it/s]


Skipped 0 questions, kept 6849 questions.


em: 0.47, f1: 0.48, loss: inf ||: : 857it [01:36,  8.91it/s]


results of numset original
{'em': 0.46795152577018545, 'f1': 0.47773835596437486, 'loss': inf}


In [None]:
result= evaluate_dataset(model, dev,  f'naqanet_ngrams_final_seed_{seed}' )
print('results of  original devset')
print(result)

loading instances: 9536it [01:07, 140.83it/s]


Skipped 0 questions, kept 9536 questions.


em: 0.45, f1: 0.49, loss: inf ||: : 1192it [02:13,  8.94it/s]

results of  original devset
{'em': 0.4540687919463087, 'f1': 0.48862730704698026, 'loss': inf}





In [None]:
result= evaluate_dataset(model, contrast_sets,  f'naqanet_ngrams_final_seed_{seed}' )
print('results of  original contrast sets')
print(result)

loading instances: 947it [00:09, 95.22it/s] 


Skipped 0 questions, kept 947 questions.


em: 0.28, f1: 0.34, loss: inf ||: : 119it [00:13,  8.63it/s]

results of  original contrast sets
{'em': 0.27666314677930304, 'f1': 0.34233368532206965, 'loss': inf}





In [None]:
result= evaluate_dataset(model, shuffled_1g_dev_data_path,  f'naqanet_ngrams_final_seed_{seed}')
print('results of  shuffled q 1g')
print(result)

loading instances: 6849it [00:48, 140.18it/s]


Skipped 0 questions, kept 6849 questions.


em: 0.07, f1: 0.07, loss: inf ||: : 857it [01:34,  9.10it/s]


results of  shuffled q 1g
{'em': 0.06760110965104395, 'f1': 0.06782595999415973, 'loss': inf}


In [None]:
result= evaluate_dataset(model, shuffled_2g_dev_data_path,  f'naqanet_ngrams_final_seed_{seed}' )
print('results of  shuffled q 2g')
print(result)

loading instances: 6849it [00:48, 140.03it/s]


Skipped 0 questions, kept 6849 questions.


em: 0.08, f1: 0.08, loss: inf ||: : 857it [01:34,  9.04it/s]


results of  shuffled q 2g
{'em': 0.07563147904803622, 'f1': 0.07684041465907433, 'loss': inf}


In [None]:
result= evaluate_dataset(model, shuffled_3g_dev_data_path,  f'naqanet_ngrams_final_seed_{seed}')
print('results of  shuffled q 3g')
print(result)

loading instances: 6849it [00:48, 140.44it/s]


Skipped 0 questions, kept 6849 questions.


em: 0.09, f1: 0.09, loss: inf ||: : 857it [01:34,  9.11it/s]


results of  shuffled q 3g
{'em': 0.08745802306906118, 'f1': 0.08944955467951524, 'loss': inf}


In [None]:
result= evaluate_dataset(model, sh_p_1g,  f'naqanet_ngrams_final_seed_{seed}' )
print('results of  shuffled p 1g')
print(result)

loading instances: 6849it [00:48, 140.22it/s]


Skipped 0 questions, kept 6849 questions.


em: 0.01, f1: 0.01, loss: inf ||: : 857it [01:33,  9.12it/s]


results of  shuffled p 1g
{'em': 0.012994597751496568, 'f1': 0.013067601109651043, 'loss': inf}


In [None]:
result= evaluate_dataset(model, sh_p_2g,  f'naqanet_ngrams_final_seed_{seed}' )
print('results of  shuffled p 2g')
print(result)

loading instances: 6849it [00:49, 138.71it/s]


Skipped 0 questions, kept 6849 questions.


em: 0.01, f1: 0.01, loss: inf ||: : 857it [01:34,  9.11it/s]


results of  shuffled p 2g
{'em': 0.012994597751496568, 'f1': 0.013067601109651043, 'loss': inf}


In [None]:
result= evaluate_dataset(model, sh_p_3g,  f'naqanet_ngrams_final_seed_{seed}' )
print('results of  shuffled p 3g')
print(result)

loading instances: 6849it [01:10, 96.48it/s] 


Skipped 0 questions, kept 6849 questions.


em: 0.02, f1: 0.02, loss: inf ||: : 857it [01:33,  9.14it/s]


results of  shuffled p 3g
{'em': 0.017082785808147174, 'f1': 0.01715578916630165, 'loss': inf}
