Skip to content

Commit

Permalink
Added group_name to asset (#8110)
Browse files Browse the repository at this point in the history
Added group_name parameter to `@asset` decorator and made it available in the graphql API.
  • Loading branch information
shalabhc committed May 31, 2022
1 parent cd3f69e commit dc922f9
Show file tree
Hide file tree
Showing 7 changed files with 53 additions and 0 deletions.
1 change: 1 addition & 0 deletions js_modules/dagit/packages/core/src/graphql/schema.graphql

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ class GrapheneAssetNode(graphene.ObjectType):
partitionKeys = non_null_list(graphene.String)
partitionDefinition = graphene.String()
repository = graphene.NonNull(lambda: external.GrapheneRepository)
groupName = graphene.String()

class Meta:
name = "AssetNode"
Expand Down Expand Up @@ -154,6 +155,7 @@ def __init__(
assetKey=external_asset_node.asset_key,
description=external_asset_node.op_description,
opName=external_asset_node.op_name,
groupName=external_asset_node.group_name,
)

@property
Expand Down
6 changes: 6 additions & 0 deletions python_modules/dagster/dagster/core/asset_defs/assets.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def __init__(
selected_asset_keys: Optional[AbstractSet[AssetKey]] = None,
can_subset: bool = False,
resource_defs: Optional[Mapping[str, ResourceDefinition]] = None,
group_names: Optional[Mapping[AssetKey, str]] = None,
# if adding new fields, make sure to handle them in the with_replaced_asset_keys method
):
self._node_def = node_def
Expand Down Expand Up @@ -62,6 +63,7 @@ def __init__(
f"expected keys: {all_asset_keys}",
)
self._resource_defs = check.opt_mapping_param(resource_defs, "resource_defs")
self._group_names = check.opt_mapping_param(group_names, "group_names")

if selected_asset_keys is not None:
self._selected_asset_keys = selected_asset_keys
Expand Down Expand Up @@ -142,6 +144,10 @@ def from_graph(
def can_subset(self) -> bool:
return self._can_subset

@property
def group_names(self) -> Mapping[AssetKey, str]:
return self._group_names

@property
def op(self) -> OpDefinition:
check.invariant(
Expand Down
7 changes: 7 additions & 0 deletions python_modules/dagster/dagster/core/asset_defs/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ def asset(
partitions_def: Optional[PartitionsDefinition] = ...,
partition_mappings: Optional[Mapping[str, PartitionMapping]] = ...,
op_tags: Optional[Dict[str, Any]] = ...,
group_name: Optional[str] = ...,
) -> Callable[[Callable[..., Any]], AssetsDefinition]:
...

Expand All @@ -81,6 +82,7 @@ def asset(
partitions_def: Optional[PartitionsDefinition] = None,
partition_mappings: Optional[Mapping[str, PartitionMapping]] = None,
op_tags: Optional[Dict[str, Any]] = None,
group_name: Optional[str] = None,
) -> Union[AssetsDefinition, Callable[[Callable[..., Any]], AssetsDefinition]]:
"""Create a definition for how to compute an asset.
Expand Down Expand Up @@ -125,6 +127,7 @@ def asset(
Frameworks may expect and require certain metadata to be attached to a op. Values that
are not strings will be json encoded and must meet the criteria that
`json.loads(json.dumps(value)) == value`.
group_name (Optional[str]): A string name used to organize multiple assets into groups.
Examples:
Expand Down Expand Up @@ -157,6 +160,7 @@ def inner(fn: Callable[..., Any]) -> AssetsDefinition:
partitions_def=partitions_def,
partition_mappings=partition_mappings,
op_tags=op_tags,
group_name=group_name,
)(fn)

return inner
Expand All @@ -179,6 +183,7 @@ def __init__(
partitions_def: Optional[PartitionsDefinition] = None,
partition_mappings: Optional[Mapping[str, PartitionMapping]] = None,
op_tags: Optional[Dict[str, Any]] = None,
group_name: Optional[str] = None,
):
self.name = name
# if user inputs a single string, coerce to list
Expand All @@ -197,6 +202,7 @@ def __init__(
self.partition_mappings = partition_mappings
self.op_tags = op_tags
self.resource_defs = dict(check.opt_mapping_param(resource_defs, "resource_defs"))
self.group_name = group_name

def __call__(self, fn: Callable) -> AssetsDefinition:
asset_name = self.name or fn.__name__
Expand Down Expand Up @@ -265,6 +271,7 @@ def __call__(self, fn: Callable) -> AssetsDefinition:
if self.partition_mappings
else None,
resource_defs=self.resource_defs,
group_names={out_asset_key: self.group_name} if self.group_name else None,
)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,12 @@ def asset_info_for_output(
NodeOutputHandle(node_handle, output_name)
)

def group_names_by_assets(self) -> Mapping[AssetKey, str]:
group_names: Dict[AssetKey, str] = {}
for assets_def in self._assets_defs:
group_names.update(assets_def.group_names)
return group_names


def build_asset_selection_job(
name: str,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ class ExternalAssetNode(
("output_name", Optional[str]),
("output_description", Optional[str]),
("metadata_entries", Sequence[MetadataEntry]),
("group_name", Optional[str]),
],
)
):
Expand All @@ -719,6 +720,7 @@ def __new__(
output_name: Optional[str] = None,
output_description: Optional[str] = None,
metadata_entries: Optional[Sequence[MetadataEntry]] = None,
group_name: Optional[str] = None,
):
# backcompat logic to handle ExternalAssetNodes serialized without op_names/graph_name
if not op_names:
Expand Down Expand Up @@ -750,6 +752,7 @@ def __new__(
metadata_entries=check.opt_sequence_param(
metadata_entries, "metadata_entries", of_type=MetadataEntry
),
group_name=check.opt_str_param(group_name, "group_name"),
)


Expand Down Expand Up @@ -795,6 +798,7 @@ def external_asset_graph_from_defs(
dep_by: Dict[AssetKey, Dict[AssetKey, ExternalAssetDependedBy]] = defaultdict(dict)
all_upstream_asset_keys: Set[AssetKey] = set()
op_names_by_asset_key: Dict[AssetKey, Sequence[str]] = {}
group_names: Dict[AssetKey, str] = {}

for pipeline_def in pipelines:
asset_info_by_node_output = pipeline_def.asset_layer.asset_info_by_node_output_handle
Expand All @@ -819,6 +823,8 @@ def external_asset_graph_from_defs(
downstream_asset_key=output_key
)

group_names.update(pipeline_def.asset_layer.group_names_by_assets())

asset_keys_without_definitions = all_upstream_asset_keys.difference(
node_defs_by_asset_key.keys()
).difference(source_assets_by_key.keys())
Expand All @@ -829,6 +835,7 @@ def external_asset_graph_from_defs(
dependencies=list(deps[asset_key].values()),
depended_by=list(dep_by[asset_key].values()),
job_names=[],
group_name=group_names.get(asset_key),
)
for asset_key in asset_keys_without_definitions
]
Expand All @@ -847,6 +854,7 @@ def external_asset_graph_from_defs(
job_names=[],
op_description=source_asset.description,
metadata_entries=metadata_entries,
group_name=group_names.get(source_asset.key),
)
)

Expand Down Expand Up @@ -906,6 +914,7 @@ def external_asset_graph_from_defs(
output_name=output_def.name,
output_description=output_def.description,
metadata_entries=output_def.metadata_entries,
group_name=group_names.get(asset_key),
)
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,28 @@ def asset1():
]


def test_asset_with_group_name():
@asset(group_name="group1")
def asset1():
return 1

assets_job = build_assets_job("assets_job", [asset1])
external_asset_nodes = external_asset_graph_from_defs([assets_job], source_assets_by_key={})

assert external_asset_nodes[0].group_name == "group1"


def test_asset_missing_group_name():
@asset
def asset1():
return 1

assets_job = build_assets_job("assets_job", [asset1])
external_asset_nodes = external_asset_graph_from_defs([assets_job], source_assets_by_key={})

assert external_asset_nodes[0].group_name is None


def test_two_asset_job():
@asset
def asset1():
Expand Down

0 comments on commit dc922f9

Please sign in to comment.