Skip to content

Commit

Permalink
feat: add FeatureGroup create function
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 631879143
  • Loading branch information
vertex-sdk-bot authored and Copybara-Service committed May 8, 2024
1 parent cd85d8f commit 3938107
Show file tree
Hide file tree
Showing 4 changed files with 280 additions and 8 deletions.
138 changes: 135 additions & 3 deletions tests/unit/vertexai/test_feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,32 @@
# limitations under the License.
#

import re
from typing import Dict, List
from unittest.mock import patch
from unittest import mock
from unittest.mock import call, patch

from google.api_core import operation as ga_operation
from google.cloud import aiplatform
from google.cloud.aiplatform import base
from vertexai.resources.preview.feature_store import (
feature_group,
)
from vertexai.resources.preview import (
FeatureGroup,
)
import vertexai.resources.preview.feature_store.utils as fs_utils
from vertexai.resources.preview.feature_store import (
FeatureGroupBigQuerySource,
)
import pytest
from google.cloud.aiplatform.compat.services import (
feature_registry_service_client,
)
from google.cloud.aiplatform.compat import types


from feature_store_constants import (
_TEST_PARENT,
_TEST_PROJECT,
_TEST_LOCATION,
_TEST_FG1,
Expand All @@ -45,6 +55,16 @@
pytestmark = pytest.mark.usefixtures("google_auth_mock")


@pytest.fixture
def fg_logger_mock():
with patch.object(
feature_group._LOGGER,
"info",
wraps=feature_group._LOGGER.info,
) as logger_mock:
yield logger_mock


@pytest.fixture
def get_fg_mock():
with patch.object(
Expand All @@ -55,6 +75,18 @@ def get_fg_mock():
yield get_fg_mock


@pytest.fixture
def create_fg_mock():
with patch.object(
feature_registry_service_client.FeatureRegistryServiceClient,
"create_feature_group",
) as create_fg_mock:
create_fg_lro_mock = mock.Mock(ga_operation.Operation)
create_fg_lro_mock.result.return_value = _TEST_FG1
create_fg_mock.return_value = create_fg_lro_mock
yield create_fg_mock


def fg_eq(
fg_to_check: FeatureGroup,
name: str,
Expand All @@ -68,7 +100,7 @@ def fg_eq(
"""Check if a FeatureGroup has the appropriate values set."""
assert fg_to_check.name == name
assert fg_to_check.resource_name == resource_name
assert fg_to_check.source == fs_utils.FeatureGroupBigQuerySource(
assert fg_to_check.source == FeatureGroupBigQuerySource(
uri=source_uri,
entity_id_columns=entity_id_columns,
)
Expand Down Expand Up @@ -101,3 +133,103 @@ def test_init(feature_group_name, get_fg_mock):
location=_TEST_LOCATION,
labels=_TEST_FG1_LABELS,
)


def test_create_fg_no_source_raises_error():
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

with pytest.raises(
ValueError,
match=re.escape("Please specify a valid source."),
):
FeatureGroup.create("fg")


def test_create_fg_bad_source_raises_error():
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

with pytest.raises(
ValueError,
match=re.escape("Only FeatureGroupBigQuerySource is a supported source."),
):
FeatureGroup.create("fg", source=int(1))


def test_create_fg_no_source_bq_uri_raises_error():
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

with pytest.raises(
ValueError,
match=re.escape("Please specify URI in BigQuery source."),
):
FeatureGroup.create(
"fg", source=FeatureGroupBigQuerySource(uri=None, entity_id_columns=None)
)


@pytest.mark.parametrize("create_request_timeout", [None, 1.0])
@pytest.mark.parametrize("sync", [True, False])
def test_create_fg(
create_fg_mock, get_fg_mock, fg_logger_mock, create_request_timeout, sync
):
aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION)

fg = FeatureGroup.create(
_TEST_FG1_ID,
source=FeatureGroupBigQuerySource(
uri=_TEST_FG1_BQ_URI,
entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS,
),
labels=_TEST_FG1_LABELS,
create_request_timeout=create_request_timeout,
sync=sync,
)

if not sync:
fg.wait()

# When creating, the FeatureOnlineStore object doesn't have the path set.
expected_fg = types.feature_group.FeatureGroup(
name=_TEST_FG1_ID,
big_query=types.feature_group.FeatureGroup.BigQuery(
big_query_source=types.io.BigQuerySource(
input_uri=_TEST_FG1_BQ_URI,
),
entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS,
),
labels=_TEST_FG1_LABELS,
)
create_fg_mock.assert_called_once_with(
parent=_TEST_PARENT,
feature_group=expected_fg,
feature_group_id=_TEST_FG1_ID,
metadata=(),
timeout=create_request_timeout,
)

fg_logger_mock.assert_has_calls(
[
call("Creating FeatureGroup"),
call(
f"Create FeatureGroup backing LRO: {create_fg_mock.return_value.operation.name}"
),
call(
"FeatureGroup created. Resource name: projects/test-project/locations/us-central1/featureGroups/my_fg1"
),
call("To use this FeatureGroup in another session:"),
call(
"feature_group = aiplatform.FeatureGroup('projects/test-project/locations/us-central1/featureGroups/my_fg1')"
),
]
)

fg_eq(
fg,
name=_TEST_FG1_ID,
resource_name=_TEST_FG1_PATH,
source_uri=_TEST_FG1_BQ_URI,
entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS,
project=_TEST_PROJECT,
location=_TEST_LOCATION,
labels=_TEST_FG1_LABELS,
)
1 change: 1 addition & 0 deletions vertexai/resources/preview/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"EntityType",
"PipelineJobSchedule",
"FeatureGroup",
"FeatureGroupBigQuerySource",
"FeatureOnlineStoreType",
"FeatureOnlineStore",
"FeatureView",
Expand Down
2 changes: 2 additions & 0 deletions vertexai/resources/preview/feature_store/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
)

from vertexai.resources.preview.feature_store.utils import (
FeatureGroupBigQuerySource,
FeatureViewBigQuerySource,
FeatureViewReadResponse,
IndexConfig,
Expand All @@ -41,6 +42,7 @@

__all__ = (
FeatureGroup,
FeatureGroupBigQuerySource,
FeatureOnlineStoreType,
FeatureOnlineStore,
FeatureView,
Expand Down
147 changes: 142 additions & 5 deletions vertexai/resources/preview/feature_store/feature_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,26 @@
# limitations under the License.
#

from typing import Optional
from typing import (
Sequence,
Tuple,
Dict,
List,
Optional,
)
from google.auth import credentials as auth_credentials
from google.cloud.aiplatform import base
from google.cloud.aiplatform import base, initializer
from google.cloud.aiplatform import utils
from google.cloud.aiplatform.compat.types import (
feature_group as gca_feature_group,
io as gca_io,
)
from vertexai.resources.preview.feature_store.utils import (
FeatureGroupBigQuerySource,
)
import vertexai.resources.preview.feature_store.utils as fs_utils


_LOGGER = base.Logger(__name__)


class FeatureGroup(base.VertexAiResourceNounWithFutureManager):
Expand Down Expand Up @@ -71,9 +83,134 @@ def __init__(

self._gca_resource = self._get_gca_resource(resource_name=name)

@classmethod
def create(
cls,
name: str,
source: FeatureGroupBigQuerySource = None,
entity_id_columns: Optional[List[str]] = None,
labels: Optional[Dict[str, str]] = None,
description: Optional[str] = None,
project: Optional[str] = None,
location: Optional[str] = None,
credentials: Optional[auth_credentials.Credentials] = None,
request_metadata: Optional[Sequence[Tuple[str, str]]] = None,
create_request_timeout: Optional[float] = None,
sync: bool = True,
) -> "FeatureGroup":
"""Creates a new feature group.
Args:
name: The name of the feature group.
source: The BigQuery source of the feature group.
entity_id_columns:
The entity ID columns. If not specified, defaults to
['entity_id'].
labels:
The labels with user-defined metadata to organize your
FeatureGroup.
Label keys and values can be no longer than 64
characters (Unicode codepoints), can only
contain lowercase letters, numeric characters,
underscores and dashes. International characters
are allowed.
See https://goo.gl/xmQnxf for more information
on and examples of labels. No more than 64 user
labels can be associated with one
FeatureGroup(System labels are excluded)."
System reserved label keys are prefixed with
"aiplatform.googleapis.com/" and are immutable.
description: Description of the FeatureGroup.
project:
Project to create feature group in. If unset, the project set in
aiplatform.init will be used.
location:
Location to create feature group in. If not set, location set in
aiplatform.init will be used.
credentials:
Custom credentials to use to create this feature group.
Overrides credentials set in aiplatform.init.
request_metadata:
Strings which should be sent along with the request as metadata.
create_request_timeout:
The timeout for the create request in seconds.
sync:
Whether to execute this creation synchronously. If False, this
method will be executed in concurrent Future and any downstream
object will be immediately returned and synced when the Future
has completed.
Returns:
FeatureGroup - the FeatureGroup resource object.
"""

if not source:
raise ValueError("Please specify a valid source.")

# Only BigQuery source is supported right now.
if not isinstance(source, FeatureGroupBigQuerySource):
raise ValueError("Only FeatureGroupBigQuerySource is a supported source.")

# BigQuery source validation.
if not source.uri:
raise ValueError("Please specify URI in BigQuery source.")

if not source.entity_id_columns:
_LOGGER.info(
"No entity ID columns specified in BigQuery source. Defaulting to ['entity_id']."
)
entity_id_columns = ["entity_id"]
else:
entity_id_columns = source.entity_id_columns

gapic_feature_group = gca_feature_group.FeatureGroup(
big_query=gca_feature_group.FeatureGroup.BigQuery(
big_query_source=gca_io.BigQuerySource(input_uri=source.uri),
entity_id_columns=entity_id_columns,
),
name=name,
description=description,
)

if labels:
utils.validate_labels(labels)
gapic_feature_group.labels = labels

if request_metadata is None:
request_metadata = ()

api_client = cls._instantiate_client(location=location, credentials=credentials)

create_feature_group_lro = api_client.create_feature_group(
parent=initializer.global_config.common_location_path(
project=project, location=location
),
feature_group=gapic_feature_group,
feature_group_id=name,
metadata=request_metadata,
timeout=create_request_timeout,
)

_LOGGER.log_create_with_lro(cls, create_feature_group_lro)

created_feature_group = create_feature_group_lro.result()

_LOGGER.log_create_complete(cls, created_feature_group, "feature_group")

feature_group_obj = cls(
name=created_feature_group.name,
project=project,
location=location,
credentials=credentials,
)

return feature_group_obj

@property
def source(self) -> fs_utils.FeatureGroupBigQuerySource:
return fs_utils.FeatureGroupBigQuerySource(
def source(self) -> FeatureGroupBigQuerySource:
return FeatureGroupBigQuerySource(
uri=self._gca_resource.big_query.big_query_source.input_uri,
entity_id_columns=self._gca_resource.big_query.entity_id_columns,
)

0 comments on commit 3938107

Please sign in to comment.