/
completion.py
59 lines (48 loc) · 1.92 KB
/
completion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import openai
import time
openai.api_key = ""
def getCompletionsOf(readFrom, writeTo, modelName, heading, numRows):
# Read in given number of prompts
prompts = [None] * numRows
with open(readFrom, "r") as file:
for i in range(numRows):
prompts[i] = file.readline().strip()
# Get responses from API
completions = [None] * numRows
for i in range(numRows):
# Get response from model
response = None
gotResponse = False
while gotResponse == False:
try:
response = openai.ChatCompletion.create(
model = modelName,
messages = [
{"role": "user", "content": prompts[i]}
]
)
gotResponse = True
except Exception as e:
print("Exception on request " + str(i + 1) + ": " + str(e))
time.sleep(10)
# Extract entire message content
message = response['choices'][0]['message']['content']
# Take only the first word of the response
tokens = message.split()
if tokens[0] == "the" and len(tokens) == 2: # Make an exception for 'the _____'
extraction = tokens[1]
else:
extraction = tokens[0]
# Remove whitespace, punctuation, and capitalization
extraction = extraction.strip()
extraction = extraction.strip(".")
extraction = extraction.lower()
completions[i] = extraction
# Sleep briefly to avoid being booted during requesting
print("Completed request " + str(i + 1))
time.sleep(1)
# Output completions to csv
with open(writeTo, "w") as file:
file.write(heading + "\n")
for line in completions:
file.write(line + "\n")