# Day 5
## Part 1

In [4]:
from collections import defaultdict

def read_data(s):
    section_1, section_2 = s.strip().split("\n\n")
    
    rules = defaultdict(set)
    for line in section_1.splitlines():
        a, b = line.split("|")
        rules[int(a)].add(int(b))

    updates = [
        [int(x) for x in line.split(",")]
        for line in section_2.splitlines()
    ]

    return rules, updates

def middle_value(l):
    return l[len(l) // 2]

def valid_update(rules, update):
    return all(
        y in rules[x]
        for x, y in zip(update, update[1:])
    )

def part_1(data):
    rules, updates = data
    return sum(
        middle_value(update)
        for update in updates
        if valid_update(rules, update)
    )

test_data = read_data("""47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13

75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47""")

assert part_1(test_data) == 143

In [6]:
data = read_data(open("input").read())

part_1(data)

6951

## Part 2
Recursively find the first page of the remaining update.

In [7]:
def correct_order(rules, update):
    if len(update) == 1:
        return update
        
    first = next(
        p 
        for p in update 
        if not any(p in rules[x] for x in update)
    )
    remaining = update[:]
    remaining.remove(first)
    return [first] + correct_order(rules, remaining)

def part_2(data):
    rules, updates = data

    return sum(
        middle_value(correct_order(rules, update))
        for update in updates
        if not valid_update(rules, update)
    )

assert part_2(test_data) == 123

In [8]:
part_2(data)

4121

### Part 2 revisited

I'm a bit uncomfortable with how the ordered list is constructed above as repeatedly concatenating lists is inefficient, so here's a non-recursive rewrite using set operations.

In [21]:
def correct_order_v2(rules, update):
    q = set(update)
    ordered = []

    while q:
        next_page = q.difference(*(rules[p] for p in q)).pop()
        ordered.append(next_page)
        q.remove(next_page)

    return ordered

def part_2_v2(data):
    rules, updates = data
    
    return sum(
        middle_value(correct_order_v2(rules, update))
        for update in updates
        if not valid_update(rules, update)
    )

assert part_2_v2(test_data) == 123

In [22]:
part_2_v2(data)

4121

In [23]:
%%timeit

part_2(data)

6.48 ms ± 140 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [24]:
%%timeit

part_2_v2(data)

4.47 ms ± 54.1 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)


Maybe not worth spending time on that.