Skip to content

Commit

Permalink
Merge pull request #13 from demml/updates/metadata-and-sql
Browse files Browse the repository at this point in the history
Updated UI #1
  • Loading branch information
thorrester authored May 6, 2024
2 parents dd52d36 + c819046 commit fddc717
Show file tree
Hide file tree
Showing 15 changed files with 438 additions and 60 deletions.
80 changes: 79 additions & 1 deletion opsml/app/routes/cards.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# Copyright (c) Shipt, Inc.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Optional, Union
from typing import Dict, Optional, Union

from fastapi import APIRouter, Body, Depends, HTTPException, Request, status

Expand All @@ -15,6 +15,7 @@
ListCardRequest,
ListCardResponse,
NamesResponse,
RegistryQuery,
RepositoriesResponse,
UidExistsRequest,
UidExistsResponse,
Expand Down Expand Up @@ -78,6 +79,39 @@ def card_repositories(
return RepositoriesResponse(repositories=repositories)


@router.get("/cards/registry/stats", name="registry_stats")
def query_registry_stats(
request: Request,
registry_type: str,
search_term: Optional[str] = None,
) -> Dict[str, int]:
"""Get card information from a registry
Args:
request:
FastAPI request object
registry_type:
Type of registry
search_term:
search term to filter by. This term can be a repository or a name
Returns:
`dict`
"""

try:
registry: CardRegistry = getattr(request.app.state.registries, registry_type)
stats: Dict[str, int] = registry._registry.query_stats(search_term)

return stats

except Exception as error:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to query registry. {error}",
) from error


