Skip to content

Commit

Permalink
[Feature store] Added list entities endpoint (#594)
Browse files Browse the repository at this point in the history
  • Loading branch information
theSaarco committed Dec 15, 2020
1 parent 7a38bc3 commit 929e943
Show file tree
Hide file tree
Showing 13 changed files with 227 additions and 17 deletions.
12 changes: 12 additions & 0 deletions mlrun/api/api/endpoints/feature_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,18 @@ def list_features(
return features


@router.get("/projects/{project}/entities", response_model=schemas.EntitiesOutput)
def list_entities(
project: str,
name: str = None,
tag: str = None,
labels: List[str] = Query(None, alias="label"),
db_session: Session = Depends(deps.get_db_session),
):
features = get_db().list_entities(db_session, project, name, tag, labels)
return features


@router.post(
"/projects/{project}/feature-vectors", response_model=schemas.FeatureVector
)
Expand Down
11 changes: 11 additions & 0 deletions mlrun/api/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,17 @@ def list_features(
) -> schemas.FeaturesOutput:
pass

@abstractmethod
def list_entities(
self,
session,
project: str,
name: str = None,
tag: str = None,
labels: List[str] = None,
) -> schemas.EntitiesOutput:
pass

@abstractmethod
def list_feature_sets(
self,
Expand Down
10 changes: 10 additions & 0 deletions mlrun/api/db/filedb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,16 @@ def list_features(
) -> schemas.FeaturesOutput:
raise NotImplementedError()

def list_entities(
self,
session,
project: str,
name: str = None,
tag: str = None,
labels: List[str] = None,
) -> schemas.EntitiesOutput:
pass

def list_feature_sets(
self,
session,
Expand Down
97 changes: 85 additions & 12 deletions mlrun/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,9 +873,37 @@ def _generate_records_with_tags_assigned(
def _generate_feature_set_digest(feature_set: schemas.FeatureSet):
return schemas.FeatureSetDigestOutput(
metadata=feature_set.metadata,
spec=schemas.FeatureSetDigestSpec(entities=feature_set.spec.entities),
spec=schemas.FeatureSetDigestSpec(
entities=feature_set.spec.entities, features=feature_set.spec.features,
),
)

def _generate_feature_or_entity_list_query(
self,
session,
query_class,
project: str,
feature_set_keys,
name: str = None,
tag: str = None,
labels: List[str] = None,
):
# Query the actual objects to be returned
query = (
session.query(FeatureSet, query_class)
.filter_by(project=project)
.join(query_class)
)

if name:
query = query.filter(query_class.name.ilike(f"%{name}%"))
if labels:
query = self._add_labels_filter(session, query, query_class, labels)
if tag:
query = query.filter(FeatureSet.id.in_(feature_set_keys))

return query

def list_features(
self,
session,
Expand All @@ -890,19 +918,10 @@ def list_features(
session, FeatureSet, project, tag, name=None
)

# Query the actual objects to be returned
query = (
session.query(FeatureSet, Feature)
.filter_by(project=project)
.join(FeatureSet.features)
query = self._generate_feature_or_entity_list_query(
session, Feature, project, feature_set_id_tags.keys(), name, tag, labels
)

if name:
query = query.filter(Feature.name.ilike(f"%{name}%"))
if labels:
query = self._add_labels_filter(session, query, Feature, labels)
if tag:
query = query.filter(FeatureSet.id.in_(feature_set_id_tags.keys()))
if entities:
query = query.join(FeatureSet.entities).filter(Entity.name.in_(entities))

Expand Down Expand Up @@ -944,6 +963,60 @@ def list_features(
)
return schemas.FeaturesOutput(features=features_results)

def list_entities(
self,
session,
project: str,
name: str = None,
tag: str = None,
labels: List[str] = None,
) -> schemas.EntitiesOutput:
feature_set_id_tags = self._get_records_to_tags_map(
session, FeatureSet, project, tag, name=None
)

query = self._generate_feature_or_entity_list_query(
session, Entity, project, feature_set_id_tags.keys(), name, tag, labels
)

entities_results = []
for row in query:
entity_record = schemas.FeatureRecord.from_orm(row.Entity)
entity_name = entity_record.name

feature_sets = self._generate_records_with_tags_assigned(
row.FeatureSet,
self._transform_feature_set_model_to_schema,
feature_set_id_tags,
tag,
)

for feature_set in feature_sets:
# Get the feature from the feature-set full structure, as it may contain extra fields (which are not
# in the DB)
entity = next(
(
entity
for entity in feature_set.spec.entities
if entity.name == entity_name
),
None,
)
if not entity:
raise DBError(
"Inconsistent data in DB - entities in DB not in feature-set document"
)

entities_results.append(
schemas.EntityListOutput(
entity=entity,
feature_set_digest=self._generate_feature_set_digest(
feature_set
),
)
)
return schemas.EntitiesOutput(entities=entities_results)

def list_feature_sets(
self,
session,
Expand Down
2 changes: 2 additions & 0 deletions mlrun/api/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
FeatureSetDigestSpec,
FeatureListOutput,
FeaturesOutput,
EntityListOutput,
EntitiesOutput,
FeatureVector,
FeatureVectorRecord,
FeatureVectorsOutput,
Expand Down
10 changes: 10 additions & 0 deletions mlrun/api/schemas/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class FeatureSetsOutput(BaseModel):

class FeatureSetDigestSpec(BaseModel):
entities: List[Entity]
features: List[Feature]


class FeatureSetDigestOutput(BaseModel):
Expand All @@ -89,6 +90,15 @@ class FeaturesOutput(BaseModel):
features: List[FeatureListOutput]


class EntityListOutput(BaseModel):
entity: Entity
feature_set_digest: FeatureSetDigestOutput


class EntitiesOutput(BaseModel):
entities: List[EntityListOutput]


class FeatureVector(BaseModel):
kind: ObjectKind = Field(ObjectKind.feature_vector, const=True)
metadata: ObjectMetadata
Expand Down
6 changes: 6 additions & 0 deletions mlrun/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,12 @@ def list_features(
) -> schemas.FeaturesOutput:
pass

@abstractmethod
def list_entities(
self, project: str, name: str = None, tag: str = None, labels: List[str] = None,
) -> schemas.EntitiesOutput:
pass

@abstractmethod
def list_feature_sets(
self,
Expand Down
5 changes: 5 additions & 0 deletions mlrun/db/filedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,11 @@ def list_features(
):
raise NotImplementedError()

def list_entities(
self, project: str, name: str = None, tag: str = None, labels: List[str] = None,
):
raise NotImplementedError()

def list_feature_sets(
self,
project: str = "",
Expand Down
20 changes: 18 additions & 2 deletions mlrun/db/httpdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,7 @@ def list_features(
tag: str = None,
entities: List[str] = None,
labels: List[str] = None,
) -> schemas.FeaturesOutput:
) -> List[dict]:
project = project or default_project
params = {
"name": name,
Expand All @@ -751,7 +751,23 @@ def list_features(

error_message = f"Failed listing features, project: {project}, query: {params}"
resp = self.api_call("GET", path, error_message, params=params)
return schemas.FeaturesOutput(**resp.json())
return resp.json()["features"]

def list_entities(
self, project: str, name: str = None, tag: str = None, labels: List[str] = None,
) -> List[dict]:
project = project or default_project
params = {
"name": name,
"tag": tag,
"label": labels or [],
}

path = f"projects/{project}/entities"

error_message = f"Failed listing entities, project: {project}, query: {params}"
resp = self.api_call("GET", path, error_message, params=params)
return resp.json()["entities"]

def list_feature_sets(
self,
Expand Down
7 changes: 7 additions & 0 deletions mlrun/db/sqldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,13 @@ def list_features(
self.db.list_features, self.session, project, name, tag, entities, labels,
)

def list_entities(
self, project: str, name: str = None, tag: str = None, labels: List[str] = None,
):
return self._transform_db_error(
self.db.list_entities, self.session, project, name, tag, labels,
)

def list_feature_sets(
self,
project: str = "",
Expand Down
5 changes: 3 additions & 2 deletions tests/api/api/feature_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,10 @@ def _list_and_assert_objects(
assert response.status_code == HTTPStatus.OK.value
response_body = response.json()
assert entity_name in response_body
number_of_entities = len(response_body[entity_name])
assert (
len(response_body[entity_name]) == expected_number_of_entities
), f"wrong number of {entity_name} entities in response"
number_of_entities == expected_number_of_entities
), f"wrong number of {entity_name} in response - {number_of_entities} instead of {expected_number_of_entities}"
return response_body


Expand Down
45 changes: 45 additions & 0 deletions tests/api/api/feature_store/test_feature_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,51 @@ def test_feature_set_wrong_kind_failure(db: Session, client: TestClient) -> None
assert response.status_code != HTTPStatus.OK.value


def test_entities_list(db: Session, client: TestClient) -> None:
project_name = f"prj-{uuid4().hex}"

name = "feature_set"
count = 5
colors = ["red", "blue"]
for i in range(count):
feature_set = _generate_feature_set(f"{name}_{i}")
feature_set["spec"]["entities"] = [
{
"name": f"entity_{i}",
"value_type": "str",
"labels": {"color": colors[i % 2], "id": f"id_{i}"},
},
]

_feature_set_create_and_assert(client, project_name, feature_set)
_list_and_assert_objects(client, "entities", project_name, "name=entity_0", 1)
_list_and_assert_objects(client, "entities", project_name, "name=entity", count)
_list_and_assert_objects(client, "entities", project_name, "label=color", count)
_list_and_assert_objects(
client, "entities", project_name, f"label=color={colors[1]}", count // 2
)
_list_and_assert_objects(
client, "entities", project_name, "name=entity&label=id=id_0", 1
)

# set a new tag
tag = "my-new-tag"
query = {"feature_sets": {"name": f"{name}_{i}"}}
resp = client.post(f"/api/{project_name}/tag/{tag}", json=query)
assert resp.status_code == HTTPStatus.OK.value
# Now expecting to get 2 objects, one with "latest" tag and one with "my-new-tag"
entities_response = _list_and_assert_objects(
client, "entities", project_name, f"name=entity_{i}", 2
)
assert (
entities_response["entities"][0]["feature_set_digest"]["metadata"]["tag"]
== "latest"
)
assert (
entities_response["entities"][1]["feature_set_digest"]["metadata"]["tag"] == tag
)


def test_features_list(db: Session, client: TestClient) -> None:
project_name = f"prj-{uuid4().hex}"

Expand Down
14 changes: 13 additions & 1 deletion tests/rundb/test_httpdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,12 +433,24 @@ def test_feature_sets(create_server):
feature_set_without_labels["metadata"]["project"] = project
db.store_feature_set(feature_set_without_labels)
feature_set_update = {
"metadata": {"labels": {"label1": "value1", "label2": "value2"}}
"spec": {"entities": [{"name": "nothing", "value_type": "bool"}]},
"metadata": {"labels": {"label1": "value1", "label2": "value2"}},
}
db.patch_feature_set(name, feature_set_update, project)
feature_set = db.get_feature_set(name, project)
assert len(feature_set["metadata"]["labels"]) == 2, "Labels didn't get updated"

features = db.list_features(project, "time")
# The feature-set with different labels also counts here
assert len(features) == count + 1
# Only count, since we modified the entity of the last feature-set - other name, no labels
entities = db.list_entities(project, "ticker")
assert len(entities) == count
entities = db.list_entities(project, labels=["type"])
assert len(entities) == count
entities = db.list_entities(project, labels=["type=prod"])
assert len(entities) == count


def _create_feature_vector(name):
return {
Expand Down

0 comments on commit 929e943

Please sign in to comment.