-
Notifications
You must be signed in to change notification settings - Fork 13
Expand file tree
/
Copy pathbasic_classifier.py
More file actions
83 lines (60 loc) · 2.03 KB
/
basic_classifier.py
File metadata and controls
83 lines (60 loc) · 2.03 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
'''
Created on Feb 6, 2024
@author: immanueltrummer
'''
import argparse
import openai
import pandas as pd
import time
client = openai.OpenAI()
def create_prompt(text):
""" Create prompt for sentiment classification.
Args:
text: text to classify.
Returns:
Prompt for text classification.
"""
task = 'Is the sentiment positive or negative?'
answer_format = 'Answer ("pos"/"neg")'
return f'{text}\n{task}\n{answer_format}:'
def call_llm(prompt):
""" Query large language model and return answer.
Args:
prompt: input prompt for language model.
Returns:
Answer by language model and total number of tokens.
"""
for nr_retries in range(1, 4):
try:
response = client.chat.completions.create(
model='gpt-3.5-turbo',
messages=[
{'role':'user', 'content':prompt}
],
temperature=0
)
answer = response.choices[0].message.content
nr_tokens = response.usage.total_tokens
return answer, nr_tokens
except Exception as e:
print(f'Exception: {e}')
time.sleep(nr_retries * 2)
raise Exception('Cannot query OpenAI model!')
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('file_path', type=str, help='Path to input file')
args = parser.parse_args()
df = pd.read_csv(args.file_path)
nr_correct = 0
nr_tokens = 0
for _, row in df.iterrows():
text = row['text']
prompt = create_prompt(text)
label, current_tokens = call_llm(prompt)
ground_truth = row['sentiment']
if label == ground_truth:
nr_correct += 1
nr_tokens += current_tokens
print(f'Label: {label}; Ground truth: {ground_truth}')
print(f'Number of correct labels:\t{nr_correct}')
print(f'Number of tokens used :\t{nr_tokens}')