Skip to content

Commit

Permalink
[API] Add list feature sets/vectors tags endpoints (#1260)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hedingber committed Aug 29, 2021
1 parent 98978ee commit bebb828
Show file tree
Hide file tree
Showing 10 changed files with 226 additions and 30 deletions.
73 changes: 73 additions & 0 deletions mlrun/api/api/endpoints/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import mlrun.api.crud
import mlrun.api.utils.clients.opa
import mlrun.api.utils.singletons.project_member
import mlrun.errors
import mlrun.feature_store
from mlrun import v3io_cred
from mlrun.api import schemas
Expand Down Expand Up @@ -205,6 +206,41 @@ def list_feature_sets(
return mlrun.api.schemas.FeatureSetsOutput(feature_sets=feature_sets)


@router.get(
"/projects/{project}/feature-sets/{name}/tags",
response_model=schemas.FeatureSetsTagsOutput,
)
def list_feature_sets_tags(
project: str,
name: str,
auth_verifier: deps.AuthVerifierDep = Depends(deps.AuthVerifierDep),
db_session: Session = Depends(deps.get_db_session),
):
if name != "*":
raise mlrun.errors.MLRunInvalidArgumentError(
"Listing a specific feature set tags is not supported, set name to *"
)
mlrun.api.utils.clients.opa.Client().query_project_permissions(
project, mlrun.api.schemas.AuthorizationAction.read, auth_verifier.auth_info,
)
tag_tuples = mlrun.api.crud.FeatureStore().list_feature_sets_tags(
db_session, project,
)
feature_set_name_to_tag = {tag_tuple[1]: tag_tuple[2] for tag_tuple in tag_tuples}
allowed_feature_set_names = mlrun.api.utils.clients.opa.Client().filter_project_resources_by_permissions(
mlrun.api.schemas.AuthorizationResourceTypes.feature_set,
list(feature_set_name_to_tag.keys()),
lambda feature_set_name: (project, feature_set_name,),
auth_verifier.auth_info,
)
tags = [
tag_tuple[2]
for tag_tuple in tag_tuples
if tag_tuple[1] in allowed_feature_set_names
]
return mlrun.api.schemas.FeatureSetsTagsOutput(tags=tags)


def _has_v3io_path(data_source, data_targets, feature_set):
paths = []

Expand Down Expand Up @@ -494,6 +530,43 @@ def list_feature_vectors(
return mlrun.api.schemas.FeatureVectorsOutput(feature_vectors=feature_vectors)


@router.get(
"/projects/{project}/feature-vectors/{name}/tags",
response_model=schemas.FeatureVectorsTagsOutput,
)
def list_feature_vectors_tags(
project: str,
name: str,
auth_verifier: deps.AuthVerifierDep = Depends(deps.AuthVerifierDep),
db_session: Session = Depends(deps.get_db_session),
):
if name != "*":
raise mlrun.errors.MLRunInvalidArgumentError(
"Listing a specific feature vector tags is not supported, set name to *"
)
mlrun.api.utils.clients.opa.Client().query_project_permissions(
project, mlrun.api.schemas.AuthorizationAction.read, auth_verifier.auth_info,
)
tag_tuples = mlrun.api.crud.FeatureStore().list_feature_vectors_tags(
db_session, project,
)
feature_vector_name_to_tag = {
tag_tuple[1]: tag_tuple[2] for tag_tuple in tag_tuples
}
allowed_feature_vector_names = mlrun.api.utils.clients.opa.Client().filter_project_resources_by_permissions(
mlrun.api.schemas.AuthorizationResourceTypes.feature_vector,
list(feature_vector_name_to_tag.keys()),
lambda feature_vector_name: (project, feature_vector_name,),
auth_verifier.auth_info,
)
tags = [
tag_tuple[2]
for tag_tuple in tag_tuples
if tag_tuple[1] in allowed_feature_vector_names
]
return mlrun.api.schemas.FeatureVectorsTagsOutput(tags=tags)


@router.put(
"/projects/{project}/feature-vectors/{name}/references/{reference}",
response_model=schemas.FeatureVector,
Expand Down
40 changes: 40 additions & 0 deletions mlrun/api/crud/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,16 @@ def get_feature_set(
db_session, mlrun.api.schemas.FeatureSet, project, name, tag, uid
)

def list_feature_sets_tags(
self, db_session: sqlalchemy.orm.Session, project: str,
) -> typing.List[typing.Tuple[str, str, str]]:
"""
:return: a list of Tuple of (project, feature_set.name, tag)
"""
return self._list_object_type_tags(
db_session, mlrun.api.schemas.FeatureSet, project
)

def list_feature_sets(
self,
db_session: sqlalchemy.orm.Session,
Expand Down Expand Up @@ -195,6 +205,16 @@ def get_feature_vector(
db_session, mlrun.api.schemas.FeatureVector, project, name, tag, uid,
)

def list_feature_vectors_tags(
self, db_session: sqlalchemy.orm.Session, project: str,
) -> typing.List[typing.Tuple[str, str, str]]:
"""
:return: a list of Tuple of (project, feature_vector.name, tag)
"""
return self._list_object_type_tags(
db_session, mlrun.api.schemas.FeatureVector, project
)

def list_feature_vectors(
self,
db_session: sqlalchemy.orm.Session,
Expand Down Expand Up @@ -338,6 +358,26 @@ def _get_object(
f"Provided object type is not supported. object_type={object_schema.__class__.__name__}"
)

def _list_object_type_tags(
self,
db_session: sqlalchemy.orm.Session,
object_schema: typing.ClassVar,
project: str,
) -> typing.List[typing.Tuple[str, str, str]]:
project = project or mlrun.mlconf.default_project
if object_schema.__name__ == mlrun.api.schemas.FeatureSet.__name__:
return mlrun.api.utils.singletons.db.get_db().list_feature_sets_tags(
db_session, project
)
elif object_schema.__name__ == mlrun.api.schemas.FeatureVector.__name__:
return mlrun.api.utils.singletons.db.get_db().list_feature_vectors_tags(
db_session, project
)
else:
raise NotImplementedError(
f"Provided object type is not supported. object_type={object_schema.__class__.__name__}"
)

def _delete_object(
self,
db_session: sqlalchemy.orm.Session,
Expand Down
20 changes: 19 additions & 1 deletion mlrun/api/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Tuple

from mlrun.api import schemas

Expand Down Expand Up @@ -324,6 +324,15 @@ def list_feature_sets(
) -> schemas.FeatureSetsOutput:
pass

@abstractmethod
def list_feature_sets_tags(
self, session, project: str,
) -> List[Tuple[str, str, str]]:
"""
:return: a list of Tuple of (project, feature_set.name, tag)
"""
pass

@abstractmethod
def patch_feature_set(
self,
Expand Down Expand Up @@ -369,6 +378,15 @@ def list_feature_vectors(
) -> schemas.FeatureVectorsOutput:
pass

@abstractmethod
def list_feature_vectors_tags(
self, session, project: str,
) -> List[Tuple[str, str, str]]:
"""
:return: a list of Tuple of (project, feature_vector.name, tag)
"""
pass

@abstractmethod
def store_feature_vector(
self,
Expand Down
10 changes: 10 additions & 0 deletions mlrun/api/db/filedb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,11 @@ def list_feature_sets(
) -> schemas.FeatureSetsOutput:
raise NotImplementedError()

def list_feature_sets_tags(
self, session, project: str,
):
raise NotImplementedError()

def patch_feature_set(
self,
session,
Expand Down Expand Up @@ -295,6 +300,11 @@ def list_feature_vectors(
) -> schemas.FeatureVectorsOutput:
raise NotImplementedError()

def list_feature_vectors_tags(
self, session, project: str,
):
raise NotImplementedError()

def store_feature_vector(
self,
session,
Expand Down
23 changes: 22 additions & 1 deletion mlrun/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,6 @@ def list_artifact_tags(
self, session, project
) -> typing.List[typing.Tuple[str, str, str]]:
"""
:return: a list of Tuple of (project, artifact.key, tag)
"""
query = (
Expand Down Expand Up @@ -1423,6 +1422,17 @@ def list_feature_sets(
)
return schemas.FeatureSetsOutput(feature_sets=feature_sets)

def list_feature_sets_tags(
self, session, project: str,
):
query = (
session.query(FeatureSet.name, FeatureSet.Tag.name)
.filter(FeatureSet.Tag.project == project)
.join(FeatureSet, FeatureSet.Tag.obj_id == FeatureSet.id)
.distinct()
)
return [(project, row[0], row[1]) for row in query]

@staticmethod
def _update_feature_set_features(
feature_set: FeatureSet, feature_dicts: List[dict], replace=False
Expand Down Expand Up @@ -1781,6 +1791,17 @@ def list_feature_vectors(
)
return schemas.FeatureVectorsOutput(feature_vectors=feature_vectors)

def list_feature_vectors_tags(
self, session, project: str,
):
query = (
session.query(FeatureVector.name, FeatureVector.Tag.name)
.filter(FeatureVector.Tag.project == project)
.join(FeatureVector, FeatureVector.Tag.obj_id == FeatureVector.id)
.distinct()
)
return [(project, row[0], row[1]) for row in query]

def store_feature_vector(
self,
session,
Expand Down
2 changes: 2 additions & 0 deletions mlrun/api/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,12 @@
FeatureSetRecord,
FeatureSetsOutput,
FeatureSetSpec,
FeatureSetsTagsOutput,
FeaturesOutput,
FeatureVector,
FeatureVectorRecord,
FeatureVectorsOutput,
FeatureVectorsTagsOutput,
)
from .frontend_spec import FeatureFlags, FrontendSpec, ProjectMembershipFeatureFlag
from .function import FunctionState
Expand Down
8 changes: 8 additions & 0 deletions mlrun/api/schemas/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ class FeatureSetsOutput(BaseModel):
feature_sets: List[FeatureSet]


class FeatureSetsTagsOutput(BaseModel):
tags: List[str] = []


class FeatureSetDigestSpec(BaseModel):
entities: List[Entity]
features: List[Feature]
Expand Down Expand Up @@ -124,6 +128,10 @@ class FeatureVectorsOutput(BaseModel):
feature_vectors: List[FeatureVector]


class FeatureVectorsTagsOutput(BaseModel):
tags: List[str] = []


class DataSource(BaseModel):
kind: str
name: str
Expand Down
10 changes: 10 additions & 0 deletions tests/api/api/feature_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ def _list_and_assert_objects(
return response_body


def _list_tags_and_assert(client: TestClient, entity_name, project, expected_tags):
entity_url_name = entity_name.replace("_", "-")
url = f"/api/projects/{project}/{entity_url_name}/*/tags"
response = client.get(url)
assert response.status_code == HTTPStatus.OK.value
response_body = response.json()

assert DeepDiff(response_body["tags"], expected_tags, ignore_order=True,) == {}


def _patch_object(
client: TestClient,
project_name,
Expand Down
14 changes: 14 additions & 0 deletions tests/api/api/feature_store/test_feature_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .base import (
_assert_diff_as_expected_except_for_specific_metadata,
_list_and_assert_objects,
_list_tags_and_assert,
_patch_object,
_test_partition_by_for_feature_store_objects,
)
Expand Down Expand Up @@ -510,6 +511,19 @@ def test_feature_set_tagging_with_re_store(db: Session, client: TestClient) -> N
assert response[0]["metadata"]["extra_metadata"] == 200


def test_list_feature_sets_tags(db: Session, client: TestClient) -> None:
project_name = "some-project"
name = "feature_set1"
feature_set = _generate_feature_set(name)

tags = ["tag-1", "tag-2", "tag-3", "tag-4"]
for tag in tags:
_store_and_assert_feature_set(client, project_name, name, tag, feature_set)
_list_tags_and_assert(
client, "feature_sets", project_name, tags,
)


def test_feature_set_create_without_labels(db: Session, client: TestClient) -> None:
project_name = f"prj-{uuid4().hex}"
name = "feature_set1"
Expand Down

0 comments on commit bebb828

Please sign in to comment.