diff --git a/tests/unit/vertexai/test_feature_group.py b/tests/unit/vertexai/test_feature_group.py index 3120818609..ba986c3ef1 100644 --- a/tests/unit/vertexai/test_feature_group.py +++ b/tests/unit/vertexai/test_feature_group.py @@ -15,22 +15,32 @@ # limitations under the License. # +import re from typing import Dict, List -from unittest.mock import patch +from unittest import mock +from unittest.mock import call, patch +from google.api_core import operation as ga_operation from google.cloud import aiplatform from google.cloud.aiplatform import base +from vertexai.resources.preview.feature_store import ( + feature_group, +) from vertexai.resources.preview import ( FeatureGroup, ) -import vertexai.resources.preview.feature_store.utils as fs_utils +from vertexai.resources.preview.feature_store import ( + FeatureGroupBigQuerySource, +) import pytest from google.cloud.aiplatform.compat.services import ( feature_registry_service_client, ) +from google.cloud.aiplatform.compat import types from feature_store_constants import ( + _TEST_PARENT, _TEST_PROJECT, _TEST_LOCATION, _TEST_FG1, @@ -45,6 +55,16 @@ pytestmark = pytest.mark.usefixtures("google_auth_mock") +@pytest.fixture +def fg_logger_mock(): + with patch.object( + feature_group._LOGGER, + "info", + wraps=feature_group._LOGGER.info, + ) as logger_mock: + yield logger_mock + + @pytest.fixture def get_fg_mock(): with patch.object( @@ -55,6 +75,18 @@ def get_fg_mock(): yield get_fg_mock +@pytest.fixture +def create_fg_mock(): + with patch.object( + feature_registry_service_client.FeatureRegistryServiceClient, + "create_feature_group", + ) as create_fg_mock: + create_fg_lro_mock = mock.Mock(ga_operation.Operation) + create_fg_lro_mock.result.return_value = _TEST_FG1 + create_fg_mock.return_value = create_fg_lro_mock + yield create_fg_mock + + def fg_eq( fg_to_check: FeatureGroup, name: str, @@ -68,7 +100,7 @@ def fg_eq( """Check if a FeatureGroup has the appropriate values set.""" assert fg_to_check.name == name assert fg_to_check.resource_name == resource_name - assert fg_to_check.source == fs_utils.FeatureGroupBigQuerySource( + assert fg_to_check.source == FeatureGroupBigQuerySource( uri=source_uri, entity_id_columns=entity_id_columns, ) @@ -101,3 +133,103 @@ def test_init(feature_group_name, get_fg_mock): location=_TEST_LOCATION, labels=_TEST_FG1_LABELS, ) + + +def test_create_fg_no_source_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape("Please specify a valid source."), + ): + FeatureGroup.create("fg") + + +def test_create_fg_bad_source_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape("Only FeatureGroupBigQuerySource is a supported source."), + ): + FeatureGroup.create("fg", source=int(1)) + + +def test_create_fg_no_source_bq_uri_raises_error(): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + with pytest.raises( + ValueError, + match=re.escape("Please specify URI in BigQuery source."), + ): + FeatureGroup.create( + "fg", source=FeatureGroupBigQuerySource(uri=None, entity_id_columns=None) + ) + + +@pytest.mark.parametrize("create_request_timeout", [None, 1.0]) +@pytest.mark.parametrize("sync", [True, False]) +def test_create_fg( + create_fg_mock, get_fg_mock, fg_logger_mock, create_request_timeout, sync +): + aiplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + fg = FeatureGroup.create( + _TEST_FG1_ID, + source=FeatureGroupBigQuerySource( + uri=_TEST_FG1_BQ_URI, + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG1_LABELS, + create_request_timeout=create_request_timeout, + sync=sync, + ) + + if not sync: + fg.wait() + + # When creating, the FeatureOnlineStore object doesn't have the path set. + expected_fg = types.feature_group.FeatureGroup( + name=_TEST_FG1_ID, + big_query=types.feature_group.FeatureGroup.BigQuery( + big_query_source=types.io.BigQuerySource( + input_uri=_TEST_FG1_BQ_URI, + ), + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + ), + labels=_TEST_FG1_LABELS, + ) + create_fg_mock.assert_called_once_with( + parent=_TEST_PARENT, + feature_group=expected_fg, + feature_group_id=_TEST_FG1_ID, + metadata=(), + timeout=create_request_timeout, + ) + + fg_logger_mock.assert_has_calls( + [ + call("Creating FeatureGroup"), + call( + f"Create FeatureGroup backing LRO: {create_fg_mock.return_value.operation.name}" + ), + call( + "FeatureGroup created. Resource name: projects/test-project/locations/us-central1/featureGroups/my_fg1" + ), + call("To use this FeatureGroup in another session:"), + call( + "feature_group = aiplatform.FeatureGroup('projects/test-project/locations/us-central1/featureGroups/my_fg1')" + ), + ] + ) + + fg_eq( + fg, + name=_TEST_FG1_ID, + resource_name=_TEST_FG1_PATH, + source_uri=_TEST_FG1_BQ_URI, + entity_id_columns=_TEST_FG1_ENTITY_ID_COLUMNS, + project=_TEST_PROJECT, + location=_TEST_LOCATION, + labels=_TEST_FG1_LABELS, + ) diff --git a/vertexai/resources/preview/__init__.py b/vertexai/resources/preview/__init__.py index aef94c6099..e0ea4632c8 100644 --- a/vertexai/resources/preview/__init__.py +++ b/vertexai/resources/preview/__init__.py @@ -64,6 +64,7 @@ "EntityType", "PipelineJobSchedule", "FeatureGroup", + "FeatureGroupBigQuerySource", "FeatureOnlineStoreType", "FeatureOnlineStore", "FeatureView", diff --git a/vertexai/resources/preview/feature_store/__init__.py b/vertexai/resources/preview/feature_store/__init__.py index 07855d5088..bc7d1b0373 100644 --- a/vertexai/resources/preview/feature_store/__init__.py +++ b/vertexai/resources/preview/feature_store/__init__.py @@ -30,6 +30,7 @@ ) from vertexai.resources.preview.feature_store.utils import ( + FeatureGroupBigQuerySource, FeatureViewBigQuerySource, FeatureViewReadResponse, IndexConfig, @@ -41,6 +42,7 @@ __all__ = ( FeatureGroup, + FeatureGroupBigQuerySource, FeatureOnlineStoreType, FeatureOnlineStore, FeatureView, diff --git a/vertexai/resources/preview/feature_store/feature_group.py b/vertexai/resources/preview/feature_store/feature_group.py index be437389cd..90198a4e01 100644 --- a/vertexai/resources/preview/feature_store/feature_group.py +++ b/vertexai/resources/preview/feature_store/feature_group.py @@ -15,14 +15,26 @@ # limitations under the License. # -from typing import Optional +from typing import ( + Sequence, + Tuple, + Dict, + List, + Optional, +) from google.auth import credentials as auth_credentials -from google.cloud.aiplatform import base +from google.cloud.aiplatform import base, initializer from google.cloud.aiplatform import utils from google.cloud.aiplatform.compat.types import ( feature_group as gca_feature_group, + io as gca_io, +) +from vertexai.resources.preview.feature_store.utils import ( + FeatureGroupBigQuerySource, ) -import vertexai.resources.preview.feature_store.utils as fs_utils + + +_LOGGER = base.Logger(__name__) class FeatureGroup(base.VertexAiResourceNounWithFutureManager): @@ -71,9 +83,134 @@ def __init__( self._gca_resource = self._get_gca_resource(resource_name=name) + @classmethod + def create( + cls, + name: str, + source: FeatureGroupBigQuerySource = None, + entity_id_columns: Optional[List[str]] = None, + labels: Optional[Dict[str, str]] = None, + description: Optional[str] = None, + project: Optional[str] = None, + location: Optional[str] = None, + credentials: Optional[auth_credentials.Credentials] = None, + request_metadata: Optional[Sequence[Tuple[str, str]]] = None, + create_request_timeout: Optional[float] = None, + sync: bool = True, + ) -> "FeatureGroup": + """Creates a new feature group. + + Args: + name: The name of the feature group. + source: The BigQuery source of the feature group. + entity_id_columns: + The entity ID columns. If not specified, defaults to + ['entity_id']. + labels: + The labels with user-defined metadata to organize your + FeatureGroup. + + Label keys and values can be no longer than 64 + characters (Unicode codepoints), can only + contain lowercase letters, numeric characters, + underscores and dashes. International characters + are allowed. + + See https://goo.gl/xmQnxf for more information + on and examples of labels. No more than 64 user + labels can be associated with one + FeatureGroup(System labels are excluded)." + System reserved label keys are prefixed with + "aiplatform.googleapis.com/" and are immutable. + description: Description of the FeatureGroup. + project: + Project to create feature group in. If unset, the project set in + aiplatform.init will be used. + location: + Location to create feature group in. If not set, location set in + aiplatform.init will be used. + credentials: + Custom credentials to use to create this feature group. + Overrides credentials set in aiplatform.init. + request_metadata: + Strings which should be sent along with the request as metadata. + create_request_timeout: + The timeout for the create request in seconds. + sync: + Whether to execute this creation synchronously. If False, this + method will be executed in concurrent Future and any downstream + object will be immediately returned and synced when the Future + has completed. + + Returns: + FeatureGroup - the FeatureGroup resource object. + """ + + if not source: + raise ValueError("Please specify a valid source.") + + # Only BigQuery source is supported right now. + if not isinstance(source, FeatureGroupBigQuerySource): + raise ValueError("Only FeatureGroupBigQuerySource is a supported source.") + + # BigQuery source validation. + if not source.uri: + raise ValueError("Please specify URI in BigQuery source.") + + if not source.entity_id_columns: + _LOGGER.info( + "No entity ID columns specified in BigQuery source. Defaulting to ['entity_id']." + ) + entity_id_columns = ["entity_id"] + else: + entity_id_columns = source.entity_id_columns + + gapic_feature_group = gca_feature_group.FeatureGroup( + big_query=gca_feature_group.FeatureGroup.BigQuery( + big_query_source=gca_io.BigQuerySource(input_uri=source.uri), + entity_id_columns=entity_id_columns, + ), + name=name, + description=description, + ) + + if labels: + utils.validate_labels(labels) + gapic_feature_group.labels = labels + + if request_metadata is None: + request_metadata = () + + api_client = cls._instantiate_client(location=location, credentials=credentials) + + create_feature_group_lro = api_client.create_feature_group( + parent=initializer.global_config.common_location_path( + project=project, location=location + ), + feature_group=gapic_feature_group, + feature_group_id=name, + metadata=request_metadata, + timeout=create_request_timeout, + ) + + _LOGGER.log_create_with_lro(cls, create_feature_group_lro) + + created_feature_group = create_feature_group_lro.result() + + _LOGGER.log_create_complete(cls, created_feature_group, "feature_group") + + feature_group_obj = cls( + name=created_feature_group.name, + project=project, + location=location, + credentials=credentials, + ) + + return feature_group_obj + @property - def source(self) -> fs_utils.FeatureGroupBigQuerySource: - return fs_utils.FeatureGroupBigQuerySource( + def source(self) -> FeatureGroupBigQuerySource: + return FeatureGroupBigQuerySource( uri=self._gca_resource.big_query.big_query_source.input_uri, entity_id_columns=self._gca_resource.big_query.entity_id_columns, )