In [62]:
# we can convert the rules into before, after dictionaries
# then also convert each update into a list format
# iterate over each element of each update
# then check the two sides of the list from the element with the rule dictionaries

def parse_input(input_str: str):
    rules_str, updates = input_str.split('\n\n')
    # parse updates into list format
    update_list = []
    for update_str in updates.split('\n'):
        update = []
        for page in update_str.split(','):
            update.append(int(page))
        update_list.append(update)

    # print(update_list)

    # parse rules into dictionaries
    
    # before_dict: keys must come before values
    before_dict = dict()
    # after_dict: keys come after values
    after_dict = dict()

    for line in rules_str.split('\n'):
        [before, after] = [int(page) for page in line.split('|')]
        if before in before_dict:
            before_dict[before].append(after)
        else:
            before_dict[before] = [after]
        if after in after_dict:
            after_dict[after].append(before)
        else:
            after_dict[after] = [before]

    # print(before_dict)
    # print(after_dict)
    
    return update_list, before_dict, after_dict

class Manual:
    def __init__(self, update_list: list, before_rules: dict, after_rules: dict):
        self.update_list = update_list
        self.before_rules = before_rules
        self.after_rules = after_rules

    def checkUpdate(self, update: list):
        for index, element in enumerate(update):
            # splice the list into two parts, before and after element
            if index > 0:
                pre = update[:index] # list of numbers in the update set that come before the current num 
                # hence, none of these same numbers should be in the corresponding before_dict value
                for item in pre:
                    if item in self.before_rules.get(element, []):
                        return False
        
            if index < len(update)-1:
                post = update[index+1:]
                for item in post:
                    if item in self.after_rules.get(element, []):
                        return False

        return True
    
    def getMiddleNumber(self, update: list):
        # assumption is that list is of odd number length
        middle_index = int(len(update) / 2)
        return update[middle_index]
    
    def part1(self):
        valid_updates = [update for update in self.update_list if self.checkUpdate(update)]
        sum_of_middle_nums = sum([self.getMiddleNumber(update) for update in valid_updates])
        return sum_of_middle_nums

In [63]:
with open('data/test/5.txt', 'r', encoding='utf-8') as f:
    input_str = f.read()

update_list, before_dict, after_dict = parse_input(input_str)
manual = Manual(update_list, before_dict, after_dict)
manual.part1()


143

In [64]:
with open('data/input/5.txt', 'r', encoding='utf-8') as f:
    input_str = f.read()

update_list, before_dict, after_dict = parse_input(input_str)
manual = Manual(update_list, before_dict, after_dict)
manual.part1()

5091

In [65]:
#part2
from collections import deque
from tqdm.notebook import tqdm

class Manual2(Manual):
    def fixUpdate(self, update: list):
        sorted_list = []
        # pop numbers off a deque
        # if number is in order, add to new list 
        # if number not in order, add to other end of deque
        # loop over deque until solved 

        dq = deque()
        for item in update:
            dq.append(item)

        while dq:
            a = dq.pop()
            # check if a is in order (ie. everything before it in the deque is legally there)
            # if not in order, move to other end of deque
            # if in order, move to sorted list
            in_order = True
            for item in self.before_rules.get(a, []): # comparing what should come after the popped item with the set of numbers that actually come before the popped item in the deque
                if item in dq:
                    dq.appendleft(a)
                    in_order = False
                    break
            if in_order:
                sorted_list.append(a)

        return sorted_list

    def part2(self):
        invalid_updates = [update for update in self.update_list if not self.checkUpdate(update)]
        # print(invalid_updates)
        fixed_updates = [self.fixUpdate(update) for update in invalid_updates]
        # print(fixed_updates)
        sum_of_middle_nums = sum([self.getMiddleNumber(update) for update in fixed_updates])
        return sum_of_middle_nums



In [66]:
with open('data/test/5.txt', 'r', encoding='utf-8') as f:
    input_str = f.read()

update_list, before_dict, after_dict = parse_input(input_str)
manual = Manual2(update_list, before_dict, after_dict)
manual.part2()

123

In [67]:
with open('data/input/5.txt', 'r', encoding='utf-8') as f:
    input_str = f.read()

update_list, before_dict, after_dict = parse_input(input_str)
manual = Manual2(update_list, before_dict, after_dict)
manual.part2()

4681