<a href="https://colab.research.google.com/github/dolmani38/QA/blob/main/Korean_QA_on_Wiki.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Building a QA System with BERT on Wikipedia

https://qa.fastforwardlabs.com/pytorch/hugging%20face/wikipedia/bert/transformers/2020/05/19/Getting_Started_with_QA.html

위의 내용을 한국어 QA로 변경

지식 base = wiki+네이버(view,kin,news)

In [2]:
!pip install wikipedia==1.4.0
!pip install sentence-transformers==1.1.1
!pip install transformers==4.6.0



##한국어 SQUAD 모델의 사용
https://huggingface.co/monologg/koelectra-base-v3-finetuned-korquad

##영어 + 한국어 STS 모델의 사용

#Extending Sentence Embeddings Models to New Languages

**Available Pre-trained Models**

*   **distiluse-base-multilingual-cased**: Supported languages: Arabic, Chinese, Dutch, English, French, German, Italian, Korean, Polish, Portuguese, Russian, Spanish, Turkish. Model is based on DistilBERT-multi-lingual.
*   **xlm-r-base-en-ko-nli-ststb**: Supported languages: English, Korean. Performance on Korean STSbenchmark: 81.47
*   **xlm-r-large-en-ko-nli-ststb**: Supported languages: English, Korean. Performance on Korean STSbenchmark: 84.05 --> 이거 사용!


참조:https://github.com/UKPLab/sentence-transformers/blob/master/docs/pretrained-models/multilingual-models.md



In [37]:
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

SQUAD_MODEL = "monologg/koelectra-base-v3-finetuned-korquad"
STS_MODEL = "xlm-r-large-en-ko-nli-ststb"

tokenizer = AutoTokenizer.from_pretrained(SQUAD_MODEL)
model = AutoModelForQuestionAnswering.from_pretrained(SQUAD_MODEL)
#reader = DocumentReader(SQUAD_MODEL) 
embedder = SentenceTransformer(STS_MODEL)

