In [119]:
import numpy as np
from collections import defaultdict

In [120]:
def read_input(infile):
    track = []
    start, end  = (), ()
    with open(infile, 'r') as inf:
        y = 0
        for line in inf.readlines():
            x = 0
            for v in line.strip():
                if v == 'S':
                    start = (y, x)
                elif v == 'E':
                    end = (y, x)
                elif v == '.':
                    track.append((y, x))
                x+=1
            y+=1

    track.append(end)
    return track, start, end, y, x
                    

In [151]:
def find_path(track, start, end, ny, nx):
    path = np.zeros((ny, nx), dtype=int)
    path += int(1e4)

    y, x = start
    path[y, x] = 0
    d = 1
    while (y, x) != end:
        for dy, dx, in [(0,1), (1,0), (0,-1), (-1,0)]:
            ny, nx = dy+y, dx+x
            if (ny, nx) in track and path[ny, nx] == 1e4:
                path[ny, nx] = d
                y, x = ny, nx
                d+=1
                break
    return path                

def find_path_alt(track, start, end, ny, nx):

    y, x = start
    path = [(y, x)]
    while (y, x) != end:
        for dy, dx, in [(0,1), (1,0), (0,-1), (-1,0)]:
            ny, nx = dy+y, dx+x
            if (ny, nx) in track and (ny, nx) not in path:
                path.append((ny, nx))
                y, x = ny, nx
                break
    return path                

def find_cheats(track, start, end, ny, nx):
    path = find_path(track, start, end, ny, nx)

    clen = 100

    dp_y = np.abs((path[1:-2,] - path[3:,]))-2
    n = len(np.argwhere((dp_y >= clen) & (dp_y < 900000)))
    dp_x = np.abs((path[:,1:-2] - path[:,3:]))-2
    n += len(np.argwhere((dp_x >= clen) & (dp_x < 900000)))
    return n

def find_all_cheats(clen, track, start, end, ny, nx):
    path = find_path_alt(track, start, end, ny, nx)

    cd = defaultdict(set)
    nc = 100
    cheats = set()
    for i, (y,x) in enumerate(path[0:-1]):
        for j, (ny, nx) in enumerate(path[i+1:]):
            dist = (abs(ny-y) + abs(nx-x))
            if dist <= clen:
                if j-dist+1 >= 100:
                    cheats.add((y, x, ny, nx))
                    cd[j-dist+1].add((y, x, ny, nx))

    return len(cheats)

In [152]:
print('*******\nPuzzle1\n*******\n')

print('Puzzle case\n-----------\n')

res = find_all_cheats(2, *read_input('input20.txt'))

print(f'Number of cheats is {res}')

assert res == 1406

print('\n*******\nPuzzle2\n*******\n')

print('Puzzle case\n-----------\n')

res = find_all_cheats(20, *read_input('input20.txt'))

print(f'Number of cheats is {res}')

assert res == 1006101


*******
Puzzle1
*******

Puzzle case
-----------

Number of cheats is 1406

*******
Puzzle2
*******

Puzzle case
-----------

Number of cheats is 1006101
