In [1]:
from typing import NamedTuple
from collections import deque, Counter
from itertools import combinations

### Part One

The plan is to perform a breadth-first search from the end and note the distance of every point on the map. Points that are not reachable from the end, or are blocks will have a distance of `inf` just to keep things simple. Also save a set of the block locations to make it easy to iterate over them.

In part one, for each block, look at its neighbor points and their distance to the end. If the difference of these distances is greater than two, you can save time, but removing the blocks. This should run in reasonable linear time (not counting the BFS).

### Part Two
This worked, but is slow.
For each path point A, compare it with all points B that are withing a manhattan distance of 20. If the A's distance to the end it longer than walking the manattan distance plus B's distance, then it's a better path and saves the difference. Could be sped up a bit by caching the the comparison since we compare A and B and later compare B and A. 

It seems like the input allows for some sloppiness. There are no edge cases where there is more than one path between start and end going through the same point. IOW - no branching. 🤷 I guess we should take it!



In [2]:
class Point(NamedTuple):
    row: int
    col: int

    def __add__(self, other):
        return Point(self.row + other.row, self.col + other.col)

    def m_dist(self, other):
        '''Manhattan Distance'''
        return abs(self.col - other.col) + abs(self.row - other.row)
    

class Map:
    directions = [Point(0, 1),Point(0, -1),Point(-1, 0),Point(1, 0)] 
    
    def __init__(self, s):
        self.map = s

        self.start = None
        self.end = None
        self.w = len(s[0])
        self.h = len(s)
        self.blocks = set()
        self.distances = {}
        
        for row, line in enumerate(s):
            for col, c in enumerate(line):
                if c == '#':
                    self.blocks.add(Point(row, col))
                if c == 'S':
                    self.start = Point(row, col)
                if c == 'E':
                    self.end = Point(row, col)

    def neighbors(self, point):
        for d in Map.directions:
            p = point + d
            if 0 <= p.row < self.h and 0 <= p.col < self.w and self.map[p.row][p.col] != '#':
                yield p
                        
        
    def bfs(self, start, stop):
        seen = set([start])
        d = deque([(start, 0)])
        dist = None
        while d:
            loc, distance = d.popleft()

            # Keep track of distances of every findable point
            self.distances[loc] = distance
            
            if loc == stop:
                dist =  distance
            
            for n in self.neighbors(loc):
                if n in self.blocks or n in seen:
                    continue
                d.append((n, distance + 1))
                seen.add(n)
        
        return dist
        
    def cheat(self, block):
        '''return savings if the block is removed'''    
        for a, b in (combinations(filter(lambda p: p in m.distances, self.neighbors(block)), r=2)):
            savings = abs(self.distances[a] -  self.distances[b]) - 2
            if savings >= 2: 
                return savings
        return 0

    def points_at_distance(self, point, d):
        # Return all points within a manahattan distance of d from p
        for row in range(max(0, point.row - d), min(self.h, point.row + d + 1)):
            offset = d - abs(row - point.row)
            for col in range(max(0, point.col - offset), min(self.w,  point.col + offset + 1)):
                dest = Point(row, col)
                if dest != point and dest in self.distances:
                    yield Point(row, col)
   
    def cheat2(self):
        # Find point within a 20 step manhattan distance walk 
        # if that point + the walk are closer, it's a better path
        counts = Counter()
        for source, dist in self.distances.items():
            for dest in self.points_at_distance(source, 20):
                m_dist = source.m_dist(dest)
                if dist > self.distances[dest] + m_dist:
                    savings = dist - (self.distances[dest] + m_dist)
                    counts[savings] += 1
                         
        return counts



In [3]:
s='''###############
#...#...#.....#
#.#.#.#.#.###.#
#S#...#.#.#...#
#######.#.#.###
#######.#.#...#
#######.#.###.#
###..E#...#...#
###.#######.###
#...###...#...#
#.#####.#.###.#
#.#...#.#.#...#
#.#.#.#.#.#.###
#...#...#...###
###############'''.split('\n')


m = Map(s)
m.bfs(m.end, m.start)

from collections import Counter
counts = Counter()
for b in m.blocks:
    savings =  m.cheat(b)
    if savings:
       counts[savings] += 1

for d, count in sorted(counts.items(), key=lambda i: i[0]):
    print(f"{count} that save {d}")


14 that save 2
14 that save 4
2 that save 6
4 that save 8
2 that save 10
3 that save 12
1 that save 20
1 that save 36
1 that save 38
1 that save 40
1 that save 64


In [4]:
counts2 = m.cheat2()

for d, count in sorted(counts2.items(), key=lambda i: i[0]):
    if d >= 50:
        print(f"{count} that save {d}")

32 that save 50
31 that save 52
29 that save 54
39 that save 56
25 that save 58
23 that save 60
20 that save 62
19 that save 64
12 that save 66
14 that save 68
12 that save 70
22 that save 72
4 that save 74
3 that save 76


In [5]:
with open('input_files/20.txt') as f:
    raw = f.read().splitlines()

m = Map(raw)
m.bfs(m.end, m.start)
count = 0
for b in m.blocks:
    savings = m.cheat(b)
    if savings and savings >= 100:
       count += 1

print("part one:", count)


part one: 1426


In [6]:
counts2 = m.cheat2()

In [7]:
total = 0
for d, count in counts2.items():
    if d >= 100:
        total += count
print("part two:", total)

part two: 1000697
