In [None]:
%load_ext autoreload
%autoreload 2
from aoc.lib import load, timing

YEAR = 2024
DAY = 5
TEST = False
TESTDATA = '47|53\n97|13\n97|61\n97|47\n75|29\n61|13\n75|53\n29|13\n97|29\n53|29\n61|53\n97|53\n61|29\n47|13\n75|47\n97|75\n47|61\n75|61\n47|29\n75|13\n53|13\n\n75,47,61,53,29\n97,61,53,29,13\n75,29,13\n75,97,47,61,53\n61,13,29\n97,13,75,29,47\n

In [None]:
from collections import defaultdict

@timing
def prepare_data():
    data = load(YEAR, DAY, split_lines=True, test=TESTDATA if TEST else None)

    # parse rules and updates
    data['rules'] = defaultdict(lambda:{'before': set(), 'after': set()})
    data['updates'] = []
    
    it = iter(data['split'])
    while (row := next(it)) != '':
        a, b = (int(i) for i in row.split('|'))
        data['rules'][a]['after'].add(b)
        data['rules'][b]['before'].add(a)

    for row in it:
        data['updates'].append([int(a) for a in row.split(',')])

    return data

In [None]:
# Level 1: find all updates where all rules are fulfilled and sum their middle page numbers
def is_good(data, update):
    before = set()
    after = set(update)
    for num in update:
        rules = data['rules'][num]
        if     len(set.intersection(rules['after'], before)) > 0 \
            or len(set.intersection(rules['before'], after)) > 0:
            return False
        before.add(num)
        after.remove(num)
    return True

@timing
def level1(data):
    good = [upd for upd in data['updates'] if is_good(data, upd)]
    return sum([l[(len(l)-1)//2] for l in good])


data = prepare_data()
print(level1(data))

In [None]:
# Level 2: Correct the incorrect updates and sum their middle page numbers
def fix_update(data, update):
    if len(update) == 0:
        return []
    # find the number that goes front
    # first, identify those with no 'before' dependencies in this update
    no_before = set([num for num in update if len(set.intersection(update, data['rules'][num]['before'])) == 0])
    # eliminate those that appear in the 'after' dependencies of the remaining set
    after = [data['rules'][n]['after'] for n in no_before]
    starters = set.difference(no_before, set.union(*after))
    # pick the first of these and recurse
    first = list(starters)[0]
    update.remove(first)
    return [first] + fix_update(data, update)

@timing
def level2(data):
    bad = [upd for upd in data['updates'] if not is_good(data, upd)]
    fixed = [fix_update(data, set(upd)) for upd in bad]
    return sum([l[(len(l)-1)//2] for l in fixed])

data = prepare_data()
print(level2(data))