From 84dfa74e78294b448a42f97e033154b3c29fb60f Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 30 Oct 2024 16:43:31 -0700 Subject: [PATCH 1/2] sync stainless evals --- src/llama_stack/lib/.keep | 4 + src/llama_stack_client/_base_client.py | 2 +- src/llama_stack_client/_client.py | 72 +- src/llama_stack_client/_compat.py | 2 +- src/llama_stack_client/_models.py | 10 +- src/llama_stack_client/_types.py | 6 +- src/llama_stack_client/resources/__init__.py | 108 +-- .../resources/agents/steps.py | 4 + .../resources/agents/turn.py | 4 + src/llama_stack_client/resources/datasetio.py | 201 ++++++ src/llama_stack_client/resources/datasets.py | 163 ++--- .../resources/eval/__init__.py | 33 + src/llama_stack_client/resources/eval/eval.py | 326 +++++++++ src/llama_stack_client/resources/eval/job.py | 354 +++++++++ .../resources/evaluate/evaluate.py | 118 +-- .../resources/evaluate/jobs.py | 250 ++----- src/llama_stack_client/resources/inference.py | 8 + .../resources/post_training/jobs.py | 360 +++++++++- .../resources/post_training/post_training.py | 38 +- src/llama_stack_client/resources/scoring.py | 290 ++++++++ .../resources/scoring_functions.py | 356 +++++++++ src/llama_stack_client/types/__init__.py | 39 +- .../types/agent_create_params.py | 76 +- .../types/agent_create_response.py | 1 - .../types/agents/session.py | 61 +- .../types/agents/session_create_response.py | 1 - .../types/agents/step_retrieve_params.py | 2 + .../types/agents/step_retrieve_response.py | 122 +--- src/llama_stack_client/types/agents/turn.py | 121 +--- .../types/agents/turn_create_response.py | 120 +--- .../types/agents/turn_retrieve_params.py | 2 + .../types/completion_response.py | 17 +- .../types/dataset_list_response.py | 38 + .../types/dataset_register_params.py | 46 ++ .../types/dataset_retrieve_params.py | 2 +- .../types/dataset_retrieve_response.py | 38 + .../datasetio_get_rows_paginated_params.py | 21 + src/llama_stack_client/types/eval/__init__.py | 9 + .../types/eval/job_cancel_params.py | 15 + .../types/eval/job_result_params.py | 15 + .../types/eval/job_result_response.py | 19 + .../types/eval/job_status.py | 7 + .../types/eval/job_status_params.py | 15 + .../types/eval_evaluate_batch_params.py | 259 +++++++ .../types/eval_evaluate_params.py | 259 +++++++ .../types/eval_evaluate_response.py | 19 + .../types/evaluate/__init__.py | 8 +- .../types/evaluate/evaluate_response.py | 19 + .../types/evaluate/job_artifacts_response.py | 1 - .../types/evaluate/job_cancel_params.py | 2 +- .../types/evaluate/job_logs_response.py | 1 - .../types/evaluate/job_result_params.py | 15 + .../types/evaluate/job_status.py | 7 + .../types/evaluate/job_status_params.py | 2 +- .../types/evaluate/job_status_response.py | 1 - .../types/evaluate_evaluate_batch_params.py | 259 +++++++ .../types/evaluate_evaluate_params.py | 259 +++++++ .../types/evaluation_job.py | 1 - src/llama_stack_client/types/health_info.py | 1 - .../types/inference_chat_completion_params.py | 28 +- .../inference_chat_completion_response.py | 17 +- .../types/inference_completion_params.py | 31 +- .../types/inference_completion_response.py | 11 +- src/llama_stack_client/types/job.py | 10 + .../types/memory_bank_list_response.py | 56 +- .../types/memory_bank_register_params.py | 60 +- .../types/memory_bank_retrieve_response.py | 56 +- .../types/memory_retrieval_step.py | 7 +- .../types/model_list_response.py | 23 + .../types/model_retrieve_response.py | 23 + .../types/paginated_rows_result.py | 15 + .../types/post_training_job.py | 1 - ...ost_training_preference_optimize_params.py | 5 +- ...st_training_supervised_fine_tune_params.py | 5 +- src/llama_stack_client/types/provider_info.py | 1 - .../types/run_shield_response.py | 16 +- .../types/score_batch_response.py | 19 + .../types/score_response.py | 17 + .../types/scoring_fn_def_with_provider.py | 84 +++ .../scoring_fn_def_with_provider_param.py | 84 +++ .../scoring_function_def_with_provider.py | 98 +++ ...coring_function_def_with_provider_param.py | 98 +++ .../types/scoring_function_list_response.py | 84 +++ .../types/scoring_function_register_params.py | 16 + .../types/scoring_function_retrieve_params.py | 15 + .../scoring_function_retrieve_response.py | 84 +++ .../types/scoring_score_batch_params.py | 20 + .../types/scoring_score_params.py | 18 + .../types/shared/__init__.py | 5 + .../types/shared/graph_memory_bank_def.py | 15 + .../types/shared/key_value_memory_bank_def.py | 15 + .../types/shared/keyword_memory_bank_def.py | 15 + .../types/shared/safety_violation.py | 16 + .../types/shared/vector_memory_bank_def.py | 22 + .../types/shared_params/__init__.py | 4 + .../shared_params/graph_memory_bank_def.py | 15 + .../key_value_memory_bank_def.py | 15 + .../shared_params/keyword_memory_bank_def.py | 15 + .../shared_params/vector_memory_bank_def.py | 21 + .../types/shield_call_step.py | 15 +- .../types/shield_list_response.py | 19 + .../types/shield_retrieve_response.py | 19 + .../synthetic_data_generation_response.py | 33 +- .../types/tool_execution_step.py | 19 +- src/llama_stack_client/types/tool_response.py | 21 + tests/api_resources/agents/test_session.py | 285 ++++++++ tests/api_resources/agents/test_steps.py | 26 +- tests/api_resources/agents/test_turn.py | 636 +++++++++++++++++ tests/api_resources/agents/test_turns.py | 118 +-- tests/api_resources/eval/__init__.py | 1 + tests/api_resources/eval/test_job.py | 259 +++++++ tests/api_resources/evaluate/test_jobs.py | 199 ++++-- tests/api_resources/memory/test_documents.py | 24 + tests/api_resources/post_training/test_job.py | 427 +++++++++++ .../api_resources/post_training/test_jobs.py | 24 + tests/api_resources/test_batch_inferences.py | 675 ++++++++++++++++++ tests/api_resources/test_datasetio.py | 112 +++ tests/api_resources/test_datasets.py | 303 ++++---- tests/api_resources/test_eval.py | 318 +++++++++ tests/api_resources/test_evaluate.py | 319 +++++++++ tests/api_resources/test_evaluations.py | 78 -- tests/api_resources/test_inference.py | 419 ++++------- tests/api_resources/test_inspect.py | 86 +++ tests/api_resources/test_memory.py | 538 +------------- tests/api_resources/test_memory_banks.py | 237 ++++-- tests/api_resources/test_models.py | 224 ++++-- tests/api_resources/test_post_training.py | 168 +---- tests/api_resources/test_providers.py | 86 +++ tests/api_resources/test_reward_scoring.py | 18 +- tests/api_resources/test_routes.py | 86 +++ tests/api_resources/test_safety.py | 18 +- tests/api_resources/test_scoring.py | 202 ++++++ tests/api_resources/test_scoring_functions.py | 438 ++++++++++++ tests/api_resources/test_shields.py | 224 ++++-- .../test_synthetic_data_generation.py | 18 +- tests/api_resources/test_telemetry.py | 50 +- tests/conftest.py | 14 +- tests/test_client.py | 344 ++++++++- tests/test_models.py | 2 +- tests/test_response.py | 50 ++ 140 files changed, 10172 insertions(+), 2749 deletions(-) create mode 100644 src/llama_stack/lib/.keep create mode 100644 src/llama_stack_client/resources/datasetio.py create mode 100644 src/llama_stack_client/resources/eval/__init__.py create mode 100644 src/llama_stack_client/resources/eval/eval.py create mode 100644 src/llama_stack_client/resources/eval/job.py create mode 100644 src/llama_stack_client/resources/scoring.py create mode 100644 src/llama_stack_client/resources/scoring_functions.py create mode 100644 src/llama_stack_client/types/dataset_list_response.py create mode 100644 src/llama_stack_client/types/dataset_register_params.py create mode 100644 src/llama_stack_client/types/dataset_retrieve_response.py create mode 100644 src/llama_stack_client/types/datasetio_get_rows_paginated_params.py create mode 100644 src/llama_stack_client/types/eval/__init__.py create mode 100644 src/llama_stack_client/types/eval/job_cancel_params.py create mode 100644 src/llama_stack_client/types/eval/job_result_params.py create mode 100644 src/llama_stack_client/types/eval/job_result_response.py create mode 100644 src/llama_stack_client/types/eval/job_status.py create mode 100644 src/llama_stack_client/types/eval/job_status_params.py create mode 100644 src/llama_stack_client/types/eval_evaluate_batch_params.py create mode 100644 src/llama_stack_client/types/eval_evaluate_params.py create mode 100644 src/llama_stack_client/types/eval_evaluate_response.py create mode 100644 src/llama_stack_client/types/evaluate/evaluate_response.py create mode 100644 src/llama_stack_client/types/evaluate/job_result_params.py create mode 100644 src/llama_stack_client/types/evaluate/job_status.py create mode 100644 src/llama_stack_client/types/evaluate_evaluate_batch_params.py create mode 100644 src/llama_stack_client/types/evaluate_evaluate_params.py create mode 100644 src/llama_stack_client/types/job.py create mode 100644 src/llama_stack_client/types/model_list_response.py create mode 100644 src/llama_stack_client/types/model_retrieve_response.py create mode 100644 src/llama_stack_client/types/paginated_rows_result.py create mode 100644 src/llama_stack_client/types/score_batch_response.py create mode 100644 src/llama_stack_client/types/score_response.py create mode 100644 src/llama_stack_client/types/scoring_fn_def_with_provider.py create mode 100644 src/llama_stack_client/types/scoring_fn_def_with_provider_param.py create mode 100644 src/llama_stack_client/types/scoring_function_def_with_provider.py create mode 100644 src/llama_stack_client/types/scoring_function_def_with_provider_param.py create mode 100644 src/llama_stack_client/types/scoring_function_list_response.py create mode 100644 src/llama_stack_client/types/scoring_function_register_params.py create mode 100644 src/llama_stack_client/types/scoring_function_retrieve_params.py create mode 100644 src/llama_stack_client/types/scoring_function_retrieve_response.py create mode 100644 src/llama_stack_client/types/scoring_score_batch_params.py create mode 100644 src/llama_stack_client/types/scoring_score_params.py create mode 100644 src/llama_stack_client/types/shared/graph_memory_bank_def.py create mode 100644 src/llama_stack_client/types/shared/key_value_memory_bank_def.py create mode 100644 src/llama_stack_client/types/shared/keyword_memory_bank_def.py create mode 100644 src/llama_stack_client/types/shared/safety_violation.py create mode 100644 src/llama_stack_client/types/shared/vector_memory_bank_def.py create mode 100644 src/llama_stack_client/types/shared_params/graph_memory_bank_def.py create mode 100644 src/llama_stack_client/types/shared_params/key_value_memory_bank_def.py create mode 100644 src/llama_stack_client/types/shared_params/keyword_memory_bank_def.py create mode 100644 src/llama_stack_client/types/shared_params/vector_memory_bank_def.py create mode 100644 src/llama_stack_client/types/shield_list_response.py create mode 100644 src/llama_stack_client/types/shield_retrieve_response.py create mode 100644 src/llama_stack_client/types/tool_response.py create mode 100644 tests/api_resources/agents/test_session.py create mode 100644 tests/api_resources/agents/test_turn.py create mode 100644 tests/api_resources/eval/__init__.py create mode 100644 tests/api_resources/eval/test_job.py create mode 100644 tests/api_resources/post_training/test_job.py create mode 100644 tests/api_resources/test_batch_inferences.py create mode 100644 tests/api_resources/test_datasetio.py create mode 100644 tests/api_resources/test_eval.py create mode 100644 tests/api_resources/test_evaluate.py create mode 100644 tests/api_resources/test_inspect.py create mode 100644 tests/api_resources/test_providers.py create mode 100644 tests/api_resources/test_routes.py create mode 100644 tests/api_resources/test_scoring.py create mode 100644 tests/api_resources/test_scoring_functions.py diff --git a/src/llama_stack/lib/.keep b/src/llama_stack/lib/.keep new file mode 100644 index 00000000..5e2c99fd --- /dev/null +++ b/src/llama_stack/lib/.keep @@ -0,0 +1,4 @@ +File generated from our OpenAPI spec by Stainless. + +This directory can be used to store custom files to expand the SDK. +It is ignored by Stainless code generation and its content (other than this keep file) won't be touched. \ No newline at end of file diff --git a/src/llama_stack_client/_base_client.py b/src/llama_stack_client/_base_client.py index 9fa79520..9640b524 100644 --- a/src/llama_stack_client/_base_client.py +++ b/src/llama_stack_client/_base_client.py @@ -1575,7 +1575,7 @@ async def _request( except Exception as err: log.debug("Encountered Exception", exc_info=True) - if retries_taken > 0: + if remaining_retries > 0: return await self._retry_request( input_options, cast_to, diff --git a/src/llama_stack_client/_client.py b/src/llama_stack_client/_client.py index a0592e93..a8489b0f 100644 --- a/src/llama_stack_client/_client.py +++ b/src/llama_stack_client/_client.py @@ -48,22 +48,23 @@ class LlamaStackClient(SyncAPIClient): agents: resources.AgentsResource batch_inferences: resources.BatchInferencesResource - datasets: resources.DatasetsResource - evaluate: resources.EvaluateResource - evaluations: resources.EvaluationsResource inspect: resources.InspectResource inference: resources.InferenceResource memory: resources.MemoryResource memory_banks: resources.MemoryBanksResource + datasets: resources.DatasetsResource models: resources.ModelsResource post_training: resources.PostTrainingResource providers: resources.ProvidersResource - reward_scoring: resources.RewardScoringResource routes: resources.RoutesResource safety: resources.SafetyResource shields: resources.ShieldsResource synthetic_data_generation: resources.SyntheticDataGenerationResource telemetry: resources.TelemetryResource + datasetio: resources.DatasetioResource + scoring: resources.ScoringResource + scoring_functions: resources.ScoringFunctionsResource + eval: resources.EvalResource with_raw_response: LlamaStackClientWithRawResponse with_streaming_response: LlamaStackClientWithStreamedResponse @@ -110,22 +111,23 @@ def __init__( self.agents = resources.AgentsResource(self) self.batch_inferences = resources.BatchInferencesResource(self) - self.datasets = resources.DatasetsResource(self) - self.evaluate = resources.EvaluateResource(self) - self.evaluations = resources.EvaluationsResource(self) self.inspect = resources.InspectResource(self) self.inference = resources.InferenceResource(self) self.memory = resources.MemoryResource(self) self.memory_banks = resources.MemoryBanksResource(self) + self.datasets = resources.DatasetsResource(self) self.models = resources.ModelsResource(self) self.post_training = resources.PostTrainingResource(self) self.providers = resources.ProvidersResource(self) - self.reward_scoring = resources.RewardScoringResource(self) self.routes = resources.RoutesResource(self) self.safety = resources.SafetyResource(self) self.shields = resources.ShieldsResource(self) self.synthetic_data_generation = resources.SyntheticDataGenerationResource(self) self.telemetry = resources.TelemetryResource(self) + self.datasetio = resources.DatasetioResource(self) + self.scoring = resources.ScoringResource(self) + self.scoring_functions = resources.ScoringFunctionsResource(self) + self.eval = resources.EvalResource(self) self.with_raw_response = LlamaStackClientWithRawResponse(self) self.with_streaming_response = LlamaStackClientWithStreamedResponse(self) @@ -229,22 +231,23 @@ def _make_status_error( class AsyncLlamaStackClient(AsyncAPIClient): agents: resources.AsyncAgentsResource batch_inferences: resources.AsyncBatchInferencesResource - datasets: resources.AsyncDatasetsResource - evaluate: resources.AsyncEvaluateResource - evaluations: resources.AsyncEvaluationsResource inspect: resources.AsyncInspectResource inference: resources.AsyncInferenceResource memory: resources.AsyncMemoryResource memory_banks: resources.AsyncMemoryBanksResource + datasets: resources.AsyncDatasetsResource models: resources.AsyncModelsResource post_training: resources.AsyncPostTrainingResource providers: resources.AsyncProvidersResource - reward_scoring: resources.AsyncRewardScoringResource routes: resources.AsyncRoutesResource safety: resources.AsyncSafetyResource shields: resources.AsyncShieldsResource synthetic_data_generation: resources.AsyncSyntheticDataGenerationResource telemetry: resources.AsyncTelemetryResource + datasetio: resources.AsyncDatasetioResource + scoring: resources.AsyncScoringResource + scoring_functions: resources.AsyncScoringFunctionsResource + eval: resources.AsyncEvalResource with_raw_response: AsyncLlamaStackClientWithRawResponse with_streaming_response: AsyncLlamaStackClientWithStreamedResponse @@ -291,22 +294,23 @@ def __init__( self.agents = resources.AsyncAgentsResource(self) self.batch_inferences = resources.AsyncBatchInferencesResource(self) - self.datasets = resources.AsyncDatasetsResource(self) - self.evaluate = resources.AsyncEvaluateResource(self) - self.evaluations = resources.AsyncEvaluationsResource(self) self.inspect = resources.AsyncInspectResource(self) self.inference = resources.AsyncInferenceResource(self) self.memory = resources.AsyncMemoryResource(self) self.memory_banks = resources.AsyncMemoryBanksResource(self) + self.datasets = resources.AsyncDatasetsResource(self) self.models = resources.AsyncModelsResource(self) self.post_training = resources.AsyncPostTrainingResource(self) self.providers = resources.AsyncProvidersResource(self) - self.reward_scoring = resources.AsyncRewardScoringResource(self) self.routes = resources.AsyncRoutesResource(self) self.safety = resources.AsyncSafetyResource(self) self.shields = resources.AsyncShieldsResource(self) self.synthetic_data_generation = resources.AsyncSyntheticDataGenerationResource(self) self.telemetry = resources.AsyncTelemetryResource(self) + self.datasetio = resources.AsyncDatasetioResource(self) + self.scoring = resources.AsyncScoringResource(self) + self.scoring_functions = resources.AsyncScoringFunctionsResource(self) + self.eval = resources.AsyncEvalResource(self) self.with_raw_response = AsyncLlamaStackClientWithRawResponse(self) self.with_streaming_response = AsyncLlamaStackClientWithStreamedResponse(self) @@ -411,17 +415,14 @@ class LlamaStackClientWithRawResponse: def __init__(self, client: LlamaStackClient) -> None: self.agents = resources.AgentsResourceWithRawResponse(client.agents) self.batch_inferences = resources.BatchInferencesResourceWithRawResponse(client.batch_inferences) - self.datasets = resources.DatasetsResourceWithRawResponse(client.datasets) - self.evaluate = resources.EvaluateResourceWithRawResponse(client.evaluate) - self.evaluations = resources.EvaluationsResourceWithRawResponse(client.evaluations) self.inspect = resources.InspectResourceWithRawResponse(client.inspect) self.inference = resources.InferenceResourceWithRawResponse(client.inference) self.memory = resources.MemoryResourceWithRawResponse(client.memory) self.memory_banks = resources.MemoryBanksResourceWithRawResponse(client.memory_banks) + self.datasets = resources.DatasetsResourceWithRawResponse(client.datasets) self.models = resources.ModelsResourceWithRawResponse(client.models) self.post_training = resources.PostTrainingResourceWithRawResponse(client.post_training) self.providers = resources.ProvidersResourceWithRawResponse(client.providers) - self.reward_scoring = resources.RewardScoringResourceWithRawResponse(client.reward_scoring) self.routes = resources.RoutesResourceWithRawResponse(client.routes) self.safety = resources.SafetyResourceWithRawResponse(client.safety) self.shields = resources.ShieldsResourceWithRawResponse(client.shields) @@ -429,23 +430,24 @@ def __init__(self, client: LlamaStackClient) -> None: client.synthetic_data_generation ) self.telemetry = resources.TelemetryResourceWithRawResponse(client.telemetry) + self.datasetio = resources.DatasetioResourceWithRawResponse(client.datasetio) + self.scoring = resources.ScoringResourceWithRawResponse(client.scoring) + self.scoring_functions = resources.ScoringFunctionsResourceWithRawResponse(client.scoring_functions) + self.eval = resources.EvalResourceWithRawResponse(client.eval) class AsyncLlamaStackClientWithRawResponse: def __init__(self, client: AsyncLlamaStackClient) -> None: self.agents = resources.AsyncAgentsResourceWithRawResponse(client.agents) self.batch_inferences = resources.AsyncBatchInferencesResourceWithRawResponse(client.batch_inferences) - self.datasets = resources.AsyncDatasetsResourceWithRawResponse(client.datasets) - self.evaluate = resources.AsyncEvaluateResourceWithRawResponse(client.evaluate) - self.evaluations = resources.AsyncEvaluationsResourceWithRawResponse(client.evaluations) self.inspect = resources.AsyncInspectResourceWithRawResponse(client.inspect) self.inference = resources.AsyncInferenceResourceWithRawResponse(client.inference) self.memory = resources.AsyncMemoryResourceWithRawResponse(client.memory) self.memory_banks = resources.AsyncMemoryBanksResourceWithRawResponse(client.memory_banks) + self.datasets = resources.AsyncDatasetsResourceWithRawResponse(client.datasets) self.models = resources.AsyncModelsResourceWithRawResponse(client.models) self.post_training = resources.AsyncPostTrainingResourceWithRawResponse(client.post_training) self.providers = resources.AsyncProvidersResourceWithRawResponse(client.providers) - self.reward_scoring = resources.AsyncRewardScoringResourceWithRawResponse(client.reward_scoring) self.routes = resources.AsyncRoutesResourceWithRawResponse(client.routes) self.safety = resources.AsyncSafetyResourceWithRawResponse(client.safety) self.shields = resources.AsyncShieldsResourceWithRawResponse(client.shields) @@ -453,23 +455,24 @@ def __init__(self, client: AsyncLlamaStackClient) -> None: client.synthetic_data_generation ) self.telemetry = resources.AsyncTelemetryResourceWithRawResponse(client.telemetry) + self.datasetio = resources.AsyncDatasetioResourceWithRawResponse(client.datasetio) + self.scoring = resources.AsyncScoringResourceWithRawResponse(client.scoring) + self.scoring_functions = resources.AsyncScoringFunctionsResourceWithRawResponse(client.scoring_functions) + self.eval = resources.AsyncEvalResourceWithRawResponse(client.eval) class LlamaStackClientWithStreamedResponse: def __init__(self, client: LlamaStackClient) -> None: self.agents = resources.AgentsResourceWithStreamingResponse(client.agents) self.batch_inferences = resources.BatchInferencesResourceWithStreamingResponse(client.batch_inferences) - self.datasets = resources.DatasetsResourceWithStreamingResponse(client.datasets) - self.evaluate = resources.EvaluateResourceWithStreamingResponse(client.evaluate) - self.evaluations = resources.EvaluationsResourceWithStreamingResponse(client.evaluations) self.inspect = resources.InspectResourceWithStreamingResponse(client.inspect) self.inference = resources.InferenceResourceWithStreamingResponse(client.inference) self.memory = resources.MemoryResourceWithStreamingResponse(client.memory) self.memory_banks = resources.MemoryBanksResourceWithStreamingResponse(client.memory_banks) + self.datasets = resources.DatasetsResourceWithStreamingResponse(client.datasets) self.models = resources.ModelsResourceWithStreamingResponse(client.models) self.post_training = resources.PostTrainingResourceWithStreamingResponse(client.post_training) self.providers = resources.ProvidersResourceWithStreamingResponse(client.providers) - self.reward_scoring = resources.RewardScoringResourceWithStreamingResponse(client.reward_scoring) self.routes = resources.RoutesResourceWithStreamingResponse(client.routes) self.safety = resources.SafetyResourceWithStreamingResponse(client.safety) self.shields = resources.ShieldsResourceWithStreamingResponse(client.shields) @@ -477,23 +480,24 @@ def __init__(self, client: LlamaStackClient) -> None: client.synthetic_data_generation ) self.telemetry = resources.TelemetryResourceWithStreamingResponse(client.telemetry) + self.datasetio = resources.DatasetioResourceWithStreamingResponse(client.datasetio) + self.scoring = resources.ScoringResourceWithStreamingResponse(client.scoring) + self.scoring_functions = resources.ScoringFunctionsResourceWithStreamingResponse(client.scoring_functions) + self.eval = resources.EvalResourceWithStreamingResponse(client.eval) class AsyncLlamaStackClientWithStreamedResponse: def __init__(self, client: AsyncLlamaStackClient) -> None: self.agents = resources.AsyncAgentsResourceWithStreamingResponse(client.agents) self.batch_inferences = resources.AsyncBatchInferencesResourceWithStreamingResponse(client.batch_inferences) - self.datasets = resources.AsyncDatasetsResourceWithStreamingResponse(client.datasets) - self.evaluate = resources.AsyncEvaluateResourceWithStreamingResponse(client.evaluate) - self.evaluations = resources.AsyncEvaluationsResourceWithStreamingResponse(client.evaluations) self.inspect = resources.AsyncInspectResourceWithStreamingResponse(client.inspect) self.inference = resources.AsyncInferenceResourceWithStreamingResponse(client.inference) self.memory = resources.AsyncMemoryResourceWithStreamingResponse(client.memory) self.memory_banks = resources.AsyncMemoryBanksResourceWithStreamingResponse(client.memory_banks) + self.datasets = resources.AsyncDatasetsResourceWithStreamingResponse(client.datasets) self.models = resources.AsyncModelsResourceWithStreamingResponse(client.models) self.post_training = resources.AsyncPostTrainingResourceWithStreamingResponse(client.post_training) self.providers = resources.AsyncProvidersResourceWithStreamingResponse(client.providers) - self.reward_scoring = resources.AsyncRewardScoringResourceWithStreamingResponse(client.reward_scoring) self.routes = resources.AsyncRoutesResourceWithStreamingResponse(client.routes) self.safety = resources.AsyncSafetyResourceWithStreamingResponse(client.safety) self.shields = resources.AsyncShieldsResourceWithStreamingResponse(client.shields) @@ -501,6 +505,10 @@ def __init__(self, client: AsyncLlamaStackClient) -> None: client.synthetic_data_generation ) self.telemetry = resources.AsyncTelemetryResourceWithStreamingResponse(client.telemetry) + self.datasetio = resources.AsyncDatasetioResourceWithStreamingResponse(client.datasetio) + self.scoring = resources.AsyncScoringResourceWithStreamingResponse(client.scoring) + self.scoring_functions = resources.AsyncScoringFunctionsResourceWithStreamingResponse(client.scoring_functions) + self.eval = resources.AsyncEvalResourceWithStreamingResponse(client.eval) Client = LlamaStackClient diff --git a/src/llama_stack_client/_compat.py b/src/llama_stack_client/_compat.py index 162a6fbe..d89920d9 100644 --- a/src/llama_stack_client/_compat.py +++ b/src/llama_stack_client/_compat.py @@ -133,7 +133,7 @@ def model_json(model: pydantic.BaseModel, *, indent: int | None = None) -> str: def model_dump( model: pydantic.BaseModel, *, - exclude: IncEx = None, + exclude: IncEx | None = None, exclude_unset: bool = False, exclude_defaults: bool = False, warnings: bool = True, diff --git a/src/llama_stack_client/_models.py b/src/llama_stack_client/_models.py index d386eaa3..42551b76 100644 --- a/src/llama_stack_client/_models.py +++ b/src/llama_stack_client/_models.py @@ -176,7 +176,7 @@ def __str__(self) -> str: # Based on https://github.com/samuelcolvin/pydantic/issues/1168#issuecomment-817742836. @classmethod @override - def construct( + def construct( # pyright: ignore[reportIncompatibleMethodOverride] cls: Type[ModelT], _fields_set: set[str] | None = None, **values: object, @@ -248,8 +248,8 @@ def model_dump( self, *, mode: Literal["json", "python"] | str = "python", - include: IncEx = None, - exclude: IncEx = None, + include: IncEx | None = None, + exclude: IncEx | None = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, @@ -303,8 +303,8 @@ def model_dump_json( self, *, indent: int | None = None, - include: IncEx = None, - exclude: IncEx = None, + include: IncEx | None = None, + exclude: IncEx | None = None, by_alias: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, diff --git a/src/llama_stack_client/_types.py b/src/llama_stack_client/_types.py index 8294dfce..a20fdd0e 100644 --- a/src/llama_stack_client/_types.py +++ b/src/llama_stack_client/_types.py @@ -16,7 +16,7 @@ Optional, Sequence, ) -from typing_extensions import Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable +from typing_extensions import Set, Literal, Protocol, TypeAlias, TypedDict, override, runtime_checkable import httpx import pydantic @@ -193,7 +193,9 @@ def get(self, __key: str) -> str | None: ... # Note: copied from Pydantic # https://github.com/pydantic/pydantic/blob/32ea570bf96e84234d2992e1ddf40ab8a565925a/pydantic/main.py#L49 -IncEx: TypeAlias = "set[int] | set[str] | dict[int, Any] | dict[str, Any] | None" +IncEx: TypeAlias = Union[ + Set[int], Set[str], Mapping[int, Union["IncEx", Literal[True]]], Mapping[str, Union["IncEx", Literal[True]]] +] PostParser = Callable[[Any], Any] diff --git a/src/llama_stack_client/resources/__init__.py b/src/llama_stack_client/resources/__init__.py index 59144a3e..f5cdbbeb 100644 --- a/src/llama_stack_client/resources/__init__.py +++ b/src/llama_stack_client/resources/__init__.py @@ -1,5 +1,13 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. +from .eval import ( + EvalResource, + AsyncEvalResource, + EvalResourceWithRawResponse, + AsyncEvalResourceWithRawResponse, + EvalResourceWithStreamingResponse, + AsyncEvalResourceWithStreamingResponse, +) from .agents import ( AgentsResource, AsyncAgentsResource, @@ -48,6 +56,14 @@ InspectResourceWithStreamingResponse, AsyncInspectResourceWithStreamingResponse, ) +from .scoring import ( + ScoringResource, + AsyncScoringResource, + ScoringResourceWithRawResponse, + AsyncScoringResourceWithRawResponse, + ScoringResourceWithStreamingResponse, + AsyncScoringResourceWithStreamingResponse, +) from .shields import ( ShieldsResource, AsyncShieldsResource, @@ -64,13 +80,13 @@ DatasetsResourceWithStreamingResponse, AsyncDatasetsResourceWithStreamingResponse, ) -from .evaluate import ( - EvaluateResource, - AsyncEvaluateResource, - EvaluateResourceWithRawResponse, - AsyncEvaluateResourceWithRawResponse, - EvaluateResourceWithStreamingResponse, - AsyncEvaluateResourceWithStreamingResponse, +from .datasetio import ( + DatasetioResource, + AsyncDatasetioResource, + DatasetioResourceWithRawResponse, + AsyncDatasetioResourceWithRawResponse, + DatasetioResourceWithStreamingResponse, + AsyncDatasetioResourceWithStreamingResponse, ) from .inference import ( InferenceResource, @@ -96,14 +112,6 @@ TelemetryResourceWithStreamingResponse, AsyncTelemetryResourceWithStreamingResponse, ) -from .evaluations import ( - EvaluationsResource, - AsyncEvaluationsResource, - EvaluationsResourceWithRawResponse, - AsyncEvaluationsResourceWithRawResponse, - EvaluationsResourceWithStreamingResponse, - AsyncEvaluationsResourceWithStreamingResponse, -) from .memory_banks import ( MemoryBanksResource, AsyncMemoryBanksResource, @@ -120,14 +128,6 @@ PostTrainingResourceWithStreamingResponse, AsyncPostTrainingResourceWithStreamingResponse, ) -from .reward_scoring import ( - RewardScoringResource, - AsyncRewardScoringResource, - RewardScoringResourceWithRawResponse, - AsyncRewardScoringResourceWithRawResponse, - RewardScoringResourceWithStreamingResponse, - AsyncRewardScoringResourceWithStreamingResponse, -) from .batch_inferences import ( BatchInferencesResource, AsyncBatchInferencesResource, @@ -136,6 +136,14 @@ BatchInferencesResourceWithStreamingResponse, AsyncBatchInferencesResourceWithStreamingResponse, ) +from .scoring_functions import ( + ScoringFunctionsResource, + AsyncScoringFunctionsResource, + ScoringFunctionsResourceWithRawResponse, + AsyncScoringFunctionsResourceWithRawResponse, + ScoringFunctionsResourceWithStreamingResponse, + AsyncScoringFunctionsResourceWithStreamingResponse, +) from .synthetic_data_generation import ( SyntheticDataGenerationResource, AsyncSyntheticDataGenerationResource, @@ -158,24 +166,6 @@ "AsyncBatchInferencesResourceWithRawResponse", "BatchInferencesResourceWithStreamingResponse", "AsyncBatchInferencesResourceWithStreamingResponse", - "DatasetsResource", - "AsyncDatasetsResource", - "DatasetsResourceWithRawResponse", - "AsyncDatasetsResourceWithRawResponse", - "DatasetsResourceWithStreamingResponse", - "AsyncDatasetsResourceWithStreamingResponse", - "EvaluateResource", - "AsyncEvaluateResource", - "EvaluateResourceWithRawResponse", - "AsyncEvaluateResourceWithRawResponse", - "EvaluateResourceWithStreamingResponse", - "AsyncEvaluateResourceWithStreamingResponse", - "EvaluationsResource", - "AsyncEvaluationsResource", - "EvaluationsResourceWithRawResponse", - "AsyncEvaluationsResourceWithRawResponse", - "EvaluationsResourceWithStreamingResponse", - "AsyncEvaluationsResourceWithStreamingResponse", "InspectResource", "AsyncInspectResource", "InspectResourceWithRawResponse", @@ -200,6 +190,12 @@ "AsyncMemoryBanksResourceWithRawResponse", "MemoryBanksResourceWithStreamingResponse", "AsyncMemoryBanksResourceWithStreamingResponse", + "DatasetsResource", + "AsyncDatasetsResource", + "DatasetsResourceWithRawResponse", + "AsyncDatasetsResourceWithRawResponse", + "DatasetsResourceWithStreamingResponse", + "AsyncDatasetsResourceWithStreamingResponse", "ModelsResource", "AsyncModelsResource", "ModelsResourceWithRawResponse", @@ -218,12 +214,6 @@ "AsyncProvidersResourceWithRawResponse", "ProvidersResourceWithStreamingResponse", "AsyncProvidersResourceWithStreamingResponse", - "RewardScoringResource", - "AsyncRewardScoringResource", - "RewardScoringResourceWithRawResponse", - "AsyncRewardScoringResourceWithRawResponse", - "RewardScoringResourceWithStreamingResponse", - "AsyncRewardScoringResourceWithStreamingResponse", "RoutesResource", "AsyncRoutesResource", "RoutesResourceWithRawResponse", @@ -254,4 +244,28 @@ "AsyncTelemetryResourceWithRawResponse", "TelemetryResourceWithStreamingResponse", "AsyncTelemetryResourceWithStreamingResponse", + "DatasetioResource", + "AsyncDatasetioResource", + "DatasetioResourceWithRawResponse", + "AsyncDatasetioResourceWithRawResponse", + "DatasetioResourceWithStreamingResponse", + "AsyncDatasetioResourceWithStreamingResponse", + "ScoringResource", + "AsyncScoringResource", + "ScoringResourceWithRawResponse", + "AsyncScoringResourceWithRawResponse", + "ScoringResourceWithStreamingResponse", + "AsyncScoringResourceWithStreamingResponse", + "ScoringFunctionsResource", + "AsyncScoringFunctionsResource", + "ScoringFunctionsResourceWithRawResponse", + "AsyncScoringFunctionsResourceWithRawResponse", + "ScoringFunctionsResourceWithStreamingResponse", + "AsyncScoringFunctionsResourceWithStreamingResponse", + "EvalResource", + "AsyncEvalResource", + "EvalResourceWithRawResponse", + "AsyncEvalResourceWithRawResponse", + "EvalResourceWithStreamingResponse", + "AsyncEvalResourceWithStreamingResponse", ] diff --git a/src/llama_stack_client/resources/agents/steps.py b/src/llama_stack_client/resources/agents/steps.py index ab038157..80f5db43 100644 --- a/src/llama_stack_client/resources/agents/steps.py +++ b/src/llama_stack_client/resources/agents/steps.py @@ -49,6 +49,7 @@ def retrieve( self, *, agent_id: str, + session_id: str, step_id: str, turn_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, @@ -83,6 +84,7 @@ def retrieve( query=maybe_transform( { "agent_id": agent_id, + "session_id": session_id, "step_id": step_id, "turn_id": turn_id, }, @@ -117,6 +119,7 @@ async def retrieve( self, *, agent_id: str, + session_id: str, step_id: str, turn_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, @@ -151,6 +154,7 @@ async def retrieve( query=await async_maybe_transform( { "agent_id": agent_id, + "session_id": session_id, "step_id": step_id, "turn_id": turn_id, }, diff --git a/src/llama_stack_client/resources/agents/turn.py b/src/llama_stack_client/resources/agents/turn.py index 44a04b6e..f63d356b 100644 --- a/src/llama_stack_client/resources/agents/turn.py +++ b/src/llama_stack_client/resources/agents/turn.py @@ -185,6 +185,7 @@ def retrieve( self, *, agent_id: str, + session_id: str, turn_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -218,6 +219,7 @@ def retrieve( query=maybe_transform( { "agent_id": agent_id, + "session_id": session_id, "turn_id": turn_id, }, turn_retrieve_params.TurnRetrieveParams, @@ -380,6 +382,7 @@ async def retrieve( self, *, agent_id: str, + session_id: str, turn_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. @@ -413,6 +416,7 @@ async def retrieve( query=await async_maybe_transform( { "agent_id": agent_id, + "session_id": session_id, "turn_id": turn_id, }, turn_retrieve_params.TurnRetrieveParams, diff --git a/src/llama_stack_client/resources/datasetio.py b/src/llama_stack_client/resources/datasetio.py new file mode 100644 index 00000000..92dafdbf --- /dev/null +++ b/src/llama_stack_client/resources/datasetio.py @@ -0,0 +1,201 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import httpx + +from ..types import datasetio_get_rows_paginated_params +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._utils import ( + maybe_transform, + strip_not_given, + async_maybe_transform, +) +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.paginated_rows_result import PaginatedRowsResult + +__all__ = ["DatasetioResource", "AsyncDatasetioResource"] + + +class DatasetioResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> DatasetioResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return DatasetioResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> DatasetioResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return DatasetioResourceWithStreamingResponse(self) + + def get_rows_paginated( + self, + *, + dataset_id: str, + rows_in_page: int, + filter_condition: str | NotGiven = NOT_GIVEN, + page_token: str | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> PaginatedRowsResult: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/datasetio/get_rows_paginated", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + { + "dataset_id": dataset_id, + "rows_in_page": rows_in_page, + "filter_condition": filter_condition, + "page_token": page_token, + }, + datasetio_get_rows_paginated_params.DatasetioGetRowsPaginatedParams, + ), + ), + cast_to=PaginatedRowsResult, + ) + + +class AsyncDatasetioResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncDatasetioResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return AsyncDatasetioResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncDatasetioResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return AsyncDatasetioResourceWithStreamingResponse(self) + + async def get_rows_paginated( + self, + *, + dataset_id: str, + rows_in_page: int, + filter_condition: str | NotGiven = NOT_GIVEN, + page_token: str | NotGiven = NOT_GIVEN, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> PaginatedRowsResult: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/datasetio/get_rows_paginated", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + { + "dataset_id": dataset_id, + "rows_in_page": rows_in_page, + "filter_condition": filter_condition, + "page_token": page_token, + }, + datasetio_get_rows_paginated_params.DatasetioGetRowsPaginatedParams, + ), + ), + cast_to=PaginatedRowsResult, + ) + + +class DatasetioResourceWithRawResponse: + def __init__(self, datasetio: DatasetioResource) -> None: + self._datasetio = datasetio + + self.get_rows_paginated = to_raw_response_wrapper( + datasetio.get_rows_paginated, + ) + + +class AsyncDatasetioResourceWithRawResponse: + def __init__(self, datasetio: AsyncDatasetioResource) -> None: + self._datasetio = datasetio + + self.get_rows_paginated = async_to_raw_response_wrapper( + datasetio.get_rows_paginated, + ) + + +class DatasetioResourceWithStreamingResponse: + def __init__(self, datasetio: DatasetioResource) -> None: + self._datasetio = datasetio + + self.get_rows_paginated = to_streamed_response_wrapper( + datasetio.get_rows_paginated, + ) + + +class AsyncDatasetioResourceWithStreamingResponse: + def __init__(self, datasetio: AsyncDatasetioResource) -> None: + self._datasetio = datasetio + + self.get_rows_paginated = async_to_streamed_response_wrapper( + datasetio.get_rows_paginated, + ) diff --git a/src/llama_stack_client/resources/datasets.py b/src/llama_stack_client/resources/datasets.py index 0373e92c..c3927e7f 100644 --- a/src/llama_stack_client/resources/datasets.py +++ b/src/llama_stack_client/resources/datasets.py @@ -2,14 +2,11 @@ from __future__ import annotations +from typing import Optional + import httpx -from ..types import ( - TrainEvalDataset, - dataset_create_params, - dataset_delete_params, - dataset_retrieve_params, -) +from ..types import dataset_register_params, dataset_retrieve_params from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven from .._utils import ( maybe_transform, @@ -25,8 +22,8 @@ async_to_streamed_response_wrapper, ) from .._base_client import make_request_options -from ..types.train_eval_dataset import TrainEvalDataset -from ..types.train_eval_dataset_param import TrainEvalDatasetParam +from ..types.dataset_list_response import DatasetListResponse +from ..types.dataset_retrieve_response import DatasetRetrieveResponse __all__ = ["DatasetsResource", "AsyncDatasetsResource"] @@ -51,11 +48,10 @@ def with_streaming_response(self) -> DatasetsResourceWithStreamingResponse: """ return DatasetsResourceWithStreamingResponse(self) - def create( + def retrieve( self, *, - dataset: TrainEvalDatasetParam, - uuid: str, + dataset_identifier: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -63,7 +59,7 @@ def create( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> Optional[DatasetRetrieveResponse]: """ Args: extra_headers: Send extra headers @@ -74,30 +70,27 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} extra_headers = { **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), **(extra_headers or {}), } - return self._post( - "/datasets/create", - body=maybe_transform( - { - "dataset": dataset, - "uuid": uuid, - }, - dataset_create_params.DatasetCreateParams, - ), + return self._get( + "/datasets/get", options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform( + {"dataset_identifier": dataset_identifier}, dataset_retrieve_params.DatasetRetrieveParams + ), ), - cast_to=NoneType, + cast_to=DatasetRetrieveResponse, ) - def retrieve( + def list( self, *, - dataset_uuid: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -105,7 +98,7 @@ def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> TrainEvalDataset: + ) -> DatasetListResponse: """ Args: extra_headers: Send extra headers @@ -116,26 +109,23 @@ def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} extra_headers = { **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), **(extra_headers or {}), } return self._get( - "/datasets/get", + "/datasets/list", options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=maybe_transform({"dataset_uuid": dataset_uuid}, dataset_retrieve_params.DatasetRetrieveParams), + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=TrainEvalDataset, + cast_to=DatasetListResponse, ) - def delete( + def register( self, *, - dataset_uuid: str, + dataset_def: dataset_register_params.DatasetDef, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -160,8 +150,8 @@ def delete( **(extra_headers or {}), } return self._post( - "/datasets/delete", - body=maybe_transform({"dataset_uuid": dataset_uuid}, dataset_delete_params.DatasetDeleteParams), + "/datasets/register", + body=maybe_transform({"dataset_def": dataset_def}, dataset_register_params.DatasetRegisterParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -189,11 +179,10 @@ def with_streaming_response(self) -> AsyncDatasetsResourceWithStreamingResponse: """ return AsyncDatasetsResourceWithStreamingResponse(self) - async def create( + async def retrieve( self, *, - dataset: TrainEvalDatasetParam, - uuid: str, + dataset_identifier: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -201,7 +190,7 @@ async def create( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> None: + ) -> Optional[DatasetRetrieveResponse]: """ Args: extra_headers: Send extra headers @@ -212,30 +201,27 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ - extra_headers = {"Accept": "*/*", **(extra_headers or {})} extra_headers = { **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), **(extra_headers or {}), } - return await self._post( - "/datasets/create", - body=await async_maybe_transform( - { - "dataset": dataset, - "uuid": uuid, - }, - dataset_create_params.DatasetCreateParams, - ), + return await self._get( + "/datasets/get", options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + {"dataset_identifier": dataset_identifier}, dataset_retrieve_params.DatasetRetrieveParams + ), ), - cast_to=NoneType, + cast_to=DatasetRetrieveResponse, ) - async def retrieve( + async def list( self, *, - dataset_uuid: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -243,7 +229,7 @@ async def retrieve( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> TrainEvalDataset: + ) -> DatasetListResponse: """ Args: extra_headers: Send extra headers @@ -254,28 +240,23 @@ async def retrieve( timeout: Override the client-level default timeout for this request, in seconds """ + extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} extra_headers = { **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), **(extra_headers or {}), } return await self._get( - "/datasets/get", + "/datasets/list", options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=await async_maybe_transform( - {"dataset_uuid": dataset_uuid}, dataset_retrieve_params.DatasetRetrieveParams - ), + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=TrainEvalDataset, + cast_to=DatasetListResponse, ) - async def delete( + async def register( self, *, - dataset_uuid: str, + dataset_def: dataset_register_params.DatasetDef, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -300,8 +281,10 @@ async def delete( **(extra_headers or {}), } return await self._post( - "/datasets/delete", - body=await async_maybe_transform({"dataset_uuid": dataset_uuid}, dataset_delete_params.DatasetDeleteParams), + "/datasets/register", + body=await async_maybe_transform( + {"dataset_def": dataset_def}, dataset_register_params.DatasetRegisterParams + ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), @@ -313,14 +296,14 @@ class DatasetsResourceWithRawResponse: def __init__(self, datasets: DatasetsResource) -> None: self._datasets = datasets - self.create = to_raw_response_wrapper( - datasets.create, - ) self.retrieve = to_raw_response_wrapper( datasets.retrieve, ) - self.delete = to_raw_response_wrapper( - datasets.delete, + self.list = to_raw_response_wrapper( + datasets.list, + ) + self.register = to_raw_response_wrapper( + datasets.register, ) @@ -328,14 +311,14 @@ class AsyncDatasetsResourceWithRawResponse: def __init__(self, datasets: AsyncDatasetsResource) -> None: self._datasets = datasets - self.create = async_to_raw_response_wrapper( - datasets.create, - ) self.retrieve = async_to_raw_response_wrapper( datasets.retrieve, ) - self.delete = async_to_raw_response_wrapper( - datasets.delete, + self.list = async_to_raw_response_wrapper( + datasets.list, + ) + self.register = async_to_raw_response_wrapper( + datasets.register, ) @@ -343,14 +326,14 @@ class DatasetsResourceWithStreamingResponse: def __init__(self, datasets: DatasetsResource) -> None: self._datasets = datasets - self.create = to_streamed_response_wrapper( - datasets.create, - ) self.retrieve = to_streamed_response_wrapper( datasets.retrieve, ) - self.delete = to_streamed_response_wrapper( - datasets.delete, + self.list = to_streamed_response_wrapper( + datasets.list, + ) + self.register = to_streamed_response_wrapper( + datasets.register, ) @@ -358,12 +341,12 @@ class AsyncDatasetsResourceWithStreamingResponse: def __init__(self, datasets: AsyncDatasetsResource) -> None: self._datasets = datasets - self.create = async_to_streamed_response_wrapper( - datasets.create, - ) self.retrieve = async_to_streamed_response_wrapper( datasets.retrieve, ) - self.delete = async_to_streamed_response_wrapper( - datasets.delete, + self.list = async_to_streamed_response_wrapper( + datasets.list, + ) + self.register = async_to_streamed_response_wrapper( + datasets.register, ) diff --git a/src/llama_stack_client/resources/eval/__init__.py b/src/llama_stack_client/resources/eval/__init__.py new file mode 100644 index 00000000..771fdbb4 --- /dev/null +++ b/src/llama_stack_client/resources/eval/__init__.py @@ -0,0 +1,33 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from .job import ( + JobResource, + AsyncJobResource, + JobResourceWithRawResponse, + AsyncJobResourceWithRawResponse, + JobResourceWithStreamingResponse, + AsyncJobResourceWithStreamingResponse, +) +from .eval import ( + EvalResource, + AsyncEvalResource, + EvalResourceWithRawResponse, + AsyncEvalResourceWithRawResponse, + EvalResourceWithStreamingResponse, + AsyncEvalResourceWithStreamingResponse, +) + +__all__ = [ + "JobResource", + "AsyncJobResource", + "JobResourceWithRawResponse", + "AsyncJobResourceWithRawResponse", + "JobResourceWithStreamingResponse", + "AsyncJobResourceWithStreamingResponse", + "EvalResource", + "AsyncEvalResource", + "EvalResourceWithRawResponse", + "AsyncEvalResourceWithRawResponse", + "EvalResourceWithStreamingResponse", + "AsyncEvalResourceWithStreamingResponse", +] diff --git a/src/llama_stack_client/resources/eval/eval.py b/src/llama_stack_client/resources/eval/eval.py new file mode 100644 index 00000000..1e4ff7eb --- /dev/null +++ b/src/llama_stack_client/resources/eval/eval.py @@ -0,0 +1,326 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, List, Union, Iterable + +import httpx + +from .job import ( + JobResource, + AsyncJobResource, + JobResourceWithRawResponse, + AsyncJobResourceWithRawResponse, + JobResourceWithStreamingResponse, + AsyncJobResourceWithStreamingResponse, +) +from ...types import eval_evaluate_params, eval_evaluate_batch_params +from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from ..._utils import ( + maybe_transform, + strip_not_given, + async_maybe_transform, +) +from ..._compat import cached_property +from ..._resource import SyncAPIResource, AsyncAPIResource +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from ...types.job import Job +from ..._base_client import make_request_options +from ...types.eval_evaluate_response import EvalEvaluateResponse + +__all__ = ["EvalResource", "AsyncEvalResource"] + + +class EvalResource(SyncAPIResource): + @cached_property + def job(self) -> JobResource: + return JobResource(self._client) + + @cached_property + def with_raw_response(self) -> EvalResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return EvalResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> EvalResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return EvalResourceWithStreamingResponse(self) + + def evaluate( + self, + *, + candidate: eval_evaluate_params.Candidate, + input_rows: Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]], + scoring_functions: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> EvalEvaluateResponse: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._post( + "/eval/evaluate", + body=maybe_transform( + { + "candidate": candidate, + "input_rows": input_rows, + "scoring_functions": scoring_functions, + }, + eval_evaluate_params.EvalEvaluateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=EvalEvaluateResponse, + ) + + def evaluate_batch( + self, + *, + candidate: eval_evaluate_batch_params.Candidate, + dataset_id: str, + scoring_functions: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Job: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._post( + "/eval/evaluate_batch", + body=maybe_transform( + { + "candidate": candidate, + "dataset_id": dataset_id, + "scoring_functions": scoring_functions, + }, + eval_evaluate_batch_params.EvalEvaluateBatchParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Job, + ) + + +class AsyncEvalResource(AsyncAPIResource): + @cached_property + def job(self) -> AsyncJobResource: + return AsyncJobResource(self._client) + + @cached_property + def with_raw_response(self) -> AsyncEvalResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return AsyncEvalResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncEvalResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return AsyncEvalResourceWithStreamingResponse(self) + + async def evaluate( + self, + *, + candidate: eval_evaluate_params.Candidate, + input_rows: Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]], + scoring_functions: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> EvalEvaluateResponse: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._post( + "/eval/evaluate", + body=await async_maybe_transform( + { + "candidate": candidate, + "input_rows": input_rows, + "scoring_functions": scoring_functions, + }, + eval_evaluate_params.EvalEvaluateParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=EvalEvaluateResponse, + ) + + async def evaluate_batch( + self, + *, + candidate: eval_evaluate_batch_params.Candidate, + dataset_id: str, + scoring_functions: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Job: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._post( + "/eval/evaluate_batch", + body=await async_maybe_transform( + { + "candidate": candidate, + "dataset_id": dataset_id, + "scoring_functions": scoring_functions, + }, + eval_evaluate_batch_params.EvalEvaluateBatchParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=Job, + ) + + +class EvalResourceWithRawResponse: + def __init__(self, eval: EvalResource) -> None: + self._eval = eval + + self.evaluate = to_raw_response_wrapper( + eval.evaluate, + ) + self.evaluate_batch = to_raw_response_wrapper( + eval.evaluate_batch, + ) + + @cached_property + def job(self) -> JobResourceWithRawResponse: + return JobResourceWithRawResponse(self._eval.job) + + +class AsyncEvalResourceWithRawResponse: + def __init__(self, eval: AsyncEvalResource) -> None: + self._eval = eval + + self.evaluate = async_to_raw_response_wrapper( + eval.evaluate, + ) + self.evaluate_batch = async_to_raw_response_wrapper( + eval.evaluate_batch, + ) + + @cached_property + def job(self) -> AsyncJobResourceWithRawResponse: + return AsyncJobResourceWithRawResponse(self._eval.job) + + +class EvalResourceWithStreamingResponse: + def __init__(self, eval: EvalResource) -> None: + self._eval = eval + + self.evaluate = to_streamed_response_wrapper( + eval.evaluate, + ) + self.evaluate_batch = to_streamed_response_wrapper( + eval.evaluate_batch, + ) + + @cached_property + def job(self) -> JobResourceWithStreamingResponse: + return JobResourceWithStreamingResponse(self._eval.job) + + +class AsyncEvalResourceWithStreamingResponse: + def __init__(self, eval: AsyncEvalResource) -> None: + self._eval = eval + + self.evaluate = async_to_streamed_response_wrapper( + eval.evaluate, + ) + self.evaluate_batch = async_to_streamed_response_wrapper( + eval.evaluate_batch, + ) + + @cached_property + def job(self) -> AsyncJobResourceWithStreamingResponse: + return AsyncJobResourceWithStreamingResponse(self._eval.job) diff --git a/src/llama_stack_client/resources/eval/job.py b/src/llama_stack_client/resources/eval/job.py new file mode 100644 index 00000000..4824db87 --- /dev/null +++ b/src/llama_stack_client/resources/eval/job.py @@ -0,0 +1,354 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Optional + +import httpx + +from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven +from ..._utils import ( + maybe_transform, + strip_not_given, + async_maybe_transform, +) +from ..._compat import cached_property +from ..._resource import SyncAPIResource, AsyncAPIResource +from ..._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from ...types.eval import job_cancel_params, job_result_params, job_status_params +from ..._base_client import make_request_options +from ...types.eval.job_status import JobStatus +from ...types.eval.job_result_response import JobResultResponse + +__all__ = ["JobResource", "AsyncJobResource"] + + +class JobResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> JobResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return JobResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> JobResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return JobResourceWithStreamingResponse(self) + + def cancel( + self, + *, + job_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._post( + "/eval/job/cancel", + body=maybe_transform({"job_id": job_id}, job_cancel_params.JobCancelParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + + def result( + self, + *, + job_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> JobResultResponse: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/eval/job/result", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"job_id": job_id}, job_result_params.JobResultParams), + ), + cast_to=JobResultResponse, + ) + + def status( + self, + *, + job_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Optional[JobStatus]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/eval/job/status", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"job_id": job_id}, job_status_params.JobStatusParams), + ), + cast_to=JobStatus, + ) + + +class AsyncJobResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncJobResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return AsyncJobResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncJobResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return AsyncJobResourceWithStreamingResponse(self) + + async def cancel( + self, + *, + job_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._post( + "/eval/job/cancel", + body=await async_maybe_transform({"job_id": job_id}, job_cancel_params.JobCancelParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + + async def result( + self, + *, + job_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> JobResultResponse: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/eval/job/result", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"job_id": job_id}, job_result_params.JobResultParams), + ), + cast_to=JobResultResponse, + ) + + async def status( + self, + *, + job_id: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Optional[JobStatus]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/eval/job/status", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"job_id": job_id}, job_status_params.JobStatusParams), + ), + cast_to=JobStatus, + ) + + +class JobResourceWithRawResponse: + def __init__(self, job: JobResource) -> None: + self._job = job + + self.cancel = to_raw_response_wrapper( + job.cancel, + ) + self.result = to_raw_response_wrapper( + job.result, + ) + self.status = to_raw_response_wrapper( + job.status, + ) + + +class AsyncJobResourceWithRawResponse: + def __init__(self, job: AsyncJobResource) -> None: + self._job = job + + self.cancel = async_to_raw_response_wrapper( + job.cancel, + ) + self.result = async_to_raw_response_wrapper( + job.result, + ) + self.status = async_to_raw_response_wrapper( + job.status, + ) + + +class JobResourceWithStreamingResponse: + def __init__(self, job: JobResource) -> None: + self._job = job + + self.cancel = to_streamed_response_wrapper( + job.cancel, + ) + self.result = to_streamed_response_wrapper( + job.result, + ) + self.status = to_streamed_response_wrapper( + job.status, + ) + + +class AsyncJobResourceWithStreamingResponse: + def __init__(self, job: AsyncJobResource) -> None: + self._job = job + + self.cancel = async_to_streamed_response_wrapper( + job.cancel, + ) + self.result = async_to_streamed_response_wrapper( + job.result, + ) + self.status = async_to_streamed_response_wrapper( + job.status, + ) diff --git a/src/llama_stack_client/resources/evaluate/evaluate.py b/src/llama_stack_client/resources/evaluate/evaluate.py index d6e543a1..b22e9523 100644 --- a/src/llama_stack_client/resources/evaluate/evaluate.py +++ b/src/llama_stack_client/resources/evaluate/evaluate.py @@ -2,8 +2,7 @@ from __future__ import annotations -from typing import List -from typing_extensions import Literal +from typing import Dict, List, Union, Iterable import httpx @@ -15,7 +14,7 @@ JobsResourceWithStreamingResponse, AsyncJobsResourceWithStreamingResponse, ) -from ...types import evaluate_summarization_params, evaluate_question_answering_params +from ...types import evaluate_evaluate_params, evaluate_evaluate_batch_params from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( maybe_transform, @@ -30,8 +29,9 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ...types.job import Job from ..._base_client import make_request_options -from ...types.evaluation_job import EvaluationJob +from ...types.evaluate.evaluate_response import EvaluateResponse __all__ = ["EvaluateResource", "AsyncEvaluateResource"] @@ -60,10 +60,12 @@ def with_streaming_response(self) -> EvaluateResourceWithStreamingResponse: """ return EvaluateResourceWithStreamingResponse(self) - def question_answering( + def evaluate( self, *, - metrics: List[Literal["em", "f1"]], + candidate: evaluate_evaluate_params.Candidate, + input_rows: Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]], + scoring_functions: List[str], x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -71,7 +73,7 @@ def question_answering( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> EvaluationJob: + ) -> EvaluateResponse: """ Args: extra_headers: Send extra headers @@ -87,20 +89,27 @@ def question_answering( **(extra_headers or {}), } return self._post( - "/evaluate/question_answering/", + "/eval/evaluate", body=maybe_transform( - {"metrics": metrics}, evaluate_question_answering_params.EvaluateQuestionAnsweringParams + { + "candidate": candidate, + "input_rows": input_rows, + "scoring_functions": scoring_functions, + }, + evaluate_evaluate_params.EvaluateEvaluateParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=EvaluationJob, + cast_to=EvaluateResponse, ) - def summarization( + def evaluate_batch( self, *, - metrics: List[Literal["rouge", "bleu"]], + candidate: evaluate_evaluate_batch_params.Candidate, + dataset_id: str, + scoring_functions: List[str], x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -108,7 +117,7 @@ def summarization( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> EvaluationJob: + ) -> Job: """ Args: extra_headers: Send extra headers @@ -124,12 +133,19 @@ def summarization( **(extra_headers or {}), } return self._post( - "/evaluate/summarization/", - body=maybe_transform({"metrics": metrics}, evaluate_summarization_params.EvaluateSummarizationParams), + "/eval/evaluate_batch", + body=maybe_transform( + { + "candidate": candidate, + "dataset_id": dataset_id, + "scoring_functions": scoring_functions, + }, + evaluate_evaluate_batch_params.EvaluateEvaluateBatchParams, + ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=EvaluationJob, + cast_to=Job, ) @@ -157,10 +173,12 @@ def with_streaming_response(self) -> AsyncEvaluateResourceWithStreamingResponse: """ return AsyncEvaluateResourceWithStreamingResponse(self) - async def question_answering( + async def evaluate( self, *, - metrics: List[Literal["em", "f1"]], + candidate: evaluate_evaluate_params.Candidate, + input_rows: Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]], + scoring_functions: List[str], x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -168,7 +186,7 @@ async def question_answering( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> EvaluationJob: + ) -> EvaluateResponse: """ Args: extra_headers: Send extra headers @@ -184,20 +202,27 @@ async def question_answering( **(extra_headers or {}), } return await self._post( - "/evaluate/question_answering/", + "/eval/evaluate", body=await async_maybe_transform( - {"metrics": metrics}, evaluate_question_answering_params.EvaluateQuestionAnsweringParams + { + "candidate": candidate, + "input_rows": input_rows, + "scoring_functions": scoring_functions, + }, + evaluate_evaluate_params.EvaluateEvaluateParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=EvaluationJob, + cast_to=EvaluateResponse, ) - async def summarization( + async def evaluate_batch( self, *, - metrics: List[Literal["rouge", "bleu"]], + candidate: evaluate_evaluate_batch_params.Candidate, + dataset_id: str, + scoring_functions: List[str], x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -205,7 +230,7 @@ async def summarization( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> EvaluationJob: + ) -> Job: """ Args: extra_headers: Send extra headers @@ -221,14 +246,19 @@ async def summarization( **(extra_headers or {}), } return await self._post( - "/evaluate/summarization/", + "/eval/evaluate_batch", body=await async_maybe_transform( - {"metrics": metrics}, evaluate_summarization_params.EvaluateSummarizationParams + { + "candidate": candidate, + "dataset_id": dataset_id, + "scoring_functions": scoring_functions, + }, + evaluate_evaluate_batch_params.EvaluateEvaluateBatchParams, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), - cast_to=EvaluationJob, + cast_to=Job, ) @@ -236,11 +266,11 @@ class EvaluateResourceWithRawResponse: def __init__(self, evaluate: EvaluateResource) -> None: self._evaluate = evaluate - self.question_answering = to_raw_response_wrapper( - evaluate.question_answering, + self.evaluate = to_raw_response_wrapper( + evaluate.evaluate, ) - self.summarization = to_raw_response_wrapper( - evaluate.summarization, + self.evaluate_batch = to_raw_response_wrapper( + evaluate.evaluate_batch, ) @cached_property @@ -252,11 +282,11 @@ class AsyncEvaluateResourceWithRawResponse: def __init__(self, evaluate: AsyncEvaluateResource) -> None: self._evaluate = evaluate - self.question_answering = async_to_raw_response_wrapper( - evaluate.question_answering, + self.evaluate = async_to_raw_response_wrapper( + evaluate.evaluate, ) - self.summarization = async_to_raw_response_wrapper( - evaluate.summarization, + self.evaluate_batch = async_to_raw_response_wrapper( + evaluate.evaluate_batch, ) @cached_property @@ -268,11 +298,11 @@ class EvaluateResourceWithStreamingResponse: def __init__(self, evaluate: EvaluateResource) -> None: self._evaluate = evaluate - self.question_answering = to_streamed_response_wrapper( - evaluate.question_answering, + self.evaluate = to_streamed_response_wrapper( + evaluate.evaluate, ) - self.summarization = to_streamed_response_wrapper( - evaluate.summarization, + self.evaluate_batch = to_streamed_response_wrapper( + evaluate.evaluate_batch, ) @cached_property @@ -284,11 +314,11 @@ class AsyncEvaluateResourceWithStreamingResponse: def __init__(self, evaluate: AsyncEvaluateResource) -> None: self._evaluate = evaluate - self.question_answering = async_to_streamed_response_wrapper( - evaluate.question_answering, + self.evaluate = async_to_streamed_response_wrapper( + evaluate.evaluate, ) - self.summarization = async_to_streamed_response_wrapper( - evaluate.summarization, + self.evaluate_batch = async_to_streamed_response_wrapper( + evaluate.evaluate_batch, ) @cached_property diff --git a/src/llama_stack_client/resources/evaluate/jobs.py b/src/llama_stack_client/resources/evaluate/jobs.py index 55e0dbb6..55219e62 100644 --- a/src/llama_stack_client/resources/evaluate/jobs.py +++ b/src/llama_stack_client/resources/evaluate/jobs.py @@ -2,6 +2,8 @@ from __future__ import annotations +from typing import Optional + import httpx from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven @@ -19,11 +21,9 @@ async_to_streamed_response_wrapper, ) from ..._base_client import make_request_options -from ...types.evaluate import job_logs_params, job_cancel_params, job_status_params, job_artifacts_params -from ...types.evaluation_job import EvaluationJob -from ...types.evaluate.job_logs_response import JobLogsResponse -from ...types.evaluate.job_status_response import JobStatusResponse -from ...types.evaluate.job_artifacts_response import JobArtifactsResponse +from ...types.evaluate import job_cancel_params, job_result_params, job_status_params +from ...types.evaluate.job_status import JobStatus +from ...types.evaluate.evaluate_response import EvaluateResponse __all__ = ["JobsResource", "AsyncJobsResource"] @@ -48,82 +48,10 @@ def with_streaming_response(self) -> JobsResourceWithStreamingResponse: """ return JobsResourceWithStreamingResponse(self) - def list( - self, - *, - x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> EvaluationJob: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} - extra_headers = { - **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), - **(extra_headers or {}), - } - return self._get( - "/evaluate/jobs", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=EvaluationJob, - ) - - def artifacts( - self, - *, - job_uuid: str, - x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> JobArtifactsResponse: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = { - **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), - **(extra_headers or {}), - } - return self._get( - "/evaluate/job/artifacts", - options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=maybe_transform({"job_uuid": job_uuid}, job_artifacts_params.JobArtifactsParams), - ), - cast_to=JobArtifactsResponse, - ) - def cancel( self, *, - job_uuid: str, + job_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -148,18 +76,18 @@ def cancel( **(extra_headers or {}), } return self._post( - "/evaluate/job/cancel", - body=maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams), + "/eval/job/cancel", + body=maybe_transform({"job_id": job_id}, job_cancel_params.JobCancelParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=NoneType, ) - def logs( + def result( self, *, - job_uuid: str, + job_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -167,7 +95,7 @@ def logs( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> JobLogsResponse: + ) -> EvaluateResponse: """ Args: extra_headers: Send extra headers @@ -183,21 +111,21 @@ def logs( **(extra_headers or {}), } return self._get( - "/evaluate/job/logs", + "/eval/job/result", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout, - query=maybe_transform({"job_uuid": job_uuid}, job_logs_params.JobLogsParams), + query=maybe_transform({"job_id": job_id}, job_result_params.JobResultParams), ), - cast_to=JobLogsResponse, + cast_to=EvaluateResponse, ) def status( self, *, - job_uuid: str, + job_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -205,7 +133,7 @@ def status( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> JobStatusResponse: + ) -> Optional[JobStatus]: """ Args: extra_headers: Send extra headers @@ -221,15 +149,15 @@ def status( **(extra_headers or {}), } return self._get( - "/evaluate/job/status", + "/eval/job/status", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout, - query=maybe_transform({"job_uuid": job_uuid}, job_status_params.JobStatusParams), + query=maybe_transform({"job_id": job_id}, job_status_params.JobStatusParams), ), - cast_to=JobStatusResponse, + cast_to=JobStatus, ) @@ -253,82 +181,10 @@ def with_streaming_response(self) -> AsyncJobsResourceWithStreamingResponse: """ return AsyncJobsResourceWithStreamingResponse(self) - async def list( - self, - *, - x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> EvaluationJob: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} - extra_headers = { - **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), - **(extra_headers or {}), - } - return await self._get( - "/evaluate/jobs", - options=make_request_options( - extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout - ), - cast_to=EvaluationJob, - ) - - async def artifacts( - self, - *, - job_uuid: str, - x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, - # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. - # The extra values given here take precedence over values defined on the client or passed to this method. - extra_headers: Headers | None = None, - extra_query: Query | None = None, - extra_body: Body | None = None, - timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> JobArtifactsResponse: - """ - Args: - extra_headers: Send extra headers - - extra_query: Add additional query parameters to the request - - extra_body: Add additional JSON properties to the request - - timeout: Override the client-level default timeout for this request, in seconds - """ - extra_headers = { - **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), - **(extra_headers or {}), - } - return await self._get( - "/evaluate/job/artifacts", - options=make_request_options( - extra_headers=extra_headers, - extra_query=extra_query, - extra_body=extra_body, - timeout=timeout, - query=await async_maybe_transform({"job_uuid": job_uuid}, job_artifacts_params.JobArtifactsParams), - ), - cast_to=JobArtifactsResponse, - ) - async def cancel( self, *, - job_uuid: str, + job_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -353,18 +209,18 @@ async def cancel( **(extra_headers or {}), } return await self._post( - "/evaluate/job/cancel", - body=await async_maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams), + "/eval/job/cancel", + body=await async_maybe_transform({"job_id": job_id}, job_cancel_params.JobCancelParams), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=NoneType, ) - async def logs( + async def result( self, *, - job_uuid: str, + job_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -372,7 +228,7 @@ async def logs( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> JobLogsResponse: + ) -> EvaluateResponse: """ Args: extra_headers: Send extra headers @@ -388,21 +244,21 @@ async def logs( **(extra_headers or {}), } return await self._get( - "/evaluate/job/logs", + "/eval/job/result", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout, - query=await async_maybe_transform({"job_uuid": job_uuid}, job_logs_params.JobLogsParams), + query=await async_maybe_transform({"job_id": job_id}, job_result_params.JobResultParams), ), - cast_to=JobLogsResponse, + cast_to=EvaluateResponse, ) async def status( self, *, - job_uuid: str, + job_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -410,7 +266,7 @@ async def status( extra_query: Query | None = None, extra_body: Body | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, - ) -> JobStatusResponse: + ) -> Optional[JobStatus]: """ Args: extra_headers: Send extra headers @@ -426,15 +282,15 @@ async def status( **(extra_headers or {}), } return await self._get( - "/evaluate/job/status", + "/eval/job/status", options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout, - query=await async_maybe_transform({"job_uuid": job_uuid}, job_status_params.JobStatusParams), + query=await async_maybe_transform({"job_id": job_id}, job_status_params.JobStatusParams), ), - cast_to=JobStatusResponse, + cast_to=JobStatus, ) @@ -442,17 +298,11 @@ class JobsResourceWithRawResponse: def __init__(self, jobs: JobsResource) -> None: self._jobs = jobs - self.list = to_raw_response_wrapper( - jobs.list, - ) - self.artifacts = to_raw_response_wrapper( - jobs.artifacts, - ) self.cancel = to_raw_response_wrapper( jobs.cancel, ) - self.logs = to_raw_response_wrapper( - jobs.logs, + self.result = to_raw_response_wrapper( + jobs.result, ) self.status = to_raw_response_wrapper( jobs.status, @@ -463,17 +313,11 @@ class AsyncJobsResourceWithRawResponse: def __init__(self, jobs: AsyncJobsResource) -> None: self._jobs = jobs - self.list = async_to_raw_response_wrapper( - jobs.list, - ) - self.artifacts = async_to_raw_response_wrapper( - jobs.artifacts, - ) self.cancel = async_to_raw_response_wrapper( jobs.cancel, ) - self.logs = async_to_raw_response_wrapper( - jobs.logs, + self.result = async_to_raw_response_wrapper( + jobs.result, ) self.status = async_to_raw_response_wrapper( jobs.status, @@ -484,17 +328,11 @@ class JobsResourceWithStreamingResponse: def __init__(self, jobs: JobsResource) -> None: self._jobs = jobs - self.list = to_streamed_response_wrapper( - jobs.list, - ) - self.artifacts = to_streamed_response_wrapper( - jobs.artifacts, - ) self.cancel = to_streamed_response_wrapper( jobs.cancel, ) - self.logs = to_streamed_response_wrapper( - jobs.logs, + self.result = to_streamed_response_wrapper( + jobs.result, ) self.status = to_streamed_response_wrapper( jobs.status, @@ -505,17 +343,11 @@ class AsyncJobsResourceWithStreamingResponse: def __init__(self, jobs: AsyncJobsResource) -> None: self._jobs = jobs - self.list = async_to_streamed_response_wrapper( - jobs.list, - ) - self.artifacts = async_to_streamed_response_wrapper( - jobs.artifacts, - ) self.cancel = async_to_streamed_response_wrapper( jobs.cancel, ) - self.logs = async_to_streamed_response_wrapper( - jobs.logs, + self.result = async_to_streamed_response_wrapper( + jobs.result, ) self.status = async_to_streamed_response_wrapper( jobs.status, diff --git a/src/llama_stack_client/resources/inference.py b/src/llama_stack_client/resources/inference.py index 895cf578..23b4f68a 100644 --- a/src/llama_stack_client/resources/inference.py +++ b/src/llama_stack_client/resources/inference.py @@ -61,6 +61,7 @@ def chat_completion( messages: Iterable[inference_chat_completion_params.Message], model: str, logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, stream: bool | NotGiven = NOT_GIVEN, tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, @@ -108,6 +109,7 @@ def chat_completion( "messages": messages, "model": model, "logprobs": logprobs, + "response_format": response_format, "sampling_params": sampling_params, "stream": stream, "tool_choice": tool_choice, @@ -131,6 +133,7 @@ def completion( content: inference_completion_params.Content, model: str, logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, stream: bool | NotGiven = NOT_GIVEN, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, @@ -164,6 +167,7 @@ def completion( "content": content, "model": model, "logprobs": logprobs, + "response_format": response_format, "sampling_params": sampling_params, "stream": stream, }, @@ -247,6 +251,7 @@ async def chat_completion( messages: Iterable[inference_chat_completion_params.Message], model: str, logprobs: inference_chat_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_chat_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, stream: bool | NotGiven = NOT_GIVEN, tool_choice: Literal["auto", "required"] | NotGiven = NOT_GIVEN, @@ -294,6 +299,7 @@ async def chat_completion( "messages": messages, "model": model, "logprobs": logprobs, + "response_format": response_format, "sampling_params": sampling_params, "stream": stream, "tool_choice": tool_choice, @@ -317,6 +323,7 @@ async def completion( content: inference_completion_params.Content, model: str, logprobs: inference_completion_params.Logprobs | NotGiven = NOT_GIVEN, + response_format: inference_completion_params.ResponseFormat | NotGiven = NOT_GIVEN, sampling_params: SamplingParams | NotGiven = NOT_GIVEN, stream: bool | NotGiven = NOT_GIVEN, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, @@ -350,6 +357,7 @@ async def completion( "content": content, "model": model, "logprobs": logprobs, + "response_format": response_format, "sampling_params": sampling_params, "stream": stream, }, diff --git a/src/llama_stack_client/resources/post_training/jobs.py b/src/llama_stack_client/resources/post_training/jobs.py index 43fdd782..840b2e70 100644 --- a/src/llama_stack_client/resources/post_training/jobs.py +++ b/src/llama_stack_client/resources/post_training/jobs.py @@ -4,8 +4,12 @@ import httpx -from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ..._utils import strip_not_given +from ..._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven +from ..._utils import ( + maybe_transform, + strip_not_given, + async_maybe_transform, +) from ..._compat import cached_property from ..._resource import SyncAPIResource, AsyncAPIResource from ..._response import ( @@ -15,7 +19,11 @@ async_to_streamed_response_wrapper, ) from ..._base_client import make_request_options +from ...types.post_training import job_logs_params, job_cancel_params, job_status_params, job_artifacts_params from ...types.post_training_job import PostTrainingJob +from ...types.post_training.post_training_job_status import PostTrainingJobStatus +from ...types.post_training.post_training_job_artifacts import PostTrainingJobArtifacts +from ...types.post_training.post_training_job_log_stream import PostTrainingJobLogStream __all__ = ["JobsResource", "AsyncJobsResource"] @@ -74,6 +82,156 @@ def list( cast_to=PostTrainingJob, ) + def artifacts( + self, + *, + job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> PostTrainingJobArtifacts: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/post_training/job/artifacts", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"job_uuid": job_uuid}, job_artifacts_params.JobArtifactsParams), + ), + cast_to=PostTrainingJobArtifacts, + ) + + def cancel( + self, + *, + job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._post( + "/post_training/job/cancel", + body=maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + + def logs( + self, + *, + job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> PostTrainingJobLogStream: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/post_training/job/logs", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"job_uuid": job_uuid}, job_logs_params.JobLogsParams), + ), + cast_to=PostTrainingJobLogStream, + ) + + def status( + self, + *, + job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> PostTrainingJobStatus: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/post_training/job/status", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"job_uuid": job_uuid}, job_status_params.JobStatusParams), + ), + cast_to=PostTrainingJobStatus, + ) + class AsyncJobsResource(AsyncAPIResource): @cached_property @@ -129,6 +287,156 @@ async def list( cast_to=PostTrainingJob, ) + async def artifacts( + self, + *, + job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> PostTrainingJobArtifacts: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/post_training/job/artifacts", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"job_uuid": job_uuid}, job_artifacts_params.JobArtifactsParams), + ), + cast_to=PostTrainingJobArtifacts, + ) + + async def cancel( + self, + *, + job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._post( + "/post_training/job/cancel", + body=await async_maybe_transform({"job_uuid": job_uuid}, job_cancel_params.JobCancelParams), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + + async def logs( + self, + *, + job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> PostTrainingJobLogStream: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/post_training/job/logs", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"job_uuid": job_uuid}, job_logs_params.JobLogsParams), + ), + cast_to=PostTrainingJobLogStream, + ) + + async def status( + self, + *, + job_uuid: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> PostTrainingJobStatus: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/post_training/job/status", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform({"job_uuid": job_uuid}, job_status_params.JobStatusParams), + ), + cast_to=PostTrainingJobStatus, + ) + class JobsResourceWithRawResponse: def __init__(self, jobs: JobsResource) -> None: @@ -137,6 +445,18 @@ def __init__(self, jobs: JobsResource) -> None: self.list = to_raw_response_wrapper( jobs.list, ) + self.artifacts = to_raw_response_wrapper( + jobs.artifacts, + ) + self.cancel = to_raw_response_wrapper( + jobs.cancel, + ) + self.logs = to_raw_response_wrapper( + jobs.logs, + ) + self.status = to_raw_response_wrapper( + jobs.status, + ) class AsyncJobsResourceWithRawResponse: @@ -146,6 +466,18 @@ def __init__(self, jobs: AsyncJobsResource) -> None: self.list = async_to_raw_response_wrapper( jobs.list, ) + self.artifacts = async_to_raw_response_wrapper( + jobs.artifacts, + ) + self.cancel = async_to_raw_response_wrapper( + jobs.cancel, + ) + self.logs = async_to_raw_response_wrapper( + jobs.logs, + ) + self.status = async_to_raw_response_wrapper( + jobs.status, + ) class JobsResourceWithStreamingResponse: @@ -155,6 +487,18 @@ def __init__(self, jobs: JobsResource) -> None: self.list = to_streamed_response_wrapper( jobs.list, ) + self.artifacts = to_streamed_response_wrapper( + jobs.artifacts, + ) + self.cancel = to_streamed_response_wrapper( + jobs.cancel, + ) + self.logs = to_streamed_response_wrapper( + jobs.logs, + ) + self.status = to_streamed_response_wrapper( + jobs.status, + ) class AsyncJobsResourceWithStreamingResponse: @@ -164,3 +508,15 @@ def __init__(self, jobs: AsyncJobsResource) -> None: self.list = async_to_streamed_response_wrapper( jobs.list, ) + self.artifacts = async_to_streamed_response_wrapper( + jobs.artifacts, + ) + self.cancel = async_to_streamed_response_wrapper( + jobs.cancel, + ) + self.logs = async_to_streamed_response_wrapper( + jobs.logs, + ) + self.status = async_to_streamed_response_wrapper( + jobs.status, + ) diff --git a/src/llama_stack_client/resources/post_training/post_training.py b/src/llama_stack_client/resources/post_training/post_training.py index e05521d9..c54cfb4e 100644 --- a/src/llama_stack_client/resources/post_training/post_training.py +++ b/src/llama_stack_client/resources/post_training/post_training.py @@ -15,10 +15,7 @@ JobResourceWithStreamingResponse, AsyncJobResourceWithStreamingResponse, ) -from ...types import ( - post_training_preference_optimize_params, - post_training_supervised_fine_tune_params, -) +from ...types import post_training_preference_optimize_params, post_training_supervised_fine_tune_params from ..._types import NOT_GIVEN, Body, Query, Headers, NotGiven from ..._utils import ( maybe_transform, @@ -35,7 +32,6 @@ ) from ..._base_client import make_request_options from ...types.post_training_job import PostTrainingJob -from ...types.train_eval_dataset_param import TrainEvalDatasetParam __all__ = ["PostTrainingResource", "AsyncPostTrainingResource"] @@ -69,14 +65,14 @@ def preference_optimize( *, algorithm: Literal["dpo"], algorithm_config: post_training_preference_optimize_params.AlgorithmConfig, - dataset: TrainEvalDatasetParam, + dataset_id: str, finetuned_model: str, hyperparam_search_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], job_uuid: str, logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], optimizer_config: post_training_preference_optimize_params.OptimizerConfig, training_config: post_training_preference_optimize_params.TrainingConfig, - validation_dataset: TrainEvalDatasetParam, + validation_dataset_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -105,14 +101,14 @@ def preference_optimize( { "algorithm": algorithm, "algorithm_config": algorithm_config, - "dataset": dataset, + "dataset_id": dataset_id, "finetuned_model": finetuned_model, "hyperparam_search_config": hyperparam_search_config, "job_uuid": job_uuid, "logger_config": logger_config, "optimizer_config": optimizer_config, "training_config": training_config, - "validation_dataset": validation_dataset, + "validation_dataset_id": validation_dataset_id, }, post_training_preference_optimize_params.PostTrainingPreferenceOptimizeParams, ), @@ -127,14 +123,14 @@ def supervised_fine_tune( *, algorithm: Literal["full", "lora", "qlora", "dora"], algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig, - dataset: TrainEvalDatasetParam, + dataset_id: str, hyperparam_search_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], job_uuid: str, logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], model: str, optimizer_config: post_training_supervised_fine_tune_params.OptimizerConfig, training_config: post_training_supervised_fine_tune_params.TrainingConfig, - validation_dataset: TrainEvalDatasetParam, + validation_dataset_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -163,14 +159,14 @@ def supervised_fine_tune( { "algorithm": algorithm, "algorithm_config": algorithm_config, - "dataset": dataset, + "dataset_id": dataset_id, "hyperparam_search_config": hyperparam_search_config, "job_uuid": job_uuid, "logger_config": logger_config, "model": model, "optimizer_config": optimizer_config, "training_config": training_config, - "validation_dataset": validation_dataset, + "validation_dataset_id": validation_dataset_id, }, post_training_supervised_fine_tune_params.PostTrainingSupervisedFineTuneParams, ), @@ -210,14 +206,14 @@ async def preference_optimize( *, algorithm: Literal["dpo"], algorithm_config: post_training_preference_optimize_params.AlgorithmConfig, - dataset: TrainEvalDatasetParam, + dataset_id: str, finetuned_model: str, hyperparam_search_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], job_uuid: str, logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], optimizer_config: post_training_preference_optimize_params.OptimizerConfig, training_config: post_training_preference_optimize_params.TrainingConfig, - validation_dataset: TrainEvalDatasetParam, + validation_dataset_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -246,14 +242,14 @@ async def preference_optimize( { "algorithm": algorithm, "algorithm_config": algorithm_config, - "dataset": dataset, + "dataset_id": dataset_id, "finetuned_model": finetuned_model, "hyperparam_search_config": hyperparam_search_config, "job_uuid": job_uuid, "logger_config": logger_config, "optimizer_config": optimizer_config, "training_config": training_config, - "validation_dataset": validation_dataset, + "validation_dataset_id": validation_dataset_id, }, post_training_preference_optimize_params.PostTrainingPreferenceOptimizeParams, ), @@ -268,14 +264,14 @@ async def supervised_fine_tune( *, algorithm: Literal["full", "lora", "qlora", "dora"], algorithm_config: post_training_supervised_fine_tune_params.AlgorithmConfig, - dataset: TrainEvalDatasetParam, + dataset_id: str, hyperparam_search_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], job_uuid: str, logger_config: Dict[str, Union[bool, float, str, Iterable[object], object, None]], model: str, optimizer_config: post_training_supervised_fine_tune_params.OptimizerConfig, training_config: post_training_supervised_fine_tune_params.TrainingConfig, - validation_dataset: TrainEvalDatasetParam, + validation_dataset_id: str, x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. # The extra values given here take precedence over values defined on the client or passed to this method. @@ -304,14 +300,14 @@ async def supervised_fine_tune( { "algorithm": algorithm, "algorithm_config": algorithm_config, - "dataset": dataset, + "dataset_id": dataset_id, "hyperparam_search_config": hyperparam_search_config, "job_uuid": job_uuid, "logger_config": logger_config, "model": model, "optimizer_config": optimizer_config, "training_config": training_config, - "validation_dataset": validation_dataset, + "validation_dataset_id": validation_dataset_id, }, post_training_supervised_fine_tune_params.PostTrainingSupervisedFineTuneParams, ), diff --git a/src/llama_stack_client/resources/scoring.py b/src/llama_stack_client/resources/scoring.py new file mode 100644 index 00000000..ce77012f --- /dev/null +++ b/src/llama_stack_client/resources/scoring.py @@ -0,0 +1,290 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, List, Union, Iterable + +import httpx + +from ..types import scoring_score_params, scoring_score_batch_params +from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven +from .._utils import ( + maybe_transform, + strip_not_given, + async_maybe_transform, +) +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.score_response import ScoreResponse +from ..types.score_batch_response import ScoreBatchResponse + +__all__ = ["ScoringResource", "AsyncScoringResource"] + + +class ScoringResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> ScoringResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return ScoringResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> ScoringResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return ScoringResourceWithStreamingResponse(self) + + def score( + self, + *, + input_rows: Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]], + scoring_functions: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ScoreResponse: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._post( + "/scoring/score", + body=maybe_transform( + { + "input_rows": input_rows, + "scoring_functions": scoring_functions, + }, + scoring_score_params.ScoringScoreParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ScoreResponse, + ) + + def score_batch( + self, + *, + dataset_id: str, + save_results_dataset: bool, + scoring_functions: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ScoreBatchResponse: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._post( + "/scoring/score_batch", + body=maybe_transform( + { + "dataset_id": dataset_id, + "save_results_dataset": save_results_dataset, + "scoring_functions": scoring_functions, + }, + scoring_score_batch_params.ScoringScoreBatchParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ScoreBatchResponse, + ) + + +class AsyncScoringResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncScoringResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return AsyncScoringResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncScoringResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return AsyncScoringResourceWithStreamingResponse(self) + + async def score( + self, + *, + input_rows: Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]], + scoring_functions: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ScoreResponse: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._post( + "/scoring/score", + body=await async_maybe_transform( + { + "input_rows": input_rows, + "scoring_functions": scoring_functions, + }, + scoring_score_params.ScoringScoreParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ScoreResponse, + ) + + async def score_batch( + self, + *, + dataset_id: str, + save_results_dataset: bool, + scoring_functions: List[str], + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ScoreBatchResponse: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._post( + "/scoring/score_batch", + body=await async_maybe_transform( + { + "dataset_id": dataset_id, + "save_results_dataset": save_results_dataset, + "scoring_functions": scoring_functions, + }, + scoring_score_batch_params.ScoringScoreBatchParams, + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ScoreBatchResponse, + ) + + +class ScoringResourceWithRawResponse: + def __init__(self, scoring: ScoringResource) -> None: + self._scoring = scoring + + self.score = to_raw_response_wrapper( + scoring.score, + ) + self.score_batch = to_raw_response_wrapper( + scoring.score_batch, + ) + + +class AsyncScoringResourceWithRawResponse: + def __init__(self, scoring: AsyncScoringResource) -> None: + self._scoring = scoring + + self.score = async_to_raw_response_wrapper( + scoring.score, + ) + self.score_batch = async_to_raw_response_wrapper( + scoring.score_batch, + ) + + +class ScoringResourceWithStreamingResponse: + def __init__(self, scoring: ScoringResource) -> None: + self._scoring = scoring + + self.score = to_streamed_response_wrapper( + scoring.score, + ) + self.score_batch = to_streamed_response_wrapper( + scoring.score_batch, + ) + + +class AsyncScoringResourceWithStreamingResponse: + def __init__(self, scoring: AsyncScoringResource) -> None: + self._scoring = scoring + + self.score = async_to_streamed_response_wrapper( + scoring.score, + ) + self.score_batch = async_to_streamed_response_wrapper( + scoring.score_batch, + ) diff --git a/src/llama_stack_client/resources/scoring_functions.py b/src/llama_stack_client/resources/scoring_functions.py new file mode 100644 index 00000000..dd063d29 --- /dev/null +++ b/src/llama_stack_client/resources/scoring_functions.py @@ -0,0 +1,356 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Optional + +import httpx + +from ..types import ( + ScoringFnDefWithProvider, + scoring_function_register_params, + scoring_function_retrieve_params, +) +from .._types import NOT_GIVEN, Body, Query, Headers, NoneType, NotGiven +from .._utils import ( + maybe_transform, + strip_not_given, + async_maybe_transform, +) +from .._compat import cached_property +from .._resource import SyncAPIResource, AsyncAPIResource +from .._response import ( + to_raw_response_wrapper, + to_streamed_response_wrapper, + async_to_raw_response_wrapper, + async_to_streamed_response_wrapper, +) +from .._base_client import make_request_options +from ..types.scoring_fn_def_with_provider import ScoringFnDefWithProvider +from ..types.scoring_fn_def_with_provider_param import ScoringFnDefWithProviderParam + +__all__ = ["ScoringFunctionsResource", "AsyncScoringFunctionsResource"] + + +class ScoringFunctionsResource(SyncAPIResource): + @cached_property + def with_raw_response(self) -> ScoringFunctionsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return ScoringFunctionsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> ScoringFunctionsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return ScoringFunctionsResourceWithStreamingResponse(self) + + def retrieve( + self, + *, + name: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Optional[ScoringFnDefWithProvider]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/scoring_functions/get", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=maybe_transform({"name": name}, scoring_function_retrieve_params.ScoringFunctionRetrieveParams), + ), + cast_to=ScoringFnDefWithProvider, + ) + + def list( + self, + *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ScoringFnDefWithProvider: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._get( + "/scoring_functions/list", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ScoringFnDefWithProvider, + ) + + def register( + self, + *, + function_def: ScoringFnDefWithProviderParam, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return self._post( + "/scoring_functions/register", + body=maybe_transform( + {"function_def": function_def}, scoring_function_register_params.ScoringFunctionRegisterParams + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + + +class AsyncScoringFunctionsResource(AsyncAPIResource): + @cached_property + def with_raw_response(self) -> AsyncScoringFunctionsResourceWithRawResponse: + """ + This property can be used as a prefix for any HTTP method call to return the + the raw response object instead of the parsed content. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#accessing-raw-response-data-eg-headers + """ + return AsyncScoringFunctionsResourceWithRawResponse(self) + + @cached_property + def with_streaming_response(self) -> AsyncScoringFunctionsResourceWithStreamingResponse: + """ + An alternative to `.with_raw_response` that doesn't eagerly read the response body. + + For more information, see https://www.github.com/stainless-sdks/llama-stack-python#with_streaming_response + """ + return AsyncScoringFunctionsResourceWithStreamingResponse(self) + + async def retrieve( + self, + *, + name: str, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Optional[ScoringFnDefWithProvider]: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/scoring_functions/get", + options=make_request_options( + extra_headers=extra_headers, + extra_query=extra_query, + extra_body=extra_body, + timeout=timeout, + query=await async_maybe_transform( + {"name": name}, scoring_function_retrieve_params.ScoringFunctionRetrieveParams + ), + ), + cast_to=ScoringFnDefWithProvider, + ) + + async def list( + self, + *, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> ScoringFnDefWithProvider: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "application/jsonl", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._get( + "/scoring_functions/list", + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=ScoringFnDefWithProvider, + ) + + async def register( + self, + *, + function_def: ScoringFnDefWithProviderParam, + x_llama_stack_provider_data: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> None: + """ + Args: + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + extra_headers = {"Accept": "*/*", **(extra_headers or {})} + extra_headers = { + **strip_not_given({"X-LlamaStack-ProviderData": x_llama_stack_provider_data}), + **(extra_headers or {}), + } + return await self._post( + "/scoring_functions/register", + body=await async_maybe_transform( + {"function_def": function_def}, scoring_function_register_params.ScoringFunctionRegisterParams + ), + options=make_request_options( + extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout + ), + cast_to=NoneType, + ) + + +class ScoringFunctionsResourceWithRawResponse: + def __init__(self, scoring_functions: ScoringFunctionsResource) -> None: + self._scoring_functions = scoring_functions + + self.retrieve = to_raw_response_wrapper( + scoring_functions.retrieve, + ) + self.list = to_raw_response_wrapper( + scoring_functions.list, + ) + self.register = to_raw_response_wrapper( + scoring_functions.register, + ) + + +class AsyncScoringFunctionsResourceWithRawResponse: + def __init__(self, scoring_functions: AsyncScoringFunctionsResource) -> None: + self._scoring_functions = scoring_functions + + self.retrieve = async_to_raw_response_wrapper( + scoring_functions.retrieve, + ) + self.list = async_to_raw_response_wrapper( + scoring_functions.list, + ) + self.register = async_to_raw_response_wrapper( + scoring_functions.register, + ) + + +class ScoringFunctionsResourceWithStreamingResponse: + def __init__(self, scoring_functions: ScoringFunctionsResource) -> None: + self._scoring_functions = scoring_functions + + self.retrieve = to_streamed_response_wrapper( + scoring_functions.retrieve, + ) + self.list = to_streamed_response_wrapper( + scoring_functions.list, + ) + self.register = to_streamed_response_wrapper( + scoring_functions.register, + ) + + +class AsyncScoringFunctionsResourceWithStreamingResponse: + def __init__(self, scoring_functions: AsyncScoringFunctionsResource) -> None: + self._scoring_functions = scoring_functions + + self.retrieve = async_to_streamed_response_wrapper( + scoring_functions.retrieve, + ) + self.list = async_to_streamed_response_wrapper( + scoring_functions.list, + ) + self.register = async_to_streamed_response_wrapper( + scoring_functions.register, + ) diff --git a/src/llama_stack_client/types/__init__.py b/src/llama_stack_client/types/__init__.py index 30573e02..e5a02c3e 100644 --- a/src/llama_stack_client/types/__init__.py +++ b/src/llama_stack_client/types/__init__.py @@ -2,6 +2,7 @@ from __future__ import annotations +from .job import Job as Job from .trace import Trace as Trace from .shared import ( ToolCall as ToolCall, @@ -11,15 +12,23 @@ SystemMessage as SystemMessage, SamplingParams as SamplingParams, BatchCompletion as BatchCompletion, + SafetyViolation as SafetyViolation, CompletionMessage as CompletionMessage, + GraphMemoryBankDef as GraphMemoryBankDef, ToolResponseMessage as ToolResponseMessage, + VectorMemoryBankDef as VectorMemoryBankDef, + KeywordMemoryBankDef as KeywordMemoryBankDef, + KeyValueMemoryBankDef as KeyValueMemoryBankDef, ) from .route_info import RouteInfo as RouteInfo from .health_info import HealthInfo as HealthInfo from .provider_info import ProviderInfo as ProviderInfo -from .evaluation_job import EvaluationJob as EvaluationJob +from .tool_response import ToolResponse as ToolResponse +from .inference_step import InferenceStep as InferenceStep +from .score_response import ScoreResponse as ScoreResponse +from .token_log_probs import TokenLogProbs as TokenLogProbs +from .shield_call_step import ShieldCallStep as ShieldCallStep from .post_training_job import PostTrainingJob as PostTrainingJob -from .train_eval_dataset import TrainEvalDataset as TrainEvalDataset from .agent_create_params import AgentCreateParams as AgentCreateParams from .agent_delete_params import AgentDeleteParams as AgentDeleteParams from .completion_response import CompletionResponse as CompletionResponse @@ -27,41 +36,51 @@ from .memory_query_params import MemoryQueryParams as MemoryQueryParams from .route_list_response import RouteListResponse as RouteListResponse from .run_shield_response import RunShieldResponse as RunShieldResponse +from .tool_execution_step import ToolExecutionStep as ToolExecutionStep +from .eval_evaluate_params import EvalEvaluateParams as EvalEvaluateParams from .memory_insert_params import MemoryInsertParams as MemoryInsertParams +from .score_batch_response import ScoreBatchResponse as ScoreBatchResponse +from .scoring_score_params import ScoringScoreParams as ScoringScoreParams from .agent_create_response import AgentCreateResponse as AgentCreateResponse -from .dataset_create_params import DatasetCreateParams as DatasetCreateParams -from .dataset_delete_params import DatasetDeleteParams as DatasetDeleteParams +from .dataset_list_response import DatasetListResponse as DatasetListResponse +from .memory_retrieval_step import MemoryRetrievalStep as MemoryRetrievalStep from .model_register_params import ModelRegisterParams as ModelRegisterParams from .model_retrieve_params import ModelRetrieveParams as ModelRetrieveParams +from .paginated_rows_result import PaginatedRowsResult as PaginatedRowsResult +from .eval_evaluate_response import EvalEvaluateResponse as EvalEvaluateResponse from .provider_list_response import ProviderListResponse as ProviderListResponse from .shield_register_params import ShieldRegisterParams as ShieldRegisterParams from .shield_retrieve_params import ShieldRetrieveParams as ShieldRetrieveParams +from .dataset_register_params import DatasetRegisterParams as DatasetRegisterParams from .dataset_retrieve_params import DatasetRetrieveParams as DatasetRetrieveParams from .model_def_with_provider import ModelDefWithProvider as ModelDefWithProvider -from .reward_scoring_response import RewardScoringResponse as RewardScoringResponse from .query_documents_response import QueryDocumentsResponse as QueryDocumentsResponse from .safety_run_shield_params import SafetyRunShieldParams as SafetyRunShieldParams from .shield_def_with_provider import ShieldDefWithProvider as ShieldDefWithProvider -from .train_eval_dataset_param import TrainEvalDatasetParam as TrainEvalDatasetParam +from .dataset_retrieve_response import DatasetRetrieveResponse as DatasetRetrieveResponse from .memory_bank_list_response import MemoryBankListResponse as MemoryBankListResponse +from .eval_evaluate_batch_params import EvalEvaluateBatchParams as EvalEvaluateBatchParams +from .scoring_score_batch_params import ScoringScoreBatchParams as ScoringScoreBatchParams from .telemetry_get_trace_params import TelemetryGetTraceParams as TelemetryGetTraceParams from .telemetry_log_event_params import TelemetryLogEventParams as TelemetryLogEventParams from .inference_completion_params import InferenceCompletionParams as InferenceCompletionParams from .inference_embeddings_params import InferenceEmbeddingsParams as InferenceEmbeddingsParams from .memory_bank_register_params import MemoryBankRegisterParams as MemoryBankRegisterParams from .memory_bank_retrieve_params import MemoryBankRetrieveParams as MemoryBankRetrieveParams -from .reward_scoring_score_params import RewardScoringScoreParams as RewardScoringScoreParams -from .evaluate_summarization_params import EvaluateSummarizationParams as EvaluateSummarizationParams +from .scoring_fn_def_with_provider import ScoringFnDefWithProvider as ScoringFnDefWithProvider from .inference_completion_response import InferenceCompletionResponse as InferenceCompletionResponse from .memory_bank_retrieve_response import MemoryBankRetrieveResponse as MemoryBankRetrieveResponse from .model_def_with_provider_param import ModelDefWithProviderParam as ModelDefWithProviderParam from .shield_def_with_provider_param import ShieldDefWithProviderParam as ShieldDefWithProviderParam +from .rest_api_execution_config_param import RestAPIExecutionConfigParam as RestAPIExecutionConfigParam from .inference_chat_completion_params import InferenceChatCompletionParams as InferenceChatCompletionParams +from .scoring_function_register_params import ScoringFunctionRegisterParams as ScoringFunctionRegisterParams +from .scoring_function_retrieve_params import ScoringFunctionRetrieveParams as ScoringFunctionRetrieveParams from .batch_inference_completion_params import BatchInferenceCompletionParams as BatchInferenceCompletionParams -from .evaluation_text_generation_params import EvaluationTextGenerationParams as EvaluationTextGenerationParams -from .evaluate_question_answering_params import EvaluateQuestionAnsweringParams as EvaluateQuestionAnsweringParams from .inference_chat_completion_response import InferenceChatCompletionResponse as InferenceChatCompletionResponse +from .scoring_fn_def_with_provider_param import ScoringFnDefWithProviderParam as ScoringFnDefWithProviderParam from .synthetic_data_generation_response import SyntheticDataGenerationResponse as SyntheticDataGenerationResponse +from .datasetio_get_rows_paginated_params import DatasetioGetRowsPaginatedParams as DatasetioGetRowsPaginatedParams from .batch_inference_chat_completion_params import ( BatchInferenceChatCompletionParams as BatchInferenceChatCompletionParams, ) diff --git a/src/llama_stack_client/types/agent_create_params.py b/src/llama_stack_client/types/agent_create_params.py index af43f7ae..bba57077 100644 --- a/src/llama_stack_client/types/agent_create_params.py +++ b/src/llama_stack_client/types/agent_create_params.py @@ -7,22 +7,18 @@ from .._utils import PropertyInfo from .shared_params.sampling_params import SamplingParams +from .rest_api_execution_config_param import RestAPIExecutionConfigParam __all__ = [ "AgentCreateParams", "AgentConfig", "AgentConfigTool", "AgentConfigToolSearchToolDefinition", - "AgentConfigToolSearchToolDefinitionRemoteExecution", "AgentConfigToolWolframAlphaToolDefinition", - "AgentConfigToolWolframAlphaToolDefinitionRemoteExecution", "AgentConfigToolPhotogenToolDefinition", - "AgentConfigToolPhotogenToolDefinitionRemoteExecution", "AgentConfigToolCodeInterpreterToolDefinition", - "AgentConfigToolCodeInterpreterToolDefinitionRemoteExecution", "AgentConfigToolFunctionCallToolDefinition", "AgentConfigToolFunctionCallToolDefinitionParameters", - "AgentConfigToolFunctionCallToolDefinitionRemoteExecution", "AgentConfigToolMemoryToolDefinition", "AgentConfigToolMemoryToolDefinitionMemoryBankConfig", "AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0", @@ -42,18 +38,6 @@ class AgentCreateParams(TypedDict, total=False): x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] -class AgentConfigToolSearchToolDefinitionRemoteExecution(TypedDict, total=False): - method: Required[Literal["GET", "POST", "PUT", "DELETE"]] - - url: Required[str] - - body: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - headers: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - class AgentConfigToolSearchToolDefinition(TypedDict, total=False): api_key: Required[str] @@ -65,19 +49,7 @@ class AgentConfigToolSearchToolDefinition(TypedDict, total=False): output_shields: List[str] - remote_execution: AgentConfigToolSearchToolDefinitionRemoteExecution - - -class AgentConfigToolWolframAlphaToolDefinitionRemoteExecution(TypedDict, total=False): - method: Required[Literal["GET", "POST", "PUT", "DELETE"]] - - url: Required[str] - - body: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - headers: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + remote_execution: RestAPIExecutionConfigParam class AgentConfigToolWolframAlphaToolDefinition(TypedDict, total=False): @@ -89,19 +61,7 @@ class AgentConfigToolWolframAlphaToolDefinition(TypedDict, total=False): output_shields: List[str] - remote_execution: AgentConfigToolWolframAlphaToolDefinitionRemoteExecution - - -class AgentConfigToolPhotogenToolDefinitionRemoteExecution(TypedDict, total=False): - method: Required[Literal["GET", "POST", "PUT", "DELETE"]] - - url: Required[str] - - body: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - headers: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + remote_execution: RestAPIExecutionConfigParam class AgentConfigToolPhotogenToolDefinition(TypedDict, total=False): @@ -111,19 +71,7 @@ class AgentConfigToolPhotogenToolDefinition(TypedDict, total=False): output_shields: List[str] - remote_execution: AgentConfigToolPhotogenToolDefinitionRemoteExecution - - -class AgentConfigToolCodeInterpreterToolDefinitionRemoteExecution(TypedDict, total=False): - method: Required[Literal["GET", "POST", "PUT", "DELETE"]] - - url: Required[str] - - body: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - headers: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] + remote_execution: RestAPIExecutionConfigParam class AgentConfigToolCodeInterpreterToolDefinition(TypedDict, total=False): @@ -135,7 +83,7 @@ class AgentConfigToolCodeInterpreterToolDefinition(TypedDict, total=False): output_shields: List[str] - remote_execution: AgentConfigToolCodeInterpreterToolDefinitionRemoteExecution + remote_execution: RestAPIExecutionConfigParam class AgentConfigToolFunctionCallToolDefinitionParameters(TypedDict, total=False): @@ -148,18 +96,6 @@ class AgentConfigToolFunctionCallToolDefinitionParameters(TypedDict, total=False required: bool -class AgentConfigToolFunctionCallToolDefinitionRemoteExecution(TypedDict, total=False): - method: Required[Literal["GET", "POST", "PUT", "DELETE"]] - - url: Required[str] - - body: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - headers: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - params: Dict[str, Union[bool, float, str, Iterable[object], object, None]] - - class AgentConfigToolFunctionCallToolDefinition(TypedDict, total=False): description: Required[str] @@ -173,7 +109,7 @@ class AgentConfigToolFunctionCallToolDefinition(TypedDict, total=False): output_shields: List[str] - remote_execution: AgentConfigToolFunctionCallToolDefinitionRemoteExecution + remote_execution: RestAPIExecutionConfigParam class AgentConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0(TypedDict, total=False): diff --git a/src/llama_stack_client/types/agent_create_response.py b/src/llama_stack_client/types/agent_create_response.py index be253645..65d2275f 100644 --- a/src/llama_stack_client/types/agent_create_response.py +++ b/src/llama_stack_client/types/agent_create_response.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from .._models import BaseModel __all__ = ["AgentCreateResponse"] diff --git a/src/llama_stack_client/types/agents/session.py b/src/llama_stack_client/types/agents/session.py index 10d38a5c..23c915fc 100644 --- a/src/llama_stack_client/types/agents/session.py +++ b/src/llama_stack_client/types/agents/session.py @@ -2,65 +2,18 @@ from typing import List, Union, Optional from datetime import datetime -from typing_extensions import Literal, TypeAlias +from typing_extensions import TypeAlias from .turn import Turn from ..._models import BaseModel +from ..shared.graph_memory_bank_def import GraphMemoryBankDef +from ..shared.vector_memory_bank_def import VectorMemoryBankDef +from ..shared.keyword_memory_bank_def import KeywordMemoryBankDef +from ..shared.key_value_memory_bank_def import KeyValueMemoryBankDef -__all__ = [ - "Session", - "MemoryBank", - "MemoryBankVectorMemoryBankDef", - "MemoryBankKeyValueMemoryBankDef", - "MemoryBankKeywordMemoryBankDef", - "MemoryBankGraphMemoryBankDef", -] +__all__ = ["Session", "MemoryBank"] - -class MemoryBankVectorMemoryBankDef(BaseModel): - chunk_size_in_tokens: int - - embedding_model: str - - identifier: str - - provider_id: str - - type: Literal["vector"] - - overlap_size_in_tokens: Optional[int] = None - - -class MemoryBankKeyValueMemoryBankDef(BaseModel): - identifier: str - - provider_id: str - - type: Literal["keyvalue"] - - -class MemoryBankKeywordMemoryBankDef(BaseModel): - identifier: str - - provider_id: str - - type: Literal["keyword"] - - -class MemoryBankGraphMemoryBankDef(BaseModel): - identifier: str - - provider_id: str - - type: Literal["graph"] - - -MemoryBank: TypeAlias = Union[ - MemoryBankVectorMemoryBankDef, - MemoryBankKeyValueMemoryBankDef, - MemoryBankKeywordMemoryBankDef, - MemoryBankGraphMemoryBankDef, -] +MemoryBank: TypeAlias = Union[VectorMemoryBankDef, KeyValueMemoryBankDef, KeywordMemoryBankDef, GraphMemoryBankDef] class Session(BaseModel): diff --git a/src/llama_stack_client/types/agents/session_create_response.py b/src/llama_stack_client/types/agents/session_create_response.py index 13d5a35f..6adcf0b2 100644 --- a/src/llama_stack_client/types/agents/session_create_response.py +++ b/src/llama_stack_client/types/agents/session_create_response.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["SessionCreateResponse"] diff --git a/src/llama_stack_client/types/agents/step_retrieve_params.py b/src/llama_stack_client/types/agents/step_retrieve_params.py index cccdc19e..065c067f 100644 --- a/src/llama_stack_client/types/agents/step_retrieve_params.py +++ b/src/llama_stack_client/types/agents/step_retrieve_params.py @@ -12,6 +12,8 @@ class StepRetrieveParams(TypedDict, total=False): agent_id: Required[str] + session_id: Required[str] + step_id: Required[str] turn_id: Required[str] diff --git a/src/llama_stack_client/types/agents/step_retrieve_response.py b/src/llama_stack_client/types/agents/step_retrieve_response.py index 4946d2ff..77376b4f 100644 --- a/src/llama_stack_client/types/agents/step_retrieve_response.py +++ b/src/llama_stack_client/types/agents/step_retrieve_response.py @@ -1,123 +1,17 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Union, Optional -from datetime import datetime -from typing_extensions import Literal, TypeAlias - -from pydantic import Field as FieldInfo +from typing import Union +from typing_extensions import TypeAlias from ..._models import BaseModel -from ..shared.tool_call import ToolCall -from ..shared.image_media import ImageMedia -from ..shared.completion_message import CompletionMessage - -__all__ = [ - "StepRetrieveResponse", - "Step", - "StepInferenceStep", - "StepToolExecutionStep", - "StepToolExecutionStepToolResponse", - "StepToolExecutionStepToolResponseContent", - "StepToolExecutionStepToolResponseContentUnionMember2", - "StepShieldCallStep", - "StepShieldCallStepViolation", - "StepMemoryRetrievalStep", - "StepMemoryRetrievalStepInsertedContext", - "StepMemoryRetrievalStepInsertedContextUnionMember2", -] - - -class StepInferenceStep(BaseModel): - inference_model_response: CompletionMessage = FieldInfo(alias="model_response") - - step_id: str - - step_type: Literal["inference"] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - -StepToolExecutionStepToolResponseContentUnionMember2: TypeAlias = Union[str, ImageMedia] - -StepToolExecutionStepToolResponseContent: TypeAlias = Union[ - str, ImageMedia, List[StepToolExecutionStepToolResponseContentUnionMember2] -] - - -class StepToolExecutionStepToolResponse(BaseModel): - call_id: str - - content: StepToolExecutionStepToolResponseContent - - tool_name: Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str] - - -class StepToolExecutionStep(BaseModel): - step_id: str - - step_type: Literal["tool_execution"] - - tool_calls: List[ToolCall] - - tool_responses: List[StepToolExecutionStepToolResponse] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - -class StepShieldCallStepViolation(BaseModel): - metadata: Dict[str, Union[bool, float, str, List[object], object, None]] - - violation_level: Literal["info", "warn", "error"] - - user_message: Optional[str] = None - - -class StepShieldCallStep(BaseModel): - step_id: str - - step_type: Literal["shield_call"] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - violation: Optional[StepShieldCallStepViolation] = None - - -StepMemoryRetrievalStepInsertedContextUnionMember2: TypeAlias = Union[str, ImageMedia] - -StepMemoryRetrievalStepInsertedContext: TypeAlias = Union[ - str, ImageMedia, List[StepMemoryRetrievalStepInsertedContextUnionMember2] -] - - -class StepMemoryRetrievalStep(BaseModel): - inserted_context: StepMemoryRetrievalStepInsertedContext - - memory_bank_ids: List[str] - - step_id: str - - step_type: Literal["memory_retrieval"] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None +from ..inference_step import InferenceStep +from ..shield_call_step import ShieldCallStep +from ..tool_execution_step import ToolExecutionStep +from ..memory_retrieval_step import MemoryRetrievalStep +__all__ = ["StepRetrieveResponse", "Step"] -Step: TypeAlias = Union[StepInferenceStep, StepToolExecutionStep, StepShieldCallStep, StepMemoryRetrievalStep] +Step: TypeAlias = Union[InferenceStep, ToolExecutionStep, ShieldCallStep, MemoryRetrievalStep] class StepRetrieveResponse(BaseModel): diff --git a/src/llama_stack_client/types/agents/turn.py b/src/llama_stack_client/types/agents/turn.py index f7a25db3..457f3a95 100644 --- a/src/llama_stack_client/types/agents/turn.py +++ b/src/llama_stack_client/types/agents/turn.py @@ -1,129 +1,24 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Union, Optional +from typing import List, Union, Optional from datetime import datetime -from typing_extensions import Literal, TypeAlias - -from pydantic import Field as FieldInfo +from typing_extensions import TypeAlias from ..._models import BaseModel -from ..shared.tool_call import ToolCall +from ..inference_step import InferenceStep +from ..shield_call_step import ShieldCallStep from ..shared.attachment import Attachment -from ..shared.image_media import ImageMedia from ..shared.user_message import UserMessage +from ..tool_execution_step import ToolExecutionStep +from ..memory_retrieval_step import MemoryRetrievalStep from ..shared.completion_message import CompletionMessage from ..shared.tool_response_message import ToolResponseMessage -__all__ = [ - "Turn", - "InputMessage", - "Step", - "StepInferenceStep", - "StepToolExecutionStep", - "StepToolExecutionStepToolResponse", - "StepToolExecutionStepToolResponseContent", - "StepToolExecutionStepToolResponseContentUnionMember2", - "StepShieldCallStep", - "StepShieldCallStepViolation", - "StepMemoryRetrievalStep", - "StepMemoryRetrievalStepInsertedContext", - "StepMemoryRetrievalStepInsertedContextUnionMember2", -] +__all__ = ["Turn", "InputMessage", "Step"] InputMessage: TypeAlias = Union[UserMessage, ToolResponseMessage] - -class StepInferenceStep(BaseModel): - inference_model_response: CompletionMessage = FieldInfo(alias="model_response") - - step_id: str - - step_type: Literal["inference"] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - -StepToolExecutionStepToolResponseContentUnionMember2: TypeAlias = Union[str, ImageMedia] - -StepToolExecutionStepToolResponseContent: TypeAlias = Union[ - str, ImageMedia, List[StepToolExecutionStepToolResponseContentUnionMember2] -] - - -class StepToolExecutionStepToolResponse(BaseModel): - call_id: str - - content: StepToolExecutionStepToolResponseContent - - tool_name: Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str] - - -class StepToolExecutionStep(BaseModel): - step_id: str - - step_type: Literal["tool_execution"] - - tool_calls: List[ToolCall] - - tool_responses: List[StepToolExecutionStepToolResponse] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - -class StepShieldCallStepViolation(BaseModel): - metadata: Dict[str, Union[bool, float, str, List[object], object, None]] - - violation_level: Literal["info", "warn", "error"] - - user_message: Optional[str] = None - - -class StepShieldCallStep(BaseModel): - step_id: str - - step_type: Literal["shield_call"] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - violation: Optional[StepShieldCallStepViolation] = None - - -StepMemoryRetrievalStepInsertedContextUnionMember2: TypeAlias = Union[str, ImageMedia] - -StepMemoryRetrievalStepInsertedContext: TypeAlias = Union[ - str, ImageMedia, List[StepMemoryRetrievalStepInsertedContextUnionMember2] -] - - -class StepMemoryRetrievalStep(BaseModel): - inserted_context: StepMemoryRetrievalStepInsertedContext - - memory_bank_ids: List[str] - - step_id: str - - step_type: Literal["memory_retrieval"] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - -Step: TypeAlias = Union[StepInferenceStep, StepToolExecutionStep, StepShieldCallStep, StepMemoryRetrievalStep] +Step: TypeAlias = Union[InferenceStep, ToolExecutionStep, ShieldCallStep, MemoryRetrievalStep] class Turn(BaseModel): diff --git a/src/llama_stack_client/types/agents/turn_create_response.py b/src/llama_stack_client/types/agents/turn_create_response.py index 2ab93b4d..596a3e5e 100644 --- a/src/llama_stack_client/types/agents/turn_create_response.py +++ b/src/llama_stack_client/types/agents/turn_create_response.py @@ -1,16 +1,17 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Dict, List, Union, Optional -from datetime import datetime from typing_extensions import Literal, TypeAlias from pydantic import Field as FieldInfo from .turn import Turn from ..._models import BaseModel +from ..inference_step import InferenceStep from ..shared.tool_call import ToolCall -from ..shared.image_media import ImageMedia -from ..shared.completion_message import CompletionMessage +from ..shield_call_step import ShieldCallStep +from ..tool_execution_step import ToolExecutionStep +from ..memory_retrieval_step import MemoryRetrievalStep __all__ = [ "TurnCreateResponse", @@ -22,16 +23,6 @@ "EventPayloadAgentTurnResponseStepProgressPayloadToolCallDeltaContent", "EventPayloadAgentTurnResponseStepCompletePayload", "EventPayloadAgentTurnResponseStepCompletePayloadStepDetails", - "EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsInferenceStep", - "EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStep", - "EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStepToolResponse", - "EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStepToolResponseContent", - "EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStepToolResponseContentUnionMember2", - "EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsShieldCallStep", - "EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsShieldCallStepViolation", - "EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsMemoryRetrievalStep", - "EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsMemoryRetrievalStepInsertedContext", - "EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsMemoryRetrievalStepInsertedContextUnionMember2", "EventPayloadAgentTurnResponseTurnStartPayload", "EventPayloadAgentTurnResponseTurnCompletePayload", ] @@ -70,109 +61,8 @@ class EventPayloadAgentTurnResponseStepProgressPayload(BaseModel): tool_response_text_delta: Optional[str] = None -class EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsInferenceStep(BaseModel): - inference_model_response: CompletionMessage = FieldInfo(alias="model_response") - - step_id: str - - step_type: Literal["inference"] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - -EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStepToolResponseContentUnionMember2: TypeAlias = Union[ - str, ImageMedia -] - -EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStepToolResponseContent: TypeAlias = Union[ - str, - ImageMedia, - List[EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStepToolResponseContentUnionMember2], -] - - -class EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStepToolResponse(BaseModel): - call_id: str - - content: EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStepToolResponseContent - - tool_name: Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str] - - -class EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStep(BaseModel): - step_id: str - - step_type: Literal["tool_execution"] - - tool_calls: List[ToolCall] - - tool_responses: List[EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStepToolResponse] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - -class EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsShieldCallStepViolation(BaseModel): - metadata: Dict[str, Union[bool, float, str, List[object], object, None]] - - violation_level: Literal["info", "warn", "error"] - - user_message: Optional[str] = None - - -class EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsShieldCallStep(BaseModel): - step_id: str - - step_type: Literal["shield_call"] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - violation: Optional[EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsShieldCallStepViolation] = None - - -EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsMemoryRetrievalStepInsertedContextUnionMember2: TypeAlias = ( - Union[str, ImageMedia] -) - -EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsMemoryRetrievalStepInsertedContext: TypeAlias = Union[ - str, - ImageMedia, - List[EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsMemoryRetrievalStepInsertedContextUnionMember2], -] - - -class EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsMemoryRetrievalStep(BaseModel): - inserted_context: EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsMemoryRetrievalStepInsertedContext - - memory_bank_ids: List[str] - - step_id: str - - step_type: Literal["memory_retrieval"] - - turn_id: str - - completed_at: Optional[datetime] = None - - started_at: Optional[datetime] = None - - EventPayloadAgentTurnResponseStepCompletePayloadStepDetails: TypeAlias = Union[ - EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsInferenceStep, - EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsToolExecutionStep, - EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsShieldCallStep, - EventPayloadAgentTurnResponseStepCompletePayloadStepDetailsMemoryRetrievalStep, + InferenceStep, ToolExecutionStep, ShieldCallStep, MemoryRetrievalStep ] diff --git a/src/llama_stack_client/types/agents/turn_retrieve_params.py b/src/llama_stack_client/types/agents/turn_retrieve_params.py index 7f3349a7..b8f16db9 100644 --- a/src/llama_stack_client/types/agents/turn_retrieve_params.py +++ b/src/llama_stack_client/types/agents/turn_retrieve_params.py @@ -12,6 +12,8 @@ class TurnRetrieveParams(TypedDict, total=False): agent_id: Required[str] + session_id: Required[str] + turn_id: Required[str] x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/completion_response.py b/src/llama_stack_client/types/completion_response.py index d1922d8c..c8d97e79 100644 --- a/src/llama_stack_client/types/completion_response.py +++ b/src/llama_stack_client/types/completion_response.py @@ -1,18 +1,17 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Optional +from typing import List, Optional +from typing_extensions import Literal from .._models import BaseModel -from .shared.completion_message import CompletionMessage +from .token_log_probs import TokenLogProbs -__all__ = ["CompletionResponse", "Logprob"] - - -class Logprob(BaseModel): - logprobs_by_token: Dict[str, float] +__all__ = ["CompletionResponse"] class CompletionResponse(BaseModel): - completion_message: CompletionMessage + content: str + + stop_reason: Literal["end_of_turn", "end_of_message", "out_of_tokens"] - logprobs: Optional[List[Logprob]] = None + logprobs: Optional[List[TokenLogProbs]] = None diff --git a/src/llama_stack_client/types/dataset_list_response.py b/src/llama_stack_client/types/dataset_list_response.py new file mode 100644 index 00000000..993bf13c --- /dev/null +++ b/src/llama_stack_client/types/dataset_list_response.py @@ -0,0 +1,38 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union +from typing_extensions import Literal, TypeAlias + +from .._models import BaseModel + +__all__ = ["DatasetListResponse", "DatasetSchema", "DatasetSchemaType"] + + +class DatasetSchemaType(BaseModel): + type: Literal["string"] + + +DatasetSchema: TypeAlias = Union[ + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, +] + + +class DatasetListResponse(BaseModel): + dataset_schema: Dict[str, DatasetSchema] + + identifier: str + + metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + + provider_id: str + + url: str diff --git a/src/llama_stack_client/types/dataset_register_params.py b/src/llama_stack_client/types/dataset_register_params.py new file mode 100644 index 00000000..c7228c18 --- /dev/null +++ b/src/llama_stack_client/types/dataset_register_params.py @@ -0,0 +1,46 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, Union, Iterable +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["DatasetRegisterParams", "DatasetDef", "DatasetDefDatasetSchema", "DatasetDefDatasetSchemaType"] + + +class DatasetRegisterParams(TypedDict, total=False): + dataset_def: Required[DatasetDef] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + + +class DatasetDefDatasetSchemaType(TypedDict, total=False): + type: Required[Literal["string"]] + + +DatasetDefDatasetSchema: TypeAlias = Union[ + DatasetDefDatasetSchemaType, + DatasetDefDatasetSchemaType, + DatasetDefDatasetSchemaType, + DatasetDefDatasetSchemaType, + DatasetDefDatasetSchemaType, + DatasetDefDatasetSchemaType, + DatasetDefDatasetSchemaType, + DatasetDefDatasetSchemaType, + DatasetDefDatasetSchemaType, + DatasetDefDatasetSchemaType, +] + + +class DatasetDef(TypedDict, total=False): + dataset_schema: Required[Dict[str, DatasetDefDatasetSchema]] + + identifier: Required[str] + + metadata: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + provider_id: Required[str] + + url: Required[str] diff --git a/src/llama_stack_client/types/dataset_retrieve_params.py b/src/llama_stack_client/types/dataset_retrieve_params.py index a6cc0ff1..2a577aa8 100644 --- a/src/llama_stack_client/types/dataset_retrieve_params.py +++ b/src/llama_stack_client/types/dataset_retrieve_params.py @@ -10,6 +10,6 @@ class DatasetRetrieveParams(TypedDict, total=False): - dataset_uuid: Required[str] + dataset_identifier: Required[str] x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/dataset_retrieve_response.py b/src/llama_stack_client/types/dataset_retrieve_response.py new file mode 100644 index 00000000..2c7400d1 --- /dev/null +++ b/src/llama_stack_client/types/dataset_retrieve_response.py @@ -0,0 +1,38 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union +from typing_extensions import Literal, TypeAlias + +from .._models import BaseModel + +__all__ = ["DatasetRetrieveResponse", "DatasetSchema", "DatasetSchemaType"] + + +class DatasetSchemaType(BaseModel): + type: Literal["string"] + + +DatasetSchema: TypeAlias = Union[ + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, + DatasetSchemaType, +] + + +class DatasetRetrieveResponse(BaseModel): + dataset_schema: Dict[str, DatasetSchema] + + identifier: str + + metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + + provider_id: str + + url: str diff --git a/src/llama_stack_client/types/datasetio_get_rows_paginated_params.py b/src/llama_stack_client/types/datasetio_get_rows_paginated_params.py new file mode 100644 index 00000000..82bf9690 --- /dev/null +++ b/src/llama_stack_client/types/datasetio_get_rows_paginated_params.py @@ -0,0 +1,21 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["DatasetioGetRowsPaginatedParams"] + + +class DatasetioGetRowsPaginatedParams(TypedDict, total=False): + dataset_id: Required[str] + + rows_in_page: Required[int] + + filter_condition: str + + page_token: str + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/eval/__init__.py b/src/llama_stack_client/types/eval/__init__.py new file mode 100644 index 00000000..95e97450 --- /dev/null +++ b/src/llama_stack_client/types/eval/__init__.py @@ -0,0 +1,9 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from .job_status import JobStatus as JobStatus +from .job_cancel_params import JobCancelParams as JobCancelParams +from .job_result_params import JobResultParams as JobResultParams +from .job_status_params import JobStatusParams as JobStatusParams +from .job_result_response import JobResultResponse as JobResultResponse diff --git a/src/llama_stack_client/types/eval/job_cancel_params.py b/src/llama_stack_client/types/eval/job_cancel_params.py new file mode 100644 index 00000000..337f6803 --- /dev/null +++ b/src/llama_stack_client/types/eval/job_cancel_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo + +__all__ = ["JobCancelParams"] + + +class JobCancelParams(TypedDict, total=False): + job_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/eval/job_result_params.py b/src/llama_stack_client/types/eval/job_result_params.py new file mode 100644 index 00000000..694e12f8 --- /dev/null +++ b/src/llama_stack_client/types/eval/job_result_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo + +__all__ = ["JobResultParams"] + + +class JobResultParams(TypedDict, total=False): + job_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/eval/job_result_response.py b/src/llama_stack_client/types/eval/job_result_response.py new file mode 100644 index 00000000..78c7620b --- /dev/null +++ b/src/llama_stack_client/types/eval/job_result_response.py @@ -0,0 +1,19 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union + +from ..._models import BaseModel + +__all__ = ["JobResultResponse", "Scores"] + + +class Scores(BaseModel): + aggregated_results: Dict[str, Union[bool, float, str, List[object], object, None]] + + score_rows: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + + +class JobResultResponse(BaseModel): + generations: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + + scores: Dict[str, Scores] diff --git a/src/llama_stack_client/types/eval/job_status.py b/src/llama_stack_client/types/eval/job_status.py new file mode 100644 index 00000000..22bc685c --- /dev/null +++ b/src/llama_stack_client/types/eval/job_status.py @@ -0,0 +1,7 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing_extensions import Literal, TypeAlias + +__all__ = ["JobStatus"] + +JobStatus: TypeAlias = Literal["completed", "in_progress"] diff --git a/src/llama_stack_client/types/eval/job_status_params.py b/src/llama_stack_client/types/eval/job_status_params.py new file mode 100644 index 00000000..01070e2a --- /dev/null +++ b/src/llama_stack_client/types/eval/job_status_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo + +__all__ = ["JobStatusParams"] + + +class JobStatusParams(TypedDict, total=False): + job_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/eval_evaluate_batch_params.py b/src/llama_stack_client/types/eval_evaluate_batch_params.py new file mode 100644 index 00000000..ded8d391 --- /dev/null +++ b/src/llama_stack_client/types/eval_evaluate_batch_params.py @@ -0,0 +1,259 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, List, Union, Iterable +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict + +from .._utils import PropertyInfo +from .shared_params.system_message import SystemMessage +from .shared_params.sampling_params import SamplingParams +from .rest_api_execution_config_param import RestAPIExecutionConfigParam + +__all__ = [ + "EvalEvaluateBatchParams", + "Candidate", + "CandidateModelCandidate", + "CandidateAgentCandidate", + "CandidateAgentCandidateConfig", + "CandidateAgentCandidateConfigTool", + "CandidateAgentCandidateConfigToolSearchToolDefinition", + "CandidateAgentCandidateConfigToolWolframAlphaToolDefinition", + "CandidateAgentCandidateConfigToolPhotogenToolDefinition", + "CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition", + "CandidateAgentCandidateConfigToolFunctionCallToolDefinition", + "CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters", + "CandidateAgentCandidateConfigToolMemoryToolDefinition", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType", +] + + +class EvalEvaluateBatchParams(TypedDict, total=False): + candidate: Required[Candidate] + + dataset_id: Required[str] + + scoring_functions: Required[List[str]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + + +class CandidateModelCandidate(TypedDict, total=False): + model: Required[str] + + sampling_params: Required[SamplingParams] + + type: Required[Literal["model"]] + + system_message: SystemMessage + + +class CandidateAgentCandidateConfigToolSearchToolDefinition(TypedDict, total=False): + api_key: Required[str] + + engine: Required[Literal["bing", "brave"]] + + type: Required[Literal["brave_search"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolWolframAlphaToolDefinition(TypedDict, total=False): + api_key: Required[str] + + type: Required[Literal["wolfram_alpha"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolPhotogenToolDefinition(TypedDict, total=False): + type: Required[Literal["photogen"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition(TypedDict, total=False): + enable_inline_code_execution: Required[bool] + + type: Required[Literal["code_interpreter"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters(TypedDict, total=False): + param_type: Required[str] + + default: Union[bool, float, str, Iterable[object], object, None] + + description: str + + required: bool + + +class CandidateAgentCandidateConfigToolFunctionCallToolDefinition(TypedDict, total=False): + description: Required[str] + + function_name: Required[str] + + parameters: Required[Dict[str, CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters]] + + type: Required[Literal["function_call"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0(TypedDict, total=False): + bank_id: Required[str] + + type: Required[Literal["vector"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1(TypedDict, total=False): + bank_id: Required[str] + + keys: Required[List[str]] + + type: Required[Literal["keyvalue"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2(TypedDict, total=False): + bank_id: Required[str] + + type: Required[Literal["keyword"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3(TypedDict, total=False): + bank_id: Required[str] + + entities: Required[List[str]] + + type: Required[Literal["graph"]] + + +CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig: TypeAlias = Union[ + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3, +] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0(TypedDict, total=False): + sep: Required[str] + + type: Required[Literal["default"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1(TypedDict, total=False): + model: Required[str] + + template: Required[str] + + type: Required[Literal["llm"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType(TypedDict, total=False): + type: Required[Literal["custom"]] + + +CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig: TypeAlias = Union[ + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0, + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1, + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType, +] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinition(TypedDict, total=False): + max_chunks: Required[int] + + max_tokens_in_context: Required[int] + + memory_bank_configs: Required[Iterable[CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig]] + + query_generator_config: Required[CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig] + + type: Required[Literal["memory"]] + + input_shields: List[str] + + output_shields: List[str] + + +CandidateAgentCandidateConfigTool: TypeAlias = Union[ + CandidateAgentCandidateConfigToolSearchToolDefinition, + CandidateAgentCandidateConfigToolWolframAlphaToolDefinition, + CandidateAgentCandidateConfigToolPhotogenToolDefinition, + CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition, + CandidateAgentCandidateConfigToolFunctionCallToolDefinition, + CandidateAgentCandidateConfigToolMemoryToolDefinition, +] + + +class CandidateAgentCandidateConfig(TypedDict, total=False): + enable_session_persistence: Required[bool] + + instructions: Required[str] + + max_infer_iters: Required[int] + + model: Required[str] + + input_shields: List[str] + + output_shields: List[str] + + sampling_params: SamplingParams + + tool_choice: Literal["auto", "required"] + + tool_prompt_format: Literal["json", "function_tag", "python_list"] + """ + `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + """ + + tools: Iterable[CandidateAgentCandidateConfigTool] + + +class CandidateAgentCandidate(TypedDict, total=False): + config: Required[CandidateAgentCandidateConfig] + + type: Required[Literal["agent"]] + + +Candidate: TypeAlias = Union[CandidateModelCandidate, CandidateAgentCandidate] diff --git a/src/llama_stack_client/types/eval_evaluate_params.py b/src/llama_stack_client/types/eval_evaluate_params.py new file mode 100644 index 00000000..5f786316 --- /dev/null +++ b/src/llama_stack_client/types/eval_evaluate_params.py @@ -0,0 +1,259 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, List, Union, Iterable +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict + +from .._utils import PropertyInfo +from .shared_params.system_message import SystemMessage +from .shared_params.sampling_params import SamplingParams +from .rest_api_execution_config_param import RestAPIExecutionConfigParam + +__all__ = [ + "EvalEvaluateParams", + "Candidate", + "CandidateModelCandidate", + "CandidateAgentCandidate", + "CandidateAgentCandidateConfig", + "CandidateAgentCandidateConfigTool", + "CandidateAgentCandidateConfigToolSearchToolDefinition", + "CandidateAgentCandidateConfigToolWolframAlphaToolDefinition", + "CandidateAgentCandidateConfigToolPhotogenToolDefinition", + "CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition", + "CandidateAgentCandidateConfigToolFunctionCallToolDefinition", + "CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters", + "CandidateAgentCandidateConfigToolMemoryToolDefinition", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType", +] + + +class EvalEvaluateParams(TypedDict, total=False): + candidate: Required[Candidate] + + input_rows: Required[Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]]] + + scoring_functions: Required[List[str]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + + +class CandidateModelCandidate(TypedDict, total=False): + model: Required[str] + + sampling_params: Required[SamplingParams] + + type: Required[Literal["model"]] + + system_message: SystemMessage + + +class CandidateAgentCandidateConfigToolSearchToolDefinition(TypedDict, total=False): + api_key: Required[str] + + engine: Required[Literal["bing", "brave"]] + + type: Required[Literal["brave_search"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolWolframAlphaToolDefinition(TypedDict, total=False): + api_key: Required[str] + + type: Required[Literal["wolfram_alpha"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolPhotogenToolDefinition(TypedDict, total=False): + type: Required[Literal["photogen"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition(TypedDict, total=False): + enable_inline_code_execution: Required[bool] + + type: Required[Literal["code_interpreter"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters(TypedDict, total=False): + param_type: Required[str] + + default: Union[bool, float, str, Iterable[object], object, None] + + description: str + + required: bool + + +class CandidateAgentCandidateConfigToolFunctionCallToolDefinition(TypedDict, total=False): + description: Required[str] + + function_name: Required[str] + + parameters: Required[Dict[str, CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters]] + + type: Required[Literal["function_call"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0(TypedDict, total=False): + bank_id: Required[str] + + type: Required[Literal["vector"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1(TypedDict, total=False): + bank_id: Required[str] + + keys: Required[List[str]] + + type: Required[Literal["keyvalue"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2(TypedDict, total=False): + bank_id: Required[str] + + type: Required[Literal["keyword"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3(TypedDict, total=False): + bank_id: Required[str] + + entities: Required[List[str]] + + type: Required[Literal["graph"]] + + +CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig: TypeAlias = Union[ + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3, +] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0(TypedDict, total=False): + sep: Required[str] + + type: Required[Literal["default"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1(TypedDict, total=False): + model: Required[str] + + template: Required[str] + + type: Required[Literal["llm"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType(TypedDict, total=False): + type: Required[Literal["custom"]] + + +CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig: TypeAlias = Union[ + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0, + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1, + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType, +] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinition(TypedDict, total=False): + max_chunks: Required[int] + + max_tokens_in_context: Required[int] + + memory_bank_configs: Required[Iterable[CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig]] + + query_generator_config: Required[CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig] + + type: Required[Literal["memory"]] + + input_shields: List[str] + + output_shields: List[str] + + +CandidateAgentCandidateConfigTool: TypeAlias = Union[ + CandidateAgentCandidateConfigToolSearchToolDefinition, + CandidateAgentCandidateConfigToolWolframAlphaToolDefinition, + CandidateAgentCandidateConfigToolPhotogenToolDefinition, + CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition, + CandidateAgentCandidateConfigToolFunctionCallToolDefinition, + CandidateAgentCandidateConfigToolMemoryToolDefinition, +] + + +class CandidateAgentCandidateConfig(TypedDict, total=False): + enable_session_persistence: Required[bool] + + instructions: Required[str] + + max_infer_iters: Required[int] + + model: Required[str] + + input_shields: List[str] + + output_shields: List[str] + + sampling_params: SamplingParams + + tool_choice: Literal["auto", "required"] + + tool_prompt_format: Literal["json", "function_tag", "python_list"] + """ + `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + """ + + tools: Iterable[CandidateAgentCandidateConfigTool] + + +class CandidateAgentCandidate(TypedDict, total=False): + config: Required[CandidateAgentCandidateConfig] + + type: Required[Literal["agent"]] + + +Candidate: TypeAlias = Union[CandidateModelCandidate, CandidateAgentCandidate] diff --git a/src/llama_stack_client/types/eval_evaluate_response.py b/src/llama_stack_client/types/eval_evaluate_response.py new file mode 100644 index 00000000..d5734ed2 --- /dev/null +++ b/src/llama_stack_client/types/eval_evaluate_response.py @@ -0,0 +1,19 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union + +from .._models import BaseModel + +__all__ = ["EvalEvaluateResponse", "Scores"] + + +class Scores(BaseModel): + aggregated_results: Dict[str, Union[bool, float, str, List[object], object, None]] + + score_rows: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + + +class EvalEvaluateResponse(BaseModel): + generations: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + + scores: Dict[str, Scores] diff --git a/src/llama_stack_client/types/evaluate/__init__.py b/src/llama_stack_client/types/evaluate/__init__.py index 17233c60..c7dfea29 100644 --- a/src/llama_stack_client/types/evaluate/__init__.py +++ b/src/llama_stack_client/types/evaluate/__init__.py @@ -2,10 +2,8 @@ from __future__ import annotations -from .job_logs_params import JobLogsParams as JobLogsParams +from .job_status import JobStatus as JobStatus +from .evaluate_response import EvaluateResponse as EvaluateResponse from .job_cancel_params import JobCancelParams as JobCancelParams -from .job_logs_response import JobLogsResponse as JobLogsResponse +from .job_result_params import JobResultParams as JobResultParams from .job_status_params import JobStatusParams as JobStatusParams -from .job_status_response import JobStatusResponse as JobStatusResponse -from .job_artifacts_params import JobArtifactsParams as JobArtifactsParams -from .job_artifacts_response import JobArtifactsResponse as JobArtifactsResponse diff --git a/src/llama_stack_client/types/evaluate/evaluate_response.py b/src/llama_stack_client/types/evaluate/evaluate_response.py new file mode 100644 index 00000000..11aa6820 --- /dev/null +++ b/src/llama_stack_client/types/evaluate/evaluate_response.py @@ -0,0 +1,19 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union + +from ..._models import BaseModel + +__all__ = ["EvaluateResponse", "Scores"] + + +class Scores(BaseModel): + aggregated_results: Dict[str, Union[bool, float, str, List[object], object, None]] + + score_rows: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + + +class EvaluateResponse(BaseModel): + generations: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + + scores: Dict[str, Scores] diff --git a/src/llama_stack_client/types/evaluate/job_artifacts_response.py b/src/llama_stack_client/types/evaluate/job_artifacts_response.py index 288f78f0..e39404cf 100644 --- a/src/llama_stack_client/types/evaluate/job_artifacts_response.py +++ b/src/llama_stack_client/types/evaluate/job_artifacts_response.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["JobArtifactsResponse"] diff --git a/src/llama_stack_client/types/evaluate/job_cancel_params.py b/src/llama_stack_client/types/evaluate/job_cancel_params.py index 9321c3ba..337f6803 100644 --- a/src/llama_stack_client/types/evaluate/job_cancel_params.py +++ b/src/llama_stack_client/types/evaluate/job_cancel_params.py @@ -10,6 +10,6 @@ class JobCancelParams(TypedDict, total=False): - job_uuid: Required[str] + job_id: Required[str] x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/evaluate/job_logs_response.py b/src/llama_stack_client/types/evaluate/job_logs_response.py index 276f582f..ec036719 100644 --- a/src/llama_stack_client/types/evaluate/job_logs_response.py +++ b/src/llama_stack_client/types/evaluate/job_logs_response.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["JobLogsResponse"] diff --git a/src/llama_stack_client/types/evaluate/job_result_params.py b/src/llama_stack_client/types/evaluate/job_result_params.py new file mode 100644 index 00000000..694e12f8 --- /dev/null +++ b/src/llama_stack_client/types/evaluate/job_result_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from ..._utils import PropertyInfo + +__all__ = ["JobResultParams"] + + +class JobResultParams(TypedDict, total=False): + job_id: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/evaluate/job_status.py b/src/llama_stack_client/types/evaluate/job_status.py new file mode 100644 index 00000000..22bc685c --- /dev/null +++ b/src/llama_stack_client/types/evaluate/job_status.py @@ -0,0 +1,7 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing_extensions import Literal, TypeAlias + +__all__ = ["JobStatus"] + +JobStatus: TypeAlias = Literal["completed", "in_progress"] diff --git a/src/llama_stack_client/types/evaluate/job_status_params.py b/src/llama_stack_client/types/evaluate/job_status_params.py index f1f8b204..01070e2a 100644 --- a/src/llama_stack_client/types/evaluate/job_status_params.py +++ b/src/llama_stack_client/types/evaluate/job_status_params.py @@ -10,6 +10,6 @@ class JobStatusParams(TypedDict, total=False): - job_uuid: Required[str] + job_id: Required[str] x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/evaluate/job_status_response.py b/src/llama_stack_client/types/evaluate/job_status_response.py index 241dde63..9405d17e 100644 --- a/src/llama_stack_client/types/evaluate/job_status_response.py +++ b/src/llama_stack_client/types/evaluate/job_status_response.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from ..._models import BaseModel __all__ = ["JobStatusResponse"] diff --git a/src/llama_stack_client/types/evaluate_evaluate_batch_params.py b/src/llama_stack_client/types/evaluate_evaluate_batch_params.py new file mode 100644 index 00000000..f729a91a --- /dev/null +++ b/src/llama_stack_client/types/evaluate_evaluate_batch_params.py @@ -0,0 +1,259 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, List, Union, Iterable +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict + +from .._utils import PropertyInfo +from .shared_params.system_message import SystemMessage +from .shared_params.sampling_params import SamplingParams +from .rest_api_execution_config_param import RestAPIExecutionConfigParam + +__all__ = [ + "EvaluateEvaluateBatchParams", + "Candidate", + "CandidateModelCandidate", + "CandidateAgentCandidate", + "CandidateAgentCandidateConfig", + "CandidateAgentCandidateConfigTool", + "CandidateAgentCandidateConfigToolSearchToolDefinition", + "CandidateAgentCandidateConfigToolWolframAlphaToolDefinition", + "CandidateAgentCandidateConfigToolPhotogenToolDefinition", + "CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition", + "CandidateAgentCandidateConfigToolFunctionCallToolDefinition", + "CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters", + "CandidateAgentCandidateConfigToolMemoryToolDefinition", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType", +] + + +class EvaluateEvaluateBatchParams(TypedDict, total=False): + candidate: Required[Candidate] + + dataset_id: Required[str] + + scoring_functions: Required[List[str]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + + +class CandidateModelCandidate(TypedDict, total=False): + model: Required[str] + + sampling_params: Required[SamplingParams] + + type: Required[Literal["model"]] + + system_message: SystemMessage + + +class CandidateAgentCandidateConfigToolSearchToolDefinition(TypedDict, total=False): + api_key: Required[str] + + engine: Required[Literal["bing", "brave"]] + + type: Required[Literal["brave_search"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolWolframAlphaToolDefinition(TypedDict, total=False): + api_key: Required[str] + + type: Required[Literal["wolfram_alpha"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolPhotogenToolDefinition(TypedDict, total=False): + type: Required[Literal["photogen"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition(TypedDict, total=False): + enable_inline_code_execution: Required[bool] + + type: Required[Literal["code_interpreter"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters(TypedDict, total=False): + param_type: Required[str] + + default: Union[bool, float, str, Iterable[object], object, None] + + description: str + + required: bool + + +class CandidateAgentCandidateConfigToolFunctionCallToolDefinition(TypedDict, total=False): + description: Required[str] + + function_name: Required[str] + + parameters: Required[Dict[str, CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters]] + + type: Required[Literal["function_call"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0(TypedDict, total=False): + bank_id: Required[str] + + type: Required[Literal["vector"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1(TypedDict, total=False): + bank_id: Required[str] + + keys: Required[List[str]] + + type: Required[Literal["keyvalue"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2(TypedDict, total=False): + bank_id: Required[str] + + type: Required[Literal["keyword"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3(TypedDict, total=False): + bank_id: Required[str] + + entities: Required[List[str]] + + type: Required[Literal["graph"]] + + +CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig: TypeAlias = Union[ + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3, +] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0(TypedDict, total=False): + sep: Required[str] + + type: Required[Literal["default"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1(TypedDict, total=False): + model: Required[str] + + template: Required[str] + + type: Required[Literal["llm"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType(TypedDict, total=False): + type: Required[Literal["custom"]] + + +CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig: TypeAlias = Union[ + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0, + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1, + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType, +] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinition(TypedDict, total=False): + max_chunks: Required[int] + + max_tokens_in_context: Required[int] + + memory_bank_configs: Required[Iterable[CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig]] + + query_generator_config: Required[CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig] + + type: Required[Literal["memory"]] + + input_shields: List[str] + + output_shields: List[str] + + +CandidateAgentCandidateConfigTool: TypeAlias = Union[ + CandidateAgentCandidateConfigToolSearchToolDefinition, + CandidateAgentCandidateConfigToolWolframAlphaToolDefinition, + CandidateAgentCandidateConfigToolPhotogenToolDefinition, + CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition, + CandidateAgentCandidateConfigToolFunctionCallToolDefinition, + CandidateAgentCandidateConfigToolMemoryToolDefinition, +] + + +class CandidateAgentCandidateConfig(TypedDict, total=False): + enable_session_persistence: Required[bool] + + instructions: Required[str] + + max_infer_iters: Required[int] + + model: Required[str] + + input_shields: List[str] + + output_shields: List[str] + + sampling_params: SamplingParams + + tool_choice: Literal["auto", "required"] + + tool_prompt_format: Literal["json", "function_tag", "python_list"] + """ + `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + """ + + tools: Iterable[CandidateAgentCandidateConfigTool] + + +class CandidateAgentCandidate(TypedDict, total=False): + config: Required[CandidateAgentCandidateConfig] + + type: Required[Literal["agent"]] + + +Candidate: TypeAlias = Union[CandidateModelCandidate, CandidateAgentCandidate] diff --git a/src/llama_stack_client/types/evaluate_evaluate_params.py b/src/llama_stack_client/types/evaluate_evaluate_params.py new file mode 100644 index 00000000..e2daff58 --- /dev/null +++ b/src/llama_stack_client/types/evaluate_evaluate_params.py @@ -0,0 +1,259 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, List, Union, Iterable +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict + +from .._utils import PropertyInfo +from .shared_params.system_message import SystemMessage +from .shared_params.sampling_params import SamplingParams +from .rest_api_execution_config_param import RestAPIExecutionConfigParam + +__all__ = [ + "EvaluateEvaluateParams", + "Candidate", + "CandidateModelCandidate", + "CandidateAgentCandidate", + "CandidateAgentCandidateConfig", + "CandidateAgentCandidateConfigTool", + "CandidateAgentCandidateConfigToolSearchToolDefinition", + "CandidateAgentCandidateConfigToolWolframAlphaToolDefinition", + "CandidateAgentCandidateConfigToolPhotogenToolDefinition", + "CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition", + "CandidateAgentCandidateConfigToolFunctionCallToolDefinition", + "CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters", + "CandidateAgentCandidateConfigToolMemoryToolDefinition", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1", + "CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType", +] + + +class EvaluateEvaluateParams(TypedDict, total=False): + candidate: Required[Candidate] + + input_rows: Required[Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]]] + + scoring_functions: Required[List[str]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] + + +class CandidateModelCandidate(TypedDict, total=False): + model: Required[str] + + sampling_params: Required[SamplingParams] + + type: Required[Literal["model"]] + + system_message: SystemMessage + + +class CandidateAgentCandidateConfigToolSearchToolDefinition(TypedDict, total=False): + api_key: Required[str] + + engine: Required[Literal["bing", "brave"]] + + type: Required[Literal["brave_search"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolWolframAlphaToolDefinition(TypedDict, total=False): + api_key: Required[str] + + type: Required[Literal["wolfram_alpha"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolPhotogenToolDefinition(TypedDict, total=False): + type: Required[Literal["photogen"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition(TypedDict, total=False): + enable_inline_code_execution: Required[bool] + + type: Required[Literal["code_interpreter"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters(TypedDict, total=False): + param_type: Required[str] + + default: Union[bool, float, str, Iterable[object], object, None] + + description: str + + required: bool + + +class CandidateAgentCandidateConfigToolFunctionCallToolDefinition(TypedDict, total=False): + description: Required[str] + + function_name: Required[str] + + parameters: Required[Dict[str, CandidateAgentCandidateConfigToolFunctionCallToolDefinitionParameters]] + + type: Required[Literal["function_call"]] + + input_shields: List[str] + + output_shields: List[str] + + remote_execution: RestAPIExecutionConfigParam + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0(TypedDict, total=False): + bank_id: Required[str] + + type: Required[Literal["vector"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1(TypedDict, total=False): + bank_id: Required[str] + + keys: Required[List[str]] + + type: Required[Literal["keyvalue"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2(TypedDict, total=False): + bank_id: Required[str] + + type: Required[Literal["keyword"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3(TypedDict, total=False): + bank_id: Required[str] + + entities: Required[List[str]] + + type: Required[Literal["graph"]] + + +CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig: TypeAlias = Union[ + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember0, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember1, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember2, + CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfigUnionMember3, +] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0(TypedDict, total=False): + sep: Required[str] + + type: Required[Literal["default"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1(TypedDict, total=False): + model: Required[str] + + template: Required[str] + + type: Required[Literal["llm"]] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType(TypedDict, total=False): + type: Required[Literal["custom"]] + + +CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig: TypeAlias = Union[ + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember0, + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigUnionMember1, + CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfigType, +] + + +class CandidateAgentCandidateConfigToolMemoryToolDefinition(TypedDict, total=False): + max_chunks: Required[int] + + max_tokens_in_context: Required[int] + + memory_bank_configs: Required[Iterable[CandidateAgentCandidateConfigToolMemoryToolDefinitionMemoryBankConfig]] + + query_generator_config: Required[CandidateAgentCandidateConfigToolMemoryToolDefinitionQueryGeneratorConfig] + + type: Required[Literal["memory"]] + + input_shields: List[str] + + output_shields: List[str] + + +CandidateAgentCandidateConfigTool: TypeAlias = Union[ + CandidateAgentCandidateConfigToolSearchToolDefinition, + CandidateAgentCandidateConfigToolWolframAlphaToolDefinition, + CandidateAgentCandidateConfigToolPhotogenToolDefinition, + CandidateAgentCandidateConfigToolCodeInterpreterToolDefinition, + CandidateAgentCandidateConfigToolFunctionCallToolDefinition, + CandidateAgentCandidateConfigToolMemoryToolDefinition, +] + + +class CandidateAgentCandidateConfig(TypedDict, total=False): + enable_session_persistence: Required[bool] + + instructions: Required[str] + + max_infer_iters: Required[int] + + model: Required[str] + + input_shields: List[str] + + output_shields: List[str] + + sampling_params: SamplingParams + + tool_choice: Literal["auto", "required"] + + tool_prompt_format: Literal["json", "function_tag", "python_list"] + """ + `json` -- Refers to the json format for calling tools. The json format takes the + form like { "type": "function", "function" : { "name": "function_name", + "description": "function_description", "parameters": {...} } } + + `function_tag` -- This is an example of how you could define your own user + defined format for making tool calls. The function_tag format looks like this, + (parameters) + + The detailed prompts for each of these formats are added to llama cli + """ + + tools: Iterable[CandidateAgentCandidateConfigTool] + + +class CandidateAgentCandidate(TypedDict, total=False): + config: Required[CandidateAgentCandidateConfig] + + type: Required[Literal["agent"]] + + +Candidate: TypeAlias = Union[CandidateModelCandidate, CandidateAgentCandidate] diff --git a/src/llama_stack_client/types/evaluation_job.py b/src/llama_stack_client/types/evaluation_job.py index c8f291b9..5c0b51f7 100644 --- a/src/llama_stack_client/types/evaluation_job.py +++ b/src/llama_stack_client/types/evaluation_job.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from .._models import BaseModel __all__ = ["EvaluationJob"] diff --git a/src/llama_stack_client/types/health_info.py b/src/llama_stack_client/types/health_info.py index 5a8a6c80..f410c8d2 100644 --- a/src/llama_stack_client/types/health_info.py +++ b/src/llama_stack_client/types/health_info.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from .._models import BaseModel __all__ = ["HealthInfo"] diff --git a/src/llama_stack_client/types/inference_chat_completion_params.py b/src/llama_stack_client/types/inference_chat_completion_params.py index 27a099e5..d7ae066b 100644 --- a/src/llama_stack_client/types/inference_chat_completion_params.py +++ b/src/llama_stack_client/types/inference_chat_completion_params.py @@ -12,7 +12,16 @@ from .shared_params.completion_message import CompletionMessage from .shared_params.tool_response_message import ToolResponseMessage -__all__ = ["InferenceChatCompletionParams", "Message", "Logprobs", "Tool", "ToolParameters"] +__all__ = [ + "InferenceChatCompletionParams", + "Message", + "Logprobs", + "ResponseFormat", + "ResponseFormatUnionMember0", + "ResponseFormatUnionMember1", + "Tool", + "ToolParameters", +] class InferenceChatCompletionParams(TypedDict, total=False): @@ -22,6 +31,8 @@ class InferenceChatCompletionParams(TypedDict, total=False): logprobs: Logprobs + response_format: ResponseFormat + sampling_params: SamplingParams stream: bool @@ -53,6 +64,21 @@ class Logprobs(TypedDict, total=False): top_k: int +class ResponseFormatUnionMember0(TypedDict, total=False): + json_schema: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + type: Required[Literal["json_schema"]] + + +class ResponseFormatUnionMember1(TypedDict, total=False): + bnf: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + type: Required[Literal["grammar"]] + + +ResponseFormat: TypeAlias = Union[ResponseFormatUnionMember0, ResponseFormatUnionMember1] + + class ToolParameters(TypedDict, total=False): param_type: Required[str] diff --git a/src/llama_stack_client/types/inference_chat_completion_response.py b/src/llama_stack_client/types/inference_chat_completion_response.py index fbb8c0f5..ce9b330e 100644 --- a/src/llama_stack_client/types/inference_chat_completion_response.py +++ b/src/llama_stack_client/types/inference_chat_completion_response.py @@ -1,33 +1,28 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Union, Optional +from typing import List, Union, Optional from typing_extensions import Literal, TypeAlias from .._models import BaseModel +from .token_log_probs import TokenLogProbs from .shared.tool_call import ToolCall from .shared.completion_message import CompletionMessage __all__ = [ "InferenceChatCompletionResponse", "ChatCompletionResponse", - "ChatCompletionResponseLogprob", "ChatCompletionResponseStreamChunk", "ChatCompletionResponseStreamChunkEvent", "ChatCompletionResponseStreamChunkEventDelta", "ChatCompletionResponseStreamChunkEventDeltaToolCallDelta", "ChatCompletionResponseStreamChunkEventDeltaToolCallDeltaContent", - "ChatCompletionResponseStreamChunkEventLogprob", ] -class ChatCompletionResponseLogprob(BaseModel): - logprobs_by_token: Dict[str, float] - - class ChatCompletionResponse(BaseModel): completion_message: CompletionMessage - logprobs: Optional[List[ChatCompletionResponseLogprob]] = None + logprobs: Optional[List[TokenLogProbs]] = None ChatCompletionResponseStreamChunkEventDeltaToolCallDeltaContent: TypeAlias = Union[str, ToolCall] @@ -44,16 +39,12 @@ class ChatCompletionResponseStreamChunkEventDeltaToolCallDelta(BaseModel): ] -class ChatCompletionResponseStreamChunkEventLogprob(BaseModel): - logprobs_by_token: Dict[str, float] - - class ChatCompletionResponseStreamChunkEvent(BaseModel): delta: ChatCompletionResponseStreamChunkEventDelta event_type: Literal["start", "complete", "progress"] - logprobs: Optional[List[ChatCompletionResponseStreamChunkEventLogprob]] = None + logprobs: Optional[List[TokenLogProbs]] = None stop_reason: Optional[Literal["end_of_turn", "end_of_message", "out_of_tokens"]] = None diff --git a/src/llama_stack_client/types/inference_completion_params.py b/src/llama_stack_client/types/inference_completion_params.py index f83537b2..d6f9275b 100644 --- a/src/llama_stack_client/types/inference_completion_params.py +++ b/src/llama_stack_client/types/inference_completion_params.py @@ -2,14 +2,22 @@ from __future__ import annotations -from typing import List, Union -from typing_extensions import Required, Annotated, TypeAlias, TypedDict +from typing import Dict, List, Union, Iterable +from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict from .._utils import PropertyInfo from .shared_params.image_media import ImageMedia from .shared_params.sampling_params import SamplingParams -__all__ = ["InferenceCompletionParams", "Content", "ContentUnionMember2", "Logprobs"] +__all__ = [ + "InferenceCompletionParams", + "Content", + "ContentUnionMember2", + "Logprobs", + "ResponseFormat", + "ResponseFormatUnionMember0", + "ResponseFormatUnionMember1", +] class InferenceCompletionParams(TypedDict, total=False): @@ -19,6 +27,8 @@ class InferenceCompletionParams(TypedDict, total=False): logprobs: Logprobs + response_format: ResponseFormat + sampling_params: SamplingParams stream: bool @@ -33,3 +43,18 @@ class InferenceCompletionParams(TypedDict, total=False): class Logprobs(TypedDict, total=False): top_k: int + + +class ResponseFormatUnionMember0(TypedDict, total=False): + json_schema: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + type: Required[Literal["json_schema"]] + + +class ResponseFormatUnionMember1(TypedDict, total=False): + bnf: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + type: Required[Literal["grammar"]] + + +ResponseFormat: TypeAlias = Union[ResponseFormatUnionMember0, ResponseFormatUnionMember1] diff --git a/src/llama_stack_client/types/inference_completion_response.py b/src/llama_stack_client/types/inference_completion_response.py index 181bcc7f..2f64e1b7 100644 --- a/src/llama_stack_client/types/inference_completion_response.py +++ b/src/llama_stack_client/types/inference_completion_response.py @@ -1,22 +1,19 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Union, Optional +from typing import List, Union, Optional from typing_extensions import Literal, TypeAlias from .._models import BaseModel +from .token_log_probs import TokenLogProbs from .completion_response import CompletionResponse -__all__ = ["InferenceCompletionResponse", "CompletionResponseStreamChunk", "CompletionResponseStreamChunkLogprob"] - - -class CompletionResponseStreamChunkLogprob(BaseModel): - logprobs_by_token: Dict[str, float] +__all__ = ["InferenceCompletionResponse", "CompletionResponseStreamChunk"] class CompletionResponseStreamChunk(BaseModel): delta: str - logprobs: Optional[List[CompletionResponseStreamChunkLogprob]] = None + logprobs: Optional[List[TokenLogProbs]] = None stop_reason: Optional[Literal["end_of_turn", "end_of_message", "out_of_tokens"]] = None diff --git a/src/llama_stack_client/types/job.py b/src/llama_stack_client/types/job.py new file mode 100644 index 00000000..25c33c4c --- /dev/null +++ b/src/llama_stack_client/types/job.py @@ -0,0 +1,10 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + + +from .._models import BaseModel + +__all__ = ["Job"] + + +class Job(BaseModel): + job_id: str diff --git a/src/llama_stack_client/types/memory_bank_list_response.py b/src/llama_stack_client/types/memory_bank_list_response.py index 3a9e9eda..ea3cd036 100644 --- a/src/llama_stack_client/types/memory_bank_list_response.py +++ b/src/llama_stack_client/types/memory_bank_list_response.py @@ -1,56 +1,14 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Union, Optional -from typing_extensions import Literal, TypeAlias +from typing import Union +from typing_extensions import TypeAlias -from .._models import BaseModel - -__all__ = [ - "MemoryBankListResponse", - "VectorMemoryBankDef", - "KeyValueMemoryBankDef", - "KeywordMemoryBankDef", - "GraphMemoryBankDef", -] - - -class VectorMemoryBankDef(BaseModel): - chunk_size_in_tokens: int - - embedding_model: str - - identifier: str - - provider_id: str - - type: Literal["vector"] - - overlap_size_in_tokens: Optional[int] = None - - -class KeyValueMemoryBankDef(BaseModel): - identifier: str - - provider_id: str - - type: Literal["keyvalue"] - - -class KeywordMemoryBankDef(BaseModel): - identifier: str - - provider_id: str - - type: Literal["keyword"] - - -class GraphMemoryBankDef(BaseModel): - identifier: str - - provider_id: str - - type: Literal["graph"] +from .shared.graph_memory_bank_def import GraphMemoryBankDef +from .shared.vector_memory_bank_def import VectorMemoryBankDef +from .shared.keyword_memory_bank_def import KeywordMemoryBankDef +from .shared.key_value_memory_bank_def import KeyValueMemoryBankDef +__all__ = ["MemoryBankListResponse"] MemoryBankListResponse: TypeAlias = Union[ VectorMemoryBankDef, KeyValueMemoryBankDef, KeywordMemoryBankDef, GraphMemoryBankDef diff --git a/src/llama_stack_client/types/memory_bank_register_params.py b/src/llama_stack_client/types/memory_bank_register_params.py index cdba8b2a..1952e389 100644 --- a/src/llama_stack_client/types/memory_bank_register_params.py +++ b/src/llama_stack_client/types/memory_bank_register_params.py @@ -3,18 +3,15 @@ from __future__ import annotations from typing import Union -from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict +from typing_extensions import Required, Annotated, TypeAlias, TypedDict from .._utils import PropertyInfo +from .shared_params.graph_memory_bank_def import GraphMemoryBankDef +from .shared_params.vector_memory_bank_def import VectorMemoryBankDef +from .shared_params.keyword_memory_bank_def import KeywordMemoryBankDef +from .shared_params.key_value_memory_bank_def import KeyValueMemoryBankDef -__all__ = [ - "MemoryBankRegisterParams", - "MemoryBank", - "MemoryBankVectorMemoryBankDef", - "MemoryBankKeyValueMemoryBankDef", - "MemoryBankKeywordMemoryBankDef", - "MemoryBankGraphMemoryBankDef", -] +__all__ = ["MemoryBankRegisterParams", "MemoryBank"] class MemoryBankRegisterParams(TypedDict, total=False): @@ -23,47 +20,4 @@ class MemoryBankRegisterParams(TypedDict, total=False): x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] -class MemoryBankVectorMemoryBankDef(TypedDict, total=False): - chunk_size_in_tokens: Required[int] - - embedding_model: Required[str] - - identifier: Required[str] - - provider_id: Required[str] - - type: Required[Literal["vector"]] - - overlap_size_in_tokens: int - - -class MemoryBankKeyValueMemoryBankDef(TypedDict, total=False): - identifier: Required[str] - - provider_id: Required[str] - - type: Required[Literal["keyvalue"]] - - -class MemoryBankKeywordMemoryBankDef(TypedDict, total=False): - identifier: Required[str] - - provider_id: Required[str] - - type: Required[Literal["keyword"]] - - -class MemoryBankGraphMemoryBankDef(TypedDict, total=False): - identifier: Required[str] - - provider_id: Required[str] - - type: Required[Literal["graph"]] - - -MemoryBank: TypeAlias = Union[ - MemoryBankVectorMemoryBankDef, - MemoryBankKeyValueMemoryBankDef, - MemoryBankKeywordMemoryBankDef, - MemoryBankGraphMemoryBankDef, -] +MemoryBank: TypeAlias = Union[VectorMemoryBankDef, KeyValueMemoryBankDef, KeywordMemoryBankDef, GraphMemoryBankDef] diff --git a/src/llama_stack_client/types/memory_bank_retrieve_response.py b/src/llama_stack_client/types/memory_bank_retrieve_response.py index c90a3742..294ca9c8 100644 --- a/src/llama_stack_client/types/memory_bank_retrieve_response.py +++ b/src/llama_stack_client/types/memory_bank_retrieve_response.py @@ -1,56 +1,14 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Union, Optional -from typing_extensions import Literal, TypeAlias +from typing import Union +from typing_extensions import TypeAlias -from .._models import BaseModel - -__all__ = [ - "MemoryBankRetrieveResponse", - "VectorMemoryBankDef", - "KeyValueMemoryBankDef", - "KeywordMemoryBankDef", - "GraphMemoryBankDef", -] - - -class VectorMemoryBankDef(BaseModel): - chunk_size_in_tokens: int - - embedding_model: str - - identifier: str - - provider_id: str - - type: Literal["vector"] - - overlap_size_in_tokens: Optional[int] = None - - -class KeyValueMemoryBankDef(BaseModel): - identifier: str - - provider_id: str - - type: Literal["keyvalue"] - - -class KeywordMemoryBankDef(BaseModel): - identifier: str - - provider_id: str - - type: Literal["keyword"] - - -class GraphMemoryBankDef(BaseModel): - identifier: str - - provider_id: str - - type: Literal["graph"] +from .shared.graph_memory_bank_def import GraphMemoryBankDef +from .shared.vector_memory_bank_def import VectorMemoryBankDef +from .shared.keyword_memory_bank_def import KeywordMemoryBankDef +from .shared.key_value_memory_bank_def import KeyValueMemoryBankDef +__all__ = ["MemoryBankRetrieveResponse"] MemoryBankRetrieveResponse: TypeAlias = Union[ VectorMemoryBankDef, KeyValueMemoryBankDef, KeywordMemoryBankDef, GraphMemoryBankDef, None diff --git a/src/llama_stack_client/types/memory_retrieval_step.py b/src/llama_stack_client/types/memory_retrieval_step.py index 6e247bb9..9ffc988b 100644 --- a/src/llama_stack_client/types/memory_retrieval_step.py +++ b/src/llama_stack_client/types/memory_retrieval_step.py @@ -6,11 +6,12 @@ from .._models import BaseModel from .shared.image_media import ImageMedia -from .shared.content_array import ContentArray -__all__ = ["MemoryRetrievalStep", "InsertedContext"] +__all__ = ["MemoryRetrievalStep", "InsertedContext", "InsertedContextUnionMember2"] -InsertedContext: TypeAlias = Union[str, ImageMedia, ContentArray] +InsertedContextUnionMember2: TypeAlias = Union[str, ImageMedia] + +InsertedContext: TypeAlias = Union[str, ImageMedia, List[InsertedContextUnionMember2]] class MemoryRetrievalStep(BaseModel): diff --git a/src/llama_stack_client/types/model_list_response.py b/src/llama_stack_client/types/model_list_response.py new file mode 100644 index 00000000..6fb1a940 --- /dev/null +++ b/src/llama_stack_client/types/model_list_response.py @@ -0,0 +1,23 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union + +from .._models import BaseModel + +__all__ = ["ModelListResponse", "ProviderConfig"] + + +class ProviderConfig(BaseModel): + config: Dict[str, Union[bool, float, str, List[object], object, None]] + + provider_type: str + + +class ModelListResponse(BaseModel): + llama_model: object + """ + The model family and SKU of the model along with other parameters corresponding + to the model. + """ + + provider_config: ProviderConfig diff --git a/src/llama_stack_client/types/model_retrieve_response.py b/src/llama_stack_client/types/model_retrieve_response.py new file mode 100644 index 00000000..48c54d01 --- /dev/null +++ b/src/llama_stack_client/types/model_retrieve_response.py @@ -0,0 +1,23 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union + +from .._models import BaseModel + +__all__ = ["ModelRetrieveResponse", "ProviderConfig"] + + +class ProviderConfig(BaseModel): + config: Dict[str, Union[bool, float, str, List[object], object, None]] + + provider_type: str + + +class ModelRetrieveResponse(BaseModel): + llama_model: object + """ + The model family and SKU of the model along with other parameters corresponding + to the model. + """ + + provider_config: ProviderConfig diff --git a/src/llama_stack_client/types/paginated_rows_result.py b/src/llama_stack_client/types/paginated_rows_result.py new file mode 100644 index 00000000..15eb53e2 --- /dev/null +++ b/src/llama_stack_client/types/paginated_rows_result.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional + +from .._models import BaseModel + +__all__ = ["PaginatedRowsResult"] + + +class PaginatedRowsResult(BaseModel): + rows: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + + total_count: int + + next_page_token: Optional[str] = None diff --git a/src/llama_stack_client/types/post_training_job.py b/src/llama_stack_client/types/post_training_job.py index 1195facc..8cd98126 100644 --- a/src/llama_stack_client/types/post_training_job.py +++ b/src/llama_stack_client/types/post_training_job.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from .._models import BaseModel __all__ = ["PostTrainingJob"] diff --git a/src/llama_stack_client/types/post_training_preference_optimize_params.py b/src/llama_stack_client/types/post_training_preference_optimize_params.py index 805e6cfb..eb44c6ac 100644 --- a/src/llama_stack_client/types/post_training_preference_optimize_params.py +++ b/src/llama_stack_client/types/post_training_preference_optimize_params.py @@ -6,7 +6,6 @@ from typing_extensions import Literal, Required, Annotated, TypedDict from .._utils import PropertyInfo -from .train_eval_dataset_param import TrainEvalDatasetParam __all__ = ["PostTrainingPreferenceOptimizeParams", "AlgorithmConfig", "OptimizerConfig", "TrainingConfig"] @@ -16,7 +15,7 @@ class PostTrainingPreferenceOptimizeParams(TypedDict, total=False): algorithm_config: Required[AlgorithmConfig] - dataset: Required[TrainEvalDatasetParam] + dataset_id: Required[str] finetuned_model: Required[str] @@ -30,7 +29,7 @@ class PostTrainingPreferenceOptimizeParams(TypedDict, total=False): training_config: Required[TrainingConfig] - validation_dataset: Required[TrainEvalDatasetParam] + validation_dataset_id: Required[str] x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py b/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py index 084e1ede..9230e631 100644 --- a/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py +++ b/src/llama_stack_client/types/post_training_supervised_fine_tune_params.py @@ -6,7 +6,6 @@ from typing_extensions import Literal, Required, Annotated, TypeAlias, TypedDict from .._utils import PropertyInfo -from .train_eval_dataset_param import TrainEvalDatasetParam __all__ = [ "PostTrainingSupervisedFineTuneParams", @@ -24,7 +23,7 @@ class PostTrainingSupervisedFineTuneParams(TypedDict, total=False): algorithm_config: Required[AlgorithmConfig] - dataset: Required[TrainEvalDatasetParam] + dataset_id: Required[str] hyperparam_search_config: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] @@ -38,7 +37,7 @@ class PostTrainingSupervisedFineTuneParams(TypedDict, total=False): training_config: Required[TrainingConfig] - validation_dataset: Required[TrainEvalDatasetParam] + validation_dataset_id: Required[str] x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/provider_info.py b/src/llama_stack_client/types/provider_info.py index ca532f5f..a1a79f79 100644 --- a/src/llama_stack_client/types/provider_info.py +++ b/src/llama_stack_client/types/provider_info.py @@ -1,7 +1,6 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. - from .._models import BaseModel __all__ = ["ProviderInfo"] diff --git a/src/llama_stack_client/types/run_shield_response.py b/src/llama_stack_client/types/run_shield_response.py index 9e9e6330..1dbdf5a0 100644 --- a/src/llama_stack_client/types/run_shield_response.py +++ b/src/llama_stack_client/types/run_shield_response.py @@ -1,20 +1,12 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Union, Optional -from typing_extensions import Literal +from typing import Optional from .._models import BaseModel +from .shared.safety_violation import SafetyViolation -__all__ = ["RunShieldResponse", "Violation"] - - -class Violation(BaseModel): - metadata: Dict[str, Union[bool, float, str, List[object], object, None]] - - violation_level: Literal["info", "warn", "error"] - - user_message: Optional[str] = None +__all__ = ["RunShieldResponse"] class RunShieldResponse(BaseModel): - violation: Optional[Violation] = None + violation: Optional[SafetyViolation] = None diff --git a/src/llama_stack_client/types/score_batch_response.py b/src/llama_stack_client/types/score_batch_response.py new file mode 100644 index 00000000..876bf062 --- /dev/null +++ b/src/llama_stack_client/types/score_batch_response.py @@ -0,0 +1,19 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional + +from .._models import BaseModel + +__all__ = ["ScoreBatchResponse", "Results"] + + +class Results(BaseModel): + aggregated_results: Dict[str, Union[bool, float, str, List[object], object, None]] + + score_rows: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + + +class ScoreBatchResponse(BaseModel): + results: Dict[str, Results] + + dataset_id: Optional[str] = None diff --git a/src/llama_stack_client/types/score_response.py b/src/llama_stack_client/types/score_response.py new file mode 100644 index 00000000..967cc623 --- /dev/null +++ b/src/llama_stack_client/types/score_response.py @@ -0,0 +1,17 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union + +from .._models import BaseModel + +__all__ = ["ScoreResponse", "Results"] + + +class Results(BaseModel): + aggregated_results: Dict[str, Union[bool, float, str, List[object], object, None]] + + score_rows: List[Dict[str, Union[bool, float, str, List[object], object, None]]] + + +class ScoreResponse(BaseModel): + results: Dict[str, Results] diff --git a/src/llama_stack_client/types/scoring_fn_def_with_provider.py b/src/llama_stack_client/types/scoring_fn_def_with_provider.py new file mode 100644 index 00000000..dd8b1c1e --- /dev/null +++ b/src/llama_stack_client/types/scoring_fn_def_with_provider.py @@ -0,0 +1,84 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional +from typing_extensions import Literal, TypeAlias + +from .._models import BaseModel + +__all__ = [ + "ScoringFnDefWithProvider", + "Parameter", + "ParameterType", + "ParameterTypeType", + "ReturnType", + "ReturnTypeType", + "Context", +] + + +class ParameterTypeType(BaseModel): + type: Literal["string"] + + +ParameterType: TypeAlias = Union[ + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, +] + + +class Parameter(BaseModel): + name: str + + type: ParameterType + + description: Optional[str] = None + + +class ReturnTypeType(BaseModel): + type: Literal["string"] + + +ReturnType: TypeAlias = Union[ + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, +] + + +class Context(BaseModel): + judge_model: str + + judge_score_regex: Optional[List[str]] = None + + prompt_template: Optional[str] = None + + +class ScoringFnDefWithProvider(BaseModel): + identifier: str + + metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + + parameters: List[Parameter] + + provider_id: str + + return_type: ReturnType + + context: Optional[Context] = None + + description: Optional[str] = None diff --git a/src/llama_stack_client/types/scoring_fn_def_with_provider_param.py b/src/llama_stack_client/types/scoring_fn_def_with_provider_param.py new file mode 100644 index 00000000..e6f8d4fa --- /dev/null +++ b/src/llama_stack_client/types/scoring_fn_def_with_provider_param.py @@ -0,0 +1,84 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, List, Union, Iterable +from typing_extensions import Literal, Required, TypeAlias, TypedDict + +__all__ = [ + "ScoringFnDefWithProviderParam", + "Parameter", + "ParameterType", + "ParameterTypeType", + "ReturnType", + "ReturnTypeType", + "Context", +] + + +class ParameterTypeType(TypedDict, total=False): + type: Required[Literal["string"]] + + +ParameterType: TypeAlias = Union[ + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, +] + + +class Parameter(TypedDict, total=False): + name: Required[str] + + type: Required[ParameterType] + + description: str + + +class ReturnTypeType(TypedDict, total=False): + type: Required[Literal["string"]] + + +ReturnType: TypeAlias = Union[ + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, +] + + +class Context(TypedDict, total=False): + judge_model: Required[str] + + judge_score_regex: List[str] + + prompt_template: str + + +class ScoringFnDefWithProviderParam(TypedDict, total=False): + identifier: Required[str] + + metadata: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + parameters: Required[Iterable[Parameter]] + + provider_id: Required[str] + + return_type: Required[ReturnType] + + context: Context + + description: str diff --git a/src/llama_stack_client/types/scoring_function_def_with_provider.py b/src/llama_stack_client/types/scoring_function_def_with_provider.py new file mode 100644 index 00000000..42d1d2be --- /dev/null +++ b/src/llama_stack_client/types/scoring_function_def_with_provider.py @@ -0,0 +1,98 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional +from typing_extensions import Literal, TypeAlias + +from .._models import BaseModel + +__all__ = [ + "ScoringFunctionDefWithProvider", + "Parameter", + "ParameterType", + "ParameterTypeType", + "ParameterTypeUnionMember7", + "ReturnType", + "ReturnTypeType", + "ReturnTypeUnionMember7", + "Context", +] + + +class ParameterTypeType(BaseModel): + type: Literal["string"] + + +class ParameterTypeUnionMember7(BaseModel): + type: Literal["custom"] + + validator_class: str + + +ParameterType: TypeAlias = Union[ + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeUnionMember7, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, +] + + +class Parameter(BaseModel): + name: str + + type: ParameterType + + description: Optional[str] = None + + +class ReturnTypeType(BaseModel): + type: Literal["string"] + + +class ReturnTypeUnionMember7(BaseModel): + type: Literal["custom"] + + validator_class: str + + +ReturnType: TypeAlias = Union[ + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeUnionMember7, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, +] + + +class Context(BaseModel): + judge_model: str + + prompt_template: Optional[str] = None + + +class ScoringFunctionDefWithProvider(BaseModel): + identifier: str + + metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + + parameters: List[Parameter] + + provider_id: str + + return_type: ReturnType + + context: Optional[Context] = None + + description: Optional[str] = None diff --git a/src/llama_stack_client/types/scoring_function_def_with_provider_param.py b/src/llama_stack_client/types/scoring_function_def_with_provider_param.py new file mode 100644 index 00000000..93bdee51 --- /dev/null +++ b/src/llama_stack_client/types/scoring_function_def_with_provider_param.py @@ -0,0 +1,98 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, Union, Iterable +from typing_extensions import Literal, Required, TypeAlias, TypedDict + +__all__ = [ + "ScoringFunctionDefWithProviderParam", + "Parameter", + "ParameterType", + "ParameterTypeType", + "ParameterTypeUnionMember7", + "ReturnType", + "ReturnTypeType", + "ReturnTypeUnionMember7", + "Context", +] + + +class ParameterTypeType(TypedDict, total=False): + type: Required[Literal["string"]] + + +class ParameterTypeUnionMember7(TypedDict, total=False): + type: Required[Literal["custom"]] + + validator_class: Required[str] + + +ParameterType: TypeAlias = Union[ + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeUnionMember7, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, +] + + +class Parameter(TypedDict, total=False): + name: Required[str] + + type: Required[ParameterType] + + description: str + + +class ReturnTypeType(TypedDict, total=False): + type: Required[Literal["string"]] + + +class ReturnTypeUnionMember7(TypedDict, total=False): + type: Required[Literal["custom"]] + + validator_class: Required[str] + + +ReturnType: TypeAlias = Union[ + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeUnionMember7, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, +] + + +class Context(TypedDict, total=False): + judge_model: Required[str] + + prompt_template: str + + +class ScoringFunctionDefWithProviderParam(TypedDict, total=False): + identifier: Required[str] + + metadata: Required[Dict[str, Union[bool, float, str, Iterable[object], object, None]]] + + parameters: Required[Iterable[Parameter]] + + provider_id: Required[str] + + return_type: Required[ReturnType] + + context: Context + + description: str diff --git a/src/llama_stack_client/types/scoring_function_list_response.py b/src/llama_stack_client/types/scoring_function_list_response.py new file mode 100644 index 00000000..c1dd9d4f --- /dev/null +++ b/src/llama_stack_client/types/scoring_function_list_response.py @@ -0,0 +1,84 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional +from typing_extensions import Literal, TypeAlias + +from .._models import BaseModel + +__all__ = [ + "ScoringFunctionListResponse", + "Parameter", + "ParameterType", + "ParameterTypeType", + "ReturnType", + "ReturnTypeType", + "Context", +] + + +class ParameterTypeType(BaseModel): + type: Literal["string"] + + +ParameterType: TypeAlias = Union[ + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, +] + + +class Parameter(BaseModel): + name: str + + type: ParameterType + + description: Optional[str] = None + + +class ReturnTypeType(BaseModel): + type: Literal["string"] + + +ReturnType: TypeAlias = Union[ + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, +] + + +class Context(BaseModel): + judge_model: str + + judge_score_regex: Optional[List[str]] = None + + prompt_template: Optional[str] = None + + +class ScoringFunctionListResponse(BaseModel): + identifier: str + + metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + + parameters: List[Parameter] + + provider_id: str + + return_type: ReturnType + + context: Optional[Context] = None + + description: Optional[str] = None diff --git a/src/llama_stack_client/types/scoring_function_register_params.py b/src/llama_stack_client/types/scoring_function_register_params.py new file mode 100644 index 00000000..36e93e48 --- /dev/null +++ b/src/llama_stack_client/types/scoring_function_register_params.py @@ -0,0 +1,16 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo +from .scoring_fn_def_with_provider_param import ScoringFnDefWithProviderParam + +__all__ = ["ScoringFunctionRegisterParams"] + + +class ScoringFunctionRegisterParams(TypedDict, total=False): + function_def: Required[ScoringFnDefWithProviderParam] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/scoring_function_retrieve_params.py b/src/llama_stack_client/types/scoring_function_retrieve_params.py new file mode 100644 index 00000000..8ad1029c --- /dev/null +++ b/src/llama_stack_client/types/scoring_function_retrieve_params.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["ScoringFunctionRetrieveParams"] + + +class ScoringFunctionRetrieveParams(TypedDict, total=False): + name: Required[str] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/scoring_function_retrieve_response.py b/src/llama_stack_client/types/scoring_function_retrieve_response.py new file mode 100644 index 00000000..a99c1fb7 --- /dev/null +++ b/src/llama_stack_client/types/scoring_function_retrieve_response.py @@ -0,0 +1,84 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional +from typing_extensions import Literal, TypeAlias + +from .._models import BaseModel + +__all__ = [ + "ScoringFunctionRetrieveResponse", + "Parameter", + "ParameterType", + "ParameterTypeType", + "ReturnType", + "ReturnTypeType", + "Context", +] + + +class ParameterTypeType(BaseModel): + type: Literal["string"] + + +ParameterType: TypeAlias = Union[ + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, + ParameterTypeType, +] + + +class Parameter(BaseModel): + name: str + + type: ParameterType + + description: Optional[str] = None + + +class ReturnTypeType(BaseModel): + type: Literal["string"] + + +ReturnType: TypeAlias = Union[ + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, + ReturnTypeType, +] + + +class Context(BaseModel): + judge_model: str + + judge_score_regex: Optional[List[str]] = None + + prompt_template: Optional[str] = None + + +class ScoringFunctionRetrieveResponse(BaseModel): + identifier: str + + metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + + parameters: List[Parameter] + + provider_id: str + + return_type: ReturnType + + context: Optional[Context] = None + + description: Optional[str] = None diff --git a/src/llama_stack_client/types/scoring_score_batch_params.py b/src/llama_stack_client/types/scoring_score_batch_params.py new file mode 100644 index 00000000..f23c23a4 --- /dev/null +++ b/src/llama_stack_client/types/scoring_score_batch_params.py @@ -0,0 +1,20 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import List +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["ScoringScoreBatchParams"] + + +class ScoringScoreBatchParams(TypedDict, total=False): + dataset_id: Required[str] + + save_results_dataset: Required[bool] + + scoring_functions: Required[List[str]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/scoring_score_params.py b/src/llama_stack_client/types/scoring_score_params.py new file mode 100644 index 00000000..8419b096 --- /dev/null +++ b/src/llama_stack_client/types/scoring_score_params.py @@ -0,0 +1,18 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing import Dict, List, Union, Iterable +from typing_extensions import Required, Annotated, TypedDict + +from .._utils import PropertyInfo + +__all__ = ["ScoringScoreParams"] + + +class ScoringScoreParams(TypedDict, total=False): + input_rows: Required[Iterable[Dict[str, Union[bool, float, str, Iterable[object], object, None]]]] + + scoring_functions: Required[List[str]] + + x_llama_stack_provider_data: Annotated[str, PropertyInfo(alias="X-LlamaStack-ProviderData")] diff --git a/src/llama_stack_client/types/shared/__init__.py b/src/llama_stack_client/types/shared/__init__.py index b4ab76d2..21cc1d96 100644 --- a/src/llama_stack_client/types/shared/__init__.py +++ b/src/llama_stack_client/types/shared/__init__.py @@ -7,5 +7,10 @@ from .system_message import SystemMessage as SystemMessage from .sampling_params import SamplingParams as SamplingParams from .batch_completion import BatchCompletion as BatchCompletion +from .safety_violation import SafetyViolation as SafetyViolation from .completion_message import CompletionMessage as CompletionMessage +from .graph_memory_bank_def import GraphMemoryBankDef as GraphMemoryBankDef from .tool_response_message import ToolResponseMessage as ToolResponseMessage +from .vector_memory_bank_def import VectorMemoryBankDef as VectorMemoryBankDef +from .keyword_memory_bank_def import KeywordMemoryBankDef as KeywordMemoryBankDef +from .key_value_memory_bank_def import KeyValueMemoryBankDef as KeyValueMemoryBankDef diff --git a/src/llama_stack_client/types/shared/graph_memory_bank_def.py b/src/llama_stack_client/types/shared/graph_memory_bank_def.py new file mode 100644 index 00000000..22353e6c --- /dev/null +++ b/src/llama_stack_client/types/shared/graph_memory_bank_def.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["GraphMemoryBankDef"] + + +class GraphMemoryBankDef(BaseModel): + identifier: str + + provider_id: str + + type: Literal["graph"] diff --git a/src/llama_stack_client/types/shared/key_value_memory_bank_def.py b/src/llama_stack_client/types/shared/key_value_memory_bank_def.py new file mode 100644 index 00000000..2a328d38 --- /dev/null +++ b/src/llama_stack_client/types/shared/key_value_memory_bank_def.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["KeyValueMemoryBankDef"] + + +class KeyValueMemoryBankDef(BaseModel): + identifier: str + + provider_id: str + + type: Literal["keyvalue"] diff --git a/src/llama_stack_client/types/shared/keyword_memory_bank_def.py b/src/llama_stack_client/types/shared/keyword_memory_bank_def.py new file mode 100644 index 00000000..e1637af5 --- /dev/null +++ b/src/llama_stack_client/types/shared/keyword_memory_bank_def.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["KeywordMemoryBankDef"] + + +class KeywordMemoryBankDef(BaseModel): + identifier: str + + provider_id: str + + type: Literal["keyword"] diff --git a/src/llama_stack_client/types/shared/safety_violation.py b/src/llama_stack_client/types/shared/safety_violation.py new file mode 100644 index 00000000..e3c94312 --- /dev/null +++ b/src/llama_stack_client/types/shared/safety_violation.py @@ -0,0 +1,16 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union, Optional +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["SafetyViolation"] + + +class SafetyViolation(BaseModel): + metadata: Dict[str, Union[bool, float, str, List[object], object, None]] + + violation_level: Literal["info", "warn", "error"] + + user_message: Optional[str] = None diff --git a/src/llama_stack_client/types/shared/vector_memory_bank_def.py b/src/llama_stack_client/types/shared/vector_memory_bank_def.py new file mode 100644 index 00000000..04297526 --- /dev/null +++ b/src/llama_stack_client/types/shared/vector_memory_bank_def.py @@ -0,0 +1,22 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Optional +from typing_extensions import Literal + +from ..._models import BaseModel + +__all__ = ["VectorMemoryBankDef"] + + +class VectorMemoryBankDef(BaseModel): + chunk_size_in_tokens: int + + embedding_model: str + + identifier: str + + provider_id: str + + type: Literal["vector"] + + overlap_size_in_tokens: Optional[int] = None diff --git a/src/llama_stack_client/types/shared_params/__init__.py b/src/llama_stack_client/types/shared_params/__init__.py index 5ebba3b4..157c2d45 100644 --- a/src/llama_stack_client/types/shared_params/__init__.py +++ b/src/llama_stack_client/types/shared_params/__init__.py @@ -7,4 +7,8 @@ from .system_message import SystemMessage as SystemMessage from .sampling_params import SamplingParams as SamplingParams from .completion_message import CompletionMessage as CompletionMessage +from .graph_memory_bank_def import GraphMemoryBankDef as GraphMemoryBankDef from .tool_response_message import ToolResponseMessage as ToolResponseMessage +from .vector_memory_bank_def import VectorMemoryBankDef as VectorMemoryBankDef +from .keyword_memory_bank_def import KeywordMemoryBankDef as KeywordMemoryBankDef +from .key_value_memory_bank_def import KeyValueMemoryBankDef as KeyValueMemoryBankDef diff --git a/src/llama_stack_client/types/shared_params/graph_memory_bank_def.py b/src/llama_stack_client/types/shared_params/graph_memory_bank_def.py new file mode 100644 index 00000000..d9858622 --- /dev/null +++ b/src/llama_stack_client/types/shared_params/graph_memory_bank_def.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["GraphMemoryBankDef"] + + +class GraphMemoryBankDef(TypedDict, total=False): + identifier: Required[str] + + provider_id: Required[str] + + type: Required[Literal["graph"]] diff --git a/src/llama_stack_client/types/shared_params/key_value_memory_bank_def.py b/src/llama_stack_client/types/shared_params/key_value_memory_bank_def.py new file mode 100644 index 00000000..c6e2999b --- /dev/null +++ b/src/llama_stack_client/types/shared_params/key_value_memory_bank_def.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["KeyValueMemoryBankDef"] + + +class KeyValueMemoryBankDef(TypedDict, total=False): + identifier: Required[str] + + provider_id: Required[str] + + type: Required[Literal["keyvalue"]] diff --git a/src/llama_stack_client/types/shared_params/keyword_memory_bank_def.py b/src/llama_stack_client/types/shared_params/keyword_memory_bank_def.py new file mode 100644 index 00000000..d71ca72d --- /dev/null +++ b/src/llama_stack_client/types/shared_params/keyword_memory_bank_def.py @@ -0,0 +1,15 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["KeywordMemoryBankDef"] + + +class KeywordMemoryBankDef(TypedDict, total=False): + identifier: Required[str] + + provider_id: Required[str] + + type: Required[Literal["keyword"]] diff --git a/src/llama_stack_client/types/shared_params/vector_memory_bank_def.py b/src/llama_stack_client/types/shared_params/vector_memory_bank_def.py new file mode 100644 index 00000000..50428659 --- /dev/null +++ b/src/llama_stack_client/types/shared_params/vector_memory_bank_def.py @@ -0,0 +1,21 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +from typing_extensions import Literal, Required, TypedDict + +__all__ = ["VectorMemoryBankDef"] + + +class VectorMemoryBankDef(TypedDict, total=False): + chunk_size_in_tokens: Required[int] + + embedding_model: Required[str] + + identifier: Required[str] + + provider_id: Required[str] + + type: Required[Literal["vector"]] + + overlap_size_in_tokens: int diff --git a/src/llama_stack_client/types/shield_call_step.py b/src/llama_stack_client/types/shield_call_step.py index d4b90d8b..8c539d0b 100644 --- a/src/llama_stack_client/types/shield_call_step.py +++ b/src/llama_stack_client/types/shield_call_step.py @@ -1,20 +1,13 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import Dict, List, Union, Optional +from typing import Optional from datetime import datetime from typing_extensions import Literal from .._models import BaseModel +from .shared.safety_violation import SafetyViolation -__all__ = ["ShieldCallStep", "Violation"] - - -class Violation(BaseModel): - metadata: Dict[str, Union[bool, float, str, List[object], object, None]] - - violation_level: Literal["info", "warn", "error"] - - user_message: Optional[str] = None +__all__ = ["ShieldCallStep"] class ShieldCallStep(BaseModel): @@ -28,4 +21,4 @@ class ShieldCallStep(BaseModel): started_at: Optional[datetime] = None - violation: Optional[Violation] = None + violation: Optional[SafetyViolation] = None diff --git a/src/llama_stack_client/types/shield_list_response.py b/src/llama_stack_client/types/shield_list_response.py new file mode 100644 index 00000000..9f375cbf --- /dev/null +++ b/src/llama_stack_client/types/shield_list_response.py @@ -0,0 +1,19 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union + +from .._models import BaseModel + +__all__ = ["ShieldListResponse", "ProviderConfig"] + + +class ProviderConfig(BaseModel): + config: Dict[str, Union[bool, float, str, List[object], object, None]] + + provider_type: str + + +class ShieldListResponse(BaseModel): + provider_config: ProviderConfig + + shield_type: str diff --git a/src/llama_stack_client/types/shield_retrieve_response.py b/src/llama_stack_client/types/shield_retrieve_response.py new file mode 100644 index 00000000..d552342f --- /dev/null +++ b/src/llama_stack_client/types/shield_retrieve_response.py @@ -0,0 +1,19 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import Dict, List, Union + +from .._models import BaseModel + +__all__ = ["ShieldRetrieveResponse", "ProviderConfig"] + + +class ProviderConfig(BaseModel): + config: Dict[str, Union[bool, float, str, List[object], object, None]] + + provider_type: str + + +class ShieldRetrieveResponse(BaseModel): + provider_config: ProviderConfig + + shield_type: str diff --git a/src/llama_stack_client/types/synthetic_data_generation_response.py b/src/llama_stack_client/types/synthetic_data_generation_response.py index 2fda8e3e..a2ee11e6 100644 --- a/src/llama_stack_client/types/synthetic_data_generation_response.py +++ b/src/llama_stack_client/types/synthetic_data_generation_response.py @@ -1,42 +1,13 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. from typing import Dict, List, Union, Optional -from typing_extensions import TypeAlias from .._models import BaseModel -from .shared.user_message import UserMessage -from .shared.system_message import SystemMessage -from .shared.completion_message import CompletionMessage -from .shared.tool_response_message import ToolResponseMessage -__all__ = [ - "SyntheticDataGenerationResponse", - "SyntheticData", - "SyntheticDataDialog", - "SyntheticDataScoredGeneration", - "SyntheticDataScoredGenerationMessage", -] - -SyntheticDataDialog: TypeAlias = Union[UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage] - -SyntheticDataScoredGenerationMessage: TypeAlias = Union[ - UserMessage, SystemMessage, ToolResponseMessage, CompletionMessage -] - - -class SyntheticDataScoredGeneration(BaseModel): - message: SyntheticDataScoredGenerationMessage - - score: float - - -class SyntheticData(BaseModel): - dialog: List[SyntheticDataDialog] - - scored_generations: List[SyntheticDataScoredGeneration] +__all__ = ["SyntheticDataGenerationResponse"] class SyntheticDataGenerationResponse(BaseModel): - synthetic_data: List[SyntheticData] + synthetic_data: List[Dict[str, Union[bool, float, str, List[object], object, None]]] statistics: Optional[Dict[str, Union[bool, float, str, List[object], object, None]]] = None diff --git a/src/llama_stack_client/types/tool_execution_step.py b/src/llama_stack_client/types/tool_execution_step.py index 4b5a98a9..c4ad5d68 100644 --- a/src/llama_stack_client/types/tool_execution_step.py +++ b/src/llama_stack_client/types/tool_execution_step.py @@ -1,25 +1,14 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. -from typing import List, Union, Optional +from typing import List, Optional from datetime import datetime -from typing_extensions import Literal, TypeAlias +from typing_extensions import Literal from .._models import BaseModel +from .tool_response import ToolResponse from .shared.tool_call import ToolCall -from .shared.image_media import ImageMedia -from .shared.content_array import ContentArray -__all__ = ["ToolExecutionStep", "ToolResponse", "ToolResponseContent"] - -ToolResponseContent: TypeAlias = Union[str, ImageMedia, ContentArray] - - -class ToolResponse(BaseModel): - call_id: str - - content: ToolResponseContent - - tool_name: Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str] +__all__ = ["ToolExecutionStep"] class ToolExecutionStep(BaseModel): diff --git a/src/llama_stack_client/types/tool_response.py b/src/llama_stack_client/types/tool_response.py new file mode 100644 index 00000000..aad08716 --- /dev/null +++ b/src/llama_stack_client/types/tool_response.py @@ -0,0 +1,21 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Union +from typing_extensions import Literal, TypeAlias + +from .._models import BaseModel +from .shared.image_media import ImageMedia + +__all__ = ["ToolResponse", "Content", "ContentUnionMember2"] + +ContentUnionMember2: TypeAlias = Union[str, ImageMedia] + +Content: TypeAlias = Union[str, ImageMedia, List[ContentUnionMember2]] + + +class ToolResponse(BaseModel): + call_id: str + + content: Content + + tool_name: Union[Literal["brave_search", "wolfram_alpha", "photogen", "code_interpreter"], str] diff --git a/tests/api_resources/agents/test_session.py b/tests/api_resources/agents/test_session.py new file mode 100644 index 00000000..9c3a0364 --- /dev/null +++ b/tests/api_resources/agents/test_session.py @@ -0,0 +1,285 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types.agents import ( + Session, + SessionCreateResponse, +) + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestSession: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_create(self, client: LlamaStackClient) -> None: + session = client.agents.session.create( + agent_id="agent_id", + session_name="session_name", + ) + assert_matches_type(SessionCreateResponse, session, path=["response"]) + + @parametrize + def test_method_create_with_all_params(self, client: LlamaStackClient) -> None: + session = client.agents.session.create( + agent_id="agent_id", + session_name="session_name", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(SessionCreateResponse, session, path=["response"]) + + @parametrize + def test_raw_response_create(self, client: LlamaStackClient) -> None: + response = client.agents.session.with_raw_response.create( + agent_id="agent_id", + session_name="session_name", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + session = response.parse() + assert_matches_type(SessionCreateResponse, session, path=["response"]) + + @parametrize + def test_streaming_response_create(self, client: LlamaStackClient) -> None: + with client.agents.session.with_streaming_response.create( + agent_id="agent_id", + session_name="session_name", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + session = response.parse() + assert_matches_type(SessionCreateResponse, session, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_retrieve(self, client: LlamaStackClient) -> None: + session = client.agents.session.retrieve( + agent_id="agent_id", + session_id="session_id", + ) + assert_matches_type(Session, session, path=["response"]) + + @parametrize + def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: + session = client.agents.session.retrieve( + agent_id="agent_id", + session_id="session_id", + turn_ids=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Session, session, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: + response = client.agents.session.with_raw_response.retrieve( + agent_id="agent_id", + session_id="session_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + session = response.parse() + assert_matches_type(Session, session, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: + with client.agents.session.with_streaming_response.retrieve( + agent_id="agent_id", + session_id="session_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + session = response.parse() + assert_matches_type(Session, session, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_delete(self, client: LlamaStackClient) -> None: + session = client.agents.session.delete( + agent_id="agent_id", + session_id="session_id", + ) + assert session is None + + @parametrize + def test_method_delete_with_all_params(self, client: LlamaStackClient) -> None: + session = client.agents.session.delete( + agent_id="agent_id", + session_id="session_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert session is None + + @parametrize + def test_raw_response_delete(self, client: LlamaStackClient) -> None: + response = client.agents.session.with_raw_response.delete( + agent_id="agent_id", + session_id="session_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + session = response.parse() + assert session is None + + @parametrize + def test_streaming_response_delete(self, client: LlamaStackClient) -> None: + with client.agents.session.with_streaming_response.delete( + agent_id="agent_id", + session_id="session_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + session = response.parse() + assert session is None + + assert cast(Any, response.is_closed) is True + + +class TestAsyncSession: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_create(self, async_client: AsyncLlamaStackClient) -> None: + session = await async_client.agents.session.create( + agent_id="agent_id", + session_name="session_name", + ) + assert_matches_type(SessionCreateResponse, session, path=["response"]) + + @parametrize + async def test_method_create_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + session = await async_client.agents.session.create( + agent_id="agent_id", + session_name="session_name", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(SessionCreateResponse, session, path=["response"]) + + @parametrize + async def test_raw_response_create(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.session.with_raw_response.create( + agent_id="agent_id", + session_name="session_name", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + session = await response.parse() + assert_matches_type(SessionCreateResponse, session, path=["response"]) + + @parametrize + async def test_streaming_response_create(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.session.with_streaming_response.create( + agent_id="agent_id", + session_name="session_name", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + session = await response.parse() + assert_matches_type(SessionCreateResponse, session, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + session = await async_client.agents.session.retrieve( + agent_id="agent_id", + session_id="session_id", + ) + assert_matches_type(Session, session, path=["response"]) + + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + session = await async_client.agents.session.retrieve( + agent_id="agent_id", + session_id="session_id", + turn_ids=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Session, session, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.session.with_raw_response.retrieve( + agent_id="agent_id", + session_id="session_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + session = await response.parse() + assert_matches_type(Session, session, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.session.with_streaming_response.retrieve( + agent_id="agent_id", + session_id="session_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + session = await response.parse() + assert_matches_type(Session, session, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_delete(self, async_client: AsyncLlamaStackClient) -> None: + session = await async_client.agents.session.delete( + agent_id="agent_id", + session_id="session_id", + ) + assert session is None + + @parametrize + async def test_method_delete_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + session = await async_client.agents.session.delete( + agent_id="agent_id", + session_id="session_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert session is None + + @parametrize + async def test_raw_response_delete(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.session.with_raw_response.delete( + agent_id="agent_id", + session_id="session_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + session = await response.parse() + assert session is None + + @parametrize + async def test_streaming_response_delete(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.session.with_streaming_response.delete( + agent_id="agent_id", + session_id="session_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + session = await response.parse() + assert session is None + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/agents/test_steps.py b/tests/api_resources/agents/test_steps.py index 5c61a819..92b0b1f3 100644 --- a/tests/api_resources/agents/test_steps.py +++ b/tests/api_resources/agents/test_steps.py @@ -9,7 +9,7 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types.agents import AgentsStep +from llama_stack_client.types.agents import StepRetrieveResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -21,25 +21,28 @@ class TestSteps: def test_method_retrieve(self, client: LlamaStackClient) -> None: step = client.agents.steps.retrieve( agent_id="agent_id", + session_id="session_id", step_id="step_id", turn_id="turn_id", ) - assert_matches_type(AgentsStep, step, path=["response"]) + assert_matches_type(StepRetrieveResponse, step, path=["response"]) @parametrize def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: step = client.agents.steps.retrieve( agent_id="agent_id", + session_id="session_id", step_id="step_id", turn_id="turn_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(AgentsStep, step, path=["response"]) + assert_matches_type(StepRetrieveResponse, step, path=["response"]) @parametrize def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: response = client.agents.steps.with_raw_response.retrieve( agent_id="agent_id", + session_id="session_id", step_id="step_id", turn_id="turn_id", ) @@ -47,12 +50,13 @@ def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" step = response.parse() - assert_matches_type(AgentsStep, step, path=["response"]) + assert_matches_type(StepRetrieveResponse, step, path=["response"]) @parametrize def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: with client.agents.steps.with_streaming_response.retrieve( agent_id="agent_id", + session_id="session_id", step_id="step_id", turn_id="turn_id", ) as response: @@ -60,7 +64,7 @@ def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" step = response.parse() - assert_matches_type(AgentsStep, step, path=["response"]) + assert_matches_type(StepRetrieveResponse, step, path=["response"]) assert cast(Any, response.is_closed) is True @@ -72,25 +76,28 @@ class TestAsyncSteps: async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: step = await async_client.agents.steps.retrieve( agent_id="agent_id", + session_id="session_id", step_id="step_id", turn_id="turn_id", ) - assert_matches_type(AgentsStep, step, path=["response"]) + assert_matches_type(StepRetrieveResponse, step, path=["response"]) @parametrize async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: step = await async_client.agents.steps.retrieve( agent_id="agent_id", + session_id="session_id", step_id="step_id", turn_id="turn_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(AgentsStep, step, path=["response"]) + assert_matches_type(StepRetrieveResponse, step, path=["response"]) @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.agents.steps.with_raw_response.retrieve( agent_id="agent_id", + session_id="session_id", step_id="step_id", turn_id="turn_id", ) @@ -98,12 +105,13 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" step = await response.parse() - assert_matches_type(AgentsStep, step, path=["response"]) + assert_matches_type(StepRetrieveResponse, step, path=["response"]) @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.agents.steps.with_streaming_response.retrieve( agent_id="agent_id", + session_id="session_id", step_id="step_id", turn_id="turn_id", ) as response: @@ -111,6 +119,6 @@ async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackCl assert response.http_request.headers.get("X-Stainless-Lang") == "python" step = await response.parse() - assert_matches_type(AgentsStep, step, path=["response"]) + assert_matches_type(StepRetrieveResponse, step, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/agents/test_turn.py b/tests/api_resources/agents/test_turn.py new file mode 100644 index 00000000..5e51e7fe --- /dev/null +++ b/tests/api_resources/agents/test_turn.py @@ -0,0 +1,636 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types.agents import Turn, TurnCreateResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestTurn: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_method_create_overload_1(self, client: LlamaStackClient) -> None: + turn = client.agents.turn.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + ) + assert_matches_type(TurnCreateResponse, turn, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_method_create_with_all_params_overload_1(self, client: LlamaStackClient) -> None: + turn = client.agents.turn.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + session_id="session_id", + attachments=[ + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + ], + stream=False, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(TurnCreateResponse, turn, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_raw_response_create_overload_1(self, client: LlamaStackClient) -> None: + response = client.agents.turn.with_raw_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = response.parse() + assert_matches_type(TurnCreateResponse, turn, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_streaming_response_create_overload_1(self, client: LlamaStackClient) -> None: + with client.agents.turn.with_streaming_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = response.parse() + assert_matches_type(TurnCreateResponse, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_method_create_overload_2(self, client: LlamaStackClient) -> None: + turn_stream = client.agents.turn.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) + turn_stream.response.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_method_create_with_all_params_overload_2(self, client: LlamaStackClient) -> None: + turn_stream = client.agents.turn.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + session_id="session_id", + stream=True, + attachments=[ + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + turn_stream.response.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_raw_response_create_overload_2(self, client: LlamaStackClient) -> None: + response = client.agents.turn.with_raw_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + def test_streaming_response_create_overload_2(self, client: LlamaStackClient) -> None: + with client.agents.turn.with_streaming_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_retrieve(self, client: LlamaStackClient) -> None: + turn = client.agents.turn.retrieve( + agent_id="agent_id", + session_id="session_id", + turn_id="turn_id", + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: + turn = client.agents.turn.retrieve( + agent_id="agent_id", + session_id="session_id", + turn_id="turn_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: + response = client.agents.turn.with_raw_response.retrieve( + agent_id="agent_id", + session_id="session_id", + turn_id="turn_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: + with client.agents.turn.with_streaming_response.retrieve( + agent_id="agent_id", + session_id="session_id", + turn_id="turn_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncTurn: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_method_create_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + turn = await async_client.agents.turn.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + ) + assert_matches_type(TurnCreateResponse, turn, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_method_create_with_all_params_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + turn = await async_client.agents.turn.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + session_id="session_id", + attachments=[ + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + ], + stream=False, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(TurnCreateResponse, turn, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_raw_response_create_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.turn.with_raw_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = await response.parse() + assert_matches_type(TurnCreateResponse, turn, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_streaming_response_create_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.turn.with_streaming_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = await response.parse() + assert_matches_type(TurnCreateResponse, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_method_create_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + turn_stream = await async_client.agents.turn.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) + await turn_stream.response.aclose() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_method_create_with_all_params_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + turn_stream = await async_client.agents.turn.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + session_id="session_id", + stream=True, + attachments=[ + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + { + "content": "string", + "mime_type": "mime_type", + }, + ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + await turn_stream.response.aclose() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_raw_response_create_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.turn.with_raw_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) + @parametrize + async def test_streaming_response_create_overload_2(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.turn.with_streaming_response.create( + agent_id="agent_id", + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + session_id="session_id", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + turn = await async_client.agents.turn.retrieve( + agent_id="agent_id", + session_id="session_id", + turn_id="turn_id", + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + turn = await async_client.agents.turn.retrieve( + agent_id="agent_id", + session_id="session_id", + turn_id="turn_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.agents.turn.with_raw_response.retrieve( + agent_id="agent_id", + session_id="session_id", + turn_id="turn_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + turn = await response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.agents.turn.with_streaming_response.retrieve( + agent_id="agent_id", + session_id="session_id", + turn_id="turn_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + turn = await response.parse() + assert_matches_type(Turn, turn, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/agents/test_turns.py b/tests/api_resources/agents/test_turns.py index e4066f9e..30510a23 100644 --- a/tests/api_resources/agents/test_turns.py +++ b/tests/api_resources/agents/test_turns.py @@ -9,20 +9,17 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types.agents import Turn, TurnCreateResponse +from llama_stack_client.types.agents import Turn, AgentsTurnStreamChunk base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") -class TestTurn: +class TestTurns: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize def test_method_create_overload_1(self, client: LlamaStackClient) -> None: - turn = client.agents.turn.create( + turn = client.agents.turns.create( agent_id="agent_id", messages=[ { @@ -40,14 +37,11 @@ def test_method_create_overload_1(self, client: LlamaStackClient) -> None: ], session_id="session_id", ) - assert_matches_type(TurnCreateResponse, turn, path=["response"]) + assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"]) - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize def test_method_create_with_all_params_overload_1(self, client: LlamaStackClient) -> None: - turn = client.agents.turn.create( + turn = client.agents.turns.create( agent_id="agent_id", messages=[ { @@ -84,14 +78,11 @@ def test_method_create_with_all_params_overload_1(self, client: LlamaStackClient stream=False, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(TurnCreateResponse, turn, path=["response"]) + assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"]) - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize def test_raw_response_create_overload_1(self, client: LlamaStackClient) -> None: - response = client.agents.turn.with_raw_response.create( + response = client.agents.turns.with_raw_response.create( agent_id="agent_id", messages=[ { @@ -113,14 +104,11 @@ def test_raw_response_create_overload_1(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" turn = response.parse() - assert_matches_type(TurnCreateResponse, turn, path=["response"]) + assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"]) - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize def test_streaming_response_create_overload_1(self, client: LlamaStackClient) -> None: - with client.agents.turn.with_streaming_response.create( + with client.agents.turns.with_streaming_response.create( agent_id="agent_id", messages=[ { @@ -142,16 +130,13 @@ def test_streaming_response_create_overload_1(self, client: LlamaStackClient) -> assert response.http_request.headers.get("X-Stainless-Lang") == "python" turn = response.parse() - assert_matches_type(TurnCreateResponse, turn, path=["response"]) + assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"]) assert cast(Any, response.is_closed) is True - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize def test_method_create_overload_2(self, client: LlamaStackClient) -> None: - turn_stream = client.agents.turn.create( + turn_stream = client.agents.turns.create( agent_id="agent_id", messages=[ { @@ -172,12 +157,9 @@ def test_method_create_overload_2(self, client: LlamaStackClient) -> None: ) turn_stream.response.close() - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize def test_method_create_with_all_params_overload_2(self, client: LlamaStackClient) -> None: - turn_stream = client.agents.turn.create( + turn_stream = client.agents.turns.create( agent_id="agent_id", messages=[ { @@ -216,12 +198,9 @@ def test_method_create_with_all_params_overload_2(self, client: LlamaStackClient ) turn_stream.response.close() - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize def test_raw_response_create_overload_2(self, client: LlamaStackClient) -> None: - response = client.agents.turn.with_raw_response.create( + response = client.agents.turns.with_raw_response.create( agent_id="agent_id", messages=[ { @@ -245,12 +224,9 @@ def test_raw_response_create_overload_2(self, client: LlamaStackClient) -> None: stream = response.parse() stream.close() - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize def test_streaming_response_create_overload_2(self, client: LlamaStackClient) -> None: - with client.agents.turn.with_streaming_response.create( + with client.agents.turns.with_streaming_response.create( agent_id="agent_id", messages=[ { @@ -279,7 +255,7 @@ def test_streaming_response_create_overload_2(self, client: LlamaStackClient) -> @parametrize def test_method_retrieve(self, client: LlamaStackClient) -> None: - turn = client.agents.turn.retrieve( + turn = client.agents.turns.retrieve( agent_id="agent_id", turn_id="turn_id", ) @@ -287,7 +263,7 @@ def test_method_retrieve(self, client: LlamaStackClient) -> None: @parametrize def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: - turn = client.agents.turn.retrieve( + turn = client.agents.turns.retrieve( agent_id="agent_id", turn_id="turn_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", @@ -296,7 +272,7 @@ def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None @parametrize def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: - response = client.agents.turn.with_raw_response.retrieve( + response = client.agents.turns.with_raw_response.retrieve( agent_id="agent_id", turn_id="turn_id", ) @@ -308,7 +284,7 @@ def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: @parametrize def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: - with client.agents.turn.with_streaming_response.retrieve( + with client.agents.turns.with_streaming_response.retrieve( agent_id="agent_id", turn_id="turn_id", ) as response: @@ -321,15 +297,12 @@ def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: assert cast(Any, response.is_closed) is True -class TestAsyncTurn: +class TestAsyncTurns: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize async def test_method_create_overload_1(self, async_client: AsyncLlamaStackClient) -> None: - turn = await async_client.agents.turn.create( + turn = await async_client.agents.turns.create( agent_id="agent_id", messages=[ { @@ -347,14 +320,11 @@ async def test_method_create_overload_1(self, async_client: AsyncLlamaStackClien ], session_id="session_id", ) - assert_matches_type(TurnCreateResponse, turn, path=["response"]) + assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"]) - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize async def test_method_create_with_all_params_overload_1(self, async_client: AsyncLlamaStackClient) -> None: - turn = await async_client.agents.turn.create( + turn = await async_client.agents.turns.create( agent_id="agent_id", messages=[ { @@ -391,14 +361,11 @@ async def test_method_create_with_all_params_overload_1(self, async_client: Asyn stream=False, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(TurnCreateResponse, turn, path=["response"]) + assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"]) - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize async def test_raw_response_create_overload_1(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.agents.turn.with_raw_response.create( + response = await async_client.agents.turns.with_raw_response.create( agent_id="agent_id", messages=[ { @@ -420,14 +387,11 @@ async def test_raw_response_create_overload_1(self, async_client: AsyncLlamaStac assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" turn = await response.parse() - assert_matches_type(TurnCreateResponse, turn, path=["response"]) + assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"]) - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize async def test_streaming_response_create_overload_1(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.agents.turn.with_streaming_response.create( + async with async_client.agents.turns.with_streaming_response.create( agent_id="agent_id", messages=[ { @@ -449,16 +413,13 @@ async def test_streaming_response_create_overload_1(self, async_client: AsyncLla assert response.http_request.headers.get("X-Stainless-Lang") == "python" turn = await response.parse() - assert_matches_type(TurnCreateResponse, turn, path=["response"]) + assert_matches_type(AgentsTurnStreamChunk, turn, path=["response"]) assert cast(Any, response.is_closed) is True - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize async def test_method_create_overload_2(self, async_client: AsyncLlamaStackClient) -> None: - turn_stream = await async_client.agents.turn.create( + turn_stream = await async_client.agents.turns.create( agent_id="agent_id", messages=[ { @@ -479,12 +440,9 @@ async def test_method_create_overload_2(self, async_client: AsyncLlamaStackClien ) await turn_stream.response.aclose() - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize async def test_method_create_with_all_params_overload_2(self, async_client: AsyncLlamaStackClient) -> None: - turn_stream = await async_client.agents.turn.create( + turn_stream = await async_client.agents.turns.create( agent_id="agent_id", messages=[ { @@ -523,12 +481,9 @@ async def test_method_create_with_all_params_overload_2(self, async_client: Asyn ) await turn_stream.response.aclose() - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize async def test_raw_response_create_overload_2(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.agents.turn.with_raw_response.create( + response = await async_client.agents.turns.with_raw_response.create( agent_id="agent_id", messages=[ { @@ -552,12 +507,9 @@ async def test_raw_response_create_overload_2(self, async_client: AsyncLlamaStac stream = await response.parse() await stream.close() - @pytest.mark.skip( - reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" - ) @parametrize async def test_streaming_response_create_overload_2(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.agents.turn.with_streaming_response.create( + async with async_client.agents.turns.with_streaming_response.create( agent_id="agent_id", messages=[ { @@ -586,7 +538,7 @@ async def test_streaming_response_create_overload_2(self, async_client: AsyncLla @parametrize async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: - turn = await async_client.agents.turn.retrieve( + turn = await async_client.agents.turns.retrieve( agent_id="agent_id", turn_id="turn_id", ) @@ -594,7 +546,7 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> Non @parametrize async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - turn = await async_client.agents.turn.retrieve( + turn = await async_client.agents.turns.retrieve( agent_id="agent_id", turn_id="turn_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", @@ -603,7 +555,7 @@ async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaSta @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.agents.turn.with_raw_response.retrieve( + response = await async_client.agents.turns.with_raw_response.retrieve( agent_id="agent_id", turn_id="turn_id", ) @@ -615,7 +567,7 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.agents.turn.with_streaming_response.retrieve( + async with async_client.agents.turns.with_streaming_response.retrieve( agent_id="agent_id", turn_id="turn_id", ) as response: diff --git a/tests/api_resources/eval/__init__.py b/tests/api_resources/eval/__init__.py new file mode 100644 index 00000000..fd8019a9 --- /dev/null +++ b/tests/api_resources/eval/__init__.py @@ -0,0 +1 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. diff --git a/tests/api_resources/eval/test_job.py b/tests/api_resources/eval/test_job.py new file mode 100644 index 00000000..ffffca71 --- /dev/null +++ b/tests/api_resources/eval/test_job.py @@ -0,0 +1,259 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, Optional, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types.eval import ( + JobStatus, + JobResultResponse, +) + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestJob: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_cancel(self, client: LlamaStackClient) -> None: + job = client.eval.job.cancel( + job_id="job_id", + ) + assert job is None + + @parametrize + def test_method_cancel_with_all_params(self, client: LlamaStackClient) -> None: + job = client.eval.job.cancel( + job_id="job_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert job is None + + @parametrize + def test_raw_response_cancel(self, client: LlamaStackClient) -> None: + response = client.eval.job.with_raw_response.cancel( + job_id="job_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert job is None + + @parametrize + def test_streaming_response_cancel(self, client: LlamaStackClient) -> None: + with client.eval.job.with_streaming_response.cancel( + job_id="job_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert job is None + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_result(self, client: LlamaStackClient) -> None: + job = client.eval.job.result( + job_id="job_id", + ) + assert_matches_type(JobResultResponse, job, path=["response"]) + + @parametrize + def test_method_result_with_all_params(self, client: LlamaStackClient) -> None: + job = client.eval.job.result( + job_id="job_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(JobResultResponse, job, path=["response"]) + + @parametrize + def test_raw_response_result(self, client: LlamaStackClient) -> None: + response = client.eval.job.with_raw_response.result( + job_id="job_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert_matches_type(JobResultResponse, job, path=["response"]) + + @parametrize + def test_streaming_response_result(self, client: LlamaStackClient) -> None: + with client.eval.job.with_streaming_response.result( + job_id="job_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert_matches_type(JobResultResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_status(self, client: LlamaStackClient) -> None: + job = client.eval.job.status( + job_id="job_id", + ) + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + @parametrize + def test_method_status_with_all_params(self, client: LlamaStackClient) -> None: + job = client.eval.job.status( + job_id="job_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + @parametrize + def test_raw_response_status(self, client: LlamaStackClient) -> None: + response = client.eval.job.with_raw_response.status( + job_id="job_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + @parametrize + def test_streaming_response_status(self, client: LlamaStackClient) -> None: + with client.eval.job.with_streaming_response.status( + job_id="job_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncJob: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_cancel(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.eval.job.cancel( + job_id="job_id", + ) + assert job is None + + @parametrize + async def test_method_cancel_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.eval.job.cancel( + job_id="job_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert job is None + + @parametrize + async def test_raw_response_cancel(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.eval.job.with_raw_response.cancel( + job_id="job_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert job is None + + @parametrize + async def test_streaming_response_cancel(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.eval.job.with_streaming_response.cancel( + job_id="job_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert job is None + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_result(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.eval.job.result( + job_id="job_id", + ) + assert_matches_type(JobResultResponse, job, path=["response"]) + + @parametrize + async def test_method_result_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.eval.job.result( + job_id="job_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(JobResultResponse, job, path=["response"]) + + @parametrize + async def test_raw_response_result(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.eval.job.with_raw_response.result( + job_id="job_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(JobResultResponse, job, path=["response"]) + + @parametrize + async def test_streaming_response_result(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.eval.job.with_streaming_response.result( + job_id="job_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(JobResultResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_status(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.eval.job.status( + job_id="job_id", + ) + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + @parametrize + async def test_method_status_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.eval.job.status( + job_id="job_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + @parametrize + async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.eval.job.with_raw_response.status( + job_id="job_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + @parametrize + async def test_streaming_response_status(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.eval.job.with_streaming_response.status( + job_id="job_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/evaluate/test_jobs.py b/tests/api_resources/evaluate/test_jobs.py index 8a5a35b9..6a60f944 100644 --- a/tests/api_resources/evaluate/test_jobs.py +++ b/tests/api_resources/evaluate/test_jobs.py @@ -3,13 +3,16 @@ from __future__ import annotations import os -from typing import Any, cast +from typing import Any, Optional, cast import pytest from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import EvaluationJob +from llama_stack_client.types.evaluate import ( + JobStatus, + EvaluateResponse, +) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -18,123 +21,137 @@ class TestJobs: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - def test_method_list(self, client: LlamaStackClient) -> None: - job = client.evaluate.jobs.list() - assert_matches_type(EvaluationJob, job, path=["response"]) + def test_method_cancel(self, client: LlamaStackClient) -> None: + job = client.evaluate.jobs.cancel( + job_id="job_id", + ) + assert job is None @parametrize - def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: - job = client.evaluate.jobs.list( + def test_method_cancel_with_all_params(self, client: LlamaStackClient) -> None: + job = client.evaluate.jobs.cancel( + job_id="job_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(EvaluationJob, job, path=["response"]) + assert job is None @parametrize - def test_raw_response_list(self, client: LlamaStackClient) -> None: - response = client.evaluate.jobs.with_raw_response.list() + def test_raw_response_cancel(self, client: LlamaStackClient) -> None: + response = client.evaluate.jobs.with_raw_response.cancel( + job_id="job_id", + ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert_matches_type(EvaluationJob, job, path=["response"]) + assert job is None @parametrize - def test_streaming_response_list(self, client: LlamaStackClient) -> None: - with client.evaluate.jobs.with_streaming_response.list() as response: + def test_streaming_response_cancel(self, client: LlamaStackClient) -> None: + with client.evaluate.jobs.with_streaming_response.cancel( + job_id="job_id", + ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert_matches_type(EvaluationJob, job, path=["response"]) + assert job is None assert cast(Any, response.is_closed) is True @parametrize - def test_method_cancel(self, client: LlamaStackClient) -> None: - job = client.evaluate.jobs.cancel( - job_uuid="job_uuid", + def test_method_result(self, client: LlamaStackClient) -> None: + job = client.evaluate.jobs.result( + job_id="job_id", ) - assert job is None + assert_matches_type(EvaluateResponse, job, path=["response"]) @parametrize - def test_method_cancel_with_all_params(self, client: LlamaStackClient) -> None: - job = client.evaluate.jobs.cancel( - job_uuid="job_uuid", + def test_method_result_with_all_params(self, client: LlamaStackClient) -> None: + job = client.evaluate.jobs.result( + job_id="job_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert job is None + assert_matches_type(EvaluateResponse, job, path=["response"]) @parametrize - def test_raw_response_cancel(self, client: LlamaStackClient) -> None: - response = client.evaluate.jobs.with_raw_response.cancel( - job_uuid="job_uuid", + def test_raw_response_result(self, client: LlamaStackClient) -> None: + response = client.evaluate.jobs.with_raw_response.result( + job_id="job_id", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert job is None + assert_matches_type(EvaluateResponse, job, path=["response"]) @parametrize - def test_streaming_response_cancel(self, client: LlamaStackClient) -> None: - with client.evaluate.jobs.with_streaming_response.cancel( - job_uuid="job_uuid", + def test_streaming_response_result(self, client: LlamaStackClient) -> None: + with client.evaluate.jobs.with_streaming_response.result( + job_id="job_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" job = response.parse() - assert job is None + assert_matches_type(EvaluateResponse, job, path=["response"]) assert cast(Any, response.is_closed) is True - -class TestAsyncJobs: - parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - @parametrize - async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: - job = await async_client.evaluate.jobs.list() - assert_matches_type(EvaluationJob, job, path=["response"]) + def test_method_status(self, client: LlamaStackClient) -> None: + job = client.evaluate.jobs.status( + job_id="job_id", + ) + assert_matches_type(Optional[JobStatus], job, path=["response"]) @parametrize - async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - job = await async_client.evaluate.jobs.list( + def test_method_status_with_all_params(self, client: LlamaStackClient) -> None: + job = client.evaluate.jobs.status( + job_id="job_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(EvaluationJob, job, path=["response"]) + assert_matches_type(Optional[JobStatus], job, path=["response"]) @parametrize - async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.evaluate.jobs.with_raw_response.list() + def test_raw_response_status(self, client: LlamaStackClient) -> None: + response = client.evaluate.jobs.with_raw_response.status( + job_id="job_id", + ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - job = await response.parse() - assert_matches_type(EvaluationJob, job, path=["response"]) + job = response.parse() + assert_matches_type(Optional[JobStatus], job, path=["response"]) @parametrize - async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.evaluate.jobs.with_streaming_response.list() as response: + def test_streaming_response_status(self, client: LlamaStackClient) -> None: + with client.evaluate.jobs.with_streaming_response.status( + job_id="job_id", + ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" - job = await response.parse() - assert_matches_type(EvaluationJob, job, path=["response"]) + job = response.parse() + assert_matches_type(Optional[JobStatus], job, path=["response"]) assert cast(Any, response.is_closed) is True + +class TestAsyncJobs: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @parametrize async def test_method_cancel(self, async_client: AsyncLlamaStackClient) -> None: job = await async_client.evaluate.jobs.cancel( - job_uuid="job_uuid", + job_id="job_id", ) assert job is None @parametrize async def test_method_cancel_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: job = await async_client.evaluate.jobs.cancel( - job_uuid="job_uuid", + job_id="job_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert job is None @@ -142,7 +159,7 @@ async def test_method_cancel_with_all_params(self, async_client: AsyncLlamaStack @parametrize async def test_raw_response_cancel(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.evaluate.jobs.with_raw_response.cancel( - job_uuid="job_uuid", + job_id="job_id", ) assert response.is_closed is True @@ -153,7 +170,7 @@ async def test_raw_response_cancel(self, async_client: AsyncLlamaStackClient) -> @parametrize async def test_streaming_response_cancel(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.evaluate.jobs.with_streaming_response.cancel( - job_uuid="job_uuid", + job_id="job_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -162,3 +179,81 @@ async def test_streaming_response_cancel(self, async_client: AsyncLlamaStackClie assert job is None assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_result(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.evaluate.jobs.result( + job_id="job_id", + ) + assert_matches_type(EvaluateResponse, job, path=["response"]) + + @parametrize + async def test_method_result_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.evaluate.jobs.result( + job_id="job_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluateResponse, job, path=["response"]) + + @parametrize + async def test_raw_response_result(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.evaluate.jobs.with_raw_response.result( + job_id="job_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(EvaluateResponse, job, path=["response"]) + + @parametrize + async def test_streaming_response_result(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.evaluate.jobs.with_streaming_response.result( + job_id="job_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(EvaluateResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_status(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.evaluate.jobs.status( + job_id="job_id", + ) + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + @parametrize + async def test_method_status_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.evaluate.jobs.status( + job_id="job_id", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + @parametrize + async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.evaluate.jobs.with_raw_response.status( + job_id="job_id", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + @parametrize + async def test_streaming_response_status(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.evaluate.jobs.with_streaming_response.status( + job_id="job_id", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(Optional[JobStatus], job, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/memory/test_documents.py b/tests/api_resources/memory/test_documents.py index d4041354..765aa1a2 100644 --- a/tests/api_resources/memory/test_documents.py +++ b/tests/api_resources/memory/test_documents.py @@ -17,6 +17,9 @@ class TestDocuments: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_method_retrieve(self, client: LlamaStackClient) -> None: document = client.memory.documents.retrieve( @@ -25,6 +28,9 @@ def test_method_retrieve(self, client: LlamaStackClient) -> None: ) assert_matches_type(DocumentRetrieveResponse, document, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: document = client.memory.documents.retrieve( @@ -34,6 +40,9 @@ def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None ) assert_matches_type(DocumentRetrieveResponse, document, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: response = client.memory.documents.with_raw_response.retrieve( @@ -46,6 +55,9 @@ def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: document = response.parse() assert_matches_type(DocumentRetrieveResponse, document, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: with client.memory.documents.with_streaming_response.retrieve( @@ -107,6 +119,9 @@ def test_streaming_response_delete(self, client: LlamaStackClient) -> None: class TestAsyncDocuments: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: document = await async_client.memory.documents.retrieve( @@ -115,6 +130,9 @@ async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> Non ) assert_matches_type(DocumentRetrieveResponse, document, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: document = await async_client.memory.documents.retrieve( @@ -124,6 +142,9 @@ async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaSta ) assert_matches_type(DocumentRetrieveResponse, document, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.memory.documents.with_raw_response.retrieve( @@ -136,6 +157,9 @@ async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) document = await response.parse() assert_matches_type(DocumentRetrieveResponse, document, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.memory.documents.with_streaming_response.retrieve( diff --git a/tests/api_resources/post_training/test_job.py b/tests/api_resources/post_training/test_job.py new file mode 100644 index 00000000..db6e5e54 --- /dev/null +++ b/tests/api_resources/post_training/test_job.py @@ -0,0 +1,427 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types import PostTrainingJob +from llama_stack_client.types.post_training import ( + JobLogsResponse, + JobStatusResponse, + JobArtifactsResponse, +) + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestJob: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + def test_method_list(self, client: LlamaStackClient) -> None: + job = client.post_training.job.list() + assert_matches_type(PostTrainingJob, job, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: + job = client.post_training.job.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PostTrainingJob, job, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + def test_raw_response_list(self, client: LlamaStackClient) -> None: + response = client.post_training.job.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert_matches_type(PostTrainingJob, job, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + def test_streaming_response_list(self, client: LlamaStackClient) -> None: + with client.post_training.job.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert_matches_type(PostTrainingJob, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_artifacts(self, client: LlamaStackClient) -> None: + job = client.post_training.job.artifacts( + job_uuid="job_uuid", + ) + assert_matches_type(JobArtifactsResponse, job, path=["response"]) + + @parametrize + def test_method_artifacts_with_all_params(self, client: LlamaStackClient) -> None: + job = client.post_training.job.artifacts( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(JobArtifactsResponse, job, path=["response"]) + + @parametrize + def test_raw_response_artifacts(self, client: LlamaStackClient) -> None: + response = client.post_training.job.with_raw_response.artifacts( + job_uuid="job_uuid", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert_matches_type(JobArtifactsResponse, job, path=["response"]) + + @parametrize + def test_streaming_response_artifacts(self, client: LlamaStackClient) -> None: + with client.post_training.job.with_streaming_response.artifacts( + job_uuid="job_uuid", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert_matches_type(JobArtifactsResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_cancel(self, client: LlamaStackClient) -> None: + job = client.post_training.job.cancel( + job_uuid="job_uuid", + ) + assert job is None + + @parametrize + def test_method_cancel_with_all_params(self, client: LlamaStackClient) -> None: + job = client.post_training.job.cancel( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert job is None + + @parametrize + def test_raw_response_cancel(self, client: LlamaStackClient) -> None: + response = client.post_training.job.with_raw_response.cancel( + job_uuid="job_uuid", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert job is None + + @parametrize + def test_streaming_response_cancel(self, client: LlamaStackClient) -> None: + with client.post_training.job.with_streaming_response.cancel( + job_uuid="job_uuid", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert job is None + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_logs(self, client: LlamaStackClient) -> None: + job = client.post_training.job.logs( + job_uuid="job_uuid", + ) + assert_matches_type(JobLogsResponse, job, path=["response"]) + + @parametrize + def test_method_logs_with_all_params(self, client: LlamaStackClient) -> None: + job = client.post_training.job.logs( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(JobLogsResponse, job, path=["response"]) + + @parametrize + def test_raw_response_logs(self, client: LlamaStackClient) -> None: + response = client.post_training.job.with_raw_response.logs( + job_uuid="job_uuid", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert_matches_type(JobLogsResponse, job, path=["response"]) + + @parametrize + def test_streaming_response_logs(self, client: LlamaStackClient) -> None: + with client.post_training.job.with_streaming_response.logs( + job_uuid="job_uuid", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert_matches_type(JobLogsResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_status(self, client: LlamaStackClient) -> None: + job = client.post_training.job.status( + job_uuid="job_uuid", + ) + assert_matches_type(JobStatusResponse, job, path=["response"]) + + @parametrize + def test_method_status_with_all_params(self, client: LlamaStackClient) -> None: + job = client.post_training.job.status( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(JobStatusResponse, job, path=["response"]) + + @parametrize + def test_raw_response_status(self, client: LlamaStackClient) -> None: + response = client.post_training.job.with_raw_response.status( + job_uuid="job_uuid", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = response.parse() + assert_matches_type(JobStatusResponse, job, path=["response"]) + + @parametrize + def test_streaming_response_status(self, client: LlamaStackClient) -> None: + with client.post_training.job.with_streaming_response.status( + job_uuid="job_uuid", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = response.parse() + assert_matches_type(JobStatusResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncJob: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.post_training.job.list() + assert_matches_type(PostTrainingJob, job, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.post_training.job.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PostTrainingJob, job, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.post_training.job.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(PostTrainingJob, job, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.post_training.job.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(PostTrainingJob, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_artifacts(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.post_training.job.artifacts( + job_uuid="job_uuid", + ) + assert_matches_type(JobArtifactsResponse, job, path=["response"]) + + @parametrize + async def test_method_artifacts_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.post_training.job.artifacts( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(JobArtifactsResponse, job, path=["response"]) + + @parametrize + async def test_raw_response_artifacts(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.post_training.job.with_raw_response.artifacts( + job_uuid="job_uuid", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(JobArtifactsResponse, job, path=["response"]) + + @parametrize + async def test_streaming_response_artifacts(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.post_training.job.with_streaming_response.artifacts( + job_uuid="job_uuid", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(JobArtifactsResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_cancel(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.post_training.job.cancel( + job_uuid="job_uuid", + ) + assert job is None + + @parametrize + async def test_method_cancel_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.post_training.job.cancel( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert job is None + + @parametrize + async def test_raw_response_cancel(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.post_training.job.with_raw_response.cancel( + job_uuid="job_uuid", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert job is None + + @parametrize + async def test_streaming_response_cancel(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.post_training.job.with_streaming_response.cancel( + job_uuid="job_uuid", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert job is None + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_logs(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.post_training.job.logs( + job_uuid="job_uuid", + ) + assert_matches_type(JobLogsResponse, job, path=["response"]) + + @parametrize + async def test_method_logs_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.post_training.job.logs( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(JobLogsResponse, job, path=["response"]) + + @parametrize + async def test_raw_response_logs(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.post_training.job.with_raw_response.logs( + job_uuid="job_uuid", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(JobLogsResponse, job, path=["response"]) + + @parametrize + async def test_streaming_response_logs(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.post_training.job.with_streaming_response.logs( + job_uuid="job_uuid", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(JobLogsResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_status(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.post_training.job.status( + job_uuid="job_uuid", + ) + assert_matches_type(JobStatusResponse, job, path=["response"]) + + @parametrize + async def test_method_status_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + job = await async_client.post_training.job.status( + job_uuid="job_uuid", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(JobStatusResponse, job, path=["response"]) + + @parametrize + async def test_raw_response_status(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.post_training.job.with_raw_response.status( + job_uuid="job_uuid", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + job = await response.parse() + assert_matches_type(JobStatusResponse, job, path=["response"]) + + @parametrize + async def test_streaming_response_status(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.post_training.job.with_streaming_response.status( + job_uuid="job_uuid", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + job = await response.parse() + assert_matches_type(JobStatusResponse, job, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/post_training/test_jobs.py b/tests/api_resources/post_training/test_jobs.py index 2c41d2df..400f5aa4 100644 --- a/tests/api_resources/post_training/test_jobs.py +++ b/tests/api_resources/post_training/test_jobs.py @@ -22,11 +22,17 @@ class TestJobs: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_method_list(self, client: LlamaStackClient) -> None: job = client.post_training.jobs.list() assert_matches_type(PostTrainingJob, job, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: job = client.post_training.jobs.list( @@ -34,6 +40,9 @@ def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: ) assert_matches_type(PostTrainingJob, job, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_raw_response_list(self, client: LlamaStackClient) -> None: response = client.post_training.jobs.with_raw_response.list() @@ -43,6 +52,9 @@ def test_raw_response_list(self, client: LlamaStackClient) -> None: job = response.parse() assert_matches_type(PostTrainingJob, job, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_streaming_response_list(self, client: LlamaStackClient) -> None: with client.post_training.jobs.with_streaming_response.list() as response: @@ -214,11 +226,17 @@ def test_streaming_response_status(self, client: LlamaStackClient) -> None: class TestAsyncJobs: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: job = await async_client.post_training.jobs.list() assert_matches_type(PostTrainingJob, job, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: job = await async_client.post_training.jobs.list( @@ -226,6 +244,9 @@ async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackCl ) assert_matches_type(PostTrainingJob, job, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.post_training.jobs.with_raw_response.list() @@ -235,6 +256,9 @@ async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> N job = await response.parse() assert_matches_type(PostTrainingJob, job, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.post_training.jobs.with_streaming_response.list() as response: diff --git a/tests/api_resources/test_batch_inferences.py b/tests/api_resources/test_batch_inferences.py new file mode 100644 index 00000000..d323e62a --- /dev/null +++ b/tests/api_resources/test_batch_inferences.py @@ -0,0 +1,675 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types import ( + BatchInferenceChatCompletionResponse, +) +from llama_stack_client.types.shared import BatchCompletion + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestBatchInferences: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_chat_completion(self, client: LlamaStackClient) -> None: + batch_inference = client.batch_inferences.chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", + ) + assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) + + @parametrize + def test_method_chat_completion_with_all_params(self, client: LlamaStackClient) -> None: + batch_inference = client.batch_inferences.chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + ], + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + tool_choice="auto", + tool_prompt_format="json", + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "default": True, + "description": "description", + "required": True, + } + }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "default": True, + "description": "description", + "required": True, + } + }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "default": True, + "description": "description", + "required": True, + } + }, + }, + ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) + + @parametrize + def test_raw_response_chat_completion(self, client: LlamaStackClient) -> None: + response = client.batch_inferences.with_raw_response.chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + batch_inference = response.parse() + assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) + + @parametrize + def test_streaming_response_chat_completion(self, client: LlamaStackClient) -> None: + with client.batch_inferences.with_streaming_response.chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + batch_inference = response.parse() + assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_completion(self, client: LlamaStackClient) -> None: + batch_inference = client.batch_inferences.completion( + content_batch=["string", "string", "string"], + model="model", + ) + assert_matches_type(BatchCompletion, batch_inference, path=["response"]) + + @parametrize + def test_method_completion_with_all_params(self, client: LlamaStackClient) -> None: + batch_inference = client.batch_inferences.completion( + content_batch=["string", "string", "string"], + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(BatchCompletion, batch_inference, path=["response"]) + + @parametrize + def test_raw_response_completion(self, client: LlamaStackClient) -> None: + response = client.batch_inferences.with_raw_response.completion( + content_batch=["string", "string", "string"], + model="model", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + batch_inference = response.parse() + assert_matches_type(BatchCompletion, batch_inference, path=["response"]) + + @parametrize + def test_streaming_response_completion(self, client: LlamaStackClient) -> None: + with client.batch_inferences.with_streaming_response.completion( + content_batch=["string", "string", "string"], + model="model", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + batch_inference = response.parse() + assert_matches_type(BatchCompletion, batch_inference, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncBatchInferences: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: + batch_inference = await async_client.batch_inferences.chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", + ) + assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) + + @parametrize + async def test_method_chat_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + batch_inference = await async_client.batch_inferences.chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + [ + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + { + "content": "string", + "role": "user", + "context": "string", + }, + ], + ], + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + tool_choice="auto", + tool_prompt_format="json", + tools=[ + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "default": True, + "description": "description", + "required": True, + } + }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "default": True, + "description": "description", + "required": True, + } + }, + }, + { + "tool_name": "brave_search", + "description": "description", + "parameters": { + "foo": { + "param_type": "param_type", + "default": True, + "description": "description", + "required": True, + } + }, + }, + ], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) + + @parametrize + async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.batch_inferences.with_raw_response.chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + batch_inference = await response.parse() + assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) + + @parametrize + async def test_streaming_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.batch_inferences.with_streaming_response.chat_completion( + messages_batch=[ + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + [ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + ], + model="model", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + batch_inference = await response.parse() + assert_matches_type(BatchInferenceChatCompletionResponse, batch_inference, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_completion(self, async_client: AsyncLlamaStackClient) -> None: + batch_inference = await async_client.batch_inferences.completion( + content_batch=["string", "string", "string"], + model="model", + ) + assert_matches_type(BatchCompletion, batch_inference, path=["response"]) + + @parametrize + async def test_method_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + batch_inference = await async_client.batch_inferences.completion( + content_batch=["string", "string", "string"], + model="model", + logprobs={"top_k": 0}, + sampling_params={ + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(BatchCompletion, batch_inference, path=["response"]) + + @parametrize + async def test_raw_response_completion(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.batch_inferences.with_raw_response.completion( + content_batch=["string", "string", "string"], + model="model", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + batch_inference = await response.parse() + assert_matches_type(BatchCompletion, batch_inference, path=["response"]) + + @parametrize + async def test_streaming_response_completion(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.batch_inferences.with_streaming_response.completion( + content_batch=["string", "string", "string"], + model="model", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + batch_inference = await response.parse() + assert_matches_type(BatchCompletion, batch_inference, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_datasetio.py b/tests/api_resources/test_datasetio.py new file mode 100644 index 00000000..80053733 --- /dev/null +++ b/tests/api_resources/test_datasetio.py @@ -0,0 +1,112 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types import PaginatedRowsResult + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestDatasetio: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_get_rows_paginated(self, client: LlamaStackClient) -> None: + datasetio = client.datasetio.get_rows_paginated( + dataset_id="dataset_id", + rows_in_page=0, + ) + assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) + + @parametrize + def test_method_get_rows_paginated_with_all_params(self, client: LlamaStackClient) -> None: + datasetio = client.datasetio.get_rows_paginated( + dataset_id="dataset_id", + rows_in_page=0, + filter_condition="filter_condition", + page_token="page_token", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) + + @parametrize + def test_raw_response_get_rows_paginated(self, client: LlamaStackClient) -> None: + response = client.datasetio.with_raw_response.get_rows_paginated( + dataset_id="dataset_id", + rows_in_page=0, + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + datasetio = response.parse() + assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) + + @parametrize + def test_streaming_response_get_rows_paginated(self, client: LlamaStackClient) -> None: + with client.datasetio.with_streaming_response.get_rows_paginated( + dataset_id="dataset_id", + rows_in_page=0, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + datasetio = response.parse() + assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncDatasetio: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_get_rows_paginated(self, async_client: AsyncLlamaStackClient) -> None: + datasetio = await async_client.datasetio.get_rows_paginated( + dataset_id="dataset_id", + rows_in_page=0, + ) + assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) + + @parametrize + async def test_method_get_rows_paginated_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + datasetio = await async_client.datasetio.get_rows_paginated( + dataset_id="dataset_id", + rows_in_page=0, + filter_condition="filter_condition", + page_token="page_token", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) + + @parametrize + async def test_raw_response_get_rows_paginated(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.datasetio.with_raw_response.get_rows_paginated( + dataset_id="dataset_id", + rows_in_page=0, + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + datasetio = await response.parse() + assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) + + @parametrize + async def test_streaming_response_get_rows_paginated(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.datasetio.with_streaming_response.get_rows_paginated( + dataset_id="dataset_id", + rows_in_page=0, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + datasetio = await response.parse() + assert_matches_type(PaginatedRowsResult, datasetio, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_datasets.py b/tests/api_resources/test_datasets.py index e6a3f571..c85ce04b 100644 --- a/tests/api_resources/test_datasets.py +++ b/tests/api_resources/test_datasets.py @@ -3,13 +3,16 @@ from __future__ import annotations import os -from typing import Any, cast +from typing import Any, Optional, cast import pytest from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import TrainEvalDataset +from llama_stack_client.types import ( + DatasetListResponse, + DatasetRetrieveResponse, +) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -18,136 +21,148 @@ class TestDatasets: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - def test_method_create(self, client: LlamaStackClient) -> None: - dataset = client.datasets.create( - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - uuid="uuid", + def test_method_retrieve(self, client: LlamaStackClient) -> None: + dataset = client.datasets.retrieve( + dataset_identifier="dataset_identifier", ) - assert dataset is None + assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) @parametrize - def test_method_create_with_all_params(self, client: LlamaStackClient) -> None: - dataset = client.datasets.create( - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - "metadata": {"foo": True}, - }, - uuid="uuid", + def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: + dataset = client.datasets.retrieve( + dataset_identifier="dataset_identifier", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert dataset is None + assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) @parametrize - def test_raw_response_create(self, client: LlamaStackClient) -> None: - response = client.datasets.with_raw_response.create( - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - uuid="uuid", + def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: + response = client.datasets.with_raw_response.retrieve( + dataset_identifier="dataset_identifier", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = response.parse() - assert dataset is None + assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) @parametrize - def test_streaming_response_create(self, client: LlamaStackClient) -> None: - with client.datasets.with_streaming_response.create( - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - uuid="uuid", + def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: + with client.datasets.with_streaming_response.retrieve( + dataset_identifier="dataset_identifier", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = response.parse() - assert dataset is None + assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) assert cast(Any, response.is_closed) is True + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize - def test_method_delete(self, client: LlamaStackClient) -> None: - dataset = client.datasets.delete( - dataset_uuid="dataset_uuid", - ) - assert dataset is None + def test_method_list(self, client: LlamaStackClient) -> None: + dataset = client.datasets.list() + assert_matches_type(DatasetListResponse, dataset, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize - def test_method_delete_with_all_params(self, client: LlamaStackClient) -> None: - dataset = client.datasets.delete( - dataset_uuid="dataset_uuid", + def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: + dataset = client.datasets.list( x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert dataset is None + assert_matches_type(DatasetListResponse, dataset, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize - def test_raw_response_delete(self, client: LlamaStackClient) -> None: - response = client.datasets.with_raw_response.delete( - dataset_uuid="dataset_uuid", - ) + def test_raw_response_list(self, client: LlamaStackClient) -> None: + response = client.datasets.with_raw_response.list() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = response.parse() - assert dataset is None + assert_matches_type(DatasetListResponse, dataset, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize - def test_streaming_response_delete(self, client: LlamaStackClient) -> None: - with client.datasets.with_streaming_response.delete( - dataset_uuid="dataset_uuid", - ) as response: + def test_streaming_response_list(self, client: LlamaStackClient) -> None: + with client.datasets.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = response.parse() - assert dataset is None + assert_matches_type(DatasetListResponse, dataset, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - def test_method_get(self, client: LlamaStackClient) -> None: - dataset = client.datasets.get( - dataset_uuid="dataset_uuid", + def test_method_register(self, client: LlamaStackClient) -> None: + dataset = client.datasets.register( + dataset_def={ + "dataset_schema": {"foo": {"type": "string"}}, + "identifier": "identifier", + "metadata": {"foo": True}, + "provider_id": "provider_id", + "url": "https://example.com", + }, ) - assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + assert dataset is None @parametrize - def test_method_get_with_all_params(self, client: LlamaStackClient) -> None: - dataset = client.datasets.get( - dataset_uuid="dataset_uuid", + def test_method_register_with_all_params(self, client: LlamaStackClient) -> None: + dataset = client.datasets.register( + dataset_def={ + "dataset_schema": {"foo": {"type": "string"}}, + "identifier": "identifier", + "metadata": {"foo": True}, + "provider_id": "provider_id", + "url": "https://example.com", + }, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + assert dataset is None @parametrize - def test_raw_response_get(self, client: LlamaStackClient) -> None: - response = client.datasets.with_raw_response.get( - dataset_uuid="dataset_uuid", + def test_raw_response_register(self, client: LlamaStackClient) -> None: + response = client.datasets.with_raw_response.register( + dataset_def={ + "dataset_schema": {"foo": {"type": "string"}}, + "identifier": "identifier", + "metadata": {"foo": True}, + "provider_id": "provider_id", + "url": "https://example.com", + }, ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = response.parse() - assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + assert dataset is None @parametrize - def test_streaming_response_get(self, client: LlamaStackClient) -> None: - with client.datasets.with_streaming_response.get( - dataset_uuid="dataset_uuid", + def test_streaming_response_register(self, client: LlamaStackClient) -> None: + with client.datasets.with_streaming_response.register( + dataset_def={ + "dataset_schema": {"foo": {"type": "string"}}, + "identifier": "identifier", + "metadata": {"foo": True}, + "provider_id": "provider_id", + "url": "https://example.com", + }, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = response.parse() - assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + assert dataset is None assert cast(Any, response.is_closed) is True @@ -156,135 +171,147 @@ class TestAsyncDatasets: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) @parametrize - async def test_method_create(self, async_client: AsyncLlamaStackClient) -> None: - dataset = await async_client.datasets.create( - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - uuid="uuid", + async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + dataset = await async_client.datasets.retrieve( + dataset_identifier="dataset_identifier", ) - assert dataset is None + assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) @parametrize - async def test_method_create_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - dataset = await async_client.datasets.create( - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - "metadata": {"foo": True}, - }, - uuid="uuid", + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + dataset = await async_client.datasets.retrieve( + dataset_identifier="dataset_identifier", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert dataset is None + assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) @parametrize - async def test_raw_response_create(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.datasets.with_raw_response.create( - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - uuid="uuid", + async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.datasets.with_raw_response.retrieve( + dataset_identifier="dataset_identifier", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = await response.parse() - assert dataset is None + assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) @parametrize - async def test_streaming_response_create(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.datasets.with_streaming_response.create( - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, - uuid="uuid", + async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.datasets.with_streaming_response.retrieve( + dataset_identifier="dataset_identifier", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = await response.parse() - assert dataset is None + assert_matches_type(Optional[DatasetRetrieveResponse], dataset, path=["response"]) assert cast(Any, response.is_closed) is True + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize - async def test_method_delete(self, async_client: AsyncLlamaStackClient) -> None: - dataset = await async_client.datasets.delete( - dataset_uuid="dataset_uuid", - ) - assert dataset is None + async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: + dataset = await async_client.datasets.list() + assert_matches_type(DatasetListResponse, dataset, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize - async def test_method_delete_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - dataset = await async_client.datasets.delete( - dataset_uuid="dataset_uuid", + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + dataset = await async_client.datasets.list( x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert dataset is None + assert_matches_type(DatasetListResponse, dataset, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize - async def test_raw_response_delete(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.datasets.with_raw_response.delete( - dataset_uuid="dataset_uuid", - ) + async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.datasets.with_raw_response.list() assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = await response.parse() - assert dataset is None + assert_matches_type(DatasetListResponse, dataset, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize - async def test_streaming_response_delete(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.datasets.with_streaming_response.delete( - dataset_uuid="dataset_uuid", - ) as response: + async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.datasets.with_streaming_response.list() as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = await response.parse() - assert dataset is None + assert_matches_type(DatasetListResponse, dataset, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - async def test_method_get(self, async_client: AsyncLlamaStackClient) -> None: - dataset = await async_client.datasets.get( - dataset_uuid="dataset_uuid", + async def test_method_register(self, async_client: AsyncLlamaStackClient) -> None: + dataset = await async_client.datasets.register( + dataset_def={ + "dataset_schema": {"foo": {"type": "string"}}, + "identifier": "identifier", + "metadata": {"foo": True}, + "provider_id": "provider_id", + "url": "https://example.com", + }, ) - assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + assert dataset is None @parametrize - async def test_method_get_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - dataset = await async_client.datasets.get( - dataset_uuid="dataset_uuid", + async def test_method_register_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + dataset = await async_client.datasets.register( + dataset_def={ + "dataset_schema": {"foo": {"type": "string"}}, + "identifier": "identifier", + "metadata": {"foo": True}, + "provider_id": "provider_id", + "url": "https://example.com", + }, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + assert dataset is None @parametrize - async def test_raw_response_get(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.datasets.with_raw_response.get( - dataset_uuid="dataset_uuid", + async def test_raw_response_register(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.datasets.with_raw_response.register( + dataset_def={ + "dataset_schema": {"foo": {"type": "string"}}, + "identifier": "identifier", + "metadata": {"foo": True}, + "provider_id": "provider_id", + "url": "https://example.com", + }, ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = await response.parse() - assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + assert dataset is None @parametrize - async def test_streaming_response_get(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.datasets.with_streaming_response.get( - dataset_uuid="dataset_uuid", + async def test_streaming_response_register(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.datasets.with_streaming_response.register( + dataset_def={ + "dataset_schema": {"foo": {"type": "string"}}, + "identifier": "identifier", + "metadata": {"foo": True}, + "provider_id": "provider_id", + "url": "https://example.com", + }, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" dataset = await response.parse() - assert_matches_type(TrainEvalDataset, dataset, path=["response"]) + assert dataset is None assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_eval.py b/tests/api_resources/test_eval.py new file mode 100644 index 00000000..9a117c17 --- /dev/null +++ b/tests/api_resources/test_eval.py @@ -0,0 +1,318 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types import Job, EvalEvaluateResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestEval: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_evaluate(self, client: LlamaStackClient) -> None: + eval = client.eval.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(EvalEvaluateResponse, eval, path=["response"]) + + @parametrize + def test_method_evaluate_with_all_params(self, client: LlamaStackClient) -> None: + eval = client.eval.evaluate( + candidate={ + "model": "model", + "sampling_params": { + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + "type": "model", + "system_message": { + "content": "string", + "role": "system", + }, + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvalEvaluateResponse, eval, path=["response"]) + + @parametrize + def test_raw_response_evaluate(self, client: LlamaStackClient) -> None: + response = client.eval.with_raw_response.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + eval = response.parse() + assert_matches_type(EvalEvaluateResponse, eval, path=["response"]) + + @parametrize + def test_streaming_response_evaluate(self, client: LlamaStackClient) -> None: + with client.eval.with_streaming_response.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + eval = response.parse() + assert_matches_type(EvalEvaluateResponse, eval, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_evaluate_batch(self, client: LlamaStackClient) -> None: + eval = client.eval.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(Job, eval, path=["response"]) + + @parametrize + def test_method_evaluate_batch_with_all_params(self, client: LlamaStackClient) -> None: + eval = client.eval.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": { + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + "type": "model", + "system_message": { + "content": "string", + "role": "system", + }, + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Job, eval, path=["response"]) + + @parametrize + def test_raw_response_evaluate_batch(self, client: LlamaStackClient) -> None: + response = client.eval.with_raw_response.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + eval = response.parse() + assert_matches_type(Job, eval, path=["response"]) + + @parametrize + def test_streaming_response_evaluate_batch(self, client: LlamaStackClient) -> None: + with client.eval.with_streaming_response.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + eval = response.parse() + assert_matches_type(Job, eval, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncEval: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_evaluate(self, async_client: AsyncLlamaStackClient) -> None: + eval = await async_client.eval.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(EvalEvaluateResponse, eval, path=["response"]) + + @parametrize + async def test_method_evaluate_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + eval = await async_client.eval.evaluate( + candidate={ + "model": "model", + "sampling_params": { + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + "type": "model", + "system_message": { + "content": "string", + "role": "system", + }, + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvalEvaluateResponse, eval, path=["response"]) + + @parametrize + async def test_raw_response_evaluate(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.eval.with_raw_response.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + eval = await response.parse() + assert_matches_type(EvalEvaluateResponse, eval, path=["response"]) + + @parametrize + async def test_streaming_response_evaluate(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.eval.with_streaming_response.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + eval = await response.parse() + assert_matches_type(EvalEvaluateResponse, eval, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_evaluate_batch(self, async_client: AsyncLlamaStackClient) -> None: + eval = await async_client.eval.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(Job, eval, path=["response"]) + + @parametrize + async def test_method_evaluate_batch_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + eval = await async_client.eval.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": { + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + "type": "model", + "system_message": { + "content": "string", + "role": "system", + }, + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Job, eval, path=["response"]) + + @parametrize + async def test_raw_response_evaluate_batch(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.eval.with_raw_response.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + eval = await response.parse() + assert_matches_type(Job, eval, path=["response"]) + + @parametrize + async def test_streaming_response_evaluate_batch(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.eval.with_streaming_response.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + eval = await response.parse() + assert_matches_type(Job, eval, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_evaluate.py b/tests/api_resources/test_evaluate.py new file mode 100644 index 00000000..0c8283f5 --- /dev/null +++ b/tests/api_resources/test_evaluate.py @@ -0,0 +1,319 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types import Job +from llama_stack_client.types.evaluate import EvaluateResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestEvaluate: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_evaluate(self, client: LlamaStackClient) -> None: + evaluate = client.evaluate.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(EvaluateResponse, evaluate, path=["response"]) + + @parametrize + def test_method_evaluate_with_all_params(self, client: LlamaStackClient) -> None: + evaluate = client.evaluate.evaluate( + candidate={ + "model": "model", + "sampling_params": { + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + "type": "model", + "system_message": { + "content": "string", + "role": "system", + }, + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluateResponse, evaluate, path=["response"]) + + @parametrize + def test_raw_response_evaluate(self, client: LlamaStackClient) -> None: + response = client.evaluate.with_raw_response.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + evaluate = response.parse() + assert_matches_type(EvaluateResponse, evaluate, path=["response"]) + + @parametrize + def test_streaming_response_evaluate(self, client: LlamaStackClient) -> None: + with client.evaluate.with_streaming_response.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + evaluate = response.parse() + assert_matches_type(EvaluateResponse, evaluate, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_evaluate_batch(self, client: LlamaStackClient) -> None: + evaluate = client.evaluate.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(Job, evaluate, path=["response"]) + + @parametrize + def test_method_evaluate_batch_with_all_params(self, client: LlamaStackClient) -> None: + evaluate = client.evaluate.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": { + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + "type": "model", + "system_message": { + "content": "string", + "role": "system", + }, + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Job, evaluate, path=["response"]) + + @parametrize + def test_raw_response_evaluate_batch(self, client: LlamaStackClient) -> None: + response = client.evaluate.with_raw_response.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + evaluate = response.parse() + assert_matches_type(Job, evaluate, path=["response"]) + + @parametrize + def test_streaming_response_evaluate_batch(self, client: LlamaStackClient) -> None: + with client.evaluate.with_streaming_response.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + evaluate = response.parse() + assert_matches_type(Job, evaluate, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncEvaluate: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_evaluate(self, async_client: AsyncLlamaStackClient) -> None: + evaluate = await async_client.evaluate.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(EvaluateResponse, evaluate, path=["response"]) + + @parametrize + async def test_method_evaluate_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + evaluate = await async_client.evaluate.evaluate( + candidate={ + "model": "model", + "sampling_params": { + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + "type": "model", + "system_message": { + "content": "string", + "role": "system", + }, + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(EvaluateResponse, evaluate, path=["response"]) + + @parametrize + async def test_raw_response_evaluate(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.evaluate.with_raw_response.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + evaluate = await response.parse() + assert_matches_type(EvaluateResponse, evaluate, path=["response"]) + + @parametrize + async def test_streaming_response_evaluate(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.evaluate.with_streaming_response.evaluate( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + evaluate = await response.parse() + assert_matches_type(EvaluateResponse, evaluate, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_evaluate_batch(self, async_client: AsyncLlamaStackClient) -> None: + evaluate = await async_client.evaluate.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(Job, evaluate, path=["response"]) + + @parametrize + async def test_method_evaluate_batch_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + evaluate = await async_client.evaluate.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": { + "strategy": "greedy", + "max_tokens": 0, + "repetition_penalty": 0, + "temperature": 0, + "top_k": 0, + "top_p": 0, + }, + "type": "model", + "system_message": { + "content": "string", + "role": "system", + }, + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Job, evaluate, path=["response"]) + + @parametrize + async def test_raw_response_evaluate_batch(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.evaluate.with_raw_response.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + evaluate = await response.parse() + assert_matches_type(Job, evaluate, path=["response"]) + + @parametrize + async def test_streaming_response_evaluate_batch(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.evaluate.with_streaming_response.evaluate_batch( + candidate={ + "model": "model", + "sampling_params": {"strategy": "greedy"}, + "type": "model", + }, + dataset_id="dataset_id", + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + evaluate = await response.parse() + assert_matches_type(Job, evaluate, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_evaluations.py b/tests/api_resources/test_evaluations.py index 760ac40d..8291396c 100644 --- a/tests/api_resources/test_evaluations.py +++ b/tests/api_resources/test_evaluations.py @@ -17,45 +17,6 @@ class TestEvaluations: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @parametrize - def test_method_summarization(self, client: LlamaStackClient) -> None: - evaluation = client.evaluations.summarization( - metrics=["rouge", "bleu"], - ) - assert_matches_type(EvaluationJob, evaluation, path=["response"]) - - @parametrize - def test_method_summarization_with_all_params(self, client: LlamaStackClient) -> None: - evaluation = client.evaluations.summarization( - metrics=["rouge", "bleu"], - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert_matches_type(EvaluationJob, evaluation, path=["response"]) - - @parametrize - def test_raw_response_summarization(self, client: LlamaStackClient) -> None: - response = client.evaluations.with_raw_response.summarization( - metrics=["rouge", "bleu"], - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - evaluation = response.parse() - assert_matches_type(EvaluationJob, evaluation, path=["response"]) - - @parametrize - def test_streaming_response_summarization(self, client: LlamaStackClient) -> None: - with client.evaluations.with_streaming_response.summarization( - metrics=["rouge", "bleu"], - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - evaluation = response.parse() - assert_matches_type(EvaluationJob, evaluation, path=["response"]) - - assert cast(Any, response.is_closed) is True - @parametrize def test_method_text_generation(self, client: LlamaStackClient) -> None: evaluation = client.evaluations.text_generation( @@ -99,45 +60,6 @@ def test_streaming_response_text_generation(self, client: LlamaStackClient) -> N class TestAsyncEvaluations: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - @parametrize - async def test_method_summarization(self, async_client: AsyncLlamaStackClient) -> None: - evaluation = await async_client.evaluations.summarization( - metrics=["rouge", "bleu"], - ) - assert_matches_type(EvaluationJob, evaluation, path=["response"]) - - @parametrize - async def test_method_summarization_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - evaluation = await async_client.evaluations.summarization( - metrics=["rouge", "bleu"], - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert_matches_type(EvaluationJob, evaluation, path=["response"]) - - @parametrize - async def test_raw_response_summarization(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.evaluations.with_raw_response.summarization( - metrics=["rouge", "bleu"], - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - evaluation = await response.parse() - assert_matches_type(EvaluationJob, evaluation, path=["response"]) - - @parametrize - async def test_streaming_response_summarization(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.evaluations.with_streaming_response.summarization( - metrics=["rouge", "bleu"], - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - evaluation = await response.parse() - assert_matches_type(EvaluationJob, evaluation, path=["response"]) - - assert cast(Any, response.is_closed) is True - @parametrize async def test_method_text_generation(self, async_client: AsyncLlamaStackClient) -> None: evaluation = await async_client.evaluations.text_generation( diff --git a/tests/api_resources/test_inference.py b/tests/api_resources/test_inference.py index e4b5b6fe..870aecf4 100644 --- a/tests/api_resources/test_inference.py +++ b/tests/api_resources/test_inference.py @@ -10,6 +10,7 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient from llama_stack_client.types import ( + EmbeddingsResponse, InferenceCompletionResponse, InferenceChatCompletionResponse, ) @@ -20,8 +21,11 @@ class TestInference: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) @parametrize - def test_method_chat_completion_overload_1(self, client: LlamaStackClient) -> None: + def test_method_chat_completion(self, client: LlamaStackClient) -> None: inference = client.inference.chat_completion( messages=[ { @@ -41,8 +45,11 @@ def test_method_chat_completion_overload_1(self, client: LlamaStackClient) -> No ) assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) @parametrize - def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaStackClient) -> None: + def test_method_chat_completion_with_all_params(self, client: LlamaStackClient) -> None: inference = client.inference.chat_completion( messages=[ { @@ -63,6 +70,10 @@ def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaSt ], model="model", logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, sampling_params={ "strategy": "greedy", "max_tokens": 0, @@ -71,7 +82,7 @@ def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaSt "top_k": 0, "top_p": 0, }, - stream=False, + stream=True, tool_choice="auto", tool_prompt_format="json", tools=[ @@ -116,8 +127,11 @@ def test_method_chat_completion_with_all_params_overload_1(self, client: LlamaSt ) assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) @parametrize - def test_raw_response_chat_completion_overload_1(self, client: LlamaStackClient) -> None: + def test_raw_response_chat_completion(self, client: LlamaStackClient) -> None: response = client.inference.with_raw_response.chat_completion( messages=[ { @@ -141,8 +155,11 @@ def test_raw_response_chat_completion_overload_1(self, client: LlamaStackClient) inference = response.parse() assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) @parametrize - def test_streaming_response_chat_completion_overload_1(self, client: LlamaStackClient) -> None: + def test_streaming_response_chat_completion(self, client: LlamaStackClient) -> None: with client.inference.with_streaming_response.chat_completion( messages=[ { @@ -169,50 +186,23 @@ def test_streaming_response_chat_completion_overload_1(self, client: LlamaStackC assert cast(Any, response.is_closed) is True @parametrize - def test_method_chat_completion_overload_2(self, client: LlamaStackClient) -> None: - inference_stream = client.inference.chat_completion( - messages=[ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + def test_method_completion(self, client: LlamaStackClient) -> None: + inference = client.inference.completion( + content="string", model="model", - stream=True, ) - inference_stream.response.close() + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize - def test_method_chat_completion_with_all_params_overload_2(self, client: LlamaStackClient) -> None: - inference_stream = client.inference.chat_completion( - messages=[ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], + def test_method_completion_with_all_params(self, client: LlamaStackClient) -> None: + inference = client.inference.completion( + content="string", model="model", - stream=True, logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, sampling_params={ "strategy": "greedy", "max_tokens": 0, @@ -221,153 +211,77 @@ def test_method_chat_completion_with_all_params_overload_2(self, client: LlamaSt "top_k": 0, "top_p": 0, }, - tool_choice="auto", - tool_prompt_format="json", - tools=[ - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "default": True, - "description": "description", - "required": True, - } - }, - }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "default": True, - "description": "description", - "required": True, - } - }, - }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "default": True, - "description": "description", - "required": True, - } - }, - }, - ], + stream=True, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - inference_stream.response.close() + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize - def test_raw_response_chat_completion_overload_2(self, client: LlamaStackClient) -> None: - response = client.inference.with_raw_response.chat_completion( - messages=[ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + def test_raw_response_completion(self, client: LlamaStackClient) -> None: + response = client.inference.with_raw_response.completion( + content="string", model="model", - stream=True, ) + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - stream = response.parse() - stream.close() + inference = response.parse() + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize - def test_streaming_response_chat_completion_overload_2(self, client: LlamaStackClient) -> None: - with client.inference.with_streaming_response.chat_completion( - messages=[ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + def test_streaming_response_completion(self, client: LlamaStackClient) -> None: + with client.inference.with_streaming_response.completion( + content="string", model="model", - stream=True, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" - stream = response.parse() - stream.close() + inference = response.parse() + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - def test_method_completion(self, client: LlamaStackClient) -> None: - inference = client.inference.completion( - content="string", + def test_method_embeddings(self, client: LlamaStackClient) -> None: + inference = client.inference.embeddings( + contents=["string", "string", "string"], model="model", ) - assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) + assert_matches_type(EmbeddingsResponse, inference, path=["response"]) @parametrize - def test_method_completion_with_all_params(self, client: LlamaStackClient) -> None: - inference = client.inference.completion( - content="string", + def test_method_embeddings_with_all_params(self, client: LlamaStackClient) -> None: + inference = client.inference.embeddings( + contents=["string", "string", "string"], model="model", - logprobs={"top_k": 0}, - sampling_params={ - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, - }, - stream=True, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) + assert_matches_type(EmbeddingsResponse, inference, path=["response"]) @parametrize - def test_raw_response_completion(self, client: LlamaStackClient) -> None: - response = client.inference.with_raw_response.completion( - content="string", + def test_raw_response_embeddings(self, client: LlamaStackClient) -> None: + response = client.inference.with_raw_response.embeddings( + contents=["string", "string", "string"], model="model", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" inference = response.parse() - assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) + assert_matches_type(EmbeddingsResponse, inference, path=["response"]) @parametrize - def test_streaming_response_completion(self, client: LlamaStackClient) -> None: - with client.inference.with_streaming_response.completion( - content="string", + def test_streaming_response_embeddings(self, client: LlamaStackClient) -> None: + with client.inference.with_streaming_response.embeddings( + contents=["string", "string", "string"], model="model", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" inference = response.parse() - assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) + assert_matches_type(EmbeddingsResponse, inference, path=["response"]) assert cast(Any, response.is_closed) is True @@ -375,8 +289,11 @@ def test_streaming_response_completion(self, client: LlamaStackClient) -> None: class TestAsyncInference: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) @parametrize - async def test_method_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + async def test_method_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: inference = await async_client.inference.chat_completion( messages=[ { @@ -396,8 +313,11 @@ async def test_method_chat_completion_overload_1(self, async_client: AsyncLlamaS ) assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) @parametrize - async def test_method_chat_completion_with_all_params_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + async def test_method_chat_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: inference = await async_client.inference.chat_completion( messages=[ { @@ -418,6 +338,10 @@ async def test_method_chat_completion_with_all_params_overload_1(self, async_cli ], model="model", logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, sampling_params={ "strategy": "greedy", "max_tokens": 0, @@ -426,7 +350,7 @@ async def test_method_chat_completion_with_all_params_overload_1(self, async_cli "top_k": 0, "top_p": 0, }, - stream=False, + stream=True, tool_choice="auto", tool_prompt_format="json", tools=[ @@ -471,8 +395,11 @@ async def test_method_chat_completion_with_all_params_overload_1(self, async_cli ) assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) @parametrize - async def test_raw_response_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + async def test_raw_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.inference.with_raw_response.chat_completion( messages=[ { @@ -496,8 +423,11 @@ async def test_raw_response_chat_completion_overload_1(self, async_client: Async inference = await response.parse() assert_matches_type(InferenceChatCompletionResponse, inference, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type text/event-stream, Prism mock server will fail" + ) @parametrize - async def test_streaming_response_chat_completion_overload_1(self, async_client: AsyncLlamaStackClient) -> None: + async def test_streaming_response_chat_completion(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.inference.with_streaming_response.chat_completion( messages=[ { @@ -524,50 +454,23 @@ async def test_streaming_response_chat_completion_overload_1(self, async_client: assert cast(Any, response.is_closed) is True @parametrize - async def test_method_chat_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None: - inference_stream = await async_client.inference.chat_completion( - messages=[ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + async def test_method_completion(self, async_client: AsyncLlamaStackClient) -> None: + inference = await async_client.inference.completion( + content="string", model="model", - stream=True, ) - await inference_stream.response.aclose() + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize - async def test_method_chat_completion_with_all_params_overload_2(self, async_client: AsyncLlamaStackClient) -> None: - inference_stream = await async_client.inference.chat_completion( - messages=[ - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - { - "content": "string", - "role": "user", - "context": "string", - }, - ], + async def test_method_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + inference = await async_client.inference.completion( + content="string", model="model", - stream=True, logprobs={"top_k": 0}, + response_format={ + "json_schema": {"foo": True}, + "type": "json_schema", + }, sampling_params={ "strategy": "greedy", "max_tokens": 0, @@ -576,152 +479,76 @@ async def test_method_chat_completion_with_all_params_overload_2(self, async_cli "top_k": 0, "top_p": 0, }, - tool_choice="auto", - tool_prompt_format="json", - tools=[ - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "default": True, - "description": "description", - "required": True, - } - }, - }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "default": True, - "description": "description", - "required": True, - } - }, - }, - { - "tool_name": "brave_search", - "description": "description", - "parameters": { - "foo": { - "param_type": "param_type", - "default": True, - "description": "description", - "required": True, - } - }, - }, - ], + stream=True, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - await inference_stream.response.aclose() + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize - async def test_raw_response_chat_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.inference.with_raw_response.chat_completion( - messages=[ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + async def test_raw_response_completion(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.inference.with_raw_response.completion( + content="string", model="model", - stream=True, ) + assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" - stream = await response.parse() - await stream.close() + inference = await response.parse() + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) @parametrize - async def test_streaming_response_chat_completion_overload_2(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.inference.with_streaming_response.chat_completion( - messages=[ - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - { - "content": "string", - "role": "user", - }, - ], + async def test_streaming_response_completion(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.inference.with_streaming_response.completion( + content="string", model="model", - stream=True, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" - stream = await response.parse() - await stream.close() + inference = await response.parse() + assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - async def test_method_completion(self, async_client: AsyncLlamaStackClient) -> None: - inference = await async_client.inference.completion( - content="string", + async def test_method_embeddings(self, async_client: AsyncLlamaStackClient) -> None: + inference = await async_client.inference.embeddings( + contents=["string", "string", "string"], model="model", ) - assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) + assert_matches_type(EmbeddingsResponse, inference, path=["response"]) @parametrize - async def test_method_completion_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - inference = await async_client.inference.completion( - content="string", + async def test_method_embeddings_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + inference = await async_client.inference.embeddings( + contents=["string", "string", "string"], model="model", - logprobs={"top_k": 0}, - sampling_params={ - "strategy": "greedy", - "max_tokens": 0, - "repetition_penalty": 0, - "temperature": 0, - "top_k": 0, - "top_p": 0, - }, - stream=True, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) + assert_matches_type(EmbeddingsResponse, inference, path=["response"]) @parametrize - async def test_raw_response_completion(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.inference.with_raw_response.completion( - content="string", + async def test_raw_response_embeddings(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.inference.with_raw_response.embeddings( + contents=["string", "string", "string"], model="model", ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" inference = await response.parse() - assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) + assert_matches_type(EmbeddingsResponse, inference, path=["response"]) @parametrize - async def test_streaming_response_completion(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.inference.with_streaming_response.completion( - content="string", + async def test_streaming_response_embeddings(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.inference.with_streaming_response.embeddings( + contents=["string", "string", "string"], model="model", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" inference = await response.parse() - assert_matches_type(InferenceCompletionResponse, inference, path=["response"]) + assert_matches_type(EmbeddingsResponse, inference, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_inspect.py b/tests/api_resources/test_inspect.py new file mode 100644 index 00000000..66c8aa9c --- /dev/null +++ b/tests/api_resources/test_inspect.py @@ -0,0 +1,86 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types import HealthInfo + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestInspect: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_health(self, client: LlamaStackClient) -> None: + inspect = client.inspect.health() + assert_matches_type(HealthInfo, inspect, path=["response"]) + + @parametrize + def test_method_health_with_all_params(self, client: LlamaStackClient) -> None: + inspect = client.inspect.health( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(HealthInfo, inspect, path=["response"]) + + @parametrize + def test_raw_response_health(self, client: LlamaStackClient) -> None: + response = client.inspect.with_raw_response.health() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + inspect = response.parse() + assert_matches_type(HealthInfo, inspect, path=["response"]) + + @parametrize + def test_streaming_response_health(self, client: LlamaStackClient) -> None: + with client.inspect.with_streaming_response.health() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + inspect = response.parse() + assert_matches_type(HealthInfo, inspect, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncInspect: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_health(self, async_client: AsyncLlamaStackClient) -> None: + inspect = await async_client.inspect.health() + assert_matches_type(HealthInfo, inspect, path=["response"]) + + @parametrize + async def test_method_health_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + inspect = await async_client.inspect.health( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(HealthInfo, inspect, path=["response"]) + + @parametrize + async def test_raw_response_health(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.inspect.with_raw_response.health() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + inspect = await response.parse() + assert_matches_type(HealthInfo, inspect, path=["response"]) + + @parametrize + async def test_streaming_response_health(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.inspect.with_streaming_response.health() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + inspect = await response.parse() + assert_matches_type(HealthInfo, inspect, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_memory.py b/tests/api_resources/test_memory.py index 4af1bdf3..cbe16a14 100644 --- a/tests/api_resources/test_memory.py +++ b/tests/api_resources/test_memory.py @@ -9,9 +9,7 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import ( - QueryDocuments, -) +from llama_stack_client.types import QueryDocumentsResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -19,265 +17,6 @@ class TestMemory: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) - @parametrize - def test_method_create(self, client: LlamaStackClient) -> None: - memory = client.memory.create( - body={}, - ) - assert_matches_type(object, memory, path=["response"]) - - @parametrize - def test_method_create_with_all_params(self, client: LlamaStackClient) -> None: - memory = client.memory.create( - body={}, - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert_matches_type(object, memory, path=["response"]) - - @parametrize - def test_raw_response_create(self, client: LlamaStackClient) -> None: - response = client.memory.with_raw_response.create( - body={}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory = response.parse() - assert_matches_type(object, memory, path=["response"]) - - @parametrize - def test_streaming_response_create(self, client: LlamaStackClient) -> None: - with client.memory.with_streaming_response.create( - body={}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory = response.parse() - assert_matches_type(object, memory, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_retrieve(self, client: LlamaStackClient) -> None: - memory = client.memory.retrieve( - bank_id="bank_id", - ) - assert_matches_type(object, memory, path=["response"]) - - @parametrize - def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: - memory = client.memory.retrieve( - bank_id="bank_id", - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert_matches_type(object, memory, path=["response"]) - - @parametrize - def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: - response = client.memory.with_raw_response.retrieve( - bank_id="bank_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory = response.parse() - assert_matches_type(object, memory, path=["response"]) - - @parametrize - def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: - with client.memory.with_streaming_response.retrieve( - bank_id="bank_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory = response.parse() - assert_matches_type(object, memory, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_update(self, client: LlamaStackClient) -> None: - memory = client.memory.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) - assert memory is None - - @parametrize - def test_method_update_with_all_params(self, client: LlamaStackClient) -> None: - memory = client.memory.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - ], - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert memory is None - - @parametrize - def test_raw_response_update(self, client: LlamaStackClient) -> None: - response = client.memory.with_raw_response.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory = response.parse() - assert memory is None - - @parametrize - def test_streaming_response_update(self, client: LlamaStackClient) -> None: - with client.memory.with_streaming_response.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory = response.parse() - assert memory is None - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_list(self, client: LlamaStackClient) -> None: - memory = client.memory.list() - assert_matches_type(object, memory, path=["response"]) - - @parametrize - def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: - memory = client.memory.list( - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert_matches_type(object, memory, path=["response"]) - - @parametrize - def test_raw_response_list(self, client: LlamaStackClient) -> None: - response = client.memory.with_raw_response.list() - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory = response.parse() - assert_matches_type(object, memory, path=["response"]) - - @parametrize - def test_streaming_response_list(self, client: LlamaStackClient) -> None: - with client.memory.with_streaming_response.list() as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory = response.parse() - assert_matches_type(object, memory, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - def test_method_drop(self, client: LlamaStackClient) -> None: - memory = client.memory.drop( - bank_id="bank_id", - ) - assert_matches_type(str, memory, path=["response"]) - - @parametrize - def test_method_drop_with_all_params(self, client: LlamaStackClient) -> None: - memory = client.memory.drop( - bank_id="bank_id", - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert_matches_type(str, memory, path=["response"]) - - @parametrize - def test_raw_response_drop(self, client: LlamaStackClient) -> None: - response = client.memory.with_raw_response.drop( - bank_id="bank_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory = response.parse() - assert_matches_type(str, memory, path=["response"]) - - @parametrize - def test_streaming_response_drop(self, client: LlamaStackClient) -> None: - with client.memory.with_streaming_response.drop( - bank_id="bank_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory = response.parse() - assert_matches_type(str, memory, path=["response"]) - - assert cast(Any, response.is_closed) is True - @parametrize def test_method_insert(self, client: LlamaStackClient) -> None: memory = client.memory.insert( @@ -395,7 +134,7 @@ def test_method_query(self, client: LlamaStackClient) -> None: bank_id="bank_id", query="string", ) - assert_matches_type(QueryDocuments, memory, path=["response"]) + assert_matches_type(QueryDocumentsResponse, memory, path=["response"]) @parametrize def test_method_query_with_all_params(self, client: LlamaStackClient) -> None: @@ -405,7 +144,7 @@ def test_method_query_with_all_params(self, client: LlamaStackClient) -> None: params={"foo": True}, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(QueryDocuments, memory, path=["response"]) + assert_matches_type(QueryDocumentsResponse, memory, path=["response"]) @parametrize def test_raw_response_query(self, client: LlamaStackClient) -> None: @@ -417,7 +156,7 @@ def test_raw_response_query(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory = response.parse() - assert_matches_type(QueryDocuments, memory, path=["response"]) + assert_matches_type(QueryDocumentsResponse, memory, path=["response"]) @parametrize def test_streaming_response_query(self, client: LlamaStackClient) -> None: @@ -429,7 +168,7 @@ def test_streaming_response_query(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory = response.parse() - assert_matches_type(QueryDocuments, memory, path=["response"]) + assert_matches_type(QueryDocumentsResponse, memory, path=["response"]) assert cast(Any, response.is_closed) is True @@ -437,265 +176,6 @@ def test_streaming_response_query(self, client: LlamaStackClient) -> None: class TestAsyncMemory: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) - @parametrize - async def test_method_create(self, async_client: AsyncLlamaStackClient) -> None: - memory = await async_client.memory.create( - body={}, - ) - assert_matches_type(object, memory, path=["response"]) - - @parametrize - async def test_method_create_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - memory = await async_client.memory.create( - body={}, - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert_matches_type(object, memory, path=["response"]) - - @parametrize - async def test_raw_response_create(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.memory.with_raw_response.create( - body={}, - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory = await response.parse() - assert_matches_type(object, memory, path=["response"]) - - @parametrize - async def test_streaming_response_create(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.memory.with_streaming_response.create( - body={}, - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory = await response.parse() - assert_matches_type(object, memory, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: - memory = await async_client.memory.retrieve( - bank_id="bank_id", - ) - assert_matches_type(object, memory, path=["response"]) - - @parametrize - async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - memory = await async_client.memory.retrieve( - bank_id="bank_id", - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert_matches_type(object, memory, path=["response"]) - - @parametrize - async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.memory.with_raw_response.retrieve( - bank_id="bank_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory = await response.parse() - assert_matches_type(object, memory, path=["response"]) - - @parametrize - async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.memory.with_streaming_response.retrieve( - bank_id="bank_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory = await response.parse() - assert_matches_type(object, memory, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_update(self, async_client: AsyncLlamaStackClient) -> None: - memory = await async_client.memory.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) - assert memory is None - - @parametrize - async def test_method_update_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - memory = await async_client.memory.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - "mime_type": "mime_type", - }, - ], - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert memory is None - - @parametrize - async def test_raw_response_update(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.memory.with_raw_response.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory = await response.parse() - assert memory is None - - @parametrize - async def test_streaming_response_update(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.memory.with_streaming_response.update( - bank_id="bank_id", - documents=[ - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - { - "content": "string", - "document_id": "document_id", - "metadata": {"foo": True}, - }, - ], - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory = await response.parse() - assert memory is None - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: - memory = await async_client.memory.list() - assert_matches_type(object, memory, path=["response"]) - - @parametrize - async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - memory = await async_client.memory.list( - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert_matches_type(object, memory, path=["response"]) - - @parametrize - async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.memory.with_raw_response.list() - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory = await response.parse() - assert_matches_type(object, memory, path=["response"]) - - @parametrize - async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.memory.with_streaming_response.list() as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory = await response.parse() - assert_matches_type(object, memory, path=["response"]) - - assert cast(Any, response.is_closed) is True - - @parametrize - async def test_method_drop(self, async_client: AsyncLlamaStackClient) -> None: - memory = await async_client.memory.drop( - bank_id="bank_id", - ) - assert_matches_type(str, memory, path=["response"]) - - @parametrize - async def test_method_drop_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - memory = await async_client.memory.drop( - bank_id="bank_id", - x_llama_stack_provider_data="X-LlamaStack-ProviderData", - ) - assert_matches_type(str, memory, path=["response"]) - - @parametrize - async def test_raw_response_drop(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.memory.with_raw_response.drop( - bank_id="bank_id", - ) - - assert response.is_closed is True - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - memory = await response.parse() - assert_matches_type(str, memory, path=["response"]) - - @parametrize - async def test_streaming_response_drop(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.memory.with_streaming_response.drop( - bank_id="bank_id", - ) as response: - assert not response.is_closed - assert response.http_request.headers.get("X-Stainless-Lang") == "python" - - memory = await response.parse() - assert_matches_type(str, memory, path=["response"]) - - assert cast(Any, response.is_closed) is True - @parametrize async def test_method_insert(self, async_client: AsyncLlamaStackClient) -> None: memory = await async_client.memory.insert( @@ -813,7 +293,7 @@ async def test_method_query(self, async_client: AsyncLlamaStackClient) -> None: bank_id="bank_id", query="string", ) - assert_matches_type(QueryDocuments, memory, path=["response"]) + assert_matches_type(QueryDocumentsResponse, memory, path=["response"]) @parametrize async def test_method_query_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: @@ -823,7 +303,7 @@ async def test_method_query_with_all_params(self, async_client: AsyncLlamaStackC params={"foo": True}, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(QueryDocuments, memory, path=["response"]) + assert_matches_type(QueryDocumentsResponse, memory, path=["response"]) @parametrize async def test_raw_response_query(self, async_client: AsyncLlamaStackClient) -> None: @@ -835,7 +315,7 @@ async def test_raw_response_query(self, async_client: AsyncLlamaStackClient) -> assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory = await response.parse() - assert_matches_type(QueryDocuments, memory, path=["response"]) + assert_matches_type(QueryDocumentsResponse, memory, path=["response"]) @parametrize async def test_streaming_response_query(self, async_client: AsyncLlamaStackClient) -> None: @@ -847,6 +327,6 @@ async def test_streaming_response_query(self, async_client: AsyncLlamaStackClien assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory = await response.parse() - assert_matches_type(QueryDocuments, memory, path=["response"]) + assert_matches_type(QueryDocumentsResponse, memory, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_memory_banks.py b/tests/api_resources/test_memory_banks.py index 764787b3..ba73846a 100644 --- a/tests/api_resources/test_memory_banks.py +++ b/tests/api_resources/test_memory_banks.py @@ -9,7 +9,10 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import MemoryBankSpec +from llama_stack_client.types import ( + MemoryBankListResponse, + MemoryBankRetrieveResponse, +) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -17,18 +20,66 @@ class TestMemoryBanks: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + @parametrize + def test_method_retrieve(self, client: LlamaStackClient) -> None: + memory_bank = client.memory_banks.retrieve( + identifier="identifier", + ) + assert_matches_type(Optional[MemoryBankRetrieveResponse], memory_bank, path=["response"]) + + @parametrize + def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: + memory_bank = client.memory_banks.retrieve( + identifier="identifier", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[MemoryBankRetrieveResponse], memory_bank, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: + response = client.memory_banks.with_raw_response.retrieve( + identifier="identifier", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory_bank = response.parse() + assert_matches_type(Optional[MemoryBankRetrieveResponse], memory_bank, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: + with client.memory_banks.with_streaming_response.retrieve( + identifier="identifier", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory_bank = response.parse() + assert_matches_type(Optional[MemoryBankRetrieveResponse], memory_bank, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_method_list(self, client: LlamaStackClient) -> None: memory_bank = client.memory_banks.list() - assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) + assert_matches_type(MemoryBankListResponse, memory_bank, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: memory_bank = client.memory_banks.list( x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) + assert_matches_type(MemoryBankListResponse, memory_bank, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_raw_response_list(self, client: LlamaStackClient) -> None: response = client.memory_banks.with_raw_response.list() @@ -36,8 +87,11 @@ def test_raw_response_list(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = response.parse() - assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) + assert_matches_type(MemoryBankListResponse, memory_bank, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_streaming_response_list(self, client: LlamaStackClient) -> None: with client.memory_banks.with_streaming_response.list() as response: @@ -45,46 +99,71 @@ def test_streaming_response_list(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = response.parse() - assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) + assert_matches_type(MemoryBankListResponse, memory_bank, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - def test_method_get(self, client: LlamaStackClient) -> None: - memory_bank = client.memory_banks.get( - bank_type="vector", + def test_method_register(self, client: LlamaStackClient) -> None: + memory_bank = client.memory_banks.register( + memory_bank={ + "chunk_size_in_tokens": 0, + "embedding_model": "embedding_model", + "identifier": "identifier", + "provider_id": "provider_id", + "type": "vector", + }, ) - assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) + assert memory_bank is None @parametrize - def test_method_get_with_all_params(self, client: LlamaStackClient) -> None: - memory_bank = client.memory_banks.get( - bank_type="vector", + def test_method_register_with_all_params(self, client: LlamaStackClient) -> None: + memory_bank = client.memory_banks.register( + memory_bank={ + "chunk_size_in_tokens": 0, + "embedding_model": "embedding_model", + "identifier": "identifier", + "provider_id": "provider_id", + "type": "vector", + "overlap_size_in_tokens": 0, + }, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) + assert memory_bank is None @parametrize - def test_raw_response_get(self, client: LlamaStackClient) -> None: - response = client.memory_banks.with_raw_response.get( - bank_type="vector", + def test_raw_response_register(self, client: LlamaStackClient) -> None: + response = client.memory_banks.with_raw_response.register( + memory_bank={ + "chunk_size_in_tokens": 0, + "embedding_model": "embedding_model", + "identifier": "identifier", + "provider_id": "provider_id", + "type": "vector", + }, ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = response.parse() - assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) + assert memory_bank is None @parametrize - def test_streaming_response_get(self, client: LlamaStackClient) -> None: - with client.memory_banks.with_streaming_response.get( - bank_type="vector", + def test_streaming_response_register(self, client: LlamaStackClient) -> None: + with client.memory_banks.with_streaming_response.register( + memory_bank={ + "chunk_size_in_tokens": 0, + "embedding_model": "embedding_model", + "identifier": "identifier", + "provider_id": "provider_id", + "type": "vector", + }, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = response.parse() - assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) + assert memory_bank is None assert cast(Any, response.is_closed) is True @@ -92,18 +171,66 @@ def test_streaming_response_get(self, client: LlamaStackClient) -> None: class TestAsyncMemoryBanks: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @parametrize + async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + memory_bank = await async_client.memory_banks.retrieve( + identifier="identifier", + ) + assert_matches_type(Optional[MemoryBankRetrieveResponse], memory_bank, path=["response"]) + + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + memory_bank = await async_client.memory_banks.retrieve( + identifier="identifier", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[MemoryBankRetrieveResponse], memory_bank, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.memory_banks.with_raw_response.retrieve( + identifier="identifier", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + memory_bank = await response.parse() + assert_matches_type(Optional[MemoryBankRetrieveResponse], memory_bank, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.memory_banks.with_streaming_response.retrieve( + identifier="identifier", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + memory_bank = await response.parse() + assert_matches_type(Optional[MemoryBankRetrieveResponse], memory_bank, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: memory_bank = await async_client.memory_banks.list() - assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) + assert_matches_type(MemoryBankListResponse, memory_bank, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: memory_bank = await async_client.memory_banks.list( x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) + assert_matches_type(MemoryBankListResponse, memory_bank, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.memory_banks.with_raw_response.list() @@ -111,8 +238,11 @@ async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> N assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = await response.parse() - assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) + assert_matches_type(MemoryBankListResponse, memory_bank, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.memory_banks.with_streaming_response.list() as response: @@ -120,45 +250,70 @@ async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = await response.parse() - assert_matches_type(MemoryBankSpec, memory_bank, path=["response"]) + assert_matches_type(MemoryBankListResponse, memory_bank, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - async def test_method_get(self, async_client: AsyncLlamaStackClient) -> None: - memory_bank = await async_client.memory_banks.get( - bank_type="vector", + async def test_method_register(self, async_client: AsyncLlamaStackClient) -> None: + memory_bank = await async_client.memory_banks.register( + memory_bank={ + "chunk_size_in_tokens": 0, + "embedding_model": "embedding_model", + "identifier": "identifier", + "provider_id": "provider_id", + "type": "vector", + }, ) - assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) + assert memory_bank is None @parametrize - async def test_method_get_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - memory_bank = await async_client.memory_banks.get( - bank_type="vector", + async def test_method_register_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + memory_bank = await async_client.memory_banks.register( + memory_bank={ + "chunk_size_in_tokens": 0, + "embedding_model": "embedding_model", + "identifier": "identifier", + "provider_id": "provider_id", + "type": "vector", + "overlap_size_in_tokens": 0, + }, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) + assert memory_bank is None @parametrize - async def test_raw_response_get(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.memory_banks.with_raw_response.get( - bank_type="vector", + async def test_raw_response_register(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.memory_banks.with_raw_response.register( + memory_bank={ + "chunk_size_in_tokens": 0, + "embedding_model": "embedding_model", + "identifier": "identifier", + "provider_id": "provider_id", + "type": "vector", + }, ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = await response.parse() - assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) + assert memory_bank is None @parametrize - async def test_streaming_response_get(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.memory_banks.with_streaming_response.get( - bank_type="vector", + async def test_streaming_response_register(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.memory_banks.with_streaming_response.register( + memory_bank={ + "chunk_size_in_tokens": 0, + "embedding_model": "embedding_model", + "identifier": "identifier", + "provider_id": "provider_id", + "type": "vector", + }, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" memory_bank = await response.parse() - assert_matches_type(Optional[MemoryBankSpec], memory_bank, path=["response"]) + assert memory_bank is None assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_models.py b/tests/api_resources/test_models.py index 83cfa612..144c967c 100644 --- a/tests/api_resources/test_models.py +++ b/tests/api_resources/test_models.py @@ -9,7 +9,7 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import ModelServingSpec +from llama_stack_client.types import ModelDefWithProvider base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -17,18 +17,66 @@ class TestModels: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + @parametrize + def test_method_retrieve(self, client: LlamaStackClient) -> None: + model = client.models.retrieve( + identifier="identifier", + ) + assert_matches_type(Optional[ModelDefWithProvider], model, path=["response"]) + + @parametrize + def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: + model = client.models.retrieve( + identifier="identifier", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[ModelDefWithProvider], model, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: + response = client.models.with_raw_response.retrieve( + identifier="identifier", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = response.parse() + assert_matches_type(Optional[ModelDefWithProvider], model, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: + with client.models.with_streaming_response.retrieve( + identifier="identifier", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = response.parse() + assert_matches_type(Optional[ModelDefWithProvider], model, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_method_list(self, client: LlamaStackClient) -> None: model = client.models.list() - assert_matches_type(ModelServingSpec, model, path=["response"]) + assert_matches_type(ModelDefWithProvider, model, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: model = client.models.list( x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(ModelServingSpec, model, path=["response"]) + assert_matches_type(ModelDefWithProvider, model, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_raw_response_list(self, client: LlamaStackClient) -> None: response = client.models.with_raw_response.list() @@ -36,8 +84,11 @@ def test_raw_response_list(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = response.parse() - assert_matches_type(ModelServingSpec, model, path=["response"]) + assert_matches_type(ModelDefWithProvider, model, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_streaming_response_list(self, client: LlamaStackClient) -> None: with client.models.with_streaming_response.list() as response: @@ -45,46 +96,66 @@ def test_streaming_response_list(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = response.parse() - assert_matches_type(ModelServingSpec, model, path=["response"]) + assert_matches_type(ModelDefWithProvider, model, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - def test_method_get(self, client: LlamaStackClient) -> None: - model = client.models.get( - core_model_id="core_model_id", + def test_method_register(self, client: LlamaStackClient) -> None: + model = client.models.register( + model={ + "identifier": "identifier", + "llama_model": "llama_model", + "metadata": {"foo": True}, + "provider_id": "provider_id", + }, ) - assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + assert model is None @parametrize - def test_method_get_with_all_params(self, client: LlamaStackClient) -> None: - model = client.models.get( - core_model_id="core_model_id", + def test_method_register_with_all_params(self, client: LlamaStackClient) -> None: + model = client.models.register( + model={ + "identifier": "identifier", + "llama_model": "llama_model", + "metadata": {"foo": True}, + "provider_id": "provider_id", + }, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + assert model is None @parametrize - def test_raw_response_get(self, client: LlamaStackClient) -> None: - response = client.models.with_raw_response.get( - core_model_id="core_model_id", + def test_raw_response_register(self, client: LlamaStackClient) -> None: + response = client.models.with_raw_response.register( + model={ + "identifier": "identifier", + "llama_model": "llama_model", + "metadata": {"foo": True}, + "provider_id": "provider_id", + }, ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = response.parse() - assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + assert model is None @parametrize - def test_streaming_response_get(self, client: LlamaStackClient) -> None: - with client.models.with_streaming_response.get( - core_model_id="core_model_id", + def test_streaming_response_register(self, client: LlamaStackClient) -> None: + with client.models.with_streaming_response.register( + model={ + "identifier": "identifier", + "llama_model": "llama_model", + "metadata": {"foo": True}, + "provider_id": "provider_id", + }, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = response.parse() - assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + assert model is None assert cast(Any, response.is_closed) is True @@ -92,18 +163,66 @@ def test_streaming_response_get(self, client: LlamaStackClient) -> None: class TestAsyncModels: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @parametrize + async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + model = await async_client.models.retrieve( + identifier="identifier", + ) + assert_matches_type(Optional[ModelDefWithProvider], model, path=["response"]) + + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + model = await async_client.models.retrieve( + identifier="identifier", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[ModelDefWithProvider], model, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.models.with_raw_response.retrieve( + identifier="identifier", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + model = await response.parse() + assert_matches_type(Optional[ModelDefWithProvider], model, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.models.with_streaming_response.retrieve( + identifier="identifier", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + model = await response.parse() + assert_matches_type(Optional[ModelDefWithProvider], model, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: model = await async_client.models.list() - assert_matches_type(ModelServingSpec, model, path=["response"]) + assert_matches_type(ModelDefWithProvider, model, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: model = await async_client.models.list( x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(ModelServingSpec, model, path=["response"]) + assert_matches_type(ModelDefWithProvider, model, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.models.with_raw_response.list() @@ -111,8 +230,11 @@ async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> N assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = await response.parse() - assert_matches_type(ModelServingSpec, model, path=["response"]) + assert_matches_type(ModelDefWithProvider, model, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.models.with_streaming_response.list() as response: @@ -120,45 +242,65 @@ async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = await response.parse() - assert_matches_type(ModelServingSpec, model, path=["response"]) + assert_matches_type(ModelDefWithProvider, model, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - async def test_method_get(self, async_client: AsyncLlamaStackClient) -> None: - model = await async_client.models.get( - core_model_id="core_model_id", + async def test_method_register(self, async_client: AsyncLlamaStackClient) -> None: + model = await async_client.models.register( + model={ + "identifier": "identifier", + "llama_model": "llama_model", + "metadata": {"foo": True}, + "provider_id": "provider_id", + }, ) - assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + assert model is None @parametrize - async def test_method_get_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - model = await async_client.models.get( - core_model_id="core_model_id", + async def test_method_register_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + model = await async_client.models.register( + model={ + "identifier": "identifier", + "llama_model": "llama_model", + "metadata": {"foo": True}, + "provider_id": "provider_id", + }, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + assert model is None @parametrize - async def test_raw_response_get(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.models.with_raw_response.get( - core_model_id="core_model_id", + async def test_raw_response_register(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.models.with_raw_response.register( + model={ + "identifier": "identifier", + "llama_model": "llama_model", + "metadata": {"foo": True}, + "provider_id": "provider_id", + }, ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = await response.parse() - assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + assert model is None @parametrize - async def test_streaming_response_get(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.models.with_streaming_response.get( - core_model_id="core_model_id", + async def test_streaming_response_register(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.models.with_streaming_response.register( + model={ + "identifier": "identifier", + "llama_model": "llama_model", + "metadata": {"foo": True}, + "provider_id": "provider_id", + }, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" model = await response.parse() - assert_matches_type(Optional[ModelServingSpec], model, path=["response"]) + assert model is None assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_post_training.py b/tests/api_resources/test_post_training.py index 5b1db4ad..ab79e75d 100644 --- a/tests/api_resources/test_post_training.py +++ b/tests/api_resources/test_post_training.py @@ -29,10 +29,7 @@ def test_method_preference_optimize(self, client: LlamaStackClient) -> None: "reward_clip": 0, "reward_scale": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", finetuned_model="https://example.com", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", @@ -52,10 +49,7 @@ def test_method_preference_optimize(self, client: LlamaStackClient) -> None: "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -69,11 +63,7 @@ def test_method_preference_optimize_with_all_params(self, client: LlamaStackClie "reward_clip": 0, "reward_scale": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - "metadata": {"foo": True}, - }, + dataset_id="dataset_id", finetuned_model="https://example.com", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", @@ -93,11 +83,7 @@ def test_method_preference_optimize_with_all_params(self, client: LlamaStackClie "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - "metadata": {"foo": True}, - }, + validation_dataset_id="validation_dataset_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -112,10 +98,7 @@ def test_raw_response_preference_optimize(self, client: LlamaStackClient) -> Non "reward_clip": 0, "reward_scale": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", finetuned_model="https://example.com", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", @@ -135,10 +118,7 @@ def test_raw_response_preference_optimize(self, client: LlamaStackClient) -> Non "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) assert response.is_closed is True @@ -156,10 +136,7 @@ def test_streaming_response_preference_optimize(self, client: LlamaStackClient) "reward_clip": 0, "reward_scale": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", finetuned_model="https://example.com", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", @@ -179,10 +156,7 @@ def test_streaming_response_preference_optimize(self, client: LlamaStackClient) "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -203,10 +177,7 @@ def test_method_supervised_fine_tune(self, client: LlamaStackClient) -> None: "lora_attn_modules": ["string", "string", "string"], "rank": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", logger_config={"foo": True}, @@ -226,10 +197,7 @@ def test_method_supervised_fine_tune(self, client: LlamaStackClient) -> None: "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -244,11 +212,7 @@ def test_method_supervised_fine_tune_with_all_params(self, client: LlamaStackCli "lora_attn_modules": ["string", "string", "string"], "rank": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - "metadata": {"foo": True}, - }, + dataset_id="dataset_id", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", logger_config={"foo": True}, @@ -268,11 +232,7 @@ def test_method_supervised_fine_tune_with_all_params(self, client: LlamaStackCli "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - "metadata": {"foo": True}, - }, + validation_dataset_id="validation_dataset_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -288,10 +248,7 @@ def test_raw_response_supervised_fine_tune(self, client: LlamaStackClient) -> No "lora_attn_modules": ["string", "string", "string"], "rank": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", logger_config={"foo": True}, @@ -311,10 +268,7 @@ def test_raw_response_supervised_fine_tune(self, client: LlamaStackClient) -> No "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) assert response.is_closed is True @@ -333,10 +287,7 @@ def test_streaming_response_supervised_fine_tune(self, client: LlamaStackClient) "lora_attn_modules": ["string", "string", "string"], "rank": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", logger_config={"foo": True}, @@ -356,10 +307,7 @@ def test_streaming_response_supervised_fine_tune(self, client: LlamaStackClient) "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -383,10 +331,7 @@ async def test_method_preference_optimize(self, async_client: AsyncLlamaStackCli "reward_clip": 0, "reward_scale": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", finetuned_model="https://example.com", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", @@ -406,10 +351,7 @@ async def test_method_preference_optimize(self, async_client: AsyncLlamaStackCli "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -423,11 +365,7 @@ async def test_method_preference_optimize_with_all_params(self, async_client: As "reward_clip": 0, "reward_scale": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - "metadata": {"foo": True}, - }, + dataset_id="dataset_id", finetuned_model="https://example.com", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", @@ -447,11 +385,7 @@ async def test_method_preference_optimize_with_all_params(self, async_client: As "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - "metadata": {"foo": True}, - }, + validation_dataset_id="validation_dataset_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -466,10 +400,7 @@ async def test_raw_response_preference_optimize(self, async_client: AsyncLlamaSt "reward_clip": 0, "reward_scale": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", finetuned_model="https://example.com", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", @@ -489,10 +420,7 @@ async def test_raw_response_preference_optimize(self, async_client: AsyncLlamaSt "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) assert response.is_closed is True @@ -510,10 +438,7 @@ async def test_streaming_response_preference_optimize(self, async_client: AsyncL "reward_clip": 0, "reward_scale": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", finetuned_model="https://example.com", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", @@ -533,10 +458,7 @@ async def test_streaming_response_preference_optimize(self, async_client: AsyncL "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" @@ -557,10 +479,7 @@ async def test_method_supervised_fine_tune(self, async_client: AsyncLlamaStackCl "lora_attn_modules": ["string", "string", "string"], "rank": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", logger_config={"foo": True}, @@ -580,10 +499,7 @@ async def test_method_supervised_fine_tune(self, async_client: AsyncLlamaStackCl "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -598,11 +514,7 @@ async def test_method_supervised_fine_tune_with_all_params(self, async_client: A "lora_attn_modules": ["string", "string", "string"], "rank": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - "metadata": {"foo": True}, - }, + dataset_id="dataset_id", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", logger_config={"foo": True}, @@ -622,11 +534,7 @@ async def test_method_supervised_fine_tune_with_all_params(self, async_client: A "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - "metadata": {"foo": True}, - }, + validation_dataset_id="validation_dataset_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) assert_matches_type(PostTrainingJob, post_training, path=["response"]) @@ -642,10 +550,7 @@ async def test_raw_response_supervised_fine_tune(self, async_client: AsyncLlamaS "lora_attn_modules": ["string", "string", "string"], "rank": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", logger_config={"foo": True}, @@ -665,10 +570,7 @@ async def test_raw_response_supervised_fine_tune(self, async_client: AsyncLlamaS "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) assert response.is_closed is True @@ -687,10 +589,7 @@ async def test_streaming_response_supervised_fine_tune(self, async_client: Async "lora_attn_modules": ["string", "string", "string"], "rank": 0, }, - dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + dataset_id="dataset_id", hyperparam_search_config={"foo": True}, job_uuid="job_uuid", logger_config={"foo": True}, @@ -710,10 +609,7 @@ async def test_streaming_response_supervised_fine_tune(self, async_client: Async "n_iters": 0, "shuffle": True, }, - validation_dataset={ - "columns": {"foo": "dialog"}, - "content_url": "https://example.com", - }, + validation_dataset_id="validation_dataset_id", ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" diff --git a/tests/api_resources/test_providers.py b/tests/api_resources/test_providers.py new file mode 100644 index 00000000..98e14e04 --- /dev/null +++ b/tests/api_resources/test_providers.py @@ -0,0 +1,86 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types import ProviderListResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestProviders: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_list(self, client: LlamaStackClient) -> None: + provider = client.providers.list() + assert_matches_type(ProviderListResponse, provider, path=["response"]) + + @parametrize + def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: + provider = client.providers.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ProviderListResponse, provider, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: LlamaStackClient) -> None: + response = client.providers.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + provider = response.parse() + assert_matches_type(ProviderListResponse, provider, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: LlamaStackClient) -> None: + with client.providers.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + provider = response.parse() + assert_matches_type(ProviderListResponse, provider, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncProviders: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: + provider = await async_client.providers.list() + assert_matches_type(ProviderListResponse, provider, path=["response"]) + + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + provider = await async_client.providers.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ProviderListResponse, provider, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.providers.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + provider = await response.parse() + assert_matches_type(ProviderListResponse, provider, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.providers.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + provider = await response.parse() + assert_matches_type(ProviderListResponse, provider, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_reward_scoring.py b/tests/api_resources/test_reward_scoring.py index 7d78fd07..12823fbd 100644 --- a/tests/api_resources/test_reward_scoring.py +++ b/tests/api_resources/test_reward_scoring.py @@ -9,7 +9,7 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import RewardScoring +from llama_stack_client.types import RewardScoringResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -114,7 +114,7 @@ def test_method_score(self, client: LlamaStackClient) -> None: ], model="model", ) - assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + assert_matches_type(RewardScoringResponse, reward_scoring, path=["response"]) @parametrize def test_method_score_with_all_params(self, client: LlamaStackClient) -> None: @@ -232,7 +232,7 @@ def test_method_score_with_all_params(self, client: LlamaStackClient) -> None: model="model", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + assert_matches_type(RewardScoringResponse, reward_scoring, path=["response"]) @parametrize def test_raw_response_score(self, client: LlamaStackClient) -> None: @@ -335,7 +335,7 @@ def test_raw_response_score(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" reward_scoring = response.parse() - assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + assert_matches_type(RewardScoringResponse, reward_scoring, path=["response"]) @parametrize def test_streaming_response_score(self, client: LlamaStackClient) -> None: @@ -438,7 +438,7 @@ def test_streaming_response_score(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" reward_scoring = response.parse() - assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + assert_matches_type(RewardScoringResponse, reward_scoring, path=["response"]) assert cast(Any, response.is_closed) is True @@ -543,7 +543,7 @@ async def test_method_score(self, async_client: AsyncLlamaStackClient) -> None: ], model="model", ) - assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + assert_matches_type(RewardScoringResponse, reward_scoring, path=["response"]) @parametrize async def test_method_score_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: @@ -661,7 +661,7 @@ async def test_method_score_with_all_params(self, async_client: AsyncLlamaStackC model="model", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + assert_matches_type(RewardScoringResponse, reward_scoring, path=["response"]) @parametrize async def test_raw_response_score(self, async_client: AsyncLlamaStackClient) -> None: @@ -764,7 +764,7 @@ async def test_raw_response_score(self, async_client: AsyncLlamaStackClient) -> assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" reward_scoring = await response.parse() - assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + assert_matches_type(RewardScoringResponse, reward_scoring, path=["response"]) @parametrize async def test_streaming_response_score(self, async_client: AsyncLlamaStackClient) -> None: @@ -867,6 +867,6 @@ async def test_streaming_response_score(self, async_client: AsyncLlamaStackClien assert response.http_request.headers.get("X-Stainless-Lang") == "python" reward_scoring = await response.parse() - assert_matches_type(RewardScoring, reward_scoring, path=["response"]) + assert_matches_type(RewardScoringResponse, reward_scoring, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_routes.py b/tests/api_resources/test_routes.py new file mode 100644 index 00000000..d7fb5f71 --- /dev/null +++ b/tests/api_resources/test_routes.py @@ -0,0 +1,86 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types import RouteListResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestRoutes: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_list(self, client: LlamaStackClient) -> None: + route = client.routes.list() + assert_matches_type(RouteListResponse, route, path=["response"]) + + @parametrize + def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: + route = client.routes.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(RouteListResponse, route, path=["response"]) + + @parametrize + def test_raw_response_list(self, client: LlamaStackClient) -> None: + response = client.routes.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + route = response.parse() + assert_matches_type(RouteListResponse, route, path=["response"]) + + @parametrize + def test_streaming_response_list(self, client: LlamaStackClient) -> None: + with client.routes.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + route = response.parse() + assert_matches_type(RouteListResponse, route, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncRoutes: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: + route = await async_client.routes.list() + assert_matches_type(RouteListResponse, route, path=["response"]) + + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + route = await async_client.routes.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(RouteListResponse, route, path=["response"]) + + @parametrize + async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.routes.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + route = await response.parse() + assert_matches_type(RouteListResponse, route, path=["response"]) + + @parametrize + async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.routes.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + route = await response.parse() + assert_matches_type(RouteListResponse, route, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_safety.py b/tests/api_resources/test_safety.py index 29761006..d1ffba98 100644 --- a/tests/api_resources/test_safety.py +++ b/tests/api_resources/test_safety.py @@ -9,7 +9,7 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import RunSheidResponse +from llama_stack_client.types import RunShieldResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -37,7 +37,7 @@ def test_method_run_shield(self, client: LlamaStackClient) -> None: params={"foo": True}, shield_type="shield_type", ) - assert_matches_type(RunSheidResponse, safety, path=["response"]) + assert_matches_type(RunShieldResponse, safety, path=["response"]) @parametrize def test_method_run_shield_with_all_params(self, client: LlamaStackClient) -> None: @@ -63,7 +63,7 @@ def test_method_run_shield_with_all_params(self, client: LlamaStackClient) -> No shield_type="shield_type", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(RunSheidResponse, safety, path=["response"]) + assert_matches_type(RunShieldResponse, safety, path=["response"]) @parametrize def test_raw_response_run_shield(self, client: LlamaStackClient) -> None: @@ -89,7 +89,7 @@ def test_raw_response_run_shield(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" safety = response.parse() - assert_matches_type(RunSheidResponse, safety, path=["response"]) + assert_matches_type(RunShieldResponse, safety, path=["response"]) @parametrize def test_streaming_response_run_shield(self, client: LlamaStackClient) -> None: @@ -115,7 +115,7 @@ def test_streaming_response_run_shield(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" safety = response.parse() - assert_matches_type(RunSheidResponse, safety, path=["response"]) + assert_matches_type(RunShieldResponse, safety, path=["response"]) assert cast(Any, response.is_closed) is True @@ -143,7 +143,7 @@ async def test_method_run_shield(self, async_client: AsyncLlamaStackClient) -> N params={"foo": True}, shield_type="shield_type", ) - assert_matches_type(RunSheidResponse, safety, path=["response"]) + assert_matches_type(RunShieldResponse, safety, path=["response"]) @parametrize async def test_method_run_shield_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: @@ -169,7 +169,7 @@ async def test_method_run_shield_with_all_params(self, async_client: AsyncLlamaS shield_type="shield_type", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(RunSheidResponse, safety, path=["response"]) + assert_matches_type(RunShieldResponse, safety, path=["response"]) @parametrize async def test_raw_response_run_shield(self, async_client: AsyncLlamaStackClient) -> None: @@ -195,7 +195,7 @@ async def test_raw_response_run_shield(self, async_client: AsyncLlamaStackClient assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" safety = await response.parse() - assert_matches_type(RunSheidResponse, safety, path=["response"]) + assert_matches_type(RunShieldResponse, safety, path=["response"]) @parametrize async def test_streaming_response_run_shield(self, async_client: AsyncLlamaStackClient) -> None: @@ -221,6 +221,6 @@ async def test_streaming_response_run_shield(self, async_client: AsyncLlamaStack assert response.http_request.headers.get("X-Stainless-Lang") == "python" safety = await response.parse() - assert_matches_type(RunSheidResponse, safety, path=["response"]) + assert_matches_type(RunShieldResponse, safety, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_scoring.py b/tests/api_resources/test_scoring.py new file mode 100644 index 00000000..d209b4db --- /dev/null +++ b/tests/api_resources/test_scoring.py @@ -0,0 +1,202 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types import ScoreResponse, ScoreBatchResponse + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestScoring: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_score(self, client: LlamaStackClient) -> None: + scoring = client.scoring.score( + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(ScoreResponse, scoring, path=["response"]) + + @parametrize + def test_method_score_with_all_params(self, client: LlamaStackClient) -> None: + scoring = client.scoring.score( + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ScoreResponse, scoring, path=["response"]) + + @parametrize + def test_raw_response_score(self, client: LlamaStackClient) -> None: + response = client.scoring.with_raw_response.score( + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + scoring = response.parse() + assert_matches_type(ScoreResponse, scoring, path=["response"]) + + @parametrize + def test_streaming_response_score(self, client: LlamaStackClient) -> None: + with client.scoring.with_streaming_response.score( + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + scoring = response.parse() + assert_matches_type(ScoreResponse, scoring, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_score_batch(self, client: LlamaStackClient) -> None: + scoring = client.scoring.score_batch( + dataset_id="dataset_id", + save_results_dataset=True, + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(ScoreBatchResponse, scoring, path=["response"]) + + @parametrize + def test_method_score_batch_with_all_params(self, client: LlamaStackClient) -> None: + scoring = client.scoring.score_batch( + dataset_id="dataset_id", + save_results_dataset=True, + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ScoreBatchResponse, scoring, path=["response"]) + + @parametrize + def test_raw_response_score_batch(self, client: LlamaStackClient) -> None: + response = client.scoring.with_raw_response.score_batch( + dataset_id="dataset_id", + save_results_dataset=True, + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + scoring = response.parse() + assert_matches_type(ScoreBatchResponse, scoring, path=["response"]) + + @parametrize + def test_streaming_response_score_batch(self, client: LlamaStackClient) -> None: + with client.scoring.with_streaming_response.score_batch( + dataset_id="dataset_id", + save_results_dataset=True, + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + scoring = response.parse() + assert_matches_type(ScoreBatchResponse, scoring, path=["response"]) + + assert cast(Any, response.is_closed) is True + + +class TestAsyncScoring: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_score(self, async_client: AsyncLlamaStackClient) -> None: + scoring = await async_client.scoring.score( + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(ScoreResponse, scoring, path=["response"]) + + @parametrize + async def test_method_score_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + scoring = await async_client.scoring.score( + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ScoreResponse, scoring, path=["response"]) + + @parametrize + async def test_raw_response_score(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.scoring.with_raw_response.score( + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + scoring = await response.parse() + assert_matches_type(ScoreResponse, scoring, path=["response"]) + + @parametrize + async def test_streaming_response_score(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.scoring.with_streaming_response.score( + input_rows=[{"foo": True}, {"foo": True}, {"foo": True}], + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + scoring = await response.parse() + assert_matches_type(ScoreResponse, scoring, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_score_batch(self, async_client: AsyncLlamaStackClient) -> None: + scoring = await async_client.scoring.score_batch( + dataset_id="dataset_id", + save_results_dataset=True, + scoring_functions=["string", "string", "string"], + ) + assert_matches_type(ScoreBatchResponse, scoring, path=["response"]) + + @parametrize + async def test_method_score_batch_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + scoring = await async_client.scoring.score_batch( + dataset_id="dataset_id", + save_results_dataset=True, + scoring_functions=["string", "string", "string"], + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ScoreBatchResponse, scoring, path=["response"]) + + @parametrize + async def test_raw_response_score_batch(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.scoring.with_raw_response.score_batch( + dataset_id="dataset_id", + save_results_dataset=True, + scoring_functions=["string", "string", "string"], + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + scoring = await response.parse() + assert_matches_type(ScoreBatchResponse, scoring, path=["response"]) + + @parametrize + async def test_streaming_response_score_batch(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.scoring.with_streaming_response.score_batch( + dataset_id="dataset_id", + save_results_dataset=True, + scoring_functions=["string", "string", "string"], + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + scoring = await response.parse() + assert_matches_type(ScoreBatchResponse, scoring, path=["response"]) + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_scoring_functions.py b/tests/api_resources/test_scoring_functions.py new file mode 100644 index 00000000..9de579f8 --- /dev/null +++ b/tests/api_resources/test_scoring_functions.py @@ -0,0 +1,438 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from __future__ import annotations + +import os +from typing import Any, Optional, cast + +import pytest + +from tests.utils import assert_matches_type +from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient +from llama_stack_client.types import ( + ScoringFnDefWithProvider, +) + +base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") + + +class TestScoringFunctions: + parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + def test_method_retrieve(self, client: LlamaStackClient) -> None: + scoring_function = client.scoring_functions.retrieve( + name="name", + ) + assert_matches_type(Optional[ScoringFnDefWithProvider], scoring_function, path=["response"]) + + @parametrize + def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: + scoring_function = client.scoring_functions.retrieve( + name="name", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[ScoringFnDefWithProvider], scoring_function, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: + response = client.scoring_functions.with_raw_response.retrieve( + name="name", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + scoring_function = response.parse() + assert_matches_type(Optional[ScoringFnDefWithProvider], scoring_function, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: + with client.scoring_functions.with_streaming_response.retrieve( + name="name", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + scoring_function = response.parse() + assert_matches_type(Optional[ScoringFnDefWithProvider], scoring_function, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + def test_method_list(self, client: LlamaStackClient) -> None: + scoring_function = client.scoring_functions.list() + assert_matches_type(ScoringFnDefWithProvider, scoring_function, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: + scoring_function = client.scoring_functions.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ScoringFnDefWithProvider, scoring_function, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + def test_raw_response_list(self, client: LlamaStackClient) -> None: + response = client.scoring_functions.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + scoring_function = response.parse() + assert_matches_type(ScoringFnDefWithProvider, scoring_function, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + def test_streaming_response_list(self, client: LlamaStackClient) -> None: + with client.scoring_functions.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + scoring_function = response.parse() + assert_matches_type(ScoringFnDefWithProvider, scoring_function, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + def test_method_register(self, client: LlamaStackClient) -> None: + scoring_function = client.scoring_functions.register( + function_def={ + "identifier": "identifier", + "metadata": {"foo": True}, + "parameters": [ + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + ], + "provider_id": "provider_id", + "return_type": {"type": "string"}, + }, + ) + assert scoring_function is None + + @parametrize + def test_method_register_with_all_params(self, client: LlamaStackClient) -> None: + scoring_function = client.scoring_functions.register( + function_def={ + "identifier": "identifier", + "metadata": {"foo": True}, + "parameters": [ + { + "name": "name", + "type": {"type": "string"}, + "description": "description", + }, + { + "name": "name", + "type": {"type": "string"}, + "description": "description", + }, + { + "name": "name", + "type": {"type": "string"}, + "description": "description", + }, + ], + "provider_id": "provider_id", + "return_type": {"type": "string"}, + "context": { + "judge_model": "judge_model", + "judge_score_regex": ["string", "string", "string"], + "prompt_template": "prompt_template", + }, + "description": "description", + }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert scoring_function is None + + @parametrize + def test_raw_response_register(self, client: LlamaStackClient) -> None: + response = client.scoring_functions.with_raw_response.register( + function_def={ + "identifier": "identifier", + "metadata": {"foo": True}, + "parameters": [ + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + ], + "provider_id": "provider_id", + "return_type": {"type": "string"}, + }, + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + scoring_function = response.parse() + assert scoring_function is None + + @parametrize + def test_streaming_response_register(self, client: LlamaStackClient) -> None: + with client.scoring_functions.with_streaming_response.register( + function_def={ + "identifier": "identifier", + "metadata": {"foo": True}, + "parameters": [ + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + ], + "provider_id": "provider_id", + "return_type": {"type": "string"}, + }, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + scoring_function = response.parse() + assert scoring_function is None + + assert cast(Any, response.is_closed) is True + + +class TestAsyncScoringFunctions: + parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + + @parametrize + async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + scoring_function = await async_client.scoring_functions.retrieve( + name="name", + ) + assert_matches_type(Optional[ScoringFnDefWithProvider], scoring_function, path=["response"]) + + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + scoring_function = await async_client.scoring_functions.retrieve( + name="name", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[ScoringFnDefWithProvider], scoring_function, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.scoring_functions.with_raw_response.retrieve( + name="name", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + scoring_function = await response.parse() + assert_matches_type(Optional[ScoringFnDefWithProvider], scoring_function, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.scoring_functions.with_streaming_response.retrieve( + name="name", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + scoring_function = await response.parse() + assert_matches_type(Optional[ScoringFnDefWithProvider], scoring_function, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: + scoring_function = await async_client.scoring_functions.list() + assert_matches_type(ScoringFnDefWithProvider, scoring_function, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + scoring_function = await async_client.scoring_functions.list( + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(ScoringFnDefWithProvider, scoring_function, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.scoring_functions.with_raw_response.list() + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + scoring_function = await response.parse() + assert_matches_type(ScoringFnDefWithProvider, scoring_function, path=["response"]) + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) + @parametrize + async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.scoring_functions.with_streaming_response.list() as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + scoring_function = await response.parse() + assert_matches_type(ScoringFnDefWithProvider, scoring_function, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @parametrize + async def test_method_register(self, async_client: AsyncLlamaStackClient) -> None: + scoring_function = await async_client.scoring_functions.register( + function_def={ + "identifier": "identifier", + "metadata": {"foo": True}, + "parameters": [ + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + ], + "provider_id": "provider_id", + "return_type": {"type": "string"}, + }, + ) + assert scoring_function is None + + @parametrize + async def test_method_register_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + scoring_function = await async_client.scoring_functions.register( + function_def={ + "identifier": "identifier", + "metadata": {"foo": True}, + "parameters": [ + { + "name": "name", + "type": {"type": "string"}, + "description": "description", + }, + { + "name": "name", + "type": {"type": "string"}, + "description": "description", + }, + { + "name": "name", + "type": {"type": "string"}, + "description": "description", + }, + ], + "provider_id": "provider_id", + "return_type": {"type": "string"}, + "context": { + "judge_model": "judge_model", + "judge_score_regex": ["string", "string", "string"], + "prompt_template": "prompt_template", + }, + "description": "description", + }, + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert scoring_function is None + + @parametrize + async def test_raw_response_register(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.scoring_functions.with_raw_response.register( + function_def={ + "identifier": "identifier", + "metadata": {"foo": True}, + "parameters": [ + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + ], + "provider_id": "provider_id", + "return_type": {"type": "string"}, + }, + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + scoring_function = await response.parse() + assert scoring_function is None + + @parametrize + async def test_streaming_response_register(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.scoring_functions.with_streaming_response.register( + function_def={ + "identifier": "identifier", + "metadata": {"foo": True}, + "parameters": [ + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + { + "name": "name", + "type": {"type": "string"}, + }, + ], + "provider_id": "provider_id", + "return_type": {"type": "string"}, + }, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + scoring_function = await response.parse() + assert scoring_function is None + + assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_shields.py b/tests/api_resources/test_shields.py index d8f75254..2ef574a4 100644 --- a/tests/api_resources/test_shields.py +++ b/tests/api_resources/test_shields.py @@ -9,7 +9,7 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import ShieldSpec +from llama_stack_client.types import ShieldDefWithProvider base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -17,18 +17,66 @@ class TestShields: parametrize = pytest.mark.parametrize("client", [False, True], indirect=True, ids=["loose", "strict"]) + @parametrize + def test_method_retrieve(self, client: LlamaStackClient) -> None: + shield = client.shields.retrieve( + shield_type="shield_type", + ) + assert_matches_type(Optional[ShieldDefWithProvider], shield, path=["response"]) + + @parametrize + def test_method_retrieve_with_all_params(self, client: LlamaStackClient) -> None: + shield = client.shields.retrieve( + shield_type="shield_type", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[ShieldDefWithProvider], shield, path=["response"]) + + @parametrize + def test_raw_response_retrieve(self, client: LlamaStackClient) -> None: + response = client.shields.with_raw_response.retrieve( + shield_type="shield_type", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + shield = response.parse() + assert_matches_type(Optional[ShieldDefWithProvider], shield, path=["response"]) + + @parametrize + def test_streaming_response_retrieve(self, client: LlamaStackClient) -> None: + with client.shields.with_streaming_response.retrieve( + shield_type="shield_type", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + shield = response.parse() + assert_matches_type(Optional[ShieldDefWithProvider], shield, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_method_list(self, client: LlamaStackClient) -> None: shield = client.shields.list() - assert_matches_type(ShieldSpec, shield, path=["response"]) + assert_matches_type(ShieldDefWithProvider, shield, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_method_list_with_all_params(self, client: LlamaStackClient) -> None: shield = client.shields.list( x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(ShieldSpec, shield, path=["response"]) + assert_matches_type(ShieldDefWithProvider, shield, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_raw_response_list(self, client: LlamaStackClient) -> None: response = client.shields.with_raw_response.list() @@ -36,8 +84,11 @@ def test_raw_response_list(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = response.parse() - assert_matches_type(ShieldSpec, shield, path=["response"]) + assert_matches_type(ShieldDefWithProvider, shield, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize def test_streaming_response_list(self, client: LlamaStackClient) -> None: with client.shields.with_streaming_response.list() as response: @@ -45,46 +96,66 @@ def test_streaming_response_list(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = response.parse() - assert_matches_type(ShieldSpec, shield, path=["response"]) + assert_matches_type(ShieldDefWithProvider, shield, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - def test_method_get(self, client: LlamaStackClient) -> None: - shield = client.shields.get( - shield_type="shield_type", + def test_method_register(self, client: LlamaStackClient) -> None: + shield = client.shields.register( + shield={ + "identifier": "identifier", + "params": {"foo": True}, + "provider_id": "provider_id", + "type": "type", + }, ) - assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + assert shield is None @parametrize - def test_method_get_with_all_params(self, client: LlamaStackClient) -> None: - shield = client.shields.get( - shield_type="shield_type", + def test_method_register_with_all_params(self, client: LlamaStackClient) -> None: + shield = client.shields.register( + shield={ + "identifier": "identifier", + "params": {"foo": True}, + "provider_id": "provider_id", + "type": "type", + }, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + assert shield is None @parametrize - def test_raw_response_get(self, client: LlamaStackClient) -> None: - response = client.shields.with_raw_response.get( - shield_type="shield_type", + def test_raw_response_register(self, client: LlamaStackClient) -> None: + response = client.shields.with_raw_response.register( + shield={ + "identifier": "identifier", + "params": {"foo": True}, + "provider_id": "provider_id", + "type": "type", + }, ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = response.parse() - assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + assert shield is None @parametrize - def test_streaming_response_get(self, client: LlamaStackClient) -> None: - with client.shields.with_streaming_response.get( - shield_type="shield_type", + def test_streaming_response_register(self, client: LlamaStackClient) -> None: + with client.shields.with_streaming_response.register( + shield={ + "identifier": "identifier", + "params": {"foo": True}, + "provider_id": "provider_id", + "type": "type", + }, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = response.parse() - assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + assert shield is None assert cast(Any, response.is_closed) is True @@ -92,18 +163,66 @@ def test_streaming_response_get(self, client: LlamaStackClient) -> None: class TestAsyncShields: parametrize = pytest.mark.parametrize("async_client", [False, True], indirect=True, ids=["loose", "strict"]) + @parametrize + async def test_method_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + shield = await async_client.shields.retrieve( + shield_type="shield_type", + ) + assert_matches_type(Optional[ShieldDefWithProvider], shield, path=["response"]) + + @parametrize + async def test_method_retrieve_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + shield = await async_client.shields.retrieve( + shield_type="shield_type", + x_llama_stack_provider_data="X-LlamaStack-ProviderData", + ) + assert_matches_type(Optional[ShieldDefWithProvider], shield, path=["response"]) + + @parametrize + async def test_raw_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.shields.with_raw_response.retrieve( + shield_type="shield_type", + ) + + assert response.is_closed is True + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + shield = await response.parse() + assert_matches_type(Optional[ShieldDefWithProvider], shield, path=["response"]) + + @parametrize + async def test_streaming_response_retrieve(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.shields.with_streaming_response.retrieve( + shield_type="shield_type", + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + shield = await response.parse() + assert_matches_type(Optional[ShieldDefWithProvider], shield, path=["response"]) + + assert cast(Any, response.is_closed) is True + + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_method_list(self, async_client: AsyncLlamaStackClient) -> None: shield = await async_client.shields.list() - assert_matches_type(ShieldSpec, shield, path=["response"]) + assert_matches_type(ShieldDefWithProvider, shield, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_method_list_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: shield = await async_client.shields.list( x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(ShieldSpec, shield, path=["response"]) + assert_matches_type(ShieldDefWithProvider, shield, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> None: response = await async_client.shields.with_raw_response.list() @@ -111,8 +230,11 @@ async def test_raw_response_list(self, async_client: AsyncLlamaStackClient) -> N assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = await response.parse() - assert_matches_type(ShieldSpec, shield, path=["response"]) + assert_matches_type(ShieldDefWithProvider, shield, path=["response"]) + @pytest.mark.skip( + reason="currently no good way to test endpoints with content type application/jsonl, Prism mock server will fail" + ) @parametrize async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient) -> None: async with async_client.shields.with_streaming_response.list() as response: @@ -120,45 +242,65 @@ async def test_streaming_response_list(self, async_client: AsyncLlamaStackClient assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = await response.parse() - assert_matches_type(ShieldSpec, shield, path=["response"]) + assert_matches_type(ShieldDefWithProvider, shield, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - async def test_method_get(self, async_client: AsyncLlamaStackClient) -> None: - shield = await async_client.shields.get( - shield_type="shield_type", + async def test_method_register(self, async_client: AsyncLlamaStackClient) -> None: + shield = await async_client.shields.register( + shield={ + "identifier": "identifier", + "params": {"foo": True}, + "provider_id": "provider_id", + "type": "type", + }, ) - assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + assert shield is None @parametrize - async def test_method_get_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - shield = await async_client.shields.get( - shield_type="shield_type", + async def test_method_register_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + shield = await async_client.shields.register( + shield={ + "identifier": "identifier", + "params": {"foo": True}, + "provider_id": "provider_id", + "type": "type", + }, x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + assert shield is None @parametrize - async def test_raw_response_get(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.shields.with_raw_response.get( - shield_type="shield_type", + async def test_raw_response_register(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.shields.with_raw_response.register( + shield={ + "identifier": "identifier", + "params": {"foo": True}, + "provider_id": "provider_id", + "type": "type", + }, ) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = await response.parse() - assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + assert shield is None @parametrize - async def test_streaming_response_get(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.shields.with_streaming_response.get( - shield_type="shield_type", + async def test_streaming_response_register(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.shields.with_streaming_response.register( + shield={ + "identifier": "identifier", + "params": {"foo": True}, + "provider_id": "provider_id", + "type": "type", + }, ) as response: assert not response.is_closed assert response.http_request.headers.get("X-Stainless-Lang") == "python" shield = await response.parse() - assert_matches_type(Optional[ShieldSpec], shield, path=["response"]) + assert shield is None assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_synthetic_data_generation.py b/tests/api_resources/test_synthetic_data_generation.py index 04cc5326..1343e693 100644 --- a/tests/api_resources/test_synthetic_data_generation.py +++ b/tests/api_resources/test_synthetic_data_generation.py @@ -9,7 +9,7 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import SyntheticDataGeneration +from llama_stack_client.types import SyntheticDataGenerationResponse base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -36,7 +36,7 @@ def test_method_generate(self, client: LlamaStackClient) -> None: ], filtering_function="none", ) - assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) + assert_matches_type(SyntheticDataGenerationResponse, synthetic_data_generation, path=["response"]) @parametrize def test_method_generate_with_all_params(self, client: LlamaStackClient) -> None: @@ -62,7 +62,7 @@ def test_method_generate_with_all_params(self, client: LlamaStackClient) -> None model="model", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) + assert_matches_type(SyntheticDataGenerationResponse, synthetic_data_generation, path=["response"]) @parametrize def test_raw_response_generate(self, client: LlamaStackClient) -> None: @@ -87,7 +87,7 @@ def test_raw_response_generate(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" synthetic_data_generation = response.parse() - assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) + assert_matches_type(SyntheticDataGenerationResponse, synthetic_data_generation, path=["response"]) @parametrize def test_streaming_response_generate(self, client: LlamaStackClient) -> None: @@ -112,7 +112,7 @@ def test_streaming_response_generate(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" synthetic_data_generation = response.parse() - assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) + assert_matches_type(SyntheticDataGenerationResponse, synthetic_data_generation, path=["response"]) assert cast(Any, response.is_closed) is True @@ -139,7 +139,7 @@ async def test_method_generate(self, async_client: AsyncLlamaStackClient) -> Non ], filtering_function="none", ) - assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) + assert_matches_type(SyntheticDataGenerationResponse, synthetic_data_generation, path=["response"]) @parametrize async def test_method_generate_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: @@ -165,7 +165,7 @@ async def test_method_generate_with_all_params(self, async_client: AsyncLlamaSta model="model", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) + assert_matches_type(SyntheticDataGenerationResponse, synthetic_data_generation, path=["response"]) @parametrize async def test_raw_response_generate(self, async_client: AsyncLlamaStackClient) -> None: @@ -190,7 +190,7 @@ async def test_raw_response_generate(self, async_client: AsyncLlamaStackClient) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" synthetic_data_generation = await response.parse() - assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) + assert_matches_type(SyntheticDataGenerationResponse, synthetic_data_generation, path=["response"]) @parametrize async def test_streaming_response_generate(self, async_client: AsyncLlamaStackClient) -> None: @@ -215,6 +215,6 @@ async def test_streaming_response_generate(self, async_client: AsyncLlamaStackCl assert response.http_request.headers.get("X-Stainless-Lang") == "python" synthetic_data_generation = await response.parse() - assert_matches_type(SyntheticDataGeneration, synthetic_data_generation, path=["response"]) + assert_matches_type(SyntheticDataGenerationResponse, synthetic_data_generation, path=["response"]) assert cast(Any, response.is_closed) is True diff --git a/tests/api_resources/test_telemetry.py b/tests/api_resources/test_telemetry.py index 83c4608d..f581e066 100644 --- a/tests/api_resources/test_telemetry.py +++ b/tests/api_resources/test_telemetry.py @@ -9,7 +9,7 @@ from tests.utils import assert_matches_type from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient -from llama_stack_client.types import TelemetryGetTraceResponse +from llama_stack_client.types import Trace from llama_stack_client._utils import parse_datetime base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") @@ -23,7 +23,7 @@ def test_method_get_trace(self, client: LlamaStackClient) -> None: telemetry = client.telemetry.get_trace( trace_id="trace_id", ) - assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + assert_matches_type(Trace, telemetry, path=["response"]) @parametrize def test_method_get_trace_with_all_params(self, client: LlamaStackClient) -> None: @@ -31,7 +31,7 @@ def test_method_get_trace_with_all_params(self, client: LlamaStackClient) -> Non trace_id="trace_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + assert_matches_type(Trace, telemetry, path=["response"]) @parametrize def test_raw_response_get_trace(self, client: LlamaStackClient) -> None: @@ -42,7 +42,7 @@ def test_raw_response_get_trace(self, client: LlamaStackClient) -> None: assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" telemetry = response.parse() - assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + assert_matches_type(Trace, telemetry, path=["response"]) @parametrize def test_streaming_response_get_trace(self, client: LlamaStackClient) -> None: @@ -53,13 +53,13 @@ def test_streaming_response_get_trace(self, client: LlamaStackClient) -> None: assert response.http_request.headers.get("X-Stainless-Lang") == "python" telemetry = response.parse() - assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + assert_matches_type(Trace, telemetry, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - def test_method_log(self, client: LlamaStackClient) -> None: - telemetry = client.telemetry.log( + def test_method_log_event(self, client: LlamaStackClient) -> None: + telemetry = client.telemetry.log_event( event={ "message": "message", "severity": "verbose", @@ -72,8 +72,8 @@ def test_method_log(self, client: LlamaStackClient) -> None: assert telemetry is None @parametrize - def test_method_log_with_all_params(self, client: LlamaStackClient) -> None: - telemetry = client.telemetry.log( + def test_method_log_event_with_all_params(self, client: LlamaStackClient) -> None: + telemetry = client.telemetry.log_event( event={ "message": "message", "severity": "verbose", @@ -88,8 +88,8 @@ def test_method_log_with_all_params(self, client: LlamaStackClient) -> None: assert telemetry is None @parametrize - def test_raw_response_log(self, client: LlamaStackClient) -> None: - response = client.telemetry.with_raw_response.log( + def test_raw_response_log_event(self, client: LlamaStackClient) -> None: + response = client.telemetry.with_raw_response.log_event( event={ "message": "message", "severity": "verbose", @@ -106,8 +106,8 @@ def test_raw_response_log(self, client: LlamaStackClient) -> None: assert telemetry is None @parametrize - def test_streaming_response_log(self, client: LlamaStackClient) -> None: - with client.telemetry.with_streaming_response.log( + def test_streaming_response_log_event(self, client: LlamaStackClient) -> None: + with client.telemetry.with_streaming_response.log_event( event={ "message": "message", "severity": "verbose", @@ -134,7 +134,7 @@ async def test_method_get_trace(self, async_client: AsyncLlamaStackClient) -> No telemetry = await async_client.telemetry.get_trace( trace_id="trace_id", ) - assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + assert_matches_type(Trace, telemetry, path=["response"]) @parametrize async def test_method_get_trace_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: @@ -142,7 +142,7 @@ async def test_method_get_trace_with_all_params(self, async_client: AsyncLlamaSt trace_id="trace_id", x_llama_stack_provider_data="X-LlamaStack-ProviderData", ) - assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + assert_matches_type(Trace, telemetry, path=["response"]) @parametrize async def test_raw_response_get_trace(self, async_client: AsyncLlamaStackClient) -> None: @@ -153,7 +153,7 @@ async def test_raw_response_get_trace(self, async_client: AsyncLlamaStackClient) assert response.is_closed is True assert response.http_request.headers.get("X-Stainless-Lang") == "python" telemetry = await response.parse() - assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + assert_matches_type(Trace, telemetry, path=["response"]) @parametrize async def test_streaming_response_get_trace(self, async_client: AsyncLlamaStackClient) -> None: @@ -164,13 +164,13 @@ async def test_streaming_response_get_trace(self, async_client: AsyncLlamaStackC assert response.http_request.headers.get("X-Stainless-Lang") == "python" telemetry = await response.parse() - assert_matches_type(TelemetryGetTraceResponse, telemetry, path=["response"]) + assert_matches_type(Trace, telemetry, path=["response"]) assert cast(Any, response.is_closed) is True @parametrize - async def test_method_log(self, async_client: AsyncLlamaStackClient) -> None: - telemetry = await async_client.telemetry.log( + async def test_method_log_event(self, async_client: AsyncLlamaStackClient) -> None: + telemetry = await async_client.telemetry.log_event( event={ "message": "message", "severity": "verbose", @@ -183,8 +183,8 @@ async def test_method_log(self, async_client: AsyncLlamaStackClient) -> None: assert telemetry is None @parametrize - async def test_method_log_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: - telemetry = await async_client.telemetry.log( + async def test_method_log_event_with_all_params(self, async_client: AsyncLlamaStackClient) -> None: + telemetry = await async_client.telemetry.log_event( event={ "message": "message", "severity": "verbose", @@ -199,8 +199,8 @@ async def test_method_log_with_all_params(self, async_client: AsyncLlamaStackCli assert telemetry is None @parametrize - async def test_raw_response_log(self, async_client: AsyncLlamaStackClient) -> None: - response = await async_client.telemetry.with_raw_response.log( + async def test_raw_response_log_event(self, async_client: AsyncLlamaStackClient) -> None: + response = await async_client.telemetry.with_raw_response.log_event( event={ "message": "message", "severity": "verbose", @@ -217,8 +217,8 @@ async def test_raw_response_log(self, async_client: AsyncLlamaStackClient) -> No assert telemetry is None @parametrize - async def test_streaming_response_log(self, async_client: AsyncLlamaStackClient) -> None: - async with async_client.telemetry.with_streaming_response.log( + async def test_streaming_response_log_event(self, async_client: AsyncLlamaStackClient) -> None: + async with async_client.telemetry.with_streaming_response.log_event( event={ "message": "message", "severity": "verbose", diff --git a/tests/conftest.py b/tests/conftest.py index 61d94b66..645cbf63 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,11 +1,11 @@ from __future__ import annotations import os -import asyncio import logging from typing import TYPE_CHECKING, Iterator, AsyncIterator import pytest +from pytest_asyncio import is_async_test from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient @@ -17,11 +17,13 @@ logging.getLogger("llama_stack_client").setLevel(logging.DEBUG) -@pytest.fixture(scope="session") -def event_loop() -> Iterator[asyncio.AbstractEventLoop]: - loop = asyncio.new_event_loop() - yield loop - loop.close() +# automatically add `pytest.mark.asyncio()` to all of our async tests +# so we don't have to add that boilerplate everywhere +def pytest_collection_modifyitems(items: list[pytest.Function]) -> None: + pytest_asyncio_tests = (item for item in items if is_async_test(item)) + session_scope_marker = pytest.mark.asyncio(loop_scope="session") + for async_test in pytest_asyncio_tests: + async_test.add_marker(session_scope_marker, append=False) base_url = os.environ.get("TEST_API_BASE_URL", "http://127.0.0.1:4010") diff --git a/tests/test_client.py b/tests/test_client.py index a080020b..ae930b0e 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -10,6 +10,7 @@ import tracemalloc from typing import Any, Union, cast from unittest import mock +from typing_extensions import Literal import httpx import pytest @@ -17,6 +18,7 @@ from pydantic import ValidationError from llama_stack_client import LlamaStackClient, AsyncLlamaStackClient, APIResponseValidationError +from llama_stack_client._types import Omit from llama_stack_client._models import BaseModel, FinalRequestOptions from llama_stack_client._constants import RAW_RESPONSE_HEADER from llama_stack_client._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError @@ -519,14 +521,6 @@ def test_base_url_env(self) -> None: client = LlamaStackClient(_strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" - # explicit environment arg requires explicitness - with update_env(LLAMA_STACK_CLIENT_BASE_URL="http://localhost:5000/from/env"): - with pytest.raises(ValueError, match=r"you must pass base_url=None"): - LlamaStackClient(_strict_response_validation=True, environment="production") - - client = LlamaStackClient(base_url=None, _strict_response_validation=True, environment="production") - assert str(client.base_url).startswith("http://any-hosted-llama-stack-client.com") - @pytest.mark.parametrize( "client", [ @@ -663,6 +657,7 @@ class Model(BaseModel): [3, "", 0.5], [2, "", 0.5 * 2.0], [1, "", 0.5 * 4.0], + [-1100, "", 7.8], # test large number potentially overflowing ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) @@ -677,12 +672,31 @@ def test_parse_retry_after_header(self, remaining_retries: int, retry_after: str @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/agents/session/create").mock(side_effect=httpx.TimeoutException("Test timeout error")) + respx_mock.post("/inference/chat_completion").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): self.client.post( - "/agents/session/create", - body=cast(object, dict(agent_id="agent_id", session_name="session_name")), + "/inference/chat_completion", + body=cast( + object, + dict( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + ), + ), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -692,12 +706,31 @@ def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> No @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/agents/session/create").mock(return_value=httpx.Response(500)) + respx_mock.post("/inference/chat_completion").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): self.client.post( - "/agents/session/create", - body=cast(object, dict(agent_id="agent_id", session_name="session_name")), + "/inference/chat_completion", + body=cast( + object, + dict( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + ), + ), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -707,8 +740,13 @@ def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> Non @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) + @pytest.mark.parametrize("failure_mode", ["status", "exception"]) def test_retries_taken( - self, client: LlamaStackClient, failures_before_success: int, respx_mock: MockRouter + self, + client: LlamaStackClient, + failures_before_success: int, + failure_mode: Literal["status", "exception"], + respx_mock: MockRouter, ) -> None: client = client.with_options(max_retries=4) @@ -718,16 +756,114 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: nonlocal nb_retries if nb_retries < failures_before_success: nb_retries += 1 + if failure_mode == "exception": + raise RuntimeError("oops") return httpx.Response(500) return httpx.Response(200) - respx_mock.post("/agents/session/create").mock(side_effect=retry_handler) - - response = client.agents.sessions.with_raw_response.create(agent_id="agent_id", session_name="session_name") + respx_mock.post("/inference/chat_completion").mock(side_effect=retry_handler) + + response = client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + ) assert response.retries_taken == failures_before_success assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_omit_retry_count_header( + self, client: LlamaStackClient, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/inference/chat_completion").mock(side_effect=retry_handler) + + response = client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + extra_headers={"x-stainless-retry-count": Omit()}, + ) + + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + def test_overwrite_retry_count_header( + self, client: LlamaStackClient, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/inference/chat_completion").mock(side_effect=retry_handler) + + response = client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + extra_headers={"x-stainless-retry-count": "42"}, + ) + + assert response.http_request.headers.get("x-stainless-retry-count") == "42" + class TestAsyncLlamaStackClient: client = AsyncLlamaStackClient(base_url=base_url, _strict_response_validation=True) @@ -1206,14 +1342,6 @@ def test_base_url_env(self) -> None: client = AsyncLlamaStackClient(_strict_response_validation=True) assert client.base_url == "http://localhost:5000/from/env/" - # explicit environment arg requires explicitness - with update_env(LLAMA_STACK_CLIENT_BASE_URL="http://localhost:5000/from/env"): - with pytest.raises(ValueError, match=r"you must pass base_url=None"): - AsyncLlamaStackClient(_strict_response_validation=True, environment="production") - - client = AsyncLlamaStackClient(base_url=None, _strict_response_validation=True, environment="production") - assert str(client.base_url).startswith("http://any-hosted-llama-stack-client.com") - @pytest.mark.parametrize( "client", [ @@ -1353,6 +1481,7 @@ class Model(BaseModel): [3, "", 0.5], [2, "", 0.5 * 2.0], [1, "", 0.5 * 4.0], + [-1100, "", 7.8], # test large number potentially overflowing ], ) @mock.patch("time.time", mock.MagicMock(return_value=1696004797)) @@ -1368,12 +1497,31 @@ async def test_parse_retry_after_header(self, remaining_retries: int, retry_afte @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/agents/session/create").mock(side_effect=httpx.TimeoutException("Test timeout error")) + respx_mock.post("/inference/chat_completion").mock(side_effect=httpx.TimeoutException("Test timeout error")) with pytest.raises(APITimeoutError): await self.client.post( - "/agents/session/create", - body=cast(object, dict(agent_id="agent_id", session_name="session_name")), + "/inference/chat_completion", + body=cast( + object, + dict( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + ), + ), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1383,12 +1531,31 @@ async def test_retrying_timeout_errors_doesnt_leak(self, respx_mock: MockRouter) @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) -> None: - respx_mock.post("/agents/session/create").mock(return_value=httpx.Response(500)) + respx_mock.post("/inference/chat_completion").mock(return_value=httpx.Response(500)) with pytest.raises(APIStatusError): await self.client.post( - "/agents/session/create", - body=cast(object, dict(agent_id="agent_id", session_name="session_name")), + "/inference/chat_completion", + body=cast( + object, + dict( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + ), + ), cast_to=httpx.Response, options={"headers": {RAW_RESPONSE_HEADER: "stream"}}, ) @@ -1399,8 +1566,13 @@ async def test_retrying_status_errors_doesnt_leak(self, respx_mock: MockRouter) @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio + @pytest.mark.parametrize("failure_mode", ["status", "exception"]) async def test_retries_taken( - self, async_client: AsyncLlamaStackClient, failures_before_success: int, respx_mock: MockRouter + self, + async_client: AsyncLlamaStackClient, + failures_before_success: int, + failure_mode: Literal["status", "exception"], + respx_mock: MockRouter, ) -> None: client = async_client.with_options(max_retries=4) @@ -1410,14 +1582,112 @@ def retry_handler(_request: httpx.Request) -> httpx.Response: nonlocal nb_retries if nb_retries < failures_before_success: nb_retries += 1 + if failure_mode == "exception": + raise RuntimeError("oops") return httpx.Response(500) return httpx.Response(200) - respx_mock.post("/agents/session/create").mock(side_effect=retry_handler) - - response = await client.agents.sessions.with_raw_response.create( - agent_id="agent_id", session_name="session_name" + respx_mock.post("/inference/chat_completion").mock(side_effect=retry_handler) + + response = await client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", ) assert response.retries_taken == failures_before_success assert int(response.http_request.headers.get("x-stainless-retry-count")) == failures_before_success + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_omit_retry_count_header( + self, async_client: AsyncLlamaStackClient, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/inference/chat_completion").mock(side_effect=retry_handler) + + response = await client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + extra_headers={"x-stainless-retry-count": Omit()}, + ) + + assert len(response.http_request.headers.get_list("x-stainless-retry-count")) == 0 + + @pytest.mark.parametrize("failures_before_success", [0, 2, 4]) + @mock.patch("llama_stack_client._base_client.BaseClient._calculate_retry_timeout", _low_retry_timeout) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_overwrite_retry_count_header( + self, async_client: AsyncLlamaStackClient, failures_before_success: int, respx_mock: MockRouter + ) -> None: + client = async_client.with_options(max_retries=4) + + nb_retries = 0 + + def retry_handler(_request: httpx.Request) -> httpx.Response: + nonlocal nb_retries + if nb_retries < failures_before_success: + nb_retries += 1 + return httpx.Response(500) + return httpx.Response(200) + + respx_mock.post("/inference/chat_completion").mock(side_effect=retry_handler) + + response = await client.inference.with_raw_response.chat_completion( + messages=[ + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + { + "content": "string", + "role": "user", + }, + ], + model="model", + extra_headers={"x-stainless-retry-count": "42"}, + ) + + assert response.http_request.headers.get("x-stainless-retry-count") == "42" diff --git a/tests/test_models.py b/tests/test_models.py index 6fea1a32..bf3f0e20 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -245,7 +245,7 @@ class Model(BaseModel): assert m.foo is True m = Model.construct(foo="CARD_HOLDER") - assert m.foo is "CARD_HOLDER" + assert m.foo == "CARD_HOLDER" m = Model.construct(foo={"bar": False}) assert isinstance(m.foo, Submodel1) diff --git a/tests/test_response.py b/tests/test_response.py index 4130175d..8c803a25 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -190,6 +190,56 @@ async def test_async_response_parse_annotated_type(async_client: AsyncLlamaStack assert obj.bar == 2 +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +def test_response_parse_bool(client: LlamaStackClient, content: str, expected: bool) -> None: + response = APIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = response.parse(to=bool) + assert result is expected + + +@pytest.mark.parametrize( + "content, expected", + [ + ("false", False), + ("true", True), + ("False", False), + ("True", True), + ("TrUe", True), + ("FalSe", False), + ], +) +async def test_async_response_parse_bool(client: AsyncLlamaStackClient, content: str, expected: bool) -> None: + response = AsyncAPIResponse( + raw=httpx.Response(200, content=content), + client=client, + stream=False, + stream_cls=None, + cast_to=str, + options=FinalRequestOptions.construct(method="get", url="/foo"), + ) + + result = await response.parse(to=bool) + assert result is expected + + class OtherModel(BaseModel): a: str From 2ce0804f71654f37ca69e98d175c1d281381968f Mon Sep 17 00:00:00 2001 From: Xi Yan Date: Wed, 30 Oct 2024 16:46:10 -0700 Subject: [PATCH 2/2] bump version --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index b9852c4b..8b438158 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "llama_stack_client" -version = "0.0.47" +version = "0.0.48" description = "The official Python library for the llama-stack-client API" dynamic = ["readme"] license = "Apache-2.0"