# Day 14: Extended Polymerization

In [1]:
from pathlib import Path
from collections import Counter, defaultdict
from more_itertools import take, pairwise, iterate, nth

from aoc2021.util import read_as_list

## Puzzle input data

In [2]:
def split_input(lines: list[str]) -> tuple:
    template = lines[0]
    rules = dict(tuple(line.split(' -> ')) for line in lines[2:])
    return template, rules

# Test data.
tdata = split_input([
    'NNCB',
    '',
    'CH -> B',  
    'HH -> N',  
    'CB -> H',  
    'NH -> C',  
    'HB -> C',  
    'HC -> B',  
    'HN -> C',  
    'NN -> C',  
    'BH -> H',  
    'NC -> B',  
    'NB -> B',  
    'BN -> B',  
    'BB -> N',  
    'BC -> B',  
    'CC -> N',  
    'CN -> C',
])

# Input data.
data = split_input(read_as_list(Path('./day14-input.txt'), func=str.rstrip))
data[0], dict(take(5, data[1].items()))

('SCSCSKKVVBKVFKSCCSOV',
 {'CP': 'C', 'SF': 'S', 'BH': 'F', 'SS': 'N', 'KB': 'N'})

## Puzzle answers
### Part 1

In [3]:
Rules = dict[str,str]
Input = tuple[str,Rules]


def step(formula: str, rules: Rules) -> str:
    return ''.join(a + rules.get(a+b, '') for a,b in pairwise(formula)) + formula[-1]


def formula_after(data: Input, nsteps: int) -> str:
    tpl, rules = data
    return nth(iterate(lambda s: step(s, rules), tpl), nsteps)


def solution(formula: str) -> int:    
    sort_freqs = sorted(Counter(formula).items(), key=lambda item: item[1])
    return sort_freqs[-1][1] - sort_freqs[0][1]


assert tdata[0] == 'NNCB'
assert step(tdata[0], tdata[1]) == 'NCNBCHB'
assert formula_after(tdata, 1) == 'NCNBCHB'
assert formula_after(tdata, 2) == 'NBCCNBBBCBHCB'
assert formula_after(tdata, 3) == 'NBBBCNCCNBBNBNBBCHBHHBCHB'
assert formula_after(tdata, 4) == 'NBBNBNBBCCNBCNCCNBBNBBNBBBNBBNBBCBHCBHHNHCBBCBHCB'
assert len(formula_after(tdata, 5)) == 97
assert len(formula_after(tdata, 10)) == 3073
assert solution('NBCCNBBBCBHCB') == 6-1
assert solution(formula_after(tdata, nsteps=10)) == 1588

In [4]:
n = solution(formula_after(data, nsteps=10))
print(f'Subtracting the frequency of the least common element from that of the most common after 10 steps: {n}')

Subtracting the frequency of the least common element from that of the most common after 10 steps: 2112


### Part 2

In [5]:
def pair_step(freqs: dict[str,int], rules: Rules) -> str:
    fs = defaultdict(int)
    for k, v in freqs.items():
        fs[k[0]+rules[k]] += v
        fs[rules[k]+k[1]] += v
    return fs


def pair_freqs(formula: str) -> dict[str,int]:
    return Counter(map(''.join, pairwise(formula)))


def pair_freqs_after(data: Input, nsteps: int) -> dict[str,int]:
    tpl, rules = data
    freqs = pair_freqs(tpl)
    return nth(iterate(lambda fs: pair_step(fs, rules), freqs), nsteps)


def first_pair_after(data: Input, nsteps: int) -> str:
    tpl, rules = data
    return nth(iterate(lambda s: step(s[:2], rules), tpl), nsteps)[:2]


def last_pair_after(data: Input, nsteps: int) -> str:
    tpl, rules = data
    return nth(iterate(lambda s: step(s[-2:], rules), tpl), nsteps)[-2:]


def serialise_pairs(ps: dict[str,int], fst: str, lst: str) -> dict[str,int]:
    ss = defaultdict(int)
    for k, v in ps.items():
        a, b = k
        ss[a] += v/2 if k != fst else 1+(v-1)/2
        ss[b] += v/2 if k != lst else 1+(v-1)/2
    return ss


def freqs_after(data: Input, nsteps: int) -> dict[str,int]:
    pfs = pair_freqs_after(data, nsteps)
    first = first_pair_after(data, nsteps)
    last = last_pair_after(data, nsteps)
    fs = serialise_pairs(pfs, first, last)
    return fs


def solution(data: Input, nsteps: int) -> int:
    fs = freqs_after(data, nsteps)
    sort_fs = sorted(fs.items(), key=lambda item: item[1])
    return int(sort_fs[-1][1] - sort_fs[0][1])


assert pair_freqs('NCNBCHB') == dict(NC=1,CN=1,NB=1,BC=1,CH=1,HB=1)
assert pair_freqs_after(tdata, 1) == dict(NC=1,CN=1,NB=1,BC=1,CH=1,HB=1)
assert pair_freqs_after(tdata, 2) == dict(NB=2,BC=2,CC=1,CN=1,BB=2,CB=2,BH=1,HC=1)
assert sum(pair_freqs_after(tdata, 5).values()) + 1 == 97
assert [first_pair_after(tdata, n) for n in range(1,5)] == ['NC','NB','NB','NB']
assert [last_pair_after(tdata, n) for n in range(1,5)] == ['HB','CB','HB','CB']
assert serialise_pairs(dict(NB=2,BC=2,CC=1,CN=1,BB=2,CB=2,BH=1,HC=1),'NB','CB') == Counter('NBCCNBBBCBHCB')
assert freqs_after(tdata, 10) == dict(B=1749, C=298, H=161, N=865)
assert solution(tdata, nsteps=10) == 1588
assert solution(tdata, nsteps=40) == 2188189693529

In [6]:
n = solution(data, nsteps=40)
print(f'Subtracting the frequency of the least common element from that of the most common after 40 steps: {n}')

Subtracting the frequency of the least common element from that of the most common after 40 steps: 3243771149914
