In [None]:
import heapq
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

def update_plot(frame, grid, paths):
    plt.cla()
    plt.imshow(grid, cmap='binary')

    for i, path in enumerate(paths):
        if frame < len(path):
            x, y = path[frame]
            plt.text(y, x, str(i+1), ha='center',
                     va='center', fontsize=10, color='red')

    # plt.gca().invert_yaxis()
    plt.xticks(range(len(grid[0])))
    plt.yticks(range(len(grid)))
    
import heapq
from collections import defaultdict


class AStarPlanner:
    def __init__(self, grid, start, goal):
        self.grid = grid
        self.start = start
        self.goal = goal
        self.rows = len(grid)
        self.cols = len(grid[0])

    def is_valid(self, pos):
        x, y = pos
        return 0 <= x < self.rows and 0 <= y < self.cols and self.grid[x][y] == 0

    def heuristic(self, pos):
        return abs(pos[0] - self.goal[0]) + abs(pos[1] - self.goal[1])
    
    def find_path(self, occupied):
        """Find a path using space-time A* with 'wait' as a valid action."""
        open_set = []
        heapq.heappush(open_set, (0 + self.heuristic(self.start), 0, self.start, 0))  # (f, g, position, time)
        came_from = {}
        g_score = defaultdict(lambda: float('inf'))
        g_score[(self.start, 0)] = 0

        while open_set:
            _, current_cost, current, current_time = heapq.heappop(open_set)

            if current == self.goal:
                return self.reconstruct_path(came_from, current, current_time)

            for dx, dy in [(-1, 0), (1, 0), (0, -1), (0, 1), (0, 0)]:
                neighbor = (current[0] + dx, current[1] + dy)
                neighbor_time = current_time + 1

                if not self.is_valid(neighbor) or (neighbor, neighbor_time) in occupied or ((current, neighbor), current_time) in occupied:
                    continue

                tentative_g_score =  g_score[(current, current_time)] + 1
                if tentative_g_score < g_score[(neighbor, neighbor_time)]:
                    g_score[(neighbor, neighbor_time)] = tentative_g_score
                    priority = tentative_g_score + self.heuristic(neighbor)
                    heapq.heappush(open_set, (priority, tentative_g_score, neighbor, neighbor_time))
                    came_from[(neighbor, neighbor_time)] = (current, current_time)

        return None  # No path found

    def reconstruct_path(self, came_from, current, current_time):
        path = []
        while (current, current_time) in came_from:
            path.append(current)
            current, current_time = came_from[(current, current_time)]
        path.append(self.start)
        return path[::-1]


class CBS:
    def __init__(self, grid, starts, goals):
        self.grid = grid
        self.starts = starts
        self.goals = goals
        self.num_agents = len(starts)
        self.paths = [None] * self.num_agents
        self.queue = []

    def detect_conflicts(self, paths):
        """Detect conflicts between agents."""
        conflicts = []
        occupied = defaultdict(list)
        edges = defaultdict(list)

        for t in range(max(len(p) for p in paths if p)):
            for i, path in enumerate(paths):
                if path and t < len(path):
                    occupied[(path[t], t)].append(i)
                if path and t + 1 < len(path):
                    start, end = path[t], path[t + 1]
                    if start != end:
                        edge = (start, end, t)
                        edges[edge].append(i)


        for pos_time, agents in occupied.items():
            if len(agents) > 1:
                conflicts.append({"type": "vertex", "time": pos_time[1], "pos": pos_time[0], "agents": agents})

        visited_edge_time = set()
        for edge_time, agents in edges.items():
            reverse_edge = (edge_time[1], edge_time[0], edge_time[2])
            if reverse_edge in edges and edge_time not in visited_edge_time:
                agents += edges[reverse_edge]
                visited_edge_time.add(reverse_edge)
                conflicts.append({"type": "edge", "time": edge_time[2], "edge": (edge_time[:2], reverse_edge[:2]), "agents": agents})

        # sort based on time
        conflicts = sorted(conflicts, key=lambda x: x["time"])

        return conflicts

    def resolve_conflict(self, conflict):
        """Add constraints to resolve conflicts."""
        constraints = []
        a1, a2 = conflict["agents"][0], conflict["agents"][1]

        if conflict["type"] == "vertex":
            pos, time = conflict["pos"], conflict["time"]
            for a in conflict["agents"]:
                constraints.append({"agent": a, "type": "vertex", "pos": pos, "time": time})

        elif conflict["type"] == "edge":
            edge, time = conflict["edge"], conflict["time"]
            constraints.append({"agent": a1, "type": "edge", "edge": edge[0], "time": time})
            constraints.append({"agent": a2, "type": "edge", "edge": edge[1], "time": time})

        return constraints

    def solve(self):
        root = {"paths": [self.find_path(i, []) for i in range(self.num_agents)], "constraints": []}
        self.queue.append(root)

        while self.queue:
            node = self.queue.pop(0)

            conflicts = self.detect_conflicts(node["paths"])
            if not conflicts:
                return node["paths"]

            conflict = conflicts[0]
            new_constraints = self.resolve_conflict(conflict)

            for constraint in new_constraints:
                new_node = {
                    "paths": node["paths"][:],
                    "constraints": node["constraints"] + [constraint],
                }
                agent = constraint["agent"]
                new_node["paths"][agent] = self.find_path(agent, new_node["constraints"])

                if new_node["paths"][agent] is not None:
                    self.queue.append(new_node)

        return None  # No solution found

    def find_path(self, agent, constraints):
        """Find a path for an agent considering the constraints."""
        occupied = set()
        edges = set()

        for c in constraints:
            if c["agent"] == agent:
                if c["type"] == "vertex":
                    occupied.add((c["pos"], c["time"]))
                elif c["type"] == "edge":
                    edges.add((c["edge"], c["time"]))

        planner = AStarPlanner(self.grid, self.starts[agent], self.goals[agent])
        return planner.find_path(occupied.union(edges))

# Example Usage
if __name__ == "__main__":
    grid = [
        [0, 0, 0, 0, 0],
        [0, 1, 1, 1, 0],
        [0, 1, 0, 0, 0],
        [1, 0, 1, 1, 0],
        [0, 0, 0, 0, 0],
    ]

    starts = [(4, 0), (4, 1)]
    goals = [(0, 4), (4, 0)]

    cbs = CBS(grid, starts, goals)
    paths = cbs.solve()

    if paths:
        for i, path in enumerate(paths):
            print(f"Agent {i}: {path}")
    else:
        print("No solution found")


Agent 0: [(4, 0), (4, 1), (4, 2), (4, 3), (4, 4), (3, 4), (2, 4), (1, 4), (0, 4)]
Agent 1: [(4, 1), (3, 1), (4, 1), (4, 0)]


In [None]:
fig = plt.figure()
ani = FuncAnimation(fig, update_plot, frames=max(len(path)
                    for path in paths), fargs=(grid, paths), interval=500)
ani.save('paths.gif', writer='imagemagick')
plt.show()