Skip to content

Commit

Permalink
[Feature store] ParquetTarget should be in a single file when partiti…
Browse files Browse the repository at this point in the history
…oned is false (#1009)
  • Loading branch information
katyakats committed Jun 14, 2021
1 parent d6c0f7b commit 2721a67
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 14 deletions.
36 changes: 22 additions & 14 deletions mlrun/datastore/targets.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,7 @@ def __init__(
attributes: typing.Dict[str, str] = None,
after_step=None,
columns=None,
partitioned: bool = False,
partitioned: bool = None,
key_bucketing_number: typing.Optional[int] = None,
partition_cols: typing.Optional[typing.List[str]] = None,
time_partitioning_granularity: typing.Optional[str] = None,
Expand All @@ -496,6 +496,19 @@ def __init__(
)
after_step = after_step or after_state

if partitioned is None:
if all(
value is None
for value in [
key_bucketing_number,
partition_cols,
time_partitioning_granularity,
]
):
partitioned = False
else:
partitioned = True

super().__init__(
name,
path,
Expand All @@ -517,19 +530,7 @@ def __init__(
f"not {time_partitioning_granularity}."
)

self.suffix = (
".parquet"
if not partitioned
and all(
value is None
for value in [
key_bucketing_number,
partition_cols,
time_partitioning_granularity,
]
)
else ""
)
self.suffix = ".parquet" if not partitioned else ""

_legal_time_units = ["year", "month", "day", "hour", "minute", "second"]

Expand Down Expand Up @@ -584,6 +585,13 @@ def add_writer_step(
if time_unit == time_partitioning_granularity:
break

if (
not self.partitioned
and not self._target_path.endswith(".parquet")
and not self._target_path.endswith(".pq")
):
partition_cols = []

graph.add_step(
name=self.name or "ParquetTarget",
after=after,
Expand Down
14 changes: 14 additions & 0 deletions tests/system/feature_store/test_feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,20 @@ def test_serverless_ingest(self):
stats.remove("timestamp")
assert features == stats, "didnt infer stats for all features"

def test_non_partitioned_target_in_dir(self):
source = CSVSource(
"mycsv", path=os.path.relpath(str(self.assets_path / "testdata.csv"))
)
path = str(self.results_path / _generate_random_name())
target = ParquetTarget(path=path)

fset = fs.FeatureSet(name="test", entities=[Entity("patient_id")])
fs.ingest(fset, source, targets=[target])

list_files = os.listdir(path)
assert len(list_files) == 1 and not os.path.isdir(path + "/" + list_files[0])
os.remove(path + "/" + list_files[0])

def test_ingest_with_timestamp(self):
key = "patient_id"
measurements = fs.FeatureSet(
Expand Down

0 comments on commit 2721a67

Please sign in to comment.