Skip to content

Commit

Permalink
Merge pull request #72 from cubenlp/rex/fix-async
Browse files Browse the repository at this point in the history
fix async response
  • Loading branch information
RexWzh committed Mar 8, 2024
2 parents a6afa7f + dad777d commit 2cd06bb
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 41 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ pip install chattool --upgrade

```bash
export OPENAI_API_KEY="sk-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
export OPENAI_API_BASEL="https://api.example.com/v1"
export OPENAI_API_BASE="https://api.example.com/v1"
export OPENAI_API_BASE_URL="https://api.example.com" # 可选
```

Win 在系统中设置环境变量。

注:环境变量中,`OPENAI_API_BASE` 优先于 `OPENAI_API_BASE_URL`,二者选其一即可。

### 示例

示例1,模拟多轮对话:
Expand Down
7 changes: 4 additions & 3 deletions chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '1073853456@qq.com'
__version__ = '3.1.0'
__version__ = '3.1.1'

import os, sys, requests
from .chattool import Chat, Resp
Expand All @@ -27,10 +27,10 @@ def load_envs(env:Union[None, str, dict]=None):
# else: load from environment variables
api_key = os.getenv('OPENAI_API_KEY')
base_url = os.getenv('OPENAI_API_BASE_URL') or "https://api.openai.com"
api_base = os.getenv('OPENAI_API_BASE', os.path.join(base_url, 'v1'))
api_base = os.getenv('OPENAI_API_BASE') or os.path.join(base_url, 'v1')
base_url = request.normalize_url(base_url)
api_base = request.normalize_url(api_base)
model = os.getenv('OPENAI_API_MODEL', "gpt-3.5-turbo")
model = os.getenv('OPENAI_API_MODEL') or "gpt-3.5-turbo"
return True