@router.get("/cards/names", response_model=NamesResponse, name="names")
def card_names(
request: Request,
Expand All @@ -104,6 +138,47 @@ def card_names(
return NamesResponse(names=names)


@router.get("/cards/registry/query/page", response_model=RegistryQuery, name="registry_page")
def query_registry_page(
request: Request,
registry_type: str,
sort_by: str = "updated_at",
repository: Optional[str] = None,
search_term: Optional[str] = None,
page: int = 0,
) -> RegistryQuery:
"""Get card information from a registry
Args:
request:
FastAPI request object
registry_type:
Type of registry
sort_by:
Field to sort by
repository:
repository to filter by
search_term:
search term to filter by. This term can be a repository or a name
page:
page number
Returns:
`dict`
"""

try:
registry: CardRegistry = getattr(request.app.state.registries, registry_type)
page = registry._registry.query_page(sort_by, page, repository, search_term)
return RegistryQuery(page=page)

except Exception as error:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to query registry. {error}",
) from error


@router.post(
"/cards/version",
response_model=Union[VersionResponse, UidExistsResponse],
Expand Down Expand Up @@ -169,6 +244,9 @@ def list_cards(
sort_by_timestamp=payload.sort_by_timestamp,
)

if payload.page:
cards = cards[payload.page * 30 : payload.page + 30]

return ListCardResponse(cards=cards)

except Exception as error:
Expand Down
7 changes: 6 additions & 1 deletion opsml/app/routes/pydantic_models.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) Shipt, Inc.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Tuple, Union

from fastapi import File, Form, UploadFile
from pydantic import BaseModel, Field, model_validator
Expand Down Expand Up @@ -75,6 +75,10 @@ class PutFileRequest(BaseModel):
write_path: str


class RegistryQuery(BaseModel):
page: List[Tuple[Union[str, int], ...]]


class ListCardRequest(BaseModel):
name: Optional[str] = None
repository: Optional[str] = None
Expand All @@ -89,6 +93,7 @@ class ListCardRequest(BaseModel):
table_name: Optional[str] = None
query_terms: Optional[Dict[str, Any]] = None
sort_by_timestamp: bool = False
page: Optional[int] = None

@model_validator(mode="before")
@classmethod
Expand Down
4 changes: 3 additions & 1 deletion opsml/cards/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,9 @@ def create_registry_record(self) -> Dict[str, Any]:
Registry metadata
"""
exclude_attr = {"data"}
return self.model_dump(exclude=exclude_attr)
dumped_model = self.model_dump(exclude=exclude_attr)
dumped_model["interface_type"] = self.interface.name()
return dumped_model

def add_info(self, info: Dict[str, Union[float, int, str]]) -> None:
"""
Expand Down
2 changes: 2 additions & 0 deletions opsml/cards/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def create_registry_record(self) -> Dict[str, Any]:

exclude_vars = {"interface": {"model", "preprocessor", "sample_data", "onnx_model"}}
dumped_model = self.model_dump(exclude=exclude_vars)
dumped_model["interface_type"] = self.interface.name()
dumped_model["task_type"] = self.interface.task_type

return dumped_model

Expand Down
4 changes: 4 additions & 0 deletions opsml/registry/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class DataRegistryRecord(SaveRecord):
runcard_uid: Optional[str] = None
pipelinecard_uid: Optional[str] = None
auditcard_uid: Optional[str] = None
interface_type: str = CommonKwargs.UNDEFINED.value

@model_validator(mode="before")
@classmethod
Expand All @@ -53,12 +54,15 @@ class ModelRegistryRecord(SaveRecord):
runcard_uid: Optional[str] = None
pipelinecard_uid: Optional[str] = None
auditcard_uid: Optional[str] = None
interface_type: str = CommonKwargs.UNDEFINED.value
task_type: str = CommonKwargs.UNDEFINED.value

model_config = ConfigDict(protected_namespaces=("protect_",))

@model_validator(mode="before")
@classmethod
def set_metadata(cls, values: Dict[str, Any]) -> Dict[str, Any]:
print(values["datacard_uid"])
metadata: Dict[str, Any] = values["metadata"]
values["sample_data_type"] = metadata["data_schema"]["data_type"]
values["model_type"] = values["interface"]["model_type"]
Expand Down
103 changes: 103 additions & 0 deletions opsml/registry/sql/base/query_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# mypy: disable-error-code="call-overload"
# pylint: disable=not-callable

# Copyright (c) Shipt, Inc.
# This source code is licensed under the MIT license found in the
Expand Down Expand Up @@ -435,6 +436,108 @@ def get_unique_card_names(self, repository: Optional[str], table: CardSQLTable)
with self.session() as sess:
return sess.scalars(query).all()

def query_stats(
self,
table: CardSQLTable,
search_term: Optional[str],
) -> Dict[str, int]:
"""Query stats for a card registry
Args:
table:
Registry table to query
search_term:
Search term
"""

query = select(
sqa_func.count(distinct(table.name)).label("nbr_names"), # type: ignore
sqa_func.count(table.version).label("nbr_versions"), # type: ignore
sqa_func.count(distinct(table.repository)).label("nbr_repos"), # type: ignore
)

if search_term:
query = query.filter(
or_(
table.name.like(f"%{search_term}%"), # type: ignore
table.repository.like(f"%{search_term}%"), # type: ignore
),
)

with self.session() as sess:
results = sess.execute(query).first()

if not results:
return {
"nbr_names": 0,
"nbr_versions": 0,
"nbr_repos": 0,
}
return {
"nbr_names": results[0],
"nbr_versions": results[1],
"nbr_repos": results[2],
}

def query_page(
self,
sort_by: str,
page: int,
search_term: Optional[str],
repository: Optional[str],
table: CardSQLTable,
) -> Sequence[Row[Any]]:
"""Returns a page result from card registry
Args:
sort_by:
Field to sort by
page:
Page number
search_term:
Search term
repository:
Repository name
table:
Registry table to query
Returns:
Tuple of card summary
"""

subquery = select(
table.repository,
table.name,
sqa_func.count(distinct(table.version)).label("versions"), # type:ignore
sqa_func.max(table.timestamp).label("updated_at"),
sqa_func.min(table.timestamp).label("created_at"),
).group_by(table.repository, table.name)

if repository is not None:
subquery = subquery.filter(table.repository == repository)

if search_term:
subquery = subquery.filter(
or_(
table.name.like(f"%{search_term}%"), # type: ignore
table.repository.like(f"%{search_term}%"), # type: ignore
),
)
subquery = subquery.subquery()

subquery2 = select(subquery, (sqa_func.row_number().over(order_by=sort_by)).label("row_number")).subquery()

lower_bound = page * 30
upper_bound = lower_bound + 30

query = (
select(subquery2)
.filter(subquery2.c.row_number.between(lower_bound, upper_bound))
.order_by(text("updated_at desc"))
)

with self.session() as sess:
records = sess.execute(query).all()

return records

def delete_card_record(
self,
table: CardSQLTable,
Expand Down
44 changes: 43 additions & 1 deletion opsml/registry/sql/base/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import textwrap
from datetime import datetime, timedelta, timezone
from typing import Any, Dict, List, Optional, Sequence, Tuple, cast
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast

import bcrypt
import jwt
Expand Down Expand Up @@ -72,6 +72,48 @@ def unique_repositories(self) -> Sequence[str]:
"""Returns a list of unique repositories"""
return self.engine.get_unique_repositories(table=self._table)

def query_stats(self, search_term: Optional[str] = None) -> Dict[str, int]:
"""Query stats from Card Database
Args:
search_term:
Search term to filter by
Returns:
Dictionary of stats
"""
return self.engine.query_stats(table=self._table, search_term=search_term)

def query_page(
self,
sort_by: str,
page: int,
repository: Optional[str] = None,
search_term: Optional[str] = None,
) -> List[Tuple[Union[str, int], ...]]:
"""Query page from Card Database
Args:
sort_by:
Field to sort by
page:
Page number
repository:
Repository to filter by
search_term:
Search term to filter by
Returns:
List of tuples
"""
return cast(
List[Tuple[Union[str, int], ...]],
self.engine.query_page(
table=self._table,
repository=repository,
search_term=search_term,
sort_by=sort_by,
page=page,
),
)

def get_unique_card_names(self, repository: Optional[str] = None) -> Sequence[str]:
"""Returns a list of unique card names
Args:
Expand Down
5 changes: 4 additions & 1 deletion opsml/registry/sql/base/sql_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from sqlalchemy.orm import declarative_base, declarative_mixin, validates

from opsml.helpers.logging import ArtifactLogger
from opsml.types import RegistryTableNames
from opsml.types import CommonKwargs, RegistryTableNames

logger = ArtifactLogger.get_logger()

Expand Down Expand Up @@ -56,6 +56,7 @@ class DataMixin:
runcard_uid = Column("runcard_uid", String(64))
pipelinecard_uid = Column("pipelinecard_uid", String(64))
auditcard_uid = Column("auditcard_uid", String(64))
interface_type = Column("interface_type", String(64), nullable=False, default=CommonKwargs.UNDEFINED.value)


class DataSchema(Base, BaseMixin, DataMixin):
Expand All @@ -73,6 +74,8 @@ class ModelMixin:
runcard_uid = Column("runcard_uid", String(64))
pipelinecard_uid = Column("pipelinecard_uid", String(64))
auditcard_uid = Column("auditcard_uid", String(64))
interface_type = Column("interface_type", String(64), nullable=False, default=CommonKwargs.UNDEFINED.value)
task_type = Column("task_type", String(64), nullable=False, default=CommonKwargs.UNDEFINED.value)


class ModelSchema(Base, BaseMixin, ModelMixin):
Expand Down
Loading

0 comments on commit fddc717

Please sign in to comment.