Skip to content

Commit

Permalink
Merge pull request #11 from demml/feature/sort-by-timestamp
Browse files Browse the repository at this point in the history
Feature/sort by timestamp
  • Loading branch information
thorrester committed May 3, 2024
2 parents 8248d52 + 086e292 commit 0d7eb3f
Show file tree
Hide file tree
Showing 10 changed files with 70 additions and 10 deletions.
1 change: 1 addition & 0 deletions opsml/app/routes/cards.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,7 @@ def list_cards(
tags=payload.tags,
ignore_release_candidates=payload.ignore_release_candidates,
query_terms=payload.query_terms,
sort_by_timestamp=payload.sort_by_timestamp,
)

return ListCardResponse(cards=cards)
Expand Down
1 change: 1 addition & 0 deletions opsml/app/routes/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class ListCardRequest(BaseModel):
registry_type: Optional[str] = None
table_name: Optional[str] = None
query_terms: Optional[Dict[str, Any]] = None
sort_by_timestamp: bool = False

@model_validator(mode="before")
@classmethod
Expand Down
14 changes: 7 additions & 7 deletions opsml/registry/records.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import time
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, ConfigDict, model_validator
from pydantic import BaseModel, ConfigDict, Field, model_validator

from opsml.types import CardVersion, CommonKwargs, RegistryType

Expand All @@ -28,7 +28,7 @@ class SaveRecord(BaseModel):

class DataRegistryRecord(SaveRecord):
data_type: Optional[str] = None
timestamp: int = get_timestamp()
timestamp: int = Field(default_factory=get_timestamp)
runcard_uid: Optional[str] = None
pipelinecard_uid: Optional[str] = None
auditcard_uid: Optional[str] = None
Expand All @@ -49,7 +49,7 @@ class ModelRegistryRecord(SaveRecord):
datacard_uid: str
sample_data_type: str
model_type: str
timestamp: int = get_timestamp()
timestamp: int = Field(default_factory=get_timestamp)
runcard_uid: Optional[str] = None
pipelinecard_uid: Optional[str] = None
auditcard_uid: Optional[str] = None
Expand Down Expand Up @@ -77,15 +77,15 @@ class RunRegistryRecord(SaveRecord):
project: Optional[str] = None
artifact_uris: Optional[Dict[str, Dict[str, str]]] = None
tags: Dict[str, Union[str, int]]
timestamp: int = get_timestamp()
timestamp: int = Field(default_factory=get_timestamp)


class PipelineRegistryRecord(SaveRecord):
pipeline_code_uri: Optional[str] = None
datacard_uids: List[str]
modelcard_uids: List[str]
runcard_uids: List[str]
timestamp: int = get_timestamp()
timestamp: int = Field(default_factory=get_timestamp)


class ProjectRegistryRecord(BaseModel):
Expand All @@ -94,15 +94,15 @@ class ProjectRegistryRecord(BaseModel):
repository: str
project_id: int
version: Optional[str] = None
timestamp: int = get_timestamp()
timestamp: int = Field(default_factory=get_timestamp)


class AuditRegistryRecord(SaveRecord):
approved: bool
datacards: List[CardVersion]
modelcards: List[CardVersion]
runcards: List[CardVersion]
timestamp: int = get_timestamp()
timestamp: int = Field(default_factory=get_timestamp)

@model_validator(mode="before")
@classmethod
Expand Down
4 changes: 4 additions & 0 deletions opsml/registry/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def list_cards(
max_date: Optional[str] = None,
limit: Optional[int] = None,
ignore_release_candidates: bool = False,
sort_by_timestamp: bool = False,
) -> List[Dict[str, Any]]:
"""Retrieves records from registry
Expand All @@ -87,6 +88,8 @@ def list_cards(
CardInfo object. If present, the info object takes precedence
ignore_release_candidates:
If True, ignores release candidates
sort_by_timestamp:
If True, sorts by timestamp descending
Returns:
pandas dataframe of records or list of dictionaries
Expand Down Expand Up @@ -117,6 +120,7 @@ def list_cards(
limit=limit,
tags=tags,
ignore_release_candidates=ignore_release_candidates,
sort_by_timestamp=sort_by_timestamp,
)

return card_list
Expand Down
4 changes: 4 additions & 0 deletions opsml/registry/sql/base/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def list_cards(
limit: Optional[int] = None,
ignore_release_candidates: bool = False,
query_terms: Optional[Dict[str, Any]] = None,
sort_by_timestamp: bool = False,
) -> pd.DataFrame:
"""
Retrieves records from registry
Expand All @@ -161,6 +162,8 @@ def list_cards(
If True, release candidates will be ignored
query_terms:
Dictionary of query terms to filter by
sort_by_timestamp:
If True, sorts by timestamp descending
Returns:
Dictionary of card records
Expand All @@ -179,6 +182,7 @@ def list_cards(
"registry_type": self.registry_type.value,
"ignore_release_candidates": ignore_release_candidates,
"query_terms": query_terms,
"sort_by_timestamp": sort_by_timestamp,
},
)

