Skip to content

Commit

Permalink
RFC: add partition tags to partitioned config (#7111)
Browse files Browse the repository at this point in the history
* add partition tags to partitioned config

* mypy

* rip out partition_key-based config/tag functions from partitioned config
  • Loading branch information
prha committed Mar 24, 2022
1 parent e8f1745 commit 0fc9ba4
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 20 deletions.
13 changes: 10 additions & 3 deletions python_modules/dagster/dagster/core/definitions/job_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
parse_op_selection,
)
from dagster.core.storage.fs_asset_io_manager import fs_asset_io_manager
from dagster.core.storage.tags import PARTITION_NAME_TAG
from dagster.core.utils import str_format_set

from .executor_definition import ExecutorDefinition
Expand Down Expand Up @@ -176,6 +175,7 @@ def execute_in_process(
version_strategy=self.version_strategy,
).get_job_def_for_op_selection(op_selection)

tags = None
if partition_key:
if not base_mode.partitioned_config:
check.failed(
Expand All @@ -185,7 +185,10 @@ def execute_in_process(
not run_config,
"Cannot provide both run_config and partition_key arguments to `execute_in_process`",
)
run_config = base_mode.partitioned_config.get_run_config(partition_key)
partition_set = self.get_partition_set_def()
partition = partition_set.get_partition(partition_key)
run_config = partition_set.run_config_for_partition(partition)
tags = partition_set.tags_for_partition(partition)

return core_execute_in_process(
node=self._graph_def,
Expand All @@ -194,7 +197,7 @@ def execute_in_process(
instance=instance,
output_capturing_enabled=True,
raise_on_error=raise_on_error,
run_tags={PARTITION_NAME_TAG: partition_key} if partition_key else None,
run_tags=tags,
)

@property
Expand Down Expand Up @@ -243,11 +246,15 @@ def get_partition_set_def(self) -> Optional["PartitionSetDefinition"]:

if not self._cached_partition_set:

tags_fn = mode.partitioned_config.tags_for_partition_fn
if not tags_fn:
tags_fn = lambda _: {}
self._cached_partition_set = PartitionSetDefinition(
job_name=self.name,
name=f"{self.name}_partition_set",
partitions_def=mode.partitioned_config.partitions_def,
run_config_fn_for_partition=mode.partitioned_config.run_config_for_partition_fn,
tags_fn_for_partition=tags_fn,
mode=mode.name,
)

Expand Down
33 changes: 20 additions & 13 deletions python_modules/dagster/dagster/core/definitions/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dateutil.relativedelta import relativedelta

from dagster import check
from dagster.core.storage.tags import PARTITION_NAME_TAG
from dagster.serdes import whitelist_for_serdes

from ...seven.compat.pendulum import PendulumDateTime, to_timezone
Expand Down Expand Up @@ -514,7 +515,7 @@ def get_partition(self, name: str) -> Partition[T]:
if partition.name == name:
return partition

check.failed("Partition name {} not found!".format(name))
raise DagsterUnknownPartitionError(f"Could not find a partition with key `{name}`")

def get_partition_names(self, current_time: Optional[datetime] = None) -> List[str]:
return [part.name for part in self.get_partitions(current_time)]
Expand Down Expand Up @@ -744,12 +745,16 @@ def __init__(
partitions_def: PartitionsDefinition[T], # pylint: disable=unsubscriptable-object
run_config_for_partition_fn: Callable[[Partition[T]], Dict[str, Any]],
decorated_fn: Optional[Callable[..., Dict[str, Any]]] = None,
tags_for_partition_fn: Optional[Callable[[Partition[T]], Dict[str, str]]] = None,
):
self._partitions = check.inst_param(partitions_def, "partitions_def", PartitionsDefinition)
self._run_config_for_partition_fn = check.callable_param(
run_config_for_partition_fn, "run_config_for_partition_fn"
)
self._decorated_fn = decorated_fn
self._tags_for_partition_fn = check.opt_callable_param(
tags_for_partition_fn, "tags_for_partition_fn"
)

@property
def partitions_def(self) -> PartitionsDefinition[T]: # pylint: disable=unsubscriptable-object
Expand All @@ -759,21 +764,13 @@ def partitions_def(self) -> PartitionsDefinition[T]: # pylint: disable=unsubscr
def run_config_for_partition_fn(self) -> Callable[[Partition[T]], Dict[str, Any]]:
return self._run_config_for_partition_fn

@property
def tags_for_partition_fn(self) -> Optional[Callable[[Partition[T]], Dict[str, str]]]:
return self._tags_for_partition_fn

def get_partition_keys(self, current_time: Optional[datetime] = None) -> List[str]:
return [partition.name for partition in self.partitions_def.get_partitions(current_time)]

def get_run_config(self, partition_key: str) -> Dict[str, Any]:
matching = [
partition
for partition in self.partitions_def.get_partitions()
if partition.name == partition_key
]
if not matching:
raise DagsterUnknownPartitionError(
f"Could not find a partition with key `{partition_key}`"
)
return self.run_config_for_partition_fn(matching[0])

def __call__(self, *args, **kwargs):
if self._decorated_fn is None:
raise DagsterInvalidInvocationError(
Expand All @@ -786,6 +783,7 @@ def __call__(self, *args, **kwargs):

def static_partitioned_config(
partition_keys: List[str],
tags_for_partition_fn: Optional[Callable[[str], Dict[str, str]]] = None,
) -> Callable[[Callable[[str], Dict[str, Any]]], PartitionedConfig]:
"""Creates a static partitioned config for a job.
Expand Down Expand Up @@ -815,17 +813,22 @@ def inner(fn: Callable[[str], Dict[str, Any]]) -> PartitionedConfig:
def _run_config_wrapper(partition: Partition[T]) -> Dict[str, Any]:
return fn(partition.name)

def _tag_wrapper(partition: Partition[T]) -> Dict[str, str]:
return tags_for_partition_fn(partition.name) if tags_for_partition_fn else {}

return PartitionedConfig(
partitions_def=StaticPartitionsDefinition(partition_keys),
run_config_for_partition_fn=_run_config_wrapper,
decorated_fn=fn,
tags_for_partition_fn=_tag_wrapper,
)

return inner


def dynamic_partitioned_config(
partition_fn: Callable[[Optional[datetime]], List[str]],
tags_for_partition_fn: Optional[Callable[[str], Dict[str, str]]] = None,
) -> Callable[[Callable[[str], Dict[str, Any]]], PartitionedConfig]:
"""Creates a dynamic partitioned config for a job.
Expand All @@ -850,10 +853,14 @@ def inner(fn: Callable[[str], Dict[str, Any]]) -> PartitionedConfig:
def _run_config_wrapper(partition: Partition[T]) -> Dict[str, Any]:
return fn(partition.name)

def _tag_wrapper(partition: Partition[T]) -> Dict[str, str]:
return tags_for_partition_fn(partition.name) if tags_for_partition_fn else {}

return PartitionedConfig(
partitions_def=DynamicPartitionsDefinition(partition_fn),
run_config_for_partition_fn=_run_config_wrapper,
decorated_fn=fn,
tags_for_partition_fn=_tag_wrapper,
)

return inner
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union
from typing import Any, Callable, Dict, List, NamedTuple, Optional, Union, cast

import pendulum

Expand Down Expand Up @@ -164,11 +164,23 @@ def __new__(
)


def wrap_time_window_tags_fn(
tags_fn: Optional[Callable[[datetime, datetime], Dict[str, str]]]
) -> Callable[[Partition], Dict[str, str]]:
def _tag_wrapper(partition: Partition) -> Dict[str, str]:
if not tags_fn:
return {}
return tags_fn(cast(datetime, partition.value[0]), cast(datetime, partition.value[1]))

return _tag_wrapper


def daily_partitioned_config(
start_date: Union[datetime, str],
timezone: Optional[str] = None,
fmt: Optional[str] = None,
end_offset: int = 0,
tags_for_partition_fn: Optional[Callable[[datetime, datetime], Dict[str, str]]] = None,
) -> Callable[[Callable[[datetime, datetime], Dict[str, Any]]], PartitionedConfig]:
"""Defines run config over a set of daily partitions.
Expand Down Expand Up @@ -205,6 +217,7 @@ def inner(fn: Callable[[datetime, datetime], Dict[str, Any]]) -> PartitionedConf
start_date=start_date, timezone=timezone, fmt=fmt, end_offset=end_offset
),
decorated_fn=fn,
tags_for_partition_fn=wrap_time_window_tags_fn(tags_for_partition_fn),
)

return inner
Expand Down Expand Up @@ -252,6 +265,7 @@ def hourly_partitioned_config(
timezone: Optional[str] = None,
fmt: Optional[str] = None,
end_offset: int = 0,
tags_for_partition_fn: Optional[Callable[[datetime, datetime], Dict[str, str]]] = None,
) -> Callable[[Callable[[datetime, datetime], Dict[str, Any]]], PartitionedConfig]:
"""Defines run config over a set of hourly partitions.
Expand Down Expand Up @@ -288,6 +302,7 @@ def inner(fn: Callable[[datetime, datetime], Dict[str, Any]]) -> PartitionedConf
start_date=start_date, timezone=timezone, fmt=fmt, end_offset=end_offset
),
decorated_fn=fn,
tags_for_partition_fn=wrap_time_window_tags_fn(tags_for_partition_fn),
)

return inner
Expand Down Expand Up @@ -335,6 +350,7 @@ def monthly_partitioned_config(
timezone: Optional[str] = None,
fmt: Optional[str] = None,
end_offset: int = 0,
tags_for_partition_fn: Optional[Callable[[datetime, datetime], Dict[str, str]]] = None,
) -> Callable[[Callable[[datetime, datetime], Dict[str, Any]]], PartitionedConfig]:
"""Defines run config over a set of monthly partitions.
Expand Down Expand Up @@ -374,6 +390,7 @@ def inner(fn: Callable[[datetime, datetime], Dict[str, Any]]) -> PartitionedConf
end_offset=end_offset,
),
decorated_fn=fn,
tags_for_partition_fn=wrap_time_window_tags_fn(tags_for_partition_fn),
)

return inner
Expand Down Expand Up @@ -421,6 +438,7 @@ def weekly_partitioned_config(
timezone: Optional[str] = None,
fmt: Optional[str] = None,
end_offset: int = 0,
tags_for_partition_fn: Optional[Callable[[datetime, datetime], Dict[str, str]]] = None,
) -> Callable[[Callable[[datetime, datetime], Dict[str, Any]]], PartitionedConfig]:
"""Defines run config over a set of weekly partitions.
Expand Down Expand Up @@ -460,6 +478,7 @@ def inner(fn: Callable[[datetime, datetime], Dict[str, Any]]) -> PartitionedConf
end_offset=end_offset,
),
decorated_fn=fn,
tags_for_partition_fn=wrap_time_window_tags_fn(tags_for_partition_fn),
)

