In [124]:
"""
https://leetcode.com/problems/design-an-expression-tree-with-evaluate-function/
https://leetcode.ca/2020-05-15-1628-Design-an-Expression-Tree-With-Evaluate-Function/


Given the postfix tokens of an arithmetic expression, build and return the binary expression tree that represents this expression.

Postfix notation is a notation for writing arithmetic expressions in which the operands (numbers) appear before their operators. 
For example, the postfix tokens of the expression 4*(5-(7+2)) are represented in the array postfix = ["4","5","7","2","+","-","*"].

The class Node is an interface you should use to implement the binary expression tree. 
The returned tree will be tested using the evaluate function, which is supposed to evaluate the tree's value. 
You should not remove the Node class; 
however, you can modify it as you wish, and you can define other classes to implement it if needed.

A binary expression tree is a kind of binary tree used to represent arithmetic expressions. 
Each node of a binary expression tree has either zero or two children. 
Leaf nodes (nodes with 0 children) correspond to operands (numbers), 
and internal nodes (nodes with two children) correspond to 
the operators '+' (addition), '-' (subtraction), '*' (multiplication), and '/' (division).

It's guaranteed that no subtree will yield a value that exceeds 109 in absolute value, 
and all the operations are valid (i.e., no division by zero).

Follow up: Could you design the expression tree such that it is more modular? 
For example, is your design able to support additional operators without making changes to your existing evaluate implementation?

 
Constraints:
1 <= s.length < 100
s.length is odd.
s consists of numbers and the characters '+', '-', '*', and '/'.
If s[i] is a number, its integer representation is no more than 105.
It is guaranteed that s is a valid expression.
The absolute value of the result and intermediate values will not exceed 109.
It is guaranteed that no expression will include division by zero.
"""

class Node:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right


def readTree(postfix):
    stack = []
    for s in postfix:
        n = Node(s)
        if not s.isdigit():
            n.right = stack.pop()
            n.left = stack.pop()
        stack.append(n)
    return stack[-1]
    

def evalExpTree(root):
    def _visit(n):
        assert(n is not None)
        if n.val.isdigit():
            return int(n.val)
        lval = _visit(n.left)
        rval = _visit(n.right)
        if n.val == "+":
            return lval + rval
        elif n.val == "-":
            return lval - rval
        elif n.val == "*":
            return lval * rval
        elif n.val == "/":
            return lval // rval

    return _visit(root)

tests = [
(["3","4","+","2","*","7","/"], 2),
#           /
#       *       7
#    +    2
#  3  4
# ((3+4)*2)/7) = 14/7 = 2
(["4","5","2","7","+","-","*"], -16)
#            *
#       4        -
#              5   +
#                 2  7
# 4*(5-(2+7)) = 4*(-4) = -16
]

for t in tests:
    retVal = evalExpTree(readTree(t[0]))
    print(t, retVal)
    assert(retVal == t[1])


(['3', '4', '+', '2', '*', '7', '/'], 2) 2
(['4', '5', '2', '7', '+', '-', '*'], -16) -16


In [123]:
"""
https://leetcode.com/problems/add-two-polynomials-represented-as-linked-lists/
https://leetcode.ca/2020-05-21-1634-Add-Two-Polynomials-Represented-as-Linked-Lists/


A polynomial linked list is a special type of linked list where every node represents a term in a polynomial expression.

Each node has three attributes:

coefficient: an integer representing the number multiplier of the term. The coefficient of the term 9x^4 is 9.
power: an integer representing the exponent. The power of the term 9x^4 is 4.
next: a pointer to the next node in the list, or null if it is the last node of the list.
For example, the polynomial 5x^3 + 4x - 7 is represented by the polynomial linked list illustrated below:

[co:5, po:3] -> [co:4, po:1] -> [co:-7, po:0] -> None



The polynomial linked list must be in its standard form: the polynomial must be in strictly descending order by its power value. 
Also, terms with a coefficient of 0 are omitted.

Given two polynomial linked list heads, poly1 and poly2, add the polynomials together and return the head of the sum of the polynomials.

PolyNode format:
The input/output format is as a list of n nodes, where each node is represented as its [coefficient, power]. 
For example, the polynomial 5x3 + 4x - 7 would be represented as: [[5,3],[4,1],[-7,0]].

 
Constraints:
0 <= n <= 10^4
-10^9 <= PolyNode.coefficient <= 10^9
PolyNode.coefficient != 0
0 <= PolyNode.power <= 10^9
PolyNode.power > PolyNode.next.power
"""

class PolyNode:
    def __init__(self, x=0, y=0, next=None):
        self.coefficient = x
        self.power = y
        self.next = next

