Skip to content

Commit 13a626b

Browse files
sararobcopybara-github
authored andcommitted
feat: Add experimental async list_prompts and list_version methods to prompt management
PiperOrigin-RevId: 809132260
1 parent 0d600fd commit 13a626b

File tree

2 files changed

+136
-2
lines changed

2 files changed

+136
-2
lines changed

tests/unit/vertexai/genai/replays/test_list_prompts.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
from tests.unit.vertexai.genai.replays import pytest_helper
1717
from vertexai._genai import types
1818

19+
import pytest
20+
1921

2022
def test_list_returns_prompts(client):
2123
prompt_refs = client.prompt_management.list_prompts()
@@ -71,3 +73,26 @@ def test_list_versions(client):
7173
globals_for_file=globals(),
7274
test_method="prompt_management.list_prompts",
7375
)
76+
77+
pytest_plugins = ("pytest_asyncio",)
78+
79+
80+
@pytest.mark.asyncio
81+
async def test_list_returns_prompts_async(client):
82+
prompt_refs = client.aio.prompt_management.list_prompts()
83+
async for prompt in prompt_refs:
84+
assert isinstance(prompt, types.PromptRef)
85+
assert prompt.prompt_id is not None
86+
assert prompt.model is not None
87+
88+
89+
@pytest.mark.asyncio
90+
async def test_list_versions_async(client):
91+
prompt_version_refs = client.aio.prompt_management.list_versions(
92+
prompt_id="3331020504126455808"
93+
)
94+
async for prompt_version in prompt_version_refs:
95+
assert isinstance(prompt_version, types.PromptVersionRef)
96+
assert prompt_version.prompt_id is not None
97+
assert prompt_version.version_id is not None
98+
assert prompt_version.model is not None

vertexai/_genai/prompt_management.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import json
2020
import logging
2121
import time
22-
from typing import Any, Iterator, Optional, Union
22+
from typing import Any, AsyncIterator, Iterator, Optional, Union
2323
from urllib.parse import urlencode
2424

2525
from google.genai import _api_module
@@ -28,7 +28,7 @@
2828
from google.genai import types as genai_types
2929
from google.genai._common import get_value_by_path as getv
3030
from google.genai._common import set_value_by_path as setv
31-
from google.genai.pagers import Pager
31+
from google.genai.pagers import AsyncPager, Pager
3232

3333
from . import _prompt_management_utils
3434
from . import types
@@ -2495,3 +2495,112 @@ async def delete_version(
24952495
logger.info(
24962496
f"Deleted prompt version {version_id} from prompt with id: {prompt_id}"
24972497
)
2498+
2499+
async def _list_prompts_pager(
2500+
self,
2501+
*,
2502+
config: Optional[types.ListPromptsConfigOrDict] = None,
2503+
) -> AsyncPager[types.Dataset]:
2504+
return AsyncPager(
2505+
"datasets",
2506+
self._list_prompts,
2507+
await self._list_prompts(config=config),
2508+
config,
2509+
)
2510+
2511+
async def _list_versions_pager(
2512+
self,
2513+
*,
2514+
prompt_id: str,
2515+
config: Optional[types.ListPromptsConfigOrDict] = None,
2516+
) -> AsyncPager[types.DatasetVersion]:
2517+
return AsyncPager(
2518+
"dataset_versions",
2519+
self._list_versions,
2520+
await self._list_versions(config=config, dataset_id=prompt_id),
2521+
config,
2522+
)
2523+
2524+
async def list_prompts(
2525+
self,
2526+
*,
2527+
config: Optional[types.ListPromptsConfigOrDict] = None,
2528+
) -> AsyncIterator[types.PromptRef]:
2529+
"""Lists prompt resources in a project.
2530+
2531+
This method retrieves all the prompts from the project provided in the
2532+
vertexai.Client constructor and returns a list of prompt references containing the prompt_id and model for the prompt.
2533+
2534+
To get the full types.Prompt resource for a PromptRef after calling this method, use the get() method with the prompt_id as the prompt_id argument.
2535+
Example usage:
2536+
2537+
```
2538+
prompt_refs = client.aio.prompt_management.list_prompts()
2539+
async for prompt_ref in prompt_refs:
2540+
await client.prompt_management.get(prompt_id=prompt_ref.prompt_id)
2541+
```
2542+
2543+
Args:
2544+
config: Optional configuration for listing prompts.
2545+
2546+
Returns:
2547+
An async iterator of types.PromptRef objects.
2548+
"""
2549+
if isinstance(config, dict):
2550+
config = types.ListPromptsConfig(**config)
2551+
elif not config:
2552+
config = types.ListPromptsConfig()
2553+
async for dataset in await self._list_prompts_pager(config=config):
2554+
if not dataset or not dataset.model_reference or not dataset.name:
2555+
continue
2556+
prompt_ref = types.PromptRef(
2557+
model=dataset.model_reference, prompt_id=dataset.name.split("/")[-1]
2558+
)
2559+
yield prompt_ref
2560+
2561+
async def list_versions(
2562+
self,
2563+
*,
2564+
prompt_id: str,
2565+
config: Optional[types.ListPromptsConfigOrDict] = None,
2566+
) -> AsyncIterator[types.PromptVersionRef]:
2567+
"""Lists prompt version resources for a provided prompt_id.
2568+
2569+
This method retrieves all the prompt versions for a provided prompt_id.
2570+
2571+
To get the full types.Prompt resource for a PromptVersionRef after calling this method, use the get() method with the returned prompt_id and version_id.
2572+
Example usage:
2573+
2574+
```
2575+
prompt_version_refs = await client.prompt_management.list_versions(prompt_id="123")
2576+
async for version_ref in prompt_version_refs:
2577+
await client.aio.prompt_management.get(prompt_id=version_ref.prompt_id, version_id=version_ref.version_id)
2578+
```
2579+
2580+
Args:
2581+
prompt_id: The id of the Vertex Dataset resource containing the prompt. For example, if the prompt resource name is "projects/123/locations/us-central1/datasets/456", then the prompt_id is "456".
2582+
config: Optional configuration for listing prompts.
2583+
2584+
Returns:
2585+
An async iterator of types.PromptVersionRef objects representing the prompt version resources for the provided prompt_id.
2586+
2587+
"""
2588+
if isinstance(config, dict):
2589+
config = types.ListPromptsConfig(**config)
2590+
elif not config:
2591+
config = types.ListPromptsConfig()
2592+
async for dataset_version in await self._list_versions_pager(
2593+
config=config, prompt_id=prompt_id
2594+
):
2595+
if (
2596+
not dataset_version
2597+
or not dataset_version.model_reference
2598+
or not dataset_version.name
2599+
):
2600+
continue
2601+
prompt_version_ref = types.PromptVersionRef(
2602+
model=dataset_version.model_reference,
2603+
version_id=dataset_version.name.split("/")[-1],
2604+
prompt_id=prompt_id,
2605+
)
2606+
yield prompt_version_ref

0 commit comments

Comments
 (0)