In [4]:
# Print iterations progress
class ProgressBar:

  def __init__(self,total=20, prefix = '', suffix = '', decimals = 1, length = 20, fill = '█', printEnd = "\r"):
    self.total = total
    self.prefix = prefix
    self.suffix = suffix
    self.decimals = decimals
    self.length = length
    self.fill = fill
    self.printEnd = printEnd
    self.ite = 0

  def printProgress(self,iteration, text):
      self.ite += iteration
      percent = ("{0:." + str(self.decimals) + "f}").format(100 * (self.ite / float(self.total)))

      filledLength = int(self.length * self.ite // self.total)
      bar = self.fill * filledLength + '-' * (self.length - filledLength)
      print(f'\r{self.prefix} |{bar}| {percent}% {self.suffix}  {text}', end="", flush=True)
      # Print New Line on Complete
      if self.ite == self.total: 
          print()

In [46]:
class AnswerVoter:
  def __init__(self, threshold_score=3,max_rank=5):
    self.answes = {}
    self.threshold_score= threshold_score
    self.max_rank = max_rank

  def add_ans(self,ans,score,src,pb):
    key = ans.replace(' ','')
    if score > self.threshold_score:
      #print(' --- Candidate answer:',ans,score)
      pb.printProgress(0,'Candidate answer:' + str(ans) + ' ' + str(score))
      if key in self.answes:
        self.answes[key][1] += score
        if src in self.answes[key][2]:
          pass
        else:
          self.answes[key][2].append(src)
      else:
        self.answes[key] = [ans,score,[src]]

  def get_ans(self):
    answer = []
    sorted_x = sorted(self.answes.items(), key=lambda kv: kv[1][1],reverse=True)
    for i in range(min(self.max_rank,len(sorted_x))):
      answer.append(sorted_x[i])

    return answer

  def print(self):
    answer = self.get_ans()
    for ans in answer:
      print('Answer:',ans[1][0], ' score:',ans[1][1], ' source:',ans[1][2])
      #print('Answer:',ans)

In [70]:
import wikipedia as wiki
import pprint as pp
from collections import OrderedDict
import torch

class DocumentReader:
    def __init__(self, _tokenizer, _model):
        #self.READER_PATH = pretrained_model_name_or_path
        self.tokenizer = _tokenizer #AutoTokenizer.from_pretrained(self.READER_PATH)
        self.model = _model #AutoModelForQuestionAnswering.from_pretrained(self.READER_PATH)
        self.max_len = self.model.config.max_position_embeddings
        self.chunked = False

        self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
        #self.tokenizer.to(self.device)
        self.model.to(self.device)

    def tokenize(self, question, context):
        self.inputs = self.tokenizer.encode_plus(question, context, max_length=512, truncation=True, add_special_tokens=True, return_tensors="pt")
        self.inputs.to(self.device)
        self.input_ids = self.inputs["input_ids"].tolist()[0]

        if len(self.input_ids) > self.max_len:
            self.inputs = self.chunkify()
            self.chunked = True
            print('input_ids:',len(self.input_ids),'max_len:',self.max_len)

    def chunkify(self):
        """ 
        Break up a long article into chunks that fit within the max token
        requirement for that Transformer model. 

        Calls to BERT / RoBERTa / ALBERT require the following format:
        [CLS] question tokens [SEP] context tokens [SEP].
        """

        # create question mask based on token_type_ids
        # value is 0 for question tokens, 1 for context tokens
        qmask = self.inputs['token_type_ids'].lt(1)
        qt = torch.masked_select(self.inputs['input_ids'], qmask)
        chunk_size = self.max_len - qt.size()[0] - 1 # the "-1" accounts for
        # having to add an ending [SEP] token to the end

        # create a dict of dicts; each sub-dict mimics the structure of pre-chunked model input
        chunked_input = OrderedDict()
        for k,v in self.inputs.items():
            q = torch.masked_select(v, qmask)
            c = torch.masked_select(v, ~qmask)
            chunks = torch.split(c, chunk_size)
            
            for i, chunk in enumerate(chunks):
                if i not in chunked_input:
                    chunked_input[i] = {}

                thing = torch.cat((q, chunk))
                if i != len(chunks)-1:
                    if k == 'input_ids':
                        thing = torch.cat((thing, torch.tensor([102])))
                    else:
                        thing = torch.cat((thing, torch.tensor([1])))

                chunked_input[i][k] = torch.unsqueeze(thing, dim=0)
        return chunked_input.to(self.device)

    def get_answer(self,answer,src,pb):
        if self.chunked:
            
            for k, chunk in self.inputs.items():
                #chunk.to(self.device)
                outputs = self.model(**chunk)

                answer_start = torch.argmax(outputs[0])
                answer_end = torch.argmax(outputs[1]) + 1
                ans = self.convert_ids_to_string(chunk['input_ids'][0][answer_start:answer_end])
                if ans.startswith(('[CLS]','[SEP]',' ','°')) or ans=='':
                    #raise Exception('No Answer')
                    pass
                else:
                    score = float(torch.max(outputs[0]))
                    #print(ans,score)
                    answer.add_ans(ans,score,src,pb)
                    #answer += ans + ', '
                    #print(ans,torch.max(answer_start_scores),torch.max(answer_end_scores))
                    #break
        else:
            outputs = self.model(**self.inputs)

            #print(outputs[0])

            answer_start = torch.argmax(outputs[0])  # get the most likely beginning of answer with the argmax of the score
            answer_end = torch.argmax(outputs[1]) + 1  # get the most likely end of answer with the argmax of the score
        
            ans = self.convert_ids_to_string(self.inputs['input_ids'][0][
                                              answer_start:answer_end])
            
            if ans in ['',' ','  ']:
              pass
            else:
              score = float(torch.max(outputs[0]))
              #print(ans,score)
              answer.add_ans(ans,score,src,pb)
        #if len(answer) == 0:
        #  raise Error("No Answer") 
        return answer
        
    def convert_ids_to_string(self, input_ids):
        return self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(input_ids))

In [71]:
import sys
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
import wikipedia as wiki
import pprint as pp
from collections import OrderedDict
import scipy
import requests
from bs4 import BeautifulSoup

