Skip to content

Commit

Permalink
Merge branch 'N8Brooks-master'
Browse files Browse the repository at this point in the history
  • Loading branch information
bbayles committed Jan 17, 2021
2 parents 3c6408e + 68f9c19 commit 1576344
Show file tree
Hide file tree
Showing 4 changed files with 300 additions and 0 deletions.
144 changes: 144 additions & 0 deletions more_itertools/more.py
Expand Up @@ -115,6 +115,9 @@
'windowed_complete',
'all_unique',
'value_chain',
'product_index',
'combination_index',
'permutation_index',
]

_marker = object()
Expand Down Expand Up @@ -3597,6 +3600,18 @@ def nth_product(index, *args):
The products of *args* can be ordered lexicographically.
:func:`nth_product` computes the product at sort position *index* without
computing the previous products.
>>> nth_product(8, range(2), range(2), range(2), range(2))
(1, 0, 0, 0)
The equivalent being:
>>> from itertools import product
>>> list(product(range(2), range(2), range(2), range(2)))[8]
(1, 0, 0, 0)
Calling :func:`nth_product` with an index that does not exist when taking
the product of *args* raises an ``IndexError``.
"""
pools = list(map(tuple, reversed(args)))
ns = list(map(len, pools))
Expand Down Expand Up @@ -3625,6 +3640,20 @@ def nth_permutation(iterable, r, index):
computes the subsequence at sort position *index* directly, without
computing the previous subsequences.
>>> nth_permutation('ghijk', 2, 5)
('h', 'i')
The equivalent being:
>>> from itertools import permutations
>>> list(permutations('ghijk', 2))[5]
('h', 'i')
Calling :func:`nth_permutation` with an index that does not exist when
choosing *r* from an *iterable* of the given length raises an
``IndexError``. Calling :func:`nth_permutation` where *r* is negative or
greater than the length of the *iterable* raises a ``ValueError``.
"""
pool = list(iterable)
n = len(pool)
Expand Down Expand Up @@ -3683,3 +3712,118 @@ def value_chain(*args):
yield from value
except TypeError:
yield value


def product_index(element, *args):
"""Equivalent to ``list(product(*args)).index(element)``
The products of *args* can be ordered lexicographically.
:func:`product_index` computes the first index of *element* without
computing the previous products.
>>> product_index([8, 2], range(10), range(5))
42
The equivalent being:
>>> from itertools import product
>>> list(product(range(10), range(5))).index((8, 2))
42
Indexing an *element* that does not exist as a product of *args* raises a
``ValueError``.
"""
index = 0

for x, pool in zip_longest(element, args, fillvalue=_marker):
if x is _marker or pool is _marker:
raise ValueError('element is not a product of args')

pool = tuple(pool)
index = index * len(pool) + pool.index(x)

return index


def combination_index(element, iterable):
"""Equivalent to ``list(combinations(iterable, r)).index(element)``
The subsequences of *iterable* that are of length *r* can be ordered
lexicographically. :func:`combination_index` computes the index of the
first *element*, without computing the previous combinations.
>>> combination_index('adf', 'abcdefg')
10
The equivalent being:
>>> from itertools import combinations
>>> list(combinations('abcdefg', 3)).index(('a', 'd', 'f'))
10
Indexing an *element* that does not exist as a combination of *iterable*
raises a ``ValueError``. The length of the combination is given implicitly
by the length of the *element*.
"""
element = enumerate(element)
k, y = next(element, (None, None))
if k is None:
return 0

indexes = []
pool = enumerate(iterable)
for n, x in pool:
if x == y:
indexes.append(n)
tmp, y = next(element, (None, None))
if tmp is None:
break
else:
k = tmp
else:
raise ValueError('element is not a combination of iterable')

n, _ = last(pool, default=(n, None))

# TODO: replace factorials with math.comb when 3.8 is the minimum version
index = 1
for i, j in enumerate(reversed(indexes), start=1):
j = n - j
if i <= j:
index += factorial(j) // (factorial(i) * factorial(j - i))

return factorial(n + 1) // (factorial(k + 1) * factorial(n - k)) - index


def permutation_index(element, iterable):
"""Equivalent to ``list(permutations(iterable, r)).index(element)```
The subsequences of *iterable* that are of length *r* where order is
important can be ordered lexicographically. :func:`permutation_index`
computes the index of the first *element* directly, without computing
the previous permutations.
>>> permutation_index([1, 3, 2], range(5))
19
The equivalent being:
>>> from itertools import permutations
>>> list(permutations(range(5), 3)).index((1, 3, 2))
19
Indexing an *element* that does not exist as a permutation of *iterable*
raises a ``ValueError``. The length of the permutation is given implicitly
by the length of the *element*.
"""
index = 0
pool = list(iterable)
for i, x in zip(range(len(pool), -1, -1), element):
r = pool.index(x)
index = index * i + r
del pool[r]

