Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add node to use OpenAI's GPT-3 for QA #2605

Merged
merged 17 commits into from Jul 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Expand Up @@ -30,6 +30,7 @@ env:
--ignore=test/nodes/test_connector.py
--ignore=test/nodes/test_summarizer_translation.py
--ignore=test/nodes/test_summarizer.py
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

jobs:

Expand Down
83 changes: 83 additions & 0 deletions haystack/json-schemas/haystack-pipeline-master.schema.json
Expand Up @@ -127,6 +127,9 @@
{
"$ref": "#/definitions/MultihopEmbeddingRetrieverComponent"
},
{
"$ref": "#/definitions/OpenAIAnswerGeneratorComponent"
},
{
"$ref": "#/definitions/PDFToTextConverterComponent"
},
Expand Down Expand Up @@ -3228,6 +3231,86 @@
],
"additionalProperties": false
},
"OpenAIAnswerGeneratorComponent": {
"type": "object",
"properties": {
"name": {
"title": "Name",
"description": "Custom name for the component. Helpful for visualization and debugging.",
"type": "string"
},
"type": {
"title": "Type",
"description": "Haystack Class name for the component.",
"type": "string",
"const": "OpenAIAnswerGenerator"
},
"params": {
"title": "Parameters",
"type": "object",
"properties": {
"api_key": {
"title": "Api Key",
"type": "string"
},
"model": {
"title": "Model",
"default": "text-curie-001",
"type": "string"
},
"max_tokens": {
"title": "Max Tokens",
"default": 7,
"type": "integer"
},
"top_k": {
"title": "Top K",
"default": 5,
"type": "integer"
},
"temperature": {
"title": "Temperature",
"default": 0,
"type": "integer"
},
"presence_penalty": {
"title": "Presence Penalty",
"default": -2.0,
"type": "number"
},
"frequency_penalty": {
"title": "Frequency Penalty",
"default": -2.0,
"type": "number"
},
"examples_context": {
"title": "Examples Context",
"type": "string"
},
"examples": {
"title": "Examples",
"type": "array",
"items": {}
},
"stop_words": {
"title": "Stop Words",
"type": "array",
"items": {}
}
},
"required": [
"api_key"
],
"additionalProperties": false,
"description": "Each parameter can reference other components defined in the same YAML file."
}
},
"required": [
"type",
"name"
],
"additionalProperties": false
},
"PDFToTextConverterComponent": {
"type": "object",
"properties": {
Expand Down
2 changes: 1 addition & 1 deletion haystack/nodes/__init__.py
Expand Up @@ -2,7 +2,7 @@

from haystack.nodes.base import BaseComponent

from haystack.nodes.answer_generator import BaseGenerator, RAGenerator, Seq2SeqGenerator
from haystack.nodes.answer_generator import BaseGenerator, RAGenerator, Seq2SeqGenerator, OpenAIAnswerGenerator
from haystack.nodes.document_classifier import BaseDocumentClassifier, TransformersDocumentClassifier
from haystack.nodes.evaluator import EvalDocuments, EvalAnswers
from haystack.nodes.extractor import EntityExtractor, simplify_ner_for_qa
Expand Down
1 change: 1 addition & 0 deletions haystack/nodes/answer_generator/__init__.py
@@ -1,2 +1,3 @@
from haystack.nodes.answer_generator.base import BaseGenerator
from haystack.nodes.answer_generator.transformers import RAGenerator, Seq2SeqGenerator
from haystack.nodes.answer_generator.openai import OpenAIAnswerGenerator
189 changes: 189 additions & 0 deletions haystack/nodes/answer_generator/openai.py
@@ -0,0 +1,189 @@
from typing import Optional, List, Tuple
import json
import logging
import requests

from transformers import GPT2TokenizerFast

from haystack.nodes.answer_generator import BaseGenerator
from haystack import Document


logger = logging.getLogger(__name__)


class OpenAIAnswerGenerator(BaseGenerator):
"""
Uses the GPT-3 models from the OpenAI API to generate answers based on supplied documents (e.g. from any retriever
in Haystack).

To be able to use this node, you need an API key from an active OpenAI account (you can sign-up for an account
[here](https://openai.com/api/)).
"""

def __init__(
self,
api_key: str,
model: str = "text-curie-001",
max_tokens: int = 7,
top_k: int = 5,
temperature: int = 0,
presence_penalty: float = -2.0,
frequency_penalty: float = -2.0,
examples_context: Optional[str] = None,
examples: Optional[List] = None,
stop_words: Optional[List] = None,
):

"""
:param api_key: Your API key from OpenAI
:param model: ID of the engine to use for generating the answer. You can select one of `"text-ada-001"`,
`"text-babbage-001"`, `"text-curie-001"`, or `"text-davinci-002"`
(from worst to best + cheapest to most expensive). Please refer to the
[OpenAI Documentation](https://beta.openai.com/docs/models/gpt-3) for more information about the
models.
:param max_tokens: The maximum number of tokens allowed for the generated answer.
:param top_k: Number of generated answers.
:param temperature: What sampling temperature to use. Higher values mean the model will take more risks and
value 0 (argmax sampling) works better for scenarios with a well-defined answer.
:param presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear
in the text so far, increasing the model's likelihood to talk about new topics.
:param frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing
frequency in the text so far, decreasing the model's likelihood to repeat the same line
verbatim.
:param examples_context: A text snippet containing the contextual information used to generate the answers for
the examples you provide.
If not supplied, the default from OpenAPI docs is used:
"In 2017, U.S. life expectancy was 78.6 years."
:param examples: List of (question, answer) pairs that will help steer the model towards the tone and answer
format you'd like. We recommend adding 2 to 3 examples.
If not supplied, the default from OpenAPI docs is used:
[["What is human life expectancy in the United States?", "78 years."]]
:param stop_words: Up to 4 sequences where the API will stop generating further tokens. The returned text will
not contain the stop sequence.
If not supplied, the default from OpenAPI docs is used: ["\n", "<|endoftext|>"]
"""
super().__init__()
if not examples_context:
examples_context = "In 2017, U.S. life expectancy was 78.6 years."
if not examples:
examples = [["What is human life expectancy in the United States?", "78 years."]]
if not stop_words:
stop_words = ["\n", "<|endoftext|>"]

self.model = model
self.max_tokens = max_tokens
self.top_k = top_k
self.temperature = temperature
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.examples_context = examples_context
self.examples = examples
self.api_key = api_key
self.stop_words = stop_words
self._tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")

if "davinci" in self.model:
self.MAX_TOKENS_LIMIT = 4000
else:
self.MAX_TOKENS_LIMIT = 2048

def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None):
"""
Use loaded QA model to generate answers for a query based on the supplied list of Documents.

Returns dictionaries containing answers.
Be aware that OpenAI doesn't return scores for those answers.

Example:
```python
|{
| 'query': 'Who is the father of Arya Stark?',
| 'answers':[Answer(
| 'answer': 'Eddard,',
| 'score': None,
| ),...
| ]
|}
```

:param query: Query string
:param documents: List of Document in which to search for the answer
:param top_k: The maximum number of answers to return
:return: Dict containing query and answers
"""
if top_k is None:
top_k = self.top_k

# convert input to OpenAI format
prompt, input_docs = self._build_prompt(query=query, documents=documents)

# get answers from OpenAI API
url = "https://api.openai.com/v1/completions"

payload = {
"model": self.model,
"prompt": prompt,
"max_tokens": self.max_tokens,
"stop": self.stop_words,
"n": top_k,
"temperature": self.temperature,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
}

headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
response = requests.request("POST", url, headers=headers, data=json.dumps(payload))

res = json.loads(response.text)
generated_answers = [ans["text"] for ans in res["choices"]]
answers = self._create_answers(generated_answers, input_docs)
result = {"query": query, "answers": answers}

return result

def _build_prompt(self, query: str, documents: List[Document]) -> Tuple[str, List[Document]]:
"""
Builds the prompt for the GPT-3 model in order for it to generate an answer.
"""
example_context = f"===\nContext: {self.examples_context}\n===\n"
example_prompts = "\n---\n".join([f"Q: {question}\nA: {answer}" for question, answer in self.examples])
instruction = "Please answer the question according to the above context.\n" + example_context + example_prompts
instruction = f"{instruction.strip()}\n\n"

qa_prompt = f"Q: {query}\nA:"

n_instruction_tokens = len(self._tokenizer.encode(instruction + qa_prompt + "===\nContext: \n===\n"))
n_docs_tokens = [len(self._tokenizer.encode(doc.content)) for doc in documents]
leftover_token_len = self.MAX_TOKENS_LIMIT - n_instruction_tokens

# Add as many Documents as context as fit into the model
input_docs = []
input_docs_content = []
skipped_docs = 0
for doc, doc_token_len in zip(documents, n_docs_tokens):
if doc_token_len <= leftover_token_len:
input_docs.append(doc)
input_docs_content.append(doc.content)
leftover_token_len -= doc_token_len
else:
skipped_docs += 1

if len(input_docs) == 0:
logger.warning(
f"Skipping all of the provided Documents, as none of them fits the maximum token limit of "
f"{self.MAX_TOKENS_LIMIT}. The generated answers will therefore not be conditioned on any context."
)
elif skipped_docs >= 1:
logger.warning(
f"Skipping {skipped_docs} of the provided Documents, as using them would exceed the maximum token "
f"limit of {self.MAX_TOKENS_LIMIT}."
)

# Top ranked documents should go at the end
context = " ".join(reversed(input_docs_content))
context = f"===\nContext: {context}\n===\n"

full_prompt = instruction + context + qa_prompt

return full_prompt, input_docs
7 changes: 6 additions & 1 deletion test/conftest.py
Expand Up @@ -49,7 +49,7 @@

from haystack.document_stores import BaseDocumentStore, DeepsetCloudDocumentStore, InMemoryDocumentStore

from haystack.nodes import BaseReader, BaseRetriever
from haystack.nodes import BaseReader, BaseRetriever, OpenAIAnswerGenerator
from haystack.nodes.answer_generator.transformers import Seq2SeqGenerator
from haystack.nodes.answer_generator.transformers import RAGenerator
from haystack.nodes.ranker import SentenceTransformersRanker
Expand Down Expand Up @@ -514,6 +514,11 @@ def rag_generator():
return RAGenerator(model_name_or_path="facebook/rag-token-nq", generator_type="token", max_length=20)


@pytest.fixture
def openai_generator():
return OpenAIAnswerGenerator(api_key=os.environ.get("OPENAI_API_KEY", ""), model="text-babbage-001", top_k=1)


@pytest.fixture
def question_generator():
return QuestionGenerator(model_name_or_path="valhalla/t5-small-e2e-qg")
Expand Down
11 changes: 11 additions & 0 deletions test/nodes/test_generator.py
@@ -1,3 +1,4 @@
import os
import sys
from typing import List

Expand Down Expand Up @@ -124,3 +125,13 @@ def __call__(self, some_invalid_para: str, another_invalid_param: str) -> None:
with pytest.raises(Exception) as exception_info:
output = pipeline.run(query=query, params={"top_k": 1})
assert "does not have a valid __call__ method signature" in str(exception_info.value)


@pytest.mark.integration
def test_openai_answer_generator(openai_generator, docs):
if "OPENAI_API_KEY" in os.environ:
prediction = openai_generator.predict(query="Who lives in Berlin?", documents=docs, top_k=1)
assert len(prediction["answers"]) == 1
assert "Carla" in prediction["answers"][0].answer
else:
pytest.skip("No API key provided in environment variables.")