Skip to content

Commit

Permalink
Adopt a new QuantumGraph save format
Browse files Browse the repository at this point in the history
This ticket modifies quantum graph to support using UUID to
identify QuantumNodes and serialize the graph out without using
pickle.
  • Loading branch information
natelust committed Dec 12, 2021
1 parent c527dea commit b2964b5
Show file tree
Hide file tree
Showing 5 changed files with 997 additions and 276 deletions.
244 changes: 93 additions & 151 deletions python/lsst/pipe/base/graph/_loadHelpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,27 +19,23 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
from __future__ import annotations
from uuid import UUID

__all__ = ("LoadHelper", )

from lsst.resources import ResourcePath
from lsst.daf.butler import Quantum
from lsst.resources.s3 import S3ResourcePath
from lsst.resources.file import FileResourcePath
from lsst.daf.butler import DimensionUniverse

from ..pipeline import TaskDef
from .quantumNode import NodeId

from dataclasses import dataclass
import functools
import io
import json
import lzma
import pickle
import struct

from collections import defaultdict, UserDict
from typing import (Optional, Iterable, DefaultDict, Set, Dict, TYPE_CHECKING, Tuple, Type, Union)
from collections import UserDict
from typing import Optional, Iterable, TYPE_CHECKING, Type, Union, IO

