Skip to content

Commit

Permalink
Merge pull request #76 from cubenlp/rex/dev
Browse files Browse the repository at this point in the history
url for api_base and base_url
  • Loading branch information
RexWzh committed Apr 5, 2024
2 parents c02bfe1 + fbb49d5 commit df26473
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 53 deletions.
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ chat.print_log()

```python
# 编写处理函数
def msg2chat(msg):
def data2chat(msg):
chat = Chat()
chat.system("你是一个熟练的数字翻译家。")
chat.user(f"请将该数字翻译为罗马数字:{msg}")
Expand All @@ -81,10 +81,8 @@ def msg2chat(msg):

checkpoint = "chat.jsonl" # 缓存文件的名称
msgs = ["1", "2", "3", "4", "5", "6", "7", "8", "9"]
# 处理部分数据
chats = process_chats(msgs[:5], msg2chat, checkpoint)
# 处理所有数据,从上一次处理的位置继续
continue_chats = process_chats(msgs, msg2chat, checkpoint)
# 处理数据,如果 checkpoint 存在,则从上次中断处继续
continue_chats = process_chats(msgs, data2chat, checkpoint)
```

示例3,批量处理数据(异步并行),用不同语言打印 hello,并使用两个协程:
Expand All @@ -103,6 +101,12 @@ async_chat_completion(langs, chkpoint="async_chat.jsonl", nproc=2, data2chat=dat
chats = load_chats("async_chat.jsonl")
```

在 Jupyter Notebook 中运行,需要使用 `await` 关键字和 `wait=True` 参数:

```python
await async_chat_completion(langs, chkpoint="async_chat.jsonl", nproc=2, data2chat=data2chat, wait=True)
```

示例4,使用工具(自定义函数):

```python
Expand Down
10 changes: 7 additions & 3 deletions README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def msg2chat(msg):

checkpoint = "chat.jsonl"
msgs = ["%d" % i for i in range(1, 10)]
# process the data
chats = process_chats(msgs[:5], msg2chat, checkpoint)
# process the rest data, and read the cache from the last time
# process the data in batch, if the checkpoint file exists, it will continue from the last checkpoint
continue_chats = process_chats(msgs, msg2chat, checkpoint)
```

Expand All @@ -104,6 +102,12 @@ async_chat_completion(langs, chkpoint="async_chat.jsonl", nproc=2, data2chat=dat
chats = load_chats("async_chat.jsonl")
```

when using `async_chat_completion` in Jupyter notebook, you should use the `await` keyword and the `wait=True` parameter:

```python
await async_chat_completion(langs, chkpoint="async_chat.jsonl", nproc=2, data2chat=data2chat, wait=True)
```

## License

This package is licensed under the MIT license. See the LICENSE file for more details.
Expand Down
45 changes: 31 additions & 14 deletions chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '1073853456@qq.com'
__version__ = '3.1.4'
__version__ = '3.1.5'

import os, sys, requests
from .chattype import Chat, Resp
Expand All @@ -13,7 +13,24 @@
from .asynctool import async_chat_completion
from .functioncall import generate_json_schema, exec_python_code
from typing import Union
import dotenv
import dotenv

raw_env_text = f"""# Description: Env file for ChatTool.
# Current version: {__version__}
# The base url of the API (with suffix /v1)
# This will override OPENAI_API_BASE_URL if both are set.
OPENAI_API_BASE=''
# The base url of the API (without suffix /v1)
OPENAI_API_BASE_URL=''
# Your API key
OPENAI_API_KEY=''
# The default model name
OPENAI_API_MODEL=''
"""

def load_envs(env:Union[None, str, dict]=None):
"""Read the environment variables for the API call"""
Expand All @@ -26,21 +43,21 @@ def load_envs(env:Union[None, str, dict]=None):
os.environ[key] = value
# else: load from environment variables
api_key = os.getenv('OPENAI_API_KEY')
base_url = os.getenv('OPENAI_API_BASE_URL') or "https://api.openai.com"
api_base = os.getenv('OPENAI_API_BASE') or os.path.join(base_url, 'v1')
base_url = request.normalize_url(base_url)
api_base = request.normalize_url(api_base)
model = os.getenv('OPENAI_API_MODEL') or "gpt-3.5-turbo"
base_url = os.getenv('OPENAI_API_BASE_URL') # or "https://api.openai.com"
api_base = os.getenv('OPENAI_API_BASE') # or os.path.join(base_url, 'v1')
model = os.getenv('OPENAI_API_MODEL') # or "gpt-3.5-turbo"
return True

def save_envs(env_file:str):
"""Save the environment variables for the API call"""
global api_key, base_url, model
global api_key, base_url, model, api_base
set_key = lambda key, value: dotenv.set_key(env_file, key, value) if value else None
with open(env_file, "w") as f:
f.write(f"OPENAI_API_KEY={api_key}\n")
f.write(f"OPENAI_API_BASE_URL={base_url}\n")
f.write(f"OPENAI_API_BASE={api_base}\n")
f.write(f"OPENAI_API_MODEL={model}\n")
f.write(raw_env_text)
set_key('OPENAI_API_KEY', api_key)
set_key('OPENAI_API_BASE_URL', base_url)
set_key('OPENAI_API_BASE', api_base)
set_key('OPENAI_API_MODEL', model)
return True

# load the environment variables
Expand Down Expand Up @@ -100,14 +117,14 @@ def debug_log( net_url:str="https://www.baidu.com"
return False

## Check the proxy status
print("\nCheck your proxy:" +\
print("\nCheck your proxy: " +\
"This is not necessary if the base url is already a proxy link.")
proxy_status()

## Base url
print("\nCheck the value OPENAI_API_BASE_URL:")
print(base_url)
print("\nCheck the value OPENAI_API_BASE:" +\
print("\nCheck the value OPENAI_API_BASE: " +\
"This will override OPENAI_API_BASE_URL if both are set.")
print(api_base)

Expand Down
50 changes: 34 additions & 16 deletions chattool/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class Chat():
def __init__( self
, msg:Union[List[Dict], None, str]=None
, api_key:Union[None, str]=None
, chat_url:Union[None, str]=None
, base_url:Union[None, str]=None
, api_base:Union[None, str]=None
, base_url:Union[None, str]=None
, chat_url:Union[None, str]=None
, model:Union[None, str]=None
, functions:Union[None, List[Dict]]=None
, function_call:Union[None, str]=None
Expand All @@ -25,8 +25,9 @@ def __init__( self
Args:
msg (Union[List[Dict], None, str], optional): chat log. Defaults to None.
api_key (Union[None, str], optional): API key. Defaults to None.
chat_url (Union[None, str], optional): base url. Defaults to None. Example: "https://api.openai.com/v1/chat/completions"
base_url (Union[None, str], optional): base url. Defaults to None. Example: "https://api.openai.com"
api_base (Union[None, str], optional): base url with suffix "/v1". Defaults to None. Example: "https://api.openai.com/v1"
base_url (Union[None, str], optional): base url without suffix "/v1". Defaults to None. Example: "https://api.openai.com"
chat_url (Union[None, str], optional): chat completion url. Defaults to None. Example: "https://api.openai.com/v1/chat/completions"
model (Union[None, str], optional): model to use. Defaults to None.
functions (Union[None, List[Dict]], optional): functions to use, each function is a JSON Schema. Defaults to None.
function_call (str, optional): method to call the function. Defaults to None. Choices: ['auto', '$NameOfTheFunction', 'none']
Expand All @@ -45,15 +46,23 @@ def __init__( self
self._chat_log = msg.copy() # avoid changing the original list
else:
raise ValueError("msg should be a list of dict, a string or None")
self._api_key = api_key or chattool.api_key
# try: api_base => base_url => chattool.api_base => chattool.base_url
if api_base is None:
api_base = os.path.join(base_url, 'v1') if base_url is not None else chattool.api_base
base_url = base_url or chattool.base_url
self._base_url = base_url
self._api_base = api_base or os.path.join(base_url, "v1")
self._chat_url = chat_url or self._api_base.rstrip('/') + '/chat/completions'
self._model = model or chattool.model
self.api_key = api_key or chattool.api_key
# chat_url > api_base > base_url > chattool.api_base > chattool.base_url
self.api_base = api_base or chattool.api_base
self.base_url = base_url or chattool.base_url
self.model = model or chattool.model or "gpt-3.5-turbo"
if chat_url:
self.chat_url = chat_url
elif api_base:
self.chat_url = os.path.join(self.api_base, "chat/completions")
elif base_url:
self.chat_url = os.path.join(self.base_url, "v1/chat/completions")
elif chattool.api_base:
self.chat_url = os.path.join(chattool.api_base, "chat/completions")
elif chattool.base_url:
self.chat_url = os.path.join(chattool.base_url, "v1/chat/completions")
else:
self.chat_url = "https://api.openai.com/v1/chat/completions"
if functions is not None:
assert isinstance(functions, list), "functions should be a list of dict"
self._functions, self._function_call = functions, function_call
Expand Down Expand Up @@ -104,6 +113,7 @@ def deepcopy(self):
, functions=self.functions
, function_call=self.function_call
, name2func=self.name2func
, api_base=self.api_base
, base_url=self.base_url)

def save(self, path:str, mode:str='a', index:int=0):
Expand Down Expand Up @@ -333,7 +343,11 @@ def get_valid_models(self, gpt_only:bool=True)->List[str]:
Returns:
List[str]: valid models
"""
model_url = os.path.join(self.api_base, 'models')
model_url = "https://api.openai.com/v1/models"
if self.api_base:
model_url = os.path.join(self.api_base, 'models')
elif self.base_url:
model_url = os.path.join(self.base_url, 'v1/models')
return valid_models(self.api_key, model_url, gpt_only=gpt_only)

# Part5: properties and setters
Expand Down Expand Up @@ -407,13 +421,17 @@ def chat_url(self, chat_url:str):
self._chat_url = chat_url

@base_url.setter
def base_url(self, base_url:str):
def base_url(self, base_url:Union[None, str]):
"""Set base url"""
if base_url:
self.chat_url = base_url.rstrip('/') + '/v1/chat/completions'
self._base_url = base_url

@api_base.setter
def api_base(self, api_base:str):
def api_base(self, api_base:Union[None, str]):
"""Set base url"""
if api_base:
self.chat_url = api_base.rstrip('/') + '/chat/completions'
self._api_base = api_base

@functions.setter
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.1.4'
VERSION = '3.1.5'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down
15 changes: 1 addition & 14 deletions tests/test_chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,10 @@
from click.testing import CliRunner
import chattool, json, os
from chattool import cli
from chattool import Chat, Resp, findcost, load_envs, save_envs
from chattool import Chat, Resp, findcost
import pytest
testpath = 'tests/testfiles/'

def test_env_file():
save_envs(testpath + "chattool.env")
with open(testpath + "test.env", "w") as f:
f.write("OPENAI_API_KEY=sk-132\n")
f.write("OPENAI_API_BASE_URL=https://api.example.com\n")
f.write("OPENAI_API_MODEL=gpt-3.5-turbo-0301\n")
load_envs(testpath + "test.env")
assert chattool.api_key == "sk-132"
assert chattool.base_url == "https://api.example.com"
assert chattool.model == "gpt-3.5-turbo-0301"
# reset the environment variables
load_envs(testpath + "chattool.env")


def test_command_line_interface():
"""Test the CLI."""
Expand Down
46 changes: 46 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Test for api_base, base_url, chat_url

import chattool
from chattool import Chat, save_envs, load_envs
testpath = 'tests/testfiles/'

def test_apibase():
api_base, base_url = chattool.api_base, chattool.base_url
chattool.api_base, chattool.base_url = None, None
# chat_url > api_base > base_url > chattool.api_base > chattool.base_url
# chat_url > api_base
chat = Chat(chat_url="https://api.pytest1.com/v1/chat/completions", api_base="https://api.pytest2.com/v1")
assert chat.chat_url == "https://api.pytest1.com/v1/chat/completions"

# api_base > base_url
chat = Chat(api_base="https://api.pytest2.com/v1", base_url="https://api.pytest3.com")
assert chat.chat_url == "https://api.pytest2.com/v1/chat/completions"

# base_url > chattool.api_base
chattool.api_base = "https://api.pytest2.com/v1"
chat = Chat(base_url="https://api.pytest3.com")
assert chat.chat_url == "https://api.pytest3.com/v1/chat/completions"

# chattool.api_base > chattool.base_url
chattool.base_url = "https://api.pytest4.com"
chat = Chat()
assert chat.chat_url == "https://api.pytest2.com/v1/chat/completions"

# base_url > chattool.api_base, chattool.base_url
chat = Chat(base_url="https://api.pytest3.com")
assert chat.chat_url == "https://api.pytest3.com/v1/chat/completions"

chattool.api_base, chattool.base_url = api_base, base_url

def test_env_file():
save_envs(testpath + "chattool.env")
with open(testpath + "test.env", "w") as f:
f.write("OPENAI_API_KEY=sk-132\n")
f.write("OPENAI_API_BASE_URL=https://api.example.com\n")
f.write("OPENAI_API_MODEL=gpt-3.5-turbo-0301\n")
load_envs(testpath + "test.env")
assert chattool.api_key == "sk-132"
assert chattool.base_url == "https://api.example.com"
assert chattool.model == "gpt-3.5-turbo-0301"
# reset the environment variables
load_envs(testpath + "chattool.env")

0 comments on commit df26473

Please sign in to comment.