In [1]:
from itertools import combinations, pairwise, permutations,product
from typing import NamedTuple
import re
from collections import Counter

class Point(NamedTuple):
    row: int
    col: int

    def __add__(self, other):
        return Point(self.row + other.row, self.col + other.col)

    def __sub__(self, other):
        return Point(self.row - other.row, self.col - other.col)

    def vectors(self, other):
        v = (other - self).clamp()
        return set(p for p in (Point(0, v.col), Point(v.row, 0)) if p != Point(0, 0))

    def clamp(self, min_v=-1, max_v=1):
        return Point(
            row = max(min_v, min(self.row, max_v)),
            col = max(min_v, min(self.col, max_v))
        )
                  
class Numpad():

    directions = {
        Point(1, 0): 'v',
        Point(-1, 0): '^',
        Point(0, 1): '>',
        Point(0, -1): '<',
    }
    
    def __init__(self, pad, controller=False):
        self.controller = controller
        self.key_lookup = {}
        self.numpad = pad
        
        for row, line in enumerate(self.numpad):
            for col, char in enumerate(line):
                self.key_lookup[char] = Point(row, col)       
            

    def best(self, start, end):
        start = self.key_lookup[start]
        end = self.key_lookup[end]
        
        dir = (end - start).clamp()
        diff = start - end

        # group all horizontal and vertial movemenments
        # to prefer moves <<^ of ^>< to <^< which requires
        # more back and forth to the A key
        ns = 'v' * abs(diff.row) if dir.row == 1 else ('^'  * abs(diff.row) if dir.row == -1 else '')
        ew = '>' * abs(diff.col) if dir.col == 1 else ('<'  * abs(diff.col) if dir.col == -1 else '')

        # To help avoid crossing the empty key on the numpad
        numpad_crosses_empty = set([start, end]).issubset(set([Point(0,0),Point(1,0), Point(2, 0), Point(3,1), Point(3,2)]))
        
        if self.controller:
            if end == Point(1, 0):
            # can only approach the `<` from the side
            # so you need to start vertical then hotizontal                
                return ns + ew
            elif start == Point(1, 0):
            # can only leave `<` from the side
            # so go horiz before vertical
                return ew + ns
            if dir == Point(-1, -1) or dir == Point(1, -1):
                # This was a bit of guess and check.
                # It means when travelling west
                # prefer <^ tp ^< but otherwise prefer the opposite.
                # all combinations work on small iterations
                # but this produces the lowest on high iterations 🤷
                return ew + ns
            return ns + ew
        else:
            if numpad_crosses_empty:
                if dir == Point (1, 1):
                    # going soutwest, don't go down into the 
                    # empty space. Strangely, thei does not seem to
                    # change the answer.
                    return ew + ns
                return ns + ew
            if dir == Point (1, 1):
                return ns + ew
            else:
                return ew + ns


def iterate(n, s, debug=False, control_only=False):
    if control_only:
        controls = s
    else:
        controls =  'A'.join([numpad.best(a, b) for a, b in pairwise(s)])+'A'
    if debug:
        print(controls, len(controls))
    for i in range(n):
        controls = 'A'.join([controlpad.best(a, b) for a, b in pairwise('A'+controls)])+'A'
        if debug:
            print(controls, len(controls))
    return controls

def get_nums(s):
    return int(re.search('\d+', s).group(0))


### Tools:

Make a numpad controller and directional controller. The make a lookup dict to all the possible pairs of directional controllers which is helpful in part two.


In [2]:
nums = [
    '789',
    '456',
    '123',
    '#0A'
]
numpad = Numpad(nums)

controls = [
    '#^A',
    '<v>'
]
controlpad = Numpad(controls, True)

chars = 'v^<>A'
lookup = {}
for a,b in permutations(chars, r=2):
    lookup[a+b] = controlpad.best(a, b)


The primary insight is that the intuition that the order of moves doesn't matter because all manhattan distances take the same amount of steps is wrong.

This is a consequence of the `<` key being further away that the others.  


In [3]:
with open('input_files/21.txt') as f:
    raw = f.read().splitlines()

print([len(iterate(2, 'A'+s))  for s in raw])
print(sum([len(iterate(2, 'A'+s)) * get_nums(s)   for s in raw]))


[72, 68, 72, 70, 70]
125742


In [4]:
from itertools import groupby

def build_dict(s):
    a_count = 0
    counts = Counter()

    for (a, la), (b, lb) in pairwise((k, len(list(g))) for k, g in groupby(s)):
        if la > 1:
            a_count += la - 1
        counts[a+b] += 1
    return counts, a_count


def next_dict(d, a_count):
    counts = Counter()
    for k, count in d.items():   
        for a, b in pairwise(k):
            if a+b == 'AA':
                a_count += 1
                continue
            next_k = lookup[a+b]
            for a, b in pairwise('A' + next_k +'A'):
                if a == b:
                    a_count += count
                    continue
                counts[a+b] += count
    return counts, a_count

def count_iterate(s, i):
    s = 'A'.join([numpad.best(a, b) for a, b in pairwise(s)])+'A'
    d, a_count = build_dict('A' + s)
    for _ in range(i):
        d, a_count = next_dict(d, a_count)
    return d.total() + a_count


In [5]:
with open('input_files/21.txt') as f:
    raw = f.read().splitlines()

n = 2
# compare to slow version for debugging
print([len(iterate(n, 'A'+code))  for code in raw])
print([count_iterate('A' + code, n)  for code in raw])


[72, 68, 72, 70, 70]
[72, 68, 72, 70, 70]


In [6]:
with open('input_files/21.txt') as f:
    raw = f.read().splitlines()
n = 25
print(sum([count_iterate('A' + code, n) * get_nums(code)  for code in raw]))

157055032722640
