In [None]:
import os
import sys

sys.path.insert(0, os.path.abspath("../utils"))
from aoc_utils import load_data, check

In [None]:
from collections import defaultdict

In [None]:
data = load_data(2024, 5)

In [None]:
# data, part_1, part_2
tests = [
    (
        """47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13

75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47
""",
        143,
        123,
    ),
    (
        """1|2
2|3

3,1,2
""",
        0,
        2,
    ),
]

# Part 1

In [None]:
def parse_updates(data):
    orders, updates = data.split("\n\n")
    precedence = defaultdict(list)
    for line in orders.splitlines():
        before, after = (int(v) for v in line.split("|"))
        precedence[before].append(after)
    updates = [[int(v) for v in line.split(",")] for line in updates.splitlines()]
    return precedence, updates

In [None]:
def sort(values, precedence):
    leasts = set(values)  # leasts contains all smallest values
    forward = defaultdict(list)  # a -> [b1, b2, ...] means a < bi
    reverse = defaultdict(list)  # a -> [b1, b2, ...] means a > bi
    for a, dests in precedence.items():
        for dest in dests:
            if a in values and dest in values:
                forward[a] += [dest]
                reverse[dest] += [a]
                leasts -= {dest}

    ret = []
    while leasts:
        least = leasts.pop()
        ret += [least]
        if least in forward:
            candidates = forward.pop(least)
            for candidate in candidates:
                reverse[candidate].remove(least)
                if not reverse[candidate]:
                    leasts.add(candidate)

    if any(reverse.values()):
        raise ValueError("Non sortable sequence")
    return ret

In [None]:
def check_updates(data, *, valid=True):
    precedence, updates = parse_updates(data)
    s = 0
    for update in updates:
        ordered = sort(update, precedence)
        if valid == (ordered == update):
            s += ordered[len(ordered) // 2]
    return s

In [None]:
check(check_updates, tests)
check_updates(data)

# Part 2

In [None]:
check(check_updates, tests, 2, valid=False)
check_updates(data, valid=False)

# Lessons learned

The built-in function `sorted` (combined with `cmp_to_key`) is not applicable for partial orders, but it also does not work with total orders that are only partially visible through the comparison function.  
At least for me, this kind of solution gave the expected result with AOC's input.

In [None]:
from functools import cmp_to_key, partial

In [None]:
def comparator(a, b, precedence):
    if a in precedence and b in precedence[a]:
        cmp = -1
    elif b in precedence and a in precedence[b]:
        cmp = 1
    else:
        cmp = 0
    print(f"comparing {a} and {b} -> {cmp}")
    return cmp

## Partial order

In the example below, 1 should be before 3, and 2 can be anywhere.  
The provided result `[3, 2, 1]` is incorrect.

In [None]:
sorted([3, 2, 1], key=cmp_to_key(partial(comparator, precedence={1: {3}})))

## Total order

In the example below, the precedence states that 1 < 2 and 2 < 3, so the only solution is: `[1, 2, 3]`.  
Since `sorted` only compares 1 and 3 (which are not comparable), and then 1 and 2 (which are already in the right order), the provided result `[3, 1, 2]` is incorrect.

In [None]:
sorted([3, 1, 2], key=cmp_to_key(partial(comparator, precedence={1: {2}, 2: {3}})))