# A* Algorithm
[link](https://www.algoexpert.io/questions/A*%20Algorithm)

## My Solution

In [1]:
class GraphNode:
    def __init__(self, coordinate):
        self.coordinate = coordinate
        self.distanceFromSource = None
        self.heuristic = None
        self.f = None # f = g + h = distanceFromSource + heuristic
        self.prevNode = None

def aStarAlgorithm(startRow, startCol, endRow, endCol, graph):
    # Write your code here.
    startCoord = (startRow, startCol)
    startNode = GraphNode(startCoord)
    startNode.distanceFromSource = 0
    startNode.heuristic = manhattanDistance(startCoord, (endRow, endCol))
    startNode.f = startNode.distanceFromSource + startNode.heuristic
    
    frontiers = minArray()
    frontiers.add(startNode)

    visited = {}
    lastNode = None
    while len(frontiers.nodes) != 0:
        curNode = frontiers.popMin()
        if curNode.coordinate == (endRow, endCol):
            lastNode = curNode
            break
        visited[curNode.coordinate] = curNode
        
        checkNeighbor(curNode, graph, frontiers, visited, (endRow, endCol))
    if lastNode is None:
        return []
    res = []
    curNode = lastNode
    while curNode is not None:
        curCoord = list(curNode.coordinate)
        res.append(curCoord)
        curNode = curNode.prevNode
    return list(reversed(res))

def manhattanDistance(coordA, coordB):
    return abs(coordA[0] - coordB[0]) + abs(coordA[1] - coordB[1])

def checkNeighbor(node, graph, frontiers, visited, endCoord):
    nodeCoord = node.coordinate

    curRow, curCol = nodeCoord[0] - 1, nodeCoord[1]
    curCoord = (curRow, curCol)
    if curRow >= 0 and curCoord not in visited and graph[curRow][curCol] == 0:
        process(node, curCoord, endCoord, frontiers)

    curRow, curCol = nodeCoord[0] + 1, nodeCoord[1]
    curCoord = (curRow, curCol)
    if curRow < len(graph) and curCoord not in visited and graph[curRow][curCol] == 0:
        process(node, curCoord, endCoord, frontiers)

    curRow, curCol = nodeCoord[0], nodeCoord[1] - 1
    curCoord = (curRow, curCol)
    if curCol >= 0 and curCoord not in visited and graph[curRow][curCol] == 0:
        process(node, curCoord, endCoord, frontiers)

    curRow, curCol = nodeCoord[0], nodeCoord[1] + 1
    curCoord = (curRow, curCol)
    if curCol < len(graph[0]) and curCoord not in visited and graph[curRow][curCol] == 0:
        process(node, curCoord, endCoord, frontiers)
        
def process(node, curCoord, endCoord, frontiers):
    g = node.distanceFromSource + 1
    h = manhattanDistance(curCoord, endCoord)
    f = g + h
    newNode = GraphNode(curCoord)
    newNode.distanceFromSource = g
    newNode.heuristic = h
    newNode.f = f
    newNode.prevNode = node
    if curCoord not in frontiers.nodes:
        frontiers.add(newNode)
    else:
        if f < frontiers.nodes[curCoord].f:
            frontiers.nodes[curCoord] = newNode

class minArray:
    def __init__(self):
        self.nodes = {}
        
    def add(self, node):
        self.nodes[node.coordinate] = node
        
    def remove(self, coord):
        return self.nodes.pop(coord)

    def popMin(self):
        if len(self.nodes) == 0:
            return None
        coords = list(self.nodes.keys())
        minCoord = coords[0]
        minNode = self.nodes[minCoord]
        for coord in coords:
            node = self.nodes[coord]
            if node.f < minNode.f:
                minCoord = coord
                minNode = node
        return self.nodes.pop(minCoord)

In [1]:
class GraphNode:
    def __init__(self, coordinate, distanceFromSource=None, heuristic=None, prevNode=None):
        self.coordinate = coordinate
        self.distanceFromSource = distanceFromSource
        self.heuristic = heuristic
        self.f = distanceFromSource + heuristic # f = g + h = distanceFromSource + heuristic
        self.prevNode = prevNode

def aStarAlgorithm(startRow, startCol, endRow, endCol, graph):
    # Write your code here.
    startCoord = (startRow, startCol)
    startNode = GraphNode(startCoord, 0, manhattanDistance(startCoord, (endRow, endCol)))
    
    frontiers = minArray()
    frontiers.add(startNode)

    visited = {}
    lastNode = None
    while len(frontiers.nodes) != 0:
        curNode = frontiers.popMin()
        if curNode.coordinate == (endRow, endCol):
            lastNode = curNode
            break
            
        visited[curNode.coordinate] = curNode
        checkNeighbor(curNode, graph, frontiers, visited, (endRow, endCol))
        
    if lastNode is None:
        return []
    res = []
    curNode = lastNode
    while curNode is not None:
        curCoord = list(curNode.coordinate)
        res.append(curCoord)
        curNode = curNode.prevNode
    return list(reversed(res))

def manhattanDistance(coordA, coordB):
    return abs(coordA[0] - coordB[0]) + abs(coordA[1] - coordB[1])

def checkNeighbor(node, graph, frontiers, visited, endCoord):
    nodeCoord = node.coordinate
    neighborCoords = []
    if nodeCoord[0] - 1 >= 0:
        neighborCoords.append((nodeCoord[0] - 1, nodeCoord[1]))
    if nodeCoord[0] + 1 < len(graph):
        neighborCoords.append((nodeCoord[0] + 1, nodeCoord[1]))
    if nodeCoord[1] - 1 >= 0:
        neighborCoords.append((nodeCoord[0], nodeCoord[1] - 1))
    if nodeCoord[1] + 1 < len(graph[0]):
        neighborCoords.append((nodeCoord[0], nodeCoord[1] + 1))
        
    for curRow, curCol in neighborCoords:
        curCoord = (curRow, curCol)
        if curCoord not in visited and graph[curRow][curCol] == 0:
            g = node.distanceFromSource + 1
            h = manhattanDistance(curCoord, endCoord)
            newNode = GraphNode(curCoord, g, h, node)
            
            if curCoord not in frontiers.nodes:
                frontiers.add(newNode)
            else:
                if g + h < frontiers.nodes[curCoord].f:
                    frontiers.nodes[curCoord] = newNode

class minArray:
    def __init__(self):
        self.nodes = {}
        
    def add(self, node):
        self.nodes[node.coordinate] = node
        
    def remove(self, coord):
        return self.nodes.pop(coord)

    def popMin(self):
        if len(self.nodes) == 0:
            return None
        coords = list(self.nodes.keys())
        minCoord = coords[0]
        minNode = self.nodes[minCoord]
        for coord in coords:
            node = self.nodes[coord]
            if node.f < minNode.f:
                minCoord = coord
                minNode = node
        return self.nodes.pop(minCoord)

In [None]:
class GraphNode:
    def __init__(self, coordinate, distanceFromSource=None, heuristic=None, prevNode=None):
        self.coordinate = coordinate
        self.distanceFromSource = distanceFromSource # g
        self.heuristic = heuristic # h
        self.f = distanceFromSource + heuristic # f = g + h = distanceFromSource + heuristic
        self.prevNode = prevNode
    
    def __lt__(self, other):
        return self.f < other.f
    
    def __le__(self, other):
        return self.f <= other.f

def aStarAlgorithm(startRow, startCol, endRow, endCol, graph):
    # Write your code here.
    startCoord = (startRow, startCol)
    startNode = GraphNode(startCoord, 0, manhattanDistance(startCoord, (endRow, endCol)))
    
    frontiers = minHeap([startNode])
    visited = {}
    lastNode = None
    
    while frontiers.isNotEmpty():
        curNode = frontiers.remove()
        if curNode.coordinate == (endRow, endCol):
            lastNode = curNode
            break
            
        visited[curNode.coordinate] = True
        checkNeighbor(curNode, graph, frontiers, visited, (endRow, endCol))
        
    if lastNode is None:
        return []
    
    return constructPath(lastNode)

def manhattanDistance(coordA, coordB):
    return abs(coordA[0] - coordB[0]) + abs(coordA[1] - coordB[1])

def checkNeighbor(node, graph, frontiers, visited, endCoord):
    nodeCoord = node.coordinate
    neighborCoords = []
    if nodeCoord[0] - 1 >= 0:
        neighborCoords.append((nodeCoord[0] - 1, nodeCoord[1]))
    if nodeCoord[0] + 1 < len(graph):
        neighborCoords.append((nodeCoord[0] + 1, nodeCoord[1]))
    if nodeCoord[1] - 1 >= 0:
        neighborCoords.append((nodeCoord[0], nodeCoord[1] - 1))
    if nodeCoord[1] + 1 < len(graph[0]):
        neighborCoords.append((nodeCoord[0], nodeCoord[1] + 1))
        
    for curRow, curCol in neighborCoords:
        curCoord = (curRow, curCol)
        if curCoord not in visited and graph[curRow][curCol] == 0:
            g = node.distanceFromSource + 1
            h = manhattanDistance(curCoord, endCoord)
            newNode = GraphNode(curCoord, g, h, node)
            
            if curCoord not in frontiers.nodes:
                frontiers.insert(newNode) # O(log(n)) time | O(1) space
            else:
                if newNode < frontiers.heap[frontiers.nodes[curCoord]]:
                    frontiers.update(curCoord, newNode) # O(log(n)) time | O(1) space
                    
def constructPath(lastNode):
    res = []
    curNode = lastNode
    while curNode is not None:
        curCoord = list(curNode.coordinate)
        res.append(curCoord)
        curNode = curNode.prevNode
    return list(reversed(res))

class minHeap:
    def __init__(self, array):
        self.heap = self.buildHeap(array)
        self.nodes = {x.coordinate: idx for idx, x in enumerate(self.heap)}

    def buildHeap(self, array):
        # Write your code here.
        self.heap = [x for x in array]
        finalIdx = len(self.heap) - 1
        finalParentIdx = (finalIdx - 1) // 2
        for i in reversed(range(finalParentIdx + 1)):
            self.heapifyDown(i)
        return self.heap
    
    def heapifyDown(self, idx):
        while idx < len(self.heap):
            if 2 * idx + 1 >= len(self.heap):
                break
            
            elif 2 * idx + 1 < len(self.heap) and 2 * idx + 2 >= len(self.heap):
                if self.heap[idx] > self.heap[2 * idx + 1]:
                    self.switch(idx, 2 * idx + 1)
                    idx = 2 * idx + 1
                else:
                    break
            elif 2 * idx + 2 < len(self.heap):
                smallerIdx = 2 * idx + 1 if self.heap[2 * idx + 1] <= self.heap[2 * idx + 2] else 2 * idx + 2
                if self.heap[idx] > self.heap[smallerIdx]:
                    self.switch(idx, smallerIdx)
                    idx = smallerIdx
                else:
                    break
        
    def switch(self, i, j):
        nodeACoord = self.heap[i].coordinate
        nodeBCoord = self.heap[j].coordinate
        self.nodes[nodeACoord], self.nodes[nodeBCoord] = j, i
        self.heap[i], self.heap[j] = self.heap[j], self.heap[i]

    def siftDown(self):
        # Write your code here.
        self.heapifyDown(0)
        
    def heapifyUp(self, idx):
        while idx > 0:
            parentIdx = (idx - 1) // 2
            if self.heap[parentIdx] > self.heap[idx]:
                self.switch(parentIdx, idx)
                idx = parentIdx
            else:
                break

    def siftUp(self):
        # Write your code here.
        self.heapifyUp(len(self.heap) - 1)

    def remove(self):
        # Write your code here.
        self.switch(0, len(self.heap) - 1)
        topNode = self.heap.pop()
        self.nodes.pop(topNode.coordinate)
        self.siftDown()
        return topNode

    def insert(self, node):
        # Write your code here.
        self.heap.append(node)
        self.nodes[node.coordinate] = len(self.heap) - 1
        self.siftUp()
        
    def update(self, coord, newNode):
        # O(log(n)) time | O(1) space
        idx = self.nodes[coord]
        self.heap[idx] = newNode
        self.heapifyUp(idx)
        
    def isNotEmpty(self):
        return len(self.heap) > 0

## Expert Solution

In [None]:
class Node:
    def __init__(self, row, col, value):
        self.id = str(row) + "-" + str(col)
        self.row = row
        self.col = col
        self.value = value
        self.distanceFromStart = float("inf")
        self.estimateDistanceToEnd = float("inf")
        self.cameFrom = None


# O(w * h * log(w * h)) time | O(w * h) space - where
# w is the width of the graph and h is the height
def aStarAlgorithm(startRow, startCol, endRow, endCol, graph):
    nodes = initializeNodes(graph)

    startNode = nodes[startRow][startCol]
    endNode = nodes[endRow][endCol]

    startNode.distanceFromStart = 0
    startNode.estimateDistanceToEnd = calculateManhattanDistance(startNode, endNode)

    nodesToVisit = MinHeap([startNode])

    while not nodesToVisit.isEmpty():
        currentMinDistanceNode = nodesToVisit.remove()

        if currentMinDistanceNode == endNode:
            break

        neighbors = getNeighboringNodes(currentMinDistanceNode, nodes)
        for neighbor in neighbors:
            if neighbor.value == 1:
                continue

            tentativeDistanceToNeighbor = currentMinDistanceNode.distanceFromStart + 1

            if tentativeDistanceToNeighbor >= neighbor.distanceFromStart:
                continue

            neighbor.cameFrom = currentMinDistanceNode
            neighbor.distanceFromStart = tentativeDistanceToNeighbor
            neighbor.estimateDistanceToEnd = tentativeDistanceToNeighbor + calculateManhattanDistance(
                neighbor, endNode
            )

            if not nodesToVisit.containsNode(neighbor):
                nodesToVisit.insert(neighbor)
            else:
                nodesToVisit.update(neighbor)

    return reconstructPath(endNode)

def initializeNodes(graph):
    nodes = []

    for i, row in enumerate(graph):
        nodes.append([])
        for j, value in enumerate(row):
            nodes[i].append(Node(i, j, value))

    return nodes

def calculateManhattanDistance(currentNode, endNode):
    currentRow = currentNode.row
    currentCol = currentNode.col
    endRow = endNode.row
    endCol = endNode.col

    return abs(currentRow - endRow) + abs(currentCol - endCol)

def getNeighboringNodes(node, nodes):
    neighbors = []

    numRows = len(nodes)
    numCols = len(nodes[0])

    row = node.row
    col = node.col

    if row < numRows - 1: # DOWN
        neighbors.append(nodes[row + 1][col])
    
    if row > 0 : # UP
        neighbors.append(nodes[row - 1][col])

    if col < numCols - 1: # RIGHT
        neighbors.append(nodes[row][col + 1])
    
    if col > 0 : # LEFT
        neighbors.append(nodes[row][col - 1])

    return neighbors

def reconstructPath(endNode):
    if not endNode.cameFrom:
        return []

    currentNode = endNode
    path = []

    while currentNode is not None:
        path.append([currentNode.row, currentNode.col])
        currentNode = currentNode.cameFrom

    return path[::-1] # reverse path so it goes from start to end

class MinHeap:
    def __init__(self, array):
        # Holds the position in the heap that each node is at
        self.nodePositionsInHeap = {node.id: idx for idx, node in enumerate(array)}
        self.heap = self.buildHeap(array)

    def isEmpty(self):
        return len(self.heap) == 0

    # O(n) time | O(1) space
    def buildHeap(self, array):
        firstParentIdx = (len(array) - 2) // 2
        for currentIdx in reversed(range(firstParentIdx + 1)):
            self.siftDown(currentIdx, len(array) - 1, array)
        return array

    # O(log(n)) time | O(1) time 
    def siftDown(self, currentIdx, endIdx, heap):
        childOneIdx = currentIdx * 2 + 1
        while childOneIdx <= endIdx:
            childTwoIdx = currentIdx * 2 + 2 if currentIdx * 2 + 2 <= endIdx else -1
            if (
                childTwoIdx != -1 
                and heap[childTwoIdx].estimateDistanceToEnd < heap[childOneIdx].estimateDistanceToEnd
            ):
                idxToSwap = childTwoIdx
            else:
                idxToSwap = childOneIdx
            if heap[idxToSwap].estimateDistanceToEnd < heap[currentIdx].estimateDistanceToEnd:
                self.swap(currentIdx, idxToSwap, heap)
                currentIdx = idxToSwap
                childOneIdx = currentIdx * 2 + 1
            else:
                return

    # O(log(n)) time | O(1) space
    def siftUp(self, currentIdx, heap):
        parentIdx = (currentIdx - 1) // 2
        while currentIdx > 0 and heap[currentIdx].estimateDistanceToEnd < heap[parentIdx].estimateDistanceToEnd:
            self.swap(currentIdx, parentIdx, heap)
            currentIdx = parentIdx
            parentIdx = (currentIdx - 1)// 2

    # O(log(n)) time | O(1) space
    def remove(self):
        if self.isEmpty():
            return

        self.swap(0, len(self.heap) - 1, self.heap)
        node = self.heap.pop()
        del self.nodePositionsInHeap[node.id]
        self.siftDown(0, len(self.heap) - 1, self.heap)
        return node
    
    # O(log(n)) time | O(1) space
    def insert(self, node):
        self.heap.append(node)
        self.nodePositionsInHeap[node.id] = len(self.heap) - 1
        self.siftUp(len(self.heap) - 1, self.heap)
    
    def swap(self, i, j, heap):
        self.nodePositionsInHeap[heap[i].id] = j
        self.nodePositionsInHeap[heap[j].id] = i
        heap[i], heap[j] = heap[j], heap[i]

    def containsNode(self, node):
        return node.id in self.nodePositionsInHeap

    def update(self, node):
        self.siftUp(self.nodePositionsInHeap[node.id], self.heap)


## Thoughts
[algorithm video](https://www.youtube.com/watch?v=eSOJ3ARN5FM)