In [658]:
import numpy as np
from functools import cache
from typing import Tuple
from collections import Counter, defaultdict
import heapq
from dataclasses import dataclass, field
from typing import Any

In [771]:
WALL = '#'
EMPTY = '.'
COSTS = {
    'A': 1,
    'B': 10,
    'C': 100,
    'D': 1000
}

def is_amphipod(c):
    return c[0] in 'ABCD'

def is_wall(c):
    return c == WALL

def is_empty(c):
    return c == EMPTY

class Map:
    def __init__(self, map_str, rooms, doorways):
        self.rooms = rooms
        self.doorways = doorways
        self.parse_map(map_str)
        
    def parse_map(self, map_str):
        map_str = map_str.strip()
        w, h = map_str.index('\n'), map_str.count('\n')+1
        self.map = np.full((h, w), WALL)
        initial_positions = []
        for i, line in enumerate(map_str.splitlines()):
            for j, c in enumerate(line.rstrip()):
                if is_amphipod(c):
                    self.map[i, j] = EMPTY
                    initial_positions.append((c, (i, j)))
                else:
                    self.map[i, j] = c.strip() or WALL
        self.initial_state = MapState.from_positions(initial_positions)
              
    @cache
    def neighbors(self, pos):
        i, j = pos
        ns = []
        for di, dj in ((-1, 0), (0, -1), (1, 0), (0, 1)):
            ni, nj = i+di, j+dj
            if (0 <= ni < self.map.shape[0] and 0 <= nj < self.map.shape[1] 
                and is_empty(self.map[ni, nj])):
                ns.append((ni, nj))
        return tuple(ns)
    
    @cache
    def hallway_positions(self):
        return set((1, j) for j in range(1, self.map.shape[-1] - 1)) - self.doorways
    
    @cache
    def is_room(self, pos):
        for amp, rooms in self.rooms.items():
            if pos in rooms: return amp
        return False
                
    @cache
    def in_own_room(self, amp, pos):
        return pos in self.rooms[amp[0]]
    
    @cache
    def in_wrong_room(self, amp, pos):
        return any(a != amp[0] and pos in rooms for (a, rooms) in self.rooms.items())
    
    @cache
    def can_stop_at(self, amp, pos):
        return pos not in self.doorways and not self.in_wrong_room(amp, pos)
    
    @cache
    def heuristic_cost(self, amp, p1, p2):
        h1 = p1[0] - 1
        h2 = p2[0] - 1
        w = abs(p2[1]-p1[1])
        return COSTS[amp[0]]*(h1+h2+w)
    
    def next_room_position(self, amp, occupied):
        for room in reversed(self.rooms[amp[0]]):
            if room not in occupied:
                return room
    
def is_room_open(amp, m, assignments):
    amp = amp[0]
    return all(a[0] == amp for (a, pos) in assignments if pos in m.rooms[amp[0]])

@cache
def heuristic_cost(m, amps, positions):
    assignments = list(zip(amps, positions))
    occupied = {pos: amp for amp, pos in assignments
                if m.in_own_room(amp, pos) and is_room_open(amp, m, assignments)}
    cost = 0
    for amp, pos in assignments:
        if pos in occupied and occupied[pos] == amp: continue
        new = m.next_room_position(amp, occupied.keys())
        cost += m.heuristic_cost(amp, pos, new)
        occupied[new] = amp
    return cost

def show_map(m, state):
    map = m.map.copy()
    for c, (i, j) in zip(state.amps, state.positions):
        map[i, j] = c[0]
    print('\n'.join(''.join(r) for r in map))
    print('Cost:', state.cost, '| Correct:', state.correct(m))

