Skip to content

Commit

Permalink
gracefully handle nsubbundles>bundlesize
Browse files Browse the repository at this point in the history
  • Loading branch information
sbailey committed Oct 25, 2017
1 parent 625b596 commit 2d20748
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 6 deletions.
2 changes: 1 addition & 1 deletion doc/release_notes.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ specter Release Notes
0.8.1 (unreleased)
------------------

* No changes yet.
* Robust even if nsubbundles>bundlesize.

0.8.0 (2017-09-29)
------------------
Expand Down
9 changes: 7 additions & 2 deletions py/specter/extract/ex2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,10 @@ def ex2d(image, imageivar, psf, specmin, nspec, wavelengths, xyrange=None,
#- index of last spectrum, non-inclusive, i.e. python-style indexing
bundlehi = min(bundlelo+bundlesize, specmin+nspec)

iibundle, iiextract = split_bundle(bundlehi-bundlelo, nsubbundles)
nsub = min(bundlehi-bundlelo, nsubbundles)
iibundle, iiextract = split_bundle(bundlehi-bundlelo, nsub)

for subbundle_index in range(nsubbundles):
for subbundle_index in range(len(iiextract)):
speclo = bundlelo + iiextract[subbundle_index][0]
spechi = bundlelo + iiextract[subbundle_index][-1]+1
keep = np.in1d(iiextract[subbundle_index], iibundle[subbundle_index])
Expand Down Expand Up @@ -499,6 +500,10 @@ def split_bundle(bundlesize, n):
([array([0, 1, 2]), array([3, 4, 5]), array([6, 7, 8, 9])],
[array([0, 1, 2, 3]), array([2, 3, 4, 5, 6]), array([5, 6, 7, 8, 9])])
'''
if n > bundlesize:
raise ValueError('n={} should be less or equal to bundlesize={}'.format(
n, bundlesize))

#- initial partition into subbundles
n_per_subbundle = [len(x) for x in np.array_split(np.arange(bundlesize), n)]

Expand Down
12 changes: 9 additions & 3 deletions py/specter/test/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,9 +268,11 @@ def test_wave_off_image(self):
self.assertTrue( np.all(flux == flux) )

def test_subbundles(self):
for nsubbundles in (2,3):
flux, ivar, Rdata = ex2d(self.image, self.ivar, self.psf, 0, self.nspec,
self.ww, wavesize=len(self.ww)//5, nsubbundles=nsubbundles)
#- should work even if nsubbundles > bundlesize
for nsubbundles in (2,3, 2*self.nspec):
flux, ivar, Rdata = ex2d(self.image, self.ivar, self.psf, 0,
self.nspec, self.ww, wavesize=len(self.ww)//5,
bundlesize=self.nspec, nsubbundles=nsubbundles)

self.assertEqual(flux.shape, (self.nspec, len(self.ww)))
self.assertEqual(ivar.shape, (self.nspec, len(self.ww)))
Expand All @@ -297,5 +299,9 @@ def test_split_bundle(self):
self.assertEqual(len(iisub[0]), 25)
self.assertTrue(np.all(iisub[0] == iiextract[0]))

#- n>bundlesize isn't allowed
with self.assertRaises(ValueError):
iisub, iiextract = split_bundle(3, 7)

if __name__ == '__main__':
unittest.main()

0 comments on commit 2d20748

Please sign in to comment.