Skip to content

Commit

Permalink
combine validate_tags and check_tags to JSONify nested tags (#7720)
Browse files Browse the repository at this point in the history
Summary:
This will provide consistent Dict => str mapping for anywhere that users input tags.

Test Plan:
BK
  • Loading branch information
gibsondan committed May 4, 2022
1 parent 504aa08 commit 948499e
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
create_offset_partition_selector,
)

from ...storage.tags import check_tags
from ..graph_definition import GraphDefinition
from ..mode import DEFAULT_MODE_NAME
from ..pipeline_definition import PipelineDefinition
Expand All @@ -39,6 +38,7 @@
ScheduleEvaluationContext,
is_context_provided,
)
from ..utils import validate_tags

if TYPE_CHECKING:
from dagster import Partition
Expand Down Expand Up @@ -114,14 +114,16 @@ def inner(fn: RawScheduleEvaluationFunction) -> ScheduleDefinition:

schedule_name = name or fn.__name__

validated_tags = None

# perform upfront validation of schedule tags
if tags_fn and tags:
raise DagsterInvalidDefinitionError(
"Attempted to provide both tags_fn and tags as arguments"
" to ScheduleDefinition. Must provide only one of the two."
)
elif tags:
check_tags(tags, "tags")
validated_tags = validate_tags(tags, allow_reserved_tags=False)

