Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-27682: Add ability to save/load QuantumGraph with URI #161

Merged
merged 2 commits into from
Nov 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
57 changes: 56 additions & 1 deletion python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from ..connections import iterConnections
from ..pipeline import TaskDef
from lsst.daf.butler import Quantum, DatasetRef
from lsst.daf.butler import Quantum, DatasetRef, ButlerURI

from ._implDetails import _DatasetTracker, DatasetTypeName
from .quantumNode import QuantumNode, NodeId, BuildId
Expand Down Expand Up @@ -549,8 +549,63 @@ def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
except nx.NetworkXNoCycle:
return []

def saveUri(self, uri):
"""Save `QuantumGraph` to the specified URI.

Parameters
----------
uri : `ButlerURI` or `str`
URI to where the graph should be saved.
"""
uri = ButlerURI(uri)
if uri.getExtension() not in (".pickle", ".pkl"):
raise TypeError(f"Can currently only save a graph in pickle format not {uri}")
uri.write(pickle.dumps(self))

@classmethod
def loadUri(cls, uri, universe):
"""Read `QuantumGraph` from a URI.

Parameters
----------
uri : `ButlerURI` or `str`
URI from where to load the graph.
universe: `~lsst.daf.butler.DimensionUniverse`
DimensionUniverse instance, not used by the method itself but
needed to ensure that registry data structures are initialized.

Returns
-------
graph : `QuantumGraph`
Resulting QuantumGraph instance.

Raises
------
TypeError
Raised if pickle contains instance of a type other than
QuantumGraph.
Notes
-----
Reading Quanta from pickle requires existence of singleton
DimensionUniverse which is usually instantiated during Registry
initialization. To make sure that DimensionUniverse exists this method
accepts dummy DimensionUniverse argument.
"""
uri = ButlerURI(uri)
# With ButlerURI we have the choice of always using a local file
# or reading in the bytes directly. Reading in bytes can be more
# efficient for reasonably-sized pickle files when the resource
# is remote. For now use the local file variant. For a local file
# as_local() does nothing.
with uri.as_local() as local, open(local.ospath, "rb") as fd:
qgraph = pickle.load(fd)
if not isinstance(qgraph, QuantumGraph):
raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
return qgraph

def save(self, file):
"""Save QuantumGraph to a file.

Presently we store QuantumGraph in pickle format, this could
potentially change in the future if better format is found.

Expand Down
45 changes: 31 additions & 14 deletions tests/test_quantumGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from itertools import chain
import os
import pickle
import tempfile
import unittest
Expand Down Expand Up @@ -190,6 +191,16 @@ def setUp(self):
self.qGraph = QuantumGraph(quantumMap)
self.universe = universe

def _cleanGraphs(self, graph1, graph2):
# This is a hack for the unit test since the qualified name will be
# different as it will be __main__ here, but qualified to the
# unittest module name when restored
# Updates in place
for saved, loaded in zip(graph1._quanta.keys(),
graph2._quanta.keys()):
saved.taskName = saved.taskName.split('.')[-1]
loaded.taskName = loaded.taskName.split('.')[-1]

def testTaskGraph(self):
for taskDef in self.quantumMap.keys():
self.assertIn(taskDef, self.qGraph.taskGraph)
Expand All @@ -210,13 +221,7 @@ def testGetQuantumNodeByNodeId(self):
def testPickle(self):
stringify = pickle.dumps(self.qGraph)
restore: QuantumGraph = pickle.loads(stringify)
# This is a hack for the unit test since the qualified name will be
# different as it will be __main__ here, but qualified to the
# unittest module name when restored
for saved, loaded in zip(self.qGraph._quanta.keys(),
restore._quanta.keys()):
saved.taskName = saved.taskName.split('.')[-1]
loaded.taskName = loaded.taskName.split('.')[-1]
self._cleanGraphs(self.qGraph, restore)
self.assertEqual(self.qGraph, restore)

def testInputQuanta(self):
Expand Down Expand Up @@ -333,15 +338,27 @@ def testSaveLoad(self):
self.qGraph.save(tmpFile)
tmpFile.seek(0)
restore = QuantumGraph.load(tmpFile, self.universe)
# This is a hack for the unit test since the qualified name will be
# different as it will be __main__ here, but qualified to the
# unittest module name when restored
for saved, loaded in zip(self.qGraph._quanta.keys(),
restore._quanta.keys()):
saved.taskName = saved.taskName.split('.')[-1]
loaded.taskName = loaded.taskName.split('.')[-1]
self._cleanGraphs(self.qGraph, restore)
self.assertEqual(self.qGraph, restore)

def testSaveLoadUri(self):
uri = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".pickle") as tmpFile:
uri = tmpFile.name
self.qGraph.saveUri(uri)
restore = QuantumGraph.loadUri(uri, self.universe)
self._cleanGraphs(self.qGraph, restore)
self.assertEqual(self.qGraph, restore)
except Exception as e:
raise e
finally:
if uri is not None:
os.remove(uri)

with self.assertRaises(TypeError):
self.qGraph.saveUri("test.notpickle")

def testContains(self):
firstNode = next(iter(self.qGraph))
self.assertIn(firstNode, self.qGraph)
Expand Down