class Korean_QA_on_Wiki:
  def __init__(self, document_reader,sentence_embedder):
    self.reader = document_reader
    self.embedder = sentence_embedder
    wiki.set_lang('ko')

  def __search_from_wiki(self,question,max_rank):
    results = wiki.search(question,results=max_rank)
    contents = []
    for result in results:
      try:
        page = wiki.page(result)
        #print(f"Top wiki result: {page}")
        text = page.content
        contents.append((text,page))
      except Exception as ex:
        print(ex)
    return contents

  def __search_from_naver(self,question,max_rank):
    contents = []
    url = 'https://search.naver.com/search.naver'
    for w in ['view','kin','news','kdic']:
      params = {'query': question,'where': w,}
      response = requests.get(url, params=params)
      html = response.text
      #뷰티풀소프의 인자값 지정
      soup = BeautifulSoup(html, 'html.parser')
      #쪼개기
      #title_list = soup.find_all('a', href=True)
      title_list = soup.select('.api_txt_lines')
      #print(title_list)
      tmp = []
      for tag in title_list:
        tmp.append(tag.text)
      contents.append((''.join(tmp),url + '?where='+w))
      tmp.clear()
    #print(contents)      
    return contents



  def question(self, questions, max_rank = 5):
    answers = {}
    for question in questions:
        print(f"Question: {question}")
        
        contents = []
        contents.extend(self.__search_from_naver(question,max_rank))
        contents.extend(self.__search_from_wiki(question,max_rank))
        src_count = len(contents)
        pb = ProgressBar(total=src_count+1,prefix='Searching answers...')
        pb.printProgress(0,str(src_count)+' of sources')
        #print('-- Source count : ', len(contents))
        answer = AnswerVoter()
        if self.embedder is None:
            for context, src in contents:
                #text = contents[idx][0]
                pb.printProgress(+1,src)
                #print('-- source :', contents[idx][1])
                try:
                    self.reader.tokenize(question, context)
                    self.reader.get_answer(answer,src,pb)
                    #answer.add_src(contents[idx][1])
                    #t = (self.reader.get_answer(answer),contents[idx][1])
                    #print(f"Answer: {t[0]}", f" from {t[1]}")
                    #answer_list.append(t)
                except Exception as ex:
                    #pb.printProgress(0,sys.exc_info()[0])
                    print(ex,sys.exc_info())
                    #pass    
        else:
            corpus_embeddings = self.embedder.encode([a for (a,b) in contents],show_progress_bar=False) 
            query_embeddings = self.embedder.encode([question])
            distances = scipy.spatial.distance.cdist(query_embeddings, corpus_embeddings, "cosine")[0]

            results = zip(range(len(distances)), distances)
            results = sorted(results, key=lambda x: x[1])
            

            for idx, distance in results:
                context = contents[idx][0]
                pb.printProgress(+1,contents[idx][1])
                #print('-- source :', contents[idx][1])
                try:
                    self.reader.tokenize(question, context)
                    self.reader.get_answer(answer,contents[idx][1],pb)
                    #answer.add_src(contents[idx][1])
                    #t = (self.reader.get_answer(answer),contents[idx][1])
                    #print(f"Answer: {t[0]}", f" from {t[1]}")
                    #answer_list.append(t)
                    
                except Exception as ex:
                    #pb.printProgress(0,sys.exc_info()[0])
                    print(ex,sys.exc_info())
                    #pass    
        answers[question] = answer.get_ans()
        pb.printProgress(+1,"완료")
        answer.print()
        print(' ')
    return answers


In [72]:
kqaw = Korean_QA_on_Wiki(DocumentReader(tokenizer,model), embedder)

In [73]:
answers = kqaw.question(["아브라함은 자식이 몇명인가?"])

Question: 아브라함은 자식이 몇명인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 13명  score: 10.882498741149902  source: [<WikipediaPage '소저너 트루스'>]
Answer: 테월데  score: 9.216580390930176  source: [<WikipediaPage '에티오피아 이름'>]
Answer: 6명  score: 6.823949337005615  source: [<WikipediaPage '아브라함'>]
Answer: 45명 40명  score: 6.245438098907471  source: ['https://search.naver.com/search.naver?where=view']
Answer: 몇 명  score: 3.558769941329956  source: ['https://search.naver.com/search.naver?where=news']
 


In [74]:
answers = kqaw.question(["베트남에서 가장 인기있는 관광지는 어디인가요?"])

