Permalink
Cannot retrieve contributors at this time
Name already in use
A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
openai-cookbook/transition_guides_for_deprecated_API_endpoints/answers_functionality_example.py /
Go to fileThis commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
304 lines (256 sloc)
9.31 KB
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| from transformers import GPT2TokenizerFast | |
| import openai | |
| tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") | |
| MAX_TOKENS_LIMIT = 2048 | |
| ANSWERS_INSTRUCTION = "Please answer the question according to the above context.\n" | |
| CONTEXT_TEMPLATE = "===\nContext: {context}\n===\n" | |
| def extract_instruction(instruction): | |
| """ | |
| Extract `instruction` parameter and format it properly. | |
| If not exist, return empty string. | |
| """ | |
| if instruction is None: | |
| return "" | |
| return f"{instruction.strip()}\n\n" | |
| def semantic_search( | |
| search_model, query_for_search, file_id=None, max_documents=None, examples=None | |
| ): | |
| """ | |
| :param examples: A list of {"text":...} or {"text": ..., "label": ...}. | |
| :return: | |
| a list of semantic search result dict of documents sorted by "score": | |
| [ | |
| { | |
| "document": ..., | |
| "object": "search_result", | |
| "score": ..., | |
| "text": ..., | |
| }, | |
| ... | |
| ] | |
| """ | |
| assert (examples is None) ^ (file_id is None) # xor | |
| if file_id is not None: | |
| # This is where you'd do an elastic search call. Since there isn't an example of this | |
| # we can query, we'll raise an error. | |
| # The return value from this would be a list of examples | |
| raise NotImplementedError() | |
| # This isn't quite accurate since Search is also being deprecated. See our search guide for more | |
| # information. | |
| search_result = openai.Search.create( | |
| model=search_model, | |
| documents=[x["text"] for x in examples], | |
| query=query_for_search, | |
| ) | |
| info_dict = {d["document"]: d for d in search_result["data"]} | |
| sorted_doc_ids = sorted( | |
| info_dict.keys(), key=lambda x: info_dict[x]["score"], reverse=True | |
| ) | |
| if max_documents: | |
| sorted_doc_ids = sorted_doc_ids[:max_documents] | |
| return [info_dict[i] for i in sorted_doc_ids] | |
| def select_by_length( | |
| sorted_doc_infos, | |
| max_token_len, | |
| lambda_fn=None, | |
| ): | |
| """ | |
| Give a list of (document ID, document content in string), we will select as many | |
| documents as possible as long as the total length does not go above `max_token_len`. | |
| :param sorted_doc_infos: A list of semantic search result dict of documents sorted by "score". | |
| :param max_token_len: The maximum token length for selected documents. | |
| :param lambda_fn: A function that takes in search results dict and output a formatted | |
| example for context stuffing. | |
| :return: A tuple of ( | |
| A concatenation of selected documents used as context, | |
| A list of selected document IDs | |
| ) | |
| """ | |
| if not sorted_doc_infos: | |
| return "", [] | |
| selected_indices = [] | |
| total_doc_tokens = 0 | |
| doc_dict = {} | |
| for i, doc_info in enumerate(sorted_doc_infos): | |
| doc = lambda_fn(doc_info) if lambda_fn else doc_info["text"] | |
| n_doc_tokens = len(tokenizer.encode(doc)) | |
| if total_doc_tokens + n_doc_tokens < max_token_len: | |
| total_doc_tokens += n_doc_tokens | |
| selected_indices.append(i) | |
| doc_dict[i] = doc | |
| # The top ranked documents should go at the end. | |
| selected_indices = selected_indices[::-1] | |
| context = "".join([doc_dict[i] for i in selected_indices]) | |
| selected_doc_infos = [sorted_doc_infos[i] for i in selected_indices] | |
| return context, selected_doc_infos | |
| def answers( | |
| examples, | |
| question, | |
| model, | |
| examples_context, | |
| file_id=None, | |
| documents=None, | |
| logit_bias=None, | |
| max_rerank=200, | |
| max_tokens=16, | |
| alternative_question=None, | |
| search_model="ada", | |
| temperature=0.0, | |
| logprobs=0, | |
| stop=None, | |
| n=1, | |
| ): | |
| """ | |
| Given a prompt, a question, a list of (question, answer) pairs as examples, and | |
| a list of documents for context, it tries to include all the QA examples and top | |
| relevant context documents. | |
| The constructed prompt for the final completion call: | |
| ``` | |
| Please answer the question according to the above context. | |
| === | |
| Context: {{ the context for example QA pairs. }} | |
| === | |
| Q: example 1 question | |
| A: example 1 answer | |
| --- | |
| Q: example 2 question | |
| A: example 2 answer | |
| === | |
| Context: {{ a list of relevant documents sorted via search(question, documents) }} | |
| === | |
| Q: question | |
| A: | |
| ``` | |
| The returned object has a structure like: | |
| { | |
| "answers": [ | |
| "Beijing", | |
| "Beijing, China" | |
| ], | |
| "completion_id": "xxx-xxx", | |
| "object": "answer", | |
| "selected_documents": [ | |
| { | |
| "document": ..., # document index, same as in search/ results. | |
| "object": "search_result", | |
| "text": ..., | |
| }, | |
| ... | |
| ], | |
| } | |
| """ | |
| examples = examples if examples else [] | |
| example_prompts = [f"Q: {x}\nA: {y}" for x, y in examples] | |
| prompt = f"Q: {question}\nA:" | |
| # Append all the QA examples into the prompt. | |
| if examples_context: | |
| examples_context = CONTEXT_TEMPLATE.format(context=examples_context) | |
| instruction = ( | |
| ANSWERS_INSTRUCTION + examples_context + "\n---\n".join(example_prompts) + "\n" | |
| ) | |
| logit_bias = logit_bias if logit_bias is not None else {} | |
| if file_id is None and documents is None: | |
| raise Exception("Please submit at least one of `documents` or `file`.") | |
| if file_id is not None and documents is not None: | |
| raise Exception("Please submit only one of `documents` or `file`.") | |
| instruction = extract_instruction(instruction) | |
| n_instruction_tokens = len(tokenizer.encode(instruction)) | |
| n_prompt_tokens = len(tokenizer.encode(prompt)) | |
| n_query_tokens = len(tokenizer.encode(question)) | |
| n_context_tokens = len(tokenizer.encode(CONTEXT_TEMPLATE.format(context=""))) | |
| if documents is not None: | |
| documents = [doc.strip() + " " for doc in documents] | |
| n_docs_tokens = [len(tokenizer.encode(doc)) for doc in documents] | |
| # Except all the required content, how many tokens left for context stuffing. | |
| leftover_token_len = MAX_TOKENS_LIMIT - ( | |
| n_instruction_tokens + n_context_tokens + n_prompt_tokens + max_tokens | |
| ) | |
| sorted_doc_infos = [] | |
| question_for_search = ( | |
| alternative_question if alternative_question is not None else question | |
| ) | |
| if file_id is not None: | |
| search_model_, sorted_doc_infos = semantic_search( | |
| search_model, | |
| question_for_search, | |
| file_id=file_id, | |
| max_documents=max_rerank, | |
| ) | |
| elif len(documents) == 0: | |
| # If no context document is provided, do nothing. | |
| pass | |
| elif min(n_docs_tokens) >= leftover_token_len: | |
| # If there is no room for adding any context doc. | |
| pass | |
| elif (max_rerank is None or max_rerank >= len(documents)) and sum( | |
| n_docs_tokens | |
| ) < leftover_token_len: | |
| # If the total length of docs is short enough to be added all. | |
| selected_indices = list(range(len(documents))) | |
| sorted_doc_infos = [ | |
| {"document": i, "text": documents[i]} for i in selected_indices | |
| ] | |
| elif n_query_tokens + max(n_docs_tokens) >= MAX_TOKENS_LIMIT: | |
| # If the prompt and the longest document together go above the limit. | |
| total_tokens = n_query_tokens + max(n_docs_tokens) | |
| raise Exception( | |
| f"The longest document and prompt pair together contains {total_tokens} " | |
| f"tokens, above the limit {MAX_TOKENS_LIMIT} for semantic search. Please consider " | |
| f"shortening the prompt or the longest document." | |
| ) | |
| else: | |
| # If we can add some context documents but not all of them, we should | |
| # query search endpoint to rank docs by score. | |
| sorted_doc_infos = semantic_search( | |
| search_model, | |
| question_for_search, | |
| examples=[{"text": doc} for doc in documents], | |
| max_documents=max_rerank, | |
| ) | |
| # Select documents w.r.t. the context length limitation. | |
| context, sorted_doc_infos = select_by_length( | |
| sorted_doc_infos, | |
| leftover_token_len, | |
| lambda_fn=lambda x: x["text"].strip() + " ", | |
| ) | |
| # Add instruction before the context and the prompt after the context. | |
| if context: | |
| context = CONTEXT_TEMPLATE.format(context=context.strip()) | |
| full_prompt = instruction + context + prompt | |
| completion_result = openai.Completion.create( | |
| engine=model, | |
| prompt=full_prompt, | |
| logit_bias=logit_bias, | |
| temperature=temperature, | |
| n=n, | |
| max_tokens=max_tokens, | |
| stop=stop, | |
| logprobs=logprobs, | |
| ) | |
| completion_result["selected_documents"] = sorted_doc_infos | |
| result = dict( | |
| object="answer", | |
| selected_documents=completion_result.pop("selected_documents"), | |
| completion=completion_result["id"], | |
| ) | |
| result["answers"] = [ | |
| item["text"].replace("A:", "").split("Q:")[0].strip() | |
| for item in completion_result["choices"] | |
| ] | |
| return result | |
| print( | |
| answers( | |
| examples=[ | |
| ["What is the capital of Washington", "Olympia"], | |
| ["What is the capital of Oregon", "Salem"], | |
| ], | |
| question="What is the capital of China?", | |
| examples_context="I am a bot that names country capitals", | |
| documents=["I am a bot that names country capitals"], | |
| model="davinci", | |
| search_model="ada", | |
| alternative_question="different test", | |
| max_tokens=16, | |
| stop=["\n\n"], | |
| ) | |
| ) |