In [1]:
from __future__ import absolute_import, division, print_function

import os
import random
import json
import time
from io import open
import math
import collections
import numpy as np
import torch
import pickle
from operator import itemgetter

from torch.utils.data import (DataLoader, TensorDataset, SequentialSampler)

import transformers
#from IR.mix_LM import Config as IRConfig, ICT
from IR.mix_LM import Config as IRConfig, BertMultiTask_ICT as BertMultiTask

from MRC.tokenization import FullTokenizer
from MRC.modeling_mrc import Config as MRCConfig, MixTraining
from MRC.korquad_utils import read_squad_examples, convert_examples_to_features

RawResult = collections.namedtuple("RawResult", ["unique_id", "start_logits", "end_logits"])

def write_predictions(all_examples, all_features, all_results, n_best_size,
                      max_answer_length):
    """Write final predictions to the json file and log-odds of null if needed."""

    example_index_to_features = collections.defaultdict(list)
    for feature in all_features:
        example_index_to_features[feature.example_index].append(feature)

    unique_id_to_result = {}
    for result in all_results:
        unique_id_to_result[result.unique_id] = result
    
    #print(unique_id_to_result)
    _PrelimPrediction = collections.namedtuple(  # pylint: disable=invalid-name
        "PrelimPrediction",
        ["feature_index", "start_index", "end_index", "start_logit", "end_logit"])

    all_predictions = collections.OrderedDict()
    all_nbest_json = collections.OrderedDict()

    for (example_index, example) in enumerate(all_examples):
        features = example_index_to_features[example_index]

        prelim_predictions = []
        for (feature_index, feature) in enumerate(features):
            result = unique_id_to_result[feature.unique_id]
            start_indexes = _get_best_indexes(result.start_logits, n_best_size)
            end_indexes = _get_best_indexes(result.end_logits, n_best_size)
            # if we could have irrelevant answers, get the min score of irrelevant
            for start_index in start_indexes:
                for end_index in end_indexes:
                    # We could hypothetically create invalid predictions, e.g., predict
                    # that the start of the span is in the question. We throw out all
                    # invalid predictions.
                    if start_index >= len(feature.tokens):
                        continue
                    if end_index >= len(feature.tokens):
                        continue
                    if start_index not in feature.token_to_orig_map:
                        continue
                    if end_index not in feature.token_to_orig_map:
                        continue
                    if not feature.token_is_max_context.get(start_index, False):
                        continue
                    if end_index < start_index:
                        continue
                    length = end_index - start_index + 1
                    if length > max_answer_length:
                        continue
                    prelim_predictions.append(
                        _PrelimPrediction(
                            feature_index=feature_index,
                            start_index=start_index,
                            end_index=end_index,
                            start_logit=result.start_logits[start_index],
                            end_logit=result.end_logits[end_index]))

        prelim_predictions = sorted(
            prelim_predictions,
            key=lambda x: (x.start_logit + x.end_logit),
            reverse=True)
        
        _NbestPrediction = collections.namedtuple(  # pylint: disable=invalid-name
            "NbestPrediction", ["text",
                                "start_logit",
                                "end_logit",
                                "start_index"])

        seen_predictions = {}
        nbest = []
        for pred in prelim_predictions:
            if len(nbest) >= n_best_size:
                break
            feature = features[pred.feature_index]

            orig_doc_start = feature.token_to_orig_map[pred.start_index]
            orig_doc_end = feature.token_to_orig_map[pred.end_index]
            try:
                orig_text = feature.paragraph_text[feature.span_offset[orig_doc_start].start:(feature.span_offset[orig_doc_end].end)+1]
            except IndexError:
                print('index error')
                continue
                
            if orig_text.endswith(" "):
                orig_text = orig_text[:-1]
            final_text = orig_text
            if final_text in seen_predictions:
                continue

            seen_predictions[final_text] = True

            nbest.append(
                _NbestPrediction(
                    text=final_text,
                    start_logit=pred.start_logit,
                    end_logit=pred.end_logit,
                    start_index=feature.span_offset[orig_doc_start].start
                ))

        # In very rare edge cases we could have no valid predictions. So we
        # just create a nonce prediction in this case to avoid failure.
        if not nbest:
            nbest.append(
                _NbestPrediction(text="empty", start_logit=0.0, end_logit=0.0, start_index=0))

        assert len(nbest) >= 1

        total_scores = []
        best_non_null_entry = None
        for entry in nbest:
            total_scores.append(entry.start_logit + entry.end_logit)

        probs = _compute_softmax(total_scores)

        nbest_json = []
        for (i, entry) in enumerate(nbest):
            output = collections.OrderedDict()
            output["text"] = entry.text
            output["probability"] = probs[i]
            output["start_logit"] = entry.start_logit
            output["end_logit"] = entry.end_logit
            output["start_index"] = entry.start_index
            nbest_json.append(output)

        assert len(nbest_json) >= 1

        answer = nbest_json[0]["text"]
        all_predictions[example.qas_id] = answer
        all_nbest_json[example.qas_id] = nbest_json

    return all_predictions, all_nbest_json

