In [53]:
from collections import Counter, defaultdict
import logging
import re
import json
import jieba
import numpy as np
from tqdm import tqdm
from sklearn.model_selection import train_test_split

In [2]:
logger = logging.getLogger('les2')
logger.setLevel(logging.INFO)
console_handler = logging.StreamHandler()
# console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

In [3]:
with open('./data/question.json', 'r', encoding='utf-8') as f:
    train_set = json.loads(f.read())

In [4]:
# 训练集共20000篇文章
len(train_set)

20000

In [5]:
def precision_recall_f1(prediction, ground_truth):
    if not isinstance(prediction, list):
        prediction_tokens = prediction.split()
    else:
        prediction_tokens = prediction
    if not isinstance(ground_truth, list):
        ground_truth_tokens = ground_truth.split()
    else:
        ground_truth_tokens = ground_truth
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())
    if num_same == 0:
        return 0, 0, 0
    p = 1.0 * num_same / len(prediction_tokens)
    r = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * p * r) / (p + r)
    return p, r, f1

In [6]:
def recall(prediction, ground_truth):
    return precision_recall_f1(prediction, ground_truth)[1]


def f1_score(prediction, ground_truth):
    return precision_recall_f1(prediction, ground_truth)[2]

In [7]:
def metric_max_over_ground_truths(metric_fn, prediction, ground_truth):
    score = metric_fn(prediction, ground_truth)
    return score

In [8]:
# 找到最相关的段落和在段落中的位置
def find_fake_answer(sample):
    for a_idx, answer_token in enumerate(sample['questions']):
        most_related_para = -1
        most_related_para_len = 999999
        max_related_score = 0
#         print('a_idx=',a_idx, 'answer_token=',answer_token)
        for p_idx, para_tokens in enumerate(sample['segmented_article_content']):
            related_score = metric_max_over_ground_truths(recall,
                                                          para_tokens,
                                                          answer_token['segmented_answer'])
#             print('p_idx=',p_idx,'related_score=',related_score)
            if related_score > max_related_score \
                    or (related_score == max_related_score
                        and len(para_tokens) < most_related_para_len):
                most_related_para = p_idx
                most_related_para_len = len(para_tokens)
                max_related_score = related_score
        sample['questions'][a_idx]['most_related_para'] = most_related_para
        most_related_para_tokens = sample['segmented_article_content'][most_related_para]
        
        answer_tokens = set(answer_token['segmented_answer'])
        best_match_score = 0
        best_match_span = [-1, -1]
        best_fake_answer = None
        
        for start_tidx in range(len(most_related_para_tokens)):
            if most_related_para_tokens[start_tidx] not in answer_tokens:
                continue
            for end_tidx in range(len(most_related_para_tokens) - 1, start_tidx - 1, -1):
                span_tokens = most_related_para_tokens[start_tidx: end_tidx + 1]
                match_score = metric_max_over_ground_truths(f1_score, span_tokens,
                                                                answer_token['segmented_answer'])
                if match_score == 0:
                    break
                if match_score > best_match_score:
                    best_match_span = [start_tidx, end_tidx]
                    best_match_score = match_score
                    best_fake_answer = ''.join(span_tokens)
        sample['questions'][a_idx]['answer_spans'] = best_match_span
        sample['questions'][a_idx]['fake_answers'] = best_fake_answer
        sample['questions'][a_idx]['match_scores'] = best_match_score
    return sample

In [9]:
def clean_data(sample):
    # 文章内容和标题分段->分词：将标题插入到分段后的首位置
    sample['segmented_article_title'] = \
        list(jieba.cut(''.join(re.split(r'\u3000+|\s+|\t+',sample['article_title'].strip()))))
    
    sample_splited_para = re.split(r'\u3000+|\s+|\t+',sample['article_content'].strip())
    if len(sample_splited_para) == 1 and len(sample_splited_para[0]) > 200:
        sample_splited_para = re.split(r'\。',sample['article_content'].strip())
    sample_splited_list = []
    for para in sample_splited_para:
        sample_splited_list.append(list(jieba.cut(para.strip(), cut_all=False)))
    sample_splited_list.insert(0, sample['segmented_article_title'])

    sample['segmented_article_content'] = sample_splited_list
       
    # 问题和答案分词处理
    for i,question in enumerate(sample['questions']):
        sample['questions'][i]['segmented_question'] = \
            list(jieba.cut(''.join(question['question'].strip().split('\u3000+|\s+|\t+'))))
        sample['questions'][i]['segmented_answer'] = \
            list(jieba.cut(''.join(question['answer'].strip().split('\u3000+|\s+|\t+'))))
    return sample

