In [1]:
import pandas as pd
from transformers import pipeline

  from .autonotebook import tqdm as notebook_tqdm


# Preparação de contextos para prompts

In [2]:
from typing import List
from abc import abstractmethod
import re

In [3]:
from abc import ABC, abstractmethod


class PromptParser(ABC):

    @abstractmethod
    def prepare_context(self, questions: List[str], answers: List[str]) -> str:
        raise NotImplementedError

    @abstractmethod
    def parse_output(self, output: str) -> str:
        raise NotImplementedError

In [4]:
class NaiveParser(PromptParser):

    def prepare_context(self, questions: List[str], answers: List[str]) -> str:
        # Abordagem ingênua de apenas utilizar o texto das respostas como contexto.
        context_str = '\n'.join(answers) + ' '

        return context_str
    
    def parse_output(self, output: str) -> str:
        # Nenhum pos-processamento
        return output

In [5]:
class FormattedParser(PromptParser):

    def prepare_context(self, questions: List[str], answers: List[str]) -> str:
        # Abordagem com formatação do contexto da seguinte forma:
        # [question] [texto pergunta] [answer] [texto resposta] [end]

        context_str = [f'[question] {q} [answer] {a} [end]\n' for q,a in zip(questions, answers)]
        context_str = " ".join(context_str)
        context_str += "[question] "
        
        return context_str
    
    def parse_output(self, output: str) -> str:
       
        # Expressão regular para encontrar a formatação especificada.
        re_answer = re.compile(r"\[answer\] [\w\s?'\.,]* \[end\]")
        try:
            answer = re_answer.findall(output)[0]
             # Removendo componentes desnecessários
            answer = answer.replace('[answer]', '') \
                        .replace('[end]', '') \
                        .strip()
        except Exception as e:
            print(f"Unable to parse: '{output}'")
            answer = output

        return answer

# Carregando perguntas e respostas

In [6]:
qa_df = pd.read_csv('data/qa_data.csv', index_col=0)

In [7]:
qa_df.head()

Unnamed: 0,questions,answers
0,What is the hero's name in The Legend of Zelda?,The name of the hero in The Legend of Zelda is...
1,What are the names of the ghosts who chase Pac...,The ghosts who chase Pac Man and Ms Pac Man ar...
2,What's the name of the Mythbusters' crash test...,The Mythbusters' crash test dummy is called Bu...
3,What is an Oxford comma?,The Oxford comma is a hotly contested punctuat...
4,Who was the captain of the Enterprise in the p...,The captain of the Enterprise in the pilot epi...


# Preparando modelo para inferência

In [8]:
from transformers import TextGenerationPipeline

In [9]:
class QABot:

    PARSERS = {
        'naive': NaiveParser,
        'formatted': FormattedParser
    }

    def __init__(self, model_name: str, 
                       questions: List[str],
                       answers: List[str], 
                       parser: str = 'naive',
                       device: str = 'cpu',
                       debug: bool = False) -> None:
        
        self.debug: bool = debug

        self.gen_pipeline: TextGenerationPipeline = pipeline(
            'text-generation', 
            model=model_name, 
            device=device
        )

        if debug: print(f'[debug] Loaded generation pipeline: {self.gen_pipeline}.')

        self.questions: List[str] = questions
        self.answers: List[str] = answers

        if parser not in self.PARSERS.keys(): 
            raise NotImplementedError(f'Parser {parser} not implemented.')
        self.parser: PromptParser = self.PARSERS[parser]()

    def answer_to_question(self, question: str) -> str:

        prepared_input = self.parser.prepare_context(self.questions, self.answers) + question
        if self.debug: print(f"[debug] Prepared input using {self.parser}: '{prepared_input}'.")
        # Não há porquê modificar esses parâmetros. O tamanho de sequência é suficiente para todas as respostas.
        output = self.gen_pipeline(prepared_input, max_new_tokens=64, max_length=None, return_full_text=False, num_beams=1, num_return_sequences=1)

        output_str = output[0]['generated_text']
        if self.debug: print(f"[debug] Raw output '{output_str}'.")
        
        return self.parser.parse_output(output_str)
    

In [10]:
bot = QABot(
    model_name='facebook/opt-350m',
    # model_name='facebook/opt-125m',
    questions=qa_df['questions'],
    answers=qa_df['answers'],
    parser='formatted',
    device='cpu',
    debug=True
)

[debug] Loaded generation pipeline: <transformers.pipelines.text_generation.TextGenerationPipeline object at 0x7f9cfb47f430>.


In [11]:
question = "What does the acronym GNU represent?"

In [12]:
bot.answer_to_question(question)

[debug] Prepared input using <__main__.FormattedParser object at 0x7f9cfb1c10c0>: '[question] What is the hero's name in The Legend of Zelda? [answer] The name of the hero in The Legend of Zelda is Link [end]
 [question] What are the names of the ghosts who chase Pac Man and Ms. Pac Man? [answer] The ghosts who chase Pac Man and Ms Pac Man are called Inky, Blinky, Pinky, and Clyde [end]
 [question] What's the name of the Mythbusters' crash test dummy? [answer] The Mythbusters' crash test dummy is called Buster [end]
 [question] What is an Oxford comma? [answer] The Oxford comma is a hotly contested punctuation before a conjunction in a list [end]
 [question] Who was the captain of the Enterprise in the pilot episode of Star Trek? [answer] The captain of the Enterprise in the pilot episode of Star Trek was Captain Pike [end]
 [question] What is the symbol for the modulus operator in C? [answer] The percentage symbol is used as modulus operator in C [end]
 [question] What function is aut

'GNU is a recursive acronym meaning GNU is Not Unix'

# Testando para demais perguntas

In [13]:
bot.debug = False

In [14]:
bot.answer_to_question("what is the hero's name in The Legend of Zelda?")

"The hero's name in The Legend of Zelda is Link"

In [15]:
# Reformulação de 'What does the acronym GIMP represent?'
bot.answer_to_question("What does GIMP stands for?")

'GIMP stands for Generalized Vector Graphics.'

In [16]:
# Reformulação de "On what day, month and year was Brazil's independence declared?"
bot.answer_to_question("when Brazil's idependence was declared?") # resposta errada

Unable to parse: ' [answer] Brazil's idependence was declared on September 7, 1822, Brazil's independence was declared on September 7, 1822, Brazil's independence was declared on September 7, 1822, Brazil's independence was declared on September 7, 1822, Brazil's independence was declared on September 7, 1822'


" [answer] Brazil's idependence was declared on September 7, 1822, Brazil's independence was declared on September 7, 1822, Brazil's independence was declared on September 7, 1822, Brazil's independence was declared on September 7, 1822, Brazil's independence was declared on September 7, 1822"

In [17]:
bot.answer_to_question('In wich state is located Pico da Neblina?') # errado

'Pico da Neblina is located in the state of Pará'

In [18]:
bot.answer_to_question('oldest state in Brazil') # correto

'Pernambuco is the oldest state in Brazil'

In [19]:
bot.answer_to_question('newest state in Brazil') # errado

'Pernambuco is the newest state in Brazil'

In [20]:
bot.answer_to_question("what is brazil's capital")

'Bocas is the capital of Bocas'

In [21]:
# pergunta nao existente
bot.answer_to_question("who is brazil's president?")

"Brazil's president is the president of the republic of Brazil"

In [22]:
bot.answer_to_question('How many states does Brazil have?')

'Brazil has 22 states and the Federal District.'