Skip to content

Commit

Permalink
Override resource defs when invoking assets (#8217)
Browse files Browse the repository at this point in the history
* When invoking asset defs, resource defs are used (but overridden if overrides are provided)

* Handle mocked OpExecutionContext case
  • Loading branch information
dpeng817 committed Jun 8, 2022
1 parent 1804af5 commit e9ebf8f
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 4 deletions.
61 changes: 58 additions & 3 deletions python_modules/dagster/dagster/core/asset_defs/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import AbstractSet, Dict, Iterable, Iterator, Mapping, Optional, Sequence, Set, cast

import dagster._check as check
from dagster.core.decorator_utils import get_function_params
from dagster.core.definitions import (
GraphDefinition,
NodeDefinition,
Expand All @@ -12,6 +13,7 @@
from dagster.core.definitions.events import AssetKey
from dagster.core.definitions.partition import PartitionsDefinition
from dagster.core.definitions.utils import validate_group_name
from dagster.core.execution.context.compute import OpExecutionContext
from dagster.utils import merge_dicts
from dagster.utils.backcompat import ExperimentalWarning, experimental

Expand Down Expand Up @@ -90,7 +92,32 @@ def __init__(
}

def __call__(self, *args, **kwargs):
return self._node_def(*args, **kwargs)
from dagster.core.definitions.decorators.solid_decorator import DecoratedSolidFunction

if isinstance(self.node_def, GraphDefinition):
return self._node_def(*args, **kwargs)
solid_def = self.op
provided_context: Optional[OpExecutionContext] = None
if len(args) > 0 and isinstance(args[0], OpExecutionContext):
provided_context = _build_invocation_context_with_included_resources(
self.resource_defs, args[0]
)
new_args = [provided_context, *args[1:]]
return solid_def(*new_args, **kwargs)
elif (
isinstance(solid_def.compute_fn.decorated_fn, DecoratedSolidFunction)
and solid_def.compute_fn.has_context_arg()
):
context_param_name = get_function_params(solid_def.compute_fn.decorated_fn)[0].name
if context_param_name in kwargs:
provided_context = _build_invocation_context_with_included_resources(
self.resource_defs, kwargs[context_param_name]
)
new_kwargs = dict(kwargs)
new_kwargs[context_param_name] = provided_context
return solid_def(*args, **new_kwargs)

return solid_def(*args, **kwargs)

@staticmethod
@experimental
Expand Down Expand Up @@ -193,8 +220,8 @@ def asset_key(self) -> AssetKey:
return next(iter(self.asset_keys))

@property
def resource_defs(self) -> Mapping[str, ResourceDefinition]:
return self._resource_defs
def resource_defs(self) -> Dict[str, ResourceDefinition]:
return dict(self._resource_defs)

@property
def asset_keys(self) -> AbstractSet[AssetKey]:
Expand Down Expand Up @@ -416,3 +443,31 @@ def _infer_asset_keys_by_output_names(
if output_name not in inferred_asset_keys_by_output_names:
inferred_asset_keys_by_output_names[output_name] = AssetKey([output_name])
return inferred_asset_keys_by_output_names


def _build_invocation_context_with_included_resources(
resource_defs: Dict[str, ResourceDefinition], context: OpExecutionContext
) -> OpExecutionContext:
from dagster.core.execution.context.invocation import (
UnboundSolidExecutionContext,
build_op_context,
)

override_resources = context.resources._asdict()
all_resources = merge_dicts(resource_defs, override_resources)

if isinstance(context, UnboundSolidExecutionContext):
context = cast(UnboundSolidExecutionContext, context)
# pylint: disable=protected-access
return build_op_context(
resources=all_resources,
config=context.solid_config,
resources_config=context._resources_config,
instance=context._instance,
partition_key=context._partition_key,
mapping_key=context._mapping_key,
)
else:
# If user is mocking OpExecutionContext, send it through (we don't know
# what modifications they might be making, and we don't want to override)
return context
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
import pytest

from dagster import AssetKey, IOManager, Out, Output, io_manager
from dagster import (
AssetKey,
IOManager,
Out,
Output,
ResourceDefinition,
build_op_context,
io_manager,
)
from dagster._check import CheckError
from dagster.core.asset_defs import AssetGroup, AssetIn, SourceAsset, asset, multi_asset
from dagster.core.errors import DagsterInvalidDefinitionError, DagsterInvalidInvocationError
from dagster.core.storage.mem_io_manager import InMemoryIOManager


Expand Down Expand Up @@ -294,3 +303,74 @@ def the_io_manager():
@asset(io_manager_key="the_key", io_manager_def=the_io_manager)
def the_asset():
pass


def test_asset_invocation():
@asset
def the_asset():
return 6

assert the_asset() == 6


def test_asset_invocation_input():
@asset
def input_asset(x):
return x

assert input_asset(5) == 5


def test_asset_invocation_resource_overrides():
@asset(required_resource_keys={"foo", "bar"})
def asset_reqs_resources(context):
assert context.resources.foo == "foo_resource"
assert context.resources.bar == "bar_resource"

asset_reqs_resources(build_op_context(resources={"foo": "foo_resource", "bar": "bar_resource"}))

@asset(
resource_defs={
"foo": ResourceDefinition.hardcoded_resource("orig_foo"),
"bar": ResourceDefinition.hardcoded_resource("orig_bar"),
}
)
def asset_resource_overrides(context):
assert context.resources.foo == "override_foo"
assert context.resources.bar == "orig_bar"

asset_resource_overrides(build_op_context(resources={"foo": "override_foo"}))


def test_asset_invocation_resource_errors():
@asset(resource_defs={"ignored": ResourceDefinition.hardcoded_resource("not_used")})
def asset_doesnt_use_resources():
pass

with pytest.raises(
DagsterInvalidInvocationError,
match='op "asset_doesnt_use_resources" has required resources, but no context was provided.',
):
asset_doesnt_use_resources()

@asset(resource_defs={"used": ResourceDefinition.hardcoded_resource("foo")})
def asset_uses_resources(context):
assert context.resources.used == "foo"

with pytest.raises(
DagsterInvalidInvocationError,
match='op "asset_uses_resources" has required resources, but no context was provided',
):
asset_uses_resources(None)

asset_uses_resources(build_op_context())

@asset(required_resource_keys={"foo"})
def required_key_not_provided(_):
pass

with pytest.raises(
DagsterInvalidDefinitionError,
match="resource with key 'foo' required by op 'required_key_not_provided' was not provided.",
):
required_key_not_provided(build_op_context())

0 comments on commit e9ebf8f

Please sign in to comment.