Skip to content

Commit 85f2fb0

Browse files
authored
Merge pull request #20 from meta-llama/stainless_sync_NDrrRs
Sync SDK & CLI for eval_tasks / scoring_functions / datasets
2 parents 0901251 + 9b4b4eb commit 85f2fb0

File tree

118 files changed

+4338
-1635
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

118 files changed

+4338
-1635
lines changed

src/llama_stack_client/_client.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,12 @@
3535
class LlamaStackClient(SyncAPIClient):
3636
agents: resources.AgentsResource
3737
batch_inferences: resources.BatchInferencesResource
38+
datasets: resources.DatasetsResource
39+
eval: resources.EvalResource
3840
inspect: resources.InspectResource
3941
inference: resources.InferenceResource
4042
memory: resources.MemoryResource
4143
memory_banks: resources.MemoryBanksResource
42-
datasets: resources.DatasetsResource
4344
models: resources.ModelsResource
4445
post_training: resources.PostTrainingResource
4546
providers: resources.ProvidersResource
@@ -51,7 +52,7 @@ class LlamaStackClient(SyncAPIClient):
5152
datasetio: resources.DatasetioResource
5253
scoring: resources.ScoringResource
5354
scoring_functions: resources.ScoringFunctionsResource
54-
eval: resources.EvalResource
55+
eval_tasks: resources.EvalTasksResource
5556
with_raw_response: LlamaStackClientWithRawResponse
5657
with_streaming_response: LlamaStackClientWithStreamedResponse
5758

