In [15]:
import json
import numpy as np

In [36]:
train_path = '../data/qangaroo_v1.1/wikihop/train.json'
val_path = '../data/qangaroo_v1.1/wikihop/dev.json'

In [37]:
dev_data = json.load(open(val_path,'r'))
train_data= json.load(open(train_path,'r'))
print(len(dev_data), len(train_data))

5129 43738


In [38]:
dev_data[1]

{'candidates': ['democratic party',
  'military',
  'progressive party',
  'republican party'],
 'annotations': [['follows', 'multiple'],
  ['follows', 'single'],
  ['follows', 'single']],
 'query': 'member_of_political_party thomas l. woolwine',
 'supports': ['James Sunny Jim Rolph, Jr. (August 23, 1869\xa0 June 2, 1934) was an American politician and a member of the Republican Party. He was elected to a single term as the 27th governor of California from January 6, 1931 until his death on June 2, 1934 at the height of the Great Depression. Previously, Rolph had been the 30th mayor of San Francisco from January 8, 1912 until his resignation to become governor. Rolph remains the longest serving mayor in San Francisco history.',
  'The California National Guard is a federally funded California military force, part of the National Guard of the United States. It comprises both Army and Air National Guard components and is the largest national guard force in the United States with a total 

In [47]:
import json
import logging

from typing import Dict, List
from overrides import overrides

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.instance import Instance
from allennlp.data.fields import Field, TextField, ListField, MetadataField, IndexField,ArrayField
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Tokenizer, WordTokenizer
logger = logging.getLogger(__name__) # pylint: disable=invalid-name
from allennlp.nn import util, InitializerApplicator, RegularizerApplicator
from allennlp.modules.matrix_attention import LinearMatrixAttention
import torch
import random

class QangarooReader(DatasetReader):
    """
    Reads a JSON-formatted Qangaroo file and returns a ``Dataset`` where the ``Instances`` have six
    fields: ``candidates``, a ``ListField[TextField]``, ``query``, a ``TextField``, ``supports``, a
    ``ListField[TextField]``, ``answer``, a ``TextField``, and ``answer_index``, a ``IndexField``.
    We also add a ``MetadataField`` that stores the instance's ID and annotations if they are present.
    Parameters
    ----------
    tokenizer : ``Tokenizer``, optional (default=``WordTokenizer()``)
        We use this ``Tokenizer`` for both the question and the passage.  See :class:`Tokenizer`.
        Default is ```WordTokenizer()``.
    token_indexers : ``Dict[str, TokenIndexer]``, optional
        We similarly use this for both the question and the passage.  See :class:`TokenIndexer`.
        Default is ``{"tokens": SingleIdTokenIndexer()}``.
    """
    def __init__(self,
                 tokenizer: Tokenizer = None,
                 token_indexers: Dict[str, TokenIndexer] = None,
                 lazy: bool = False,
                 use_label: bool = True) -> None:

        super().__init__(lazy)
        self._tokenizer = tokenizer or WordTokenizer()
        self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer('token', True)}
        self.use_label = use_label

    @overrides
    def _read(self, file_path: str):
        # if `file_path` is a URL, redirect to the cache
        file_path = cached_path(file_path)

        logger.info("Reading file at %s", file_path)
        with open(file_path) as dataset_file:
            dataset = json.load(dataset_file)
        
        logger.info('dataset length: %d',len(dataset))
        logger.info("Reading the dataset")
        for sample in dataset:

            instance = self.text_to_instance(sample['candidates'], sample['query'], sample['supports'],
                                             sample['id'], sample['answer'],
                                             sample['annotations'] if 'annotations' in sample else [[]])
            if self.use_label:
                if max(instance.fields['supports_labels'].array) == 0:
                    continue
            yield instance

    @overrides
    def text_to_instance(self, # type: ignore
                         candidates: List[str],
                         query: str,
                         supports: List[str],
                         _id: str = None,
                         answer: str = None,
                         annotations: List[List[str]] = None) -> Instance:

        # pylint: disable=arguments-differ
        fields: Dict[str, Field] = {}

        candidates_field = ListField([TextField(candidate, self._token_indexers)
                                      for candidate in self._tokenizer.batch_tokenize(candidates)])

        fields['query'] = TextField(self._tokenizer.tokenize(query.replace('_',' ')), self._token_indexers)

        fields['supports'] = ListField([TextField(support, self._token_indexers)
                                        for support in self._tokenizer.batch_tokenize(supports)])

        fields['answer'] = TextField(self._tokenizer.tokenize(answer), self._token_indexers)

        fields['answer_index'] = IndexField(candidates.index(answer), candidates_field)

        fields['candidates'] = candidates_field

        fields['metadata'] = MetadataField({'annotations': annotations, 'id': _id})
        
        if self.use_label:
            answer_tokens = fields['answer'].tokens
            answer_tokens = [token.text.lower() for token in answer_tokens]
            answer_len = len(answer_tokens)
            answer_str = ' '.join(answer_tokens)
            supports_labels = []
            for filed in fields['supports']:
                tokens = filed.tokens
                tokens = [ token.text.lower() for token in tokens]
                is_support = 0
                for i in range(len(tokens)-answer_len):
                    token_add = ' '.join(tokens[i:i+answer_len])
                    if token_add == answer_str:
                        is_support = 1
                        break
                supports_labels.append(is_support)
            fields['supports_labels'] = ArrayField(np.array(supports_labels))
        return Instance(fields)

In [48]:
reader = QangarooReader()

In [49]:
validation_dataset = reader.read('./toy_data.json')

10it [00:00, 14.18it/s]


In [57]:
instance = validation_dataset[6]
instance.fields['supports'][1].tokens

[Edward,
 Theodore,
 ",
 Teddy,
 ",
 Riley,
 (,
 born,
 October,
 8,
 ,,
 1967,
 ),
 is,
 an,
 American,
 singer,
 -,
 songwriter,
 ,,
 musician,
 ,,
 keyboardist,
 ,,
 and,
 record,
 producer,
 credited,
 with,
 the,
 creation,
 of,
 the,
 new,
 jack,
 swing,
 genre,
 .,
 Through,
 his,
 production,
 work,
 with,
 Michael,
 Jackson,
 ,,
 Bobby,
 Brown,
 ,,
 Doug,
 E.,
 Fresh,
 ,,
 Today,
 ,,
 Keith,
 Sweat,
 ,,
 Heavy,
 D.,
 ,,
 Usher,
 ,,
 Jane,
 Child,
 ,,
 etc,
 .,
 and,
 membership,
 of,
 the,
 groups,
 Guy,
 and,
 Blackstreet,
 ,,
 Riley,
 is,
 credited,
 with,
 having,
 a,
 massive,
 impact,
 and,
 seminal,
 influence,
 on,
 the,
 formation,
 of,
 contemporary,
 R&B,
 ,,
 hip,
 -,
 hop,
 ,,
 soul,
 and,
 pop,
 since,
 the,
 1980s,
 .]

In [46]:
a = 'aaa'
a.replace('a','b')

'bbb'

In [7]:
train_dataset = reader.read(train_path)

43398it [21:32, 33.57it/s]


In [12]:
import torch

In [None]:
from allennlp.data.vocabulary import Vocabulary
from allennlp.models import Model
from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder
from allennlp.modules.token_embedders import Embedding
from allennlp.modules.seq2seq_encoders import Seq2SeqEncoder, PytorchSeq2SeqWrapper, StackedBidirectionalLstm
from allennlp.nn.util import get_text_field_mask, sequence_cross_entropy_with_logits
from allennlp.training.metrics import CategoricalAccuracy
from allennlp.data.iterators import BucketIterator, BasicIterator
from allennlp.training.trainer import Trainer
import torch
import torch.nn as nn
from allennlp.modules.attention import BilinearAttention

In [None]:
vocab = Vocabulary.from_instances(validation_dataset,pretrained_files={'tokens':'./glove.840B.300d.lower.converted.zip'})

In [None]:
vocab.get_vocab_size('tokens')

In [None]:
token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),embedding_dim=300,
                            pretrained_file='./glove.840B.300d.lower.converted.zip')

