In [2]:
with open('./data/input_20.txt') as fh:
    file_input = fh.read()

In [14]:
import networkx as nx
from networkx.algorithms.shortest_paths.generic import has_path
from networkx import NetworkXNoPath
from functools import lru_cache
from collections import OrderedDict, defaultdict
from itertools import chain
import heapq

In [4]:
test = """         A           
         A           
  #######.#########  
  #######.........#  
  #######.#######.#  
  #######.#######.#  
  #######.#######.#  
  #####  B    ###.#  
BC...##  C    ###.#  
  ##.##       ###.#  
  ##...DE  F  ###.#  
  #####    G  ###.#  
  #########.#####.#  
DE..#######...###.#  
  #.#########.###.#  
FG..#########.....#  
  ###########.#####  
             Z       
             Z       """

In [5]:
test2 = """                   A               
                   A               
  #################.#############  
  #.#...#...................#.#.#  
  #.#.#.###.###.###.#########.#.#  
  #.#.#.......#...#.....#.#.#...#  
  #.#########.###.#####.#.#.###.#  
  #.............#.#.....#.......#  
  ###.###########.###.#####.#.#.#  
  #.....#        A   C    #.#.#.#  
  #######        S   P    #####.#  
  #.#...#                 #......VT
  #.#.#.#                 #.#####  
  #...#.#               YN....#.#  
  #.###.#                 #####.#  
DI....#.#                 #.....#  
  #####.#                 #.###.#  
ZZ......#               QG....#..AS
  ###.###                 #######  
JO..#.#.#                 #.....#  
  #.#.#.#                 ###.#.#  
  #...#..DI             BU....#..LF
  #####.#                 #.#####  
YN......#               VT..#....QG
  #.###.#                 #.###.#  
  #.#...#                 #.....#  
  ###.###    J L     J    #.#.###  
  #.....#    O F     P    #.#...#  
  #.###.#####.#.#####.#####.###.#  
  #...#.#.#...#.....#.....#.#...#  
  #.#####.###.###.#.#.#########.#  
  #...#.#.....#...#.#.#.#.....#.#  
  #.###.#####.###.###.#.#.#######  
  #.#.........#...#.............#  
  #########.###.###.#############  
           B   J   C               
           U   P   P               """

In [6]:
class L(list):
    def __new__(self, *args, **kwargs):
        return super(L, self).__new__(self, args, kwargs)

    def __init__(self, *args, **kwargs):
        if len(args) == 1 and hasattr(args[0], '__iter__'):
            list.__init__(self, args[0])
        else:
            list.__init__(self, args)
        self.__dict__.update(kwargs)

    def __call__(self, **kwargs):
        self.__dict__.update(kwargs)
        return self

In [7]:
def parse(txt):
    lines = L(txt.splitlines())
    lines.x = len(lines[0])
    lines.y = len(lines)
    return lines

def neighbors_short(pos, field):
    for x, y in ((pos[0]+1, pos[1]), (pos[0], pos[1]+1)):
        if x < field.x and x >= 0 and y < field.y and y >=0:
            yield (x, y), field[y][x] 

def neighbors_all(pos, field):
    for x, y in ((pos[0]+1, pos[1]), (pos[0], pos[1]+1), (pos[0]-1, pos[1]), (pos[0], pos[1]-1)):
        if x < field.x and x >= 0 and y < field.y and y >=0:
            yield (x, y), field[y][x]

wall = ord('#')
free = ord('.')
valid_targets = (free, wall)
    
def build_graph(field):
    G = nx.Graph()
    portals = defaultdict(list)
    for y, line in enumerate(field):
        for x, char in enumerate(line):
            o = ord(char)
            if o == free:
                G.add_node((x, y))                
                for npos, vngb in neighbors_short((x, y), field):
                    if ord(vngb) in valid_targets:
                        G.add_edge((x, y), npos)
            if o >= ord('A') and o <= ord('Z'):
                for npos, vngb in neighbors_short((x, y), field):
                    on = ord(vngb)
                    if on >= ord('A') and on <= ord('Z'):
                        for npos2, vngb2 in chain(neighbors_all((x, y), field), neighbors_all(npos, field)):
                            if vngb2 == '.':
                                portals[char + vngb].append(npos2)
                                break
    for p in portals:
        if len(portals[p]) == 2:
#             print(p)
            G.add_edge(portals[p][0], portals[p][1])
    return G, portals['AA'][0], portals['ZZ'][0]

In [8]:
field = parse(test)
G, start, end = build_graph(field)
path = nx.shortest_path(G, start, end)
print(len(path)-1)

23


