-
Notifications
You must be signed in to change notification settings - Fork 1
/
reverse_dictionary.py
executable file
·155 lines (125 loc) · 4.73 KB
/
reverse_dictionary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#!/usr/bin/env python3
"""
Code for the MAS.S68 (Generative AI for Constructive Communication) programming workshop
Reverse dictionary (description-to-word guesser) using the OpenAI GPT-3 API
Prompts and evaluation data are in data/train.jsonl and data/test.jsonl respectively
and were produced by ./pull_rd_data.py
"""
import argparse
import json
import os
import re
import openai
import streamlit as st
# Don't forget to set your OPENAI_API_KEY environment variable.
# Or set it here directly (but don't check it into a git repo.)
openai.api_key = os.getenv("OPENAI_API_KEY")
def call_gpt3_api(prompt, model):
response = openai.Completion.create(
model=model,
prompt=prompt,
temperature=0,
max_tokens=256,
frequency_penalty=0,
presence_penalty=0,
)
return response
def definition_to_zero_shot_prompt(definition):
return 'What are some words that mean "%s"?\n\n' % (definition)
def definition_to_few_shot_prompt(definition, examples):
instructions = "Please find words that mean the same thing as the given definition. For example:\n\n"
target = {
"definition": definition,
"word": "",
} # placeholder for the word we're looking for
# Format the example text as a bulleted list
example_text = "\n".join(
[
'- "%s": %s' % (example["definition"], example["word"])
for example in examples + [target]
]
)
return instructions + example_text
def response_to_completion_text(openai_response):
return openai_response["choices"][0]["text"]
def completion_text_to_words(s):
words = re.sub("[^A-Za-z]", " ", s.strip()).split()
if words:
return words
else:
# Return empty string if it produces no answer
return []
def get_words_for_definition(definition, examples_to_use=None):
if not examples_to_use:
prompt = definition_to_zero_shot_prompt(definition)
else:
prompt = definition_to_few_shot_prompt(definition, examples_to_use)
openai_response = call_gpt3_api(prompt, args.model)
completion = response_to_completion_text(openai_response)
return completion_text_to_words(completion)
def read_batch_of_queries(filename):
"""Reads a set of reverse dictionary labeled data as a list of dictionaries."""
return [json.loads(line) for line in open(filename, "r").read().strip().split("\n")]
def get_example_queries_for_prompt(filename):
return read_batch_of_queries(filename)[: args.num_prompt_examples]
def run_batch_of_queries(evaluation_queries_filename, prompt_example_queries_filename):
evaluation_queries = read_batch_of_queries(evaluation_queries_filename)
if args.num_prompt_examples > 0:
prompt_example_queries = get_example_queries_for_prompt(
prompt_example_queries_filename
)
else:
prompt_example_queries = None
num_correct = 0
for record in evaluation_queries:
definition = record["definition"]
record["gpt3_words"] = get_words_for_definition(
definition, examples_to_use=prompt_example_queries
)
record["gpt3_is_correct"] = (
len(record["gpt3_words"]) > 0
and record["gpt3_words"][0].lower() == record["word"].lower()
)
if record["gpt3_is_correct"]:
num_correct += 1
print(json.dumps(record))
print(
"Accuracy = %d / %d = %f"
% (num_correct, len(evaluation_queries), num_correct / len(evaluation_queries))
)
def streamlit_app():
with st.form("main_form"):
query = st.text_input("Enter description of the word you're looking for")
submitted = st.form_submit_button("Submit")
if submitted:
st.write(get_words_for_definition(query))
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--query", help="Look up words matching this description", default=None
)
parser.add_argument(
"--eval",
help="Run an evaluation on the given file of queries (in jsonl format)",
default=None,
)
parser.add_argument(
"--model", help="Which OpenAI model to use", default="text-curie-001"
)
parser.add_argument(
"--num_prompt_examples",
help="The number of examples from data/train.jsonl to include in the prompt. If 0, use a separate 0-shot prompt.",
default=0,
)
args = parser.parse_args()
args.num_prompt_examples = int(args.num_prompt_examples)
if args.query:
print(
get_words_for_definition(
args.query, get_example_queries_for_prompt("data/train.jsonl")
)
)
elif args.eval:
run_batch_of_queries(args.eval, "data/train.jsonl")
else:
streamlit_app()