In [1]:
import os
import sys
import torch
import torchvision
from tqdm import tqdm_notebook as tqdm
from transformers import BertModel, BertTokenizer

w_dir = %pwd
work_dir = os.path.dirname(w_dir)
work_dir

I1223 02:14:15.970679 139666190767936 file_utils.py:39] PyTorch version 1.1.0 available.


'/work'

In [2]:
sys.path.append(w_dir+'/fgc_support_retri')

In [3]:
from fgc_support_retri.ser_extractor import *
from fgc_support_retri.utils import read_fgc, read_hotpot
from fgc_support_retri.eval import evalaluate_f1

In [8]:
class SER_sent_extract_V2:
    def __init__(self):
        device = torch.device("cpu")
        bert_model_name = config.BERT_EMBEDDING
        bert_tokenizer = BertTokenizer.from_pretrained(bert_model_name)
        bert_indexer = BertSentV1Idx(bert_tokenizer)
        model = BertSentenceSupModel_V2.from_pretrained(bert_model_name)
        model_path = config.TRAINED_MODELS / '20191219_test2' / 'model_epoch20_eval_recall_0.524_f1_0.465.m'
        model.load_state_dict(torch.load(model_path, map_location=device))
        model.to(device)
        model.eval()
        
        self.tokenizer = bert_tokenizer
        self.model = model
        self.bert_indexer = bert_indexer
        self.device = device
    
    def predict(self, items):
        predictions = []
        for item in tqdm(items):
            with torch.no_grad():
                train_set = SerSentenceDataset([item], transform=torchvision.transforms.Compose([BertSentV2Idx(self.tokenizer)]))
                batch = bert_sentV2_collate([sample for sample in train_set])
                for key in ['input_ids', 'token_type_ids', 'attention_mask', 'tf_type', 'idf_type']:
                    batch[key] = batch[key].to(self.device)
                prediction = self.model.predict(batch, threshold=0.4)
                prediction.sort()
                item['sup_prediction'] = prediction
                predictions.append(prediction)
                
        return predictions 

In [9]:
fgc_items = read_fgc(config.FGC_DEV)

no gold supporting evidence
{'QID': 'D032Q10', 'QTYPE': '进阶题', 'QTEXT': '第二次簽訂的北美貿易協定從簽署至生效過了幾日?', 'SENTS': [{'text': '第二次签订的北美贸易协定从签署至生效过了几日?', 'start': 0, 'end': 23}], 'ANSWER': [{'ATEXT': '資訊不足無法判定', 'ATOKEN': [{'text': '资讯不足无法判定', 'start': -1}], 'ATEXT_CN': '资讯不足无法判定'}], 'ATYPE': 'Date-Duration', 'AMODE': 'Date-Duration', 'ASPAN': [], 'SHINT': [], 'QTEXT_CN': '第二次签订的北美贸易协定从签署至生效过了几日?'}
no gold supporting evidence
{'QID': 'D049Q04', 'QTYPE': '申论', 'QTEXT': '「雅婷逐字稿」的命名起源為何?', 'SENTS': [{'text': '「雅婷逐字稿」的命名起源为何?', 'start': 0, 'end': 15}], 'ANSWER': [{'ATEXT': '', 'ATOKEN': [{'text': '', 'start': 0}], 'ATEXT_CN': ''}], 'ATYPE': 'Event', 'AMODE': 'Single-Span-Extraction', 'ASPAN': [], 'SHINT': [], 'QTEXT_CN': '「雅婷逐字稿」的命名起源为何?'}
no gold supporting evidence
{'QID': 'D086Q03', 'QTYPE': '申论', 'QTEXT': '不可再生能源的意義是什麼？', 'SENTS': [{'text': '不可再生能源的意义是什么？', 'start': 0, 'end': 13}], 'ANSWER': [{'ATEXT': '', 'ATOKEN': [{'text': '', 'start': 0}], 'ATEXT_CN': ''}], 'ATYPE': 'Object', 'AMODE': 'Sing

In [10]:
extractor = SER_sent_extract_V2()

predictions = []
predictions = extractor.predict(fgc_items)
precision, recall, f1 = evalaluate_f1(fgc_items, predictions)
print(precision)
print(recall)
print(f1)

I1223 02:53:41.984415 139666190767936 tokenization_utils.py:375] loading file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-vocab.txt from cache at /root/.cache/torch/transformers/8a0c070123c1f794c42a29c6904beb7c1b8715741e235bee04aca2c7636fc83f.9b42061518a39ca00b8b52059fd2bede8daa613f8a8671500e518a8c29de8c00
I1223 02:53:42.901609 139666190767936 configuration_utils.py:152] loading configuration file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-chinese-config.json from cache at /root/.cache/torch/transformers/8a3b1cfe5da58286e12a0f5d7d182b8d6eca88c08e26c332ee3817548cf7e60a.0c16faba8be66db3f02805c912e4cf94d3c9cffc1f12fa1a39906f9270f76d33
I1223 02:53:42.904864 139666190767936 configuration_utils.py:169] Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "finetuning_task": null,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "is_

HBox(children=(IntProgress(value=0, max=213), HTML(value='')))


0.4184514003294893
0.5237113402061856
0.46520146520146516


In [7]:
for item in fgc_items:
    
    new_sents = []
    for s_i, s in enumerate(item['SENTS']):
        new_sents.append({s_i: s['text']})
    del item['SENTS']
    item['SENTS'] = new_sents


In [21]:
fgc_items[10]

{'QID': 'D011Q02',
 'SUP_EVIDENCE': [9, 16],
 'QTEXT': '形成於北大西洋的熱帶氣旋，又被稱為什麼？',
 'ANS': '颶風',
 'ASPAN': [{'text': '北大西洋', 'start': 417, 'end': 421},
  {'text': '飓风', 'start': 424, 'end': 426},
  {'text': '热带气旋', 'start': 247, 'end': 251},
  {'text': '台风', 'start': 281, 'end': 283},
  {'text': '台风', 'start': 413, 'end': 415},
  {'text': '北太平洋', 'start': 404, 'end': 408}],
 'sup_prediction': [16],
 'SENTS': [{0: '台风（英语：Typhoon，香港天文台缩写T.；日语：台风/たいふう/taifū；韩语：태풍/台风/taepung）是赤道以北，'},
  {1: '国际换日线以西，'},
  {2: '亚太国家或地区对热带气旋的一个分级。'},
  {3: '在气象学上，'},
  {4: '按世界气象组织定义，'},
  {5: '热带气旋中心持续风速达到12级（即64节或以上、32.7m/s或以上，又或者118km/hr或以上）称为飓风（Hurricane）或其他在地近义字。'},
  {6: '西北太平洋地区采用之近义字乃台风。'},
  {7: '\n广义上，「台风」这个词并非一种热带气旋强度。'},
  {8: '在台湾、日本等地，'},
  {9: '将中心持续风速每秒17.2米或以上的热带气旋（包括世界气象组织定义中的热带风暴、强烈热带风暴和台风）均称台风。'},
  {10: '在非正式场合，'},
  {11: '「台风」甚至直接泛指热带气旋本身。'},
  {12: '当西北太平洋的热带气旋达到热带风暴的强度，'},
  {13: '区域专责气象中心（RSMC）日本气象厅会对其编号及命名，'},
  {14: '名称由世界气象组织台风委员会的14个国家和地区提供。'},
  {15: '\n但在不同海洋上也各自有地区性的名称，'},
  {16: 