Skip to content

Commit

Permalink
Merge pull request #844 from lsst/tickets/DM-39434
Browse files Browse the repository at this point in the history
DM-39434: Avoid defaultdict with lambda in pickled dataclasses
  • Loading branch information
andy-slac committed May 31, 2023
2 parents ec51e16 + 719fcf9 commit 15ae3e5
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 28 deletions.
48 changes: 24 additions & 24 deletions python/lsst/daf/butler/core/datastoreRecordData.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@

import dataclasses
import uuid
from collections import defaultdict
from typing import TYPE_CHECKING, AbstractSet, Any, Dict, List, Optional, Union
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any

from lsst.utils import doImportType
from lsst.utils.introspection import get_full_type_name
Expand All @@ -41,24 +41,24 @@
if TYPE_CHECKING:
from ..registry import Registry

_Record = Dict[str, Any]
_Record = dict[str, Any]


class SerializedDatastoreRecordData(BaseModel):
"""Representation of a `DatastoreRecordData` suitable for serialization."""

dataset_ids: List[uuid.UUID]
dataset_ids: list[uuid.UUID]
"""List of dataset IDs"""

records: Dict[str, Dict[str, List[_Record]]]
records: Mapping[str, Mapping[str, list[_Record]]]
"""List of records indexed by record class name and table name."""

@classmethod
def direct(
cls,
*,
dataset_ids: List[Union[str, uuid.UUID]],
records: Dict[str, Dict[str, List[_Record]]],
dataset_ids: list[str | uuid.UUID],
records: dict[str, dict[str, list[_Record]]],
) -> SerializedDatastoreRecordData:
"""Construct a `SerializedDatastoreRecordData` directly without
validators.
Expand Down Expand Up @@ -92,8 +92,8 @@ class DatastoreRecordData:
datastore.
"""

