Skip to content

Commit

Permalink
Simplify IWorkspace API (#6821)
Browse files Browse the repository at this point in the history
Summary:
Now that the daemon manages the workspace and always knows the set of locations, this API can become less weird.

Test Plan: BK
  • Loading branch information
gibsondan committed Mar 1, 2022
1 parent f330e26 commit 3408263
Show file tree
Hide file tree
Showing 9 changed files with 12 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def resolve_asset_nodes(self, graphene_info):
else:
repo_handle = self._represented_pipeline.repository_handle
origin = repo_handle.repository_location_origin
location = graphene_info.context.get_location(origin)
location = graphene_info.context.get_location(origin.location_name)
ext_repo = location.get_repository(repo_handle.repository_name)
nodes = [
node
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def launch_run(self, context: LaunchRunContext) -> None:
)

repository_location = context.workspace.get_location(
run.external_pipeline_origin.external_repository_origin.repository_location_origin
run.external_pipeline_origin.external_repository_origin.repository_location_origin.location_name
)

check.inst(
Expand Down
5 changes: 2 additions & 3 deletions python_modules/dagster/dagster/core/workspace/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,7 @@ def has_permission(self, permission: str) -> bool:
def show_instance_config(self) -> bool:
return True

def get_location(self, origin):
location_name = origin.location_name
def get_location(self, location_name: str):
location_entry = self.get_location_entry(location_name)
if not location_entry:
raise DagsterInvariantViolationError(
Expand All @@ -112,7 +111,7 @@ def get_location(self, origin):
if location_entry.repository_location:
return location_entry.repository_location

error_info = location_entry.load_error
error_info = cast(SerializableErrorInfo, location_entry.load_error)
raise DagsterRepositoryLocationLoadError(
f"Failure loading {location_name}: {error_info.to_string()}",
load_error_infos=[error_info],
Expand Down
12 changes: 3 additions & 9 deletions python_modules/dagster/dagster/core/workspace/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,12 @@ class WorkspaceLocationEntry(NamedTuple):

class IWorkspace(ABC):
"""
A class that knows how to get a RepositoryLocation from a RepositoryLocationOrigin,
possibly creating it lazily.
Used both by Dagit (where this is a fixed WorkspaceRequestContext that manages a static
list of RepositoryLocations) and the dagster-daemon process in, which lazily creates and reloads
repository locations in a DynamicWorkspace based on the running schedules, sensors, and queued
runs in the database.
Manages a set of RepositoryLocations.
"""

@abstractmethod
def get_location(self, origin: RepositoryLocationOrigin):
"""Return the RepositoryLocation for the given RepositoryLocationOrigin, or raise an error if there is an error loading it."""
def get_location(self, location_name: str):
"""Return the RepositoryLocation for the given location name, or raise an error if there is an error loading it."""

@abstractmethod
def get_workspace_snapshot(self) -> Dict[str, WorkspaceLocationEntry]:
Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/daemon/backfill.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def execute_backfill_iteration(instance, workspace, logger, debug_crash_flags=No
)

try:
repo_location = workspace.get_location(origin)
repo_location = workspace.get_location(origin.location_name)
repo_name = backfill_job.partition_set_origin.external_repository_origin.repository_name
partition_set_name = backfill_job.partition_set_origin.partition_set_name
if not repo_location.has_repository(repo_name):
Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/daemon/sensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,7 @@ def _evaluate_sensor(
sensor_origin = external_sensor.get_external_origin()
repository_handle = external_sensor.handle.repository_handle
repo_location = workspace.get_location(
sensor_origin.external_repository_origin.repository_location_origin
sensor_origin.external_repository_origin.repository_location_origin.location_name
)

sensor_runtime_data = repo_location.get_external_sensor_execution_data(
Expand Down
4 changes: 1 addition & 3 deletions python_modules/dagster/dagster/daemon/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,10 @@ def get_workspace_snapshot(self) -> Dict[str, WorkspaceLocationEntry]:
def _load_workspace(self) -> Dict[str, WorkspaceLocationEntry]:
pass

def get_location(self, origin) -> RepositoryLocation:
def get_location(self, location_name: str) -> RepositoryLocation:
if self._location_entries == None:
self._location_entries = self._load_workspace()

location_name = origin.location_name

if location_name not in self._location_entries:
raise DagsterRepositoryLocationLoadError(
f"Location {location_name} does not exist in workspace",
Expand Down
2 changes: 1 addition & 1 deletion python_modules/dagster/dagster/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,7 @@ def _schedule_runs_at_time(
)

repo_location = workspace.get_location(
schedule_origin.external_repository_origin.repository_location_origin
schedule_origin.external_repository_origin.repository_location_origin.location_name
)

external_pipeline = repo_location.get_external_pipeline(pipeline_selector)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1798,7 +1798,7 @@ def test_grpc_server_down(instance):
with _grpc_server_external_repo(port) as external_repo:
external_schedule = external_repo.get_external_schedule("simple_schedule")
instance.start_schedule(external_schedule)
workspace.get_location(location_origin)
workspace.get_location(location_origin.location_name)

# Server is no longer running, ticks fail but indicate it will resume once it is reachable
for _trial in range(3):
Expand Down

0 comments on commit 3408263

Please sign in to comment.