In [1]:
class Pattern:
    value: str
    subpatterns: list
    def __init__(self, value=None, subpatterns=None):
        if value is not None:
            self.value = value
            self.subpatterns = None
        else:
            self.value = None
            self.subpatterns = subpatterns
    def __repr__(self):
        if self.value is not None:
            return f'"{self.value}"'
        else:
            return " | ".join(" ".join(str(p) for p in l) for l in self.subpatterns)
    def __str__(self):
        return repr(self)

In [2]:
def get_input(fname="input.txt"):
    rules = {}
    messages = []
    with open(fname) as f:
        lines = [line.strip() for line in f.readlines()]
        split = lines.index('')
        for rule in lines[:split]:
            id, patternlist = rule.split(": ")
            id = int(id)
            rules[id] = []
            subpatterns = []
            for pattern in patternlist.split(" | "):
                if pattern[0] == '"':
                    rules[id] = Pattern(pattern.strip('"'))
                else:
                    subpatterns.append([int(r) for r in pattern.split(" ")])
            if len(subpatterns) > 0:
                rules[id] = Pattern(subpatterns=subpatterns)
        messages = lines[split+1:]
    return rules, messages

In [3]:
test_rules, test_messages = get_input("test.txt")

In [4]:
test_rules

{0: 4 1 5, 1: 2 3 | 3 2, 2: 4 4 | 5 5, 3: 4 5 | 5 4, 4: "a", 5: "b"}

In [5]:
rules, messages = get_input("input.txt")

In [6]:
def matches(message, rules, rule=0, start=0):
    if start >= len(message):
        return []
    if rules[rule].value is not None:
        if rules[rule].value == message[start]:
            return [(rule, start + 1)]
        return []
    matched = []
    for pattern in rules[rule].subpatterns:
        q = [start]
        for p in pattern:
            new_q = []
            for s in q:
                m = matches(message, rules, p, s)
                for match in m:
                    new_q.append(match[1])
            q = new_q
            if len(q) == 0:
                break
        matched += [(rule, s) for s in q] 
    return matched

In [7]:
test_matches = [matches(message, test_rules) for message in test_messages]
matching = 0
for i, m in enumerate(test_matches):
    for rule, end in m:
        if end == len(test_messages[i]):
            matching += 1
print(matching)

2


In [8]:
input_matches = [matches(message, rules) for message in messages]
matching = 0
for i, m in enumerate(input_matches):
    for rule, end in m:
        if end == len(messages[i]):
            matching += 1
print(matching)

285


In [9]:
test2_rules, test2_messages = get_input("test2.txt")

In [10]:
test2_matches = [matches(message, test2_rules) for message in test2_messages]
matching = 0
for i, m in enumerate(test2_matches):
    for rule, end in m:
        if end == len(test2_messages[i]):
            matching += 1
print(matching)

3


In [11]:
test2_rules[8] = Pattern(subpatterns=[[42], [42, 8]])
test2_rules[11] = Pattern(subpatterns=[[42, 31], [42, 11, 31]])
test2_matches = [matches(message, test2_rules) for message in test2_messages]
matching = 0
for i, m in enumerate(test2_matches):
    for rule, end in m:
        if end == len(test2_messages[i]):
            matching += 1
print(matching)

12


In [12]:
rules[8] = Pattern(subpatterns=[[42], [42, 8]])
rules[11] = Pattern(subpatterns=[[42, 31], [42, 11, 31]])

In [13]:
%%time
input_matches = [matches(message, rules) for message in messages]
matching = 0
for i, m in enumerate(input_matches):
    for rule, end in m:
        if end == len(messages[i]):
            matching += 1
print(matching)

412
CPU times: user 752 ms, sys: 4.9 ms, total: 757 ms
Wall time: 762 ms
