|
19 | 19 | import json
|
20 | 20 | import logging
|
21 | 21 | import time
|
22 |
| -from typing import Any, Iterator, Optional, Union |
| 22 | +from typing import Any, AsyncIterator, Iterator, Optional, Union |
23 | 23 | from urllib.parse import urlencode
|
24 | 24 |
|
25 | 25 | from google.genai import _api_module
|
|
28 | 28 | from google.genai import types as genai_types
|
29 | 29 | from google.genai._common import get_value_by_path as getv
|
30 | 30 | from google.genai._common import set_value_by_path as setv
|
31 |
| -from google.genai.pagers import Pager |
| 31 | +from google.genai.pagers import AsyncPager, Pager |
32 | 32 |
|
33 | 33 | from . import _prompt_management_utils
|
34 | 34 | from . import types
|
@@ -2495,3 +2495,112 @@ async def delete_version(
|
2495 | 2495 | logger.info(
|
2496 | 2496 | f"Deleted prompt version {version_id} from prompt with id: {prompt_id}"
|
2497 | 2497 | )
|
| 2498 | + |
| 2499 | + async def _list_prompts_pager( |
| 2500 | + self, |
| 2501 | + *, |
| 2502 | + config: Optional[types.ListPromptsConfigOrDict] = None, |
| 2503 | + ) -> AsyncPager[types.Dataset]: |
| 2504 | + return AsyncPager( |
| 2505 | + "datasets", |
| 2506 | + self._list_prompts, |
| 2507 | + await self._list_prompts(config=config), |
| 2508 | + config, |
| 2509 | + ) |
| 2510 | + |
| 2511 | + async def _list_versions_pager( |
| 2512 | + self, |
| 2513 | + *, |
| 2514 | + prompt_id: str, |
| 2515 | + config: Optional[types.ListPromptsConfigOrDict] = None, |
| 2516 | + ) -> AsyncPager[types.DatasetVersion]: |
| 2517 | + return AsyncPager( |
| 2518 | + "dataset_versions", |
| 2519 | + self._list_versions, |
| 2520 | + await self._list_versions(config=config, dataset_id=prompt_id), |
| 2521 | + config, |
| 2522 | + ) |
| 2523 | + |
| 2524 | + async def list_prompts( |
| 2525 | + self, |
| 2526 | + *, |
| 2527 | + config: Optional[types.ListPromptsConfigOrDict] = None, |
| 2528 | + ) -> AsyncIterator[types.PromptRef]: |
| 2529 | + """Lists prompt resources in a project. |
| 2530 | +
|
| 2531 | + This method retrieves all the prompts from the project provided in the |
| 2532 | + vertexai.Client constructor and returns a list of prompt references containing the prompt_id and model for the prompt. |
| 2533 | +
|
| 2534 | + To get the full types.Prompt resource for a PromptRef after calling this method, use the get() method with the prompt_id as the prompt_id argument. |
| 2535 | + Example usage: |
| 2536 | +
|
| 2537 | + ``` |
| 2538 | + prompt_refs = client.aio.prompt_management.list_prompts() |
| 2539 | + async for prompt_ref in prompt_refs: |
| 2540 | + await client.prompt_management.get(prompt_id=prompt_ref.prompt_id) |
| 2541 | + ``` |
| 2542 | +
|
| 2543 | + Args: |
| 2544 | + config: Optional configuration for listing prompts. |
| 2545 | +
|
| 2546 | + Returns: |
| 2547 | + An async iterator of types.PromptRef objects. |
| 2548 | + """ |
| 2549 | + if isinstance(config, dict): |
| 2550 | + config = types.ListPromptsConfig(**config) |
| 2551 | + elif not config: |
| 2552 | + config = types.ListPromptsConfig() |
| 2553 | + async for dataset in await self._list_prompts_pager(config=config): |
| 2554 | + if not dataset or not dataset.model_reference or not dataset.name: |
| 2555 | + continue |
| 2556 | + prompt_ref = types.PromptRef( |
| 2557 | + model=dataset.model_reference, prompt_id=dataset.name.split("/")[-1] |
| 2558 | + ) |
| 2559 | + yield prompt_ref |
| 2560 | + |
| 2561 | + async def list_versions( |
| 2562 | + self, |
| 2563 | + *, |
| 2564 | + prompt_id: str, |
| 2565 | + config: Optional[types.ListPromptsConfigOrDict] = None, |
| 2566 | + ) -> AsyncIterator[types.PromptVersionRef]: |
| 2567 | + """Lists prompt version resources for a provided prompt_id. |
| 2568 | +
|
| 2569 | + This method retrieves all the prompt versions for a provided prompt_id. |
| 2570 | +
|
| 2571 | + To get the full types.Prompt resource for a PromptVersionRef after calling this method, use the get() method with the returned prompt_id and version_id. |
| 2572 | + Example usage: |
| 2573 | +
|
| 2574 | + ``` |
| 2575 | + prompt_version_refs = await client.prompt_management.list_versions(prompt_id="123") |
| 2576 | + async for version_ref in prompt_version_refs: |
| 2577 | + await client.aio.prompt_management.get(prompt_id=version_ref.prompt_id, version_id=version_ref.version_id) |
| 2578 | + ``` |
| 2579 | +
|
| 2580 | + Args: |
| 2581 | + prompt_id: The id of the Vertex Dataset resource containing the prompt. For example, if the prompt resource name is "projects/123/locations/us-central1/datasets/456", then the prompt_id is "456". |
| 2582 | + config: Optional configuration for listing prompts. |
| 2583 | +
|
| 2584 | + Returns: |
| 2585 | + An async iterator of types.PromptVersionRef objects representing the prompt version resources for the provided prompt_id. |
| 2586 | +
|
| 2587 | + """ |
| 2588 | + if isinstance(config, dict): |
| 2589 | + config = types.ListPromptsConfig(**config) |
| 2590 | + elif not config: |
| 2591 | + config = types.ListPromptsConfig() |
| 2592 | + async for dataset_version in await self._list_versions_pager( |
| 2593 | + config=config, prompt_id=prompt_id |
| 2594 | + ): |
| 2595 | + if ( |
| 2596 | + not dataset_version |
| 2597 | + or not dataset_version.model_reference |
| 2598 | + or not dataset_version.name |
| 2599 | + ): |
| 2600 | + continue |
| 2601 | + prompt_version_ref = types.PromptVersionRef( |
| 2602 | + model=dataset_version.model_reference, |
| 2603 | + version_id=dataset_version.name.split("/")[-1], |
| 2604 | + prompt_id=prompt_id, |
| 2605 | + ) |
| 2606 | + yield prompt_version_ref |
0 commit comments