Skip to content

Commit

Permalink
[Projects] Add project labels (#607)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hedingber committed Dec 20, 2020
1 parent 869f202 commit de956e3
Show file tree
Hide file tree
Showing 22 changed files with 290 additions and 56 deletions.
4 changes: 3 additions & 1 deletion mlrun/api/api/endpoints/projects.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import typing
from http import HTTPStatus

from fastapi import APIRouter, Depends, Response, Header, Query
Expand Down Expand Up @@ -59,6 +60,7 @@ def delete_project(
def list_projects(
format_: schemas.Format = Query(schemas.Format.full, alias="format"),
owner: str = None,
labels: typing.List[str] = Query(None, alias="label"),
db_session: Session = Depends(deps.get_db_session),
):
return get_project_member().list_projects(db_session, owner, format_)
return get_project_member().list_projects(db_session, owner, format_, labels)
6 changes: 5 additions & 1 deletion mlrun/api/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ def delete_schedule(self, session, project: str, name: str):

@abstractmethod
def list_projects(
self, session, owner: str = None, format_: schemas.Format = schemas.Format.full,
self,
session,
owner: str = None,
format_: schemas.Format = schemas.Format.full,
labels: List[str] = None,
) -> schemas.ProjectsOutput:
pass

Expand Down
10 changes: 8 additions & 2 deletions mlrun/api/db/filedb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,9 +135,15 @@ def store_schedule(self, session, data):
return self._transform_run_db_error(self.db.store_schedule, data)

def list_projects(
self, session, owner: str = None, format_: schemas.Format = schemas.Format.full,
self,
session,
owner: str = None,
format_: schemas.Format = schemas.Format.full,
labels: List[str] = None,
) -> schemas.ProjectsOutput:
return self._transform_run_db_error(self.db.list_projects, owner, format_)
return self._transform_run_db_error(
self.db.list_projects, owner, format_, labels
)

def store_project(self, session, name: str, project: schemas.Project):
raise NotImplementedError()
Expand Down
11 changes: 9 additions & 2 deletions mlrun/api/db/sqldb/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,8 @@ def create_project(self, session: Session, project: schemas.Project):
created=created,
full_object=project.dict(),
)
labels = project.metadata.labels or {}
update_labels(project_record, labels)
self._upsert(session, project_record)

def store_project(self, session: Session, name: str, project: schemas.Project):
Expand Down Expand Up @@ -715,10 +717,13 @@ def list_projects(
session: Session,
owner: str = None,
format_: mlrun.api.schemas.Format = mlrun.api.schemas.Format.full,
labels: List[str] = None,
) -> schemas.ProjectsOutput:
project_records = self._query(session, Project, owner=owner)
query = self._query(session, Project, owner=owner)
if labels:
query = self._add_labels_filter(session, query, Project, labels)
projects = []
for project_record in project_records:
for project_record in query:
if format_ == mlrun.api.schemas.Format.name_only:
projects.append(project_record.name)
elif format_ == mlrun.api.schemas.Format.full:
Expand All @@ -741,6 +746,8 @@ def _update_project_record_from_project(
project_record.description = project.spec.description
project_record.source = project.spec.source
project_record.state = project.status.state
labels = project.metadata.labels or {}
update_labels(project_record, labels)
self._upsert(session, project_record)

def _patch_project_record_from_project(
Expand Down
4 changes: 4 additions & 0 deletions mlrun/api/db/sqldb/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,10 @@ class Project(Base):
state = Column(String)
users = relationship(User, secondary=project_users)

Label = make_label(__tablename__)

labels = relationship(Label, cascade="all, delete-orphan")

@property
def full_object(self):
if self._full_object:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Adding project labels
Revision ID: bcd0c1f9720c
Revises: f4249b4ba6fa
Create Date: 2020-12-20 03:42:02.763802
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "bcd0c1f9720c"
down_revision = "f4249b4ba6fa"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"projects_labels",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(), nullable=True),
sa.Column("value", sa.String(), nullable=True),
sa.Column("parent", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["parent"], ["projects.id"],),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("name", "parent", name="_projects_labels_uc"),
)
# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("projects_labels")
# ### end Alembic commands ###
1 change: 0 additions & 1 deletion mlrun/api/schemas/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
ProjectMetadata,
ProjectSpec,
ProjectsOutput,
ProjectRecord,
)
from .schedule import (
SchedulesOutput,
Expand Down
8 changes: 1 addition & 7 deletions mlrun/api/schemas/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
class ProjectMetadata(pydantic.BaseModel):
name: str
created: typing.Optional[datetime.datetime] = None
labels: typing.Optional[dict]

class Config:
extra = pydantic.Extra.allow
Expand Down Expand Up @@ -40,13 +41,6 @@ class Project(pydantic.BaseModel):
status: ObjectStatus = ObjectStatus()


class ProjectRecord(Project):
id: int = None

class Config:
orm_mode = True


class ProjectsOutput(pydantic.BaseModel):
# use the format query param to control whether the full object will be returned or only the names
projects: typing.List[typing.Union[Project, str]]
35 changes: 25 additions & 10 deletions mlrun/api/utils/clients/nuclio.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import copy
import http
import typing

import requests.adapters
import sqlalchemy.orm
Expand Down Expand Up @@ -28,9 +29,7 @@ def create_project(
self, session: sqlalchemy.orm.Session, project: mlrun.api.schemas.Project
):
logger.debug("Creating project in Nuclio", project=project)
body = self._generate_request_body(
project.metadata.name, project.spec.description
)
body = self._generate_request_body(project)
self._post_project_to_nuclio(body)

def store_project(
Expand All @@ -40,7 +39,7 @@ def store_project(
project: mlrun.api.schemas.Project,
):
logger.debug("Storing project in Nuclio", name=name, project=project)
body = self._generate_request_body(name, project.spec.description)
body = self._generate_request_body(project)
try:
self._get_project_from_nuclio(name)
except requests.HTTPError as exc:
Expand All @@ -59,6 +58,10 @@ def patch_project(
):
response = self._get_project_from_nuclio(name)
response_body = response.json()
if project.get("metadata").get("labels") is not None:
response_body.setdefault("metadata", {}).setdefault("labels", {}).update(
project["metadata"]["labels"]
)
if project.get("spec").get("description") is not None:
response_body.setdefault("spec", {})["description"] = project["spec"][
"description"
Expand All @@ -67,7 +70,11 @@ def patch_project(

def delete_project(self, session: sqlalchemy.orm.Session, name: str):
logger.debug("Deleting project in Nuclio", name=name)
body = self._generate_request_body(name)
body = self._generate_request_body(
mlrun.api.schemas.Project(
metadata=mlrun.api.schemas.ProjectMetadata(name=name)
)
)
self._send_request_to_api("DELETE", "projects", json=body)

def get_project(
Expand All @@ -82,11 +89,16 @@ def list_projects(
session: sqlalchemy.orm.Session,
owner: str = None,
format_: mlrun.api.schemas.Format = mlrun.api.schemas.Format.full,
labels: typing.List[str] = None,
) -> mlrun.api.schemas.ProjectsOutput:
if owner:
raise NotImplementedError(
"Listing nuclio projects by owner is currently not supported"
)
if labels:
raise NotImplementedError(
"Filtering nuclio projects by labels is currently not supported"
)
response = self._send_request_to_api("GET", "projects")
response_body = response.json()
projects = []
Expand Down Expand Up @@ -136,19 +148,22 @@ def _send_request_to_api(self, method, path, **kwargs):
return response

@staticmethod
def _generate_request_body(name, description=None):
def _generate_request_body(project: mlrun.api.schemas.Project):
body = {
"metadata": {"name": name},
"metadata": {"name": project.metadata.name},
}
if description:
body["spec"] = {"description": description}
if project.metadata.labels:
body["metadata"]["labels"] = project.metadata.labels
if project.spec.description:
body["spec"] = {"description": project.spec.description}
return body

@staticmethod
def _transform_nuclio_project_to_schema(nuclio_project):
return mlrun.api.schemas.Project(
metadata=mlrun.api.schemas.ProjectMetadata(
name=nuclio_project["metadata"]["name"]
name=nuclio_project["metadata"]["name"],
labels=nuclio_project["metadata"].get("labels"),
),
spec=mlrun.api.schemas.ProjectSpec(
description=nuclio_project["spec"].get("description")
Expand Down
3 changes: 2 additions & 1 deletion mlrun/api/utils/projects/leader.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,9 @@ def list_projects(
session: sqlalchemy.orm.Session,
owner: str = None,
format_: mlrun.api.schemas.Format = mlrun.api.schemas.Format.full,
labels: typing.List[str] = None,
) -> mlrun.api.schemas.ProjectsOutput:
return self._leader_follower.list_projects(session, owner, format_)
return self._leader_follower.list_projects(session, owner, format_, labels)

def _start_periodic_sync(self):
# if no followers no need for sync
Expand Down
2 changes: 2 additions & 0 deletions mlrun/api/utils/projects/member.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import typing

import sqlalchemy.orm

Expand Down Expand Up @@ -61,5 +62,6 @@ def list_projects(
session: sqlalchemy.orm.Session,
owner: str = None,
format_: mlrun.api.schemas.Format = mlrun.api.schemas.Format.full,
labels: typing.List[str] = None,
) -> mlrun.api.schemas.ProjectsOutput:
pass
2 changes: 2 additions & 0 deletions mlrun/api/utils/projects/remotes/member.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import typing

import sqlalchemy.orm

Expand Down Expand Up @@ -47,5 +48,6 @@ def list_projects(
session: sqlalchemy.orm.Session,
owner: str = None,
format_: mlrun.api.schemas.Format = mlrun.api.schemas.Format.full,
labels: typing.List[str] = None,
) -> mlrun.api.schemas.ProjectsOutput:
pass
5 changes: 3 additions & 2 deletions mlrun/api/utils/projects/remotes/nop.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ def list_projects(
session: sqlalchemy.orm.Session,
owner: str = None,
format_: mlrun.api.schemas.Format = mlrun.api.schemas.Format.full,
labels: typing.List[str] = None,
) -> mlrun.api.schemas.ProjectsOutput:
if owner:
raise NotImplementedError()
if owner or labels:
raise NotImplementedError("Filtering by owner or labels is not supported")
if format_ == mlrun.api.schemas.Format.full:
return mlrun.api.schemas.ProjectsOutput(
projects=list(self._projects.values())
Expand Down
5 changes: 4 additions & 1 deletion mlrun/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,10 @@ def create_project(self, project: schemas.Project) -> schemas.Project:

@abstractmethod
def list_projects(
self, owner: str = None, format_: schemas.Format = schemas.Format.full,
self,
owner: str = None,
format_: schemas.Format = schemas.Format.full,
labels: List[str] = None,
) -> schemas.ProjectsOutput:
pass

Expand Down
3 changes: 2 additions & 1 deletion mlrun/db/filedb.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,9 @@ def list_projects(
self,
owner: str = None,
format_: mlrun.api.schemas.Format = mlrun.api.schemas.Format.full,
labels: List[str] = None,
) -> mlrun.api.schemas.ProjectsOutput:
if owner or format_ == mlrun.api.schemas.Format.full:
if owner or format_ == mlrun.api.schemas.Format.full or labels:
raise NotImplementedError()
run_dir = path.join(self.dirpath, run_logs)
if not path.isdir(run_dir):
Expand Down
2 changes: 2 additions & 0 deletions mlrun/db/httpdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -970,10 +970,12 @@ def list_projects(
self,
owner: str = None,
format_: mlrun.api.schemas.Format = mlrun.api.schemas.Format.full,
labels: List[str] = None,
) -> List[Union[mlrun.projects.MlrunProject, str]]:
params = {
"owner": owner,
"format": format_,
"label": labels or [],
}

error_message = f"Failed listing projects, query: {params}"
Expand Down
1 change: 1 addition & 0 deletions mlrun/db/sqldb.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def list_projects(
self,
owner: str = None,
format_: mlrun.api.schemas.Format = mlrun.api.schemas.Format.full,
labels: List[str] = None,
) -> mlrun.api.schemas.ProjectsOutput:
raise NotImplementedError()

Expand Down
3 changes: 2 additions & 1 deletion mlrun/projects/project.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,9 +204,10 @@ def _project_instance_from_struct(struct, name):


class ProjectMetadata(ModelObj):
def __init__(self, name=None, created=None):
def __init__(self, name=None, created=None, labels=None):
self.name = name
self.created = created
self.labels = labels or {}


class ProjectSpec(ModelObj):
Expand Down
25 changes: 24 additions & 1 deletion tests/api/api/test_projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,9 @@ def test_projects_crud(db: Session, client: TestClient) -> None:
)

name2 = f"prj-{uuid4().hex}"
labels_2 = {"key": "value"}
project_2 = mlrun.api.schemas.Project(
metadata=mlrun.api.schemas.ProjectMetadata(name=name2),
metadata=mlrun.api.schemas.ProjectMetadata(name=name2, labels=labels_2),
spec=mlrun.api.schemas.ProjectSpec(description="banana", source="source"),
)

Expand All @@ -53,6 +54,28 @@ def test_projects_crud(db: Session, client: TestClient) -> None:
expected = [name1, name2]
assert expected == response.json()["projects"]

# list - names only - filter by label existence
response = client.get(
"/api/projects",
params={
"format": mlrun.api.schemas.Format.name_only,
"label": list(labels_2.keys())[0],
},
)
expected = [name2]
assert expected == response.json()["projects"]

# list - names only - filter by label match
response = client.get(
"/api/projects",
params={
"format": mlrun.api.schemas.Format.name_only,
"label": f"{list(labels_2.keys())[0]}={list(labels_2.values())[0]}",
},
)
expected = [name2]
assert expected == response.json()["projects"]

# list - full
response = client.get(
"/api/projects", params={"format": mlrun.api.schemas.Format.full}
Expand Down

0 comments on commit de956e3

Please sign in to comment.