In [11]:
from typing import Dict, List, Set, Tuple

In [12]:
s_example = """47|53
97|13
97|61
97|47
75|29
61|13
75|53
29|13
97|29
53|29
61|53
97|53
61|29
47|13
75|47
97|75
47|61
75|61
47|29
75|13
53|13

75,47,61,53,29
97,61,53,29,13
75,29,13
75,97,47,61,53
61,13,29
97,13,75,29,47"""

with open('input.txt', 'r') as f:
    s_input = f.read()[:-1]

def process_string(s: str) -> Tuple[List[Tuple[int, int]], List[List[int]]]:
    rules = [line for line in s.split('\n') if '|' in line]
    rules = [[int(num_string) for num_string in rule.split('|')] for rule in rules]
    updates = [line for line in s.split('\n') if ',' in line]
    updates = [[int(num_string) for num_string in update.split(',')] for update in updates]
    return rules, updates

# Part 1

In [17]:
def make_rules_dict(rules: List[Tuple[int, int]]) -> Dict[int, Set[int]]:
    rules_dict = {}
    for a,b in rules:
        rules_dict.setdefault(a, set()).add(b)
    return rules_dict

def is_valid(rules_dict: Dict[int, Set[int]], update: List[int]) -> bool:
    return not any((a in rules_dict.get(b, set()) for a, b in zip(update, update[1:])))
    #return True when there is no violation of any rule

def compute(s: str) -> int:
    rules, updates = process_string(s)
    rules_dict = make_rules_dict(rules)
    res = 0
    return sum((update[(len(update)-1)//2] for update in updates if is_valid(rules_dict, update)))

In [18]:
compute(s_example)

143

In [19]:
compute(s_input)

5391

Time complexity: O(n * L) where n is the number of updates and L is the maximum length of an update

# Part 2

In [20]:
def sort(rules_dict: Dict[int, Set[int]], update: List[int]) -> List[int]:
    res = [update[-1]] #res is already sorted
    for el in update[-2::-1]: #new element to insert into the sorted array
        if el in rules_dict.get(res[-1], set()): #if el violates order rule for its right-neighbor
            L = len(res)
            i = L-2
            while i >= 0 and el in rules_dict[res[i]]: #if violation, explore further
                i -= 1
            res.insert(i+1, el) #if reach beginning of res, or no violation, insert at correct spot
        else: #no violation initially
            res.append(el) #insert
        #invariant: at the end of the iteration, res is sorted
    return res

def compute_incorrect(s: str) -> int:
    rules, updates = process_string(s)
    rules_dict = make_rules_dict(rules)
    res = 0
    for update in updates:
        if not is_valid(rules_dict, update):
            update_sorted = sort(rules_dict, update)
            res += update_sorted[(len(update_sorted)-1)//2]
    return res

In [21]:
compute_incorrect(s_example)

123

In [22]:
compute_incorrect(s_input)

6142

Time complexity: O(n * L^2) where n is the number of updates and L is the maximum length of an update