/
listing1.py
73 lines (57 loc) · 1.83 KB
/
listing1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
'''
Created on Nov 5, 2023
@author: immanueltrummer
'''
import argparse
import openai
import re
import time
client = openai.OpenAI()
def create_prompt(question):
""" Generate prompt to translate question into SQL query.
Args:
question: question about data in natural language.
Returns:
prompt for question translation.
"""
parts = []
parts += ['Database:']
parts += ['create table games(rank int, name text, platform text,']
parts += ['year int, genre text, publisher text, americasales numeric,']
parts += ['eusales numeric, japansales numeric, othersales numeric,']
parts += ['globalsales numeric);']
parts += ['Translate this question into SQL query:']
parts += [question]
return '\n'.join(parts)
def call_llm(prompt):
""" Query large language model and return answer.
Args:
prompt: input prompt for language model.
Returns:
Answer by language model.
"""
for nr_retries in range(1, 4):
try:
response = client.chat.completions.create(
model='gpt-4o',
messages=[
{'role':'user', 'content':prompt}
]
)
return response.choices[0].message.content
except:
time.sleep(nr_retries * 2)
raise Exception('Cannot query OpenAI model!')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('question', type=str, help='A question about game sales')
args = parser.parse_args()
prompt = create_prompt(args.question)
print('--- Prompt ---')
print(prompt)
answer = call_llm(prompt)
print('--- Answer ---')
print(answer)
query = re.findall('```sql(.*)```', answer, re.DOTALL)[0]
print('--- Extracted query ---')
print(query)