Skip to content

Commit

Permalink
[Datastore] Support writing partitioned parquet data (#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
Gal Topper committed May 20, 2021
1 parent e0c1129 commit 8d1ea44
Show file tree
Hide file tree
Showing 5 changed files with 228 additions and 10 deletions.
25 changes: 23 additions & 2 deletions mlrun/datastore/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,12 @@ def as_df(self, url, subpath, columns=None, df_module=None, format="", **kwargs)
elif url.endswith(".parquet") or url.endswith(".pq") or format == "parquet":
if columns:
kwargs["columns"] = columns
reader = df_module.read_parquet

def reader(*args, **kwargs):
df_from_pq = df_module.read_parquet(*args, **kwargs)
_drop_reserved_columns(df_from_pq)
return df_from_pq

elif url.endswith(".json") or format == "json":
reader = df_module.read_json

Expand All @@ -139,7 +144,15 @@ def as_df(self, url, subpath, columns=None, df_module=None, format="", **kwargs)

fs = self.get_filesystem()
if fs:
return reader(fs.open(url), **kwargs)
if fs.isdir(url):
storage_options = self.get_storage_options()
if storage_options:
kwargs["storage_options"] = storage_options
return reader(url, **kwargs)
else:
# If not dir, use fs.open() to avoid regression when pandas < 1.2 and does not
# support the storage_options parameter.
return reader(fs.open(url), **kwargs)

tmp = mktemp()
self.download(self._join(subpath), tmp)
Expand All @@ -156,6 +169,14 @@ def to_dict(self):
}


def _drop_reserved_columns(df):
cols_to_drop = []
for col in df.columns:
if col.startswith("igzpart_"):
cols_to_drop.append(col)
df.drop(labels=cols_to_drop, axis=1, inplace=True, errors="ignore")


class DataItem:
"""Data input/output class abstracting access to various local/remote data sources"""

Expand Down
3 changes: 3 additions & 0 deletions mlrun/datastore/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ def get_spark_options(self):
"format": "parquet",
}

def to_dataframe(self):
return mlrun.store_manager.object(url=self.path).as_df(format="parquet")


class CustomSource(BaseSourceDriver):
kind = "custom"
Expand Down
117 changes: 113 additions & 4 deletions mlrun/datastore/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from mlrun.utils import now_date
from mlrun.utils.v3io_clients import get_frames_client

from .. import errors
from ..platforms.iguazio import parse_v3io_path, split_path
from .utils import store_path_to_spark

Expand Down Expand Up @@ -258,12 +259,20 @@ def __init__(
attributes: typing.Dict[str, str] = None,
after_state=None,
columns=None,
partitioned: bool = False,
key_bucketing_number: typing.Optional[int] = None,
partition_cols: typing.Optional[typing.List[str]] = None,
time_partitioning_granularity: typing.Optional[str] = None,
):
self.name = name
self.path = str(path) if path is not None else None
self.after_state = after_state
self.attributes = attributes or {}
self.columns = columns or []
self.partitioned = partitioned
self.key_bucketing_number = key_bucketing_number
self.partition_cols = partition_cols
self.time_partitioning_granularity = time_partitioning_granularity

self._target = None
self._resource = None
Expand Down Expand Up @@ -324,8 +333,31 @@ def from_spec(cls, spec: DataTargetBase, resource=None):
driver.name = spec.name
driver.path = spec.path
driver.attributes = spec.attributes

if hasattr(spec, "columns"):
driver.columns = spec.columns

driver.partitioned = spec.partitioned

driver.key_bucketing_number = spec.key_bucketing_number
driver.partition_cols = spec.partition_cols

driver.time_partitioning_granularity = spec.time_partitioning_granularity
if spec.kind == "parquet":
driver.suffix = (
".parquet"
if not spec.partitioned
and all(
value is None
for value in [
spec.key_bucketing_number,
spec.partition_cols,
spec.time_partitioning_granularity,
]
)
else ""
)

driver._resource = resource
return driver

Expand All @@ -338,13 +370,12 @@ def _target_path(self):
"""return the actual/computed target path"""
return self.path or _get_target_path(self, self._resource)

def update_resource_status(self, status="", producer=None, is_dir=None, size=None):
def update_resource_status(self, status="", producer=None, size=None):
"""update the data target status"""
self._target = self._target or DataTarget(
self.kind, self.name, self._target_path
)
target = self._target
target.is_dir = is_dir
target.status = status or target.status or "created"
target.updated = now_date().isoformat()
target.size = size
Expand All @@ -371,11 +402,59 @@ def get_spark_options(self, key_column=None, timestamp_key=None):

class ParquetTarget(BaseStoreTarget):
kind = TargetTypes.parquet
suffix = ".parquet"
is_offline = True
support_spark = True
support_storey = True

