Skip to content

Commit

Permalink
feat: Add max_tokens to BaseGenerator params (#4168)
Browse files Browse the repository at this point in the history
* Add max_tokens to BaseGenerator params

* Make mypy happy

* Rebase and resolve conflicts

* Fix signature issues

* Update lg

* Add a mocked unit test method

* end-of-file-fixer corrected file

* Convert to unit test

* Mark test as integration

* make the test unit

---------

Co-authored-by: agnieszka-m <amarzec13@gmail.com>
Co-authored-by: Massimiliano Pippi <mpippi@gmail.com>
  • Loading branch information
3 people committed May 18, 2023
1 parent 401520b commit 5d7ee2e
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 18 deletions.
36 changes: 26 additions & 10 deletions haystack/nodes/answer_generator/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,27 +20,38 @@ def __init__(self, progress_bar: bool = True):
self.progress_bar = progress_bar

@abstractmethod
def predict(self, query: str, documents: List[Document], top_k: Optional[int]) -> Dict:
def predict(self, query: str, documents: List[Document], top_k: Optional[int], max_tokens: Optional[int]) -> Dict:
"""
Abstract method to generate answers.
:param query: Query
:param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
:param top_k: Number of returned answers
:param documents: Related documents (for example, coming from a retriever) the answer should be based on.
:param top_k: Number of returned answers.
:param max_tokens: THe maximum number of tokens the generated answer can have.
:return: Generated answers plus additional infos in a dict
"""
pass

def run(self, query: str, documents: List[Document], top_k: Optional[int] = None, labels: Optional[MultiLabel] = None, add_isolated_node_eval: bool = False): # type: ignore
def run( # type: ignore
self,
query: str,
documents: List[Document],
top_k: Optional[int] = None,
labels: Optional[MultiLabel] = None,
add_isolated_node_eval: bool = False,
max_tokens: Optional[int] = None,
): # type: ignore
if documents:
results = self.predict(query=query, documents=documents, top_k=top_k)
results = self.predict(query=query, documents=documents, top_k=top_k, max_tokens=max_tokens)
else:
results = {"answers": []}

# run evaluation with "perfect" labels as node inputs to calculate "upper bound" metrics for just this node
if add_isolated_node_eval and labels is not None:
relevant_documents = list({label.document.id: label.document for label in labels.labels}.values())
results_label_input = self.predict(query=query, documents=relevant_documents, top_k=top_k)
results_label_input = self.predict(
query=query, documents=relevant_documents, top_k=top_k, max_tokens=max_tokens
)
results["answers_isolated"] = results_label_input["answers"]

return results, "output_1"
Expand All @@ -51,8 +62,11 @@ def run_batch( # type: ignore
documents: Union[List[Document], List[List[Document]]],
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
max_tokens: Optional[int] = None,
):
results = self.predict_batch(queries=queries, documents=documents, top_k=top_k, batch_size=batch_size)
results = self.predict_batch(
queries=queries, documents=documents, top_k=top_k, batch_size=batch_size, max_tokens=max_tokens
)
return results, "output_1"

def _flatten_docs(self, documents: List[Document]):
Expand Down Expand Up @@ -92,6 +106,7 @@ def predict_batch(
documents: Union[List[Document], List[List[Document]]],
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
max_tokens: Optional[int] = None,
):
"""
Generate the answer to the input queries. The generation will be conditioned on the supplied documents.
Expand All @@ -110,10 +125,11 @@ def predict_batch(
and the Answers will be aggregated per query-Document pair.
:param queries: List of queries.
:param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
:param documents: Related documents (for example, coming from a retriever) the answer should be based on.
Can be a single list of Documents or a list of lists of Documents.
:param top_k: Number of returned answers per query.
:param batch_size: Not applicable.
:param max_tokens: The maximum number of tokens the generated answer can have.
:return: Generated answers plus additional infos in a dict like this:
```python
Expand Down Expand Up @@ -142,7 +158,7 @@ def predict_batch(
for doc in documents:
if not isinstance(doc, Document):
raise HaystackError(f"doc was of type {type(doc)}, but expected a Document.")
preds = self.predict(query=query, documents=[doc], top_k=top_k)
preds = self.predict(query=query, documents=[doc], top_k=top_k, max_tokens=max_tokens)
results["answers"].append(preds["answers"])
pb.update(1)
pb.close()
Expand All @@ -158,7 +174,7 @@ def predict_batch(
for query, cur_docs in zip(queries, documents):
if not isinstance(cur_docs, list):
raise HaystackError(f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents.")
preds = self.predict(query=query, documents=cur_docs, top_k=top_k)
preds = self.predict(query=query, documents=cur_docs, top_k=top_k, max_tokens=max_tokens)
results["answers"].append(preds["answers"])
pb.update(1)
pb.close()
Expand Down
4 changes: 3 additions & 1 deletion haystack/nodes/answer_generator/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def predict(
query: str,
documents: List[Document],
top_k: Optional[int] = None,
max_tokens: Optional[int] = None,
timeout: Union[float, Tuple[float, float]] = OPENAI_TIMEOUT,
):
"""
Expand All @@ -193,6 +194,7 @@ def predict(
:param query: The query you want to provide. It's a string.
:param documents: List of Documents in which to search for the Answer.
:param top_k: The maximum number of Answers to return.
:param max_tokens: The maximum number of tokens the generated Answer can have.
:param timeout: How many seconds to wait for the server to send data before giving up,
as a float, or a :ref:`(connect timeout, read timeout) <timeouts>` tuple.
Defaults to 10 seconds.
Expand All @@ -208,7 +210,7 @@ def predict(
payload = {
"model": self.model,
"prompt": prompt,
"max_tokens": self.max_tokens,
"max_tokens": max_tokens or self.max_tokens,
"stop": self.stop_words,
"n": top_k,
"temperature": self.temperature,
Expand Down
14 changes: 10 additions & 4 deletions haystack/nodes/answer_generator/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,17 @@ def _prepare_passage_embeddings(self, docs: List[Document], embeddings: numpy.nd

return embeddings_in_tensor.to(self.devices[0])

def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
def predict(
self, query: str, documents: List[Document], top_k: Optional[int] = None, max_tokens: Optional[int] = None
) -> Dict:
"""
Generate the answer to the input query. The generation will be conditioned on the supplied documents.
These documents can for example be retrieved via the Retriever.
:param query: Query
:param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
:param top_k: Number of returned answers
:param max_tokens: Maximum number of tokens to generate
:return: Generated answers plus additional infos in a dict like this:
```python
Expand Down Expand Up @@ -279,7 +282,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
doc_scores=doc_scores,
num_return_sequences=top_k,
num_beams=self.num_beams,
max_length=self.max_length,
max_length=max_tokens or self.max_length,
min_length=self.min_length,
n_docs=len(flat_docs_dict["content"]),
)
Expand Down Expand Up @@ -430,14 +433,17 @@ def _get_converter(cls, model_name_or_path: str) -> Optional[Callable]:
model_name_or_path = "yjernite/bart_eli5"
return cls._model_input_converters.get(model_name_or_path)

def predict(self, query: str, documents: List[Document], top_k: Optional[int] = None) -> Dict:
def predict(
self, query: str, documents: List[Document], top_k: Optional[int] = None, max_tokens: Optional[int] = None
) -> Dict:
"""
Generate the answer to the input query. The generation will be conditioned on the supplied documents.
These document can be retrieved via the Retriever or supplied directly via predict method.
:param query: Query
:param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
:param top_k: Number of returned answers
:param max_tokens: Maximum number of tokens in the generated answer
:return: Generated answers
"""
Expand Down Expand Up @@ -474,7 +480,7 @@ def predict(self, query: str, documents: List[Document], top_k: Optional[int] =
input_ids=query_and_docs_encoded["input_ids"],
attention_mask=query_and_docs_encoded["attention_mask"],
min_length=self.min_length,
max_length=self.max_length,
max_length=max_tokens or self.max_length,
do_sample=True if self.num_beams == 1 else False,
early_stopping=True,
num_beams=self.num_beams,
Expand Down
2 changes: 1 addition & 1 deletion test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def embed_documents(self, documents: List[Document]):


class MockSeq2SegGenerator(BaseGenerator):
def predict(self, query: str, documents: List[Document], top_k: Optional[int]) -> Dict:
def predict(self, query: str, documents: List[Document], top_k: Optional[int], max_tokens: Optional[int]) -> Dict:
pass


Expand Down
27 changes: 25 additions & 2 deletions test/nodes/test_generator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import sys
from typing import List
from unittest.mock import patch
from unittest.mock import patch, create_autospec

import pytest

from haystack import Pipeline
from haystack.schema import Document
from haystack.nodes.answer_generator import Seq2SeqGenerator, OpenAIAnswerGenerator, RAGenerator
from haystack.pipelines import GenerativeQAPipeline
Expand Down Expand Up @@ -218,3 +218,26 @@ def test_build_prompt_within_max_length():

assert len(prompt_docs) == 1
assert prompt_docs[0] == documents[0]


@pytest.mark.unit
def test_openai_answer_generator_pipeline_max_tokens():
"""
tests that the max_tokens parameter is passed to the generator component in the pipeline
"""
question = "What is New York City like?"
mocked_response = "Forget NYC, I was generated by the mock method."
nyc_docs = [Document(content="New York is a cool and amazing city to live in the United States of America.")]
pipeline = Pipeline()

# mock load_openai_tokenizer to avoid accessing the internet to init tiktoken
with patch("haystack.nodes.answer_generator.openai.load_openai_tokenizer"):
openai_generator = OpenAIAnswerGenerator(api_key="fake_api_key", model="text-babbage-001", top_k=1)

pipeline.add_node(component=openai_generator, name="generator", inputs=["Query"])
openai_generator.run = create_autospec(openai_generator.run)
openai_generator.run.return_value = ({"answers": mocked_response}, "output_1")

result = pipeline.run(query=question, documents=nyc_docs, params={"generator": {"max_tokens": 3}})
assert result["answers"] == mocked_response
openai_generator.run.assert_called_with(query=question, documents=nyc_docs, max_tokens=3)

0 comments on commit 5d7ee2e

Please sign in to comment.