def _wrapped_fn(context: ScheduleEvaluationContext) -> RunRequestIterator:
if should_execute:
Expand All @@ -148,7 +150,11 @@ def _wrapped_fn(context: ScheduleEvaluationContext) -> RunRequestIterator:
# this is the run-config based decorated function, wrap the evaluated run config
# and tags in a RunRequest
evaluated_run_config = copy.deepcopy(result)
evaluated_tags = tags or (tags_fn and tags_fn(context)) or None
evaluated_tags = (
validated_tags
or (tags_fn and validate_tags(tags_fn(context), allow_reserved_tags=False))
or None
)
yield RunRequest(
run_key=None,
run_config=evaluated_run_config,
Expand Down
8 changes: 3 additions & 5 deletions python_modules/dagster/dagster/core/definitions/partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from dagster import check
from dagster.serdes import whitelist_for_serdes

from ...core.definitions.utils import validate_tags
from ...seven.compat.pendulum import PendulumDateTime, to_timezone
from ...utils import frozenlist, merge_dicts
from ...utils.schedules import schedule_execution_time_iterator
Expand All @@ -24,7 +25,6 @@
user_code_error_boundary,
)
from ..storage.pipeline_run import PipelineRun
from ..storage.tags import check_tags
from .mode import DEFAULT_MODE_NAME
from .run_request import RunRequest, SkipReason
from .schedule_definition import (
Expand Down Expand Up @@ -495,11 +495,9 @@ def run_config_for_partition(self, partition: Partition[T]) -> Dict[str, Any]:
return copy.deepcopy(self._user_defined_run_config_fn_for_partition(partition))

def tags_for_partition(self, partition: Partition[T]) -> Dict[str, str]:
user_tags = copy.deepcopy(
validate_tags(self._user_defined_tags_fn_for_partition(partition))
user_tags = validate_tags(
self._user_defined_tags_fn_for_partition(partition), allow_reserved_tags=False
)
check_tags(user_tags, "user_tags")

tags = merge_dicts(user_tags, PipelineRun.tags_for_partition_set(self, partition))

return tags
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
from ..instance import DagsterInstance
from ..instance.ref import InstanceRef
from ..storage.pipeline_run import PipelineRun
from ..storage.tags import check_tags
from .graph_definition import GraphDefinition
from .mode import DEFAULT_MODE_NAME
from .pipeline_definition import PipelineDefinition
from .run_request import RunRequest, SkipReason
from .target import DirectTarget, RepoRelativeTarget
from .utils import check_valid_name
from .utils import check_valid_name, validate_tags

T = TypeVar("T")

Expand Down Expand Up @@ -302,7 +301,7 @@ def _default_run_config_fn(context: ScheduleEvaluationContext) -> RunConfig:
" to ScheduleDefinition. Must provide only one of the two."
)
elif tags:
check_tags(tags, "tags")
tags = validate_tags(tags, allow_reserved_tags=False)
tags_fn = lambda _context: tags
else:
tags_fn = check.opt_callable_param(
Expand Down Expand Up @@ -341,7 +340,7 @@ def _execution_fn(context):
ScheduleExecutionError,
lambda: f"Error occurred during the execution of tags_fn for schedule {name}",
):
evaluated_tags = tags_fn(context)
evaluated_tags = validate_tags(tags_fn(context), allow_reserved_tags=False)

yield RunRequest(
run_key=None,
Expand Down
6 changes: 5 additions & 1 deletion python_modules/dagster/dagster/core/definitions/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from dagster import check, seven
from dagster.core.errors import DagsterInvalidDefinitionError, DagsterInvariantViolationError
from dagster.core.storage.tags import check_reserved_tags
from dagster.utils import frozentags
from dagster.utils.yaml_utils import merge_yaml_strings, merge_yamls

Expand Down Expand Up @@ -79,7 +80,7 @@ def struct_to_string(name, **kwargs):
return "{name}({props_str})".format(name=name, props_str=props_str)


def validate_tags(tags: Optional[Dict[str, Any]]) -> Dict[str, Any]:
def validate_tags(tags: Optional[Dict[str, Any]], allow_reserved_tags=True) -> Dict[str, str]:
valid_tags = {}
for key, value in check.opt_dict_param(tags, "tags", key_type=str).items():
if not isinstance(value, str):
Expand Down Expand Up @@ -107,6 +108,9 @@ def validate_tags(tags: Optional[Dict[str, Any]]) -> Dict[str, Any]:
else:
valid_tags[key] = value

if not allow_reserved_tags:
check_reserved_tags(valid_tags)

return frozentags(valid_tags)


Expand Down
8 changes: 4 additions & 4 deletions python_modules/dagster/dagster/core/storage/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,12 @@ def get_tag_type(tag):
return TagType.USER_PROVIDED


def check_tags(obj, name):
check.opt_dict_param(obj, name, key_type=str, value_type=str)
def check_reserved_tags(tags):
check.opt_dict_param(tags, "tags", key_type=str, value_type=str)

for tag in obj.keys():
for tag in tags.keys():
if not tag in USER_EDITABLE_SYSTEM_TAGS:
check.invariant(
not tag.startswith(SYSTEM_TAG_PREFIX),
desc="User attempted to set tag with reserved system prefix: {tag}".format(tag=tag),
desc="Attempted to set tag with reserved system prefix: {tag}".format(tag=tag),
)
6 changes: 3 additions & 3 deletions python_modules/dagster/dagster/daemon/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from dagster import check, seven
from dagster.core.definitions.run_request import InstigatorType
from dagster.core.definitions.sensor_definition import DefaultSensorStatus, SensorExecutionData
from dagster.core.definitions.utils import validate_tags
from dagster.core.errors import DagsterError
from dagster.core.host_representation import PipelineSelector
from dagster.core.instance import DagsterInstance
Expand All @@ -19,7 +20,7 @@
TickStatus,
)
from dagster.core.storage.pipeline_run import PipelineRun, PipelineRunStatus, RunsFilter, TagBucket
from dagster.core.storage.tags import RUN_KEY_TAG, check_tags
from dagster.core.storage.tags import RUN_KEY_TAG
from dagster.core.telemetry import SENSOR_RUN_CREATED, hash_name, log_action
from dagster.core.workspace import IWorkspace
from dagster.utils import merge_dicts
Expand Down Expand Up @@ -596,8 +597,7 @@ def _create_sensor_run(
)
execution_plan_snapshot = external_execution_plan.execution_plan_snapshot

pipeline_tags = external_pipeline.tags or {}
check_tags(pipeline_tags, "pipeline_tags")
pipeline_tags = validate_tags(external_pipeline.tags or {}, allow_reserved_tags=False)
tags = merge_dicts(
merge_dicts(pipeline_tags, run_request.tags),
PipelineRun.tags_for_sensor(external_sensor),
Expand Down
6 changes: 3 additions & 3 deletions python_modules/dagster/dagster/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from dagster import check
from dagster.core.definitions.schedule_definition import DefaultScheduleStatus
from dagster.core.definitions.utils import validate_tags
from dagster.core.errors import DagsterUserCodeUnreachableError
from dagster.core.host_representation import ExternalSchedule, PipelineSelector
from dagster.core.instance import DagsterInstance
Expand All @@ -21,7 +22,7 @@
)
from dagster.core.scheduler.scheduler import DEFAULT_MAX_CATCHUP_RUNS, DagsterSchedulerError
from dagster.core.storage.pipeline_run import PipelineRun, PipelineRunStatus, RunsFilter
from dagster.core.storage.tags import RUN_KEY_TAG, SCHEDULED_EXECUTION_TIME_TAG, check_tags
from dagster.core.storage.tags import RUN_KEY_TAG, SCHEDULED_EXECUTION_TIME_TAG
from dagster.core.telemetry import SCHEDULED_RUN_CREATED, hash_name, log_action
from dagster.core.workspace import IWorkspace
from dagster.seven.compat.pendulum import to_timezone
Expand Down Expand Up @@ -536,8 +537,7 @@ def _create_scheduler_run(
)
execution_plan_snapshot = external_execution_plan.execution_plan_snapshot

pipeline_tags = external_pipeline.tags or {}
check_tags(pipeline_tags, "pipeline_tags")
pipeline_tags = validate_tags(external_pipeline.tags, allow_reserved_tags=False) or {}
tags = merge_dicts(pipeline_tags, schedule_tags)

tags[SCHEDULED_EXECUTION_TIME_TAG] = to_timezone(schedule_time, "UTC").isoformat()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import json
import re
from datetime import datetime, time

Expand All @@ -23,6 +24,7 @@
weekly_schedule,
)
from dagster.seven.compat.pendulum import create_pendulum_time, to_timezone
from dagster.utils import merge_dicts
from dagster.utils.partitions import (
DEFAULT_DATE_FORMAT,
DEFAULT_HOURLY_FORMAT_WITHOUT_TIMEZONE,
Expand Down Expand Up @@ -740,6 +742,22 @@ def bad_cron_string_three(context):
return {}


def test_schedule_with_nested_tags():

nested_tags = {"foo": {"bar": "baz"}}

@schedule(cron_schedule="* * * * *", pipeline_name="foo_pipeline", tags=nested_tags)
def my_tag_schedule():
return {}

assert my_tag_schedule.evaluate_tick(
build_schedule_context(scheduled_execution_time=pendulum.now())
)[0][0].tags == merge_dicts(
{key: json.dumps(val) for key, val in nested_tags.items()},
{"dagster/schedule_name": "my_tag_schedule"},
)


def test_scheduled_jobs():
from dagster import Field, String

Expand Down

0 comments on commit 948499e

Please sign in to comment.