In [None]:
from pathlib import Path
from collections import defaultdict

In [None]:
test_input_1 = """NNCB

CH -> B
HH -> N
CB -> H
NH -> C
HB -> C
HC -> B
HN -> C
NN -> C
BH -> H
NC -> B
NB -> B
BN -> B
BB -> N
BC -> B
CC -> N
CN -> C
"""

input_1 = Path("input_1.txt").read_text()

In [None]:
def parse_input(input_string):
    polymer_template, insertions = input_string.split("\n\n")
    insertion_pairs = (row.split(" -> ") for row in insertions.strip().splitlines())
    insertion_rules = {pair: element for pair, element in insertion_pairs}
    return polymer_template, insertion_rules

def insertion_rules_to_bigrams_rules(insertion_rules):
    return {pair: (pair[0] + insertion, insertion + pair[1]) for pair, insertion in insertion_rules.items()}
    
def next_step(polymer, insertion_rules):
    new_polymer = ""
    for n, element in enumerate(polymer):
        new_polymer += element
        if n < len(polymer) - 1:
            next_element = polymer[n+1]
            new_polymer += insertion_rules.get(element + next_element, "")        
    return new_polymer

def step_n(steps, polymer, insertion_rules):
    for n in range(steps):
        polymer = next_step(polymer, insertion_rules)
    return polymer

def next_bigram_counts(bigram_counts, bigram_rules):
    new_counts = defaultdict(int)
    for existing_bigram, count in bigram_counts.items():
        for new_bigram in bigram_rules[existing_bigram]:
            new_counts[new_bigram] += count    
    return new_counts

def bigram_counts_affter_step(steps, polymer, insertion_rules):
    bigram_rules = insertion_rules_to_bigrams_rules(insertion_rules)
    
    initial_bigrams = (bigram for bigram in (
        element + polymer[n]
        for n, element in enumerate(polymer, start=1)
        if n < len(polymer)
    ))
    
    bigram_counts = defaultdict(int)
    for bigram in initial_bigrams:
        bigram_counts[bigram] += 1
    
    for _ in range(steps):
        bigram_counts = next_bigram_counts(bigram_counts, bigram_rules)
    return bigram_counts
    
def score_after_step(steps, polymer, insertion_rules):
    bigram_counts = bigram_counts_affter_step(steps, polymer, insertion_rules)
    element_counts = defaultdict(int)
    element_counts[polymer[-1]] = 1

    for pair, count in bigram_counts.items():
        element_counts[pair[0]] += count
    return max(element_counts.values()) - min(element_counts.values())

In [None]:
# Part 1 - Test
polymer_template, insertion_rules = parse_input(test_input_1)
assert next_step(polymer_template, insertion_rules) == "NCNBCHB"
assert step_n(2, polymer_template, insertion_rules) == "NBCCNBBBCBHCB"
assert step_n(3, polymer_template, insertion_rules) == "NBBBCNCCNBBNBNBBCHBHHBCHB"
assert step_n(4, polymer_template, insertion_rules) == "NBBNBNBBCCNBCNCCNBBNBBNBBBNBBNBBCBHCBHHNHCBBCBHCB"

assert score_after_step(10, polymer_template, insertion_rules) == 1588

In [None]:
# Part 1
polymer_template, insertion_rules = parse_input(input_1)
score_after_step(10, polymer_template, insertion_rules)

In [None]:
# Part 2 - test
polymer_template, insertion_rules = parse_input(test_input_1)
assert score_after_step(40, polymer_template, insertion_rules) == 2188189693529

In [None]:
# Part 2
polymer_template, insertion_rules = parse_input(input_1)
score_after_step(40, polymer_template, insertion_rules)