Skip to content

Commit

Permalink
Merge branch 'tickets/DM-26553'
Browse files Browse the repository at this point in the history
  • Loading branch information
morriscb committed Sep 28, 2020
2 parents f4d06a5 + 6f6a268 commit 8e9662f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 6 deletions.
8 changes: 3 additions & 5 deletions python/lsst/ap/pipe/createApFakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from lsst.pipe.base import PipelineTask, PipelineTaskConnections, Struct
import lsst.pipe.base.connectionTypes as connTypes
from lsst.pipe.tasks.insertFakes import InsertFakesConfig
from lsst.pipe.tasks.parquetTable import ParquetTable

__all__ = ["CreateRandomApFakesTask",
"CreateRandomApFakesConfig",
Expand All @@ -46,7 +45,7 @@ class CreateRandomApFakesConnections(PipelineTaskConnections,
fakeCat = connTypes.Output(
doc="Catalog of fake sources to draw inputs from.",
name="{CoaddName}Coadd_fakeSourceCat",
storageClass="Parquet",
storageClass="DataFrame",
dimensions=("tract", "skymap")
)

Expand Down Expand Up @@ -99,7 +98,7 @@ class CreateRandomApFakesConfig(
randomSeed = pexConfig.Field(
doc="Random seed to set for reproducible datasets",
dtype=int,
default=None,
default=1234,
)
visitSourceFlagCol = pexConfig.Field(
doc="Name of the column flagging objects for insertion into the visit "
Expand Down Expand Up @@ -176,8 +175,7 @@ def run(self, tractId, skyMap):
self.config.paBulge: np.ones(nFakes, dtype="float"),
self.config.sourceType: nFakes * ["star"]}

return Struct(
fakeCat=ParquetTable(dataFrame=pd.DataFrame(data=randData)))
return Struct(fakeCat=pd.DataFrame(data=randData))

def createRandomPositions(self, nFakes, boundingCircle, rng):
"""Create a set of spatially uniform randoms over the tract bounding
Expand Down
41 changes: 40 additions & 1 deletion tests/test_createApFakes.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@
#

import numpy as np
import shutil
import tempfile
import unittest

import lsst.daf.butler.tests as butlerTests
import lsst.geom as geom
from lsst.pipe.base import testUtils
import lsst.skymap as skyMap
import lsst.utils.tests

Expand Down Expand Up @@ -54,6 +58,41 @@ def setUp(self):
+ int(self.nSources * self.fraction))
self.rng = np.random.default_rng(1234)

def testRunQuantum(self):
"""Test the run quantum method with a gen3 butler.
"""
root = tempfile.mkdtemp()
dimensions = {"instrument": ["notACam"],
"skymap": ["deepCoadd_skyMap"],
"tract": [0, 42],
}
testRepo = butlerTests.makeTestRepo(root, dimensions)
fakesTask = CreateRandomApFakesTask()
connections = fakesTask.config.ConnectionsClass(
config=fakesTask.config)
butlerTests.addDatasetType(
testRepo,
connections.skyMap.name,
connections.skyMap.dimensions,
connections.skyMap.storageClass)
butlerTests.addDatasetType(
testRepo,
connections.fakeCat.name,
connections.fakeCat.dimensions,
connections.fakeCat.storageClass)

dataId = {"skymap": "deepCoadd_skyMap", "tract": 0}
butler = butlerTests.makeTestCollection(testRepo)
butler.put(self.simpleMap, "deepCoadd_skyMap", {"skymap": "deepCoadd_skyMap"})

quantum = testUtils.makeQuantum(
fakesTask, butler, dataId,
{key: dataId for key in {"skyMap", "fakeCat"}})
run = testUtils.runTestQuantum(fakesTask, butler, quantum, True)
# Actual input dataset omitted for simplicity
run.assert_called_once_with(tractId=dataId["tract"], skyMap=self.simpleMap)
shutil.rmtree(root, ignore_errors=True)

def testRun(self):
"""Test the run method.
"""
Expand All @@ -63,7 +102,7 @@ def testRun(self):
fakesTask = CreateRandomApFakesTask(config=fakesConfig)
bCircle = self.simpleMap.generateTract(self.tractId).getInnerSkyPolygon().getBoundingCircle()
result = fakesTask.run(self.tractId, self.simpleMap)
fakeCat = result.fakeCat.toDataFrame()
fakeCat = result.fakeCat
self.assertEqual(len(fakeCat), self.nSources)
for idx, row in fakeCat.iterrows():
self.assertTrue(
Expand Down

0 comments on commit 8e9662f

Please sign in to comment.