From c60a5018453eb6e25e73b369621253a1fbe20820 Mon Sep 17 00:00:00 2001 From: Grant Linville Date: Tue, 17 Dec 2024 17:03:06 -0500 Subject: [PATCH] enhance: get more information about models Signed-off-by: Grant Linville --- gptscript/gptscript.py | 8 +++++--- gptscript/openai.py | 28 ++++++++++++++++++++++++++++ tests/test_gptscript.py | 8 ++++---- 3 files changed, 37 insertions(+), 7 deletions(-) create mode 100644 gptscript/openai.py diff --git a/gptscript/gptscript.py b/gptscript/gptscript.py index 392f70b..49d2e54 100644 --- a/gptscript/gptscript.py +++ b/gptscript/gptscript.py @@ -11,6 +11,7 @@ from gptscript.datasets import DatasetElementMeta, DatasetElement, DatasetMeta from gptscript.fileinfo import FileInfo from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program +from gptscript.openai import Model from gptscript.opts import GlobalOptions from gptscript.prompt import PromptResponse from gptscript.run import Run, RunBasicCommand, Options @@ -164,16 +165,17 @@ async def _run_basic_command(self, sub_command: str, request_body: Any = None): async def version(self) -> str: return await self._run_basic_command("version") - async def list_models(self, providers: list[str] = None, credential_overrides: list[str] = None) -> list[str]: + async def list_models(self, providers: list[str] = None, credential_overrides: list[str] = None) -> list[Model]: if self.opts.DefaultModelProvider != "": if providers is None: providers = [] providers.append(self.opts.DefaultModelProvider) - return (await self._run_basic_command( + res = await self._run_basic_command( "list-models", {"providers": providers, "credentialOverrides": credential_overrides} - )).split("\n") + ) + return [Model(**model) for model in json.loads(res)] async def list_credentials(self, contexts: List[str] = None, all_contexts: bool = False) -> list[Credential] | str: if contexts is None: diff --git a/gptscript/openai.py b/gptscript/openai.py new file mode 100644 index 0000000..c1115b6 --- /dev/null +++ b/gptscript/openai.py @@ -0,0 +1,28 @@ +from pydantic import BaseModel, conlist +from typing import Any, Dict, Optional + + +class Permission(BaseModel): + created: int + id: str + object: str + allow_create_engine: bool + allow_sampling: bool + allow_logprobs: bool + allow_search_indices: bool + allow_view: bool + allow_fine_tuning: bool + organization: str + group: Any + is_blocking: bool + + +class Model(BaseModel): + created: Optional[int] + id: str + object: str + owned_by: str + permission: Optional[conlist(Permission)] + root: Optional[str] + parent: Optional[str] + metadata: Optional[Dict[str, str]] diff --git a/tests/test_gptscript.py b/tests/test_gptscript.py index f83607f..9378be1 100644 --- a/tests/test_gptscript.py +++ b/tests/test_gptscript.py @@ -126,8 +126,8 @@ async def test_list_models_from_provider(gptscript): ) assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list" for model in models: - assert model.startswith("claude-3-"), "Unexpected model name" - assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name" + assert model.id.startswith("claude-3-"), "Unexpected model name" + assert model.id.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name" @pytest.mark.asyncio @@ -140,8 +140,8 @@ async def test_list_models_from_default_provider(): ) assert isinstance(models, list) and len(models) > 1, "Expected list_models to return a list" for model in models: - assert model.startswith("claude-3-"), "Unexpected model name" - assert model.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name" + assert model.id.startswith("claude-3-"), "Unexpected model name" + assert model.id.endswith("from github.com/gptscript-ai/claude3-anthropic-provider"), "Unexpected model name" finally: g.close()