def _get_best_indexes(logits, n_best_size):
    """Get the n-best logits from a list."""
    index_and_score = sorted(enumerate(logits), key=lambda x: x[1], reverse=True)

    best_indexes = []
    for i in range(len(index_and_score)):
        if i >= n_best_size:
            break
        best_indexes.append(index_and_score[i][0])
    return best_indexes


def _compute_softmax(scores):
    """Compute softmax probability over raw logits."""
    if not scores:
        return []

    max_score = None
    for score in scores:
        if max_score is None or score > max_score:
            max_score = score

    exp_scores = []
    total_sum = 0.0
    for score in scores:
        x = math.exp(score - max_score)
        exp_scores.append(x)
        total_sum += x

    probs = []
    for score in exp_scores:
        probs.append(score / total_sum)
    return probs


class OpenQA(object):
    def __init__(self):
        
        self.top_k = 10
        self.mrc_batch_size = 64
        self.seed = 42
        self.query_len = 64
        self.max_seq_length = 384
        
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_gpu = torch.cuda.device_count()
        random.seed(self.seed)
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(self.seed)

        self.doc_embedding_vectors = torch.load("pre_compute_vector/pre-compute.vec").to(self.device)
        with open("pre_compute_vector/paragraph.pkl", "rb") as f:
            self.paragraphs = pickle.load(f)

        ir_config = IRConfig("IR/mix_LM_config.json")
        mrc_config = MRCConfig("MRC/config.json")

        self.tokenizer = transformers.ElectraTokenizer.from_pretrained("monologg/koelectra-base-v3-discriminator")
        self.mrc_tokenizer = FullTokenizer('MRC/vocab_wiki_news_32k_0227.txt', do_basic_tokenize=True)

        ir_config.vocab_size = self.tokenizer.vocab_size
        ir_config.max_seq_length = self.max_seq_length
        
        self.ir_model = BertMultiTask(ir_config, train_mode='ict')
        self.ir_model.load_state_dict(torch.load('IR/model.bin'))  # IR 파인튜닝 모델 로딩
        #self.ir_model.load_state_dict(torch.load('checkpoint/pe_5e-05_126_5_4.bin'))  # IR 파인튜닝 모델 로딩
        

        self.ir_model.to(self.device)
        
        self.mrc_model = MixTraining(mrc_config)
        self.mrc_model.load_state_dict(torch.load('MRC/model.bin')) # MRC 파인튜닝 모델 로딩
        self.mrc_model.to(self.device)

        self.ir_model.eval()
        self.mrc_model.eval()

    def _ir_preproc(self, query):
        query_tokens = self.tokenizer.tokenize(query)

        max_tokens_for_que = self.query_len - 2

        if len(query_tokens) > max_tokens_for_que:
            query_tokens = query_tokens[0:max_tokens_for_que]
        
        q_tokens = []
        q_tokens.append(self.tokenizer.cls_token)
        for i in range(len(query_tokens)):
            q_tokens.append(query_tokens[i])
        q_tokens.append(self.tokenizer.sep_token)

        q_input_ids = [self.tokenizer.convert_tokens_to_ids(t) for t in q_tokens]
        q_input_mask = [1] * len(q_input_ids)

        while len(q_input_ids) < self.query_len:
            q_input_ids.append(0)
            q_input_mask.append(0)

        assert len(q_input_ids) == self.query_len
        assert len(q_input_mask) == self.query_len
        
        q_input_ids = torch.tensor(q_input_ids, dtype=torch.long)
        q_input_mask = torch.tensor(q_input_mask, dtype=torch.long)

        return q_input_ids, q_input_mask

    def __call__(self, query, doc_cat):
        
        # Information Retrieval Step
        input_ids, input_mask = self._ir_preproc(query)

        input_ids = input_ids.unsqueeze(0).to(self.device)
        input_mask = input_mask.unsqueeze(0).to(self.device)

        with torch.no_grad():
            q_encode = self.ir_model.bert(input_ids = input_ids, attention_mask = input_mask)
            q_encode_pooled = self.ir_model.decoder(q_encode)

        score = torch.matmul(q_encode_pooled, self.doc_embedding_vectors.t())
        sorted_score = torch.argsort(score, dim=1, descending=True)

        top_score = sorted_score[:, :self.top_k].squeeze(0)
        
        ir_paragraph = []
        paragraph_offset = [0]
        for i, s in enumerate(top_score):
            ir_paragraph.append(self.paragraphs[s][1])
            offset_end = len(ir_paragraph[i]) + 1
            paragraph_offset.append(paragraph_offset[i] + offset_end)

        if doc_cat:
            input_paragraph = [' '.join(ir_paragraph)]
        else:
            input_paragraph = ir_paragraph
            
        #print(input_paragraph)

        # Machine Reading Comprehension Step
        eval_examples = read_squad_examples(query, input_paragraph, f_tokenizer=self.mrc_tokenizer)
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=self.mrc_tokenizer,
            max_seq_length=512,
            doc_stride=256,
            max_query_length=64)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
        
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=self.mrc_batch_size)

        all_results = []
        for input_ids, input_mask, segment_ids, example_indices in eval_dataloader:
            input_ids = input_ids.to(self.device)
            input_mask = input_mask.to(self.device)
            segment_ids = segment_ids.to(self.device)
            with torch.no_grad():
                batch_start_logits, batch_end_logits = self.mrc_model(input_ids, segment_ids, input_mask)
            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_results.append(RawResult(unique_id=unique_id,
                                             start_logits=start_logits,
                                             end_logits=end_logits))

        answer, nbest = write_predictions(eval_examples, eval_features, all_results,
                                          n_best_size=5, max_answer_length=30)
        if doc_cat:
            nbest = nbest[0]
        else:
            nbest = [best[0] for best in nbest.values()]

        output = []
        for i, best in enumerate(nbest):
            for j, start in enumerate(paragraph_offset):
                if best["start_index"] < start:
                    break
            output.append({
                'answer': best["text"],
                'probability': best["probability"],
                'paragraph': ir_paragraph[j-1 if doc_cat else i]
            })
        output = sorted(output, key=itemgetter('probability'), reverse=True)

        return output