In [None]:
word_embeddings = BasicTextFieldEmbedder({'tokens': token_embedding})

In [None]:
iterator = BasicIterator(batch_size=2)
iterator.index_with(vocab)
rwa_iterator = iterator(validation_dataset, num_epochs=1, shuffle=False)

In [None]:
for idx,batch in enumerate(rwa_iterator):
    break
batch

In [None]:
batch['supports']['tokens'].shape

In [None]:
embedded_supports = word_embeddings(batch['supports'])
embedded_query = word_embeddings(batch['query'])
embedded_candidates = word_embeddings(batch['candidates'])

In [None]:
query_mask = util.get_text_field_mask(batch['query'])
supports_mask = util.get_text_field_mask(batch['supports'],num_wrapping_dims=1)
supports_mask_para = util.get_text_field_mask(batch['supports'])

candidates_mask_seq = util.get_text_field_mask(batch['candidates'], num_wrapping_dims=1)
candidates_mask_para = util.get_text_field_mask(batch['candidates'])
candidates_mask_seq_expand = candidates_mask_seq.view(-1, candidates_mask_seq.size(-1))

supports_mask_expand = supports_mask.view(-1,supports_mask.size(-1))
query_mask_expand = query_mask.unsqueeze(1).expand(query_mask.size(0),sup_len, query_mask.size(1))
query_mask_expand = query_mask_expand.contiguous().view(-1, query_mask_expand.size(-1))