records: defaultdict[DatasetId, defaultdict[str, List[StoredDatastoreItemInfo]]] = dataclasses.field(
default_factory=lambda: defaultdict(lambda: defaultdict(list))
records: dict[DatasetId, dict[str, list[StoredDatastoreItemInfo]]] = dataclasses.field(
default_factory=dict
)
"""Opaque table data, indexed by dataset ID and grouped by opaque table
name."""
Expand All @@ -111,11 +111,11 @@ def update(self, other: DatastoreRecordData) -> None:
Merged instances can not have identical records.
"""
for dataset_id, table_records in other.records.items():
this_table_records = self.records[dataset_id]
this_table_records = self.records.setdefault(dataset_id, {})
for table_name, records in table_records.items():
this_table_records[table_name].extend(records)
this_table_records.setdefault(table_name, []).extend(records)

def subset(self, dataset_ids: AbstractSet[DatasetId]) -> Optional[DatastoreRecordData]:
def subset(self, dataset_ids: set[DatasetId]) -> DatastoreRecordData | None:
"""Extract a subset of the records that match given dataset IDs.
Parameters
Expand All @@ -133,9 +133,7 @@ def subset(self, dataset_ids: AbstractSet[DatasetId]) -> Optional[DatastoreRecor
Records in the returned instance are shared with this instance, clients
should not update or extend records in the returned instance.
"""
matching_records: defaultdict[
DatasetId, defaultdict[str, List[StoredDatastoreItemInfo]]
] = defaultdict(lambda: defaultdict(list))
matching_records: dict[DatasetId, dict[str, list[StoredDatastoreItemInfo]]] = {}
for dataset_id in dataset_ids:
if (id_records := self.records.get(dataset_id)) is not None:
matching_records[dataset_id] = id_records
Expand Down Expand Up @@ -171,19 +169,22 @@ def _class_name(records: list[StoredDatastoreItemInfo]) -> str:
assert len(classes) == 1, f"Records have to be of the same class: {classes}"
return get_full_type_name(classes.pop())

records: defaultdict[str, defaultdict[str, List[_Record]]] = defaultdict(lambda: defaultdict(list))
records: dict[str, dict[str, list[_Record]]] = {}
for table_data in self.records.values():
for table_name, table_records in table_data.items():
class_name = _class_name(table_records)
records[class_name][table_name].extend([record.to_record() for record in table_records])
class_records = records.setdefault(class_name, {})
class_records.setdefault(table_name, []).extend(
[record.to_record() for record in table_records]
)
return SerializedDatastoreRecordData(dataset_ids=list(self.records.keys()), records=records)

@classmethod
def from_simple(
cls,
simple: SerializedDatastoreRecordData,
universe: Optional[DimensionUniverse] = None,
registry: Optional[Registry] = None,
universe: DimensionUniverse | None = None,
registry: Registry | None = None,
) -> DatastoreRecordData:
"""Make an instance of this class from serialized data.
Expand All @@ -203,17 +204,16 @@ def from_simple(
item_info : `StoredDatastoreItemInfo`
De-serialized instance of `StoredDatastoreItemInfo`.
"""
records: defaultdict[DatasetId, defaultdict[str, List[StoredDatastoreItemInfo]]] = defaultdict(
lambda: defaultdict(list)
)
records: dict[DatasetId, dict[str, list[StoredDatastoreItemInfo]]] = {}
# make sure that all dataset IDs appear in the dict even if they don't
# have records.
for dataset_id in simple.dataset_ids:
records[dataset_id] = defaultdict(list)
records[dataset_id] = {}
for class_name, table_data in simple.records.items():
klass = doImportType(class_name)
for table_name, table_records in table_data.items():
for record in table_records:
info = klass.from_record(record)
records[info.dataset_id][table_name].append(info)
dataset_type_records = records.setdefault(info.dataset_id, {})
dataset_type_records.setdefault(table_name, []).append(info)
return cls(records=records)
3 changes: 3 additions & 0 deletions python/lsst/daf/butler/core/storedFileInfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,6 @@ def update(self, **kwargs: Any) -> StoredFileInfo:
if kwargs:
raise ValueError(f"Unexpected keyword arguments for update: {', '.join(kwargs)}")
return type(self)(**new_args)

def __reduce__(self) -> str | tuple[Any, ...]:
return (self.from_record, (self.to_record(),))
7 changes: 3 additions & 4 deletions python/lsst/daf/butler/datastores/fileDatastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -2934,12 +2934,11 @@ def export_records(self, refs: Iterable[DatasetIdRef]) -> Mapping[str, Datastore
# Docstring inherited from the base class.
exported_refs = list(self._bridge.check(refs))
ids = {ref.id for ref in exported_refs}
records: defaultdict[DatasetId, defaultdict[str, List[StoredDatastoreItemInfo]]] = defaultdict(
lambda: defaultdict(list), {id: defaultdict(list) for id in ids}
)
records: dict[DatasetId, dict[str, list[StoredDatastoreItemInfo]]] = {id: {} for id in ids}
for row in self._table.fetch(dataset_id=ids):
info: StoredDatastoreItemInfo = StoredFileInfo.from_record(row)
records[info.dataset_id][self._table.name].append(info)
dataset_records = records.setdefault(info.dataset_id, {})
dataset_records.setdefault(self._table.name, []).append(info)

record_data = DatastoreRecordData(records=records)
return {self.name: record_data}
Expand Down
15 changes: 15 additions & 0 deletions tests/test_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
from __future__ import annotations

import os
import pickle
import shutil
import sys
import tempfile
import time
import unittest
import unittest.mock
import uuid
from collections import UserDict
from dataclasses import dataclass

Expand Down Expand Up @@ -904,6 +906,14 @@ def testExportImportRecords(self):
record_data = records[datastore_name]
self.assertEqual(len(record_data.records), n_refs)

# Check that subsetting works, include non-existing dataset ID.
dataset_ids = {exported_refs[0].id, uuid.uuid4()}
subset = record_data.subset(dataset_ids)
assert subset is not None
self.assertEqual(len(subset.records), 1)
subset = record_data.subset({uuid.uuid4()})
self.assertIsNone(subset)

# Use the same datastore name to import relative path.
datastore2 = self.makeDatastore("test_datastore")

Expand Down Expand Up @@ -1931,6 +1941,11 @@ def test_StoredFileInfo(self):
with self.assertRaises(ValueError):
rebased.update(something=42, new="42")

# Check that pickle works on StoredFileInfo.
pickled_info = pickle.dumps(info)
unpickled_info = pickle.loads(pickled_info)
self.assertEqual(unpickled_info, info)


if __name__ == "__main__":
unittest.main()

0 comments on commit 15ae3e5

Please sign in to comment.