In [9]:
field = parse(test2)
G, start, end = build_graph(field)
path = nx.shortest_path(G, start, end)
print(len(path)-1)

58


In [10]:
# part 1
field = parse(file_input)
G, start, end = build_graph(field)
path = nx.shortest_path(G, start, end)
print(len(path)-1)

626


In [11]:
# Part 2

In [12]:
test3 = """             Z L X W       C                 
             Z P Q B       K                 
  ###########.#.#.#.#######.###############  
  #...#.......#.#.......#.#.......#.#.#...#  
  ###.#.#.#.#.#.#.#.###.#.#.#######.#.#.###  
  #.#...#.#.#...#.#.#...#...#...#.#.......#  
  #.###.#######.###.###.#.###.###.#.#######  
  #...#.......#.#...#...#.............#...#  
  #.#########.#######.#.#######.#######.###  
  #...#.#    F       R I       Z    #.#.#.#  
  #.###.#    D       E C       H    #.#.#.#  
  #.#...#                           #...#.#  
  #.###.#                           #.###.#  
  #.#....OA                       WB..#.#..ZH
  #.###.#                           #.#.#.#  
CJ......#                           #.....#  
  #######                           #######  
  #.#....CK                         #......IC
  #.###.#                           #.###.#  
  #.....#                           #...#.#  
  ###.###                           #.#.#.#  
XF....#.#                         RF..#.#.#  
  #####.#                           #######  
  #......CJ                       NM..#...#  
  ###.#.#                           #.###.#  
RE....#.#                           #......RF
  ###.###        X   X       L      #.#.#.#  
  #.....#        F   Q       P      #.#.#.#  
  ###.###########.###.#######.#########.###  
  #.....#...#.....#.......#...#.....#.#...#  
  #####.#.###.#######.#######.###.###.#.#.#  
  #.......#.......#.#.#.#.#...#...#...#.#.#  
  #####.###.#####.#.#.#.#.###.###.#.###.###  
  #.......#.....#.#...#...............#...#  
  #############.#.#.###.###################  
               A O F   N                     
               A A D   M                     """

In [126]:
def build_graph_p2(field):
    G = nx.Graph()
    portals = defaultdict(list)
    for y, line in enumerate(field):
        for x, char in enumerate(line):
            o = ord(char)
            if o == free:
                G.add_node((x, y))                
                for npos, vngb in neighbors_short((x, y), field):
                    if ord(vngb) in valid_targets:
                        G.add_edge((x, y), npos)
            if o >= ord('A') and o <= ord('Z'):
                for npos, vngb in neighbors_short((x, y), field):
                    on = ord(vngb)
                    if on >= ord('A') and on <= ord('Z'):
                        for npos2, vngb2 in chain(neighbors_all((x, y), field), neighbors_all(npos, field)):
                            if vngb2 == '.':
                                portals[char + vngb].append(npos2)
                                break
    inner = {}
    outer = {}
    for p in portals:
        x = portals[p]
        if len(x) == 2:
            if 2 in x[0] or (x[0][0] == (field.x - 3)) or (x[0][1] == (field.y - 3)):
                inner[x[1]] = x[0]
                outer[x[0]] = x[1]
            else:
                inner[x[0]] = x[1]
                outer[x[1]] = x[0]
    return G, portals['AA'][0], portals['ZZ'][0], inner, outer, portals


@lru_cache(maxsize=None)
def ngbrs(pos):
    res = list(G[pos])
#     print(res, pos)
    if pos in inner:
        return (res, inner[pos], None)
    if pos in outer:
        return (res, None, outer[pos])
    return (res, None, None)

@lru_cache(maxsize=None)
def reachable_portals(pos):
    _in = []
    _out = []
    for i in inner:
        try:
            l = nx.shortest_path_length(G, pos, i)
            _in.append((i, l+1))
        except NetworkXNoPath:
            pass
    for o in outer:
        try:
            l = nx.shortest_path_length(G, pos, o)
            _out.append((o, l+1))
        except NetworkXNoPath:
            pass
    return (_in, _out)

def BFS2(G, start, end):
    start = start + (0,)
    end = end
    queue = [(start, 0, [])]
    visited = {}
    while queue:
        pos, length, path = queue.pop(0)
        x, y, z = pos
        if z == 0:
            try:
                l2goal = nx.shortest_path_length(G, (x, y), end)
                return length + l2goal, path
            except NetworkXNoPath:
                pass
        if pos in visited:
#             if length > visited[pos]:
#                 print("!", pos, length, visited[pos])
            continue
        visited[pos] = length
#         print(x, y)

        _in, _out = reachable_portals((x, y))
        if z > 0:
            for o in _out:
                ok, ol = o
                new_pos = outer[ok] + (z-1,)
                if new_pos not in visited:
