Skip to content

Commit

Permalink
More fixes to ensure Data Modeling Client forward-compatibility (#1731)
Browse files Browse the repository at this point in the history
  • Loading branch information
erlendvollset committed Apr 22, 2024
1 parent e21dbdb commit c5f3703
Show file tree
Hide file tree
Showing 13 changed files with 162 additions and 106 deletions.
14 changes: 13 additions & 1 deletion cognite/client/data_classes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
convert_timestamp_columns_to_datetime,
notebook_display_with_fallback,
)
from cognite.client.utils._text import convert_all_keys_to_camel_case, to_camel_case
from cognite.client.utils._text import convert_all_keys_recursive, convert_all_keys_to_camel_case, to_camel_case
from cognite.client.utils._time import TIME_ATTRIBUTES, convert_and_isoformat_time_attrs

if TYPE_CHECKING:
Expand Down Expand Up @@ -182,6 +182,18 @@ def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None =
return fast_dict_load(cls, resource, cognite_client=cognite_client)


class UnknownCogniteObject(CogniteObject):
def __init__(self, data: dict[str, Any]) -> None:
self.__data = data

@classmethod
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
return cls(resource)

def dump(self, camel_case: bool = True) -> dict[str, Any]:
return convert_all_keys_recursive(self.__data, camel_case=camel_case)


T_CogniteObject = TypeVar("T_CogniteObject", bound=CogniteObject)


Expand Down
83 changes: 41 additions & 42 deletions cognite/client/data_classes/aggregations.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Any,
ClassVar,
Iterator,
MutableSequence,
Sequence,
SupportsIndex,
TypeVar,
Expand All @@ -18,12 +19,11 @@
overload,
)

from typing_extensions import TypeAlias
from typing_extensions import Self, TypeAlias

from cognite.client.data_classes._base import CogniteObject, CogniteResourceList
from cognite.client.data_classes._base import CogniteObject, CogniteResourceList, UnknownCogniteObject
from cognite.client.data_classes.labels import Label
from cognite.client.utils._auxiliary import rename_and_exclude_keys
from cognite.client.utils._text import convert_all_keys_recursive, convert_all_keys_to_snake_case
from cognite.client.utils._text import convert_all_keys_recursive

if TYPE_CHECKING:
from cognite.client import CogniteClient
Expand All @@ -46,28 +46,26 @@ def _load(cls, resource: dict, cognite_client: CogniteClient | None = None) -> C


@dataclass
class Aggregation(ABC):
class Aggregation(CogniteObject, ABC):
_aggregation_name: ClassVar[str]

property: str

@classmethod
def load(cls, aggregation: dict[str, Any]) -> Aggregation:
(aggregation_name,) = aggregation
body = convert_all_keys_to_snake_case(aggregation[aggregation_name])
if aggregation_name == "avg":
return Avg(**body)
elif aggregation_name == "count":
return Count(**body)
elif aggregation_name == "max":
return Max(**body)
elif aggregation_name == "min":
return Min(**body)
elif aggregation_name == "sum":
return Sum(**body)
elif aggregation_name == "histogram":
return Histogram(**body)
raise ValueError(f"Unknown aggregation: {aggregation_name}")
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Aggregation:
if "avg" in resource:
return Avg(property=resource["avg"]["property"])
elif "count" in resource:
return Count(property=resource["count"]["property"])
elif "max" in resource:
return Max(property=resource["max"]["property"])
elif "min" in resource:
return Min(property=resource["min"]["property"])
elif "sum" in resource:
return Sum(property=resource["sum"]["property"])
elif "histogram" in resource:
return Histogram(property=resource["histogram"]["property"], interval=resource["histogram"]["interval"])
return cast(Aggregation, UnknownCogniteObject(resource))

def dump(self, camel_case: bool = True) -> dict[str, Any]:
output = {self._aggregation_name: {"property": self.property}}
Expand All @@ -77,7 +75,7 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]:


@dataclass
class MetricAggregation(Aggregation): ...
class MetricAggregation(Aggregation, ABC): ...


@final
Expand Down Expand Up @@ -123,38 +121,35 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]:
return output


T_AggregatedValue = TypeVar("T_AggregatedValue", bound="AggregatedValue")


@dataclass
class AggregatedValue(ABC):
class AggregatedValue(CogniteObject, ABC):
_aggregate: ClassVar[str] = field(init=False)

property: str

@classmethod
def load(cls: type[T_AggregatedValue], aggregated_value: dict[str, Any]) -> T_AggregatedValue:
if "aggregate" not in aggregated_value:
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
if "aggregate" not in resource:
raise ValueError("Missing aggregate, this is required")
aggregate = aggregated_value["aggregate"]
aggregated_value = rename_and_exclude_keys(aggregated_value, exclude={"aggregate"})
body = convert_all_keys_to_snake_case(aggregated_value)
aggregate = resource["aggregate"]

