In [1]:
from collections import defaultdict

In [2]:
test_input = '''47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13

75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47'''

In [3]:
def parse_input(s):
    rules, updates = s.split('\n\n')

    rule_lookup = defaultdict(set)
    for line in rules.split('\n'):
        before, after = line.split('|')
        rule_lookup[int(before)].add(int(after))

    updates = [[int(n) for n in line.split(',')] for line in updates.split('\n')]
    return rule_lookup, updates

## Part One

Loop through update. On each page, check if page should be before a page already seen.

In [4]:
def test_update(update, rules):
    seen = set()
    for page in update:
        must_be_before = rules[page]
        if not seen.isdisjoint(must_be_before):
            return False
        seen.add(page)
    return True

### Test input

In [5]:
rules, updates = parse_input(test_input)
sum(update[len(update)//2] for update in updates if test_update(update, rules))

143

In [6]:
with open('input_files/05.txt') as f:
    raw_input = f.read()
    
rules, updates = parse_input(raw_input)

answer = sum(update[len(update) // 2] for update in updates if test_update(update, rules))
print("Part one:", answer)

Part one: 5064


## Part Two

The main rules set seems to have a cycle, so we can't sort the entire graph. Instead, sort the subgraph only containing the needed nodes. These are well-behaved.

Let try implementing [Kahn's algorithm](https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm). 

In [7]:
def top_sort(rules, update):
    # create subgraph with only needed nodes
    subgraph = {n: rules[n] & set(update) for n in update}

    # Track how many inbound edges each node has 
    # to make it easy to test 'if m has no other incoming edges'
    inbound_counts = defaultdict(int)
    
    for n in update:
        for dest in subgraph[n]:
            inbound_counts[dest] += 1

    # Find a node(s) with no inbound edges.
    start = subgraph.keys() - inbound_counts.keys()
    ordered = []
    while start:
        n = start.pop()
        ordered.append(n)
        for m in subgraph[n]:
            inbound_counts[m] -= 1
            if inbound_counts[m] == 0:
                start.add(m)   
                        
    return ordered

def solve(rules, updates):
    bad_ordering = [update for update in updates if not test_update(update, rules)]
    
    total  = 0

    for update in bad_ordering:
        correct = top_sort(rules, update)
        total += correct[len(correct) // 2]

    return total

### Test input

In [8]:
rules, updates = parse_input(test_input)

solve(rules, updates)

123

In [9]:
rules, updates = parse_input(raw_input)
print("Part two:", solve(rules, updates))

Part two: 5152
