In [1]:
import os
import sys
import torch
import torchvision
from tqdm import tqdm_notebook as tqdm
from transformers import BertModel, BertTokenizer

w_dir = %pwd
work_dir = os.path.dirname(w_dir)
work_dir

I1211 05:31:07.330024 140436334675776 file_utils.py:39] PyTorch version 1.1.0 available.


'/work'

In [2]:
sys.path.append(w_dir+'/fgc_support_retri')

In [7]:
import config
from sup_model import BertForMultiHopQuestionAnswering, BertSupTagModel
from utils import read_fgc
from fgc_preprocess import BertSpanTagIdx, SerContextDataset, bert_context_collate

In [73]:
class SER_context_extract:
    def __init__(self):
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        device = torch.device("cpu")
        bert_model_name = 'bert-base-chinese'
        bert_encoder = BertModel.from_pretrained(bert_model_name)
        bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        model = BertSupTagModel(bert_encoder, device)
        model_path = config.TRAINED_MODELS / '20191210_BertSupTag'/ 'model_epoch5_loss_0.360.m' 
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval()
        
        self.indexer = BertSpanTagIdx(bert_tokenizer)
        self.model = model
        self.device = device
        
    def predict(self, context_sents, question):
        sample = self.indexer({'QTEXT': question, 'SENTS': context_sents})
        item = bert_context_collate([sample])
        with torch.no_grad():
            input_ids = item['input_ids'].to(self.device)
            token_type_ids = item['token_type_ids'].to(self.device)
            attention_mask = item['attention_mask'].to(self.device)
            logits = self.model(input_ids=input_ids, 
                                token_type_ids=token_type_ids,
                                attention_mask=attention_mask, 
                                mode=BertSupTagModel.ForwardMode.EVAL)
            tag_list = logits[0].cpu().numpy()
            
            sep_positions = [None] * len(sample['sentence_position'])
            for position, sid in sample['sentence_position'].items():
                sep_positions[sid] = position
            
            prediction = []
            for tid, tag in enumerate(tag_list):
                if tag == 1:
                    for sid in range(len(sep_positions)-1):
                        if sep_positions[sid] < tid < sep_positions[sid+1]:
                            prediction.append(sid)
        return prediction 
                

In [74]:
extractor = SER_context_extract()

I1211 06:51:13.323674 140436334675776 configuration_utils.py:152] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json from cache at /root/.cache/torch/transformers/8a3b1cfe5da58286e12a0f5d7d182b8d6eca88c08e26c332ee3817548cf7e60a.0c16faba8be66db3f02805c912e4cf94d3c9cffc1f12fa1a39906f9270f76d33
I1211 06:51:13.326623 140436334675776 configuration_utils.py:169] Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_decoder": false,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "num_labels": 2,
  "output_attentions": false,
  "output_hidden_states": false,
  "output_past": true,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "po

In [75]:
def evalaluate(fgc_items):
    tp = 0
    gol_t = 0
    pre_t = 0
    for data in tqdm(fgc_items):
        prediction = extractor.predict(data['SENTS'], data['QTEXT'])
        gold = data['SUP_EVIDENCE']
        pred = prediction
        
        gol_t += len(gold)
        pre_t += len(pred)
        for g in gold:
            if g in pred:
                tp += 1
        data['prediction'] = prediction
                
    precision = tp / pre_t
    recall = tp / gol_t

    f1 = 2*precision*recall / (precision+recall)

    print("precision = {}".format(precision))
    print("recall = {}".format(recall))
    print("f1 = {}".format(f1))
        

In [78]:
fgc_train_items = read_fgc(config.FGC_DEV)

no gold supporting evidence
{'QID': 'D032Q10', 'QTYPE': '进阶题', 'QTEXT': '第二次簽訂的北美貿易協定從簽署至生效過了幾日?', 'SENTS': [{'text': '第二次签订的北美贸易协定从签署至生效过了几日?', 'start': 0, 'end': 23}], 'ANSWER': [{'ATEXT': '資訊不足無法判定', 'ATOKEN': [{'text': '资讯不足无法判定', 'start': -1}], 'ATEXT_CN': '资讯不足无法判定'}], 'ATYPE': 'Date-Duration', 'AMODE': 'Date-Duration', 'ASPAN': [], 'SHINT': [], 'QTEXT_CN': '第二次签订的北美贸易协定从签署至生效过了几日?'}
no gold supporting evidence
{'QID': 'D049Q04', 'QTYPE': '申论', 'QTEXT': '「雅婷逐字稿」的命名起源為何?', 'SENTS': [{'text': '「雅婷逐字稿」的命名起源为何?', 'start': 0, 'end': 15}], 'ANSWER': [{'ATEXT': '', 'ATOKEN': [{'text': '', 'start': 0}], 'ATEXT_CN': ''}], 'ATYPE': 'Event', 'AMODE': 'Single-Span-Extraction', 'ASPAN': [], 'SHINT': [], 'QTEXT_CN': '「雅婷逐字稿」的命名起源为何?'}
no gold supporting evidence
{'QID': 'D086Q03', 'QTYPE': '申论', 'QTEXT': '不可再生能源的意義是什麼？', 'SENTS': [{'text': '不可再生能源的意义是什么？', 'start': 0, 'end': 13}], 'ANSWER': [{'ATEXT': '', 'ATOKEN': [{'text': '', 'start': 0}], 'ATEXT_CN': ''}], 'ATYPE': 'Object', 'AMODE': 'Sing

In [79]:
evalaluate(fgc_train_items)

HBox(children=(IntProgress(value=0, max=213), HTML(value='')))


precision = 0.32172701949860727
recall = 0.4762886597938144
f1 = 0.38403990024937656


In [70]:
fgc_train_items[1]

{'QID': 'D001Q02',
 'SENTS': [{'text': '苏轼（1037年1月8日－1101年8月24日），', 'start': 0, 'end': 25},
  {'text': '眉州眉山（今四川省眉山市）人，', 'start': 25, 'end': 40},
  {'text': '北宋时著名的文学家、政治家、艺术家、医学家。', 'start': 40, 'end': 62},
  {'text': '字子瞻，一字和仲，', 'start': 62, 'end': 71},
  {'text': '号东坡居士、铁冠道人。', 'start': 71, 'end': 82},
  {'text': '嘉佑二年进士，', 'start': 82, 'end': 89},
  {'text': '累官至端明殿学士兼翰林学士，', 'start': 89, 'end': 103},
  {'text': '礼部尚书。南宋理学方炽时，', 'start': 103, 'end': 116},
  {'text': '加赐谥号文忠，', 'start': 116, 'end': 123},
  {'text': '复追赠太师。', 'start': 123, 'end': 129},
  {'text': '有《东坡先生大全集》及《东坡乐府》词集传世，', 'start': 129, 'end': 151},
  {'text': '宋人王宗稷收其作品，', 'start': 151, 'end': 161},
  {'text': '编有《苏文忠公全集》。', 'start': 161, 'end': 172},
  {'text': '\n其散文、诗、词、赋均有成就，', 'start': 172, 'end': 187},
  {'text': '且善书法和绘画，', 'start': 187, 'end': 195},
  {'text': '是文学艺术史上的通才，', 'start': 195, 'end': 206},
  {'text': '也是公认韵文散文造诣皆比较杰出的大家。', 'start': 206, 'end': 225},
  {'text': '苏轼的散文为唐宋四家（韩愈、柳宗元、欧苏）之末，', 'start'