if aggregate == "avg":
deserialized: AggregatedValue = AvgValue(**body)
deserialized: Any = AvgValue(property=resource["property"], value=resource["value"])
elif aggregate == "count":
deserialized = CountValue(**body)
deserialized = CountValue(property=resource["property"], value=resource["value"])
elif aggregate == "max":
deserialized = MaxValue(**body)
deserialized = MaxValue(property=resource["property"], value=resource["value"])
elif aggregate == "min":
deserialized = MinValue(**body)
deserialized = MinValue(property=resource["property"], value=resource["value"])
elif aggregate == "sum":
deserialized = SumValue(**body)
deserialized = SumValue(property=resource["property"], value=resource["value"])
elif aggregate == "histogram":
deserialized = HistogramValue(**body)
deserialized = HistogramValue(
property=resource["property"], interval=resource["interval"], buckets=resource["buckets"]
)
else:
raise ValueError(f"Unknown aggregation: {aggregate}")
return cast(T_AggregatedValue, deserialized)
deserialized = UnknownCogniteObject(resource)
return cast(Self, deserialized)

def dump(self, camel_case: bool = True) -> dict[str, Any]:
output = {"aggregate": self._aggregate, "property": self.property}
Expand Down Expand Up @@ -215,9 +210,13 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]:
return {"start": self.start, "count": self.count}


class Buckets(UserList):
def __init__(self, items: Collection[Any]) -> None:
super().__init__([Bucket(**bucket) if isinstance(bucket, dict) else bucket for bucket in items])
class Buckets(UserList, MutableSequence[Bucket]):
def __init__(self, items: Collection[dict | Bucket]) -> None:
buckets = [
Bucket(start=bucket["start"], count=bucket["count"]) if isinstance(bucket, dict) else bucket
for bucket in items
]
super().__init__(buckets)

def dump(self, camel_case: bool = True) -> list[dict[str, Any]]:
return [bucket.dump(camel_case) for bucket in self.data]
Expand Down
5 changes: 3 additions & 2 deletions cognite/client/data_classes/data_modeling/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
CogniteFilter,
CogniteObject,
CogniteResourceList,
UnknownCogniteObject,
WriteableCogniteResourceList,
)
from cognite.client.data_classes.data_modeling._validation import validate_data_modeling_identifier
Expand Down Expand Up @@ -288,7 +289,7 @@ def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None =
return cast(Self, RequiresConstraint.load(resource))
elif resource["constraintType"] == "uniqueness":
return cast(Self, UniquenessConstraint.load(resource))
raise ValueError(f"Invalid constraint type {resource['constraintType']}")
return cast(Self, UnknownCogniteObject(resource))

@abstractmethod
def dump(self, camel_case: bool = True) -> dict[str, str | dict]:
Expand Down Expand Up @@ -337,7 +338,7 @@ def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None =
return cast(Self, BTreeIndex.load(resource))
if resource["indexType"] == "inverted":
return cast(Self, InvertedIndex.load(resource))
raise ValueError(f"Invalid index type {resource['indexType']}")
return cast(Self, UnknownCogniteObject(resource))

@abstractmethod
def dump(self, camel_case: bool = True) -> dict[str, str | dict]:
Expand Down
54 changes: 27 additions & 27 deletions cognite/client/data_classes/data_modeling/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@

import logging
from abc import ABC
from dataclasses import asdict, dataclass, field
from typing import Any, ClassVar
from dataclasses import asdict, dataclass
from typing import TYPE_CHECKING, Any, ClassVar

from typing_extensions import Self
from typing_extensions import Self, TypeAlias

from cognite.client.data_classes._base import CogniteObject, UnknownCogniteObject
from cognite.client.data_classes.data_modeling.ids import ContainerId
from cognite.client.utils._auxiliary import rename_and_exclude_keys
from cognite.client.utils._text import convert_all_keys_recursive

if TYPE_CHECKING:
from cognite.client import CogniteClient

logger = logging.getLogger(__name__)

_PROPERTY_ALIAS = {"isList": "list", "is_list": "list"}
Expand Down Expand Up @@ -44,7 +48,7 @@ def as_tuple(self) -> tuple[str, str]:


@dataclass
class PropertyType(ABC):
class PropertyType(CogniteObject, ABC):
_type: ClassVar[str]
is_list: bool = False

Expand All @@ -65,50 +69,46 @@ def __load_unit_ref(data: dict) -> UnitReference | None:
return unit