def save_envs(env_file:str):
Expand Down Expand Up @@ -91,6 +91,7 @@ def debug_log( net_url:str="https://www.baidu.com"
Returns:
bool: True if the debug is finished.
"""
print("Current version:", __version__)
# Network test
try:
requests.get(net_url, timeout=timeout)
Expand Down
7 changes: 6 additions & 1 deletion chattool/asynctool.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,12 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
api_key = chattool.api_key
assert api_key is not None, "API key is not provided!"
if chat_url is None:
chat_url = os.path.join(chattool.base_url, "v1/chat/completions")
if chattool.api_base:
chat_url = os.path.join(chattool.api_base, "chat/completions")
elif chattool.base_url:
chat_url = os.path.join(chattool.base_url, "v1/chat/completions")
else:
raise Exception("chat_url is not provided!")
chat_url = chattool.request.normalize_url(chat_url)
# run async process
assert nproc > 0, "nproc must be greater than 0!"
Expand Down
39 changes: 27 additions & 12 deletions chattool/chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,20 +251,24 @@ async def async_stream_responses(self, timeout:int=0, textonly:bool=False):
if not line: break
# strip the prefix of `data: {...}`
strline = line.decode().lstrip('data:').strip()
if strline == '[DONE]': break
# skip empty line
if not strline: continue
# read the json string
line = json.loads(strline)
# wrap the response
resp = Resp(line)
# stop if the response is finished
if resp.finish_reason == 'stop': break
# deal with the message
if 'content' not in resp.delta: continue
if textonly:
yield resp.delta_content
else:
yield resp
try:
# wrap the response
resp = Resp(json.loads(strline))
# stop if the response is finished
if resp.finish_reason == 'stop': break
# deal with the message
if 'content' not in resp.delta: continue
if textonly:
yield resp.delta_content
else:
yield resp
except Exception as e:
print(f"Error: {e}, line: {strline}")
break

# Part3: function call
def iswaiting(self):
Expand Down Expand Up @@ -353,7 +357,8 @@ def get_valid_models(self, gpt_only:bool=True)->List[str]:
Returns:
List[str]: valid models
"""
return valid_models(self.api_key, self.base_url, gpt_only=gpt_only)
model_url = os.path.join(self.api_base, 'models')
return valid_models(self.api_key, model_url, gpt_only=gpt_only)

# Part5: properties and setters
@property
Expand Down Expand Up @@ -395,6 +400,11 @@ def base_url(self):
"""Get base url"""
return self._base_url

@property
def api_base(self):
"""Get base url"""
return self._api_base

@property
def functions(self):
"""Get functions"""
Expand Down Expand Up @@ -424,6 +434,11 @@ def chat_url(self, chat_url:str):
def base_url(self, base_url:str):
"""Set base url"""
self._base_url = base_url

@api_base.setter
def api_base(self, api_base:str):
"""Set base url"""
self._api_base = api_base

@functions.setter
def functions(self, functions:List[Dict]):
Expand Down
11 changes: 5 additions & 6 deletions chattool/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def chat_completion( api_key:str
raise Exception(response.text)
return response.json()

def valid_models(api_key:str, base_url:str, gpt_only:bool=True):
def valid_models(api_key:str, model_url:str, gpt_only:bool=True):
"""Get valid models
Request url: https://api.openai.com/v1/models
Expand All @@ -97,14 +97,13 @@ def valid_models(api_key:str, base_url:str, gpt_only:bool=True):
"Authorization": "Bearer " + api_key,
"Content-Type": "application/json"
}
models_url = normalize_url(os.path.join(base_url, "v1/models"))
models_response = requests.get(models_url, headers=headers)
if models_response.status_code == 200:
data = models_response.json()
model_response = requests.get(normalize_url(model_url), headers=headers)
if model_response.status_code == 200:
data = model_response.json()
model_list = [model.get("id") for model in data.get("data")]
return [model for model in model_list if "gpt" in model] if gpt_only else model_list
else:
raise Exception(models_response.text)
raise Exception(model_response.text)

def loadfile(api_key:str, base_url:str, file:str, purpose:str='fine-tune'):
"""Upload a file that can be used across various endpoints/features.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.1.0'
VERSION = '3.1.1'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down
12 changes: 1 addition & 11 deletions tests/__init__.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,9 @@
"""Unit test package for chattool."""

from chattool import Chat
from chattool import Chat, debug_log
import os

if not os.path.exists('tests'):
os.mkdir('tests')
if not os.path.exists('tests/testfiles'):
os.mkdir('tests/testfiles')

def test_simple():
# set api_key in the environment variable
chat = Chat()
chat.user("Hello!")
chat.getresponse()
chat.print_log()
assert chat.chat_log[0] == {"role": "user", "content": "Hello!"}
assert len(chat.chat_log) == 2

File renamed without changes.
13 changes: 11 additions & 2 deletions tests/test_async.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import chattool, time, os
from chattool import Chat, process_chats
from chattool import Chat, process_chats, debug_log
from chattool.asynctool import async_chat_completion
import asyncio, pytest

Expand All @@ -10,6 +10,15 @@
]
testpath = 'tests/testfiles/'

def test_simple():
# set api_key in the environment variable
debug_log()
chat = Chat()
chat.user("Hello!")
chat.getresponse()
assert chat.chat_log[0] == {"role": "user", "content": "Hello!"}
assert len(chat.chat_log) == 2

def test_apikey():
assert chattool.api_key.startswith("sk-")

Expand Down Expand Up @@ -45,7 +54,7 @@ def test_async_process():
chkpoint = testpath + "test_async.jsonl"
t = time.time()
resp = async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, nproc=3)
resp = async_chat_completion(chatlogs, chkpoint, clearfile=True, nproc=3)
resp = async_chat_completion(chatlogs, chkpoint, nproc=3)
assert all(resp)
print(f"Time elapsed: {time.time() - t:.2f}s")

Expand Down
12 changes: 8 additions & 4 deletions tests/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
create_finetune_job, list_finetune_job, retrievejob,
listevents, canceljob, deletemodel
)
import pytest, chattool
api_key, base_url = chattool.api_key, chattool.base_url
import pytest, chattool, os
api_key, base_url, api_base = chattool.api_key, chattool.base_url, chattool.api_base
testpath = 'tests/testfiles/'

def test_valid_models():
models = valid_models(api_key, base_url, gpt_only=False)
if chattool.api_base:
model_url = os.path.join(chattool.api_base, 'models')
else:
model_url = os.path.join(chattool.base_url, 'v1/models')
models = valid_models(api_key, model_url, gpt_only=False)
assert len(models) >= 1
models = valid_models(api_key, base_url, gpt_only=True)
models = valid_models(api_key, model_url, gpt_only=True)
assert len(models) >= 1
assert 'gpt-3.5-turbo' in models

Expand Down

0 comments on commit 2cd06bb

Please sign in to comment.