Skip to content

Commit

Permalink
Merge pull request #161 from lsst/tickets/DM-27682
Browse files Browse the repository at this point in the history
DM-27682: Add ability to save/load QuantumGraph with URI
  • Loading branch information
timj committed Nov 23, 2020
2 parents 71224ee + 130e945 commit 73474ed
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 15 deletions.
57 changes: 56 additions & 1 deletion python/lsst/pipe/base/graph/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@

from ..connections import iterConnections
from ..pipeline import TaskDef
from lsst.daf.butler import Quantum, DatasetRef
from lsst.daf.butler import Quantum, DatasetRef, ButlerURI

from ._implDetails import _DatasetTracker, DatasetTypeName
from .quantumNode import QuantumNode, NodeId, BuildId
Expand Down Expand Up @@ -549,8 +549,63 @@ def findCycle(self) -> List[Tuple[QuantumNode, QuantumNode]]:
except nx.NetworkXNoCycle:
return []

def saveUri(self, uri):
"""Save `QuantumGraph` to the specified URI.
Parameters
----------
uri : `ButlerURI` or `str`
URI to where the graph should be saved.
"""
uri = ButlerURI(uri)
if uri.getExtension() not in (".pickle", ".pkl"):
raise TypeError(f"Can currently only save a graph in pickle format not {uri}")
uri.write(pickle.dumps(self))

@classmethod
def loadUri(cls, uri, universe):
"""Read `QuantumGraph` from a URI.
Parameters
----------
uri : `ButlerURI` or `str`
URI from where to load the graph.
universe: `~lsst.daf.butler.DimensionUniverse`
DimensionUniverse instance, not used by the method itself but
needed to ensure that registry data structures are initialized.
Returns
-------
graph : `QuantumGraph`
Resulting QuantumGraph instance.
Raises
------
TypeError
Raised if pickle contains instance of a type other than
QuantumGraph.
Notes
-----
Reading Quanta from pickle requires existence of singleton
DimensionUniverse which is usually instantiated during Registry
initialization. To make sure that DimensionUniverse exists this method
accepts dummy DimensionUniverse argument.
"""
uri = ButlerURI(uri)
# With ButlerURI we have the choice of always using a local file
# or reading in the bytes directly. Reading in bytes can be more
# efficient for reasonably-sized pickle files when the resource
# is remote. For now use the local file variant. For a local file
# as_local() does nothing.
with uri.as_local() as local, open(local.ospath, "rb") as fd:
qgraph = pickle.load(fd)
if not isinstance(qgraph, QuantumGraph):
raise TypeError(f"QuantumGraph pickle file has contains unexpected object type: {type(qgraph)}")
return qgraph

def save(self, file):
"""Save QuantumGraph to a file.
Presently we store QuantumGraph in pickle format, this could
potentially change in the future if better format is found.
Expand Down
45 changes: 31 additions & 14 deletions tests/test_quantumGraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
# along with this program. If not, see <http://www.gnu.org/licenses/>.

from itertools import chain
import os
import pickle
import tempfile
import unittest
Expand Down Expand Up @@ -190,6 +191,16 @@ def setUp(self):
self.qGraph = QuantumGraph(quantumMap)
self.universe = universe

def _cleanGraphs(self, graph1, graph2):
# This is a hack for the unit test since the qualified name will be
# different as it will be __main__ here, but qualified to the
# unittest module name when restored
# Updates in place
for saved, loaded in zip(graph1._quanta.keys(),
graph2._quanta.keys()):
saved.taskName = saved.taskName.split('.')[-1]
loaded.taskName = loaded.taskName.split('.')[-1]

def testTaskGraph(self):
for taskDef in self.quantumMap.keys():
self.assertIn(taskDef, self.qGraph.taskGraph)
Expand All @@ -210,13 +221,7 @@ def testGetQuantumNodeByNodeId(self):
def testPickle(self):
stringify = pickle.dumps(self.qGraph)
restore: QuantumGraph = pickle.loads(stringify)
# This is a hack for the unit test since the qualified name will be
# different as it will be __main__ here, but qualified to the
# unittest module name when restored
for saved, loaded in zip(self.qGraph._quanta.keys(),
restore._quanta.keys()):
saved.taskName = saved.taskName.split('.')[-1]
loaded.taskName = loaded.taskName.split('.')[-1]
self._cleanGraphs(self.qGraph, restore)
self.assertEqual(self.qGraph, restore)

def testInputQuanta(self):
Expand Down Expand Up @@ -333,15 +338,27 @@ def testSaveLoad(self):
self.qGraph.save(tmpFile)
tmpFile.seek(0)
restore = QuantumGraph.load(tmpFile, self.universe)
# This is a hack for the unit test since the qualified name will be
# different as it will be __main__ here, but qualified to the
# unittest module name when restored
for saved, loaded in zip(self.qGraph._quanta.keys(),
restore._quanta.keys()):
saved.taskName = saved.taskName.split('.')[-1]
loaded.taskName = loaded.taskName.split('.')[-1]
self._cleanGraphs(self.qGraph, restore)
self.assertEqual(self.qGraph, restore)

def testSaveLoadUri(self):
uri = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".pickle") as tmpFile:
uri = tmpFile.name
self.qGraph.saveUri(uri)
restore = QuantumGraph.loadUri(uri, self.universe)
self._cleanGraphs(self.qGraph, restore)
self.assertEqual(self.qGraph, restore)
except Exception as e:
raise e
finally:
if uri is not None:
os.remove(uri)

with self.assertRaises(TypeError):
self.qGraph.saveUri("test.notpickle")

def testContains(self):
firstNode = next(iter(self.qGraph))
self.assertIn(firstNode, self.qGraph)
Expand Down

0 comments on commit 73474ed

Please sign in to comment.