# On `itertools.groupby`

Presentation to San Diego Python User Group on 2019-09-26

## Why `groupby`?

- Sufficiently narrow subject for a lightning talk
- Every Python developer should be familiar with the collections and itertools packages
- It’s a pretty awesome function!
- However, unappreciated :(

## What does `groupby` do?

- Breaks a sequence of objects into chunks of consecutive elements
- A chunk is a contiguous portion of the input that have something in common, as determined by a key function
- The key function is an argument of `groupby` and defaults to the identity function
- `groupby` returns a generator of pairs `(k, g)`, one such pair for each chunk, where `k` is the key value of the chunk, and `g` is a generator of the elements of the chunk.


## Example

In [1]:
from itertools import groupby

words = 'the quick brown fox jumps over the lazy dog'.split(' ')
words.sort(key=len)
words

['the', 'fox', 'the', 'dog', 'over', 'lazy', 'quick', 'brown', 'jumps']

In [2]:
{k:list(g) for k, g in groupby(words, key=len)}

{3: ['the', 'fox', 'the', 'dog'],
 4: ['over', 'lazy'],
 5: ['quick', 'brown', 'jumps']}

In [3]:
[k for k, _ in groupby(words, key=len)]

[3, 4, 5]

## Example: Sales report

**Task:** Create a nicely formatted sales report. 

**Input:** An iterable of records, where each record consists of a product ID, product group ID, and value. The input is sorted by product group ID then by product ID. The input may contain multiple records pertaining to the same product.

**Output:** A nicely formatted report. The report should show a total for each product, a total for each product group, and a total for all records. 


In [4]:
from collections import namedtuple

Record = namedtuple('Record', 'product product_group value')

### Flow chart

A rough flow chart for the task is shown below. Typical implementations keep track of the previous record and compare it with the current record to check if the previous record was the last one of a product or product group. 

<img src="assets/flow_chart.png" alt="Flow Chart" width="500"/>

`groupby` implements the logic of comparing current with previous etc., so that you don't have to do it in your code. We can use `groupby` to deliver the records in chunks where all records in each chunk belong to the same product group.

In [5]:
from itertools import groupby

def generate_report(all_records):
    total = 0
    for prod_group_id, prod_group_records in groupby(all_records, key=lambda record: record.product_group):
        total += yield from process_product_group(prod_group_id, prod_group_records)
    yield 'Total:                          %6d' % total

The function `generate_report` divides the input stream into chunks, where each chunk corresponds to all records in one product group, and passes the product group id and product group records to `process_product_group`. The function `generate_report` passes on whatever the function `process_product_group` yields. The sum for each product group is added to the grand total, which is yielded at the end.

In [6]:
def process_product_group(group_id, group_records):
    group_total = 0
    yield 'Group: %s' % group_id
    for prod_id, prod_records in groupby(group_records, key=lambda record: record.product):
        group_total += yield from process_product(prod_id, prod_records)
    yield '    Group total:                %6d' % group_total
    yield ''
    return group_total

The function `process_product_group` divides the stream of all the records that belong to one product group into chunks, where each chunk corresponds to the records of one product, and passes the product id and the product records to `process_product`. The function `process_product_group` passes on whatever the function `process_product` yields. The sum for each product is added to the product group total, which is returned at the end.

Note that a generator function can both yield values and return a value. This is what happens: When the generator function reaches the return statement, a `StopIteration` exception is raised. The returned value is passed in the `StopIteration` exception's `value` property, and can be extracted by the client if it catches the exception. This does not need to be coded explicitly; it's all part of the `yield from` semantics. 

In [7]:
def process_product(prod_id, prod_records):
    prod_total = sum(int(record.value) for record in prod_records)
    yield '    Product: %s Value: %6d' % (prod_id, prod_total)
    return prod_total

The function `process_product` yields one line of the report that pertains to one product and returns the total value for that product.

### Verification

In [8]:
records = [Record(product='0001', product_group='001', value='12'),
           Record(product='0012', product_group='001', value='1000'),
           Record(product='0012', product_group='001', value='32'),
           Record(product='0009', product_group='007', value='207'),
           Record(product='0112', product_group='007', value='12119'),
           Record(product='1009', product_group='007', value='200')]

for line in generate_report(records):
    print(line)

Group: 001
    Product: 0001 Value:     12
    Product: 0012 Value:   1032
    Group total:                  1044

Group: 007
    Product: 0009 Value:    207
    Product: 0112 Value:  12119
    Product: 1009 Value:    200
    Group total:                 12526

Total:                           13570


## Example: Prime factors

Consider the following function:

In [9]:
def gen_prime_factors(n):
    """Generate all the prime factors of n in ascending order"""
    factor = 2
    while factor * factor <= n:
        if n % factor == 0:
            yield factor
            n //= factor
        else:
            factor = factor + (1 if factor == 2 else 2)
    if n > 1:
        yield n

In [10]:
list(gen_prime_factors(72))

[2, 2, 2, 3, 3]

Sometimes it's useful to group equal factors together:

In [11]:
def prime_factors(n):
    """Returns a list of pairs in ascending order where the first
    element of each pair is a prime, and the second element is the
    number of times the prime divides n. For example:
    prime_factors(360) returns [(2, 3), (3, 2), (5, 1)]"""

    return [(p, sum(1 for _ in g)) for p, g in groupby(gen_prime_factors(n))]

In [12]:
print(360, ' + '.join(f'{x}^{y}' for x, y in prime_factors(360)), sep=' = ')

360 = 2^3 + 3^2 + 5^1


Sometimes we need distinct factors:

In [13]:
def totient(n):
    """Euler's totient function, a.k.a. Euler's phi funtion.
    Returns the number of integers between 1 and n that are coprime
    with n. For example: totient(60) = 16. The 16 numbers that coprime
    with 60 are: 1, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 49,
    53, 59"""

    for p, _ in groupby(gen_prime_factors(n)):
        n -= n // p
    return n

In [14]:
totient(60)

16

In [15]:
from math import gcd

[n for n in range(1, 60) if gcd(60, n) == 1]

[1, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47, 49, 53, 59]

In [16]:
len(_)

16

The above functions might have been written in a more direct way without calling `gen_prime_factors`, but that would lead to code duplication. The function `gen_prime_factors` captures the common part in many prime factor related functions as exemplified by the functions above and also by the functions below:

In [17]:
def is_prime(n):
    """:returns True if n is a prime number, False otherwise."""
    return 1 < n == next(gen_prime_factors(n))

In [18]:
is_prime(1021)

True

In [19]:
is_prime(57)

False

In [20]:
prime_factors(57)

[(3, 1), (19, 1)]

In [21]:
def smallest_prime_factor(n):
    """Returns the smallest prime factor of n
    n: An int >= 2
    Raises a StopIteration if n < 2
    """
    return next(gen_prime_factors(n))

In [22]:
def largest_prime_factor(n):
    """Returns the largest prime factor of n
    n: An int >= 2
    Raises a ValueError if n < 2
    """
    return max(gen_prime_factors(n))

In [23]:
prime_factors(2019)

[(3, 1), (673, 1)]

In [24]:
smallest_prime_factor(2019)

3

In [25]:
largest_prime_factor(2019)

673