Skip to content

Commit

Permalink
Add run tags for repository/location names (#6893)
Browse files Browse the repository at this point in the history
* repo run tags

* change instance tag test

* consolidate into single repository tag

* rename tag

* move storage helper function onto dagster run class to share logic with internal
  • Loading branch information
prha committed May 4, 2022
1 parent 84cc1a4 commit e90d8d9
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ def test_session_header_decode_failure(

assert returned_run.run_id == run_id
assert returned_run.status == PipelineRunStatus.QUEUED
tags = instance.get_run_tags()
assert len(tags) == 0
mock_warnings.warn.assert_called_once()
assert mock_warnings.warn.call_args.args[0].startswith("Couldn't decode JWT header")

Expand Down Expand Up @@ -65,8 +63,7 @@ def test_session_header_decode_success(

assert returned_run.run_id == run_id
assert returned_run.status == PipelineRunStatus.QUEUED
tags = instance.get_run_tags()
assert len(tags) == 1
(tag_name, set_of_tag_values) = tags[0]
assert tag_name == "user"
assert set_of_tag_values == {expected_email}

fetched_run = instance.get_run_by_id(run_id)
assert len(fetched_run.tags) == 1
assert fetched_run.tags["user"] == expected_email
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,9 @@ def get_selector_id(self) -> str:
RepositorySelector(self.repository_location_origin.location_name, self.repository_name)
)

def get_label(self) -> str:
return f"{self.repository_name}@{self.repository_location_origin.location_name}"

def get_pipeline_origin(self, pipeline_name: str) -> "ExternalPipelineOrigin":
return ExternalPipelineOrigin(self, pipeline_name)

Expand Down
15 changes: 15 additions & 0 deletions python_modules/dagster/dagster/core/storage/pipeline_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
BACKFILL_ID_TAG,
PARTITION_NAME_TAG,
PARTITION_SET_TAG,
REPOSITORY_LABEL_TAG,
RESUME_RETRY_TAG,
SCHEDULE_NAME_TAG,
SENSOR_NAME_TAG,
Expand Down Expand Up @@ -398,6 +399,20 @@ def get_root_run_id(self):
def get_parent_run_id(self):
return self.tags.get(PARENT_RUN_ID_TAG)

def tags_for_storage(self):
repository_tags = {}
if self.external_pipeline_origin:
# tag the run with a label containing the repository name / location name, to allow for
# per-repository filtering of runs from dagit.
repository_tags[
REPOSITORY_LABEL_TAG
] = self.external_pipeline_origin.external_repository_origin.get_label()

if not self.tags:
return repository_tags

return {**repository_tags, **self.tags}

@property
def is_finished(self):
return (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,13 @@ def add_run(self, pipeline_run: PipelineRun) -> PipelineRun:
except db.exc.IntegrityError as exc:
raise DagsterRunAlreadyExists from exc

if pipeline_run.tags and len(pipeline_run.tags) > 0:
tags_to_insert = pipeline_run.tags_for_storage()
if tags_to_insert:
conn.execute(
RunTagsTable.insert(), # pylint: disable=no-value-for-parameter
[
dict(run_id=pipeline_run.run_id, key=k, value=v)
for k, v in pipeline_run.tags.items()
for k, v in tags_to_insert.items()
],
)

Expand Down
2 changes: 2 additions & 0 deletions python_modules/dagster/dagster/core/storage/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
SYSTEM_TAG_PREFIX = "dagster/"
HIDDEN_TAG_PREFIX = ".dagster/"

REPOSITORY_LABEL_TAG = f"{HIDDEN_TAG_PREFIX}repository"

SCHEDULE_NAME_TAG = "{prefix}schedule_name".format(prefix=SYSTEM_TAG_PREFIX)

SENSOR_NAME_TAG = "{prefix}sensor_name".format(prefix=SYSTEM_TAG_PREFIX)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
PARENT_RUN_ID_TAG,
PARTITION_NAME_TAG,
PARTITION_SET_TAG,
REPOSITORY_LABEL_TAG,
ROOT_RUN_ID_TAG,
)
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
Expand Down Expand Up @@ -80,16 +81,21 @@ def can_delete_runs(self):
return True

@staticmethod
def fake_repo_target():
def fake_repo_target(repo_name=None):
name = repo_name or "fake_repo_name"
return ExternalRepositoryOrigin(
ManagedGrpcPythonEnvRepositoryLocationOrigin(
LoadableTargetOrigin(
executable_path=sys.executable, module_name="fake", attribute="fake"
),
),
"fake_repo_name",
name,
)

@classmethod
def fake_job_origin(cls, job_name, repo_name=None):
return cls.fake_repo_target(repo_name).get_pipeline_origin(job_name)

@classmethod
def fake_partition_set_origin(cls, partition_set_name):
return cls.fake_repo_target().get_partition_set_origin(partition_set_name)
Expand All @@ -104,6 +110,7 @@ def build_run(
parent_run_id=None,
root_run_id=None,
pipeline_snapshot_id=None,
external_pipeline_origin=None,
):
return DagsterRun(
pipeline_name=pipeline_name,
Expand All @@ -115,6 +122,7 @@ def build_run(
root_run_id=root_run_id,
parent_run_id=parent_run_id,
pipeline_snapshot_id=pipeline_snapshot_id,
external_pipeline_origin=external_pipeline_origin,
)

def test_basic_storage(self, storage):
Expand Down Expand Up @@ -167,6 +175,35 @@ def test_fetch_by_pipeline(self, storage):
assert len(some_runs) == 1
assert some_runs[0].run_id == one

def test_fetch_by_repo(self, storage):
assert storage
self._skip_in_memory(storage)

one = make_new_run_id()
two = make_new_run_id()
job_name = "some_job"

origin_one = self.fake_job_origin(job_name, "fake_repo_one")
origin_two = self.fake_job_origin(job_name, "fake_repo_two")
storage.add_run(
TestRunStorage.build_run(
run_id=one, pipeline_name=job_name, external_pipeline_origin=origin_one
)
)
storage.add_run(
TestRunStorage.build_run(
run_id=two, pipeline_name=job_name, external_pipeline_origin=origin_two
)
)
one_runs = storage.get_runs(
RunsFilter(tags={REPOSITORY_LABEL_TAG: "fake_repo_one@fake:fake"})
)
assert len(one_runs) == 1
two_runs = storage.get_runs(
RunsFilter(tags={REPOSITORY_LABEL_TAG: "fake_repo_two@fake:fake"})
)
assert len(two_runs) == 1

def test_fetch_by_snapshot_id(self, storage):
assert storage
pipeline_def_a = PipelineDefinition(name="some_pipeline", solid_defs=[])
Expand Down

0 comments on commit e90d8d9

Please sign in to comment.