In [845]:
@dataclass(frozen=True)
class MapState:
    amps: Tuple
    positions: Tuple
    moves: Tuple
    cost: int = 0
        
    @staticmethod
    def from_positions(initial_positions):
        c = Counter()
        amps = []
        positions = []
        moves = []
        for amp, pos in initial_positions:
            c[amp] += 1
            amp = f'{amp}{c[amp]}'
            amps.append(amp)
            positions.append(pos)
            moves.append(2)
        return MapState(tuple(amps), tuple(positions), tuple(moves))
    
    def assignments(self):
        return zip(self.amps, self.positions)

    def assignments_for(self, amp):
        return sorted(pos for (a, pos) in zip(self.amps, self.positions) if a[0] == amp[0])
    
    def is_room_full(self, amp, m):
        amp = amp[0]
        return sorted(m.rooms[amp]) == self.assignments_for(amp)
    
    def is_room_open(self, amp, m):
        amp = amp[0]
        return all(a[0] == amp for (a, pos) in self.assignments() if pos in m.rooms[amp[0]])
    
    def all_rooms_full(self, m):
        return self.correct(m) == len(self.amps)
    
    def correct(self, m):
        correct = 0
        for amp, pos in self.assignments():
            if pos in m.rooms[amp[0]]:
                correct += 1
        return correct
        
    def is_impossible(self, m):
        for amp, pos, remaining in zip(self.amps, self.positions, self.moves):
            if pos not in m.rooms[amp[0]] and remaining <= 0:
                return True
        return False
    
    def can_stop_at(self, amp, pos, m):
        if not m.can_stop_at(amp, pos):
            return False
        if a := m.is_room(pos):
            return self.is_room_open(a, m)
        return True

    def move(self, amp, new, cost):
        idx = self.amps.index(amp)
        new_positions = tuple(new if i == idx else old for (i, old) in enumerate(self.positions))
        new_moves = tuple(old-1 if i == idx else old for (i, old) in enumerate(self.moves))
        new_cost = self.cost + cost
        return MapState(self.amps, new_positions, new_moves, new_cost)
    
    def moves_from(self, amp, pos, m):
        occupied = set(self.positions)
        ok_room = m.next_room_position(amp, occupied)
        
        to_visit = [(pos, 0)]
        visited = set()
        moves = set()
        
        while to_visit:
            pos, cost = to_visit.pop()
            visited.add(pos)
            for new in m.neighbors(pos):
                if new in visited or new in occupied:
                    continue
                new_cost = cost + COSTS[amp[0]]
                if (not m.is_room(new) or new == ok_room) and self.can_stop_at(amp, new, m):
                    yield new_cost, amp, new
                to_visit.append((new, new_cost))   

    def moves_within(self, m):
        for amp, pos, moves_remaining in zip(self.amps, self.positions, self.moves):
            if not moves_remaining: continue
            #if self.is_room_full(amp[0], m): continue
            if m.in_own_room(amp, pos) and self.is_room_open(amp, m): continue
            yield from self.moves_from(amp, pos, m)


@dataclass(order=True)
class PrioritizedState:
    cost: int
    n_correct: int
    state: Any=field(compare=False)
        
def find_minimal_cost_solution(m, limit=200000):
    best = float('inf')
    states = [PrioritizedState(0, 0, m.initial_state)]
    visited = set()
    iterations = 0
    while states:
        item = heapq.heappop(states)
        positions = item.state.positions
        if item.cost >= best or item.state in visited:
            continue
        else:
            visited.add(item.state)
        if item.state.all_rooms_full(m) and item.cost < best:
            best = item.cost
            print('YAY!!')
            show_map(m, item.state)
            print()
            break
        for cost, amp, new in item.state.moves_within(m):
            new_state = item.state.move(amp, new, cost)
            distance = heuristic_cost(m, new_state.amps, new_state.positions)
            total_cost = new_state.cost + distance
#             show_map(m, new_state)
#             print(best)
#             print()
            if (total_cost >= best or new_state in visited):
                continue
            if new_state.is_impossible(m): continue
            heapq.heappush(states, 
                           PrioritizedState(total_cost, -new_state.correct(m), new_state))
        iterations += 1
        if iterations > limit:
            print('early stopping')
            break
    return best

In [846]:
map_str = """#############
#...........#
###C#A#B#D###
  #B#A#D#C#
  #########"""
rooms = {
    'A': ((2, 3), (3, 3)), 
    'B': ((2, 5), (3, 5)),
    'C': ((2, 7), (3, 7)),
    'D': ((2, 9), (3, 9)),
}
doorways = {
    (1, 3), (1, 5), (1, 7), (1, 9)
}

m = Map(map_str, rooms, doorways)

In [847]:
heuristic_cost(m, m.initial_state.amps, m.initial_state.positions)

9310

In [849]:
%%time
find_minimal_cost_solution(m, 100000)

YAY!!
#############
#...........#
###A#B#C#D###
###A#B#C#D###
#############
Cost: 11516 | Correct: 8

CPU times: user 12.7 s, sys: 3.93 ms, total: 12.7 s
Wall time: 12.7 s


11516

In [850]:
map_str = """#############
#...........#
###C#A#B#D###
  #D#C#B#A#
  #D#B#A#C#
  #B#A#D#C#
  #########"""
rooms = {
    'A': ((2, 3), (3, 3), (4, 3), (5, 3)), 
    'B': ((2, 5), (3, 5), (4, 5), (5, 5)),
    'C': ((2, 7), (3, 7), (4, 7), (5, 7)),
    'D': ((2, 9), (3, 9), (4, 9), (5, 9)),
}
doorways = {
    (1, 3), (1, 5), (1, 7), (1, 9)
}

m = Map(map_str, rooms, doorways)

In [851]:
%%time
find_minimal_cost_solution(m, limit=500000)

YAY!!
#############
#...........#
###A#B#C#D###
###A#B#C#D###
###A#B#C#D###
###A#B#C#D###
#############
Cost: 40272 | Correct: 16

CPU times: user 45.6 s, sys: 3.83 ms, total: 45.6 s
Wall time: 45.7 s


40272