Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

url for api_base and base_url #76

Merged
merged 1 commit into from
Apr 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ chat.print_log()

```python
# 编写处理函数
def msg2chat(msg):
def data2chat(msg):
chat = Chat()
chat.system("你是一个熟练的数字翻译家。")
chat.user(f"请将该数字翻译为罗马数字:{msg}")
Expand All @@ -81,10 +81,8 @@ def msg2chat(msg):

checkpoint = "chat.jsonl" # 缓存文件的名称
msgs = ["1", "2", "3", "4", "5", "6", "7", "8", "9"]
# 处理部分数据
chats = process_chats(msgs[:5], msg2chat, checkpoint)
# 处理所有数据,从上一次处理的位置继续
continue_chats = process_chats(msgs, msg2chat, checkpoint)
# 处理数据,如果 checkpoint 存在,则从上次中断处继续
continue_chats = process_chats(msgs, data2chat, checkpoint)
```

示例3,批量处理数据(异步并行),用不同语言打印 hello,并使用两个协程:
Expand All @@ -103,6 +101,12 @@ async_chat_completion(langs, chkpoint="async_chat.jsonl", nproc=2, data2chat=dat
chats = load_chats("async_chat.jsonl")
```

在 Jupyter Notebook 中运行,需要使用 `await` 关键字和 `wait=True` 参数:

```python
await async_chat_completion(langs, chkpoint="async_chat.jsonl", nproc=2, data2chat=data2chat, wait=True)
```

示例4,使用工具(自定义函数):

```python
Expand Down
10 changes: 7 additions & 3 deletions README_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,7 @@ def msg2chat(msg):

checkpoint = "chat.jsonl"
msgs = ["%d" % i for i in range(1, 10)]
# process the data
chats = process_chats(msgs[:5], msg2chat, checkpoint)
# process the rest data, and read the cache from the last time
# process the data in batch, if the checkpoint file exists, it will continue from the last checkpoint
continue_chats = process_chats(msgs, msg2chat, checkpoint)
```

Expand All @@ -104,6 +102,12 @@ async_chat_completion(langs, chkpoint="async_chat.jsonl", nproc=2, data2chat=dat
chats = load_chats("async_chat.jsonl")
```

when using `async_chat_completion` in Jupyter notebook, you should use the `await` keyword and the `wait=True` parameter:

```python
await async_chat_completion(langs, chkpoint="async_chat.jsonl", nproc=2, data2chat=data2chat, wait=True)
```

## License

This package is licensed under the MIT license. See the LICENSE file for more details.
Expand Down
45 changes: 31 additions & 14 deletions chattool/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

__author__ = """Rex Wang"""
__email__ = '1073853456@qq.com'
__version__ = '3.1.4'
__version__ = '3.1.5'

import os, sys, requests
from .chattype import Chat, Resp
Expand All @@ -13,7 +13,24 @@
from .asynctool import async_chat_completion
from .functioncall import generate_json_schema, exec_python_code
from typing import Union
import dotenv
import dotenv

raw_env_text = f"""# Description: Env file for ChatTool.
# Current version: {__version__}

# The base url of the API (with suffix /v1)
# This will override OPENAI_API_BASE_URL if both are set.
OPENAI_API_BASE=''

# The base url of the API (without suffix /v1)
OPENAI_API_BASE_URL=''

# Your API key
OPENAI_API_KEY=''

# The default model name
OPENAI_API_MODEL=''
"""

def load_envs(env:Union[None, str, dict]=None):
"""Read the environment variables for the API call"""
Expand All @@ -26,21 +43,21 @@ def load_envs(env:Union[None, str, dict]=None):
os.environ[key] = value
# else: load from environment variables
api_key = os.getenv('OPENAI_API_KEY')
base_url = os.getenv('OPENAI_API_BASE_URL') or "https://api.openai.com"
api_base = os.getenv('OPENAI_API_BASE') or os.path.join(base_url, 'v1')
base_url = request.normalize_url(base_url)
api_base = request.normalize_url(api_base)
model = os.getenv('OPENAI_API_MODEL') or "gpt-3.5-turbo"
base_url = os.getenv('OPENAI_API_BASE_URL') # or "https://api.openai.com"
api_base = os.getenv('OPENAI_API_BASE') # or os.path.join(base_url, 'v1')
model = os.getenv('OPENAI_API_MODEL') # or "gpt-3.5-turbo"
return True

def save_envs(env_file:str):
"""Save the environment variables for the API call"""
global api_key, base_url, model
global api_key, base_url, model, api_base
set_key = lambda key, value: dotenv.set_key(env_file, key, value) if value else None
with open(env_file, "w") as f:
f.write(f"OPENAI_API_KEY={api_key}\n")
f.write(f"OPENAI_API_BASE_URL={base_url}\n")
f.write(f"OPENAI_API_BASE={api_base}\n")
f.write(f"OPENAI_API_MODEL={model}\n")
f.write(raw_env_text)
set_key('OPENAI_API_KEY', api_key)
set_key('OPENAI_API_BASE_URL', base_url)
set_key('OPENAI_API_BASE', api_base)
set_key('OPENAI_API_MODEL', model)
return True

# load the environment variables
Expand Down Expand Up @@ -100,14 +117,14 @@ def debug_log( net_url:str="https://www.baidu.com"
return False

## Check the proxy status
print("\nCheck your proxy:" +\
print("\nCheck your proxy: " +\
"This is not necessary if the base url is already a proxy link.")
proxy_status()

## Base url
print("\nCheck the value OPENAI_API_BASE_URL:")
print(base_url)
print("\nCheck the value OPENAI_API_BASE:" +\
print("\nCheck the value OPENAI_API_BASE: " +\
"This will override OPENAI_API_BASE_URL if both are set.")
print(api_base)

Expand Down
50 changes: 34 additions & 16 deletions chattool/chattype.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ class Chat():
def __init__( self
, msg:Union[List[Dict], None, str]=None
, api_key:Union[None, str]=None
, chat_url:Union[None, str]=None
, base_url:Union[None, str]=None
, api_base:Union[None, str]=None
, base_url:Union[None, str]=None
, chat_url:Union[None, str]=None
, model:Union[None, str]=None
, functions:Union[None, List[Dict]]=None
, function_call:Union[None, str]=None
Expand All @@ -25,8 +25,9 @@ def __init__( self
Args:
msg (Union[List[Dict], None, str], optional): chat log. Defaults to None.
api_key (Union[None, str], optional): API key. Defaults to None.
chat_url (Union[None, str], optional): base url. Defaults to None. Example: "https://api.openai.com/v1/chat/completions"
base_url (Union[None, str], optional): base url. Defaults to None. Example: "https://api.openai.com"
api_base (Union[None, str], optional): base url with suffix "/v1". Defaults to None. Example: "https://api.openai.com/v1"
base_url (Union[None, str], optional): base url without suffix "/v1". Defaults to None. Example: "https://api.openai.com"
chat_url (Union[None, str], optional): chat completion url. Defaults to None. Example: "https://api.openai.com/v1/chat/completions"
model (Union[None, str], optional): model to use. Defaults to None.
functions (Union[None, List[Dict]], optional): functions to use, each function is a JSON Schema. Defaults to None.
function_call (str, optional): method to call the function. Defaults to None. Choices: ['auto', '$NameOfTheFunction', 'none']
Expand All @@ -45,15 +46,23 @@ def __init__( self
self._chat_log = msg.copy() # avoid changing the original list
else:
raise ValueError("msg should be a list of dict, a string or None")
self._api_key = api_key or chattool.api_key
# try: api_base => base_url => chattool.api_base => chattool.base_url
if api_base is None:
api_base = os.path.join(base_url, 'v1') if base_url is not None else chattool.api_base
base_url = base_url or chattool.base_url
self._base_url = base_url
self._api_base = api_base or os.path.join(base_url, "v1")
self._chat_url = chat_url or self._api_base.rstrip('/') + '/chat/completions'
self._model = model or chattool.model
self.api_key = api_key or chattool.api_key
# chat_url > api_base > base_url > chattool.api_base > chattool.base_url
self.api_base = api_base or chattool.api_base
self.base_url = base_url or chattool.base_url
self.model = model or chattool.model or "gpt-3.5-turbo"
if chat_url:
self.chat_url = chat_url
elif api_base:
self.chat_url = os.path.join(self.api_base, "chat/completions")
elif base_url:
self.chat_url = os.path.join(self.base_url, "v1/chat/completions")
elif chattool.api_base:
self.chat_url = os.path.join(chattool.api_base, "chat/completions")
elif chattool.base_url:
self.chat_url = os.path.join(chattool.base_url, "v1/chat/completions")
else:
self.chat_url = "https://api.openai.com/v1/chat/completions"
if functions is not None:
assert isinstance(functions, list), "functions should be a list of dict"
self._functions, self._function_call = functions, function_call
Expand Down Expand Up @@ -104,6 +113,7 @@ def deepcopy(self):
, functions=self.functions
, function_call=self.function_call
, name2func=self.name2func
, api_base=self.api_base
, base_url=self.base_url)

def save(self, path:str, mode:str='a', index:int=0):
Expand Down Expand Up @@ -333,7 +343,11 @@ def get_valid_models(self, gpt_only:bool=True)->List[str]:
Returns:
List[str]: valid models
"""
model_url = os.path.join(self.api_base, 'models')
model_url = "https://api.openai.com/v1/models"
if self.api_base:
model_url = os.path.join(self.api_base, 'models')
elif self.base_url:
model_url = os.path.join(self.base_url, 'v1/models')
return valid_models(self.api_key, model_url, gpt_only=gpt_only)

# Part5: properties and setters
Expand Down Expand Up @@ -407,13 +421,17 @@ def chat_url(self, chat_url:str):
self._chat_url = chat_url

@base_url.setter
def base_url(self, base_url:str):
def base_url(self, base_url:Union[None, str]):
"""Set base url"""
if base_url:
self.chat_url = base_url.rstrip('/') + '/v1/chat/completions'
self._base_url = base_url

@api_base.setter
def api_base(self, api_base:str):
def api_base(self, api_base:Union[None, str]):
"""Set base url"""
if api_base:
self.chat_url = api_base.rstrip('/') + '/chat/completions'
self._api_base = api_base

@functions.setter
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
with open('README.md') as readme_file:
readme = readme_file.read()

VERSION = '3.1.4'
VERSION = '3.1.5'

requirements = [
'Click>=7.0', 'requests>=2.20', "responses>=0.23", 'aiohttp>=3.8',
Expand Down
15 changes: 1 addition & 14 deletions tests/test_chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,10 @@
from click.testing import CliRunner
import chattool, json, os
from chattool import cli
from chattool import Chat, Resp, findcost, load_envs, save_envs
from chattool import Chat, Resp, findcost
import pytest
testpath = 'tests/testfiles/'

def test_env_file():
save_envs(testpath + "chattool.env")
with open(testpath + "test.env", "w") as f:
f.write("OPENAI_API_KEY=sk-132\n")
f.write("OPENAI_API_BASE_URL=https://api.example.com\n")
f.write("OPENAI_API_MODEL=gpt-3.5-turbo-0301\n")
load_envs(testpath + "test.env")
assert chattool.api_key == "sk-132"
assert chattool.base_url == "https://api.example.com"
assert chattool.model == "gpt-3.5-turbo-0301"
# reset the environment variables
load_envs(testpath + "chattool.env")


def test_command_line_interface():
"""Test the CLI."""
Expand Down
46 changes: 46 additions & 0 deletions tests/test_envs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Test for api_base, base_url, chat_url

import chattool
from chattool import Chat, save_envs, load_envs
testpath = 'tests/testfiles/'

def test_apibase():
api_base, base_url = chattool.api_base, chattool.base_url
chattool.api_base, chattool.base_url = None, None
# chat_url > api_base > base_url > chattool.api_base > chattool.base_url
# chat_url > api_base
chat = Chat(chat_url="https://api.pytest1.com/v1/chat/completions", api_base="https://api.pytest2.com/v1")
assert chat.chat_url == "https://api.pytest1.com/v1/chat/completions"

# api_base > base_url
chat = Chat(api_base="https://api.pytest2.com/v1", base_url="https://api.pytest3.com")
assert chat.chat_url == "https://api.pytest2.com/v1/chat/completions"

# base_url > chattool.api_base
chattool.api_base = "https://api.pytest2.com/v1"
chat = Chat(base_url="https://api.pytest3.com")
assert chat.chat_url == "https://api.pytest3.com/v1/chat/completions"

# chattool.api_base > chattool.base_url
chattool.base_url = "https://api.pytest4.com"
chat = Chat()
assert chat.chat_url == "https://api.pytest2.com/v1/chat/completions"

# base_url > chattool.api_base, chattool.base_url
chat = Chat(base_url="https://api.pytest3.com")
assert chat.chat_url == "https://api.pytest3.com/v1/chat/completions"

chattool.api_base, chattool.base_url = api_base, base_url

def test_env_file():
save_envs(testpath + "chattool.env")
with open(testpath + "test.env", "w") as f:
f.write("OPENAI_API_KEY=sk-132\n")
f.write("OPENAI_API_BASE_URL=https://api.example.com\n")
f.write("OPENAI_API_MODEL=gpt-3.5-turbo-0301\n")
load_envs(testpath + "test.env")
assert chattool.api_key == "sk-132"
assert chattool.base_url == "https://api.example.com"
assert chattool.model == "gpt-3.5-turbo-0301"
# reset the environment variables
load_envs(testpath + "chattool.env")
Loading