Skip to content

Commit

Permalink
Let InProcessRepositoryLocationOrigin take in multiple repos (#7386)
Browse files Browse the repository at this point in the history
Summary:
As part of making the gRPC server startup time configurable, I took a look at places where we were spinning up local gRPC servers and would need to add an instance check. I found one place where code previews was using a local gRPC server subprocss when it could be using an in process repo location. This PR gives in process repo locations the ability to load multiple repositories rather than spin up a gRPC server, which will let us use that in code previews and avoid adding a gross hook to set a timeout there that we don't really need.
  • Loading branch information
gibsondan committed Apr 12, 2022
1 parent 00bde2b commit 38625ae
Show file tree
Hide file tree
Showing 13 changed files with 152 additions and 76 deletions.
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import sys

import pendulum
import pytest
Expand All @@ -10,7 +11,6 @@
main_repo_name,
)

from dagster.core.definitions.reconstruct import ReconstructableRepository
from dagster.core.host_representation import (
ExternalRepositoryOrigin,
InProcessRepositoryLocationOrigin,
Expand All @@ -21,6 +21,7 @@
InstigatorType,
ScheduleInstigatorData,
)
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from dagster.seven.compat.pendulum import create_pendulum_time
from dagster.utils import Counter, traced_counter

Expand Down Expand Up @@ -231,9 +232,13 @@ def default_execution_params():

def _get_unloadable_schedule_origin(name):
working_directory = os.path.dirname(__file__)
recon_repo = ReconstructableRepository.for_file(__file__, "doesnt_exist", working_directory)
loadable_target_origin = LoadableTargetOrigin(
executable_path=sys.executable,
python_file=__file__,
working_directory=working_directory,
)
return ExternalRepositoryOrigin(
InProcessRepositoryLocationOrigin(recon_repo), "fake_repository"
InProcessRepositoryLocationOrigin(loadable_target_origin), "fake_repository"
).get_instigator_origin(name)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from dagster import check
from dagster.core.definitions.reconstruct import ReconstructableRepository
from dagster.core.host_representation import InProcessRepositoryLocationOrigin
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from dagster.utils import file_relative_path, git_repository_root

