Skip to content

Commit

Permalink
[NEAT-183] Refactor DataTypes (#419)
Browse files Browse the repository at this point in the history
* refactor: first pass

* refactor: better implementation

* refactor; handle dict

* refactor; handle dict

* refactor; serializable entity

* refactor; handle default space and version

* refactor: backwards compatible

* refactor: backwards compatible

* refactor: fix issue with id

* feat: ParentEntityList

* entities: backwards compatability

* refactor: New literals

* style: happy mypy

* refactor; rename Literal to data types

* tests; updated test after renaming

* refactor; renaming
  • Loading branch information
doctrino committed Apr 29, 2024
1 parent d874ebf commit adca5dc
Show file tree
Hide file tree
Showing 2 changed files with 324 additions and 0 deletions.
259 changes: 259 additions & 0 deletions cognite/neat/rules/models/data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
import sys
from datetime import date, datetime
from typing import Any, ClassVar

from cognite.client.data_classes import data_modeling as dms
from pydantic import BaseModel, model_serializer, model_validator
from pydantic.functional_validators import ModelWrapValidatorHandler

if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self


class DataType(BaseModel):
name: ClassVar[str]
python: ClassVar[type]
dms: ClassVar[type[dms.PropertyType]]
graphql: ClassVar[str]
xsd: ClassVar[str]
sql: ClassVar[str]

@classmethod
def load(cls, data: Any) -> Self:
return cls.model_validate(data)

def dump(self) -> dict[str, Any]:
return self.model_dump(by_alias=True)

@model_validator(mode="wrap")
def _load(cls, value: Any, handler: ModelWrapValidatorHandler["DataType"]) -> Any:
if isinstance(value, cls | dict):
return value
elif isinstance(value, str):
try:
return _DATA_TYPE_BY_NAME[value.casefold()]()
except KeyError:
raise ValueError(f"Unknown literal type: {value}") from None
raise ValueError(f"Cannot load {cls.__name__} from {value}")

@model_serializer(when_used="unless-none", return_type=str)
def as_str(self) -> str:
return str(self)

def __str__(self) -> str:
return self.name

def __eq__(self, other: Any) -> bool:
return isinstance(other, type(self))


class Boolean(DataType):
name = "boolean"
python = bool
dms = dms.Boolean
graphql = "Boolean"
xsd = "xsd:boolean"
sql = "BOOLEAN"


class Float(DataType):
name = "float"
python = float
dms = dms.Float32
graphql = "Float"
xsd = "xsd:float"
sql = "FLOAT"


class Double(DataType):
name = "double"
python = float
dms = dms.Float64
graphql = "Float"
xsd = "xsd:double"
sql = "FLOAT"


class Integer(DataType):
name = "integer"
python = int
dms = dms.Int32
graphql = "Int"
xsd = "xsd:integer"
sql = "INTEGER"


class NonPositiveInteger(DataType):
name = "nonPositiveInteger"
python = int
dms = dms.Int32
graphql = "Int"
xsd = "xsd:nonPositiveInteger"
sql = "INTEGER"


class NonNegativeInteger(DataType):
name = "nonNegativeInteger"
python = int
dms = dms.Int32
graphql = "Int"
xsd = "xsd:nonNegativeInteger"
sql = "INTEGER"


class NegativeInteger(DataType):
name = "negativeInteger"
python = int
dms = dms.Int32
graphql = "Int"
xsd = "xsd:negativeInteger"
sql = "INTEGER"


class Long(DataType):
name = "long"
python = int
dms = dms.Int64
graphql = "Int"
xsd = "xsd:long"
sql = "BIGINT"


class String(DataType):
name = "string"
python = str
dms = dms.Text
graphql = "String"
xsd = "xsd:string"
sql = "STRING"


class LangString(DataType):
name = "langString"
python = str
dms = dms.Text
graphql = "String"
xsd = "xsd:string"
sql = "STRING"


class AnyURI(DataType):
name = "anyURI"
python = str
dms = dms.Text
graphql = "String"
xsd = "xsd:anyURI"
sql = "STRING"


class NormalizedString(DataType):
name = "normalizedString"
python = str
dms = dms.Text
graphql = "String"
xsd = "xsd:normalizedString"
sql = "STRING"


class Token(DataType):
name = "token"
python = str
dms = dms.Text
graphql = "String"
xsd = "xsd:string"
sql = "STRING"


class DateTime(DataType):
name = "dateTime"
python = datetime
dms = dms.Timestamp
graphql = "Timestamp"
xsd = "xsd:dateTimeStamp"
sql = "TIMESTAMP"


class DateTimeStamp(DataType):
name = "dateTimeStamp"
python = datetime
dms = dms.Timestamp
graphql = "Timestamp"
xsd = "xsd:dateTimeStamp"
sql = "TIMESTAMP"


class Timestamp(DataType):
name = "timestamp"
python = datetime
dms = dms.Timestamp
graphql = "Timestamp"
xsd = "xsd:dateTimeStamp"
sql = "TIMESTAMP"


class Date(DataType):
name = "date"
python = date
dms = dms.Date
graphql = "String"
xsd = "xsd:date"
sql = "DATE"


class PlainLiteral(DataType):
name = "PlainLiteral"
python = str
dms = dms.Text
graphql = "String"
xsd = "xsd:string"
sql = "STRING"


class Literal(DataType):
name = "Literal"
python = str
dms = dms.Text
graphql = "String"
xsd = "xsd:string"
sql = "STRING"


class Timeseries(DataType):
name = "timeseries"
python = dms.TimeSeriesReference
dms = dms.TimeSeriesReference
graphql = "TimeSeries"
xsd = "xsd:string"
sql = "STRING"


class File(DataType):
name = "file"
python = dms.FileReference
dms = dms.FileReference
graphql = "File"
xsd = "xsd:string"
sql = "STRING"


class Sequence(DataType):
name = "sequence"
python = dms.SequenceReference
dms = dms.SequenceReference
graphql = "Sequence"
xsd = "xsd:string"
sql = "STRING"


class Json(DataType):
name = "json"
python = dms.Json
dms = dms.Json
graphql = "Json"
xsd = "xsd:string"
sql = "STRING"


_DATA_TYPE_BY_NAME = {cls.name.casefold(): cls for cls in DataType.__subclasses__()}
65 changes: 65 additions & 0 deletions tests/tests_unit/rules/test_models/test_data_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Any

import pytest
from pydantic import BaseModel, Field

from cognite.neat.rules.models.data_types import (
Boolean,
DataType,
Double,
Float,
Integer,
Literal,
NonNegativeInteger,
NonPositiveInteger,
)
from cognite.neat.rules.models.entities import ClassEntity


class DemoProperty(BaseModel):
property_: str = Field(alias="property")
value_type: DataType | ClassEntity = Field(alias="valueType")

def dump(self) -> dict[str, Any]:
return self.model_dump(by_alias=True)


class TestLiterals:
@pytest.mark.parametrize(
"raw, expected",
[
("boolean", Boolean()),
("float", Float()),
("double", Double()),
("integer", Integer()),
("nonPositiveInteger", NonPositiveInteger()),
("nonNegativeInteger", NonNegativeInteger()),
],
)
def test_load(self, raw: str, expected: Literal):
loaded = Literal.load(raw)

assert loaded == expected

@pytest.mark.parametrize(
"raw, expected",
[
(
{"property": "a_boolean", "valueType": "boolean"},
DemoProperty(property="a_boolean", valueType=Boolean()),
),
(
{"property": "a_float", "valueType": "float"},
DemoProperty(property="a_float", valueType=Float()),
),
(
{"property": "a_class", "valueType": "my_namespace:person"},
DemoProperty(property="a_class", valueType=ClassEntity(prefix="my_namespace", suffix="person")),
),
],
)
def test_create_property(self, raw: dict[str, Any], expected: DemoProperty):
loaded = DemoProperty.model_validate(raw)

assert loaded == expected
assert loaded.dump() == raw

0 comments on commit adca5dc

Please sign in to comment.