Expand Down
8 changes: 7 additions & 1 deletion opsml/registry/sql/base/query_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def _records_from_table_query(
tags: Optional[Dict[str, str]] = None,
limit: Optional[int] = None,
query_terms: Optional[Dict[str, Any]] = None,
sort_by_timestamp: bool = False,
) -> Select[Any]:
"""
Creates a sql query based on table, uid, name, repository and version
Expand Down Expand Up @@ -274,7 +275,10 @@ def _records_from_table_query(
if bool(filters):
query = query.filter(*filters)

query = query.order_by(text("major desc"), text("minor desc"), text("patch desc"))
if not sort_by_timestamp:
query = query.order_by(text("major desc"), text("minor desc"), text("patch desc"))
else:
query = query.order_by(table.timestamp.desc()) # type: ignore

if limit is not None:
query = query.limit(limit)
Expand Down Expand Up @@ -311,6 +315,7 @@ def get_records_from_table(
tags: Optional[Dict[str, str]] = None,
limit: Optional[int] = None,
query_terms: Optional[Dict[str, Any]] = None,
sort_by_timestamp: bool = False,
) -> List[Dict[str, Any]]:
query = self._records_from_table_query(
table=table,
Expand All @@ -322,6 +327,7 @@ def get_records_from_table(
tags=tags,
limit=limit,
query_terms=query_terms,
sort_by_timestamp=sort_by_timestamp,
)

with self.session() as sess:
Expand Down
1 change: 1 addition & 0 deletions opsml/registry/sql/base/registry_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def list_cards(
limit: Optional[int] = None,
ignore_release_candidates: bool = False,
query_terms: Optional[Dict[str, Any]] = None,
sort_by_timestamp: bool = False,
) -> List[Dict[str, Any]]:
raise NotImplementedError

Expand Down
9 changes: 7 additions & 2 deletions opsml/registry/sql/base/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ def list_cards(
limit: Optional[int] = None,
ignore_release_candidates: bool = False,
query_terms: Optional[Dict[str, Any]] = None,
sort_by_timestamp: bool = False,
) -> List[Dict[str, Any]]:
"""
Retrieves records from registry
Expand All @@ -191,6 +192,8 @@ def list_cards(
If True, will ignore release candidates when searching for versions
query_terms:
Dictionary of query terms to filter by
sort_by_timestamp:
If True, sorts by timestamp descending
Returns:
Expand All @@ -210,10 +213,12 @@ def list_cards(
tags=tags,
limit=limit,
query_terms=query_terms,
sort_by_timestamp=sort_by_timestamp,
)

if cleaned_name is not None:
records = self._sort_by_version(records=records)
# may not need
# if cleaned_name is not None:
# records = self._sort_by_version(records=records)

if version is not None:
if ignore_release_candidates:
Expand Down
2 changes: 2 additions & 0 deletions tests/test_app/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@ def test_register_data(

_ = registry.list_cards()

_ = registry.list_cards(sort_by_timestamp=True)

# Verify repositories / names
repositories = registry._registry.unique_repositories
assert "mlops" in repositories
Expand Down
36 changes: 36 additions & 0 deletions tests/test_registry/test_registry.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sys
import time
import uuid
from pathlib import Path
from typing import Tuple
Expand Down Expand Up @@ -95,6 +96,9 @@ def test_register_data(
cards = registry.list_cards(name=data_card.name, repository=data_card.repository, version="1.0.0")
assert bool(cards)

cards = registry.list_cards(name=data_card.name, sort_by_timestamp=True)
assert bool(cards)

data_card = DataCard(
interface=test_interface,
name="test_df",
Expand Down Expand Up @@ -850,3 +854,35 @@ def test_register_data_timestamp(

assert isinstance(loaded.interface.data_splits[0].column_value, pd.Timestamp)
assert splits["train"].X.shape[0] == 8


def test_sort_timestamp(sql_data: SqlData, db_registries: CardRegistries) -> None:
# create data card
registry = db_registries.data
data_card = DataCard(
interface=sql_data,
name="test1",
repository="mlops",
contact="mlops.com",
version="1.0.0",
)

registry.register_card(card=data_card)

time.sleep(2)

data_card = DataCard(
interface=sql_data,
name="test2",
repository="mlops",
contact="mlops.com",
version="1.0.0",
)

registry.register_card(card=data_card)

### test sort by timestamp
cards = registry.list_cards(sort_by_timestamp=True)
print(cards)
assert cards[0]["name"] == "test2"
assert cards[1]["name"] == "test1"

0 comments on commit 0d7eb3f

Please sign in to comment.