diff --git a/kenjutsu/format.py b/kenjutsu/format.py index 9847d2e..431b637 100644 --- a/kenjutsu/format.py +++ b/kenjutsu/format.py @@ -222,8 +222,13 @@ def reformat_slices(slices, lengths=None): pass if new_lengths is not None and el_idx is None: - if len(new_slices) != len(new_lengths): - raise ValueError("Shape must be the same as the number of slices.") + if len(new_slices) < len(new_lengths): + new_slices += (Ellipsis,) + el_idx = new_slices.index(Ellipsis) + elif len(new_slices) > len(new_lengths): + raise ValueError( + "Shape must be as large or larger than the number of slices." + ) elif new_lengths is not None: if (len(new_slices) - 1) > len(new_lengths): raise ValueError( diff --git a/tests/test_format.py b/tests/test_format.py index 88c856b..ca53aee 100644 --- a/tests/test_format.py +++ b/tests/test_format.py @@ -270,11 +270,13 @@ def test_reformat_slice(self): def test_reformat_slices(self): with self.assertRaises(ValueError) as e: - format.reformat_slices((slice(None),), (1, 2)) + format.reformat_slices( + (slice(None), slice(None)), (1,) + ) self.assertEqual( str(e.exception), - "Shape must be the same as the number of slices." + "Shape must be as large or larger than the number of slices." ) with self.assertRaises(ValueError) as e: @@ -350,12 +352,24 @@ def test_reformat_slices(self): (slice(0, 10, 1),) ) + rf_slice = format.reformat_slices(slice(None), (1, 2)) + self.assertEqual( + rf_slice, + (slice(0, 1, 1), slice(0, 2, 1)) + ) + rf_slice = format.reformat_slices((slice(None),), 10) self.assertEqual( rf_slice, (slice(0, 10, 1),) ) + rf_slice = format.reformat_slices((slice(None),), (1, 2)) + self.assertEqual( + rf_slice, + (slice(0, 1, 1), slice(0, 2, 1)) + ) + rf_slice = format.reformat_slices(( -1, slice(None), diff --git a/tests/test_kenjutsu.py b/tests/test_kenjutsu.py index 48e8b33..cda11f1 100644 --- a/tests/test_kenjutsu.py +++ b/tests/test_kenjutsu.py @@ -195,14 +195,6 @@ def test_reformat_slice(self): def test_reformat_slices(self): - with self.assertRaises(ValueError) as e: - kenjutsu.reformat_slices((slice(None),), (1, 2)) - - self.assertEqual( - str(e.exception), - "Shape must be the same as the number of slices." - ) - with self.assertRaises(ValueError) as e: kenjutsu.reformat_slices( (slice(None), slice(None), Ellipsis), (1,)