In [1]:
# Imports & read file
import time
import itertools
import re

def read_file(filename):
    with open(filename) as infile:
        rules, messages = [list(g) for k, g in itertools.groupby([line.strip() for line in infile.readlines()], lambda x: bool(x)) if k]
        rules = {int(a): b for rule in rules for a, b in [rule.split(": ")]}
        return rules, messages
    return None

In [2]:
# Part 1
def compile_rule(rule, rules, compiled):
    if rule in compiled:
        return compiled[rule]
    r = rules[rule]
    if r == '"a"' or r == '"b"':
        r = r[1]
    else:
        r = [l.split(' ') for l in r.split(' | ')]
        r = "(" + "|".join("".join(compile_rule(int(i), rules, compiled) for i in l) for l in r) + ")"
    compiled[rule] = r
    return r

def follows(rule, messages):
    rule = re.compile("^"+rule+"$")
    return [bool(rule.match(message)) for message in messages]

In [3]:
# Test part 1
start = time.time()
r, m = read_file("test01.txt")
print(follows(compile_rule(0, r, {}), m) == [True, False, True, False, False])
time.time() - start

True


0.0

In [4]:
# Solve part 1
start = time.time()
r, m = read_file("input.txt")
print(sum(follows(compile_rule(0, r, {}), m)))
time.time() - start

235


0.018999099731445312

In [5]:
# Part 2
def compile_rule2(rule, rules, compiled):
    if rule in compiled:
        return compiled[rule]
    r = rules[rule]
    if r == '"a"' or r == '"b"':
        r = r[1]
    else:
        r = [l.split(' ') for l in r.split(' | ')]
        joiner = ")11(" if rule == 11 else ""
        r = "(" + "|".join(joiner.join(compile_rule2(int(i), rules, compiled) for i in l) for l in r) + ")"
        if rule == 8:
            r += '+'
        if rule == 11:
            r = ")11" + r + "11("
    compiled[rule] = r
    return r

def follows2(rule, messages, max11=5):
    n = "{{{0}}}"
    g = rule.split('11')
    g[1] += n
    g[2] += n
    rules = [re.compile("^"+"".join(g).format(i)+"$") for i in range(1, max11)]
    results = [bool(rules[0].match(message)) for message in messages]
    for i in range(len(results)):
        if not results[i]:
            j = 1
            while j < len(rules):
                if rules[j].match(messages[i]):
                    results[i] = True
                    break
                j += 1
    return results

In [6]:
# Test part 2
start = time.time()
r, m = read_file("test02.txt")
print(sum(follows2(compile_rule2(0, r, {}), m)) == 12)
time.time() - start

True


0.005999326705932617

In [7]:
# Solve part 2
start = time.time()
r, m = read_file("input.txt")
print(sum(follows2(compile_rule2(0, r, {}), m)))
time.time() - start

379


0.08100676536560059