@@ -85,7 +86,6 @@ def __init__(
8586
base_url = os.environ.get("LLAMA_STACK_CLIENT_BASE_URL")
8687
if base_url is None:
8788
base_url = f"http://any-hosted-llama-stack.com"
88-
8989
if provider_data:
9090
if default_headers is None:
9191
default_headers = {}
@@ -104,11 +104,12 @@ def __init__(
104104

105105
self.agents = resources.AgentsResource(self)
106106
self.batch_inferences = resources.BatchInferencesResource(self)
107+
self.datasets = resources.DatasetsResource(self)
108+
self.eval = resources.EvalResource(self)
107109
self.inspect = resources.InspectResource(self)
108110
self.inference = resources.InferenceResource(self)
109111
self.memory = resources.MemoryResource(self)
110112
self.memory_banks = resources.MemoryBanksResource(self)
111-
self.datasets = resources.DatasetsResource(self)
112113
self.models = resources.ModelsResource(self)
113114
self.post_training = resources.PostTrainingResource(self)
114115
self.providers = resources.ProvidersResource(self)
@@ -120,7 +121,7 @@ def __init__(
120121
self.datasetio = resources.DatasetioResource(self)
121122
self.scoring = resources.ScoringResource(self)
122123
self.scoring_functions = resources.ScoringFunctionsResource(self)
123-
self.eval = resources.EvalResource(self)
124+
self.eval_tasks = resources.EvalTasksResource(self)
124125
self.with_raw_response = LlamaStackClientWithRawResponse(self)
125126
self.with_streaming_response = LlamaStackClientWithStreamedResponse(self)
126127

@@ -224,11 +225,12 @@ def _make_status_error(
224225
class AsyncLlamaStackClient(AsyncAPIClient):
225226
agents: resources.AsyncAgentsResource
226227
batch_inferences: resources.AsyncBatchInferencesResource
228+
datasets: resources.AsyncDatasetsResource
229+
eval: resources.AsyncEvalResource
227230
inspect: resources.AsyncInspectResource
228231
inference: resources.AsyncInferenceResource
229232
memory: resources.AsyncMemoryResource
230233
memory_banks: resources.AsyncMemoryBanksResource
231-
datasets: resources.AsyncDatasetsResource
232234
models: resources.AsyncModelsResource
233235
post_training: resources.AsyncPostTrainingResource
234236
providers: resources.AsyncProvidersResource
@@ -240,7 +242,7 @@ class AsyncLlamaStackClient(AsyncAPIClient):
240242
datasetio: resources.AsyncDatasetioResource
241243
scoring: resources.AsyncScoringResource
242244
scoring_functions: resources.AsyncScoringFunctionsResource
243-
eval: resources.AsyncEvalResource
245+
eval_tasks: resources.AsyncEvalTasksResource
244246
with_raw_response: AsyncLlamaStackClientWithRawResponse
245247
with_streaming_response: AsyncLlamaStackClientWithStreamedResponse
246248

@@ -293,11 +295,12 @@ def __init__(
293295

294296
self.agents = resources.AsyncAgentsResource(self)
295297
self.batch_inferences = resources.AsyncBatchInferencesResource(self)
298+
self.datasets = resources.AsyncDatasetsResource(self)
299+
self.eval = resources.AsyncEvalResource(self)
296300
self.inspect = resources.AsyncInspectResource(self)
297301
self.inference = resources.AsyncInferenceResource(self)
298302
self.memory = resources.AsyncMemoryResource(self)
299303
self.memory_banks = resources.AsyncMemoryBanksResource(self)
300-
self.datasets = resources.AsyncDatasetsResource(self)
301304
self.models = resources.AsyncModelsResource(self)
302305
self.post_training = resources.AsyncPostTrainingResource(self)
303306
self.providers = resources.AsyncProvidersResource(self)
@@ -309,7 +312,7 @@ def __init__(
309312
self.datasetio = resources.AsyncDatasetioResource(self)
310313
self.scoring = resources.AsyncScoringResource(self)
311314
self.scoring_functions = resources.AsyncScoringFunctionsResource(self)
312-
self.eval = resources.AsyncEvalResource(self)
315+
self.eval_tasks = resources.AsyncEvalTasksResource(self)
313316
self.with_raw_response = AsyncLlamaStackClientWithRawResponse(self)
314317
self.with_streaming_response = AsyncLlamaStackClientWithStreamedResponse(self)
315318

@@ -414,11 +417,12 @@ class LlamaStackClientWithRawResponse:
414417
def __init__(self, client: LlamaStackClient) -> None:
415418
self.agents = resources.AgentsResourceWithRawResponse(client.agents)
416419
self.batch_inferences = resources.BatchInferencesResourceWithRawResponse(client.batch_inferences)
420+
self.datasets = resources.DatasetsResourceWithRawResponse(client.datasets)
421+
self.eval = resources.EvalResourceWithRawResponse(client.eval)
417422
self.inspect = resources.InspectResourceWithRawResponse(client.inspect)
418423
self.inference = resources.InferenceResourceWithRawResponse(client.inference)
419424
self.memory = resources.MemoryResourceWithRawResponse(client.memory)
420425
self.memory_banks = resources.MemoryBanksResourceWithRawResponse(client.memory_banks)
421-
self.datasets = resources.DatasetsResourceWithRawResponse(client.datasets)
422426
self.models = resources.ModelsResourceWithRawResponse(client.models)
423427
self.post_training = resources.PostTrainingResourceWithRawResponse(client.post_training)
424428
self.providers = resources.ProvidersResourceWithRawResponse(client.providers)
@@ -432,18 +436,19 @@ def __init__(self, client: LlamaStackClient) -> None:
432436
self.datasetio = resources.DatasetioResourceWithRawResponse(client.datasetio)
433437
self.scoring = resources.ScoringResourceWithRawResponse(client.scoring)
434438
self.scoring_functions = resources.ScoringFunctionsResourceWithRawResponse(client.scoring_functions)
435-
self.eval = resources.EvalResourceWithRawResponse(client.eval)
439+
self.eval_tasks = resources.EvalTasksResourceWithRawResponse(client.eval_tasks)
436440

437441

438442
class AsyncLlamaStackClientWithRawResponse:
439443
def __init__(self, client: AsyncLlamaStackClient) -> None:
440444
self.agents = resources.AsyncAgentsResourceWithRawResponse(client.agents)
441445
self.batch_inferences = resources.AsyncBatchInferencesResourceWithRawResponse(client.batch_inferences)
446+
self.datasets = resources.AsyncDatasetsResourceWithRawResponse(client.datasets)
447+
self.eval = resources.AsyncEvalResourceWithRawResponse(client.eval)
442448
self.inspect = resources.AsyncInspectResourceWithRawResponse(client.inspect)
443449
self.inference = resources.AsyncInferenceResourceWithRawResponse(client.inference)
444450
self.memory = resources.AsyncMemoryResourceWithRawResponse(client.memory)
445451
self.memory_banks = resources.AsyncMemoryBanksResourceWithRawResponse(client.memory_banks)
446-
self.datasets = resources.AsyncDatasetsResourceWithRawResponse(client.datasets)
447452
self.models = resources.AsyncModelsResourceWithRawResponse(client.models)
448453
self.post_training = resources.AsyncPostTrainingResourceWithRawResponse(client.post_training)
449454
self.providers = resources.AsyncProvidersResourceWithRawResponse(client.providers)
@@ -457,18 +462,19 @@ def __init__(self, client: AsyncLlamaStackClient) -> None:
457462
self.datasetio = resources.AsyncDatasetioResourceWithRawResponse(client.datasetio)
458463
self.scoring = resources.AsyncScoringResourceWithRawResponse(client.scoring)
459464
self.scoring_functions = resources.AsyncScoringFunctionsResourceWithRawResponse(client.scoring_functions)
460-
self.eval = resources.AsyncEvalResourceWithRawResponse(client.eval)
465+
self.eval_tasks = resources.AsyncEvalTasksResourceWithRawResponse(client.eval_tasks)
461466

462467

463468
class LlamaStackClientWithStreamedResponse:
464469
def __init__(self, client: LlamaStackClient) -> None:
465470
self.agents = resources.AgentsResourceWithStreamingResponse(client.agents)
466471
self.batch_inferences = resources.BatchInferencesResourceWithStreamingResponse(client.batch_inferences)
472+
self.datasets = resources.DatasetsResourceWithStreamingResponse(client.datasets)
473+
self.eval = resources.EvalResourceWithStreamingResponse(client.eval)
467474
self.inspect = resources.InspectResourceWithStreamingResponse(client.inspect)
468475
self.inference = resources.InferenceResourceWithStreamingResponse(client.inference)
469476
self.memory = resources.MemoryResourceWithStreamingResponse(client.memory)
470477
self.memory_banks = resources.MemoryBanksResourceWithStreamingResponse(client.memory_banks)
471-
self.datasets = resources.DatasetsResourceWithStreamingResponse(client.datasets)
472478
self.models = resources.ModelsResourceWithStreamingResponse(client.models)
473479
self.post_training = resources.PostTrainingResourceWithStreamingResponse(client.post_training)
474480
self.providers = resources.ProvidersResourceWithStreamingResponse(client.providers)
@@ -482,18 +488,19 @@ def __init__(self, client: LlamaStackClient) -> None:
482488
self.datasetio = resources.DatasetioResourceWithStreamingResponse(client.datasetio)
483489
self.scoring = resources.ScoringResourceWithStreamingResponse(client.scoring)
484490
self.scoring_functions = resources.ScoringFunctionsResourceWithStreamingResponse(client.scoring_functions)
485-
self.eval = resources.EvalResourceWithStreamingResponse(client.eval)
491+
self.eval_tasks = resources.EvalTasksResourceWithStreamingResponse(client.eval_tasks)
486492

487493

488494
class AsyncLlamaStackClientWithStreamedResponse:
489495
def __init__(self, client: AsyncLlamaStackClient) -> None:
490496
self.agents = resources.AsyncAgentsResourceWithStreamingResponse(client.agents)
491497
self.batch_inferences = resources.AsyncBatchInferencesResourceWithStreamingResponse(client.batch_inferences)
498+
self.datasets = resources.AsyncDatasetsResourceWithStreamingResponse(client.datasets)
499+
self.eval = resources.AsyncEvalResourceWithStreamingResponse(client.eval)
492500
self.inspect = resources.AsyncInspectResourceWithStreamingResponse(client.inspect)
493501
self.inference = resources.AsyncInferenceResourceWithStreamingResponse(client.inference)
494502
self.memory = resources.AsyncMemoryResourceWithStreamingResponse(client.memory)
495503
self.memory_banks = resources.AsyncMemoryBanksResourceWithStreamingResponse(client.memory_banks)
496-
self.datasets = resources.AsyncDatasetsResourceWithStreamingResponse(client.datasets)
497504
self.models = resources.AsyncModelsResourceWithStreamingResponse(client.models)
498505
self.post_training = resources.AsyncPostTrainingResourceWithStreamingResponse(client.post_training)
499506
self.providers = resources.AsyncProvidersResourceWithStreamingResponse(client.providers)
@@ -507,7 +514,7 @@ def __init__(self, client: AsyncLlamaStackClient) -> None:
507514
self.datasetio = resources.AsyncDatasetioResourceWithStreamingResponse(client.datasetio)
508515
self.scoring = resources.AsyncScoringResourceWithStreamingResponse(client.scoring)
509516
self.scoring_functions = resources.AsyncScoringFunctionsResourceWithStreamingResponse(client.scoring_functions)
510-
self.eval = resources.AsyncEvalResourceWithStreamingResponse(client.eval)
517+
self.eval_tasks = resources.AsyncEvalTasksResourceWithStreamingResponse(client.eval_tasks)
511518

512519

513520
Client = LlamaStackClient

src/llama_stack_client/_compat.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing import TYPE_CHECKING, Any, Union, Generic, TypeVar, Callable, cast, overload
44
from datetime import date, datetime
5-
from typing_extensions import Self
5+
from typing_extensions import Self, Literal
66

77
import pydantic
88
from pydantic.fields import FieldInfo
@@ -137,9 +137,11 @@ def model_dump(
137137
exclude_unset: bool = False,
138138
exclude_defaults: bool = False,
139139
warnings: bool = True,
140+
mode: Literal["json", "python"] = "python",
140141
) -> dict[str, Any]:
141-
if PYDANTIC_V2:
142+
if PYDANTIC_V2 or hasattr(model, "model_dump"):
142143
return model.model_dump(
144+
mode=mode,
143145
exclude=exclude,
144146
exclude_unset=exclude_unset,
145147
exclude_defaults=exclude_defaults,

src/llama_stack_client/_models.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
PropertyInfo,
3838
is_list,
3939
is_given,
40+
json_safe,
4041
lru_cache,
4142
is_mapping,
4243
parse_date,
@@ -279,8 +280,8 @@ def model_dump(
279280
Returns:
280281
A dictionary representation of the model.
281282
"""
282-
if mode != "python":
283-
raise ValueError("mode is only supported in Pydantic v2")
283+
if mode not in {"json", "python"}:
284+
raise ValueError("mode must be either 'json' or 'python'")
284285
if round_trip != False:
285286
raise ValueError("round_trip is only supported in Pydantic v2")
286287
if warnings != True:
@@ -289,7 +290,7 @@ def model_dump(
289290
raise ValueError("context is only supported in Pydantic v2")
290291
if serialize_as_any != False:
291292
raise ValueError("serialize_as_any is only supported in Pydantic v2")
292-
return super().dict( # pyright: ignore[reportDeprecated]
293+
dumped = super().dict( # pyright: ignore[reportDeprecated]
293294
include=include,
294295
exclude=exclude,
295296
by_alias=by_alias,
@@ -298,6 +299,8 @@ def model_dump(
298299
exclude_none=exclude_none,
299300
)
300301

302+
return cast(dict[str, Any], json_safe(dumped)) if mode == "json" else dumped
303+
301304
@override
302305
def model_dump_json(
303306
self,

src/llama_stack_client/_utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
is_list as is_list,
77
is_given as is_given,
88
is_tuple as is_tuple,
9+
json_safe as json_safe,
910
lru_cache as lru_cache,
1011
is_mapping as is_mapping,
1112
is_tuple_t as is_tuple_t,

src/llama_stack_client/_utils/_transform.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,11 @@ def _transform_recursive(
173173
# Iterable[T]
174174
or (is_iterable_type(stripped_type) and is_iterable(data) and not isinstance(data, str))
175175
):
176+
# dicts are technically iterable, but it is an iterable on the keys of the dict and is not usually
177+
# intended as an iterable, so we don't transform it.
178+
if isinstance(data, dict):
179+
return cast(object, data)
180+
176181
inner_type = extract_type_arg(stripped_type, 0)
177182
return [_transform_recursive(d, annotation=annotation, inner_type=inner_type) for d in data]
178183

@@ -186,7 +191,7 @@ def _transform_recursive(
186191
return data
187192

188193
if isinstance(data, pydantic.BaseModel):
189-
return model_dump(data, exclude_unset=True)
194+
return model_dump(data, exclude_unset=True, mode="json")
190195

191196
annotated_type = _get_annotated_type(annotation)
192197
if annotated_type is None:
@@ -324,7 +329,7 @@ async def _async_transform_recursive(
324329
return data
325330

326331
if isinstance(data, pydantic.BaseModel):
327-
return model_dump(data, exclude_unset=True)
332+
return model_dump(data, exclude_unset=True, mode="json")
328333

329334
annotated_type = _get_annotated_type(annotation)
330335
if annotated_type is None:

src/llama_stack_client/_utils/_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
overload,
1717
)
1818
from pathlib import Path
19+
from datetime import date, datetime
1920
from typing_extensions import TypeGuard
2021

2122
import sniffio
@@ -395,3 +396,19 @@ def lru_cache(*, maxsize: int | None = 128) -> Callable[[CallableT], CallableT]:
395396
maxsize=maxsize,
396397
)
397398
return cast(Any, wrapper) # type: ignore[no-any-return]
399+
400+
401+
def json_safe(data: object) -> object:
402+
"""Translates a mapping / sequence recursively in the same fashion
403+
as `pydantic` v2's `model_dump(mode="json")`.
404+
"""
405+
if is_mapping(data):
406+
return {json_safe(key): json_safe(value) for key, value in data.items()}
407+
408+
if is_iterable(data) and not isinstance(data, (str, bytes, bytearray)):
409+
return [json_safe(item) for item in data]
410+
411+
if isinstance(data, (datetime, date)):
412+
return data.isoformat()
413+
414+
return data
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
from .datasets import DatasetsParser
8+
9+
__all__ = ["DatasetsParser"]
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import argparse
8+
9+
from llama_stack_client.lib.cli.subcommand import Subcommand
10+
from .list import DatasetsList
11+
12+
13+
class DatasetsParser(Subcommand):
14+
"""Parser for datasets commands"""
15+
16+
@classmethod
17+
def create(cls, subparsers: argparse._SubParsersAction):
18+
parser = subparsers.add_parser(
19+
"datasets",
20+
help="Manage datasets",
21+
formatter_class=argparse.RawTextHelpFormatter,
22+
)
23+
parser.set_defaults(func=lambda _: parser.print_help())
24+
25+
# Create subcommands
26+
datasets_subparsers = parser.add_subparsers(title="subcommands")
27+
DatasetsList(datasets_subparsers)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the terms described in the LICENSE file in
5+
# the root directory of this source tree.
6+
7+
import argparse
8+
9+
from llama_stack_client import LlamaStackClient
10+
from llama_stack_client.lib.cli.common.utils import print_table_from_response
11+
from llama_stack_client.lib.cli.configure import get_config
12+
from llama_stack_client.lib.cli.subcommand import Subcommand
13+
14+
15+
class DatasetsList(Subcommand):
16+
def __init__(self, subparsers: argparse._SubParsersAction):
17+
super().__init__()
18+
self.parser = subparsers.add_parser(
19+
"list",
20+
prog="llama-stack-client datasets list",
21+
description="Show available datasets on distribution endpoint",
22+
formatter_class=argparse.RawTextHelpFormatter,
23+
)
24+
self._add_arguments()
25+
self.parser.set_defaults(func=self._run_datasets_list_cmd)
26+
27+
def _add_arguments(self):
28+
self.parser.add_argument(
29+
"--endpoint",
30+
type=str,
31+
help="Llama Stack distribution endpoint",
32+
)
33+
34+
def _run_datasets_list_cmd(self, args: argparse.Namespace):
35+
args.endpoint = get_config().get("endpoint") or args.endpoint
36+
37+
client = LlamaStackClient(
38+
base_url=args.endpoint,
39+
)
40+
41+
headers = ["identifier", "provider_id", "metadata", "type"]
42+
43+
datasets_list_response = client.datasets.list()
44+
if datasets_list_response:
45+
print_table_from_response(datasets_list_response, headers)

0 commit comments

Comments
 (0)