In [2]:
"""
https://leetcode.com/problems/cut-off-trees-for-golf-event/

You are asked to cut off all the trees in a forest for a golf event. The forest is represented as an m x n matrix. In this matrix:

0 means the cell cannot be walked through.
1 represents an empty cell that can be walked through.
A number greater than 1 represents a tree in a cell that can be walked through, and this number is the tree's height.
In one step, you can walk in any of the four directions: north, east, south, and west. 
If you are standing in a cell with a tree, you can choose whether to cut it off.

You must cut off the trees in order from shortest to tallest. When you cut off a tree, the value at its cell becomes 1 (an empty cell).

Starting from the point (0, 0), return the minimum steps you need to walk to cut off all the trees. 
If you cannot cut off all the trees, return -1.

Note: The input is generated such that no two trees have the same height, and there is at least one tree needs to be cut off.


Constraints:

m == forest.length
n == forest[i].length
1 <= m, n <= 50
0 <= forest[i][j] <= 10^9
Heights of all trees are distinct.
"""

from collections import deque
def cutTrees(grid):
    if grid == None or len(grid)==0 or len(grid[0])==0:
        return 0
    M = len(grid)
    N = len(grid[0])

    # sort tree heights by height
    # 2->3->4->5->6->7->8
    #  numbers are unique
    #  at least one tree to cut

    treeLocs = {} # height -> (i,j); there is a guarantee that no two trees have the same height
    treeHeights = []
    for i in range(M):
        for j in range(N):
            if grid[i][j] == 0:
                continue
            treeHeights.append(grid[i][j])
            treeLocs[ grid[i][j] ] = (i, j)
    treeHeights.sort()
    numTrees = len(treeHeights)

    # print(f"{numTrees=}")
    # print(f"{treeHeights=}")
    # print(f"{treeLocs=}")

    def _closestPath( si, sj, ei, ej ):
        """ BFS """
        if si==ei and sj==ej:
            return 0
        visited = set()
        toVisit = [ (si, sj, []) ]
        while len(toVisit):
            toVisitNext = []
            for i,j,hist in toVisit:
                for di, dj in [ (0,-1), (-1,0), (0,+1), (+1,0) ]:
                    if not (0<=i+di<M and 0<=j+dj<N) or grid[i+di][j+dj] < 1 or (i+di, j+dj) in visited:
                        continue
                    if i+di == ei and j+dj == ej:
                        return len(hist) + 1
                    toVisitNext.append( (i+di, j+dj, hist+[(i+di,j+dj)]) )
                    visited.add( (i+di, j+dj) )
            toVisit = toVisitNext
        return -1

    retVal = 0
    si, sj = 0, 0
    for height in treeHeights:
        ei, ej = treeLocs[height]
        path = _closestPath( si, sj, ei, ej )
        # print(f"({si},{sj}) -> ({ei},{ej}): {path}") 
        if path < 0:
            return -1
        retVal += path
        grid[ei][ej] = 1
        si, sj = ei, ej

    return retVal

tests = [
    ([[1,2,3],
      [0,0,4],
      [7,6,5]], 6),
    #Following the path above allows you to cut off the trees from shortest to tallest in 6 steps.
    ([[1,2,3],
      [0,0,0],
      [7,6,5]], -1),
    # The trees in the bottom row cannot be accessed as the middle row is blocked.
    ([[2,3,4],
      [0,0,5],
      [8,7,6]], 6),
    ([[3,2,4],
      [0,0,5],
      [8,7,6]], 8),  # we need to cut 2 first, then 3, then 4 ...
]
for t in tests:
    retVal = cutTrees(t[0])
    print(t, retVal)
    assert(retVal == t[1])


([[1, 1, 1], [0, 0, 1], [1, 1, 1]], 6) 6
([[1, 1, 1], [0, 0, 0], [7, 6, 5]], -1) -1
([[1, 1, 1], [0, 0, 1], [1, 1, 1]], 6) 6
([[1, 1, 1], [0, 0, 1], [1, 1, 1]], 8) 8
