# --- `Day 14`: Extended Polymerization ---

In [1]:
import aocd
import re
import operator
from collections import Counter, defaultdict, deque
from itertools import combinations
from functools import reduce, lru_cache

def prod(iterable):
    return reduce(operator.mul, iterable, 1)

def count(iterable, predicate = bool):
    return sum([1 for item in iterable if predicate(item)])

def first(iterable, default = None):
    return next(iter(iterable), default)

def lmap(func, *iterables):
    return list(map(func, *iterables))

def ints(s):
    return lmap(int, re.findall(r"-?\d+", s))

def words(s):
    return re.findall(r"[a-zA-Z]+", s)

def list_diff(x):
    return [b - a for a, b in zip(x, x[1:])]

def binary_to_int(lst):
    return int("".join(str(i) for i in lst), 2)

def get_column(lst, index):
    return [x[index] for x in lst]

In [2]:
def parse_line(line): 
    return str(line)
    
def parse_input(input):
    return list(map(parse_line, input.splitlines()))

In [3]:
final_input = parse_input(aocd.get_data(day=14, year=2021))
print(final_input[:5])

['BSONBHNSSCFPSFOPHKPK', '', 'PF -> P', 'KO -> H', 'CH -> K']


In [5]:
test_input = parse_input('''\
NNCB

CH -> B
HH -> N
CB -> H
NH -> C
HB -> C
HC -> B
HN -> C
NN -> C
BH -> H
NC -> B
NB -> B
BN -> B
BB -> N
BC -> B
CC -> N
CN -> C
''')

print(test_input)

['NNCB', '', 'CH -> B', 'HH -> N', 'CB -> H', 'NH -> C', 'HB -> C', 'HC -> B', 'HN -> C', 'NN -> C', 'BH -> H', 'NC -> B', 'NB -> B', 'BN -> B', 'BB -> N', 'BC -> B', 'CC -> N', 'CN -> C']


### Helpers

## Solution 1

In [67]:
def solve_1(input, times):
    readTemplate = False
    rules = {}
    for n in input:
        if readTemplate:
            if n != "":
                a,b = n.split(" -> ")
                rules[a] = b
        else:
            template = n
            readTemplate = True
    
    sets = defaultdict(int)
    for n in range(len(template) - 1):
        sets[template[n] + template[n+1]] += 1
    #print(sets)
        
    for i in range(times):
        newSet = defaultdict(int)
        for n in sets:
            newSet[n[0] + rules[n]] += sets[n]
            newSet[rules[n] + n[1]] += sets[n]
        sets = newSet
        #print(sets)
    
    counts = defaultdict(int)
    for j,k in sets.items():
        counts[j[0]] += k
    counts[template[-1]] += 1
    #print(counts)
    k = max(b for a,b in counts.items())
    l = min(b for a,b in counts.items())
    return k - l  

solve_1(test_input, 10)

1588

In [68]:
f"Solution 1: {solve_1(final_input, 10)}"

'Solution 1: 2740'

## Solution 2

In [69]:
def solve_2(input):
    return solve_1(input, 40)
    
solve_2(test_input)

2188189693529

In [70]:
f"Solution 2: {solve_2(final_input)}"

'Solution 2: 2959788056211'