Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-25222: Fix handling of skip-existing in PreExecInit #55

Merged
merged 2 commits into from
Jun 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
60 changes: 40 additions & 20 deletions python/lsst/ctrl/mpexec/preExecInit.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,16 @@ def saveInitOutputs(self, graph):
# check if it is there already
_LOG.debug("Retrieving InitOutputs for task=%s key=%s dsTypeName=%s",
task, name, attribute.name)
objFromStore = self.butler.get(attribute.name, {})
if objFromStore is not None:
try:
objFromStore = self.butler.get(attribute.name, {})
# Types are supposed to be identical.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Datastore.get does check that the returned python object is an instance of the StorageClass python type. It does not require that types match directly. Your test here is more stringent than butler -- does it ever trigger?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have no idea, probably not because we did not see any questions/complaints about this. What we want to do is to make sure that objects are identical but there is no way to do it now for arbitrary types, checking that types are identical is the minimum that we can do here.

# TODO: Check that object contents is identical too.
if type(objFromStore) is not type(initOutputVar):
raise TypeError(f"Stored initOutput object type {type(objFromStore)} "
f"is different from task-generated type "
f"{type(initOutputVar)} for task {taskDef}")
except LookupError:
pass
if objFromStore is None:
# butler will raise exception if dataset is already there
_LOG.debug("Saving InitOutputs for task=%s key=%s", task, name)
Expand Down Expand Up @@ -225,12 +227,14 @@ def logConfigMismatch(msg):