return inner
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def my_op(context):


def test_static_partitioned_job():
@static_partitioned_config(["blah"])
@static_partitioned_config(
["blah"], tags_for_partition_fn=lambda partition_key: {"foo": partition_key}
)
def my_static_partitioned_config(_partition_key: str):
return RUN_CONFIG

Expand All @@ -35,6 +37,7 @@ def my_job():

result = my_job.execute_in_process(partition_key="blah")
assert result.success
assert result.dagster_run.tags["foo"] == "blah"

with pytest.raises(
DagsterUnknownPartitionError, match="Could not find a partition with key `doesnotexist`"
Expand All @@ -43,7 +46,10 @@ def my_job():


def test_time_based_partitioned_job():
@daily_partitioned_config(start_date="2021-05-05")
@daily_partitioned_config(
start_date="2021-05-05",
tags_for_partition_fn=lambda start, end: {"foo": start.strftime("%Y-%m-%d")},
)
def my_daily_partitioned_config(_start, _end):
return RUN_CONFIG

Expand All @@ -63,6 +69,7 @@ def my_job():

result = my_job.execute_in_process(partition_key=partition_key)
assert result.success
assert result.dagster_run.tags["foo"] == "2021-05-05"

with pytest.raises(
DagsterUnknownPartitionError, match="Could not find a partition with key `doesnotexist`"
Expand All @@ -74,7 +81,9 @@ def test_dynamic_partitioned_config():
def partition_fn(_current_time=None):
return ["blah"]

@dynamic_partitioned_config(partition_fn)
@dynamic_partitioned_config(
partition_fn, tags_for_partition_fn=lambda partition_key: {"foo": partition_key}
)
def my_dynamic_partitioned_config(_partition_key):
return RUN_CONFIG

Expand All @@ -89,6 +98,7 @@ def my_job():

result = my_job.execute_in_process(partition_key="blah")
assert result.success
assert result.dagster_run.tags["foo"] == "blah"

with pytest.raises(
DagsterUnknownPartitionError, match="Could not find a partition with key `doesnotexist`"
Expand Down

0 comments on commit 0fc9ba4

Please sign in to comment.