-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added 100x faster vectorized version #39
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -60,6 +60,10 @@ def selectSources(self, sourceCat, matches=None): | |
""" | ||
!Return a catalog of sources: a subset of sourceCat. | ||
|
||
If sourceCat is cotiguous in memory, will use vectorized tests for ~100x | ||
execution speed advantage over non-contiguous catalogs. This would be | ||
even faster if we didn't have to check footprints for multiple peaks. | ||
|
||
@param[in] sourceCat catalog of sources that may be sources | ||
(an lsst.afw.table.SourceCatalog) | ||
|
||
|
@@ -68,16 +72,24 @@ def selectSources(self, sourceCat, matches=None): | |
""" | ||
self._getSchemaKeys(sourceCat.schema) | ||
|
||
result = table.SourceCatalog(sourceCat.table) | ||
for source in sourceCat: | ||
if self._isGood(source) and not self._isBad(source): | ||
result.append(source) | ||
if sourceCat.isContiguous(): | ||
bad = reduce(lambda x, y: np.logical_or(x, sourceCat.get(y)), self.config.badFlags, False) | ||
good = self._isGood_vector(sourceCat) | ||
result = sourceCat[good & ~bad] | ||
else: | ||
result = table.SourceCatalog(sourceCat.table) | ||
for i, source in enumerate(sourceCat): | ||
if self._isGood(source) and not self._isBad(source): | ||
result.append(source) | ||
return Struct(sourceCat=result) | ||
|
||
def _getSchemaKeys(self, schema): | ||
"""Extract and save the necessary keys from schema with asKey.""" | ||
self.parentKey = schema["parent"].asKey() | ||
self.nChildKey = schema["deblend_nChild"].asKey() | ||
self.centroidXKey = schema["slot_Centroid_x"].asKey() | ||
self.centroidYKey = schema["slot_Centroid_y"].asKey() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought the field was "slot_Centroid" and it returned a tuple of the (y, x) centroid location. Does this also work? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That works for individual sources, but I need the x/y centroids for the vectorized np.isfinite() test: there isn't an equivalent SourceCatalog thing to get a vector of tuples. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small correction: when used as a single key |
||
self.centroidFlagKey = schema["slot_Centroid_flag"].asKey() | ||
|
||
self.edgeKey = schema["base_PixelFlags_flag_edge"].asKey() | ||
self.interpolatedCenterKey = schema["base_PixelFlags_flag_interpolatedCenter"].asKey() | ||
|
@@ -88,26 +100,66 @@ def _getSchemaKeys(self, schema): | |
self.fluxFlagKey = schema[fluxPrefix + "flag"].asKey() | ||
self.fluxSigmaKey = schema[fluxPrefix + "fluxSigma"].asKey() | ||
|
||
def _isMultiple_vector(self, sourceCat): | ||
"""Return True for each source that is likely multiple sources.""" | ||
test = (sourceCat.get(self.parentKey) != 0) | (sourceCat.get(self.nChildKey) != 0) | ||
# have to currently manage footprints on a source-by-source basis. | ||
for i, cat in enumerate(sourceCat): | ||
footprint = cat.getFootprint() | ||
test[i] |= (footprint is not None) and (len(footprint.getPeaks()) > 1) | ||
return test | ||
|
||
def _isMultiple(self, source): | ||
"""Return True if source is likely multiple sources.""" | ||
if (source.get(self.parentKey) != 0) or (source.get(self.nChildKey) != 0): | ||
return True | ||
footprint = source.getFootprint() | ||
return footprint is not None and len(footprint.getPeaks()) > 1 | ||
|
||
def _hasCentroid_vector(self, sourceCat): | ||
"""Return True for each source that has a valid centroid""" | ||
return np.isfinite(sourceCat.get(self.centroidXKey)) \ | ||
& np.isfinite(sourceCat.get(self.centroidYKey)) \ | ||
& ~sourceCat.get(self.centroidFlagKey) | ||
|
||
def _hasCentroid(self, source): | ||
"""Return True if the source has a valid centroid""" | ||
centroid = source.getCentroid() | ||
return np.all(np.isfinite(centroid)) and not source.getCentroidFlag() | ||
|
||
def _goodSN_vector(self, sourceCat): | ||
"""Return True for each source that has Signal/Noise > config.minSnr.""" | ||
if self.config.minSnr <= 0: | ||
return True | ||
else: | ||
return sourceCat.get(self.fluxKey)/sourceCat.get(self.fluxSigmaKey) > self.config.minSnr | ||
|
||
def _goodSN(self, source): | ||
"""Return True if source has Signal/Noise > config.minSnr.""" | ||
return (self.config.minSnr <= 0 or | ||
(source.get(self.fluxKey)/source.get(self.fluxSigmaKey) > self.config.minSnr)) | ||
|
||
def _isUsable_vector(self, sourceCat): | ||
""" | ||
Return True for each source that is usable for matching, even if it may | ||
have a poor centroid. | ||
|
||
For a source to be usable it must: | ||
- have a valid centroid | ||
- not be deblended | ||
- have a valid flux (of the type specified in this object's constructor) | ||
- have adequate signal-to-noise | ||
""" | ||
|
||
return self._hasCentroid_vector(sourceCat) \ | ||
& ~self._isMultiple_vector(sourceCat) \ | ||
& self._goodSN_vector(sourceCat) \ | ||
& ~sourceCat.get(self.fluxFlagKey) | ||
|
||
def _isUsable(self, source): | ||
""" | ||
Return True if the source is usable for matching, even if it may have a poor centroid. | ||
Return True if the source is usable for matching, even if it may have a | ||
poor centroid. | ||
|
||
For a source to be usable it must: | ||
- have a valid centroid | ||
|
@@ -120,6 +172,22 @@ def _isUsable(self, source): | |
and not source.get(self.fluxFlagKey) \ | ||
and self._goodSN(source) | ||
|
||
def _isGood_vector(self, sourceCat): | ||
""" | ||
Return True for each source that is usable for matching and likely has a | ||
good centroid. | ||
|
||
The additional tests for a good centroid, beyond isUsable, are: | ||
- not interpolated in the center | ||
- not saturated | ||
- not near the edge | ||
""" | ||
|
||
return self._isUsable_vector(sourceCat) \ | ||
& ~sourceCat.get(self.saturatedKey) \ | ||
& ~sourceCat.get(self.interpolatedCenterKey) \ | ||
& ~sourceCat.get(self.edgeKey) | ||
|
||
def _isGood(self, source): | ||
""" | ||
Return True if source is usable for matching and likely has a good centroid. | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't like the code duplication of having vectorized and non-vectorized versions of every function, but I suppose that can't be helped.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, but until we have a generic solution, I'd rather not break non-contiguous functionality.