def __init__(
self,
name: str = "",
path=None,
attributes: typing.Dict[str, str] = None,
after_state=None,
columns=None,
partitioned: bool = False,
key_bucketing_number: typing.Optional[int] = None,
partition_cols: typing.Optional[typing.List[str]] = None,
time_partitioning_granularity: typing.Optional[str] = None,
):
super().__init__(
name,
path,
attributes,
after_state,
columns,
partitioned,
key_bucketing_number,
partition_cols,
time_partitioning_granularity,
)

if (
time_partitioning_granularity is not None
and time_partitioning_granularity not in self._legal_time_units
):
raise errors.MLRunInvalidArgumentError(
f"time_partitioning_granularity parameter must be one of {','.join(self._legal_time_units)}, "
f"not {time_partitioning_granularity}."
)

self.suffix = (
".parquet"
if not partitioned
and all(
value is None
for value in [
key_bucketing_number,
partition_cols,
time_partitioning_granularity,
]
)
else ""
)

_legal_time_units = ["year", "month", "day", "hour", "minute", "second"]

@staticmethod
def _write_dataframe(df, fs, target_path, **kwargs):
with fs.open(target_path, "wb") as fp:
Expand All @@ -388,6 +467,29 @@ def add_writer_state(
features=features, timestamp_key=timestamp_key, key_columns=None
)

partition_cols = None
if self.key_bucketing_number is not None:
partition_cols = [("$key", self.key_bucketing_number)]
if self.partition_cols:
partition_cols = partition_cols or []
partition_cols.extend(self.partition_cols)
time_partitioning_granularity = self.time_partitioning_granularity
if self.partitioned and all(
value is None
for value in [
time_partitioning_granularity,
self.key_bucketing_number,
self.partition_cols,
]
):
time_partitioning_granularity = "hour"
if time_partitioning_granularity is not None:
partition_cols = partition_cols or []
for time_unit in self._legal_time_units:
partition_cols.append(f"${time_unit}")
if time_unit == time_partitioning_granularity:
break

graph.add_step(
name=self.name or "ParquetTarget",
after=after,
Expand All @@ -396,6 +498,7 @@ def add_writer_state(
path=self._target_path,
columns=column_list,
index_cols=key_columns,
partition_cols=partition_cols,
storage_options=self._get_store().get_storage_options(),
**self.attributes,
)
Expand All @@ -406,6 +509,12 @@ def get_spark_options(self, key_column=None, timestamp_key=None):
"format": "parquet",
}

def as_df(self, columns=None, df_module=None, entities=None):
"""return the target data as dataframe"""
return mlrun.get_dataitem(self._target_path).as_df(
columns=columns, df_module=df_module, format="parquet"
)


class CSVTarget(BaseStoreTarget):
kind = TargetTypes.csv
Expand Down Expand Up @@ -659,7 +768,7 @@ def __init__(self):
def set_df(self, df):
self._df = df

def update_resource_status(self, status="", producer=None, is_dir=None):
def update_resource_status(self, status="", producer=None):
pass

def add_writer_state(
Expand Down
24 changes: 20 additions & 4 deletions mlrun/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from collections import OrderedDict
from copy import deepcopy
from os import environ
from typing import Dict, Tuple, Union
from typing import Dict, List, Optional, Tuple, Union

import mlrun

Expand Down Expand Up @@ -979,7 +979,17 @@ def set_secrets(self, secrets):
class DataTargetBase(ModelObj):
"""data target spec, specify a destination for the feature set data"""

_dict_fields = ["name", "kind", "path", "after_state", "attributes"]
_dict_fields = [
"name",
"kind",
"path",
"after_state",
"attributes",
"partitioned",
"key_bucketing_number",
"partition_cols",
"time_partitioning_granularity",
]

def __init__(
self,
Expand All @@ -988,12 +998,20 @@ def __init__(
path=None,
attributes: Dict[str, str] = None,
after_state=None,
partitioned: bool = False,
key_bucketing_number: Optional[int] = None,
partition_cols: Optional[List[str]] = None,
time_partitioning_granularity: Optional[str] = None,
):
self.name = name
self.kind: str = kind
self.path = path
self.after_state = after_state
self.attributes = attributes or {}
self.partitioned = partitioned
self.key_bucketing_number = key_bucketing_number
self.partition_cols = partition_cols
self.time_partitioning_granularity = time_partitioning_granularity


class FeatureSetProducer(ModelObj):
Expand All @@ -1017,7 +1035,6 @@ class DataTarget(DataTargetBase):
"start_time",
"online",
"status",
"is_dir",
"updated",
"size",
]
Expand All @@ -1032,7 +1049,6 @@ def __init__(
self.online = online
self.max_age = None
self.start_time = None
self.is_dir = None
self._producer = None
self.producer = {}

Expand Down

0 comments on commit 8d1ea44

Please sign in to comment.