In [1]:
import pathlib
import collections

## part 1 ##

In [2]:
testlines = '''NNCB

CH -> B
HH -> N
CB -> H
NH -> C
HB -> C
HC -> B
HN -> C
NN -> C
BH -> H
NC -> B
NB -> B
BN -> B
BB -> N
BC -> B
CC -> N
CN -> C'''.splitlines()

In [3]:
puzzlelines = pathlib.Path('day14.txt').read_text().splitlines()

In [4]:
def parse(lines):
    template = lines[0]
    rules = {}
    for line in lines[2:]:
        lhs, rhs = line.split('->')
        rules[lhs.strip()] = rhs.strip()
    return template, rules

In [5]:
def get_pairs(s):
    return [''.join(p) for p in zip(s, s[1:])]

In [6]:
def insert(s, rules):
    pairs = get_pairs(s)
    res = []
    for pair in pairs:
        # Note the 2nd part of the pair isn't appended, because it will be the first part of the next pair
        res.append(pair[0])
        res.append(rules.get(pair, ''))
    # Do need to append the second part of the last pair, though, since no next pair to be had
    res.append(s[-1])
    return ''.join(res)


In [7]:
def solve1(lines, iterations):
    tmpl, rules = parse(lines)
    s = tmpl
    for i in range(iterations):
        s = insert(s, rules)
    c = collections.Counter(s)
    common = c.most_common()
    return common[0][1]-common[-1][1]

In [8]:
solve1(testlines, 10)

1588

In [9]:
solve1(puzzlelines, 10)

2112

## part 2 ##

Brute force takes too long. We don't need the actual strings, though, just the elemental formulae, so work w/
pairs instead of strings. Note that each insertion creates 2 pairs, but counting pairs doubles the characters 
except for the first and last.

In [10]:
pd = collections.defaultdict(int)

In [11]:
t, r = parse(testlines)
for pair in get_pairs(t):
    pd[pair] += 1
pd

defaultdict(int, {'NN': 1, 'NC': 1, 'CB': 1})

In [12]:
def insert2(pairdict, rules):
    newpd = collections.defaultdict(int)
    for rule in rules:
        if rule in pairdict:
            numpairs = pairdict[rule]
            insert = rules[rule]
            p1 = ''.join([rule[0], insert])
            p2 = ''.join([insert, rule[1]])
            newpd[p1] += numpairs
            newpd[p2] += numpairs
    return newpd

In [13]:
insert2(pd, r)

defaultdict(int, {'CH': 1, 'HB': 1, 'NC': 1, 'CN': 1, 'NB': 1, 'BC': 1})

In [14]:
def pairs_to_elements(pairdict, first, last):
    c = collections.defaultdict(int)
    for pair in pairdict:
        num = pairdict[pair]
        c1, c2 = pair[0], pair[1]
        c[c1] += num
        c[c2] += num
    c[first] += 1
    c[last] += 1
    for ele in c:
        c[ele] /= 2
    return c

In [15]:
pairs_to_elements(insert2(pd, r), t[0], t[-1])

defaultdict(int, {'C': 2.0, 'H': 1.0, 'B': 2.0, 'N': 2.0})

In [18]:
def solve2(lines, iterations):
    tmpl, rules = parse(lines)
    pd = collections.defaultdict(int)
    for pair in get_pairs(tmpl):
        pd[pair] += 1
    newpd = pd.copy()
    for i in range(iterations):
        newpd = insert2(newpd, rules)
    elements = pairs_to_elements(newpd, tmpl[0], tmpl[-1])
    sortedkeys = sorted(elements, key=elements.get)
    return elements[sortedkeys[-1]] - elements[sortedkeys[0]]
    

In [16]:
tp = pairs_to_elements(insert2(pd, r), t[0], t[-1])
tp

defaultdict(int, {'C': 2.0, 'H': 1.0, 'B': 2.0, 'N': 2.0})

In [17]:
sorted(tp, key=tp.get)

['H', 'C', 'B', 'N']

In [19]:
solve2(testlines, 10)

1588.0

In [20]:
solve2(puzzlelines, 10)

2112.0

In [21]:
solve2(puzzlelines, 40)

3243771149914.0