In [None]:
from collections import defaultdict
from functools import cmp_to_key, partial

from aocd import get_data, submit

DAY = 5
YEAR = 2024

In [None]:
# use test data
raw_test = """47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13

75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47"""


# use real data
raw = get_data(day=DAY, year=YEAR)

print(raw_test)

In [None]:
def parse_rules(rules):
    rules_dict = defaultdict(list)
    for rule in rules:
        p1, p2 = map(int, rule.split("|"))
        rules_dict[p1].append(p2)

    rules_dict = {k: sorted(v) for k, v in rules_dict.items()}
    return dict(rules_dict)


def parse_updates(updates):
    return [list(map(int, u.split(","))) for u in updates.split("\n")]


def parse_data(data):
    rules, updates = data.split("\n\n")
    return parse_rules(rules.split("\n")), parse_updates(updates)


dummy = parse_data(raw_test)
real = parse_data(raw)

# Part 1


In [None]:
def update_ordering_key(u1, u2, rules):
    if u1 in rules and u2 in rules[u1]:
        return -1
    if u2 in rules and u1 in rules[u2]:
        return 1
    return 0


def order_update(update, rules):
    cmp_func = partial(update_ordering_key, rules=rules)
    return sorted(update, key=cmp_to_key(cmp_func))


def is_ordered(update, rules):
    return update == order_update(update, rules)


def get_middle(update):
    return update[len(update) // 2]

In [None]:
rules, updates = real

result = sum([get_middle(update) for update in updates if is_ordered(update, rules)])
result

In [None]:
submit(result, part="a", day=DAY, year=YEAR)

# Part 2


In [None]:
rules, updates = real

incorrect = [update for update in updates if not is_ordered(update, rules)]
result = sum([get_middle(order_update(update, rules)) for update in incorrect])
result

In [None]:
submit(result, part="b", day=DAY, year=YEAR)