if TYPE_CHECKING:
from . import QuantumGraph
Expand Down Expand Up @@ -99,46 +95,38 @@ class DefaultLoadHelper:
uriObject : `~lsst.resources.ResourcePath` or `io.IO` of bytes
This is the object that will be used to retrieve the raw bytes of the
save.
minimumVersion : `int`
Minimum version of a save file to load. Set to -1 to load all
versions. Older versions may need to be loaded, and re-saved
to upgrade them to the latest format. This upgrade may not happen
deterministically each time an older graph format is loaded. Because
of this behavior, the minimumVersion parameter, forces a user to
interact manually and take this into account before they can be used in
production.
Raises
------
ValueError
Raised if the specified file contains the wrong file signature and is
not a `QuantumGraph` save
not a `QuantumGraph` save, or if the graph save version is below the
minimum specified version.
"""
def __init__(self, uriObject: Union[ResourcePath, io.IO[bytes]]):
self.uriObject = uriObject

# The length of infoSize will either be a tuple with length 2,
# (version 1) which contains the lengths of 2 independent pickles,
# or a tuple of length 1 which contains the total length of the entire
# header information (minus the magic bytes and version bytes)
preambleSize, infoSize = self._readSizes()

# Recode the total header size
if self.save_version == 1:
self.headerSize = preambleSize + infoSize[0] + infoSize[1]
elif self.save_version == 2:
self.headerSize = preambleSize + infoSize[0]
else:
raise ValueError(f"Unable to load QuantumGraph with version {self.save_version}, "
"please try a newer version of the code.")

self._readByteMappings(preambleSize, self.headerSize, infoSize)
def __init__(self, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int):
headerBytes = self.__setup_impl(uriObject, minimumVersion)
self.headerInfo = self.deserializer.readHeaderInfo(headerBytes)

def _readSizes(self) -> Tuple[int, Tuple[int, ...]]:
def __setup_impl(self, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int) -> bytes:
self.uriObject = uriObject
# need to import here to avoid cyclic imports
from .graph import STRUCT_FMT_BASE, MAGIC_BYTES, STRUCT_FMT_STRING, SAVE_VERSION
# Read the first few bytes which correspond to the lengths of the
# magic identifier bytes, 2 byte version
# number and the two 8 bytes numbers that are the sizes of the byte
# maps
magicSize = len(MAGIC_BYTES)
from .graph import STRUCT_FMT_BASE, MAGIC_BYTES, SAVE_VERSION
from ._versionDeserializers import DESERIALIZER_MAP

# Read the first few bytes which correspond to the magic identifier
# bytes, and save version
magicSize = len(MAGIC_BYTES)
# read in just the fmt base to determine the save version
fmtSize = struct.calcsize(STRUCT_FMT_BASE)
preambleSize = magicSize + fmtSize

headerBytes = self._readBytes(0, preambleSize)
magic = headerBytes[:magicSize]
versionBytes = headerBytes[magicSize:]
Expand All @@ -147,61 +135,49 @@ def _readSizes(self) -> Tuple[int, Tuple[int, ...]]:
raise ValueError("This file does not appear to be a quantum graph save got magic bytes "
f"{magic}, expected {MAGIC_BYTES}")

# Turn they encode bytes back into a python int object
# unpack the save version bytes and verify it is a version that this
# code can understand
save_version, = struct.unpack(STRUCT_FMT_BASE, versionBytes)
# loads can sometimes trigger upgrades in format to a latest version,
# in which case accessory code might not match the upgraded graph.
# I.E. switching from old node number to UUID. This clause necessitates
# that users specifically interact with older graph versions and verify
# everything happens appropriately.
if save_version < minimumVersion:
raise ValueError(f"The loaded QuantumGraph is version {save_version}, and the minimum "
f"version specified is {minimumVersion}. Please re-run this method "
"with a lower minimum version, then re-save the graph to automatically upgrade"
"to the newest version. Older versions may not work correctly with newer code")

if save_version > SAVE_VERSION:
raise RuntimeError(f"The version of this save file is {save_version}, but this version of"
f"Quantum Graph software only knows how to read up to version {SAVE_VERSION}")

# read in the next bits
fmtString = STRUCT_FMT_STRING[save_version]
infoSize = struct.calcsize(fmtString)
infoBytes = self._readBytes(preambleSize, preambleSize+infoSize)
infoUnpack = struct.unpack(fmtString, infoBytes)

preambleSize += infoSize

# Store the save version, so future read codes can make use of any
# format changes to the save protocol
self.save_version = save_version

return preambleSize, infoUnpack

def _readByteMappings(self, preambleSize: int, headerSize: int, infoSize: Tuple[int, ...]) -> None:
# Take the header size explicitly so subclasses can modify before
# This task is called

# read the bytes of taskDef bytes and nodes skipping the size bytes
headerMaps = self._readBytes(preambleSize, headerSize)

if self.save_version == 1:
taskDefSize, _ = infoSize

# read the map of taskDef bytes back in skipping the size bytes
self.taskDefMap = pickle.loads(headerMaps[:taskDefSize])

# read back in the graph id
self._buildId = self.taskDefMap['__GraphBuildID']

# read the map of the node objects back in skipping bytes
# corresponding to the taskDef byte map
self.map = pickle.loads(headerMaps[taskDefSize:])

# There is no metadata for old versions
self.metadata = None
elif self.save_version == 2:
uncompressedHeaderMap = lzma.decompress(headerMaps)
header = json.loads(uncompressedHeaderMap)
self.taskDefMap = header['TaskDefs']
self._buildId = header['GraphBuildID']
self.map = dict(header['Nodes'])
self.metadata = header['Metadata']
else:
raise ValueError(f"Unable to load QuantumGraph with version {self.save_version}, "
"please try a newer version of the code.")

def load(self, nodes: Optional[Iterable[int]] = None, graphID: Optional[str] = None) -> QuantumGraph:
# select the appropriate deserializer for this save version
deserializerClass = DESERIALIZER_MAP[save_version]

# read in the bytes corresponding to the mappings and initialize the
# deserializer. This will be the bytes that describe the following
# byte boundaries of the header info
sizeBytes = self._readBytes(preambleSize, preambleSize+deserializerClass.structSize)
self.deserializer = deserializerClass(preambleSize, sizeBytes)

# get the header info
headerBytes = self._readBytes(preambleSize+deserializerClass.structSize,
self.deserializer.headerSize)
return headerBytes

@classmethod
def dumpHeader(cls, uriObject: Union[ResourcePath, io.IO[bytes]], minimumVersion: int = 3
) -> Optional[str]:
instance = cls.__new__(cls)
headerBytes = instance.__setup_impl(uriObject, minimumVersion)
header = instance.deserializer.unpackHeader(headerBytes)
instance.close()
return header

def load(self, universe: DimensionUniverse, nodes: Optional[Iterable[Union[str, UUID]]] = None,
graphID: Optional[str] = None) -> QuantumGraph:
"""Loads in the specified nodes from the graph
Load in the `QuantumGraph` containing only the nodes specified in the
Expand All @@ -210,6 +186,9 @@ def load(self, nodes: Optional[Iterable[int]] = None, graphID: Optional[str] = N
Parameters
----------
universe: `~lsst.daf.butler.DimensionUniverse`
DimensionUniverse instance, not used by the method itself but
needed to ensure that registry data structures are initialized.
nodes : `Iterable` of `int` or `None`
The nodes to load from the graph, loads all if value is None
(the default)
Expand All @@ -230,79 +209,27 @@ def load(self, nodes: Optional[Iterable[int]] = None, graphID: Optional[str] = N
`QuantumGraph` or if graphID parameter does not match the graph
being loaded.
"""
# need to import here to avoid cyclic imports
from . import QuantumGraph
if graphID is not None and self._buildId != graphID:
# verify this is the expected graph
if graphID is not None and self.headerInfo._buildId != graphID:
raise ValueError('graphID does not match that of the graph being loaded')
# Read in specified nodes, or all the nodes
if nodes is None:
nodes = list(self.map.keys())
nodes = list(self.headerInfo.map.keys())
# if all nodes are to be read, force the reader from the base class
# that will read all they bytes in one go
_readBytes = functools.partial(DefaultLoadHelper._readBytes, self)
else:
# only some bytes are being read using the reader specialized for
# this class
# create a set to ensure nodes are only loaded once
nodes = set(nodes)
nodes = {UUID(n) if isinstance(n, str) else n for n in nodes}
# verify that all nodes requested are in the graph
remainder = nodes - self.map.keys()
remainder = nodes - self.headerInfo.map.keys()
if remainder:
raise ValueError("Nodes {remainder} were requested, but could not be found in the input "
raise ValueError(f"Nodes {remainder} were requested, but could not be found in the input "
"graph")
_readBytes = self._readBytes
# create a container for loaded data
quanta: DefaultDict[TaskDef, Set[Quantum]] = defaultdict(set)
quantumToNodeId: Dict[Quantum, NodeId] = {}
loadedTaskDef = {}
# loop over the nodes specified above
for node in nodes:
# Get the bytes to read from the map
if self.save_version == 1:
start, stop = self.map[node]
else:
start, stop = self.map[node]['bytes']
start += self.headerSize
stop += self.headerSize

# read the specified bytes, will be overloaded by subclasses
# bytes are compressed, so decompress them
dump = lzma.decompress(_readBytes(start, stop))

# reconstruct node
qNode = pickle.loads(dump)

# read the saved node, name. If it has been loaded, attach it, if
# not read in the taskDef first, and then load it
nodeTask = qNode.taskDef
if nodeTask not in loadedTaskDef:
# Get the byte ranges corresponding to this taskDef
if self.save_version == 1:
start, stop = self.taskDefMap[nodeTask]
else:
start, stop = self.taskDefMap[nodeTask]['bytes']
start += self.headerSize
stop += self.headerSize

# load the taskDef, this method call will be overloaded by
# subclasses.
# bytes are compressed, so decompress them
taskDef = pickle.loads(lzma.decompress(_readBytes(start, stop)))
loadedTaskDef[nodeTask] = taskDef
# Explicitly overload the "frozen-ness" of nodes to attach the
# taskDef back into the un-persisted node
object.__setattr__(qNode, 'taskDef', loadedTaskDef[nodeTask])
quanta[qNode.taskDef].add(qNode.quantum)

# record the node for later processing
quantumToNodeId[qNode.quantum] = qNode.nodeId

# construct an empty new QuantumGraph object, and run the associated
# creation method with the un-persisted data
qGraph = object.__new__(QuantumGraph)
qGraph._buildGraphs(quanta, _quantumToNodeId=quantumToNodeId, _buildId=self._buildId,
metadata=self.metadata)
return qGraph
return self.deserializer.constructGraph(nodes, _readBytes, universe)

def _readBytes(self, start: int, stop: int) -> bytes:
"""Loads the specified byte range from the ResourcePath object
Expand Down Expand Up @@ -369,10 +296,10 @@ class OpenFileHandleHelper(DefaultLoadHelper):

# This helper does support partial loading

def __init__(self, uriObject: io.IO[bytes]):
def __init__(self, uriObject: io.IO[bytes], minimumVersion: int):
# Explicitly annotate type and not infer from super
self.uriObject: io.IO[bytes]
super().__init__(uriObject)
super().__init__(uriObject, minimumVersion=minimumVersion)
# This differs from the default __init__ to force the io object
# back to the beginning so that in the case the entire file is to
# read in the file is not already in a partially read state.
Expand All @@ -398,19 +325,34 @@ class LoadHelper:
This helper will raise a `ValueError` if the specified file does not appear
to be a valid `QuantumGraph` save file.
"""
uri: ResourcePath
uri: Union[ResourcePath, IO[bytes]]
"""ResourcePath object from which the `QuantumGraph` is to be loaded
"""
minimumVersion: int
"""
Minimum version of a save file to load. Set to -1 to load all
versions. Older versions may need to be loaded, and re-saved
to upgrade them to the latest format before they can be used in
production.
"""

def __enter__(self):
# Only one handler is registered for anything that is an instance of
# IOBase, so if any type is a subtype of that, set the key explicitly
# so the correct loader is found, otherwise index by the type
self._loaded = self._determineLoader()(self.uri, self.minimumVersion)
return self._loaded

def __exit__(self, type, value, traceback):
self._loaded.close()

def _determineLoader(self) -> Type[DefaultLoadHelper]:
if isinstance(self.uri, io.IOBase):
key = io.IOBase
else:
key = type(self.uri)
self._loaded = HELPER_REGISTRY[key](self.uri)
return self._loaded
return HELPER_REGISTRY[key]

def __exit__(self, type, value, traceback):
self._loaded.close()
def readHeader(self) -> Optional[str]:
type_ = self._determineLoader()
return type_.dumpHeader(self.uri, self.minimumVersion)

0 comments on commit b2964b5

Please sign in to comment.