Skip to content

Commit

Permalink
Support __iter__ and __getitem__ for dps classes with status codes (
Browse files Browse the repository at this point in the history
  • Loading branch information
haakonvt committed Apr 11, 2024
1 parent 9333494 commit e6cc816
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 18 deletions.
54 changes: 40 additions & 14 deletions cognite/client/data_classes/datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,10 @@ def to_pandas(self, camel_case: bool = False) -> pandas.DataFrame: # type: igno

return pd.DataFrame(dumped, index=[pd.Timestamp(timestamp, unit="ms")])

def dump(self, camel_case: bool = True) -> dict[str, Any]:
# Keep value even if None (bad status codes support missing):
return {"value": self.value, **super().dump(camel_case=camel_case)}


class DatapointsArray(CogniteResource):
"""An object representing datapoints using numpy arrays."""
Expand Down Expand Up @@ -433,13 +437,23 @@ def __getitem__(self, item: int | slice) -> Datapoint | DatapointsArray:
if isinstance(item, slice):
return self._slice(item)
attrs, arrays = self._data_fields()
return Datapoint(
timestamp=arrays[0][item].item() // 1_000_000,
**{attr: numpy_dtype_fix(arr[item]) for attr, arr in zip(attrs[1:], arrays[1:])}, # type: ignore [arg-type]
)
timestamp = arrays[0][item].item() // 1_000_000
data = {attr: numpy_dtype_fix(arr[item]) for attr, arr in zip(attrs[1:], arrays[1:])}

if self.status_code is not None:
data.update(status_code=self.status_code[item], status_symbol=self.status_symbol[item]) # type: ignore [index]
if self.null_timestamps and timestamp in self.null_timestamps:
data["value"] = None # type: ignore [assignment]
return Datapoint(timestamp=timestamp, **data) # type: ignore [arg-type]

def _slice(self, part: slice) -> DatapointsArray:
data: dict[str, Any] = {attr: arr[part] for attr, arr in zip(*self._data_fields())}
if self.status_code is not None:
data.update(status_code=self.status_code[part], status_symbol=self.status_symbol[part]) # type: ignore [index]
if self.null_timestamps is not None:
data["null_timestamps"] = self.null_timestamps.intersection(
data["timestamp"].astype("datetime64[ms]").astype(np.int64).tolist()
)
return DatapointsArray(**self._ts_info, **data)

def __iter__(self) -> Iterator[Datapoint]:
Expand All @@ -461,13 +475,15 @@ def __iter__(self) -> Iterator[Datapoint]:
)
attrs, arrays = self._data_fields()
# Let's not create a single Datapoint more than we have too:
yield from (
Datapoint(
timestamp=row[0].item() // 1_000_000,
**dict(zip(attrs[1:], map(numpy_dtype_fix, row[1:]))), # type: ignore [arg-type]
)
for row in zip(*arrays)
)
for i, row in enumerate(zip(*arrays)):
timestamp = row[0].item() // 1_000_000
data = dict(zip(attrs[1:], map(numpy_dtype_fix, row[1:])))
if self.status_code is not None:
data.update(status_code=self.status_code[i], status_symbol=self.status_symbol[i]) # type: ignore [index]
if self.null_timestamps and timestamp in self.null_timestamps:
data["value"] = None # type: ignore [assignment]

yield Datapoint(timestamp=timestamp, **data) # type: ignore [arg-type]

def _data_fields(self) -> tuple[list[str], list[npt.NDArray]]:
# Note: Does not return status-related fields
Expand Down Expand Up @@ -707,6 +723,10 @@ def __getitem__(self, item: int | slice) -> Datapoint | Datapoints:
dp_args = {}
for attr, values in self._get_non_empty_data_fields():
dp_args[attr] = values[item]

if self.status_code is not None:
dp_args.update(status_code=self.status_code[item], status_symbol=self.status_symbol[item]) # type: ignore [index]

return Datapoint(**dp_args)

def __iter__(self) -> Iterator[Datapoint]:
Expand Down Expand Up @@ -740,9 +760,6 @@ def dump(self, camel_case: bool = True) -> dict[str, Any]:

for dp, code, symbol in zip(datapoints, self.status_code, self.status_symbol):
dp["status"] = {"code": code, "symbol": symbol}
# When we're dealing with status codes, bad can have missing values:
if "value" not in dp:
dp["value"] = None
dumped["datapoints"] = datapoints