#                     queue.insert(0, (new_pos, length+ol, path+[(ok, ol)]))
                    queue.append((new_pos, length+ol, path+[(ok, ol)]))
        for i in _in:
            ik, il = i
            new_pos = inner[ik] + (z+1,)
            if new_pos not in visited:
                queue.append((new_pos, length+il, path+[(ik, il)]))


def BFS(G, start, end):
    start = start + (0,)
    end = end + (0,)
    queue = [(start, 0)]
    visited = {}
    while queue:
        pos, length = queue.pop(0)
        x, y, z = pos
        if pos in visited:
#             if length < visited[pos]:
#                 print("!", pos, length, visited[pos])
            continue
        visited[pos] = length
#         print(x, y)
        ns, jmp_in, jmp_out = ngbrs((x, y))
        for n in ns:
            nf = n + (z,)
            if nf == end:
                print('reached end: ', length+1)
                return length+1
            if nf not in visited:
                queue.append((nf, length+1))
        if jmp_out and z > 0:
            queue.append((jmp_out + (z-1,), length+1))
        if jmp_in:
            queue.append((jmp_in + (z+1,), length+1))

In [123]:
field = parse(test3)
G, start, end, inner, outer, portals = build_graph_p2(field)

In [124]:
BFS(G, start, end)

reached end:  396


396

In [125]:
BFS2(G, start, end)[0]

396

In [128]:
field = parse(file_input)
G, start, end, inner, outer, portals = build_graph_p2(field)
BFS2(G, start, end)[0]

6912

In [96]:
portals

defaultdict(list,
            {'ZZ': [(13, 2)],
             'LP': [(15, 2), (29, 28)],
             'XQ': [(17, 2), (21, 28)],
             'WB': [(19, 2), (36, 13)],
             'CK': [(27, 2), (8, 17)],
             'FD': [(13, 8), (19, 34)],
             'RE': [(21, 8), (2, 25)],
             'IC': [(23, 8), (42, 17)],
             'ZH': [(31, 8), (42, 13)],
             'OA': [(8, 13), (17, 34)],
             'CJ': [(2, 15), (8, 23)],
             'XF': [(2, 21), (17, 28)],
             'RF': [(36, 21), (42, 25)],
             'NM': [(36, 23), (23, 34)],
             'AA': [(15, 34)]})

In [58]:
astar(G, start, end)

(0, 396, (13, 2, 0))

In [81]:
start

(15, 34)

In [68]:
reachable_portals(start)

([((17, 28), 16)], [((17, 34), 16)])

In [62]:
for i in inner:
    print(i)

(29, 28)
(21, 28)
(36, 13)
(8, 17)
(13, 8)
(21, 8)
(23, 8)
(31, 8)
(8, 13)
(8, 23)
(17, 28)
(36, 21)
(36, 23)


In [59]:
field = parse(file_input)
G, start, end, inner, outer, portals = build_graph_p2(field)
# BFS(G, start, end)
astar(G, start, end)

MemoryError: 

In [60]:
print(1)

1


In [87]:
def H(pos, end):
    x, y, z = pos
    ex, ey, ez = end
    res = abs(x - ex) + abs(y - ey)
    res += abs(z - ez) * 15
    return res
            
def astar(G, start, end):
    start = start + (0,)
    end = end + (0,)
    queue = []
    heapq.heappush(queue, (H(start, end), 0, start))
    queue_set = set([start])
    Gscore = {}
    Fscore = {}
#     visited = {}
    while queue:
        fscore, gscore, pos  = heapq.heappop(queue)
        if pos == end:
            return fscore, gscore, pos
        queue_set.remove(pos)
        x, y, z = pos
#         if pos in visited:
#             if length < visited[pos]:
#                 print("!", pos, length, visited[pos])
#             continue
#         visited[pos] = length
#         print(x, y)
        ns, jmp_in, jmp_out = ngbrs((x, y))
        nlist = [n + (z,) for n in ns]
        if jmp_out and z > 0:
            nlist += [(jmp_out + (z-1,))]
        if jmp_in:
            nlist += [(jmp_in + (z+1,))]
        for nf in nlist:
            tmp_gscore = gscore + 1

            if tmp_gscore < Gscore.get(nf, float("inf")):
                Gscore[nf] = tmp_gscore
                new_fscore = H(nf, end)
                Fscore[nf] = Gscore[nf] + new_fscore

                if nf not in queue_set:
                    queue_set.add(nf)
                    heapq.heappush(queue, (new_fscore, tmp_gscore, nf))
    return 'failure', Gscore, Fscore    