Skip to content

Commit

Permalink
[NEAT-210] Explicit View Filter ✄ (#425)
Browse files Browse the repository at this point in the history
* tests: Added failing tests

* refactor: added DMS Node Entitiy

* refactor: setup shell for new wrapped entities

* feat: Implemented wrapped filter

* refactor: Switched to wrapped entity

* refactor: introduce DMSFilder base class

* refactor: moved out creation of filter method

* refactor: moved logic

* refactor: reorg

* refactor: support setting explicit NodeType and HasData Filter

* build: changelog
  • Loading branch information
doctrino committed May 3, 2024
1 parent d30432e commit 3cac7ab
Show file tree
Hide file tree
Showing 7 changed files with 256 additions and 43 deletions.
2 changes: 1 addition & 1 deletion cognite/neat/rules/issues/dms.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def dump(self) -> dict[str, Any]:
@dataclass(frozen=True)
class NodeTypeFilterOnParentViewWarning(DMSSchemaWarning):
description = (
"Setting a node type filter on a parent view. This is no "
"Setting a node type filter on a parent view. This is not "
"recommended as parent views are typically used for multiple type of nodes."
)
fix = "Use a HasData filter instead"
Expand Down
16 changes: 14 additions & 2 deletions cognite/neat/rules/models/entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from functools import total_ordering
from typing import Annotated, Any, ClassVar, Generic, TypeVar, cast

from cognite.client.data_classes.data_modeling.ids import ContainerId, DataModelId, PropertyId, ViewId
from cognite.client.data_classes.data_modeling.ids import ContainerId, DataModelId, NodeId, PropertyId, ViewId
from pydantic import AnyHttpUrl, BaseModel, BeforeValidator, Field, PlainSerializer, model_serializer, model_validator

if sys.version_info >= (3, 11):
Expand All @@ -30,6 +30,7 @@ class EntityTypes(StrEnum):
data_value_type = "data_value_type" # these are strings, floats, ...
xsd_value_type = "xsd_value_type"
dms_value_type = "dms_value_type"
dms_node = "dms_node"
view = "view"
reference_entity = "reference_entity"
container = "container"
Expand Down Expand Up @@ -248,7 +249,7 @@ def id(self) -> str:
return str(Unknown)


T_ID = TypeVar("T_ID", bound=ContainerId | ViewId | DataModelId | PropertyId | None)
T_ID = TypeVar("T_ID", bound=ContainerId | ViewId | DataModelId | PropertyId | NodeId | None)


class DMSEntity(Entity, Generic[T_ID], ABC):
Expand Down Expand Up @@ -376,6 +377,17 @@ def from_id(cls, id: DataModelId) -> "DataModelEntity":
return cls(space=id.space, externalId=id.external_id, version=id.version)


class DMSNodeEntity(DMSEntity[NodeId]):
type_: ClassVar[EntityTypes] = EntityTypes.dms_node

def as_id(self) -> NodeId:
return NodeId(space=self.space, external_id=self.external_id)

@classmethod
def from_id(cls, id: NodeId) -> "DMSNodeEntity":
return cls(space=id.space, externalId=id.external_id)


class ReferenceEntity(ClassEntity):
type_: ClassVar[EntityTypes] = EntityTypes.reference_entity
prefix: str
Expand Down
84 changes: 51 additions & 33 deletions cognite/neat/rules/models/rules/_dms_architect_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
ViewPropertyEntity,
)
from cognite.neat.rules.models.rules._domain_rules import DomainRules
from cognite.neat.rules.models.wrapped_entities import HasDataFilter, NodeTypeFilter

from ._base import BaseMetadata, BaseRules, ExtensionCategory, RoleTypes, SchemaCompleteness, SheetEntity, SheetList
from ._dms_schema import DMSSchema, PipelineSchema
Expand Down Expand Up @@ -253,7 +254,7 @@ class DMSView(SheetEntity):
view: ViewEntity = Field(alias="View")
implements: ViewEntityList | None = Field(None, alias="Implements")
reference: URLEntity | ReferenceEntity | None = Field(alias="Reference", default=None, union_mode="left_to_right")
filter_: Literal["hasData", "nodeType"] | None = Field(None, alias="Filter")
filter_: HasDataFilter | NodeTypeFilter | None = Field(None, alias="Filter")
in_model: bool = Field(True, alias="InModel")

