In [1]:
import pandas as pd
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


# Preparação de contextos para prompts

In [2]:
from typing import List
from abc import abstractmethod
import re

In [3]:
from abc import ABC, abstractmethod


class PromptParser(ABC):

    @abstractmethod
    def prepare_context(self, questions: List[str], answers: List[str]) -> str:
        raise NotImplementedError

    @abstractmethod
    def parse_output(self, output: str) -> str:
        raise NotImplementedError

In [4]:
class NaiveParser(PromptParser):

    def prepare_context(self, questions: List[str], answers: List[str]) -> str:
        # Abordagem ingênua de apenas utilizar o texto das respostas como contexto.
        context_str = '\n'.join(answers) + ' '

        return context_str
    
    def parse_output(self, output: str) -> str:
        # Nenhum pos-processamento
        return output

In [5]:
class FormattedParser(PromptParser):

    def prepare_context(self, questions: List[str], answers: List[str]) -> str:
        # Abordagem com formatação do contexto da seguinte forma:
        # Question: [texto pergunta] Answer: [texto resposta] [end]

        context_str = [f'Question: {q} Answer: {a} [end]\n' for q,a in zip(questions, answers)]
        context_str = " ".join(context_str)
        context_str += "Question: "
        
        return context_str
    
    def parse_output(self, output: str) -> str:
       
        # Expressão regular para encontrar a formatação especificada.
        re_answer = re.compile(r"Answer: [\w\s?'\.,]* \[end\]")
        answer = re_answer.findall(output)[0]
        # Removendo componentes desnecessários
        answer = answer.replace('Answer:', '') \
                       .replace('[end]', '') \
                       .strip()

        return answer

# Carregando perguntas e respostas

In [6]:
qa_df = pd.read_csv('data/qa_data.csv', index_col=0)

In [7]:
qa_df.head()

Unnamed: 0,questions,answers
0,What is the hero's name in The Legend of Zelda?,"Despite most people's believes, he's called Link"
1,What are the names of the ghosts who chase Pac...,"Inky, Blinky, Pinky, and Clyde"
2,What's the name of the Mythbusters' crash test...,The Mythbusters' crash test dummy is called Bu...
3,What is an Oxford comma?,The hotly contested punctuation before a conju...
4,Who was the captain of the Enterprise in the p...,The captain of the Enterprise in the pilot epi...


# Preparando modelo para inferência

In [8]:
from transformers import TextGenerationPipeline

In [9]:
class QABot:

    PARSERS = {
        'naive': NaiveParser,
        'formatted': FormattedParser
    }

    def __init__(self, model_name: str, 
                       questions: List[str],
                       answers: List[str], 
                       parser: str = 'naive',
                       device: str = 'cpu',
                       debug: bool = False) -> None:
        
        self.debug: bool = debug

        self.gen_pipeline: TextGenerationPipeline = pipeline(
            'text-generation', 
            model=model_name, 
            device=device
        )

        if debug: print(f'[debug] Loaded generation pipeline: {self.gen_pipeline}.')

        self.questions: List[str] = questions
        self.answers: List[str] = answers

        if parser not in self.PARSERS.keys(): 
            raise NotImplementedError(f'Parser {parser} not implemented.')
        self.parser: PromptParser = self.PARSERS[parser]()

    def answer_to_question(self, question: str) -> str:

        prepared_input = self.parser.prepare_context(self.questions, self.answers) + question
        if self.debug: print(f"[debug] Prepared input using {self.parser}: '{prepared_input}'.")
        # Não há porquê modificar esses parâmetros. O tamanho de sequência é suficiente para todas as respostas.
        output = self.gen_pipeline(prepared_input, max_new_tokens=32, max_length=None, return_full_text=False, num_beams=1, num_return_sequences=1)

        output_str = output[0]['generated_text']
        if self.debug: print(f"[debug] Raw output '{output_str}'.")
        
        return self.parser.parse_output(output_str)
    

In [10]:
bot = QABot(
    model_name='facebook/opt-350m',
    questions=qa_df['questions'],
    answers=qa_df['answers'],
    parser='formatted',
    device='cpu',
    debug=True
)

[debug] Loaded generation pipeline: <transformers.pipelines.text_generation.TextGenerationPipeline object at 0x7fb49fdf33d0>.


In [11]:
question = "What does the acronym GNU represent?"

In [12]:
bot.answer_to_question(question)

[debug] Prepared input using <__main__.FormattedParser object at 0x7fb49fd2d0f0>: 'Question: What is the hero's name in The Legend of Zelda? Answer: Despite most people's believes, he's called Link [end]
 Question: What are the names of the ghosts who chase Pac Man and Ms. Pac Man? Answer: Inky, Blinky, Pinky, and Clyde [end]
 Question: What's the name of the Mythbusters' crash test dummy? Answer: The Mythbusters' crash test dummy is called Buster [end]
 Question: What is an Oxford comma? Answer: The hotly contested punctuation before a conjunction in a list [end]
 Question: Who was the captain of the Enterprise in the pilot episode of Star Trek? Answer: The captain of the Enterprise in the pilot episode was Captain Pike [end]
 Question: What is the symbol for the modulus operator in C? Answer: The percentage symbol is used as modulus operator in C [end]
 Question: What function is automatically called at the beginning of a C++ program? Answer: The main function [end]
 Question: Which 

'GNU is a recursive acronym meaning GNU is Not Unix'

# Testando para demais perguntas

In [17]:
bot.debug = False

In [14]:
bot.answer_to_question("what is the hero's name in The Legend of Zelda?")

"The hero's name in The Legend of Zelda is Link"

In [18]:
# Reformulação de 'What does the acronym GIMP represent?'
bot.answer_to_question("What does GIMP stands for?")

'GIMP stands for Generalized Vector Graphics.'

In [19]:
# Reformulação de "On what day, month and year was Brazil's independence declared?"
bot.answer_to_question("when Brazil's idependence was declared?") # resposta errada

'Brazil was declared to be a republic in 1848.'

In [22]:
bot.answer_to_question('In wich state is located Pico da Neblina?') # errado

'Pernambuco'

In [23]:
bot.answer_to_question('oldest state in Brazil') # correto

'Pernambuco'

In [25]:
bot.answer_to_question('newest Brazilian state') # errado

'Brasília'

In [27]:
bot.answer_to_question("what is brazil's capital")

'Brasília'

In [29]:
# pergunta nao existente
bot.answer_to_question("who is brazil's president?")

'BÚNICO'

In [34]:
bot.answer_to_question('How many states does Brazil have?')

'Brazil has 22 states and the Federal District.'