Skip to content

Commit

Permalink
Merge 6ed4950 into b4f5a51
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham committed Dec 5, 2016
2 parents b4f5a51 + 6ed4950 commit 0cfac3e
Show file tree
Hide file tree
Showing 2 changed files with 284 additions and 12 deletions.
78 changes: 67 additions & 11 deletions kenjutsu/kenjutsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,9 @@ def reformat_slice(a_slice, a_length=None):
slice(2, 9, 1)
"""

if not isinstance(a_slice, slice):
if a_slice is Ellipsis:
a_slice = slice(None)
elif not isinstance(a_slice, slice):
raise ValueError(
"Expected a `slice` type. Instead got `%s`." % str(a_slice)
)
Expand Down Expand Up @@ -158,22 +160,76 @@ def reformat_slices(slices, lengths=None):
new_slices = (new_slices,)

new_lengths = lengths
if new_lengths is None:
new_lengths = [None] * len(new_slices)

try:
len(new_lengths)
if new_lengths is not None:
len(new_lengths)
except TypeError:
new_lengths = (new_lengths,)

if len(new_slices) != len(new_lengths):
raise ValueError("There must be an equal number of slices to lengths.")
el_idx = None
try:
el_idx = new_slices.index(Ellipsis)
except ValueError:
pass

if new_lengths is not None and el_idx is None:
if len(new_slices) != len(new_lengths):
raise ValueError("Shape must be the same as the number of slices.")
elif new_lengths is not None:
if (len(new_slices) - 1) > len(new_lengths):
raise ValueError(
"Shape must be as large or larger than the number of slices"
" without the Ellipsis."
)

if el_idx is not None:
# Break into three cases.
#
# 1. Before the Ellipsis
# 2. The Ellipsis
# 3. After the Ellipsis
#
# Cases 1 and 3 are trivially solved as before.
# Case 2 is either a no-op or a bunch of `slice(None)`s.
#
# The result is a combination of all of these.

slices_before = new_slices[:el_idx]
slices_after = new_slices[el_idx+1:]

if Ellipsis in slices_before or Ellipsis in slices_after:
raise ValueError("Only one Ellipsis is permitted. Found multiple.")

new_lengths_before = None
new_lengths_after = None
slice_el = (Ellipsis,)
if new_lengths is not None:
pos_before = len(slices_before)
pos_after = len(new_lengths) - len(slices_after)

new_lengths_before = new_lengths[:pos_before]
new_lengths_after = new_lengths[pos_after:]

new_lengths_el = new_lengths[pos_before:pos_after]
slice_el = reformat_slices(
len(new_lengths_el) * (slice(None),),
new_lengths_el
)

new_slices = (
reformat_slices(slices_before, new_lengths_before) +
slice_el +
reformat_slices(slices_after, new_lengths_after)
)
else:
if new_lengths is None:
new_lengths = [None] * len(new_slices)

new_slices = list(new_slices)
for i, each_length in enumerate(new_lengths):
new_slices[i] = reformat_slice(new_slices[i], each_length)
new_slices = list(new_slices)
for i, each_length in enumerate(new_lengths):
new_slices[i] = reformat_slice(new_slices[i], each_length)

new_slices = tuple(new_slices)
new_slices = tuple(new_slices)

return(new_slices)

Expand Down
218 changes: 217 additions & 1 deletion tests/test_kenjutsu.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,60 @@ def test_reformat_slice(self):
len(range(size)[a_slice])
)

rf_slice = kenjutsu.reformat_slice(Ellipsis)
self.assertEqual(
range(size)[:],
range(size)[rf_slice]
)

rf_slice = kenjutsu.reformat_slice(Ellipsis, size)
self.assertEqual(
range(size)[:],
range(size)[rf_slice]
)

start = rf_slice.start
stop = rf_slice.stop
step = rf_slice.step

if step is not None and step < 0 and stop is None:
stop = -1

l = float(stop - start)/float(step)
self.assertEqual(
int(math.ceil(l)),
len(range(size)[:])
)


def test_reformat_slices(self):
with self.assertRaises(ValueError) as e:
kenjutsu.reformat_slices((slice(None),), (1, 2))

self.assertEqual(
str(e.exception),
"There must be an equal number of slices to lengths."
"Shape must be the same as the number of slices."
)

with self.assertRaises(ValueError) as e:
kenjutsu.reformat_slices(
(slice(None), slice(None), Ellipsis), (1,)
)

self.assertEqual(
str(e.exception),
"Shape must be as large or larger than the number of slices"
" without the Ellipsis."
)

with self.assertRaises(ValueError) as e:
kenjutsu.reformat_slices(
(Ellipsis, Ellipsis), (1,)
)

self.assertEqual(
str(e.exception),
"Only one Ellipsis is permitted. Found multiple."
)

rf_slice = kenjutsu.reformat_slices(slice(None))
Expand All @@ -102,6 +148,18 @@ def test_reformat_slices(self):
(slice(0, None, 1),)
)

rf_slice = kenjutsu.reformat_slices(Ellipsis)
self.assertEqual(
rf_slice,
(Ellipsis,)
)

rf_slice = kenjutsu.reformat_slices(Ellipsis, 10)
self.assertEqual(
rf_slice,
(slice(0, 10, 1),)
)

rf_slice = kenjutsu.reformat_slices(slice(None), 10)
self.assertEqual(
rf_slice,
Expand Down Expand Up @@ -150,6 +208,92 @@ def test_reformat_slices(self):
)
)

rf_slice = kenjutsu.reformat_slices(
Ellipsis,
(2, 3, 4, 5)
)
self.assertEqual(
rf_slice,
(
slice(0, 2, 1),
slice(0, 3, 1),
slice(0, 4, 1),
slice(0, 5, 1)
)
)

rf_slice = kenjutsu.reformat_slices(
(
Ellipsis,
slice(0, 1)
),
(2, 3, 4, 5)
)
self.assertEqual(
rf_slice,
(
slice(0, 2, 1),
slice(0, 3, 1),
slice(0, 4, 1),
slice(0, 1, 1)
)
)

rf_slice = kenjutsu.reformat_slices(
(
slice(0, 1),
Ellipsis
),
(2, 3, 4, 5)
)
self.assertEqual(
rf_slice,
(
slice(0, 1, 1),
slice(0, 3, 1),
slice(0, 4, 1),
slice(0, 5, 1)
)
)

rf_slice = kenjutsu.reformat_slices(
(
slice(0, 1),
Ellipsis,
slice(0, 1)
),
(2, 3, 4, 5)
)
self.assertEqual(
rf_slice,
(
slice(0, 1, 1),
slice(0, 3, 1),
slice(0, 4, 1),
slice(0, 1, 1)
)
)

rf_slice = kenjutsu.reformat_slices(
(
slice(0, 1),
Ellipsis,
slice(0, 1),
slice(0, 1),
slice(0, 1)
),
(2, 3, 4, 5)
)
self.assertEqual(
rf_slice,
(
slice(0, 1, 1),
slice(0, 1, 1),
slice(0, 1, 1),
slice(0, 1, 1)
)
)


def test_len_slice(self):
with self.assertRaises(kenjutsu.UnknownSliceLengthException):
Expand All @@ -170,6 +314,11 @@ def test_len_slice(self):
len(range(size)[a_slice])
)

self.assertEqual(
kenjutsu.len_slice(Ellipsis, size),
len(range(size)[:])
)


def test_len_slices(self):
with self.assertRaises(kenjutsu.UnknownSliceLengthException):
Expand All @@ -180,6 +329,12 @@ def test_len_slices(self):
slice(None, None, 2)
))

l = kenjutsu.len_slices(Ellipsis, 10)
self.assertEqual(
l,
(10,)
)

l = kenjutsu.len_slices(slice(None), 10)
self.assertEqual(
l,
Expand All @@ -206,6 +361,67 @@ def test_len_slices(self):
(10, 10, 5, 10)
)

l = kenjutsu.len_slices(
Ellipsis,
(2, 3, 4, 5)
)
self.assertEqual(
l,
(2, 3, 4, 5)
)

l = kenjutsu.len_slices(
(
Ellipsis,
slice(0, 1)
),
(2, 3, 4, 5)
)
self.assertEqual(
l,
(2, 3, 4, 1)
)

l = kenjutsu.len_slices(
(
slice(0, 1),
Ellipsis
),
(2, 3, 4, 5)
)
self.assertEqual(
l,
(1, 3, 4, 5)
)

l = kenjutsu.len_slices(
(
slice(0, 1),
Ellipsis,
slice(0, 1)
),
(2, 3, 4, 5)
)
self.assertEqual(
l,
(1, 3, 4, 1)
)

l = kenjutsu.len_slices(
(
slice(0, 1),
Ellipsis,
slice(0, 1),
slice(0, 1),
slice(0, 1)
),
(2, 3, 4, 5)
)
self.assertEqual(
l,
(1, 1, 1, 1)
)


def test_split_blocks(self):
with self.assertRaises(ValueError) as e:
Expand Down

0 comments on commit 0cfac3e

Please sign in to comment.