@classmethod
def load(cls, data: dict) -> Self:
type_ = data["type"]
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
type_ = resource["type"]
obj: Any
if type_ == "text":
obj = Text(is_list=data["list"], collation=data["collation"])
obj = Text(is_list=resource["list"], collation=resource["collation"])
elif type_ == "boolean":
obj = Boolean(is_list=data["list"])
obj = Boolean(is_list=resource["list"])
elif type_ == "float32":
obj = Float32(is_list=data["list"], unit=cls.__load_unit_ref(data))
obj = Float32(is_list=resource["list"], unit=cls.__load_unit_ref(resource))
elif type_ == "float64":
obj = Float64(is_list=data["list"], unit=cls.__load_unit_ref(data))
obj = Float64(is_list=resource["list"], unit=cls.__load_unit_ref(resource))
elif type_ == "int32":
obj = Int32(is_list=data["list"], unit=cls.__load_unit_ref(data))
obj = Int32(is_list=resource["list"], unit=cls.__load_unit_ref(resource))
elif type_ == "int64":
obj = Int64(is_list=data["list"], unit=cls.__load_unit_ref(data))
obj = Int64(is_list=resource["list"], unit=cls.__load_unit_ref(resource))
elif type_ == "timestamp":
obj = Timestamp(is_list=data["list"])
obj = Timestamp(is_list=resource["list"])
elif type_ == "date":
obj = Date(is_list=data["list"])
obj = Date(is_list=resource["list"])
elif type_ == "json":
obj = Json(is_list=data["list"])
obj = Json(is_list=resource["list"])
elif type_ == "timeseries":
obj = TimeSeriesReference(is_list=data["list"])
obj = TimeSeriesReference(is_list=resource["list"])
elif type_ == "file":
obj = FileReference(is_list=data["list"])
obj = FileReference(is_list=resource["list"])
elif type_ == "sequence":
obj = SequenceReference(is_list=data["list"])
obj = SequenceReference(is_list=resource["list"])
elif type_ == "direct":
obj = DirectRelation(
container=ContainerId.load(container) if (container := data.get("container")) else None,
is_list=data["list"],
container=ContainerId.load(container) if (container := resource.get("container")) else None,
is_list=resource["list"],
)
else:
logger.warning(f"Unknown property type: {type_}")
obj = UnknownPropertyType(_data=data)
obj = UnknownCogniteObject(resource)
return obj


@dataclass
class UnknownPropertyType:
_data: dict[str, Any] = field(default_factory=dict)

def dump(self, camel_case: bool = True) -> dict[str, Any]:
return convert_all_keys_recursive(self._data, camel_case=camel_case)
# Kept around for backwards compatibility
UnknownPropertyType: TypeAlias = UnknownCogniteObject


@dataclass
Expand Down
16 changes: 6 additions & 10 deletions cognite/client/data_classes/data_modeling/instances.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@


@dataclass
class NodeOrEdgeData:
class NodeOrEdgeData(CogniteObject):
"""This represents the data values of a node or edge.
Args:
Expand All @@ -89,21 +89,21 @@ class NodeOrEdgeData:
properties: Mapping[str, PropertyValue]

@classmethod
def load(cls, data: dict) -> NodeOrEdgeData:
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
try:
source_type = data["source"]["type"]
source_type = resource["source"]["type"]
except KeyError as e:
raise ValueError("source must be a dict with a type key") from e
source: ContainerId | ViewId
if source_type == "container":
source = ContainerId.load(data["source"])
source = ContainerId.load(resource["source"])
elif source_type == "view":
source = ViewId.load(data["source"])
source = ViewId.load(resource["source"])
else:
raise ValueError(f"source type must be container or view, but was {source_type}")
return cls(
source=source,
properties=data["properties"],
properties=resource["properties"],
)

def dump(self, camel_case: bool = True) -> dict:
Expand Down Expand Up @@ -997,10 +997,6 @@ class InstancesResult:
nodes: NodeList
edges: EdgeList

@classmethod
def load(cls, data: str | dict) -> InstancesResult:
raise NotImplementedError


@dataclass
class InstancesApplyResult:
Expand Down
21 changes: 8 additions & 13 deletions cognite/client/data_classes/data_modeling/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,19 +156,14 @@ def load_yaml(cls, data: str) -> Query:

@classmethod
def _load(cls, resource: dict[str, Any], cognite_client: CogniteClient | None = None) -> Self:
if not (with_ := resource.get("with")):
raise ValueError("The query must contain a with key")

loaded: dict[str, Any] = {"with_": {k: ResultSetExpression.load(v) for k, v in with_.items()}}
if not (select := resource.get("select")):
raise ValueError("The query must contain a select key")
loaded["select"] = {k: Select.load(v) for k, v in select.items()}

if parameters := resource.get("parameters"):
loaded["parameters"] = dict(parameters.items())
if cursors := resource.get("cursors"):
loaded["cursors"] = dict(cursors.items())
return cls(**loaded)
parameters = dict(resource["parameters"].items()) if "parameters" in resource else None
cursors = dict(resource["cursors"].items()) if "cursors" in resource else None
return cls(
with_={k: ResultSetExpression.load(v) for k, v in resource["with"].items()},
select={k: Select.load(v) for k, v in resource["select"].items()},
parameters=parameters,
cursors=cursors,
)

def __eq__(self, other: Any) -> bool:
return type(other) is type(self) and self.dump() == other.dump()
Expand Down
Loading

0 comments on commit c5f3703

Please sign in to comment.