return index
7 changes: 7 additions & 0 deletions more_itertools/more.pyi
Expand Up @@ -465,3 +465,10 @@ def nth_permutation(
iterable: Iterable[_T], r: int, index: int
) -> Tuple[_T, ...]: ...
def value_chain(*args: Iterable[Any]) -> Iterable[Any]: ...
def product_index(element: Iterable[_T], *args: Iterable[_T]) -> int: ...
def combination_index(
element: Iterable[_T], iterable: Iterable[_T]
) -> int: ...
def permutation_index(
element: Iterable[_T], iterable: Iterable[_T]
) -> int: ...
13 changes: 13 additions & 0 deletions more_itertools/recipes.py
Expand Up @@ -536,6 +536,19 @@ def nth_combination(iterable, r, index):
sort position *index* directly, without computing the previous
subsequences.
>>> nth_combination(range(5), 3, 5)
(0, 3, 4)
The equivalent being:
>>> list(combinations(range(5), 3))[5]
(0, 3, 4)
Calling :func:`nth_combination` with an index that does not exist when
choosing *r* from an *iterable* of the given length raises an
``IndexError``. Calling :func:`nth_combination` where *r* is negative or
greater than the length of the *iterable* raises a ``ValueError``.
"""
pool = tuple(iterable)
n = len(pool)
Expand Down
136 changes: 136 additions & 0 deletions tests/test_more.py
Expand Up @@ -4227,3 +4227,139 @@ def test_complex(self):
)
expected = [1, (2, (3,)), 'foo', ['bar', ['baz']], 'tic', 'key', obj]
self.assertEqual(actual, expected)


class ProductIndexTests(TestCase):
def test_basic(self):
iterables = ['ab', 'cdef', 'ghi']
first_index = {}
for index, element in enumerate(product(*iterables)):
actual = mi.product_index(element, *iterables)
expected = first_index.setdefault(element, index)
self.assertEqual(actual, expected)

def test_multiplicity(self):
iterables = ['ab', 'bab', 'cab']
first_index = {}
for index, element in enumerate(product(*iterables)):
actual = mi.product_index(element, *iterables)
expected = first_index.setdefault(element, index)
self.assertEqual(actual, expected)

def test_long(self):
actual = mi.product_index((1, 3, 12), range(101), range(22), range(53))
expected = 1337
self.assertEqual(actual, expected)

def test_invalid_empty(self):
with self.assertRaises(ValueError):
mi.product_index('', 'ab', 'cde', 'fghi')

def test_invalid_small(self):
with self.assertRaises(ValueError):
mi.product_index('ac', 'ab', 'cde', 'fghi')

def test_invalid_large(self):
with self.assertRaises(ValueError):
mi.product_index('achi', 'ab', 'cde', 'fghi')

def test_invalid_match(self):
with self.assertRaises(ValueError):
mi.product_index('axf', 'ab', 'cde', 'fghi')


class CombinationIndexTests(TestCase):
def test_r_less_than_n(self):
iterable = 'abcdefg'
r = 4
first_index = {}
for index, element in enumerate(combinations(iterable, r)):
actual = mi.combination_index(element, iterable)
expected = first_index.setdefault(element, index)
self.assertEqual(actual, expected)

def test_r_equal_to_n(self):
iterable = 'abcd'
r = len(iterable)
first_index = {}
for index, element in enumerate(combinations(iterable, r=r)):
actual = mi.combination_index(element, iterable)
expected = first_index.setdefault(element, index)
self.assertEqual(actual, expected)

def test_multiplicity(self):
iterable = 'abacba'
r = 3
first_index = {}
for index, element in enumerate(combinations(iterable, r)):
actual = mi.combination_index(element, iterable)
expected = first_index.setdefault(element, index)
self.assertEqual(actual, expected)

def test_null(self):
actual = mi.combination_index(tuple(), [])
expected = 0
self.assertEqual(actual, expected)

def test_long(self):
actual = mi.combination_index((2, 12, 35, 126), range(180))
expected = 2000000
self.assertEqual(actual, expected)

def test_invalid_order(self):
with self.assertRaises(ValueError):
mi.combination_index(tuple('acb'), 'abcde')

def test_invalid_large(self):
with self.assertRaises(ValueError):
mi.combination_index(tuple('abcdefg'), 'abcdef')

def test_invalid_match(self):
with self.assertRaises(ValueError):
mi.combination_index(tuple('axe'), 'abcde')


class PermutationIndexTests(TestCase):
def test_r_less_than_n(self):
iterable = 'abcdefg'
r = 4
first_index = {}
for index, element in enumerate(permutations(iterable, r)):
actual = mi.permutation_index(element, iterable)
expected = first_index.setdefault(element, index)
self.assertEqual(actual, expected)

def test_r_equal_to_n(self):
iterable = 'abcd'
first_index = {}
for index, element in enumerate(permutations(iterable)):
actual = mi.permutation_index(element, iterable)
expected = first_index.setdefault(element, index)
self.assertEqual(actual, expected)

def test_multiplicity(self):
iterable = 'abacba'
r = 3
first_index = {}
for index, element in enumerate(permutations(iterable, r)):
actual = mi.permutation_index(element, iterable)
expected = first_index.setdefault(element, index)
self.assertEqual(actual, expected)

def test_null(self):
actual = mi.permutation_index(tuple(), [])
expected = 0
self.assertEqual(actual, expected)

def test_long(self):
actual = mi.permutation_index((2, 12, 35, 126), range(180))
expected = 11631678
self.assertEqual(actual, expected)

def test_invalid_large(self):
with self.assertRaises(ValueError):
mi.permutation_index(tuple('abcdefg'), 'abcdef')

def test_invalid_match(self):
with self.assertRaises(ValueError):
mi.permutation_index(tuple('axe'), 'abcde')

0 comments on commit 1576344

Please sign in to comment.