IS_BUILDKITE = os.getenv("BUILDKITE") is not None
Expand Down Expand Up @@ -43,9 +44,10 @@ def build_and_tag_test_image(tag):
@contextmanager
def get_test_project_external_pipeline(pipeline_name):
with InProcessRepositoryLocationOrigin(
ReconstructableRepository.for_file(
file_relative_path(__file__, "test_pipelines/repo.py"),
"define_demo_execution_repo",
LoadableTargetOrigin(
executable_path=sys.executable,
python_file=file_relative_path(__file__, "test_pipelines/repo.py"),
attribute="define_demo_execution_repo",
)
).create_location() as location:
yield location.get_repository("demo_execution_repo").get_full_external_pipeline(
Expand Down
24 changes: 12 additions & 12 deletions python_modules/dagster-test/dagster_test/test_project/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
RepositoryPythonOrigin,
)
from dagster.core.test_utils import in_process_test_workspace
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from dagster.serdes import create_snapshot_id, whitelist_for_serdes
from dagster.utils import file_relative_path, git_repository_root

Expand Down Expand Up @@ -201,15 +202,13 @@ def get_external_origin(self):
return ExternalPipelineOrigin(
external_repository_origin=ExternalRepositoryOrigin(
repository_location_origin=InProcessRepositoryLocationOrigin(
recon_repo=ReconstructableRepository(
pointer=FileCodePointer(
python_file="/dagster_test/test_project/test_pipelines/repo.py",
fn_name="define_demo_execution_repo",
),
container_image=self._container_image,
loadable_target_origin=LoadableTargetOrigin(
executable_path="python",
entry_point=DEFAULT_DAGSTER_ENTRY_POINT,
)
python_file="/dagster_test/test_project/test_pipelines/repo.py",
attribute="define_demo_execution_repo",
),
container_image=self._container_image,
entry_point=DEFAULT_DAGSTER_ENTRY_POINT,
),
repository_name="demo_execution_repo",
),
Expand Down Expand Up @@ -266,11 +265,12 @@ def selector_id(self):
def get_test_project_workspace(instance, container_image=None):
with in_process_test_workspace(
instance,
recon_repo=ReconstructableRepository.for_file(
file_relative_path(__file__, "test_pipelines/repo.py"),
"define_demo_execution_repo",
container_image=container_image,
loadable_target_origin=LoadableTargetOrigin(
executable_path=sys.executable,
python_file=file_relative_path(__file__, "test_pipelines/repo.py"),
attribute="define_demo_execution_repo",
),
container_image=container_image,
) as workspace:
yield workspace

Expand Down
39 changes: 30 additions & 9 deletions python_modules/dagster/dagster/core/host_representation/origin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Any,
Dict,
Generator,
List,
Mapping,
NamedTuple,
NoReturn,
Expand All @@ -18,8 +19,8 @@
)

from dagster import check
from dagster.core.definitions.reconstruct import ReconstructableRepository
from dagster.core.errors import DagsterInvariantViolationError, DagsterUserCodeUnreachableError
from dagster.core.origin import DEFAULT_DAGSTER_ENTRY_POINT
from dagster.core.types.loadable_target_origin import LoadableTargetOrigin
from dagster.serdes import (
DefaultNamedTupleSerializer,
Expand Down Expand Up @@ -143,16 +144,38 @@ def create_location(self) -> NoReturn:

@whitelist_for_serdes
class InProcessRepositoryLocationOrigin(
NamedTuple("_InProcessRepositoryLocationOrigin", [("recon_repo", ReconstructableRepository)]),
NamedTuple(
"_InProcessRepositoryLocationOrigin",
[
("loadable_target_origin", LoadableTargetOrigin),
("container_image", Optional[str]),
("entry_point", List[str]),
],
),
RepositoryLocationOrigin,
):
"""Identifies a repository location constructed in the host process. Should only be
used in tests.
"""Identifies a repository location constructed in the same process. Primarily
used in tests, since Dagster system processes like Dagit and the daemon do not
load user code in the same process.
"""

def __new__(cls, recon_repo: ReconstructableRepository):
def __new__(
cls,
loadable_target_origin: LoadableTargetOrigin,
container_image: Optional[str] = None,
entry_point: Optional[List[str]] = None,
):
return super(InProcessRepositoryLocationOrigin, cls).__new__(
cls, check.inst_param(recon_repo, "recon_repo", ReconstructableRepository)
cls,
check.inst_param(
loadable_target_origin, "loadable_target_origin", LoadableTargetOrigin
),
container_image=check.opt_str_param(container_image, "container_image"),
entry_point=(
check.opt_list_param(entry_point, "entry_point")
if entry_point
else DEFAULT_DAGSTER_ENTRY_POINT
),
)

@property
Expand All @@ -164,9 +187,7 @@ def is_reload_supported(self) -> bool:
return False

def get_display_metadata(self) -> Dict[str, Any]:
return {
"in_process_code_pointer": self.recon_repo.pointer.describe(),
}
return {}

def create_location(self) -> "InProcessRepositoryLocation":
from dagster.core.host_representation.repository_location import InProcessRepositoryLocation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import datetime
import sys
import threading
from abc import abstractmethod
from contextlib import AbstractContextManager
Expand All @@ -21,7 +20,7 @@
from dagster.api.snapshot_schedule import sync_get_external_schedule_execution_data_grpc
from dagster.api.snapshot_sensor import sync_get_external_sensor_execution_data_grpc
from dagster.core.code_pointer import CodePointer
from dagster.core.definitions.reconstruct import ReconstructablePipeline, ReconstructableRepository
from dagster.core.definitions.reconstruct import ReconstructablePipeline
from dagster.core.errors import DagsterInvariantViolationError
from dagster.core.execution.api import create_execution_plan
from dagster.core.execution.plan.state import KnownExecutionState
Expand Down Expand Up @@ -262,22 +261,30 @@ def get_repository_python_origin(self, repository_name: str) -> "RepositoryPytho

class InProcessRepositoryLocation(RepositoryLocation):
def __init__(self, origin: InProcessRepositoryLocationOrigin):
self._origin = check.inst_param(origin, "origin", InProcessRepositoryLocationOrigin)
from dagster.grpc.server import LoadedRepositories

self._recon_repo = self._origin.recon_repo
self._origin = check.inst_param(origin, "origin", InProcessRepositoryLocationOrigin)

repo_def = self._recon_repo.get_definition()
pointer = self._recon_repo.pointer
loadable_target_origin = self._origin.loadable_target_origin
self._loaded_repositories = LoadedRepositories(
loadable_target_origin, self._origin.entry_point
)

self._repository_code_pointer_dict = {repo_def.name: pointer}
self._repository_code_pointer_dict = self._loaded_repositories.code_pointers_by_repo_name

def_name = repo_def.name
self._recon_repos = {
repo_name: self._loaded_repositories.get_recon_repo(repo_name)
for repo_name in self._repository_code_pointer_dict
}

self._external_repo = external_repo_from_def(
repo_def,
RepositoryHandle(repository_name=def_name, repository_location=self),
)
self._repositories = {self._external_repo.name: self._external_repo}
self._repositories = {}
for repo_name in self._repository_code_pointer_dict:
recon_repo = self._loaded_repositories.get_recon_repo(repo_name)
repo_def = recon_repo.get_definition()
self._repositories[repo_name] = external_repo_from_def(
repo_def,
RepositoryHandle(repository_name=repo_name, repository_location=self),
)

@property
def is_reload_supported(self) -> bool:
Expand All @@ -289,27 +296,22 @@ def origin(self) -> InProcessRepositoryLocationOrigin:

@property
def executable_path(self) -> Optional[str]:
return (
self._recon_repo.executable_path if self._recon_repo.executable_path else sys.executable
)
return self._origin.loadable_target_origin.executable_path

@property
def container_image(self) -> Optional[str]:
return self._recon_repo.container_image
return self._origin.container_image

@property
def entry_point(self) -> Optional[List[str]]:
return self._recon_repo.entry_point
return self._origin.entry_point

@property
def repository_code_pointer_dict(self) -> Dict[str, CodePointer]:
return self._repository_code_pointer_dict

def get_reconstructable_pipeline(self, name: str) -> ReconstructablePipeline:
return self.get_reconstructable_repository().get_reconstructable_pipeline(name)

def get_reconstructable_repository(self) -> ReconstructableRepository:
return self.origin.recon_repo
return self._recon_repos[name].get_reconstructable_pipeline(name)

def get_repository(self, name: str) -> ExternalRepository:
return self._repositories[name]
Expand Down Expand Up @@ -378,7 +380,7 @@ def get_external_partition_config(
check.str_param(partition_name, "partition_name")

return get_partition_config(
recon_repo=self._recon_repo,
recon_repo=self._recon_repos[repository_handle.repository_name],
partition_set_name=partition_set_name,
partition_name=partition_name,
)
Expand All @@ -391,7 +393,7 @@ def get_external_partition_tags(
check.str_param(partition_name, "partition_name")

return get_partition_tags(
recon_repo=self._recon_repo,
recon_repo=self._recon_repos[repository_handle.repository_name],
partition_set_name=partition_set_name,
partition_name=partition_name,
)
Expand All @@ -403,7 +405,7 @@ def get_external_partition_names(
check.str_param(partition_set_name, "partition_set_name")

return get_partition_names(
recon_repo=self._recon_repo,
recon_repo=self._recon_repos[repository_handle.repository_name],
partition_set_name=partition_set_name,
)

Expand All @@ -420,7 +422,7 @@ def get_external_schedule_execution_data(
check.opt_inst_param(scheduled_execution_time, "scheduled_execution_time", PendulumDateTime)

return get_external_schedule_execution(
self._recon_repo,
recon_repo=self._recon_repos[repository_handle.repository_name],
instance_ref=instance.get_ref(),
schedule_name=schedule_name,
scheduled_execution_timestamp=scheduled_execution_time.timestamp()
Expand All @@ -441,7 +443,12 @@ def get_external_sensor_execution_data(
cursor: Optional[str],
) -> Union["SensorExecutionData", "ExternalSensorExecutionErrorData"]:
return get_external_sensor_execution(
self._recon_repo, instance.get_ref(), name, last_completion_time, last_run_key, cursor
self._recon_repos[repository_handle.repository_name],
instance.get_ref(),
name,
last_completion_time,
last_run_key,
cursor,
)

def get_external_partition_set_execution_param_data(
Expand All @@ -455,7 +462,7 @@ def get_external_partition_set_execution_param_data(
check.list_param(partition_names, "partition_names", of_type=str)

return get_partition_set_execution_param_data(
self._recon_repo,
recon_repo=self._recon_repos[repository_handle.repository_name],
partition_set_name=partition_set_name,
partition_names=partition_names,
)
Expand Down
10 changes: 8 additions & 2 deletions python_modules/dagster/dagster/core/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,15 @@ def create_origins(self):


@contextmanager
def in_process_test_workspace(instance, recon_repo):
def in_process_test_workspace(instance, loadable_target_origin, container_image=None):
with WorkspaceProcessContext(
instance, InProcessTestWorkspaceLoadTarget(InProcessRepositoryLocationOrigin(recon_repo))
instance,
InProcessTestWorkspaceLoadTarget(
InProcessRepositoryLocationOrigin(
loadable_target_origin,
container_image=container_image,
),
),
) as workspace_process_context:
yield workspace_process_context.create_request_context()

Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import sys
from typing import List, NamedTuple, Optional

from dagster import check
Expand All @@ -20,7 +21,7 @@ class LoadableTargetOrigin(
):
def __new__(
cls,
executable_path,
executable_path=None,
python_file=None,
module_name=None,
working_directory=None,
Expand All @@ -29,7 +30,7 @@ def __new__(
):
return super(LoadableTargetOrigin, cls).__new__(
cls,
executable_path=check.str_param(executable_path, "executable_path"),
executable_path=check.opt_str_param(executable_path, "executable_path", sys.executable),
python_file=check.opt_str_param(python_file, "python_file"),
module_name=check.opt_str_param(module_name, "module_name"),
working_directory=check.opt_str_param(working_directory, "working_directory"),
Expand Down
6 changes: 5 additions & 1 deletion python_modules/dagster/dagster/core/workspace/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,12 @@ def repository_location_errors(self) -> List[SerializableErrorInfo]:

def get_repository_location(self, name: str) -> RepositoryLocation:
location_entry = self.get_location_entry(name)
if not location_entry or not location_entry.repository_location:

if not location_entry:
raise Exception(f"Location {name} not in workspace")
if location_entry.load_error:
raise Exception(f"Error loading location {name}: {location_entry.load_error}")

return cast(RepositoryLocation, location_entry.repository_location)

def has_repository_location_error(self, name: str) -> bool:
Expand Down

0 comments on commit 38625ae

Please sign in to comment.