In [1]:
import numpy as np
from collections import Counter
from functools import cache

In [2]:
test_input = '''NNCB

CH -> B
HH -> N
CB -> H
NH -> C
HB -> C
HC -> B
HN -> C
NN -> C
BH -> H
NC -> B
NB -> B
BN -> B
BB -> N
BC -> B
CC -> N
CN -> C'''

- NN
- NCN
- NBCCN
- NBBBCNCCN

In [3]:
puzzle_input = open('inputs/14').read().strip()

In [4]:
def pair(s):
    return [''.join(p) for p in zip(s, s[1:])]

In [5]:
def collapse(strings):
    final = ''
    
    for s in strings[:-1]:
        final += s[:-1]
    
    final += strings[-1]
    
    return final

In [6]:
def p1(puzzle_input, steps=10):
    start, transition_lines = puzzle_input.split('\n\n')
    
    transition_map = {a: b for [a, b] in map(lambda s: s.split(' -> '), transition_lines.split('\n'))}
    full_transition_map = {a: a[0] + b + a[1] for a, b in transition_map.items()}
    
    working_string = start

    for step in range(steps):
        # if pair in full_transition_map else pair
        pair_strings = pair(working_string)
        working_string = collapse([full_transition_map[pair] for pair in pair_strings])

    mc = Counter(working_string).most_common()
    
    return mc[0][1] - mc[-1][1]

In [7]:
assert p1(test_input) == 1588

In [8]:
p1(puzzle_input)

2851

In [9]:
def p2(puzzle_input, steps=40):
    start, transition_lines = puzzle_input.split('\n\n')

    transition_map = {a: b for (a, b) in map(lambda s: s.split(' -> '), transition_lines.split('\n'))}
    
    letters = set(transition_map.values())
    letters_to_index = {l: i for i, l in enumerate(letters)}
    
    def fresh_counts():
        return np.zeros(len(letters))
    
    def letter_counts(s):
        counts = fresh_counts()

        for c in s:
            counts[letters_to_index[c]] += 1

        return counts
    
    @cache
    def evolve(pair, step):
        '''
        Letter counts for pair after n steps. 
        '''
        
        if step == 0:
            return letter_counts(pair)

        new_letter = transition_map[pair]
        new_letter_count = letter_counts(new_letter)

        return evolve(pair[0] + new_letter, step-1) + evolve(new_letter + pair[1], step-1) - new_letter_count
    
    counts = fresh_counts()

    for p in pair(start):
        counts += evolve(p, steps)

    # subtract one from the middle_letters: they appear twice
    for m in start[1:-1]:
        counts -= letter_counts(m)
        
    return int(counts.max() - counts.min())

In [10]:
def p2_2(puzzle_input, steps=40):
    start, transition_lines = puzzle_input.split('\n\n')

    transition_map = {a: b for (a, b) in map(lambda s: s.split(' -> '), transition_lines.split('\n'))}
    full_transition_map = {a: a[0] + b + a[1] for a, b in transition_map.items()}
    
    letters = set(transition_map.values())
    letters_to_index = {l: i for i, l in enumerate(letters)}
    
    def letter_counts(s):
        counts = np.zeros(len(letters))

        for c in s:
            counts[letters_to_index[c]] += 1

        return counts
    
    @cache
    def evolve(s, step):
        '''
        Letter counts for string after n steps. 
        '''
        if step == 0:
            return letter_counts(s)
        
        new_s = collapse([full_transition_map[p] for p in pair(s)])
        middle_letter_counts = letter_counts(new_s[1:-1])

        return sum(evolve(p, step-1) for p in pair(new_s)) - middle_letter_counts
    
    counts = evolve(start, steps)

    return int(counts.max() - counts.min())

In [11]:
assert p2(test_input) == 2188189693529

In [12]:
assert p2_2(test_input) == 2188189693529

In [13]:
p2(puzzle_input)

10002813279337

In [14]:
p2_2(puzzle_input)

10002813279337