From c12091de43b207f404ce366c8c2eb4e2dbe80556 Mon Sep 17 00:00:00 2001 From: John Kirkham Date: Sun, 19 Feb 2017 20:10:48 -0500 Subject: [PATCH] Expand short slices If a slice is shorter than the shape provided to `reformat_slices`, assume that the user wanted to include everything in those remaining dimensions. So simply tack on an `Ellipsis` to the slices. This way it automatically includes everything else after the given slice(s). --- kenjutsu/format.py | 9 +++++++-- tests/test_format.py | 18 ++++++++++++++++-- tests/test_kenjutsu.py | 8 -------- 3 files changed, 23 insertions(+), 12 deletions(-) diff --git a/kenjutsu/format.py b/kenjutsu/format.py index 9847d2e..431b637 100644 --- a/kenjutsu/format.py +++ b/kenjutsu/format.py @@ -222,8 +222,13 @@ def reformat_slices(slices, lengths=None): pass if new_lengths is not None and el_idx is None: - if len(new_slices) != len(new_lengths): - raise ValueError("Shape must be the same as the number of slices.") + if len(new_slices) < len(new_lengths): + new_slices += (Ellipsis,) + el_idx = new_slices.index(Ellipsis) + elif len(new_slices) > len(new_lengths): + raise ValueError( + "Shape must be as large or larger than the number of slices." + ) elif new_lengths is not None: if (len(new_slices) - 1) > len(new_lengths): raise ValueError( diff --git a/tests/test_format.py b/tests/test_format.py index 88c856b..ca53aee 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -270,11 +270,13 @@ def test_reformat_slice(self): def test_reformat_slices(self): with self.assertRaises(ValueError) as e: - format.reformat_slices((slice(None),), (1, 2)) + format.reformat_slices( + (slice(None), slice(None)), (1,) + ) self.assertEqual( str(e.exception), - "Shape must be the same as the number of slices." + "Shape must be as large or larger than the number of slices." ) with self.assertRaises(ValueError) as e: @@ -350,12 +352,24 @@ def test_reformat_slices(self): (slice(0, 10, 1),) ) + rf_slice = format.reformat_slices(slice(None), (1, 2)) + self.assertEqual( + rf_slice, + (slice(0, 1, 1), slice(0, 2, 1)) + ) + rf_slice = format.reformat_slices((slice(None),), 10) self.assertEqual( rf_slice, (slice(0, 10, 1),) ) + rf_slice = format.reformat_slices((slice(None),), (1, 2)) + self.assertEqual( + rf_slice, + (slice(0, 1, 1), slice(0, 2, 1)) + ) + rf_slice = format.reformat_slices(( -1, slice(None), diff --git a/tests/test_kenjutsu.py b/tests/test_kenjutsu.py index 48e8b33..cda11f1 100644 --- a/tests/test_kenjutsu.py +++ b/tests/test_kenjutsu.py @@ -195,14 +195,6 @@ def test_reformat_slice(self): def test_reformat_slices(self): - with self.assertRaises(ValueError) as e: - kenjutsu.reformat_slices((slice(None),), (1, 2)) - - self.assertEqual( - str(e.exception), - "Shape must be the same as the number of slices." - ) - with self.assertRaises(ValueError) as e: kenjutsu.reformat_slices( (slice(None), slice(None), Ellipsis), (1,)