In [11]:
import wikipedia as wiki
import torch
import wikipediaapi
wiki_wiki = wikipediaapi.Wikipedia('en')
#page_py = wiki_wiki.page('Python_(programming_language)')
from collections import OrderedDict
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
'''科学上网代理设置'''
import os
os.environ["http_proxy"] = "http://127.0.0.1:10809"
os.environ["https_proxy"] = "http://127.0.0.1:10809"
class DocumentReader:
    def __init__(self, pretrained_model_name_or_path='bert-large-uncased'):
        self.READER_PATH = pretrained_model_name_or_path
        self.tokenizer = AutoTokenizer.from_pretrained(self.READER_PATH)
        self.model = AutoModelForQuestionAnswering.from_pretrained(self.READER_PATH)
        self.max_len = self.model.config.max_position_embeddings
        self.chunked = False

    def tokenize(self, question, text):
        self.inputs = self.tokenizer.encode_plus(question, text, add_special_tokens=True, return_tensors="pt")
        self.input_ids = self.inputs["input_ids"].tolist()[0]

        if len(self.input_ids) > self.max_len:
            self.inputs = self.chunkify()
            self.chunked = True

    def chunkify(self):
        

        
        qmask = self.inputs['token_type_ids'].lt(1)
        qt = torch.masked_select(self.inputs['input_ids'], qmask)
        chunk_size = self.max_len - qt.size()[0] - 1 # the "-1" accounts for
        
        chunked_input = OrderedDict()
        for k,v in self.inputs.items():
            q = torch.masked_select(v, qmask)
            c = torch.masked_select(v, ~qmask)
            chunks = torch.split(c, chunk_size)
            
            for i, chunk in enumerate(chunks):
                if i not in chunked_input:
                    chunked_input[i] = {}

                thing = torch.cat((q, chunk))
                if i != len(chunks)-1:
                    if k == 'input_ids':
                        thing = torch.cat((thing, torch.tensor([102])))
                    else:
                        thing = torch.cat((thing, torch.tensor([1])))

                chunked_input[i][k] = torch.unsqueeze(thing, dim=0)
        return chunked_input

    def get_answer(self):
        if self.chunked:
            answer = ''
            for k, chunk in self.inputs.items():
                s = self.model(**chunk)
                answer_start_scores = s.start_logits
                answer_end_scores = s.end_logits
                answer_start = torch.argmax(answer_start_scores)
                answer_end = torch.argmax(answer_end_scores) + 1

                ans = self.convert_ids_to_string(chunk['input_ids'][0][answer_start:answer_end])
                if ans != '[CLS]':
                    answer += ans + " / "
            return answer
        else:
            answer_start_scores, answer_end_scores = self.model(**self.inputs)

            answer_start = torch.argmax(answer_start_scores)  
            answer_end = torch.argmax(answer_end_scores) + 1  
        
            return self.convert_ids_to_string(self.inputs['input_ids'][0][
                                              answer_start:answer_end])

    def convert_ids_to_string(self, input_ids):
        return self.tokenizer.convert_tokens_to_string(self.tokenizer.convert_ids_to_tokens(input_ids))


In [13]:
questions = [
    'When was Barack Obama born?',
    'Why is the sky blue?',
    'What is the python'
]

reader = DocumentReader("deepset/bert-base-cased-squad2") 

# if you trained your own model using the training cell earlier, you can access it with this:
#reader = DocumentReader("./models/bert/bbu_squad2")

for question in questions:
    print(f"Question: {question}")
    results = wiki.search(question)
    print(results[0])
    str = results[0]
    page = wiki_wiki.page(str)
    print(f"Top wiki result: {page}")

    text = page.text

    reader.tokenize(question, text)
    print(f"Answer: {reader.get_answer()}")
    print()

Question: When was Barack Obama born?
Barack Obama Sr.
Top wiki result: Barack Obama Sr. (id: ??, ns: 0)


Token indices sequence length is longer than the specified maximum sequence length for this model (2997 > 512). Running this sequence through the model will result in indexing errors


Answer: 18 June 1934 / August 1961 / 4 August 1961 / 

Question: Why is the sky blue?
Diffuse sky radiation
Top wiki result: Diffuse sky radiation (id: ??, ns: 0)
Answer: Rayleigh scattering / its intrinsic nature, can illuminate under - canopy leaves permitting more efficient total whole - plant photosynthesis than would otherwise be the case, and also increasing evaporative cooling from vegetated surfaces / 

Question: What is the python
Python (programming language)
Top wiki result: Python (programming language) (id: ??, ns: 0)
Answer: pythonic / Python 3 variants / IronPython allows running Python 2. 7 programs / [CLS] What is the python [SEP] " releases are largely compatible with the previous version but introduce new features. The second part of the version number is incremented. Starting with Python 3. 9 / Pygame / Pypi. python. org / 