In [None]:
embedded_candidates.shape

In [None]:
batch_size, sup_len, seq_len, emb_dim = embedded_supports.size()
embedded_supports_expand = embedded_supports.view(-1,seq_len, emb_dim)
embedded_candidates_expand = embedded_candidates.view(-1, embedded_candidates.size(2), emb_dim)

In [None]:
phrase_layer = PytorchSeq2SeqWrapper(StackedBidirectionalLstm(300,100,1,0.2, 0.2,True))
attention = LinearMatrixAttention(200,200,'x,y,x*y')
similarity_function_2 = LinearMatrixAttention(200,200,'x,y,x*y')

co_attention_fusion = nn.Sequential(
                        nn.Linear(600,200,bias=True),
                        nn.ReLU(inplace=True)
                    )

self_attention_fusion = nn.Sequential(
                            nn.Linear(800,200),
                            nn.ReLU(inplace=True)
                        )

supports_pooling = SelfAttentive(200)
question_pooling = SelfAttentive(200)
candidates_pooling = SelfAttentive(200)

In [None]:
encoded_query = phrase_layer(embedded_query, query_mask)
encoded_supports = phrase_layer(embedded_supports_expand, supports_mask_expand)
encoded_candidates = phrase_layer(embedded_candidates_expand, candidates_mask_seq_expand)

encoded_query_expand = encoded_query.unsqueeze(1).expand(batch_size, sup_len, encoded_query.size(1), encoded_query.size(2))
encoded_query_expand = encoded_query_expand.contiguous().view(-1,encoded_query.size(1), encoded_query.size(2))

In [None]:
# Co-attention

# shape: (batch_size*passage_num, passage_length, question_length )
supports_query_similarity = attention(encoded_supports, encoded_query_expand)

# shape: (batch_size*passage_num, passage_length, question_length )
supports_query_attention = util.masked_softmax(supports_query_similarity, query_mask_expand)
# shape: (batch_size*passage_num, passage_length, encoding_dim)
supports_query_vectors = util.weighted_sum(encoded_query_expand, supports_query_attention) 

# shape: (batch_size*passage_num, query_length, passage_length)
query_passage_attention = util.masked_softmax(supports_query_similarity.transpose(1,2), supports_mask_expand)
# shape: (batch_size*passage_num, query_length, encoding_dim)
query_supports_vectors = util.weighted_sum(encoded_supports, query_passage_attention)

# shape: (batch_size*passage_num, passage_length, encoding_dim)
supports_query_vectors_2 = torch.bmm(supports_query_attention, query_supports_vectors)
# shape: (batch_size*passage_num, passage_length, encoding_dim*2)
supports_query_vectors_final = torch.cat([supports_query_vectors, supports_query_vectors_2], dim=-1)

# Fusion, 暂时用简单的fusion函数

supports_coattention_vectors = co_attention_fusion(torch.cat([encoded_supports,supports_query_vectors_final], dim=-1))

In [None]:
suppports_self_similarity = similarity_function_2(supports_coattention_vectors, supports_coattention_vectors)
supports_selfattention = util.masked_softmax(sup_sup_similarity, supports_mask_expand)
supports_selfatt_vectors =util.weighted_sum(supports_coattention_vectors, supports_selfattention) 
support_selfatt_fusion = self_attention_fusion(util.combine_tensors('1,2,1-2,1*2',[supports_coattention_vectors, supports_selfatt_vectors]))


In [None]:
import allennlp
from torch.nn import Parameter

In [None]:
class SelfAttentive(allennlp.modules.Seq2VecEncoder):
    
    def __init__(self,
                 dim: int,
                ) -> None:
        super().__init__()
        self.weight = Parameter(torch.Tensor(dim,1))
        self.reset_parameters()
        
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.weight)
        
    def forward(self, matrix: torch.Tensor, matrix_mask: torch.Tensor) -> torch.Tensor:
        # (batch_size, seq_len, dim) -> (batch_size, seq_len)
        similarity = torch.matmul(matrix, self.weight).squeeze(-1) 
        similarity =  util.masked_softmax(similarity, matrix_mask)
        return util.weighted_sum(matrix, similarity)
        

In [None]:
supports_pooling_vectors = supports_pooling(support_selfatt_fusion, supports_mask_expand)
supports_pooling_vectors = supports_pooling_vectors.view(batch_size, sup_len,-1)

question_pooling_vectors = question_pooling(encoded_query,query_mask)

candidates_pooling_vectors = candidates_pooling(encoded_candidates, candidates_mask_seq_expand)
candidates_pooling_vectors = candidates_pooling_vectors.view(batch_size,-1, candidates_pooling_vectors.size(-1))