def readPolyNodes(arr):
    dummy=PolyNode()
    p = dummy
    for [co,po] in arr:
        p.next = PolyNode(x=co,y=po)
        p = p.next
    return dummy.next

def writePolyNodes(root):
    retVal = []
    p = root
    while p is not None:
        retVal.append([p.coefficient, p.power])
        p = p.next
    return retVal

def mergeTwoPolyNodes(n1, n2):
    dummy = PolyNode()
    p = dummy    
    p1 = n1
    p2 = n2

    while p1 or p2:
        # combine p1 and p2
        if p1 and p2 and p1.power == p2.power:
            if p1.coefficient + p2.coefficient != 0:
                p.next = PolyNode(x=p1.coefficient + p2.coefficient,
                                  y=p1.power)
                p = p.next
            p1 = p1.next
            p2 = p2.next
        # take p1
        elif (p1 and p2 and p1.power > p2.power) or not p2:
            p.next = PolyNode(x=p1.coefficient,
                              y=p1.power)
            p1 = p1.next
            p = p.next
            
        # take p2
        else:
            p.next = PolyNode(x=p2.coefficient,
                              y=p2.power)
            p2 = p2.next
            p = p.next

    return dummy.next


tests = [
    ([[1,1]], [[1,0]], [[1,1],[1,0]]),
    ([[2,2],[4,1],[3,0]], [[3,2],[-4,1],[-1,0]], [[5,2],[2,0]]),
    ([[1,2]], [[-1,2]], [])
]
for t in tests:
    n1 = readPolyNodes(t[0])
    n2 = readPolyNodes(t[1])
    retVal = writePolyNodes(mergeTwoPolyNodes(n1, n2))
    print(t, retVal)
    assert(retVal == t[2])


([[1, 1]], [[1, 0]], [[1, 1], [1, 0]]) [[1, 1], [1, 0]]
([[2, 2], [4, 1], [3, 0]], [[3, 2], [-4, 1], [-1, 0]], [[5, 2], [2, 0]]) [[5, 2], [2, 0]]
([[1, 2]], [[-1, 2]], []) []


In [115]:
"""
https://leetcode.com/problems/lowest-common-ancestor-of-a-binary-tree-iv/
https://leetcode.ca/2020-07-02-1676-Lowest-Common-Ancestor-of-a-Binary-Tree-IV/

Given the root of a binary tree and an array of TreeNode objects nodes, 
return the lowest common ancestor (LCA) of all the nodes in nodes. 

All the nodes will exist in the tree, and all values of the tree's nodes are unique.

Extending the definition of LCA on Wikipedia: 
"The lowest common ancestor of n nodes p1, p2, ..., pn in a binary tree T is 
the lowest node that has every pi as a descendant (where we allow a node to be a descendant of itself) for every valid i". 
A descendant of a node x is a node y that is on the path from node x to some leaf node.
 
Constraints:

The number of nodes in the tree is in the range [1, 104].
-109 <= Node.val <= 109
All Node.val are unique.
All nodes[i] will exist in the tree.
All nodes[i] are distinct.
"""

