Skip to content

Commit

Permalink
Expand short slices
Browse files Browse the repository at this point in the history
If a slice is shorter than the shape provided to `reformat_slices`,
assume that the user wanted to include everything in those remaining
dimensions. So simply tack on an `Ellipsis` to the slices. This way it
automatically includes everything else after the given slice(s).
  • Loading branch information
jakirkham committed Feb 20, 2017
1 parent c568f33 commit c12091d
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
9 changes: 7 additions & 2 deletions kenjutsu/format.py
Expand Up @@ -222,8 +222,13 @@ def reformat_slices(slices, lengths=None):
pass

if new_lengths is not None and el_idx is None:
if len(new_slices) != len(new_lengths):
raise ValueError("Shape must be the same as the number of slices.")
if len(new_slices) < len(new_lengths):
new_slices += (Ellipsis,)
el_idx = new_slices.index(Ellipsis)
elif len(new_slices) > len(new_lengths):
raise ValueError(
"Shape must be as large or larger than the number of slices."
)
elif new_lengths is not None:
if (len(new_slices) - 1) > len(new_lengths):
raise ValueError(
Expand Down
18 changes: 16 additions & 2 deletions tests/test_format.py
Expand Up @@ -270,11 +270,13 @@ def test_reformat_slice(self):

def test_reformat_slices(self):
with self.assertRaises(ValueError) as e:
format.reformat_slices((slice(None),), (1, 2))
format.reformat_slices(
(slice(None), slice(None)), (1,)
)

self.assertEqual(
str(e.exception),
"Shape must be the same as the number of slices."
"Shape must be as large or larger than the number of slices."
)

with self.assertRaises(ValueError) as e:
Expand Down Expand Up @@ -350,12 +352,24 @@ def test_reformat_slices(self):
(slice(0, 10, 1),)
)

rf_slice = format.reformat_slices(slice(None), (1, 2))
self.assertEqual(
rf_slice,
(slice(0, 1, 1), slice(0, 2, 1))
)

rf_slice = format.reformat_slices((slice(None),), 10)
self.assertEqual(
rf_slice,
(slice(0, 10, 1),)
)

rf_slice = format.reformat_slices((slice(None),), (1, 2))
self.assertEqual(
rf_slice,
(slice(0, 1, 1), slice(0, 2, 1))
)

rf_slice = format.reformat_slices((
-1,
slice(None),
Expand Down
8 changes: 0 additions & 8 deletions tests/test_kenjutsu.py
Expand Up @@ -195,14 +195,6 @@ def test_reformat_slice(self):


def test_reformat_slices(self):
with self.assertRaises(ValueError) as e:
kenjutsu.reformat_slices((slice(None),), (1, 2))

self.assertEqual(
str(e.exception),
"Shape must be the same as the number of slices."
)

with self.assertRaises(ValueError) as e:
kenjutsu.reformat_slices(
(slice(None), slice(None), Ellipsis), (1,)
Expand Down

0 comments on commit c12091d

Please sign in to comment.