Skip to content

Commit

Permalink
fix valid model options
Browse files Browse the repository at this point in the history
  • Loading branch information
RexWzh committed Mar 8, 2024
1 parent 4ef16e9 commit dad777d
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 14 deletions.
7 changes: 6 additions & 1 deletion chattool/asynctool.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,12 @@ def async_chat_completion( msgs:Union[List[List[Dict]], str]
api_key = chattool.api_key
assert api_key is not None, "API key is not provided!"
if chat_url is None:
chat_url = os.path.join(chattool.base_url, "v1/chat/completions")
if chattool.api_base:
chat_url = os.path.join(chattool.api_base, "chat/completions")
elif chattool.base_url:
chat_url = os.path.join(chattool.base_url, "v1/chat/completions")
else:
raise Exception("chat_url is not provided!")
chat_url = chattool.request.normalize_url(chat_url)
# run async process
assert nproc > 0, "nproc must be greater than 0!"
Expand Down
13 changes: 12 additions & 1 deletion chattool/chattool.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,8 @@ def get_valid_models(self, gpt_only:bool=True)->List[str]:
Returns:
List[str]: valid models
"""
return valid_models(self.api_key, self.base_url, gpt_only=gpt_only)
model_url = os.path.join(self.api_base, 'models')
return valid_models(self.api_key, model_url, gpt_only=gpt_only)

# Part5: properties and setters
@property
Expand Down Expand Up @@ -399,6 +400,11 @@ def base_url(self):
"""Get base url"""
return self._base_url

@property
def api_base(self):
"""Get base url"""
return self._api_base

@property
def functions(self):
"""Get functions"""
Expand Down Expand Up @@ -428,6 +434,11 @@ def chat_url(self, chat_url:str):
def base_url(self, base_url:str):
"""Set base url"""
self._base_url = base_url

@api_base.setter
def api_base(self, api_base:str):
"""Set base url"""
self._api_base = api_base

@functions.setter
def functions(self, functions:List[Dict]):
Expand Down
11 changes: 5 additions & 6 deletions chattool/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def chat_completion( api_key:str
raise Exception(response.text)
return response.json()

def valid_models(api_key:str, base_url:str, gpt_only:bool=True):
def valid_models(api_key:str, model_url:str, gpt_only:bool=True):
"""Get valid models
Request url: https://api.openai.com/v1/models
Expand All @@ -97,14 +97,13 @@ def valid_models(api_key:str, base_url:str, gpt_only:bool=True):
"Authorization": "Bearer " + api_key,
"Content-Type": "application/json"
}
models_url = normalize_url(os.path.join(base_url, "v1/models"))
models_response = requests.get(models_url, headers=headers)
if models_response.status_code == 200:
data = models_response.json()
model_response = requests.get(normalize_url(model_url), headers=headers)
if model_response.status_code == 200:
data = model_response.json()
model_list = [model.get("id") for model in data.get("data")]
return [model for model in model_list if "gpt" in model] if gpt_only else model_list
else:
raise Exception(models_response.text)
raise Exception(model_response.text)

def loadfile(api_key:str, base_url:str, file:str, purpose:str='fine-tune'):
"""Upload a file that can be used across various endpoints/features.
Expand Down
4 changes: 2 additions & 2 deletions tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

def test_simple():
# set api_key in the environment variable
debug_log()
chat = Chat()
chat.user("Hello!")
chat.getresponse()
debug_log()
assert chat.chat_log[0] == {"role": "user", "content": "Hello!"}
assert len(chat.chat_log) == 2

Expand Down Expand Up @@ -54,7 +54,7 @@ def test_async_process():
chkpoint = testpath + "test_async.jsonl"
t = time.time()
resp = async_chat_completion(chatlogs[:1], chkpoint, clearfile=True, nproc=3)
resp = async_chat_completion(chatlogs, chkpoint, clearfile=True, nproc=3)
resp = async_chat_completion(chatlogs, chkpoint, nproc=3)
assert all(resp)
print(f"Time elapsed: {time.time() - t:.2f}s")

Expand Down
12 changes: 8 additions & 4 deletions tests/test_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,18 @@
create_finetune_job, list_finetune_job, retrievejob,
listevents, canceljob, deletemodel
)
import pytest, chattool
api_key, base_url = chattool.api_key, chattool.base_url
import pytest, chattool, os
api_key, base_url, api_base = chattool.api_key, chattool.base_url, chattool.api_base
testpath = 'tests/testfiles/'

def test_valid_models():
models = valid_models(api_key, base_url, gpt_only=False)
if chattool.api_base:
model_url = os.path.join(chattool.api_base, 'models')
else:
model_url = os.path.join(chattool.base_url, 'v1/models')
models = valid_models(api_key, model_url, gpt_only=False)
assert len(models) >= 1
models = valid_models(api_key, base_url, gpt_only=True)
models = valid_models(api_key, model_url, gpt_only=True)
assert len(models) >= 1
assert 'gpt-3.5-turbo' in models

Expand Down

0 comments on commit dad777d

Please sign in to comment.