Skip to content

Commit

Permalink
[Artifacts] Limit datasets preview size [Backport 0.5.x] (#670)
Browse files Browse the repository at this point in the history
  • Loading branch information
Hedingber committed Jan 22, 2021
1 parent a11e421 commit 617947a
Show file tree
Hide file tree
Showing 5 changed files with 294 additions and 32 deletions.
18 changes: 16 additions & 2 deletions mlrun/api/db/sqldb/db.py
Expand Up @@ -196,6 +196,19 @@ def del_runs(

def store_artifact(
self, session, key, artifact, uid, iter=None, tag="", project=""
):
self._store_artifact(session, key, artifact, uid, iter, tag, project)

def _store_artifact(
self,
session,
key,
artifact,
uid,
iter=None,
tag="",
project="",
tag_artifact=True,
):
project = project or config.default_project
self._ensure_project(session, project)
Expand All @@ -212,8 +225,9 @@ def store_artifact(
update_labels(art, labels)
art.struct = artifact
self._upsert(session, art)
tag = tag or "latest"
self.tag_objects(session, [art], project, tag)
if tag_artifact:
tag = tag or "latest"
self.tag_objects(session, [art], project, tag)

def read_artifact(self, session, key, tag="", iter=None, project=""):
project = project or config.default_project
Expand Down
91 changes: 91 additions & 0 deletions mlrun/api/initial_data.py
@@ -1,3 +1,7 @@
import mlrun.api.db.sqldb.db
import sqlalchemy.orm
import mlrun.artifacts
import mlrun.artifacts.dataset
from mlrun.api.db.init_db import init_db
from mlrun.api.db.session import create_session, close_session
from mlrun.utils import logger
Expand All @@ -8,11 +12,98 @@ def init_data() -> None:
db_session = create_session()
try:
init_db(db_session)
_perform_data_migrations(db_session)
finally:
close_session(db_session)
logger.info("Initial data created")


def _perform_data_migrations(db_session: sqlalchemy.orm.Session):
# FileDB is not really a thing anymore, so using SQLDB directly
db = mlrun.api.db.sqldb.db.SQLDB("")
logger.info("Performing data migrations")
_fix_datasets_large_previews(db, db_session)


def _fix_datasets_large_previews(
db: mlrun.api.db.sqldb.db.SQLDB, db_session: sqlalchemy.orm.Session
):
# get all artifacts
artifacts = db._find_artifacts(db_session, None, "*")
for artifact in artifacts:
try:
artifact_dict = artifact.struct
if (
artifact_dict
and artifact_dict.get("kind") == mlrun.artifacts.DatasetArtifact.kind
):
header = artifact_dict.get("header", [])
if header and len(header) > mlrun.artifacts.dataset.max_preview_columns:
logger.debug(
"Found dataset artifact with more than allowed columns in preview fields. Fixing",
artifact=artifact_dict,
)
columns_to_remove = header[
mlrun.artifacts.dataset.max_preview_columns :
]

# align preview
if artifact_dict.get("preview"):
new_preview = []
for preview_row in artifact_dict["preview"]:
# sanity
if (
len(preview_row)
< mlrun.artifacts.dataset.max_preview_columns
):
logger.warning(
"Found artifact with more than allowed columns in header definition, "
"but preview data is valid. Leaving preview as is",
artifact=artifact_dict,
)
new_preview.append(
preview_row[
: mlrun.artifacts.dataset.max_preview_columns
]
)

artifact_dict["preview"] = new_preview

# align stats
for column_to_remove in columns_to_remove:
if column_to_remove in artifact_dict.get("stats", {}):
del artifact_dict["stats"][column_to_remove]

# align schema
if artifact_dict.get("schema", {}).get("fields"):
new_schema_fields = []
for field in artifact_dict["schema"]["fields"]:
if field.get("name") not in columns_to_remove:
new_schema_fields.append(field)
artifact_dict["schema"]["fields"] = new_schema_fields

# lastly, align headers
artifact_dict["header"] = header[
: mlrun.artifacts.dataset.max_preview_columns
]
logger.debug(
"Fixed dataset artifact preview fields. Storing",
artifact=artifact_dict,
)
db._store_artifact(
db_session,
artifact.key,
artifact_dict,
artifact.uid,
project=artifact.project,
tag_artifact=False,
)
except Exception as exc:
logger.warning(
"Failed fixing dataset artifact large preview. Continuing", exc=exc,
)


def main() -> None:
init_data()

Expand Down
72 changes: 42 additions & 30 deletions mlrun/artifacts/dataset.py
Expand Up @@ -24,7 +24,8 @@
from ..datastore import store_manager
from ..utils import DB_SCHEMA

preview_lines = 20
default_preview_rows_length = 20
max_preview_columns = 100
max_csv = 10000


Expand Down Expand Up @@ -100,6 +101,7 @@ def __init__(
target_path=None,
extra_data=None,
column_metadata=None,
ignore_preview_limits=False,
**kwargs,
):

Expand All @@ -120,17 +122,9 @@ def __init__(
self.column_metadata = column_metadata or {}

if df is not None:
self.length = df.shape[0]
preview = preview or preview_lines
shortdf = df
if self.length > preview:
shortdf = df.head(preview)
shortdf = shortdf.reset_index()
self.header = shortdf.columns.values.tolist()
self.preview = shortdf.values.tolist()
self.schema = build_table_schema(df)
if stats or self.length < max_csv:
self.stats = get_df_stats(df)
self.update_preview_fields_from_df(
self, df, stats, preview, ignore_preview_limits
)

self._df = df
self._kw = kwargs
Expand Down Expand Up @@ -181,6 +175,28 @@ def upload(self, data_stores):

raise ValueError(f"format {self.format} not implemented yes")

@staticmethod
def update_preview_fields_from_df(
artifact, df, stats=None, preview_rows_length=None, ignore_preview_limits=False
):
preview_rows_length = preview_rows_length or default_preview_rows_length
artifact.length = df.shape[0]
preview_df = df
if artifact.length > preview_rows_length and not ignore_preview_limits:
preview_df = df.head(preview_rows_length)
preview_df = preview_df.reset_index()
if len(preview_df.columns) > max_preview_columns and not ignore_preview_limits:
preview_df = preview_df.iloc[:, :max_preview_columns]
artifact.header = preview_df.columns.values.tolist()
artifact.preview = preview_df.values.tolist()
artifact.schema = build_table_schema(preview_df)
if (
stats
or (artifact.length < max_csv and len(df.columns) < max_preview_columns)
or ignore_preview_limits
):
artifact.stats = get_df_stats(df)


def get_df_stats(df):
d = {}
Expand Down Expand Up @@ -216,6 +232,7 @@ def update_dataset_meta(
extra_data: dict = None,
column_metadata: dict = None,
labels: dict = None,
ignore_preview_limits: bool = False,
):
"""Update dataset object attributes/metadata
Expand All @@ -225,15 +242,16 @@ def update_dataset_meta(
update_dataset_meta(dataset, from_df=df,
extra_data={'histogram': 's3://mybucket/..'})
:param from_df: read metadata (schema, preview, ..) from provided df
:param artifact: dataset artifact object or path (store://..) or DataItem
:param schema: dataset schema, see pandas build_table_schema
:param header: column headers
:param preview: list of rows and row values (from df.values.tolist())
:param stats: dict of column names and their stats (cleaned df.describe(include='all'))
:param extra_data: extra data items (key: path string | artifact)
:param column_metadata: dict of metadata per column
:param labels: metadata labels
:param from_df: read metadata (schema, preview, ..) from provided df
:param artifact: dataset artifact object or path (store://..) or DataItem
:param schema: dataset schema, see pandas build_table_schema
:param header: column headers
:param preview: list of rows and row values (from df.values.tolist())
:param stats: dict of column names and their stats (cleaned df.describe(include='all'))
:param extra_data: extra data items (key: path string | artifact)
:param column_metadata: dict of metadata per column
:param labels: metadata labels
:param ignore_preview_limits: whether to ignore the preview size limits
"""

if hasattr(artifact, "artifact_url"):
Expand All @@ -251,15 +269,9 @@ def update_dataset_meta(
raise ValueError("store artifact ({}) is not dataset kind".format(artifact))

if from_df is not None:
shortdf = from_df
length = from_df.shape[0]
if length > preview_lines:
shortdf = from_df.head(preview_lines)
artifact_spec.header = shortdf.reset_index().columns.values.tolist()
artifact_spec.preview = shortdf.reset_index().values.tolist()
artifact_spec.schema = build_table_schema(from_df)
if stats is None and length < max_csv:
artifact_spec.stats = get_df_stats(from_df)
DatasetArtifact.update_preview_fields_from_df(
artifact_spec, from_df, stats, ignore_preview_limits
)

if header:
artifact_spec.header = header
Expand Down
96 changes: 96 additions & 0 deletions tests/api/db/test_artifacts.py
@@ -1,5 +1,9 @@
import deepdiff
import numpy
import pandas
import pytest
from sqlalchemy.orm import Session
import mlrun.artifacts.dataset
from mlrun.artifacts.plots import ChartArtifact, PlotArtifact
from mlrun.artifacts.dataset import DatasetArtifact
from mlrun.artifacts.model import ModelArtifact
Expand Down Expand Up @@ -122,6 +126,98 @@ def test_list_artifact_category_filter(db: DBInterface, db_session: Session):
assert artifacts[1]["metadata"]["name"] == artifact_name_2


# running only on sqldb cause filedb is not really a thing anymore, will be removed soon
@pytest.mark.parametrize(
"db,db_session", [(dbs[0], dbs[0])], indirect=["db", "db_session"]
)
def test_data_migration_fix_datasets_large_previews(
db: DBInterface, db_session: Session,
):
artifact_with_valid_preview_key = "artifact-with-valid-preview-key"
artifact_with_valid_preview_uid = "artifact-with-valid-preview-uid"
artifact_with_valid_preview = mlrun.artifacts.DatasetArtifact(
artifact_with_valid_preview_key,
df=pandas.DataFrame(
[{"A": 10, "B": 100}, {"A": 11, "B": 110}, {"A": 12, "B": 120}]
),
)
db.store_artifact(
db_session,
artifact_with_valid_preview_key,
artifact_with_valid_preview.to_dict(),
artifact_with_valid_preview_uid,
)

artifact_with_invalid_preview_key = "artifact-with-invalid-preview-key"
artifact_with_invalid_preview_uid = "artifact-with-invalid-preview-uid"
artifact_with_invalid_preview = mlrun.artifacts.DatasetArtifact(
artifact_with_invalid_preview_key,
df=pandas.DataFrame(
numpy.random.randint(
0, 10, size=(10, mlrun.artifacts.dataset.max_preview_columns * 3)
)
),
ignore_preview_limits=True,
)
db.store_artifact(
db_session,
artifact_with_invalid_preview_key,
artifact_with_invalid_preview.to_dict(),
artifact_with_invalid_preview_uid,
)

# perform the migration
mlrun.api.initial_data._fix_datasets_large_previews(db, db_session)

artifact_with_valid_preview_after_migration = db.read_artifact(
db_session, artifact_with_valid_preview_key, artifact_with_valid_preview_uid
)
assert (
deepdiff.DeepDiff(
artifact_with_valid_preview_after_migration,
artifact_with_valid_preview.to_dict(),
ignore_order=True,
exclude_paths=["root['updated']"],
)
== {}
)

artifact_with_invalid_preview_after_migration = db.read_artifact(
db_session, artifact_with_invalid_preview_key, artifact_with_invalid_preview_uid
)
assert (
deepdiff.DeepDiff(
artifact_with_invalid_preview_after_migration,
artifact_with_invalid_preview.to_dict(),
ignore_order=True,
exclude_paths=[
"root['updated']",
"root['header']",
"root['stats']",
"root['schema']",
"root['preview']",
],
)
== {}
)
assert (
len(artifact_with_invalid_preview_after_migration["header"])
== mlrun.artifacts.dataset.max_preview_columns
)
assert (
len(artifact_with_invalid_preview_after_migration["stats"])
== mlrun.artifacts.dataset.max_preview_columns - 1
)
assert (
len(artifact_with_invalid_preview_after_migration["preview"][0])
== mlrun.artifacts.dataset.max_preview_columns
)
assert (
len(artifact_with_invalid_preview_after_migration["schema"]["fields"])
== mlrun.artifacts.dataset.max_preview_columns + 1
)


def _generate_artifact(name, kind=None):
artifact = {
"metadata": {"name": name},
Expand Down

0 comments on commit 617947a

Please sign in to comment.