Question: 베트남에서 가장 인기있는 관광지는 어디인가요?
Searching answers... |████████████████████| 100.0%   완료
Answer: 냐짱  score: 9.128152847290039  source: [<WikipediaPage '냐짱'>]
Answer: 축구  score: 8.47502613067627  source: [<WikipediaPage '태국의 문화'>]
Answer: 다낭  score: 8.257136344909668  source: ['https://search.naver.com/search.naver?where=view']
Answer: 하롱 베이  score: 6.724521160125732  source: ['https://search.naver.com/search.naver?where=kin']
 


In [76]:
answers = kqaw.question(["베트남 다낭에서 맛집 한국식당 이름은?"])

Question: 베트남 다낭에서 맛집 한국식당 이름은?
Searching answers... |████████████████████| 100.0%   완료
Answer: 홍대  score: 7.922461986541748  source: ['https://search.naver.com/search.naver?where=kin']
Answer: 배베식당  score: 4.599730014801025  source: ['https://search.naver.com/search.naver?where=view']
 


In [79]:
answers = kqaw.question(["코로나19는 언제 종식되는가?"])

Question: 코로나19는 언제 종식되는가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 2021년  score: 9.98672866821289  source: ['https://search.naver.com/search.naver?where=view']
Answer: 6월 15일  score: 9.97450065612793  source: [<WikipediaPage '뉴질랜드의 코로나19 범유행'>]
Answer: 2021년 9월 중순  score: 6.855931758880615  source: ['https://search.naver.com/search.naver?where=kin']
 


In [80]:
answers = kqaw.question(["아시아나IDT의 대표이사는 누구인가?"])

Question: 아시아나IDT의 대표이사는 누구인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 박세창  score: 10.907037734985352  source: ['https://search.naver.com/search.naver?where=view']
 


In [82]:
answers = kqaw.question(["아시아나항공은 어디에 매각되는가?"])

Question: 아시아나항공은 어디에 매각되는가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 금호아시아나그룹  score: 8.167031288146973  source: ['https://search.naver.com/search.naver?where=kin']
Answer: 대한항공  score: 6.548533916473389  source: ['https://search.naver.com/search.naver?where=news']
Answer: HDC현대산업개발  score: 6.190684795379639  source: ['https://search.naver.com/search.naver?where=kdic']
 


In [None]:
answers = kqaw.question(["북한에서 실질적인 권력자는 누구인가?",
                           "세계에서 가장 넓은 호수는?",
                           "오로라가 가장 잘 보이는 곳은?",
                           "심장이 죄어오듯이 아프면 의심되는 병은 무엇인가?",
                           "항문에서 피가 나는 병은 무엇인가?",
                           "김재규는 박정희를 왜 죽였는가?",
                           "케네디를 죽인 암살범은 누구인가?",
                           "술 취하지 않는 방법은?",
                           "사람을 사랑해서 생기는 병은?",
                           "부모는 자식을 왜 사랑하는가?",
                           "나의 와이프는 나를 사랑하는가?",
                           "신은 존재 하는가?",
                           "사람의 인생에서 가장 소중한 것은 무엇인가?",
                           "바람난 여자는 다시 돌아올 수 있는가?",
                           "위가 쓰리고 아플 때 어떤 약을 복용해야 하는가?",
                           "눈알이 빠지면 어떻게 되는가?"])

Question: 북한에서 실질적인 권력자는 누구인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 김일성  score: 37.01249718666077  source: ['https://search.naver.com/search.naver?where=kdic', <WikipediaPage '김일성'>]
Answer: 김정은  score: 11.77053427696228  source: ['https://search.naver.com/search.naver?where=news', 'https://search.naver.com/search.naver?where=view']
 
Question: 세계에서 가장 넓은 호수는?
Searching answers... |████████████████████| 100.0%   완료
Answer: 카스피해  score: 29.744450569152832  source: ['https://search.naver.com/search.naver?where=kin', 'https://search.naver.com/search.naver?where=kdic', <WikipediaPage '호수'>]
Answer: 티티카카 호  score: 11.382185935974121  source: [<WikipediaPage '남아메리카'>]
 
