-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add a codex backtranslation example to improve SQL queries (#58) * Add a codex backtranslation example to improve SQL queries * Boris update ft example (#57) * update fine-tune example to show the new CLI outputs * model specifiction for search (#60) * Catch chunked encoding errors and retry (#63) * Add batch suggestion logic to prepare_data for fine_tunes and custom Q&A answers logic (#62) * Add batch suggestion logic to prepare_data for fine_tunes; add an example of how to create a rudimentary answers endpoint with a custom Q&A model Co-authored-by: Madeleine Thompson <madeleine@openai.com> Co-authored-by: hallacy <hallacy@openai.com>
- Loading branch information
1 parent
c79fefc
commit 7febb75
Showing
7 changed files
with
465 additions
and
94 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
import openai | ||
from smokey import Smokey | ||
from typing import List, Union | ||
|
||
|
||
def get_candidates( | ||
prompt: str, | ||
stop: List[str], | ||
temperature: float, | ||
priming_prefix: str, | ||
engine: str, | ||
n: int = 5, | ||
) -> List[str]: | ||
""" | ||
Generate N candidate completions based on the prompt, generated with a specific temperature. | ||
:param prompt: The prompt to start the conversation with. | ||
:param stop: A list of tokens that indicate the end of the generation. | ||
:param temperature: The temperature of the generation. | ||
:param priming_prefix: The prefix to use for the priming. | ||
:param engine: The engine to use for the generation. | ||
:param n: The number of completions to generate. | ||
:return: A list of completions. | ||
""" | ||
response = openai.Completion.create( | ||
engine=engine, | ||
prompt=prompt, | ||
temperature=temperature, | ||
max_tokens=150, | ||
top_p=1, | ||
frequency_penalty=0, | ||
presence_penalty=0, | ||
stop=stop, | ||
n=n, | ||
) | ||
responses = [priming_prefix + choice.text for choice in response.choices] | ||
return responses | ||
|
||
|
||
def rindex(lst: List, value: str) -> int: | ||
""" | ||
Return the index of the last occurence of a value in a list. | ||
:param lst: The list to search in. | ||
:param value: The value to search for. | ||
:return: The index of the last occurence of the value. | ||
""" | ||
try: | ||
return len(lst) - lst[::-1].index(value) - 1 | ||
except ValueError: | ||
raise ValueError(f"Answer start token `{value}` not found in the eval template") | ||
|
||
|
||
def eval_candidate( | ||
candidate_answer: str, | ||
original_instruction: str, | ||
eval_template: str, | ||
answer_start_token: str, | ||
engine: str, | ||
) -> float: | ||
""" | ||
Evaluate a candidate answer by calculating the average log probability | ||
of the original instruction, given the candidate answer with a specific | ||
evaluation template, aimed at reconstructing the original instruction. | ||
:param candidate_answer: The candidate answer to evaluate. | ||
:param original_instruction: The original instruction. | ||
:param eval_template: The template to use for the evaluation. | ||
:param answer_start_token: The token to use to indicate the start of the answer. | ||
:param engine: The engine to use for the evaluation. | ||
:return: The evaluation of the candidate answer. | ||
""" | ||
response = openai.Completion.create( | ||
engine=engine, | ||
prompt=eval_template.format(candidate_answer, original_instruction), | ||
temperature=0, | ||
max_tokens=0, | ||
top_p=1, | ||
frequency_penalty=0, | ||
presence_penalty=0, | ||
logprobs=1, | ||
echo=True, | ||
) | ||
|
||
answer_start = rindex( | ||
response["choices"][0]["logprobs"]["tokens"], answer_start_token | ||
) | ||
logprobs = response["choices"][0]["logprobs"]["token_logprobs"][answer_start + 1 :] | ||
return sum(logprobs) / len(logprobs) | ||
|
||
|
||
def backtranslation( | ||
prompt_template: str, | ||
additional_info: str, | ||
instruction: str, | ||
eval_template: str, | ||
priming_prefix: str = "SELECT", | ||
stop1: List[str] = ["#", ";"], | ||
answer_start_token: str = "--", | ||
n: int = 5, | ||
temperature: float = 0.5, | ||
return_all_results: bool = False, | ||
engine: str = "davinci-codex", | ||
) -> Union[str, List[str, float]]: | ||
""" | ||
Generate a number of SQL queries given a natural language instruction, | ||
and pick the best one based on the average log probability of explaining the | ||
candidate SQL query with the exact original instruction, when prompted for | ||
a natural language explanation of the candidate SQL query. | ||
:param prompt_template: The template to use for the prompt to generate SQL. | ||
:param additional_info: Additional information to include in the prompt | ||
(SQL Tables, and their properties). | ||
:param instruction: The instruction in natural language. | ||
:param eval_template: The template to use for the evaluation. | ||
:param priming_prefix: The prefix to use for the priming of the SQL query. | ||
:param stop1: A list of tokens that indicate the end of the generation. | ||
:param answer_start_token: The token to use to indicate the start of the | ||
natural answer. | ||
:param n: The number of candidates to generate. | ||
:param temperature: The temperature of the generation. | ||
:param return_all_results: Whether to return all results or just the best one. | ||
:param engine: The engine to use for the generation and evaluation. | ||
:return: The best SQL query, or a list of all scored generated SQL queries. | ||
""" | ||
prompt_template = prompt_template.format( | ||
additional_info, instruction, priming_prefix | ||
) | ||
|
||
candidates = [] | ||
responses = get_candidates( | ||
prompt_template, stop1, temperature, priming_prefix, engine=engine, n=n | ||
) | ||
for i in range(n): | ||
quality = eval_candidate( | ||
responses[i], | ||
instruction, | ||
eval_template, | ||
answer_start_token, | ||
engine=engine, | ||
) | ||
candidates.append((responses[i], quality)) | ||
|
||
candidates.sort(key=lambda x: x[1], reverse=True) | ||
if return_all_results: | ||
return candidates | ||
return candidates[0][0] | ||
|
||
|
||
def main( | ||
nl_query: str = "Return the name of each department that had more than 10 employees in June 2021", | ||
eval_template: str = "{};\n-- Explanation of the above query in human readable format\n-- {}", | ||
table_definitions: str = "# Employee(id, name, department_id)\n# Department(id, name, address)\n# Salary_Payments(id, employee_id, amount, date)\n", | ||
prompt_template: str = "### Postgres SQL tables, with their properties:\n#\n{}#\n### {}\n{}", | ||
n: int = 3, | ||
temperature: float = 0.3, | ||
engine: str = "davinci-codex", | ||
): | ||
""" | ||
Generate a number of SQL queries given a natural language instruction, | ||
and pick the best one based on the highest backtranslation score. | ||
:param nl_query: The natural language query. | ||
:param eval_template: The template to use for the evaluation. | ||
:param table_definitions: The definitions of the tables used in the query. | ||
:param prompt_template: The template to use for the prompt to generate SQL. | ||
:param n: The number of candidates to generate. | ||
:param temperature: The temperature of the generation. | ||
:param engine: The engine to use for the generation and evaluation. | ||
:return: The best SQL query, or a list of all scored generated SQL queries. | ||
""" | ||
|
||
result = backtranslation( | ||
prompt_template, | ||
table_definitions, | ||
nl_query, | ||
eval_template, | ||
priming_prefix="SELECT", | ||
temperature=temperature, | ||
n=n, | ||
engine=engine, | ||
) | ||
print(result) | ||
|
||
|
||
if __name__ == "__main__": | ||
Smokey(main) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
import openai | ||
import argparse | ||
|
||
|
||
def create_context( | ||
question, search_file_id, max_len=1800, search_model="ada", max_rerank=10 | ||
): | ||
""" | ||
Create a context for a question by finding the most similar context from the search file. | ||
:param question: The question | ||
:param search_file_id: The file id of the search file | ||
:param max_len: The maximum length of the returned context (in tokens) | ||
:param search_model: The search model to use | ||
:param max_rerank: The maximum number of reranking | ||
:return: The context | ||
""" | ||
results = openai.Engine(search_model).search( | ||
search_model=search_model, | ||
query=question, | ||
max_rerank=max_rerank, | ||
file=search_file_id, | ||
return_metadata=True, | ||
) | ||
returns = [] | ||
cur_len = 0 | ||
for result in results["data"]: | ||
cur_len += int(result["metadata"]) + 4 | ||
if cur_len > max_len: | ||
break | ||
returns.append(result["text"]) | ||
return "\n\n###\n\n".join(returns) | ||
|
||
|
||
def answer_question( | ||
search_file_id="<SEARCH_FILE_ID>", | ||
fine_tuned_qa_model="<FT_QA_MODEL_ID>", | ||
question="Which country won the European Football championship in 2021?", | ||
max_len=1800, | ||
search_model="ada", | ||
max_rerank=10, | ||
debug=False, | ||
stop_sequence=["\n", "."], | ||
max_tokens=100, | ||
): | ||
""" | ||
Answer a question based on the most similar context from the search file, using your fine-tuned model. | ||
:param question: The question | ||
:param fine_tuned_qa_model: The fine tuned QA model | ||
:param search_file_id: The file id of the search file | ||
:param max_len: The maximum length of the returned context (in tokens) | ||
:param search_model: The search model to use | ||
:param max_rerank: The maximum number of reranking | ||
:param debug: Whether to output debug information | ||
:param stop_sequence: The stop sequence for Q&A model | ||
:param max_tokens: The maximum number of tokens to return | ||
:return: The answer | ||
""" | ||
context = create_context( | ||
question, | ||
search_file_id, | ||
max_len=max_len, | ||
search_model=search_model, | ||
max_rerank=max_rerank, | ||
) | ||
if debug: | ||
print("Context:\n" + context) | ||
print("\n\n") | ||
try: | ||
response = openai.Completion.create( | ||
model=fine_tuned_qa_model, | ||
prompt=f"Answer the question based on the context below\n\nText: {context}\n\n---\n\nQuestion: {question}\nAnswer:", | ||
temperature=0, | ||
max_tokens=max_tokens, | ||
top_p=1, | ||
frequency_penalty=0, | ||
presence_penalty=0, | ||
stop=stop_sequence, | ||
) | ||
return response["choices"][0]["text"] | ||
except Exception as e: | ||
print(e) | ||
return "" | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser( | ||
description="Rudimentary functionality of the answers endpoint with a fine-tuned Q&A model.", | ||
formatter_class=argparse.ArgumentDefaultsHelpFormatter, | ||
) | ||
parser.add_argument( | ||
"--search_file_id", help="Search file id", required=True, type=str | ||
) | ||
parser.add_argument( | ||
"--fine_tuned_qa_model", help="Fine-tuned QA model id", required=True, type=str | ||
) | ||
parser.add_argument( | ||
"--question", help="Question to answer", required=True, type=str | ||
) | ||
parser.add_argument( | ||
"--max_len", | ||
help="Maximum length of the returned context (in tokens)", | ||
default=1800, | ||
type=int, | ||
) | ||
parser.add_argument( | ||
"--search_model", help="Search model to use", default="ada", type=str | ||
) | ||
parser.add_argument( | ||
"--max_rerank", | ||
help="Maximum number of reranking for the search", | ||
default=10, | ||
type=int, | ||
) | ||
parser.add_argument( | ||
"--debug", help="Print debug information (context used)", action="store_true" | ||
) | ||
parser.add_argument( | ||
"--stop_sequence", | ||
help="Stop sequences for the Q&A model", | ||
default=["\n", "."], | ||
nargs="+", | ||
type=str, | ||
) | ||
parser.add_argument( | ||
"--max_tokens", | ||
help="Maximum number of tokens to return", | ||
default=100, | ||
type=int, | ||
) | ||
args = parser.parse_args() | ||
response = answer_question( | ||
search_file_id=args.search_file_id, | ||
fine_tuned_qa_model=args.fine_tuned_qa_model, | ||
question=args.question, | ||
max_len=args.max_len, | ||
search_model=args.search_model, | ||
max_rerank=args.max_rerank, | ||
debug=args.debug, | ||
stop_sequence=args.stop_sequence, | ||
max_tokens=args.max_tokens, | ||
) | ||
print(f"Answer:{response}") |
Oops, something went wrong.