In [None]:
from aoc.utils import download_input
from functools import cmp_to_key

In [None]:
file_path = download_input(day=5, year=2024, output_dir="input_files")

In [None]:
def parse_input(input_text):
    """Parse the input into rules and updates."""
    rules_section, updates_section = input_text.strip().split("\n\n")
    
    # Parse rules
    rules = []
    for line in rules_section.splitlines():
        x, y = map(int, line.split("|"))
        rules.append((x, y))
    
    # Parse updates
    updates = []
    for line in updates_section.splitlines():
        updates.append(list(map(int, line.split(","))))
    
    return rules, updates


def is_update_valid(update, rules):
    """Check if an update is valid based on the rules."""
    page_positions = {page: idx for idx, page in enumerate(update)}
    
    for x, y in rules:
        if x in page_positions and y in page_positions:
            if page_positions[x] > page_positions[y]:
                return False
    return True


def find_middle_page(update):
    """Find the middle page of an update."""
    mid_index = len(update) // 2
    return update[mid_index]


def part1(input_text):
    """Solve the problem and return the sum of middle pages."""
    # Parse the input
    rules, updates = parse_input(input_text)
    
    # Validate updates and compute middle page numbers
    valid_middle_pages = []
    for update in updates:
        if is_update_valid(update, rules):
            valid_middle_pages.append(find_middle_page(update))
    
    # Sum the middle page numbers
    return sum(valid_middle_pages)

In [None]:
# Example input
example_input = """
47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13

75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47
"""

# Solve the example
result = part1(example_input)
print(f"Sum of middle pages for valid updates: {result}")

In [None]:
with open(file_path, "r") as file:
    input_string = file.read()

# Solve the example
result = part1(input_string)
print(f"Sum of middle pages for valid updates: {result}")

In [None]:
def sort_update(update, rules):
    """Sort the update based on the ordering rules."""

    # Create a comparison function based on the rules
    def compare_pages(a, b):
        for x, y in rules:
            if a == x and b == y:
                return -1  # a should come before b
            if a == y and b == x:
                return 1   # a should come after b
        return 0  # no rule between a and b

    return sorted(update, key=cmp_to_key(compare_pages))

def part2(input_text):
    """Solve Part 2 of the problem."""
    # Parse the input
    rules, updates = parse_input(input_text)
    
    # Identify and fix invalid updates
    corrected_middle_pages = []
    for update in updates:
        if not is_update_valid(update, rules):
            corrected_update = sort_update(update, rules)
            corrected_middle_pages.append(find_middle_page(corrected_update))
    
    # Sum the middle page numbers
    return sum(corrected_middle_pages)

In [None]:
# Solve the example
result = part2(example_input)
print(f"Sum of middle pages for corrected updates: {result}")

result = part2(input_string)
print(f"Sum of middle pages for corrected updates: {result}")