From 2be2f5497420dac26253ab37ff4473d678c6125e Mon Sep 17 00:00:00 2001 From: Nat Wilson Date: Sat, 8 Oct 2016 12:57:00 -0700 Subject: [PATCH] add bandindexer shape tests backport to 0.7.x --- karta/raster/band.py | 2 +- tests/band_tests.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 1 deletion(-) diff --git a/karta/raster/band.py b/karta/raster/band.py index c304f9e..f2cd297 100644 --- a/karta/raster/band.py +++ b/karta/raster/band.py @@ -113,7 +113,7 @@ def shape(self): elif len(self.bands) == 1: return self.bands[0].size else: - return (len(self.bands), self.bands[0].size[0], self.bands.size[1]) + return (len(self.bands), self.bands[0].size[0], self.bands[0].size[1]) @property def dtype(self): diff --git a/tests/band_tests.py b/tests/band_tests.py index 4cb15ba..9336d05 100644 --- a/tests/band_tests.py +++ b/tests/band_tests.py @@ -155,5 +155,33 @@ def test_set_multibanded_masked(self): self.assertEqual(np.sum(indexer[:,:]), 336) return + def test_set_multibanded_masked_array(self): + values = np.ones([16, 16]) + bands = [CompressedBand((16, 16), np.float32), + CompressedBand((16, 16), np.float32), + CompressedBand((16, 16), np.float32)] + + mask = np.zeros([16, 16], dtype=np.bool) + mask[8:, 2:] = True + + indexer = BandIndexer(bands) + indexer[:,:] = np.zeros([16, 16]) + indexer[mask] = np.ones(8*14) + + self.assertEqual(np.sum(indexer[:,:]), 336) + return + + def test_shape(self): + bands = [CompressedBand((16, 16), np.float32), + CompressedBand((16, 16), np.float32), + CompressedBand((16, 16), np.float32)] + + indexer1 = BandIndexer([bands[0]]) + self.assertEqual(indexer1.shape, (16, 16)) + + indexer3 = BandIndexer(bands) + self.assertEqual(indexer3.shape, (3, 16, 16)) + return + if __name__ == "__main__": unittest.main()