In [1]:
SAMPLE_TEXT = """
NNCB

CH -> B
HH -> N
CB -> H
NH -> C
HB -> C
HC -> B
HN -> C
NN -> C
BH -> H
NC -> B
NB -> B
BN -> B
BB -> N
BC -> B
CC -> N
CN -> C
"""

In [2]:
def tokenize_line(line):
    if " -> " in line:
        ab, c = line.split(" -> ")
        return "mapping", ab, c
    else:
        return "chars", list(line)


def parse_text(raw_text):
    return [tokenize_line(l) for l in raw_text.split("\n") if l]


def read_input():
    with open("input.txt", "rt") as f:
        return f.read()

def split_by_type(lines):
    chars = (l for l in lines if l[0] == 'chars')
    mapping = (l for l in lines if l[0] == 'mapping')
    return list(chars)[0][1], {tuple(list(m[1])): m[2] for m in mapping}

In [3]:
from collections import Counter, defaultdict
from functools import cache

In [4]:
# Naive approach that expands an actual list of characters
def expand(chars, mapping):
    result = [chars[0]]
    for i in range(1, len(chars)):
        result.append(mapping[chars[i-1], chars[i]])
        result.append(chars[i])
    return result

def expand_n_times(chars, mapping, n):
    result = chars[:]
    for i in range(n):
        # print(i + 1)
        result = expand(result, mapping)
    return result

In [6]:
def expand_n_times_v2(chars, mapping, n):
    # A slightly better approach that uses recursive descent
    count = defaultdict(int)

    def descend(a, b, n):
        c = mapping[a, b]
        count[c] += 1
        if n > 1:
            descend(a, c, n - 1)
            descend(c, b, n - 1)

    count[chars[0]] += 1
    # print(0, chars[0])
    for i in range(1, len(chars)):
        # print(i, chars[i])
        count[chars[i]] += 1
        descend(chars[i - 1], chars[i], n)
    return count


In [7]:
def expand_n_times_v3(chars, mapping, n):
    # V2 still wasn't fast enough. Added memoization to the recursive
    # call since there's a relative small combination of (letter 1, letter 2, n)
    count = Counter(chars)

    @cache
    def descend(a, b, n):
        c = mapping[a, b]
        result = Counter(c)
        if n > 1:
            result += descend(a, c, n - 1)
            result += descend(c, b, n - 1)
        return result

    # print(0, chars[0])
    for i in range(1, len(chars)):
        # print(i, chars[i])
        count += descend(chars[i - 1], chars[i], n)
    return count

In [8]:
chars, mappings = split_by_type(parse_text(SAMPLE_TEXT))
chars = expand_n_times(chars, mappings, 10)
count = Counter(chars)
count.most_common()[0], count.most_common()[-1]

(('B', 1749), ('H', 161))

In [9]:
chars, mappings = split_by_type(parse_text(SAMPLE_TEXT))
count = expand_n_times_v2(chars, mappings, 10)
max(count.values()), min(count.values())

(1749, 161)

In [10]:
chars, mappings = split_by_type(parse_text(SAMPLE_TEXT))
count = expand_n_times_v3(chars, mappings, 10)
count.most_common()[0], count.most_common()[-1]

(('B', 1749), ('H', 161))

In [11]:
# Comparing performance of different approaches
%timeit expand_n_times(chars, mappings, 10)

913 µs ± 9.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [12]:
%timeit expand_n_times_v2(chars, mappings, 10)

973 µs ± 2.91 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [13]:
# Not super clear if the cache is being reset between runs here
%timeit expand_n_times_v3(chars, mappings, 10)

612 µs ± 77.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


In [14]:
chars, mappings = split_by_type(parse_text(read_input()))
chars = expand_n_times(chars, mappings, 10)
count = Counter(chars)
count.most_common()[0], count.most_common()[-1]

(('S', 2979), ('P', 976))

In [15]:
chars, mappings = split_by_type(parse_text(read_input()))
count = expand_n_times_v2(chars, mappings, 10)
max(count.values()), min(count.values())

(2979, 976)

In [16]:
chars, mappings = split_by_type(parse_text(read_input()))
count = expand_n_times_v3(chars, mappings, 10)
count.most_common()[0], count.most_common()[-1]

(('S', 2979), ('P', 976))

In [20]:
2979 - 976

2003

In [17]:
# part 2 - with sample text
chars, mappings = split_by_type(parse_text(SAMPLE_TEXT))
count = expand_n_times_v3(chars, mappings, 40)
count.most_common()[0], count.most_common()[-1]

(('B', 2192039569602), ('H', 3849876073))

In [46]:
2192039569602 - 3849876073

2188189693529

In [18]:
# part 2 - problem text
chars, mappings = split_by_type(parse_text(read_input()))
count = expand_n_times_v3(chars, mappings, 40)
count.most_common()[0], count.most_common()[-1]

(('B', 3225985458057), ('P', 949341457946))

In [19]:
3225985458057 - 949341457946

2276644000111