In [145]:
from collections import defaultdict, deque
from functools import cmp_to_key
from math import floor


def read_rules():
    rules_dict = defaultdict(list)
    with open('day05_rules.txt', 'r') as file:
        for line in file:
            line = line.strip('\n').split('|')
            rules_dict[line[0]].append(line[1]) 

    return rules_dict

def check_line_order(line, rules_dict):
    for index, item in enumerate(line):
        after_values = rules_dict[item]

        if not all(rest in after_values for rest in line[index+1:]):
            return False
    return True

def build_sort_dicts():
    smaller = defaultdict(set)
    larger = defaultdict(set)
    
    with open('day05_rules.txt', 'r') as file:
        for line in file:
            line = line.strip('\n').split('|')
            before, after = int(line[0]), int(line[1])
            smaller[before].add(after)
            larger[after].add(before)

    return smaller, larger

def cmp(before, after, smaller, larger):
    if before in smaller and after in smaller[before]:
        return -1
    if after in smaller and before in smaller[after]:
        return 1
    if before in larger and after in larger[before]:
        return 1
    if after in larger and before in larger[after]:
        return -1
    return 0
            
def part1():
    rules_dict = read_rules()
    middle_sum = 0

    with open('day05_updates.txt', 'r') as file:
        for line in file:
            line = line.strip('\n').split(',')
            if (check_line_order(line, rules_dict)):
                middle_sum += int(line[int((len(line) - 1)/2)])
    
    print(middle_sum)

def part2():
    smaller, larger = build_sort_dicts()
    middle_sum = 0

    with open('day05_updates.txt', 'r') as file:
        for line in file:
            line = [int(numb) for numb in line.strip('\n').split(',')]
            sorted_list = sorted(line, key=cmp_to_key(
                lambda before, after: cmp(before, after, smaller, larger)))
            if line != sorted_list:
                middle_sum += sorted_list[floor(len(sorted_list)/2)]
    
    print(middle_sum)

            

In [146]:
part1()
part2()

5391
6142