oldConfig = None
if self.skipExisting:
oldConfig = self.butler.get(configName, {})
if oldConfig is not None:
try:
oldConfig = self.butler.get(configName, {})
if not taskDef.config.compare(oldConfig, shortcut=False, output=logConfigMismatch):
raise TypeError(
f"Config does not match existing task config {configName!r} in butler; "
"tasks configurations must be consistent within the same run collection")
except LookupError:
pass
if oldConfig is None:
# butler will raise exception if dataset is already there
_LOG.debug("Saving Config for task=%s dataset type=%s", taskDef.label, configName)
Expand All @@ -251,20 +255,36 @@ def savePackageVersions(self, graph):
compatible.
"""
packages = Packages.fromSystem()
_LOG.debug("want to save packages: %s", packages)
datasetType = "packages"
oldPackages = self.butler.get(datasetType, {}) if self.skipExisting else None
if oldPackages is not None:
# Note that because we can only detect python modules that have been imported, the stored
# list of products may be more or less complete than what we have now. What's important is
# that the products that are in common have the same version.
diff = packages.difference(oldPackages)
if diff:
versions_str = "; ".join(f"{pkg}: {diff[pkg][1]} vs {diff[pkg][0]}" for pkg in diff)
raise TypeError(f"Package versions mismatch: ({versions_str})")
# Update the old set of packages in case we have more packages that haven't been persisted.
extra = packages.extra(oldPackages)
if extra:
oldPackages.update(packages)
self.butler.put(oldPackages, datasetType, {})
else:
self.butler.put(packages, datasetType, {})
dataId = {}
oldPackages = None
# start transaction to rollback any changes on exceptions
with self.butler.transaction():
if self.skipExisting:
try:
oldPackages = self.butler.get(datasetType, dataId, collections=[self.butler.run])
_LOG.debug("old packages: %s", oldPackages)
except LookupError:
pass
if oldPackages is not None:
# Note that because we can only detect python modules that have been imported, the stored
# list of products may be more or less complete than what we have now. What's important is
# that the products that are in common have the same version.
diff = packages.difference(oldPackages)
if diff:
versions_str = "; ".join(f"{pkg}: {diff[pkg][1]} vs {diff[pkg][0]}" for pkg in diff)
raise TypeError(f"Package versions mismatch: ({versions_str})")
else:
_LOG.debug("new packages are consistent with old")
# Update the old set of packages in case we have more packages that haven't been persisted.
extra = packages.extra(oldPackages)
if extra:
_LOG.debug("extra packages: %s", extra)
oldPackages.update(packages)
# have to remove existing dataset first, butler nas no replace option
ref = self.butler.registry.findDataset(datasetType, dataId, collections=[self.butler.run])
self.butler.pruneDatasets([ref], unstore=True, purge=True)
self.butler.put(oldPackages, datasetType, dataId)
else:
self.butler.put(packages, datasetType, dataId)
2 changes: 1 addition & 1 deletion tests/testUtil.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def get(self, datasetRefOrType, dataId=None, parameters=None, **kwds):
dsdata = self.datasets.get(dsTypeName)
if dsdata:
return dsdata.get(key)
return None
raise LookupError

def put(self, obj, datasetRefOrType, dataId=None, producer=None, **kwds):
datasetType, dataId = self._standardizeArgs(datasetRefOrType, dataId, **kwds)
Expand Down
102 changes: 102 additions & 0 deletions tests/test_preExecInit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
# This file is part of ctrl_mpexec.
#
# Developed for the LSST Data Management System.
# This product includes software developed by the LSST Project
# (https://www.lsst.org).
# See the COPYRIGHT file at the top-level directory of this distribution
# for details of code ownership.
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

"""Simple unit test for PreExecInit class.
"""

import unittest

from lsst.ctrl.mpexec import PreExecInit
from testUtil import makeSimpleQGraph, AddTaskFactoryMock


class PreExecInitTestCase(unittest.TestCase):
"""A test case for PreExecInit
"""

def test_saveInitOutputs(self):
taskFactory = AddTaskFactoryMock()
for skipExisting in (False, True):
with self.subTest(skipExisting=skipExisting):
butler, qgraph = makeSimpleQGraph()
preExecInit = PreExecInit(butler=butler, taskFactory=taskFactory, skipExisting=skipExisting)
preExecInit.saveInitOutputs(qgraph)

def test_saveInitOutputs_twice(self):
taskFactory = AddTaskFactoryMock()
for skipExisting in (False, True):
with self.subTest(skipExisting=skipExisting):
butler, qgraph = makeSimpleQGraph()
preExecInit = PreExecInit(butler=butler, taskFactory=taskFactory, skipExisting=skipExisting)
preExecInit.saveInitOutputs(qgraph)
if skipExisting:
# will ignore this
preExecInit.saveInitOutputs(qgraph)
else:
# Second time it will fail
with self.assertRaises(Exception):
preExecInit.saveInitOutputs(qgraph)

def test_saveConfigs(self):
for skipExisting in (False, True):
with self.subTest(skipExisting=skipExisting):
butler, qgraph = makeSimpleQGraph()
preExecInit = PreExecInit(butler=butler, taskFactory=None, skipExisting=skipExisting)
preExecInit.saveConfigs(qgraph)

def test_saveConfigs_twice(self):
for skipExisting in (False, True):
with self.subTest(skipExisting=skipExisting):
butler, qgraph = makeSimpleQGraph()
preExecInit = PreExecInit(butler=butler, taskFactory=None, skipExisting=skipExisting)
preExecInit.saveConfigs(qgraph)
if skipExisting:
# will ignore this
preExecInit.saveConfigs(qgraph)
else:
# Second time it will fail
with self.assertRaises(Exception):
preExecInit.saveConfigs(qgraph)

def test_savePackageVersions(self):
for skipExisting in (False, True):
with self.subTest(skipExisting=skipExisting):
butler, qgraph = makeSimpleQGraph()
preExecInit = PreExecInit(butler=butler, taskFactory=None, skipExisting=skipExisting)
preExecInit.savePackageVersions(qgraph)

def test_savePackageVersions_twice(self):
for skipExisting in (False, True):
with self.subTest(skipExisting=skipExisting):
butler, qgraph = makeSimpleQGraph()
preExecInit = PreExecInit(butler=butler, taskFactory=None, skipExisting=skipExisting)
preExecInit.savePackageVersions(qgraph)
if skipExisting:
# if this is the same packages then it should not attempt to save
preExecInit.savePackageVersions(qgraph)
else:
# second time it will fail
with self.assertRaises(Exception):
preExecInit.savePackageVersions(qgraph)


if __name__ == "__main__":
unittest.main()