-
Notifications
You must be signed in to change notification settings - Fork 3.9k
/
test_client.py
146 lines (125 loc) · 5.08 KB
/
test_client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import pytest
from autogen import OpenAIWrapper, config_list_from_json, config_list_openai_aoai
from test_utils import OAI_CONFIG_LIST, KEY_LOC
TOOL_ENABLED = False
try:
from openai import OpenAI
from openai.types.chat.chat_completion import ChatCompletionMessage
except ImportError:
skip = True
else:
skip = False
import openai
if openai.__version__ >= "1.1.0":
TOOL_ENABLED = True
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_aoai_chat_completion():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
# for config in config_list:
# print(config)
# client = OpenAIWrapper(**config)
# response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
response = client.create(messages=[{"role": "user", "content": "2+2="}], cache_seed=None)
print(response)
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip and not TOOL_ENABLED, reason="openai>=1.1.0 not installed")
def test_oai_tool_calling_extraction():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
filter_dict={"api_type": ["azure"], "model": ["gpt-3.5-turbo"]},
)
client = OpenAIWrapper(config_list=config_list)
response = client.create(
messages=[
{
"role": "user",
"content": "What is the weather in San Francisco?",
},
],
tools=[
{
"type": "function",
"function": {
"name": "getCurrentWeather",
"description": "Get the weather in location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string", "description": "The city and state e.g. San Francisco, CA"},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
},
}
],
)
print(response)
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_chat_completion():
config_list = config_list_from_json(
env_or_file=OAI_CONFIG_LIST,
file_location=KEY_LOC,
)
client = OpenAIWrapper(config_list=config_list)
response = client.create(messages=[{"role": "user", "content": "1+1="}])
print(response)
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_completion():
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+1=", model="gpt-3.5-turbo-instruct")
print(response)
print(client.extract_text_or_completion_object(response))
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
@pytest.mark.parametrize(
"cache_seed, model",
[
(None, "gpt-3.5-turbo-instruct"),
(42, "gpt-3.5-turbo-instruct"),
(None, "text-ada-001"),
],
)
def test_cost(cache_seed, model):
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list, cache_seed=cache_seed)
response = client.create(prompt="1+3=", model=model)
print(response.cost)
@pytest.mark.skipif(skip, reason="openai>=1 not installed")
def test_usage_summary():
config_list = config_list_openai_aoai(KEY_LOC)
client = OpenAIWrapper(config_list=config_list)
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=None)
# usage should be recorded
assert client.actual_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
assert client.total_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
# check print
client.print_usage_summary()
# check update
client._update_usage_summary(response, use_cache=True)
assert (
client.total_usage_summary["total_cost"] == response.cost * 2
), "total_cost should be equal to response.cost * 2"
# check clear
client.clear_usage_summary()
assert client.actual_usage_summary is None, "actual_usage_summary should be None"
assert client.total_usage_summary is None, "total_usage_summary should be None"
# actual usage and all usage should be different
response = client.create(prompt="1+3=", model="gpt-3.5-turbo-instruct", cache_seed=42)
assert client.total_usage_summary["total_cost"] > 0, "total_cost should be greater than 0"
assert client.actual_usage_summary is None, "No actual cost should be recorded"
if __name__ == "__main__":
test_aoai_chat_completion()
test_oai_tool_calling_extraction()
test_chat_completion()
test_completion()
test_cost()
test_usage_summary()