In [None]:
print(question_pooling_vectors.shape, supports_pooling_vectors.shape, candidates_pooling_vectors.shape)

### SAN

In [None]:
from allennlp.common.registrable import Registrable

class Decoder(nn.Module, Registrable):
    
    def forward(self,
               supports_vectors: torch.FloatTensor,
               query_vectors: torch.FloatTensor,
               candidates_vectors: torch.FloatTensor,
               supports_mask: torch.LongTensor = None):
        raise NotImplementedError

In [None]:
@Decoder.register("san_decoder")
class SANDecoder(Decoder):
    
    def __init__(self,
                 support_dim: int,
                 query_dim: int,
                 candidates_dim: int,
                 num_step: int = 1,
                 reason_type: int = 0,
                 reason_dropout_p: float = 0.2,
                 dropout_p: float = 0.4
                 ) -> None:
        """
        Parameters
        ----------
        
        reason_type: 0: random
                     1: only last
                     2: avg
        """
        super().__init__()
        
        assert num_step > 0
        assert reason_type < 3 and reason_type >=0
        
        self.num_step = num_step
        self.reason_type = reason_type
        self.dropout_p = dropout_p
        self.reason_dropout_p = reason_dropout_p
        
        self.supports_predictor = BilinearAttention(query_dim, support_dim, normalize=True)
        self.candidates_predictor = BilinearAttention(support_dim, candidates_dim, normalize=False)
        
        self.rnn = nn.GRUCell(support_dim, query_dim)
        self.alpha = Parameter(torch.zeros(1,1))

    @overrides
    def forward(self,
               supports_vectors: torch.FloatTensor,
               query_vectors: torch.FloatTensor,
               candidates_vectors: torch.FloatTensor,
               supports_mask: torch.LongTensor = None):
        """
        Parameters
        ----------
        supports_vectors: (batch_size, supports_length, supports_dim)
        query_vectors: (batch_size, query_dim)
        candidates_vectors: (batch_size, candidates_lenght, candidates_dim)
        
        Returns
        -------
        supports_probability: (batch_size, supports_length) | normalized
        candidates_score: (batch_size, candidates_length) | unnormalized
        """
        
        h0 = query_vectors
        memory = supports_pooling_vectors
        memory_mask = supports_mask
        
        supports_probabilities_list = []
        candidates_scores_list = []
        
        for i in range(self.num_step):
            supports_prob = self.supports_predictor(h0, memory, memory_mask)
            
            x_i = util.weighted_sum(memory, supports_prob)
            candidates_score = self.candidates_predictor(x_i, candidates_vectors)
            
            h0 = self.rnn(x_i, h0)
            
            supports_probabilities_list.append(supports_prob)
            candidates_scores_list.append(candidates_score)
            
        # stochastic dropout    
        if self.reason_type == 0:
            supports_probabilities = torch.stack(supports_probabilities_list,2)
            candidates_scores = torch.stack(candidates_scores_list, 2)      
            
            batch_size = h0.size(0)
            mask = self.generate_mask(batch_size)
            mask = mask.unsqueeze(1)
            
            supports_probabilities = supports_probabilities * mask.expand_as(supports_probabilities)
            candidates_scores = candidates_scores * mask.expand_as(candidates_scores)
            final_supports_prob = torch.mean(supports_probabilities, 2)
            final_candidates_score = torch.mean(candidates_scores, 2)  
        # prediction from the final step
        elif self.reason_type == 1:
            final_supports_prob = supports_probabilities_list[-1]
            final_candidates_score = candidates_scores_list[-1]
        # prediction averaged from all the steps     
        elif self.reason_type == 2:
            supports_probabilities = torch.stack(supports_probabilities_list,2)
            candidates_scores = torch.stack(candidates_scores_list, 2)
            
            final_supports_prob = torch.mean(supports_probabilities, 2)
            final_candidates_score = torch.mean(candidates_scores, 2)
        return final_supports_prob, final_candidates_score
            
    def generate_mask(self, batch_size: int) -> torch.Tensor:
        if self.training:
            dropout_p = self.reason_dropout_p
        else:
            dropout_p = 0.0

        new_data = self.alpha.data.new_zeros(batch_size, self.num_step)
        new_data = (1-dropout_p) * (new_data.zero_() + 1)
        for i in range(new_data.size(0)):
            one = random.randint(0, new_data.size(1)-1)
            new_data[i][one] = 1
        mask = 1.0/(1 - dropout_p) * torch.bernoulli(new_data)
        mask.requires_grad = False
        return mask            

In [None]:
san = SANDecoder(200, 200,200, reason_type=1)

In [None]:
import pixiedust


In [None]:
%%pixie_debugger

supports_prob, candidates_score= san(supports_pooling_vectors, question_pooling_vectors, candidates_pooling_vectors, supports_mask_para)

In [None]:
print(1)