In [1]:
import json
import logging

from typing import Dict, List
from overrides import overrides

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.instance import Instance
from allennlp.data.fields import Field, TextField, ListField, MetadataField, IndexField,ArrayField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Tokenizer, WordTokenizer

logger = logging.getLogger(__name__) # pylint: disable=invalid-name
logger.setLevel(logging.DEBUG)
logging.debug("test")

from allennlp.nn import util, InitializerApplicator, RegularizerApplicator
from allennlp.modules.matrix_attention import LinearMatrixAttention
import torch
import random
import numpy as np


class QangarooReader(DatasetReader):
    """
    Reads a JSON-formatted Qangaroo file and returns a ``Dataset`` where the ``Instances`` have six
    fields: ``candidates``, a ``ListField[TextField]``, ``query``, a ``TextField``, ``supports``, a
    ``ListField[TextField]``, ``answer``, a ``TextField``, and ``answer_index``, a ``IndexField``.
    We also add a ``MetadataField`` that stores the instance's ID and annotations if they are present.
    Parameters
    ----------
    tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``)
        We use this ``Tokenizer`` for both the question and the passage.  See :class:`Tokenizer`.
        Default is ```WordTokenizer()``.
    token_indexers : ``Dict[str, TokenIndexer]``, optional
        We similarly use this for both the question and the passage.  See :class:`TokenIndexer`.
        Default is ``{"tokens": SingleIdTokenIndexer()}``.
    """
    def __init__(self,
                 tokenizer: Tokenizer = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 lazy: bool = False,
                 use_label: bool = True,
                 use_mention: bool = True) -> None:

        super().__init__(lazy)
        self._tokenizer = tokenizer or WordTokenizer()
        self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer('token', True)}
        self.use_label = use_label
        self.use_mention = use_mention

    @overrides
    def _read(self, file_path: str):
        # if `file_path` is a URL, redirect to the cache
        file_path = cached_path(file_path)

        logger.info("Reading file at %s", file_path)
        with open(file_path) as dataset_file:
            dataset = json.load(dataset_file)
        
        logger.info('dataset length: %d',len(dataset))
        logger.info("Reading the dataset")
        for sample in dataset:

            instance = self.text_to_instance(sample['candidates'], sample['query'], sample['supports'],
                                             sample['id'], sample['answer'],
                                             sample['annotations'] if 'annotations' in sample else [[]])
            if self.use_label:
                if max(instance.fields['supports_labels'].array) == 0:
                    continue
            yield instance

    @overrides
    def text_to_instance(self, # type: ignore
                         candidates: List[str],
                         query: str,
                         supports: List[str],
                         _id: str = None,
                         answer: str = None,
                         annotations: List[List[str]] = None) -> Instance:

        # pylint: disable=arguments-differ
        fields: Dict[str, Field] = {}

        candidates_field = ListField([TextField(candidate, self._token_indexers)
                                      for candidate in self._tokenizer.batch_tokenize(candidates)])

        fields['query'] = TextField(self._tokenizer.tokenize(query.replace('_',' ')), self._token_indexers)

        fields['supports'] = ListField([TextField(support, self._token_indexers)
                                        for support in self._tokenizer.batch_tokenize(supports)])

        fields['answer'] = TextField(self._tokenizer.tokenize(answer), self._token_indexers)

        fields['answer_index'] = IndexField(candidates.index(answer), candidates_field)

        fields['candidates'] = candidates_field

        fields['metadata'] = MetadataField({'annotations': annotations, 'id': _id})
        
        if self.use_label:
            answer_tokens = fields['answer'].tokens
            answer_tokens = [token.text.lower() for token in answer_tokens]
            answer_len = len(answer_tokens)
            answer_str = ' '.join(answer_tokens)
            supports_labels = []
            for filed in fields['supports']:
                tokens = filed.tokens
                tokens = [ token.text.lower() for token in tokens]
                is_support = 0
                for i in range(len(tokens)-answer_len):
                    token_add = ' '.join(tokens[i:i+answer_len])
                    if token_add == answer_str:
                        is_support = 1
                        break
                supports_labels.append(is_support)
            fields['supports_labels'] = ArrayField(np.array(supports_labels))
            
        if self.use_mention:
            all_mentions = []

            for candidate_field in fields['candidates']:
                candidate = candidate_field.tokens
                candidate = [token.text.lower() for token in candidate]
                c = ' '.join(candidate)
                mentions = []
                for idx, support_field in enumerate(fields['supports']):
                    support = support_field.tokens
                    support = [ token.text.lower() for token in support]
                    for i in range(len(support)):
                        token = support[i]
                        if token == candidate[0]:
                            s = ' '.join(support[i:i+len(candidate)])
                            if s == c:
                                mentions.append([idx, i, i+len(candidate)])
                all_mentions.append(mentions)
        fields['mentions'] = MetadataField(all_mentions)
        
        return Instance(fields)

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [None]:
reader = QangarooReader()
train_path = './data/qangaroo_v1.1/wikihop/train.json'
val_path = './data/qangaroo_v1.1/wikihop/dev.json'

toy_data = reader.read('./toy_data.json')
validation_dataset = reader.read(val_path)
train_dataset = reader.read(train_path)

0it [00:00, ?it/s]INFO:__main__:Reading file at ./toy_data.json
INFO:__main__:dataset length: 10
INFO:__main__:Reading the dataset
10it [00:00, 10.23it/s]
0it [00:00, ?it/s]INFO:__main__:Reading file at ./data/qangaroo_v1.1/wikihop/dev.json
INFO:__main__:dataset length: 5129
INFO:__main__:Reading the dataset
5099it [06:55, 12.26it/s]
0it [00:00, ?it/s]INFO:__main__:Reading file at ./data/qangaroo_v1.1/wikihop/train.json
INFO:__main__:dataset length: 43738
INFO:__main__:Reading the dataset
9567it [11:04, 10.60it/s]

In [None]:
from tqdm import trange
def test_mentions(dataset):
    for idx in  trange(len(dataset)):
        instance = dataset[idx]
        mentions = instance.fields['mentions'].metadata
        candidates = instance.fields['candidates']
        answer_index = instance.fields['answer_index']
        for idx, candidate_mentions in enumerate(mentions):
            candidate = candidates[idx].tokens    
            candidate = [token.text.lower() for token in candidate]    
            candidate = ' '.join(candidate)
            for mention in candidate_mentions:
                support_idx, s_idx, e_idx = mention
                support = instance.fields['supports'][support_idx].tokens
                mention_item = support[s_idx:e_idx]
                mention_item = [token.text.lower() for token in mention_item]    
                mention_item = ' '.join(mention_item)
                assert mention_item == candidate
test_mentions(validation_dataset)