Question: 오로라가 가장 잘 보이는 곳은?
Searching answers... |████████████████████| 100.0%   완료
Answer: 계란형 지대  score: 32.557658195495605  source: ['https://search.naver.com/search.naver?where=kin', 'https://search.naver.com/search.naver?where=kdic']
Answer: 남극및 북극 양극지방  score: 19.048751831054688  source: ['htt

In [None]:
answers = kqaw.question(["아시아나항공은 어디에 매각될 것인가?",
                         "박세창은 어느 회사의 사장인가?"])

Question: 아시아나항공은 어디에 매각될 것인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: HDC현대산업개발  score: 15.071530818939209  source: ['https://search.naver.com/search.naver?where=kdic']
Answer: 한진상사  score: 11.714262008666992  source: ['https://search.naver.com/search.naver?where=kdic']
Answer: 산업은행  score: 5.711127758026123  source: ['https://search.naver.com/search.naver?where=news']
Answer: LG그룹  score: 4.931703090667725  source: [<WikipediaPage '문화방송'>]
 
Question: 박세창은 어느 회사의 사장인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 아시아나IDT  score: 61.1358003616333  source: ['https://search.naver.com/search.naver?where=view', 'https://search.naver.com/search.naver?where=news']
Answer: 금호타이어  score: 16.81703233718872  source: [<WikipediaPage '금호석유화학'>]
Answer: 금호아시아나그룹  score: 9.963479042053223  source: ['https://search.naver.com/search.naver?where=news']
 


In [None]:
answers = kqaw.question(["아시아나항공 사장의 이름은?",
                         "금호건설의 사장은 누구인가?"])

Question: 아시아나항공 사장의 이름은?
Searching answers... |████████████████████| 100.0%   완료
Answer: 한창수  score: 20.9244384765625  source: ['https://search.naver.com/search.naver?where=view', 'https://search.naver.com/search.naver?where=news']
Answer: 윤영두  score: 11.341591835021973  source: ['https://search.naver.com/search.naver?where=kin']
 
Question: 금호건설의 사장은 누구인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 박삼구  score: 10.64574146270752  source: ['https://search.naver.com/search.naver?where=kdic']
Answer: 이서형  score: 10.147002220153809  source: ['https://search.naver.com/search.naver?where=view']
 


In [None]:
# 한진중공업 매각 우선협상대상자
answers = kqaw.question(["한진중공업 매각 우선협상대상자는 어디인가?"])

Question: 한진중공업 매각 우선협상대상자는 어디인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 동부건설  score: 17.863832473754883  source: ['https://search.naver.com/search.naver?where=view']
Answer: 동부건설 컨소  score: 11.456985473632812  source: ['https://search.naver.com/search.naver?where=view']
 


In [None]:
answers = kqaw.question(["아파트 값은 계속 오를 것인가?",
                         "코로나는 언제 종식 될 것인가?"])

Question: 아파트 값은 계속 오를 것인가?
Searching answers... |████████████████████| 100.0%   완료
 
Question: 코로나는 언제 종식 될 것인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 2020년  score: 7.0133490562438965  source: ['https://search.naver.com/search.naver?where=kdic']
Answer: 2021년 9월 중순  score: 4.102957248687744  source: ['https://search.naver.com/search.naver?where=kin']
 


In [None]:
answers = kqaw.question(["단백질의 화학식 구성은 어떻게 되는가?"])

Question: 단백질의 화학식 구성은 어떻게 되는가?
Searching answers... |████████████████████| 100.0%   완료
Answer: HO2CCH2NH2  score: 10.998641014099121  source: [<WikipediaPage '글라이신'>]
Answer: 글리코실화  score: 10.18725872039795  source: [<WikipediaPage '세린'>]
 


In [None]:
answers = kqaw.question(["우리나라 특허의 권리보장 기간은 몇년인가?"])

Question: 우리나라 특허의 권리보장 기간은 몇년인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 20년  score: 31.303051948547363  source: ['https://search.naver.com/search.naver?where=view', 'https://search.naver.com/search.naver?where=kin', 'https://search.naver.com/search.naver?where=kdic']
 


In [None]:
answers = kqaw.question(["발열 마른기침 피로감 등의 증상을 보이면 어떤 병이 의심되는가?"])

Question: 발열 마른기침 피로감 등의 증상을 보이면 어떤 병이 의심되는가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 코로나 19  score: 15.271196842193604  source: ['https://search.naver.com/search.naver?where=kin', 'https://search.naver.com/search.naver?where=view']
Answer: 광견병  score: 11.497452735900879  source: ['https://search.naver.com/search.naver?where=kdic']
 


In [None]:
answers = kqaw.question(["흉부통증과 호흡곤란, 쉰목소리, 가끔 피가 썩인 가래도 있습니다. 어떤 병일까요?"])

Question: 흉부통증과 호흡곤란, 쉰목소리, 가끔 피가 썩인 가래도 있습니다. 어떤 병일까요?
Searching answers... |████████████████████| 100.0%   완료
Answer: 폐암  score: 12.317249298095703  source: ['https://search.naver.com/search.naver?where=kin', 'https://search.naver.com/search.naver?where=view']
 


In [None]:
answers = kqaw.question(["똥을 싸고 나면 휴지에 피가 뭍습니다. 의심되는 병은 무엇인가요?"])

Question: 똥을 싸고 나면 휴지에 피가 뭍습니다. 의심되는 병은 무엇인가요?
Searching answers... |████████████████████| 100.0%   완료
Answer: 치열  score: 16.448171615600586  source: ['https://search.naver.com/search.naver?where=kin']
Answer: 궤양성 대장염  score: 9.296835899353027  source: ['https://search.naver.com/search.naver?where=kin']
 


In [None]:
answers = kqaw.question(["신경망 알고리즘의 활성화 함수에는 어떤 것이 있나요?"])

Question: 신경망 알고리즘의 activation function에는 어떤 것이 있나요?
Searching answers... |████████████████████| 100.0%   완료
Answer: 활성화 함수  score: 7.86516809463501  source: ['https://search.naver.com/search.naver?where=view']
Answer: 활성함수  score: 6.889105796813965  source: ['https://search.naver.com/search.naver?where=view']
 


In [None]:
answers = kqaw.question(["현존하는 인공지능 중 가장 성능이 우수한 것은 무엇입니까?"])

Question: 현존하는 인공지능 중 가장 성능이 우수한 것은 무엇입니까?
Searching answers... |████████████████████| 100.0%   완료
Answer: Global Hawk  score: 17.53031873703003  source: ['https://search.naver.com/search.naver?where=kdic', <WikipediaPage '무인 항공기'>]
Answer: 슈퍼 컴퓨터  score: 10.809731483459473  source: ['https://search.naver.com/search.naver?where=kin']
 


In [None]:
answers = kqaw.question(["이세돌을 이긴 것은 무엇입니까?"])

Question: 이세돌을 이긴 것은 무엇입니까?
Searching answers... |████████████████████| 100.0%   완료
Answer: 알파고  score: 60.73257637023926  source: ['https://search.naver.com/search.naver?where=view', <WikipediaPage '인공지능'>, 'https://search.naver.com/search.naver?where=kin', <WikipediaPage '알파고'>]
Answer: 깔끔함  score: 8.735248565673828  source: [<WikipediaPage '인공지능'>]
 


In [None]:
answers = kqaw.question(["피타고라스는 어느 나라 사람인가?"])

Question: 피타고라스는 어느 나라 사람인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 그리스  score: 40.09191274642944  source: ['https://search.naver.com/search.naver?where=kin', 'https://search.naver.com/search.naver?where=view', 'https://search.naver.com/search.naver?where=news', 'https://search.naver.com/search.naver?where=kdic', <WikipediaPage '수학자'>]
Answer: 이집트  score: 6.741414546966553  source: ['https://search.naver.com/search.naver?where=view']
 


In [None]:
answers = kqaw.question(["아이즈원 멤버 인원수는?"])

Question: 아이즈원 멤버 인원수는?
Searching answers... |████████████████████| 100.0%   완료
Answer: 12명  score: 11.964734077453613  source: ['https://search.naver.com/search.naver?where=news']
Answer: 4명  score: 11.074978828430176  source: ['https://search.naver.com/search.naver?where=kin']
Answer: 8명  score: 10.20055103302002  source: ['https://search.naver.com/search.naver?where=kin']
Answer: 300명  score: 6.7700066566467285  source: ['https://search.naver.com/search.naver?where=kdic']
Answer: 12  score: 6.522476673126221  source: ['https://search.naver.com/search.naver?where=view']
 


In [None]:
answers = kqaw.question(["트와이스 중에 가장 인기 있는 사람은?"])

Question: 트와이스 중에 가장 인기 있는 사람은?
Searching answers... |████████████████████| 100.0%   완료
Answer: 지효  score: 16.386046409606934  source: ['https://search.naver.com/search.naver?where=news', <WikipediaPage '청하 (가수)'>]
Answer: 가는 세월  score: 10.251775741577148  source: ['https://search.naver.com/search.naver?where=kdic']
 


In [None]:
answers = kqaw.question(["벤츠 자동차가 처음 발명된 년도는?"])

Question: 벤츠 자동차가 처음 발명된 년도는?
Searching answers... |████████████████████| 100.0%   완료
Answer: 1886년  score: 33.59215593338013  source: ['https://search.naver.com/search.naver?where=kin', 'https://search.naver.com/search.naver?where=kdic', <WikipediaPage '자동차'>, <WikipediaPage '만하임'>]
Answer: 1883년  score: 11.766658782958984  source: ['https://search.naver.com/search.naver?where=kin']
 


In [None]:
answers = kqaw.question(["교통사고 대비를 위해 들어야 하는 보험은 무엇인가?"])

Question: 교통사고 대비를 위해 들어야 하는 보험은 무엇인가?
Searching answers... |████████████████████| 100.0%   완료
Answer: 자동차보험  score: 10.861504554748535  source: ['https://search.naver.com/search.naver?where=news']
Answer: 운전자보험  score: 10.007701873779297  source: ['https://search.naver.com/search.naver?where=view', 'https://search.naver.com/search.naver?where=kin']
 


In [None]:
answers = kqaw.question(["너의 이름은 무엇이니?"])

In [None]:
answers = kqaw.question(["너는 남자니 여자니?"])

#KorQuad2.0 학습.

https://github.com/huggingface/datasets

https://huggingface.co/transformers/custom_datasets.html#qa-squad



In [None]:

if True:
    from google.colab import drive
    drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
!pip install datasets



In [None]:
from datasets import list_datasets, load_dataset, list_metrics, load_metric, load_from_disk

# Print all the available datasets
print(list_datasets())

# Load a dataset and print the first example in the training set
squad_dataset = load_dataset('squad_kor_v2',cache_dir ='/content/drive/MyDrive/korQuAD2.1/dataset')


['acronym_identification', 'ade_corpus_v2', 'adversarial_qa', 'aeslc', 'afrikaans_ner_corpus', 'ag_news', 'ai2_arc', 'air_dialogue', 'ajgt_twitter_ar', 'allegro_reviews', 'allocine', 'alt', 'amazon_polarity', 'amazon_reviews_multi', 'amazon_us_reviews', 'ambig_qa', 'amttl', 'anli', 'app_reviews', 'aqua_rat', 'aquamuse', 'ar_cov19', 'ar_res_reviews', 'ar_sarcasm', 'arabic_billion_words', 'arabic_pos_dialect', 'arabic_speech_corpus', 'arcd', 'arsentd_lev', 'art', 'arxiv_dataset', 'ascent_kb', 'aslg_pc12', 'asnq', 'asset', 'assin', 'assin2', 'atomic', 'autshumato', 'babi_qa', 'banking77', 'bbaw_egyptian', 'bbc_hindi_nli', 'bc2gm_corpus', 'best2009', 'bianet', 'bible_para', 'big_patent', 'billsum', 'bing_coronavirus_query_set', 'biomrc', 'blended_skill_talk', 'blimp', 'blog_authorship_corpus', 'bn_hate_speech', 'bookcorpus', 'bookcorpusopen', 'boolq', 'bprec', 'break_data', 'brwac', 'bsd_ja_en', 'bswac', 'c3', 'c4', 'cail2018', 'caner', 'capes', 'catalonia_independence', 'cawac', 'cbt', 'c

Reusing dataset squad_kor_v2 (/content/drive/MyDrive/korQuAD2.1/dataset/squad_kor_v2/squad_kor_v2/2.1.0/8e4ee4e5757761cf13f00b2d4e4cef2e842c0ea3c57050fec9fafc8fec60e128)


ValueError: ignored

In [None]:
import json
from pathlib import Path

def read_squad(path):
    path = Path(path)
    with open(path, 'rb') as f:
        squad_dict = json.load(f)

    contexts = []
    questions = []
    answers = []
    for group in squad_dict['data']:
        for passage in group['paragraphs']:
            context = passage['context']
            for qa in passage['qas']:
                question = qa['question']
                for answer in qa['answers']:
                    contexts.append(context)
                    questions.append(question)
                    answers.append(answer)

    return contexts, questions, answers

train_contexts, train_questions, train_answers = read_squad('/content/drive/MyDrive/korQuAD2.1/dataset/squad_kor_v2/squad_kor_v2/2.1.0/8e4ee4e5757761cf13f00b2d4e4cef2e842c0ea3c57050fec9fafc8fec60e128/squad_kor_v2-train.arrow')
val_contexts, val_questions, val_answers = read_squad('/content/drive/MyDrive/korQuAD2.1/dataset/squad_kor_v2/squad_kor_v2/2.1.0/8e4ee4e5757761cf13f00b2d4e4cef2e842c0ea3c57050fec9fafc8fec60e128/squad_kor_v2-validation.arrow')

In [None]:
squad_dataset['train']

Dataset({
    features: ['id', 'title', 'context', 'question', 'answer', 'url', 'raw_html'],
    num_rows: 83486
})

In [None]:
!pip install transformers

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/d8/b2/57495b5309f09fa501866e225c84532d1fd89536ea62406b2181933fb418/transformers-4.5.1-py3-none-any.whl (2.1MB)
[K     |████████████████████████████████| 2.1MB 4.2MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/75/ee/67241dc87f266093c533a2d4d3d69438e57d7a90abb216fa076e7d475d4a/sacremoses-0.0.45-py3-none-any.whl (895kB)
[K     |████████████████████████████████| 901kB 19.5MB/s 
Collecting tokenizers<0.11,>=0.10.1
[?25l  Downloading https://files.pythonhosted.org/packages/ae/04/5b870f26a858552025a62f1649c20d29d2672c02ff3c3fb4c688ca46467a/tokenizers-0.10.2-cp37-cp37m-manylinux2010_x86_64.whl (3.3MB)
[K     |████████████████████████████████| 3.3MB 22.1MB/s 
Installing collected packages: sacremoses, tokenizers, transformers
Successfully installed sacremoses-0.0.45 tokenizers-0.10.2 transformers-4.5.1


In [None]:
train_contexts = squad_dataset['train']['context']
train_questions = squad_dataset['train']['question']
train_answers = squad_dataset['train']['answer']

In [None]:
val_contexts = squad_dataset['validation']['context']
val_questions = squad_dataset['validation']['question']
val_answers = squad_dataset['validation']['answer']

In [None]:
def add_end_idx(answers, contexts):
    for answer, context in zip(answers, contexts):
        gold_text = answer['text']
        start_idx = answer['answer_start']
        end_idx = start_idx + len(gold_text)

        # sometimes squad answers are off by a character or two – fix this
        if context[start_idx:end_idx] == gold_text:
            answer['answer_end'] = end_idx
        elif context[start_idx-1:end_idx-1] == gold_text:
            answer['answer_start'] = start_idx - 1
            answer['answer_end'] = end_idx - 1     # When the gold label is off by one character
        elif context[start_idx-2:end_idx-2] == gold_text:
            answer['answer_start'] = start_idx - 2
            answer['answer_end'] = end_idx - 2     # When the gold label is off by two characters

add_end_idx(train_answers, train_contexts)
add_end_idx(val_answers, val_contexts)

In [None]:
from transformers import DistilBertTokenizerFast

tokenizer = DistilBertTokenizerFast.from_pretrained('monologg/kobert')

train_encodings = tokenizer(train_contexts, train_questions, truncation=True, padding=True)
val_encodings = tokenizer(val_contexts, val_questions, truncation=True, padding=True)

In [None]:
def add_token_positions(encodings, answers):
    start_positions = []
    end_positions = []
    for i in range(len(answers)):
        start_positions.append(encodings.char_to_token(i, answers[i]['answer_start']))
        end_positions.append(encodings.char_to_token(i, answers[i]['answer_end'] - 1))

        # if start position is None, the answer passage has been truncated
        if start_positions[-1] is None:
            start_positions[-1] = tokenizer.model_max_length
        if end_positions[-1] is None:
            end_positions[-1] = tokenizer.model_max_length

    encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

add_token_positions(train_encodings, train_answers)
add_token_positions(val_encodings, val_answers)