In [12]:
def store_prerpocess_data():
    preprocess_data = []
    for i in range(1,201):
        with open('./data/preprocessed_%d.json' % i, 'r', encoding='utf-8') as f:
            d = json.load(f)
        preprocess_data.extend(d)
    with open('./data/preprocessed.json', 'w', encoding='utf-8') as f:
            json.dump(preprocess_data, f)

In [10]:
# 数据去重
with open('./data/preprocessed.json', 'r', encoding='utf-8') as f:
    data_preprocessed = json.load(f)

In [12]:
title = Counter([data_preprocessed[i]['article_title'] for i in range(len(data_preprocessed))])

In [15]:
len(title)

18476

In [69]:
{x : title[x] for x in title if title[x] >= 2 }

{'沙特成功拦截一枚导弹 疑来自也门胡塞武装': 2,
 '美国驻黑山大使馆遭自杀式爆炸袭击 未造成严重损坏': 2,
 '王毅与坦桑尼亚外长马希加举行会谈': 3,
 '特雷莎梅：突访伊拉克劳军 造访中东搞军售': 2,
 '国家和军队多部门联合部署全面规范现役士兵职业技能鉴定': 2,
 '“J警报”称朝鲜疑将发弹道导弹!? NHK为误报致歉': 2,
 '伊拉克西部一咖啡馆遭爆炸袭击至少10人死亡': 2,
 '阿富汗首都发生巨大爆炸造成近400人伤亡': 2,
 '俄战略轰炸机向叙境内极端组织目标发射导弹': 3,
 '美国称“确信”叙利亚发动化武袭击 正研究化学制剂': 2,
 '习近平就阿富汗发生汽车炸弹袭击事件向加尼总统致慰问电': 2,
 '雪中送炭之举': 2,
 '叙政府谴责美空袭 以战机对叙目标发动导弹攻击': 2,
 '美三艘航母将齐聚西太搞军演 外媒：展示战力震慑朝鲜': 2,
 '朝鲜驻新大使馆：朝美会谈将顺利举行': 2,
 '王毅同朝鲜外相李勇浩举行会谈': 7,
 '美军核航母编队已到关岛 酝酿加大南海巡航力度': 2,
 '为防武力夺权 传中国武警部队将由中央军委直辖': 2,
 '外交部副部长孔铉佑同韩方谈朝鲜半岛局势': 3,
 '韩日首脑将在平昌举行会谈 或磋商慰安妇等议题': 2,
 '中国海军第二十七批护航编队结束对突尼斯访问': 2,
 '阿富汗西部遭塔利班袭击 造成至少20名警察身亡': 2,
 '沙特拦截两枚胡塞武装发射的导弹 击落两架无人机': 2,
 '一艘载有中国船员船只在马来西亚附近海域倾覆': 2,
 '中国将派雪豹特种部队打击叙境内“东突”分子？外交部回应': 2,
 '美军泄密变性士兵进哈佛 气走中情局新老领导': 2,
 '国防部回应美军机在吉布提受到中国驻吉基地激光照射': 2,
 '性能突出！美国海军接收第15艘弗吉尼亚级核潜艇': 2,
 '愤怒的抗议一边儿去 驻日美军新基地强行施工': 2,
 '简氏称东风26亮相阅兵式表明该导弹或已部署': 2,
 '俄罗斯检查站遭自杀式爆炸袭击 致警员1死2伤': 2,
 '习近平同冈比亚总统巴罗举行会谈': 7,
 '俄强化对招募恐怖分子罪行惩罚力度 最高可判无期徒刑': 2,
 '俄国防部网站遭到密集黑客攻击': 2,
 '八艘无人驾驶潜艇或将联合搜

In [21]:
title_set = set()
data_qc = []
for sample in data_preprocessed:
    title = sample['article_title']
    if title in title_set:
        continue
    else:
        title_set.add(title)
        data_qc.append(sample)

In [22]:
len(data_qc)

18476

In [24]:
data_qc[100]

{'article_content': '据中航工业网站消息，20日，中航工业成发与中航空天发动机研究院就短垂项目加工合作举行签约仪式。短距起飞/垂直降落飞机推进系统项目(简称短垂项目)是针对提高海军两栖作战能力，填补该类作战武器装备空白而进行的探索项目。有人士推断，中国此次开发的短垂项目使用的应为一型喷气式涡扇发动机。中航工业成发与中航空天发动机研究院签署的风扇部件，很可能是类似美国F-35B垂直/短距起降战斗机上F135-PW-600发动机所采用的升力风扇部件。据了解，垂直起降技术是从50年代末期开始发展的一项航空技术。虽然起步相对较早，但是受制于飞控、发动机动力、材料等多种因素导致成功服役的战机相对较少。目前性能比较好的垂直起降战斗机是美国的F-35当中的F-35B型号。但尚未装备部队。美国在垂直起降技术方面已经超越了创始国英国，拥有目前世界最先进的垂直起降技术。虽然中国落后其他发达国家半个世纪，但是现在努力也为时未晚。因此，这个绝对是激动人心的好事情。最早的垂直起降飞行器设想是几位英国人在1975年提出的，设想是由一对装有螺桨推进器的可旋转机翼组成的飞行器，在进行垂直起飞和降落时机翼与地面为垂直状态，依靠螺桨产生的升力进行上升与下降，当飞行器上升到空中时，机翼旋转90度，变成与普通飞机一样的方式前进(类似美国的鱼鹰)，按照该设想，飞行器的动力由一台蒸气引擎提供。中国也曾研制过垂直短距起降飞机。1969年3月2日，中苏边境发生了珍宝岛事件，中国一下子被推到了战争的边缘。随后，中国全国都开始了备战准备。此时，林彪指示空军和三机部要根据准备打仗要求，除了大搞运输机和直升飞机外，还要尽快研制既能打又便于藏的作战飞机包括垂直起降飞机，并要求空军在最短时间内拿出意见和方案来。2015-3-26 参考消息网',
 'article_id': '48621',
 'article_title': '中国开始研制垂直起降战机 提高海军两栖作战能力',
 'article_type': '防务快讯',
 'questions': [{'answer': '美国的F-35当中的F-35B型号',
   'answer_spans': [172, 183],
   'fake_answers': '美国的F-35当中的F-35B型号',
   'match_scores': 1.0

In [26]:
with open('./data/preprocessed_qc.json', 'w', encoding='utf-8') as f:
    json.dump(data_qc, f)

In [40]:
def train_test_split(dataset,train_percent=0.9):
    index = np.arange(len(dataset))
    np.random.shuffle(index)

    train_size = int(len(dataset) * train_percent)
    train_index = index[:train_size]
    test_index = index[train_size:]
    train_set, test_set = [], []
    for index in train_index:
        train_set.append(dataset[index])
    for index in test_index:
        test_set.append(dataset[index])
        
    return train_set, test_set

In [48]:
trainset, testset = train_test_split(data_qc)

In [137]:
class LESDataset(object):
    def __init__(self, max_p_len, max_q_len,vocab, train_file=None, test_file=None):
        self.max_p_len = max_p_len
        self.max_q_len = max_q_len
        self.vocab = vocab
        if train_file:
            self.train_set = self._load_dataset(train_file)
        if test_file:
            self.test_set = self._load_dataset(test_file)

    def _load_dataset(self, data_path, train=True):
        """
        加载数据集
        :param data_path:
        :return:
        """
        with open(data_path, 'r', encoding='utf-8') as f:
            data_set = json.load(f)
        if train:
            data = []
            for sample in data_set:
                for qa_pairs in sample['questions']:
                    if qa_pairs['answer_spans'][0] == -1:
                        continue
                    data.append({'question':qa_pairs['segmented_question'],
                                'passage':sample['segmented_article_content'][qa_pairs['most_related_para']],
                                'answer_span':qa_pairs['answer_spans']})
        return data
    
    def word_iter(self, set_name):
        if set_name == 'train':
            data_set = self.train_set

        for sample in data_set:
            for question in sample['questions']:
                for word in question['segmented_question']:
                    yield word
                for word in sample['segmented_article_content'][question['most_related_para']]:
                    yield word
                    
    def gen_mini_batches(self, set_name, batch_size, pad_id=0,shuffle=True):
        if set_name == 'train':
            data = self.train_set
            
        data_size = len(data)
        indices = np.arange(data_size)
        if shuffle:
            np.random.shuffle(indices)
        for batch_start in np.arange(0, data_size, batch_size):
            batch_indices = indices[batch_start:batch_start+batch_size]
            batch_data = [data[i] for i in batch_indices]
            yield self._one_mini_batch(batch_data, pad_id)
            
    def _one_mini_batch(self, batch_data_raw, pad_id):
        batch_data = {'question_token_ids':[],
                     'question_length':[],
                     'passage_token_ids':[],
                     'passage_length':[],
                     'start_id':[],
                     'end_id':[]}
        for qa_pairs in batch_data_raw:
            batch_data['question_token_ids'].append(self.convert_to_ids(qa_pairs['question'])),
            batch_data['question_length'].append(len(qa_pairs['question']))
            batch_data['passage_token_ids'].append(self.convert_to_ids(qa_pairs['passage']))
            batch_data['passage_length'].append(len(qa_pairs['passage']))
            batch_data['start_id'].append(qa_pairs['answer_span'][0])
            batch_data['end_id'].append(qa_pairs['answer_span'][1])
            
        batch_data = self._dynamic_padding(batch_data, pad_id)
        return batch_data
    
    def _dynamic_padding(self, batch_data, pad_id):
        pad_p_len = min(self.max_p_len, max(batch_data['passage_length']))
        pad_q_len = min(self.max_q_len, max(batch_data['question_length']))
        batch_data['passage_token_ids'] = [(ids + [pad_id] * (pad_p_len - len(ids)))[:pad_p_len]
                                                for ids in batch_data['passage_token_ids']]
        batch_data['question_token_ids'] = [(ids + [pad_id] * (pad_q_len - len(ids)))[:pad_q_len]
                                                for ids in batch_data['question_token_ids']]
        return batch_data
            
        
    def convert_to_ids(self,tokens):
        ids = []
        for token in tokens:
            ids.append(self.vocab.token2id[token.lower()])
        return ids
        
        

In [138]:
class Vocab(object):

    def __init__(self, filename=None, lower=False):
        self.id2token = {}
        self.token2id = {}
        self.token_cnt = defaultdict(int)
        self.lower = lower

        self.embed_dim = None
        self.embeddings = None

        self.pad_token = '<blank>'
        self.unk_token = '<unk>'

        self.initial_tokens = []
        self.initial_tokens.extend([self.pad_token, self.unk_token])

        for token in self.initial_tokens:
            self.add(token)


    def add(self, token, cnt=True):
        token = token.lower() if self.lower else token

        if token in self.token2id:
            idx = self.token2id[token]
        else:
            idx = len(self.token2id)
            self.token2id[token] = idx
            self.id2token[idx] = token
        if cnt:
            self.token_cnt[token] += 1

        return idx
    
    def randomly_init_embeddings(self, embed_dim):
        self.embed_dim = embed_dim
        self.embeddings = np.random.rand(len(self.token2id), embed_dim)
        
        for token in [self.pad_token, self.unk_token]:
            self.embeddings[self.token2id[token]] = np.zeros([embed_dim])

In [139]:
les_dataset = LESDataset(300, 60,vocab ,train_file='F:\\jupyter_file\\MC\\data\\trainset.json')

In [140]:
batches = les_dataset.gen_mini_batches('train', 32)

In [141]:
batch = next(batches)

In [146]:
batch['question_token_ids']

[[275,
  538,
  27,
  301,
  2738,
  2985,
  15973,
  15974,
  93,
  168,
  587,
  201,
  1081,
  55281,
  881,
  34,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [107,
  24144,
  762,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [14165,
  27,
  27464,
  4745,
  681,
  682,
  303,
  41,
  1277,
  4351,
  23,
  995,
  5350,
  27,
  955,
  23967,
  891,
  53,
  1282,
  27,
  405,
  1914,
  149,
  877,
  573,
  536,
  125,
  13320,
  37,
  27465,
  1226,
  560,
  614,
  10,
  930,
  735,
  152,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0],
 [321,
  1664,
  22746,
  121,
  692,
  302,
  1679,
  3707,
  692,
  305,
  5218,
  27,
  165,
  1393,
  33185,
  23,
  2320,
  910,
  934,
  27,
  4637,
  2320,