In [2]:
model = OpenQA()

In [10]:
query = input("Query: ")
result = model(query, True)
for i in result:
    if i['probability'] > 0.1:
        print('정답 :', i['answer']+'(%0.2f%%)' %(i['probability']*100))
        print(i['paragraph'])
        print("\n")

Query: 구글 본사의 위치는 어디인가요?
정답 : 미국 캘리포니아 주 마운틴뷰(86.63%)
구글 와이파이 ( Google WiFi ) 는 미국 캘리포니아 주 마운틴뷰에 배치된 자치 무선망 ( Municipal wireless network ) 이다 . 온전히 구글에 의해 투자되었으며 마운틴뷰 가로등에 주로 설치되어 있다 . 구글은 2010년까지 이 서비스를 무료로 유지하겠다고 언급했다 . 최초 서비스는 마운틴뷰 기지에서 2014년 5월 3일 구글에 의해 종료되었으며 새로운 공공 실외 와이파이를 제공하였다 .




In [4]:
# 위키 2018년 덤프 중 약 20만개에서 리스트와 표를 제외한 텍스트로 구성함
# 샘플 질문
2018년 취임한 대한민국 대통령은?
탄소와 산소가 반응하면 무엇이 되나요?
지구의 반지름은 얼마 입니까?
너 자신을 알라고 한 철학자는 누구인가?
방사선을 최초로 발견한 사람은?
애플의 창시자는?
구글 본사의 위치는 어디인가요?

Object `대통령은` not found.
Object `되나요` not found.
Object `입니까` not found.
Object `누구인가` not found.
Object `사람은` not found.
Object `창시자는` not found.
Object `어디인가요` not found.
