In [4]:
from lsst.pipe.base import PipelineTaskConnections, PipelineTaskConfig, PipelineTask

# This lets a connection specify that at least N datasets are going to be used as
# an input
class RequireAtLeast:
    def __init__(self, n):
        self.n = n
    def __call__(self, config, quantumInputRefs, connection):
        return len(quantumInputRefs[connection]) >= self.n

class AssembleCoaddConnections(PipelineTaskConnections):
    inputWarps = Input(name="{inputCoaddName}Coadd_{warpType}Warp",
                       storageClass="ExposureF",
                       dimensions=("tract", "patch", "skymap", "visit", "instrument"),
                       deferLoad=True,
                       multiple=True,
                       checkFunction=RequireAtLeast(2))
    skyMap = PrerequisiteInput(name="{inputCoaddName}Coadd_skyMap",
                               storageClass="SkyMap",
                               dimensions=("skymap",))
    brightObjectMask = PrerequisiteInput(name="brightObjectMask",
                                         storageClass="ObjectMaskCatalog",
                                         dimensions=("tract", "patch", "skymap", "abstract_filter"))
    coaddExposure = Output(name="{outputCoaddName}Coadd",
                           storageClass="ExposureF",
                           dimensions=("tract", "patch", "skymap", "abstract_filter"))
    nImage = Output(name="{outputCoaddName}Coadd_nImage",
                    storageClass="ImageU",
                    dimensions=("tract", "patch", "skymap", "abstract_filter"))
    dimensions = ("tract", "patch", "abstract_filter", "skymap")
    defaultTemplates = {"inputCoaddName": "deep", "outputCoaddName": "deep", "warpType": "direct"}
    
    def __init__(self, *, config=None):
        super().__init__(config=None)
        # Filter out bright object masks if the task is not going to use them
        # Check that the attribute exists, as the connection class potentially may be used
        # for more than one config class, not clear at this point
        if hasattr(config, "doMaskBrightObjects") and not config.doMaskBrightObjects:
            self.prerequisiteInputs -= set(("brightObjectMask",))
        # Filter out nImage if one will not be produced
        if hasattr(config, "doNImage") and not config.doNImage:
            self.outputs -= set(("nImage"))

In [None]:
class AssembleCoaddConfig(PipelineTaskConfig, pipelineConnections=AssembleCoaddConnections):
    doMaskBrightObjects = pexConfig.Field(dtype=bool, default=False,
                                          doc="Set mask and flag bits for bright objects")
    ...
    doNImage = pexConfig.Field(dtype=bool, default=False,
                               doc="Create image of number of contributing exposures for each pixel")

In [None]:
class AssembleCoaddTask(PipelineTask):
    def runQuantum(self, butlerQC, inputRefs, outputRefs):
        inputs = butlerQC.get(inputRefs)
        skyMap = inputs["skyMap"]
        outputDataRef = next(iter(outputRefs.values()))
        skyInfo = makeSkyInfo(skyMap,
                              tractId=outputDataRef.dataId['tract'],
                              patchId=outputDataRef.dataId['patch'])
        warps = inputs['inputWarps']
        prepairedInputs = self.prepareInputs(warps)
        supplementaryData = self.makeSupplementaryDataGen3(...)
        retStruct = self.run(skyInfo, warps, prepairedInputs.imageScalerList, prepairedInputs.weightList,
                             supplementaryData=supplementaryData)
        self.processResults(retStruct.coaddExposure, inputs)
        butlerQc.put(retStruct, outputRefs)
    
    def prepairInputs(self, warps):
        statsCtrl = afwMath.StatisticsControl()
        statsCtrl.setNumSigmaClip(self.config.sigmaClip)
        statsCtrl.setNumIter(self.config.clipIter)
        statsCtrl.setAndMask(self.getBadPixelMask())
        statsCtrl.setNanSafe(True)       
        # and weightList: a list of the weight of the associated coadd tempExp
        # and imageScalerList: a list of scale factors for the associated coadd tempExp
        weightList = []
        imageScalerList = []
        tempExpName = self.getTempExpDatasetName(self.warpType)
        for warp in warps:
            tempExp = warp.get() # <- get from a defer object, allows sub queries (such as subseting) on dataRef
            ...