In [61]:
def read_file(input_file_path):
    with open(input_file_path, 'r') as f:
        lines = f.readlines()
    return lines


def preprocess(s):
    # separate two parts
    separater = s.index('\n')
    rules = s[:separater]
    updates = s[separater+1:] 

    # clean up format
    rules = [rule.rstrip().split('|') for rule in rules]
    rules = [(int(rule[0]), int(rule[1]))for rule in rules]

    updates = [update.rstrip().split(',') for update in updates]
    updates = [[int(val) for val in update] for update in updates]

    return rules, updates


def get_allowed_after_vals(rules):
    allowed_after_vals = dict()
    for rule in rules:
        first = rule[0]
        later = rule[1]
        if not allowed_after_vals.get(first):
            new = {first:[later]}
            allowed_after_vals.update(new)
        else:
            allowed_after_vals[first].append(later)
    return allowed_after_vals
        
def check_update(update, rules):
    relevant_rules = [rule for rule in rules if rule[0] in update and rule[1] in update]
    for rule in relevant_rules:
        first = rule[0]
        later = rule[1]
        if update.index(first) > update.index(later):
            return False
    return True

def check_udpate_v2(update:list, allowed_after_vals:dict):
    # functionally the same as check_update but using a data structure for allowed_after_vals that makes pt2 easier
    for i, x in enumerate(update):
        if not allowed_after_vals.get(x):
            continue
        vals_actually_before_x = set(update[:i])
        allowed_vals_after_x = set(allowed_after_vals[x])
        if len(vals_actually_before_x.intersection(allowed_vals_after_x))>0:
            # if any value before x in update is in the after vals, then it violates the rule
            return False
    return True

def find_mid_num(lst):
    return lst[len(lst)//2]
    
def find_corrected_mid_num(update:list, allowed_after_vals:dict):
    # find the would-have-been correct mid number in a wrong update had it been correct
    correct_after_value_count = len(update)//2
    for i, x in enumerate(update):
        if not allowed_after_vals.get(x):
            continue
        allowed_vals_after_x = set(allowed_after_vals[x])
        rest_of_update = set(update[:i]+ update[i+1:])  # assume no duplicates in update
        assert len(rest_of_update) == len(update)-1
        if len(rest_of_update.intersection(allowed_vals_after_x)) == correct_after_value_count:
            return x


def solution_pt1(input_file_path):
    s = read_file(input_file_path)
    rules, updates = preprocess(s)

    checks = [check_update(update, rules) for update in updates]
    return sum([find_mid_num(update) for check, update in zip(checks, updates) if check])


def solution_pt2(input_file_path):
    s = read_file(input_file_path)
    rules, updates = preprocess(s)
    allowed_after_vals = get_allowed_after_vals(rules)

    checks = [check_udpate_v2(update, allowed_after_vals) for update in updates] # could use check_udpate too. 
    incorrect_updates = [update for check, update in zip(checks, updates) if not check] 
    return sum([find_corrected_mid_num(update, allowed_after_vals) for update in incorrect_updates])

In [97]:
solution_pt1('input_5.txt')

5129

In [62]:
solution_pt2('input_5.txt')

4077

scratchpad

In [3]:
def read_file(input_file_path):
    with open(input_file_path, 'r') as f:
        lines = f.readlines()
    return lines

s = read_file('test.txt')

separater = s.index('\n')
rules = s[:separater]
updates = s[separater+1:]

In [4]:
rules = [rule.rstrip().split('|') for rule in rules]
rules = [(int(rule[0]), int(rule[1]))for rule in rules]
rules

[(47, 53),
 (97, 13),
 (97, 61),
 (97, 47),
 (75, 29),
 (61, 13),
 (75, 53),
 (29, 13),
 (97, 29),
 (53, 29),
 (61, 53),
 (97, 53),
 (61, 29),
 (47, 13),
 (75, 47),
 (97, 75),
 (47, 61),
 (75, 61),
 (47, 29),
 (75, 13),
 (53, 13)]

In [5]:
updates = [update.rstrip().split(',') for update in updates]
updates = [[int(val) for val in update] for update in updates]
updates

[[75, 47, 61, 53, 29],
 [97, 61, 53, 29, 13],
 [75, 29, 13],
 [75, 97, 47, 61, 53],
 [61, 13, 29],
 [97, 13, 75, 29, 47]]

In [6]:
def find_mid_num(lst):
    return lst[len(lst)//2]

part 1

In [92]:
def check_update(update, rules):
    relevant_rules = [rule for rule in rules if rule[0] in update and rule[1] in update]
    for rule in relevant_rules:
        first = rule[0]
        later = rule[1]
        if update.index(first) > update.index(later):
            return False
    return True

In [95]:
checks = [check_update(update, rules) for update in updates]
print(checks)
sum([find_mid_num(update) for check, update in zip(checks, updates) if check])

[True, True, True, False, False, False]


143

part 2

In [7]:
rules

[(47, 53),
 (97, 13),
 (97, 61),
 (97, 47),
 (75, 29),
 (61, 13),
 (75, 53),
 (29, 13),
 (97, 29),
 (53, 29),
 (61, 53),
 (97, 53),
 (61, 29),
 (47, 13),
 (75, 47),
 (97, 75),
 (47, 61),
 (75, 61),
 (47, 29),
 (75, 13),
 (53, 13)]

In [8]:
def get_allowed_after_vals(rules):
    allowed_after_vals = dict()
    for rule in rules:
        first = rule[0]
        later = rule[1]
        if not allowed_after_vals.get(first):
            new = {first:[later]}
            allowed_after_vals.update(new)
        else:
            allowed_after_vals[first].append(later)
    return allowed_after_vals
        

In [12]:
allowed_after_vals = get_allowed_after_vals(rules)
allowed_after_vals

{47: [53, 13, 61, 29],
 97: [13, 61, 47, 29, 53, 75],
 75: [29, 53, 47, 61, 13],
 61: [13, 53, 29],
 29: [13],
 53: [29, 13]}

In [17]:
updates

[[75, 47, 61, 53, 29],
 [97, 61, 53, 29, 13],
 [75, 29, 13],
 [75, 97, 47, 61, 53],
 [61, 13, 29],
 [97, 13, 75, 29, 47]]

In [24]:
def check_udpate_v2(update:list, allowed_after_vals:dict):
    for i, x in enumerate(update):
        if not allowed_after_vals.get(x):
            continue
        vals_actually_before_x = set(update[:i])
        allowed_vals_after_x = set(allowed_after_vals[x])
        if len(vals_actually_before_x.intersection(allowed_vals_after_x))>0:
            # if any value before x in update is in the after vals, then it violates the rule
            return False
    return True
checks = [check_udpate_v2(update, allowed_after_vals) for update in updates]
incorrect_updates = [update for check, update in zip(checks, updates) if not check] 
incorrect_updates     


[[75, 97, 47, 61, 53], [61, 13, 29], [97, 13, 75, 29, 47]]

In [25]:
def find_corrected_mid_num(update:list, allowed_after_vals:dict):
    # find the would-have-been correct mid number in a wrong update had it been correct
    correct_after_value_count = len(update)//2
    for i, x in enumerate(update):
        if not allowed_after_vals.get(x):
            continue
        allowed_vals_after_x = set(allowed_after_vals[x])
        rest_of_update = set(update[:i]+ update[i+1:])  # assume no duplicates in update
        if len(rest_of_update.intersection(allowed_vals_after_x)) ==correct_after_value_count:
            return x

In [27]:
sum([find_corrected_mid_num(update, allowed_after_vals) for update in incorrect_updates])

123