Skip to content

Commit

Permalink
feat: GenAI - Context Caching - add get() classmethod and refresh() i…
Browse files Browse the repository at this point in the history
…nstance method

PiperOrigin-RevId: 644141561
  • Loading branch information
ZhenyiQ authored and Copybara-Service committed Jun 17, 2024
1 parent 62f7af5 commit 6be874a
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 18 deletions.
45 changes: 45 additions & 0 deletions tests/unit/vertexai/test_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,32 @@ def get_cached_content(self, name, retry=None):
yield get_cached_content


@pytest.fixture
def mock_list_cached_contents():
"""Mocks GenAiCacheServiceClient.get_cached_content()."""

def list_cached_contents(self, request):
del self, request
response = [
GapicCachedContent(
name="cached_content1_from_list_request",
model="model-name1",
),
GapicCachedContent(
name="cached_content2_from_list_request",
model="model-name2",
),
]
return response

with mock.patch.object(
gen_ai_cache_service.client.GenAiCacheServiceClient,
"list_cached_contents",
new=list_cached_contents,
) as list_cached_contents:
yield list_cached_contents


@pytest.mark.usefixtures("google_auth_mock")
class TestCaching:
"""Unit tests for caching.CachedContent."""
Expand Down Expand Up @@ -118,6 +144,19 @@ def test_constructor_with_only_content_id(self, mock_get_cached_content):
)
assert cache.model_name == "model-name"

def test_get_with_content_id(self, mock_get_cached_content):
partial_resource_name = "contents-id"

cache = caching.CachedContent.get(
cached_content_name=partial_resource_name,
)

assert cache.name == "contents-id"
assert cache.resource_name == (
f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/cachedContents/contents-id"
)
assert cache.model_name == "model-name"

def test_create_with_real_payload(
self, mock_create_cached_content, mock_get_cached_content
):
Expand Down Expand Up @@ -162,3 +201,9 @@ def test_create_with_real_payload_and_wrapped_type(
== f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}/publishers/google/models/model-name"
)
assert cache.name == _CREATED_CONTENT_ID

def test_list(self, mock_list_cached_contents):
cached_contents = caching.CachedContent.list()
for i, cached_content in enumerate(cached_contents):
assert cached_content.name == f"cached_content{i + 1}_from_list_request"
assert cached_content.model_name == f"model-name{i + 1}"
40 changes: 22 additions & 18 deletions vertexai/caching/_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,28 +135,15 @@ def __init__(self, cached_content_name: str):
"456".
"""
super().__init__(resource_name=cached_content_name)

resource_name = aiplatform_utils.full_resource_name(
resource_name=cached_content_name,
resource_noun=self._resource_noun,
parse_resource_name_method=self._parse_resource_name,
format_resource_name_method=self._format_resource_name,
project=self.project,
location=self.location,
parent_resource_name_fields=None,
resource_id_validator=self._resource_id_validator,
)
self._gca_resource = gca_cached_content.CachedContent(name=resource_name)
self._gca_resource = self._get_gca_resource(cached_content_name)

@property
def _raw_cached_content(self) -> gca_cached_content.CachedContent:
return self._gca_resource

@property
def model_name(self) -> str:
if not self._raw_cached_content.model:
self._sync_gca_resource()
return self._raw_cached_content.model
return self._gca_resource.model

@classmethod
def create(
Expand Down Expand Up @@ -235,6 +222,10 @@ def create(
obj._gca_resource = cached_content_resource
return obj

def refresh(self):
"""Syncs the local cached content with the remote resource."""
self._sync_gca_resource()

def update(
self,
*,
Expand Down Expand Up @@ -265,15 +256,28 @@ def update(

@property
def expire_time(self) -> datetime.datetime:
"""Time this resource was last updated."""
self._sync_gca_resource()
"""Time this resource is considered expired.
The returned value may be stale. Use refresh() to get the latest value.
Returns:
The expiration time of the cached content resource.
"""
return self._gca_resource.expire_time

def delete(self):
"""Deletes the current cached content resource."""
self._delete()

@classmethod
def list(cls):
def list(cls) -> List["CachedContent"]:
"""Lists the active cached content resources."""
# TODO(b/345326114): Make list() interface richer after aligning with
# Google AI SDK
return cls._list()

@classmethod
def get(cls, cached_content_name: str) -> "CachedContent":
"""Retrieves an existing cached content resource."""
cache = cls(cached_content_name)
return cache

0 comments on commit 6be874a

Please sign in to comment.