def as_view(self) -> dm.ViewApply:
Expand Down Expand Up @@ -849,40 +850,18 @@ def _create_views_with_node_types(
node_types = dm.NodeApplyList([])
parent_views = {parent for view in views for parent in view.implements or []}
for view in views:
ref_containers = sorted(view.referenced_containers(), key=lambda c: c.as_tuple())
dms_view = dms_view_by_id.get(view.as_id())
has_data = dm.filters.HasData(containers=list(ref_containers)) if ref_containers else None
if dms_view and isinstance(dms_view.reference, ReferenceEntity):
# If the view is a reference, we implement the reference view,
# and need the filter to match the reference
ref_view = dms_view.reference.as_view_id()
node_type = dm.filters.Equals(
["node", "type"], {"space": ref_view.space, "externalId": ref_view.external_id}
)
else:
node_type = dm.filters.Equals(["node", "type"], {"space": view.space, "externalId": view.external_id})
if view.as_id() in parent_views:
if dms_view and dms_view.filter_ == "nodeType":
view.filter = self._create_view_filter(view, dms_view, parent_views)
if (
isinstance(view.filter, dm.filters.Equals)
and isinstance(view.filter._value, dict)
and (node_space := view.filter._value.get("space"))
and (node_ext_id := view.filter._value.get("externalId"))
):
node_types.append(dm.NodeApply(space=node_space, external_id=node_ext_id))
if view.as_id() in parent_views:
warnings.warn(issues.dms.NodeTypeFilterOnParentViewWarning(view.as_id()), stacklevel=2)
view.filter = node_type
node_types.append(dm.NodeApply(space=view.space, external_id=view.external_id, sources=[]))
else:
view.filter = has_data
elif has_data is None:
# Child filter without container properties
if dms_view and dms_view.filter_ == "hasData":
warnings.warn(issues.dms.HasDataFilterOnNoPropertiesViewWarning(view.as_id()), stacklevel=2)
view.filter = node_type
node_types.append(dm.NodeApply(space=view.space, external_id=view.external_id, sources=[]))
else:
if dms_view and (dms_view.filter_ == "hasData" or dms_view.filter_ is None):
# Default option
view.filter = has_data
elif dms_view and dms_view.filter_ == "nodeType":
view.filter = node_type
node_types.append(dm.NodeApply(space=view.space, external_id=view.external_id, sources=[]))
else:
view.filter = has_data

return views, node_types

def _create_containers(
Expand Down Expand Up @@ -983,6 +962,45 @@ def _gather_properties(

return container_properties_by_id, view_properties_by_id

def _create_view_filter(
self, view: dm.ViewApply, dms_view: DMSView | None, parent_views: set[dm.ViewId]
) -> dm.Filter:
selected_filter_name = (dms_view and dms_view.filter_ and dms_view.filter_.name) or ""
if dms_view and dms_view.filter_ and not dms_view.filter_.is_empty:
# Has Explicit Filter
return dms_view.filter_.as_filter()

ref_containers = view.referenced_containers()
has_data = dm.filters.HasData(containers=list(ref_containers)) if ref_containers else None
if dms_view and isinstance(dms_view.reference, ReferenceEntity):
# If the view is a reference, we implement the reference view,
# and need the filter to match the reference
ref_view = dms_view.reference.as_view_id()
node_type = dm.filters.Equals(
["node", "type"], {"space": ref_view.space, "externalId": ref_view.external_id}
)
else:
node_type = dm.filters.Equals(["node", "type"], {"space": view.space, "externalId": view.external_id})

if view.as_id() in parent_views:
if selected_filter_name == "nodeType":
return node_type
else:
return cast(dm.Filter, has_data)
elif has_data is None:
# Child filter without container properties
if selected_filter_name == "hasData":
warnings.warn(issues.dms.HasDataFilterOnNoPropertiesViewWarning(view.as_id()), stacklevel=2)
return node_type
else:
if dms_view and ((selected_filter_name == "hasData") or dms_view.filter_ is None):
# Default option
return has_data
elif selected_filter_name == "nodeType":
return node_type
else:
return has_data


class _DMSRulesConverter:
def __init__(self, dms: DMSRules):
Expand Down
133 changes: 133 additions & 0 deletions cognite/neat/rules/models/wrapped_entities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
from abc import ABC, abstractmethod
from functools import total_ordering
from typing import Any, ClassVar, TypeVar

from cognite.client import data_modeling as dm
from cognite.client.data_classes.data_modeling import ContainerId, NodeId
from pydantic import BaseModel, model_serializer, model_validator

from cognite.neat.rules.models.entities import ContainerEntity, DMSNodeEntity, Entity


@total_ordering
class WrappedEntity(BaseModel, ABC):
name: ClassVar[str]
_inner_cls: ClassVar[type[Entity]]
_support_list: ClassVar[bool] = False
inner: Entity | list[Entity] | None

@classmethod
def load(cls: "type[T_WrappedEntity]", data: Any) -> "T_WrappedEntity":
if isinstance(data, cls):
return data
return cls.model_validate(data)

@model_validator(mode="before")
def _load(cls, data: Any) -> dict:
if isinstance(data, dict):
return data
elif not isinstance(data, str):
raise ValueError(f"Cannot load {cls.__name__} from {data}")
elif not data.casefold().startswith(cls.name.casefold()):
raise ValueError(f"Expected {cls.name} but got {data}")
result = cls._parse(data)
return result

@classmethod
def _parse(cls, data: str) -> dict:
if data.casefold() == cls.name.casefold():
return {"inner": None}
inner = data[len(cls.name) :].removeprefix("(").removesuffix(")")
if cls._support_list:
return {"inner": [cls._inner_cls.load(entry.strip()) for entry in inner.split(",")]}
return {"inner": cls._inner_cls.load(inner)}

@model_serializer(when_used="unless-none", return_type=str)
def as_str(self) -> str:
return str(self)

def __str__(self):
return self.id

@property
def id(self) -> str:
inner = self.as_tuple()[1:]
return f"{self.name}({','.join(inner)})"

@property
def is_empty(self) -> bool:
return self.inner is None or (isinstance(self.inner, list) and not self.inner)

def dump(self) -> str:
return str(self)

def as_tuple(self) -> tuple[str, ...]:
entities: list[str] = []
if isinstance(self.inner, Entity):
entities.append(str(self.inner))
elif isinstance(self.inner, list):
entities.extend(map(str, self.inner))
return self.name, *entities

def __lt__(self, other: object) -> bool:
if not isinstance(other, WrappedEntity):
return NotImplemented
return self.as_tuple() < other.as_tuple()

def __eq__(self, other: object) -> bool:
if not isinstance(other, WrappedEntity):
return NotImplemented
return self.as_tuple() == other.as_tuple()

def __hash__(self) -> int:
return hash(str(self))

def __repr__(self) -> str:
return self.id


T_WrappedEntity = TypeVar("T_WrappedEntity", bound=WrappedEntity)


class DMSFilter(WrappedEntity):
@abstractmethod
def as_filter(self, default: Any | None = None) -> dm.filters.Filter:
raise NotImplementedError


class NodeTypeFilter(DMSFilter):
name: ClassVar[str] = "nodeType"
_inner_cls: ClassVar[type[DMSNodeEntity]] = DMSNodeEntity
inner: DMSNodeEntity | None = None

def as_filter(self, default: NodeId | None = None) -> dm.Filter:
if self.inner is not None:
space = self.inner.space
external_id = self.inner.external_id
elif default is not None:
space = default.space
external_id = default.external_id
else:
raise ValueError("Empty nodeType filter, please provide a default node.")
return dm.filters.Equals(["node", "type"], {"space": space, "externalId": external_id})


class HasDataFilter(DMSFilter):
name: ClassVar[str] = "hasData"
_inner_cls: ClassVar[type[ContainerEntity]] = ContainerEntity
_support_list: ClassVar[bool] = True
inner: list[ContainerEntity] | None = None # type: ignore[assignment]

def as_filter(self, default: list[ContainerId] | None = None) -> dm.Filter:
containers: list[ContainerId]
if self.inner:
containers = [container.as_id() for container in self.inner]
elif default:
containers = default
else:
raise ValueError("Empty hasData filter, please provide a default containers.")

return dm.filters.HasData(
# Sorting to ensure deterministic order
containers=sorted(containers, key=lambda container: container.as_tuple()) # type: ignore[union-attr]
)
8 changes: 8 additions & 0 deletions docs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ Changes are grouped as follows:
- `Fixed` for any bug fixes.
- `Security` in case of vulnerabilities.


## TBD
### Added
- In `DMSRules`, added support for setting containerId and nodeId in `View.Filter`. Earlier, only `nodeType` and
`hasData` were supported which always used an implicit `containerId` and `nodeId` respectively. Now, the user can
specify the node type and container id(s) by setting `nodeType(my_space:my_node_type)` and
`hasData(my_space:my_container_id, my_space:my_other_container_id)`.

## [0.75.9] - 04-05-24
### Improved
- Steps are now categorized as `current`, `legacy`, and `io` steps
Expand Down
14 changes: 7 additions & 7 deletions tests/tests_unit/rules/test_models/test_dms_architect_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def rules_schema_tests_cases() -> Iterable[ParameterSet]:
),
]
),
node_types=dm.NodeApplyList([dm.NodeApply(space="my_space", external_id="WindFarm", sources=[])]),
node_types=dm.NodeApplyList([dm.NodeApply(space="my_space", external_id="WindFarm")]),
),
id="Two properties, one container, one view",
)
Expand Down Expand Up @@ -605,7 +605,7 @@ def rules_schema_tests_cases() -> Iterable[ParameterSet]:
),
node_types=dm.NodeApplyList(
[
dm.NodeApply(space="my_space", external_id="Activity", sources=[]),
dm.NodeApply(space="my_space", external_id="Activity"),
]
),
)
Expand Down Expand Up @@ -638,7 +638,7 @@ def rules_schema_tests_cases() -> Iterable[ParameterSet]:
)
],
views=[
DMSViewWrite(view="generating_unit", class_="generating_unit"),
DMSViewWrite(view="generating_unit", class_="generating_unit", filter_="NodeType(sp_other:wind_turbine)"),
],
containers=[
DMSContainerWrite(container="generating_unit", class_="generating_unit"),
Expand Down Expand Up @@ -672,7 +672,7 @@ def rules_schema_tests_cases() -> Iterable[ParameterSet]:
container_property_identifier="display_name",
),
},
filter=dm.filters.HasData(containers=[dm.ContainerId("my_space", "generating_unit")]),
filter=dm.filters.Equals(["node", "type"], {"space": "sp_other", "externalId": "wind_turbine"}),
),
]
),
Expand All @@ -685,12 +685,12 @@ def rules_schema_tests_cases() -> Iterable[ParameterSet]:
),
]
),
node_types=dm.NodeApplyList([]),
node_types=dm.NodeApplyList([dm.NodeApply(space="sp_other", external_id="wind_turbine")]),
)
yield pytest.param(
dms_rules,
expected_schema,
id="No casing standardization",
id="Explict set NodeType Filter",
)

dms_rules = DMSRulesWrite(
Expand Down Expand Up @@ -758,7 +758,7 @@ def rules_schema_tests_cases() -> Iterable[ParameterSet]:
)
]
),
node_types=dm.NodeApplyList([dm.NodeApply(space="sp_solution", external_id="Asset", sources=[])]),
node_types=dm.NodeApplyList([dm.NodeApply(space="sp_solution", external_id="Asset")]),
)

yield pytest.param(
Expand Down
Loading

0 comments on commit 3cac7ab

Please sign in to comment.