if camel_case:
Expand Down Expand Up @@ -920,6 +937,11 @@ def __get_datapoint_objects(self) -> list[Datapoint]:
dp_args = {}
for attr, value in fields:
dp_args[attr] = value[i]
if self.status_code is not None:
dp_args.update(
status_code=self.status_code[i],
status_symbol=self.status_symbol[i], # type: ignore [index]
)
new_dps_objects.append(Datapoint(**dp_args))
self.__datapoint_objects = new_dps_objects
return self.__datapoint_objects
Expand All @@ -932,9 +954,13 @@ def _slice(self, slice: slice) -> Datapoints:
is_step=self.is_step,
unit=self.unit,
unit_external_id=self.unit_external_id,
granularity=self.granularity,
)
for attr, value in self._get_non_empty_data_fields():
setattr(truncated_datapoints, attr, value[slice])
if self.status_code is not None:
truncated_datapoints.status_code = self.status_code[slice]
truncated_datapoints.status_symbol = self.status_symbol[slice] # type: ignore [index]
return truncated_datapoints

def _repr_html_(self) -> str:
Expand Down
35 changes: 35 additions & 0 deletions tests/tests_integration/test_api/test_datapoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from cognite.client import CogniteClient
from cognite.client.data_classes import (
Datapoint,
Datapoints,
DatapointsArray,
DatapointsArrayList,
Expand Down Expand Up @@ -698,6 +699,40 @@ def test_numpy_dtypes_conversions_for_string_and_numeric(self, cognite_client, a
assert type(dp_dumped["timestamp"]) is int # noqa: E721
assert type(dp_dumped["value"]) is str # noqa: E721

def test_getitem_and_iter_preserves_status_codes(self, cognite_client, ts_status_codes, retrieve_endpoints):
mixed_ts, *_ = ts_status_codes
for endpoint in retrieve_endpoints:
dps_res = endpoint(
id=mixed_ts.id, include_status=True, ignore_bad_datapoints=False, start=ts_to_ms("2023-02-11"), limit=5
)
# Test object itself, plus slice of object:
for dps in [dps_res, dps_res[:5]]:
for dp, code, symbol in zip(dps, dps.status_code, dps.status_symbol):
assert isinstance(dp, Datapoint)
assert code is not None and code == dp.status_code
assert symbol is not None and symbol == dp.status_symbol

assert math.isclose(dps.value[0], dps[0].value)
assert math.isclose(dps.value[4], dps[4].value)
assert math.isclose(dps.value[0], 432.9514228031592)
assert math.isclose(dps.value[4], 143.05065712951188)

assert dps.value[1] == dps[1].value == math.inf
assert math.isnan(dps.value[2]) and math.isnan(dps[2].value)

if isinstance(dps, Datapoints):
assert dps.value[3] is None
elif isinstance(dps, DatapointsArray):
assert math.isnan(dps.value[3])
bad_ts = dps.timestamp[3].item() // 1_000_000
assert dps.null_timestamps == {bad_ts}

# Test slicing a part without a missing value:
dps_slice = dps[:3]
assert not dps_slice.null_timestamps
else:
assert False

@pytest.mark.parametrize("test_is_string", (True, False))
def test_n_dps_retrieved_with_without_uncertain_and_bad(self, retrieve_endpoints, ts_status_codes, test_is_string):
if test_is_string:
Expand Down
8 changes: 7 additions & 1 deletion tests/tests_integration/test_cognite_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,11 @@ def test_delete(self, cognite_client):
def test_cognite_client_is_picklable(cognite_client):
if isinstance(cognite_client.config.credentials, (Token, OAuthClientCertificate)):
pytest.skip()
roundtrip_client = pickle.loads(pickle.dumps(cognite_client))
try:
roundtrip_client = pickle.loads(pickle.dumps(cognite_client))
except TypeError:
print(cognite_client) # noqa T201
print(type(cognite_client)) # noqa T201
print(vars(cognite_client)) # noqa T201
raise
assert roundtrip_client.iam.token.inspect().projects
6 changes: 3 additions & 3 deletions tests/tests_unit/test_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,14 +510,14 @@ def request_callback(request):
resource_cls=SomeResource,
resource_path=URL_PATH,
method="POST",
partitions=10,
partitions=15,
limit=None,
)
assert 503 == exc.value.code
assert exc.value.unknown == [("3/10",)]
assert exc.value.unknown == [("3/15",)]
assert exc.value.skipped
assert exc.value.successful
assert 9 == len(exc.value.successful) + len(exc.value.skipped)
assert 14 == len(exc.value.successful) + len(exc.value.skipped)
assert 1 < len(rsps.calls)

@pytest.fixture
Expand Down

0 comments on commit e6cc816

Please sign in to comment.