class Node:
    def __init__(self, val, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
        

def readTree(arr):
    nodeMap={}
    for i in range(len(arr)-1, -1, -1):
        nodeMap[i] = Node(arr[i])
        nodeMap[i].left = nodeMap[i*2+1] if i*2+1 in nodeMap else None
        nodeMap[i].right = nodeMap[i*2+2] if i*2+2 in nodeMap else None
    return nodeMap[0]


def lowestCommonAncestor(root, nodes):
    def _getPath(n, val, paths):
        if n == None:
            return None
        if n.val == val:
            return paths + [val]
        leftPath = _getPath(n.left, val, paths + [n.val])
        rightPath = _getPath(n.right, val, paths + [n.val])
        if leftPath:
            return leftPath
        if rightPath:
            return rightPath

    minPathLen = float('inf')
    paths = []
    for n in nodes:
        paths.append(_getPath(root, n, []))
        if len(paths[-1]) < minPathLen:
            minPathLen = len(paths[-1])
    print(paths)

    for i in range(minPathLen):
        allSame = True
        for p in paths:
            if p[i] != paths[0][i]:
                return paths[0][i-1]
    return paths[0][-1]
        
    

tests = [
    ([3,5,1,6,2,0,8,None,None,7,4], [4,7], 2),
    #Explanation: The lowest common ancestor of nodes 4 and 7 is node 2.
    ([3,5,1,6,2,0,8,None,None,7,4], [1], 1),
    ([3,5,1,6,2,0,8,None,None,7,4], [7,6,2,4], 5)
]
#         3
#    5        1
#  6   2    0   8
#     7 4
for t in tests:
    retVal = lowestCommonAncestor(readTree(t[0]), t[1])
    print(t, retVal)
    assert(retVal == t[2])


[[3, 5, 2, 4], [3, 5, 2, 7]]
([3, 5, 1, 6, 2, 0, 8, None, None, 7, 4], [4, 7], 2) 2
[[3, 1]]
([3, 5, 1, 6, 2, 0, 8, None, None, 7, 4], [1], 1) 1
[[3, 5, 2, 7], [3, 5, 6], [3, 5, 2], [3, 5, 2, 4]]
([3, 5, 1, 6, 2, 0, 8, None, None, 7, 4], [7, 6, 2, 4], 5) 5


In [109]:
"""
https://leetcode.com/problems/maximum-units-on-a-truck/

You are assigned to put some amount of boxes onto one truck. You are given a 2D array boxTypes, 
where boxTypes[i] = [numberOfBoxesi, numberOfUnitsPerBoxi]:

numberOfBoxesi is the number of boxes of type i.
numberOfUnitsPerBoxi is the number of units in each box of the type i.
You are also given an integer truckSize, which is the maximum number of boxes that can be put on the truck. 
You can choose any boxes to put on the truck as long as the number of boxes does not exceed truckSize.

Return the maximum total number of units that can be put on the truck.

 

Constraints:

1 <= boxTypes.length <= 1000
1 <= numberOfBoxesi, numberOfUnitsPerBoxi <= 1000
1 <= truckSize <= 10^6
"""

def maxUnitsOnTruck(boxTypes, truckSize):
    """ priotize boxes with more units """
    import heapq
    maxheap = []
    for numBox, numUnitsPerBox in boxTypes:
        heapq.heappush(maxheap, (-numUnitsPerBox, (numBox, numUnitsPerBox)))

    numUnits = 0
    numLoadedBoxes = 0
    while numLoadedBoxes < truckSize and len(maxheap):
        _, (numBoxes, numUnitsPerBox) = heapq.heappop(maxheap)
        numBoxesToLoad = min( numBoxes, truckSize - numLoadedBoxes )
        numUnits += numBoxesToLoad * numUnitsPerBox
        numLoadedBoxes += numBoxesToLoad

    return numUnits

tests = [
    ([[1,3],[2,2],[3,1]], 4, 8),
    # - 1 box of the first type that contains 3 units.
    # - 2 boxes of the second type that contain 2 units each.
    # - 3 boxes of the third type that contain 1 unit each.
    # You can take all the boxes of the first and second types, and one box of the third type.
    # The total number of units will be = (1 * 3) + (2 * 2) + (1 * 1) = 8.
    ([[5,10],[2,5],[4,7],[3,9]], 10, 91)
]
for t in tests:
    retVal = maxUnitsOnTruck(t[0], t[1])
    print(t, retVal)
    assert(retVal == t[2])


([[1, 3], [2, 2], [3, 1]], 4, 8) 8
([[5, 10], [2, 5], [4, 7], [3, 9]], 10, 91) 91


In [107]:
"""
https://leetcode.com/problems/maximum-number-of-events-that-can-be-attended-ii/


You are given an array of events where events[i] = [startDayi, endDayi, valuei].
The ith event starts at startDayi and ends at endDayi, and if you attend this event, you will receive a value of valuei.
You are also given an integer k which represents the maximum number of events you can attend.

You can only attend one event at a time.
If you choose to attend an event, you must attend the entire event.
Note that the end day is inclusive: that is, you cannot attend two events where one of them starts and the other ends on the same day.

Return the maximum sum of values that you can receive by attending events. 
"""


def getMaxValue(events, k):
    """
    generate all possible combinations
    """
    maxProfit = 0
    _events = sorted(events)
    profit = 0
    attended = set()
    def _attend(eIdx, startTime):
        nonlocal profit, maxProfit
        if eIdx >= len(_events):
            return
        if len(attended) == k:
            return
        # attend eIdx
        if startTime <= _events[eIdx][0]:
            attended.add(eIdx)
            profit += _events[eIdx][2]
            if profit > maxProfit:
                maxProfit = profit
            _attend(eIdx+1, _events[eIdx][1]+1)
            attended.remove(eIdx)
            profit -= _events[eIdx][2]
        # do not attend eIdx
        _attend(eIdx+1, startTime)
    
    _attend(0, _events[0][0])
    return maxProfit


tests = [
([[1,2,4],[3,4,3],[2,3,1]], 2, 7),
([[1,2,4],[3,4,3],[2,3,10]], 2, 10),
([[1,1,1],[2,2,2],[3,3,3],[4,4,4]], 3, 9)
]
for t in tests:
    retVal = getMaxValue(t[0], t[1])
    print(t, retVal)
    assert(retVal == t[2])

([[1, 2, 4], [3, 4, 3], [2, 3, 1]], 2, 7) 7
([[1, 2, 4], [3, 4, 3], [2, 3, 10]], 2, 10) 10
([[1, 1, 1], [2, 2, 2], [3, 3, 3], [4, 4, 4]], 3, 9) 9


In [101]:
"""
https://leetcode.com/problems/design-a-text-editor/

Design a text editor with a cursor that can do the following:

Add text to where the cursor is.
Delete text from where the cursor is (simulating the backspace key).
Move the cursor either left or right.
When deleting text, only characters to the left of the cursor will be deleted. 
The cursor will also remain within the actual text and cannot be moved beyond it. 
More formally, we have that 0 <= cursor.position <= currentText.length always holds.

Implement the TextEditor class:
TextEditor() Initializes the object with empty text.
void addText(string text) Appends text to where the cursor is. The cursor ends to the right of text.
int deleteText(int k) Deletes k characters to the left of the cursor. 
Returns the number of characters actually deleted.
string cursorLeft(int k) Moves the cursor to the left k times. 
Returns the last min(10, len) characters to the left of the cursor, where len is the number of characters to the left of the cursor.
string cursorRight(int k) Moves the cursor to the right k times. 
Returns the last min(10, len) characters to the left of the cursor, where len is the number of characters to the left of the cursor.
 

Constraints:
1 <= text.length, k <= 40
text consists of lowercase English letters.
At most 2 * 10^4 calls in total will be made to addText, deleteText, cursorLeft and cursorRight.
 

Follow-up: Could you find a solution with time complexity of O(k) per call?

deque()

               cursorP
                |
                V
H - a - b - c - | - d - e - f - g - E

H, E: sentinel nodes
cursorP

"""

class Node:
    def __init__(self, val, prev=None, next=None):
        self.val = val
        self.prev = prev
        self.next = next

class TextEditor:
    def __init__(self):
        self.head = Node(-1)
        self.tail = Node(-1)
        self.head.next = self.tail
        self.tail.prev = self.head
        self.p = self.tail
        # self.print()
    
    def print(self):
        p = self.head
        while p:
            print(p.val, "Pointer" if p == self.p else "")
            p = p.next
        print()

    def addText(self, text: str) -> None:
        prev = self.p.prev
        for c in text:
            prev.next = Node(c, prev=prev)
            prev = prev.next
        prev.next = self.p
        self.p.prev = prev 
        # self.print()
        # print(f"addText({text})")


    def deleteText(self, k: int) -> int:
        # self.print()
        retVal = 0
        for i in range(k):
            # head -> a b c -> tail
            #                  p
            if self.p.prev == self.head:
                break
            toDel = self.p.prev
            pprev = toDel.prev
            pprev.next = self.p
            self.p.prev = pprev
            retVal += 1
        # self.print()
        # print(f"deleteText({k}): {retVal}")
        return retVal

    def _getLast10Chars(self):
        if self.p == self.head:
            return ""
        retVal = []
        p = self.p
        for _ in range(10):
            p = p.prev
            if p == self.head:
                break
            retVal.append(p.val)
        return "".join(reversed(retVal))
    
    def cursorLeft(self, k: int) -> str:
        retVal = 0
        for i in range(k):
            if self.p.prev == self.head:
                break
            self.p = self.p.prev
            retVal += 1
        # self.print()
        retVal = self._getLast10Chars()
        # print(f"cursorLeft({k}): {retVal}")
        return retVal

    def cursorRight(self, k: int) -> str:
        for i in range(k):
            if self.p is self.tail:
                break
            self.p = self.p.next
        # self.print()
        retVal = self._getLast10Chars()
        # print(f"cursorRight({k}): {retVal}")
        return retVal


#Input
#["TextEditor", "addText", "deleteText", "addText", "cursorRight", "cursorLeft", "deleteText", "cursorLeft", "cursorRight"]
#[[], ["leetcode"], [4], ["practice"], [3], [8], [10], [2], [6]]
#Output
#[null, null, 4, null, "etpractice", "leet", 4, "", "practi"]

textEditor = TextEditor()      #; // The current text is "|". (The '|' character represents the cursor)
textEditor.addText("leetcode") #; // The current text is "leetcode|".
assert(textEditor.deleteText(4) == 4)       #; // return 4
                               # // The current text is "leet|". 
                               # // 4 characters were deleted.
textEditor.addText("practice") #; // The current text is "leetpractice|". 
assert(textEditor.cursorRight(3) == "etpractice")     #; // return "etpractice"
                               #// The current text is "leetpractice|". 
                               #// The cursor cannot be moved beyond the actual text and thus did not move.
                               #// "etpractice" is the last 10 characters to the left of the cursor.
assert(textEditor.cursorLeft(8) == "leet")       #; // return "leet"
                               #// The current text is "leet|practice".
                               #// "leet" is the last min(10, 4) = 4 characters to the left of the cursor.
assert(textEditor.deleteText(10) == 4)      #; // return 4
                               #// The current text is "|practice".
                               #// Only 4 characters were deleted.
assert(textEditor.cursorLeft(2) == "")       #; // return ""
                               #// The current text is "|practice".
                               #// The cursor cannot be moved beyond the actual text and thus did not move. 
                               #// "" is the last min(10, 0) = 0 characters to the left of the cursor.
assert(textEditor.cursorRight(6) == "practi")      #; // return "practi"
                               #// The current text is "practi|ce".
                               #// "practi" is the last min(10, 6) = 6 characters to the left of the cursor.


In [71]:
"""
https://leetcode.com/problems/median-of-two-sorted-arrays/

Given two sorted arrays nums1 and nums2 of size m and n respectively, return the median of the two sorted arrays.

The overall run time complexity should be O(log (m+n)).

Constraints:

nums1.length == m
nums2.length == n
0 <= m <= 1000
0 <= n <= 1000
1 <= m + n <= 2000
-10^6 <= nums1[i], nums2[i] <= 10^6
"""

def findMedianSortedArrays(nums1, nums2):
    m = len(nums1)
    n = len(nums2)

    def findKth(i, j, k):
        if i >= m:
            return nums2[j + k - 1]
        if j >= n:
            return nums1[i + k - 1]
        if k == 1:
            return min(nums1[i], nums2[j])

        midVal1 = nums1[i + k//2 - 1] if i + k//2 - 1 < m else float('inf')
        midVal2 = nums2[j + k//2 - 1] if j + k//2 - 1 < n else float('inf')

        if midVal1 < midVal2:
            return findKth(i + k//2, j,        k - k//2) # '+' or '-' k//2
        else:
            return findKth(i,        j + k//2, k - k//2)

    leftIdx = (m + n + 1) // 2 
    rightIdx = (m + n + 2) // 2
    return (findKth(0, 0, leftIdx) + findKth(0, 0, rightIdx)) / 2.0

tests = [
    ([1,3], [2], 2.00000),
    ([1,2], [3,4], 2.50000)
]
for t in tests:
    retVal = findMedianSortedArrays(t[0], t[1])
    print(t, retVal)
    assert(retVal - t[2] < pow(1,-5))

([1, 3], [2], 2.0) 2.0
([1, 2], [3, 4], 2.5) 2.5


In [66]:
"""
https://leetcode.com/problems/word-search-ii/


Given an m x n board of characters and a list of strings words, return all words on the board.

Each word must be constructed from letters of sequentially adjacent cells, 
where adjacent cells are horizontally or vertically neighboring. 
The same letter cell may not be used more than once in a word.


Constraints:

m == board.length
n == board[i].length
1 <= m, n <= 12
board[i][j] is a lowercase English letter.
1 <= words.length <= 3 * 10^4
1 <= words[i].length <= 10
words[i] consists of lowercase English letters.
All the strings of words are unique.
"""


def wordSearchII(grid, words):
    M = len(grid)
    N = len(grid[0])

    def _wordExists(w):
        visited=set()
        def _visit(i,j,wIdx):
            if grid[i][j] != w[wIdx]:
                return False
            visited.add((i,j))
            if len(visited) == len(w):
                return True
            for di,dj in [ (+1,0), (0,+1), (-1,0), (0,-1) ]:
                if not (0<=i+di<M and 0<=j+dj<N):
                    continue
                if (i+di,j+dj) in visited:
                    continue
                if _visit(i+di, j+dj, wIdx+1):
                    return True
            visited.remove((i,j))
            return False
            
        for i in range(M):
            for j in range(N):
                if _visit(i,j,0):
                    return True
        return False

    retVal = []
    for w in words:
        if _wordExists(w):
            retVal.append(w)
    return retVal

tests = [
    ([["o","a","a","n"],
      ["e","t","a","e"],
      ["i","h","k","r"],
      ["i","f","l","v"]], 
      ["oath","pea","eat","rain"],
      ["eat","oath"]),
    ([["a","b"],
     ["c","d"]],
     ["abcb"],
     [])
]

for t in tests:
    retVal = wordSearchII(t[0], t[1])
    print(t, retVal)
    assert(sorted(retVal) == sorted(t[2]))


([['o', 'a', 'a', 'n'], ['e', 't', 'a', 'e'], ['i', 'h', 'k', 'r'], ['i', 'f', 'l', 'v']], ['oath', 'pea', 'eat', 'rain'], ['eat', 'oath']) ['oath', 'eat']
([['a', 'b'], ['c', 'd']], ['abcb'], []) []


In [60]:
"""
https://leetcode.com/problems/cut-off-trees-for-golf-event/

You are asked to cut off all the trees in a forest for a golf event. The forest is represented as an m x n matrix. In this matrix:

0 means the cell cannot be walked through.
1 represents an empty cell that can be walked through.
A number greater than 1 represents a tree in a cell that can be walked through, and this number is the tree's height.
In one step, you can walk in any of the four directions: north, east, south, and west. 
If you are standing in a cell with a tree, you can choose whether to cut it off.

You must cut off the trees in order from shortest to tallest. When you cut off a tree, the value at its cell becomes 1 (an empty cell).

Starting from the point (0, 0), return the minimum steps you need to walk to cut off all the trees. 
If you cannot cut off all the trees, return -1.

Note: The input is generated such that no two trees have the same height, and there is at least one tree needs to be cut off.


Constraints:

m == forest.length
n == forest[i].length
1 <= m, n <= 50
0 <= forest[i][j] <= 10^9
Heights of all trees are distinct.
"""

from collections import deque
def cutTrees(grid):
    if grid == None or len(grid)==0 or len(grid[0])==0:
        return 0
    M = len(grid)
    N = len(grid[0])

    # sort tree heights by height
    # 2->3->4->5->6->7->8
    #  numbers are unique
    #  at least one tree to cut

    treeLocs = {} # height -> (i,j); there is a guarantee that no two trees have the same height
    treeHeights = []
    for i in range(M):
        for j in range(N):
            if grid[i][j] == 0:
                continue
            treeHeights.append(grid[i][j])
            treeLocs[ grid[i][j] ] = (i, j)
    treeHeights.sort()
    numTrees = len(treeHeights)

    # print(f"{numTrees=}")
    # print(f"{treeHeights=}")
    # print(f"{treeLocs=}")

    def _closestPath( si, sj, ei, ej ):
        """ BFS """
        if si==ei and sj==ej:
            return 0
        visited = set()
        toVisit = [ (si, sj, []) ]
        while len(toVisit):
            toVisitNext = []
            for i,j,hist in toVisit:
                for di, dj in [ (0,-1), (-1,0), (0,+1), (+1,0) ]:
                    if not (0<=i+di<M and 0<=j+dj<N) or grid[i+di][j+dj] < 1 or (i+di, j+dj) in visited:
                        continue
                    if i+di == ei and j+dj == ej:
                        return len(hist) + 1
                    toVisitNext.append( (i+di, j+dj, hist+[(i+di,j+dj)]) )
                    visited.add( (i+di, j+dj) )
            toVisit = toVisitNext
        return -1

    retVal = 0
    si, sj = 0, 0
    for height in treeHeights:
        ei, ej = treeLocs[height]
        path = _closestPath( si, sj, ei, ej )
        # print(f"({si},{sj}) -> ({ei},{ej}): {path}") 
        if path < 0:
            return -1
        retVal += path
        grid[ei][ej] = 1
        si, sj = ei, ej

    return retVal

tests = [
    ([[1,2,3],
      [0,0,4],
      [7,6,5]], 6),
    #Following the path above allows you to cut off the trees from shortest to tallest in 6 steps.
    ([[1,2,3],
      [0,0,0],
      [7,6,5]], -1),
    # The trees in the bottom row cannot be accessed as the middle row is blocked.
    ([[2,3,4],
      [0,0,5],
      [8,7,6]], 6),
    ([[3,2,4],
      [0,0,5],
      [8,7,6]], 8),  # we need to cut 2 first, then 3, then 4 ...
]
for t in tests:
    retVal = cutTrees(t[0])
    print(t, retVal)
    assert(retVal == t[1])


([[1, 1, 1], [0, 0, 1], [1, 1, 1]], 6) 6
([[1, 1, 1], [0, 0, 0], [7, 6, 5]], -1) -1
([[1, 1, 1], [0, 0, 1], [1, 1, 1]], 6) 6
([[1, 1, 1], [0, 0, 1], [1, 1, 1]], 8) 8


In [50]:
"""
https://leetcode.com/problems/basic-calculator-iv/

Given an expression such as expression = "e + 8 - a + 5" and an evaluation map such as {"e": 1} (given in terms of evalvars = ["e"] and evalints = [1]), 
return a list of tokens representing the simplified expression, such as ["-1*a","14"]

An expression alternates chunks and symbols, with a space separating each chunk and symbol.
A chunk is either an expression in parentheses, a variable, or a non-negative integer.
A variable is a string of lowercase letters (not including digits.) Note that variables can be multiple letters, and note that variables never have a leading coefficient or unary operator like "2x" or "-x".
Expressions are evaluated in the usual order: brackets first, then multiplication, then addition and subtraction.

For example, expression = "1 + 2 * 3" has an answer of ["7"].
The format of the output is as follows:

For each term of free variables with a non-zero coefficient, we write the free variables within a term in sorted order lexicographically.
For example, we would never write a term like "b*a*c", only "a*b*c".
Terms have degrees equal to the number of free variables being multiplied, counting multiplicity. We write the largest degree terms of our answer first, breaking ties by lexicographic order ignoring the leading coefficient of the term.
For example, "a*a*b*c" has degree 4.
The leading coefficient of the term is placed directly to the left with an asterisk separating it from the variables (if they exist.) A leading coefficient of 1 is still printed.
An example of a well-formatted answer is ["-2*a*a*a", "3*a*a*b", "3*b*b", "4*a", "5*c", "-6"].
Terms (including constant terms) with coefficient 0 are not included.
For example, an expression of "0" has an output of [].
Note: You may assume that the given expression is always valid. All intermediate results will be in the range of [-231, 231 - 1].


Constraints:
1 <= expression.length <= 250
expression consists of lowercase English letters, digits, '+', '-', '*', '(', ')', ' '.
expression does not contain any leading or trailing spaces.
All the tokens in expression are separated by a single space.
0 <= evalvars.length <= 100
1 <= evalvars[i].length <= 20
evalvars[i] consists of lowercase English letters.
evalints.length == evalvars.length
-100 <= evalints[i] <= 100
"""

from collections import defaultdict
def calc(s, evalvars, evalints):
    def _isNum(t):
        for c in t:
            if not c.isdigit():
                return False
        return True

    def _multNums(nums):
        retVal = 1
        for n in nums:
            retVal *= n
        return retVal

    def _getDegrees(s):
        numMult = 0
        for c in s:
            if c == "*":
                numMult
        return numMult + 1

    def _addTerms(terms):
        termMap = {} # defaultdict(int)
        termMap["NUM"] =  [1, 0]

        for t in terms:
            if len(t["VAR"]) + len(t["NUM"]) == 0:
                continue
            sign = t["SIGN"]
            val =  _multNums(t["NUM"])
            t["VAR"].sort()
            if len(t["VAR"])>0:
                key = "*".join(t["VAR"])
            else:
                key = "NUM"
            if key in termMap:
                oldsign, oldval = termMap[key]
                newval = oldsign * oldval + sign * val
                sign = 1 if newval > 0 else -1
                val = newval * sign
            termMap[key] = (sign, val)

        keys = sorted([(-_getDegrees(k), k) for k,_ in termMap.items() if k != "NUM" ])

        varTerms = [ {"VAR": [k], "NUM": [termMap[k][1]], "SIGN": termMap[k][0],} for _, k in keys if termMap[k][1] != 0]
        numTerms = [ {"VAR": [], "NUM": [termMap["NUM"][1]], "SIGN": termMap["NUM"][0],} ]
        return varTerms + numTerms

    def _multTerms(terms1, terms2):
        """
         terms1: a b c
         terms2: 1 2 3
            1 * a + 2 * a + 3 * a + 
            1 * b + 2 * b + 3 * b + 
            1 * c + 2 * c + 3 * c
        """
        allterms = []

        # print(f"{terms1=}")
        # print(f"{terms2=}")        
        for t1 in terms1:
            for t2 in terms2:
                #4 cases: var * var, var * num, num * var, num * num
                allvars = t1["VAR"] + t2["VAR"]
                allnums = [_multNums(t1["NUM"]) * _multNums(t2["NUM"])]
                sign = t1["SIGN"] * t2["SIGN"]
                if allnums[0] == 0:
                    continue
                allterms.append({"VAR": allvars, "NUM": allnums, "SIGN": sign,})

                # print("\t", t1)
                # print("\t", t2)
                # print("\t", allterms)
    
        return _addTerms(allterms)

    def _readOut(terms):
        """ assume sorted in the right order """

        retVal = []
        for t in terms:
            if len(t["VAR"]):
                retVal.append( str(t["SIGN"] * _multNums(t["NUM"])) + "*" + "*".join(t["VAR"]) )
            else:
                retVal.append( str(t["SIGN"] * _multNums(t["NUM"])) )
        return retVal

    varMap = {}
    for var, val in zip(evalvars, evalints):
        varMap[var] = val

    stack = [] #  (prevTerms, prevOp)
    terms = []
    cache = {"VAR": [], "NUM": [], "SIGN": 1,}
    prevOp = ""
    for t in s.replace("(", " ( ").replace(")", " ) ").split():
        if t == "+" or t == "-":
            # push cache to terms
            terms.append(cache)
            cache = {"VAR": [], "NUM": [], "SIGN": 1,}
            prevOp = t
        elif t == "*":
            prevOp = t
        elif t == "(":
            # use stack to save/load state
            # print("Saving ")
            # print(f"\t{terms=}")
            # print(f"\t{cache=}")
            # print(f"\t{prevOp=}")
            stack.append(
                ( _addTerms(terms), cache, prevOp )
            )
            terms = []
            cache = {"VAR": [], "NUM": [], "SIGN": 1,}
            prevOp = ""

        elif t == ")":

            if len(cache["VAR"]) + len(cache["NUM"]) > 0:
                terms.append(cache)
            curTerms = _addTerms(terms)

            # print("CurTerms ")
            # print(f"\t{curTerms=}")            
            # pop prevState
            # prevState prevOp curState -> add back to terms
            prevTerms, prevCache, op = stack.pop()
            # print("prevTerms ")
            # print(f"\t{prevTerms=}")
            # print(f"\t{prevCache=}")
            # print(f"\t{op=}")            
            if len(prevCache["VAR"]) + len(prevCache["NUM"]) > 0 and op == "*":
                # print("PrevTerms + prevCache * curTerms")
                terms = _addTerms(  prevTerms,
                                    _multTerms([prevCache], curTerms) )
                cache = {"VAR": [], "NUM": [], "SIGN": 1,}
                prevOp = ""            

            elif op == "*":
                # print("PrevTerms * curTerms")
                terms = _multTerms(prevTerms, curTerms)
                cache = {"VAR": [], "NUM": [], "SIGN": 1,}
                prevOp = ""            
            else:
                # print("PrevTerms + prevCache + curTerms")
                if len(prevCache["VAR"]) + len(prevCache["NUM"]) > 0:
                    prevTerms.append(prevCache)

                if op == "-":
                    for t in curTerms:
                        t["SIGN"] *= -1

                terms = _addTerms(prevTerms + curTerms)
                cache = {"VAR": [], "NUM": [], "SIGN": 1,}
                prevOp = ""            
        
        elif _isNum(t):
            _num = int(t)
            cache["NUM"].append(_num)
            if prevOp == "-":
                cache["SIGN"] *= -1
            prevOp = ""

        elif t in varMap:
            _num = varMap[t]
            if prevOp == "-":
                _num *= -1
            cache["NUM"].append(_num)
            prevOp = ""

        else:
            cache["VAR"].append(t)
            if prevOp == "-":
                cache["SIGN"] *= -1
            prevOp = ""

    # push cache to terms
    if len(cache["VAR"]) + len(cache["NUM"]) > 0:
        terms.append(cache)
    # for t in terms:
    #     print(t)
    # print(_addTerms(terms))


    return _readOut(_addTerms(terms))


tests = [
(                 "e + 8 - a + 5",                ["e"],     [1],        ["-1*a",  "14"]),
(                 "e * 8 - a + 5",                ["e"],     [1],        ["-1*a",  "13"]),
(                 "e + 8 - a - 5",                ["e"],     [1],        ["-1*a",  "4"]),
(                 "e * 8 - a - 5",                ["e"],     [1],        ["-1*a",  "3"]),
("e - 8 + temperature - pressure", ["e", "temperature"], [1, 12], ["-1*pressure",   "5"]),
(             "(e + 8) * (e - 8)",                   [],      [],       ["1*e*e", "-64"])
]
for t in tests:
    retVal = calc(t[0], t[1], t[2])
    print(t, retVal)
    assert(retVal == t[3])

('e + 8 - a + 5', ['e'], [1], ['-1*a', '14']) ['-1*a', '14']
('e * 8 - a + 5', ['e'], [1], ['-1*a', '13']) ['-1*a', '13']
('e + 8 - a - 5', ['e'], [1], ['-1*a', '4']) ['-1*a', '4']
('e * 8 - a - 5', ['e'], [1], ['-1*a', '3']) ['-1*a', '3']
('e - 8 + temperature - pressure', ['e', 'temperature'], [1, 12], ['-1*pressure', '5']) ['-1*pressure', '5']
('(e + 8) * (e - 8)', [], [], ['1*e*e', '-64']) ['1*e*e', '-64']
