# Solution
* We first need to find the node, if the node doesn't exist, we can still just safely return the root node
* Once the node is found we have two cases: it only has one child or it has two children
* If only one child, then this is straight forward - the child will just take the current spot
  * The BST property will still hold because every subtree is a valid BST
* If it has two children, then we need to
  * One, first find the smallest value in the right node's sub-tree
  * Then we just replace this current node's value with this minimum value
  * This way everything to its right will be greater than this value, and all the existing links to parent and children still hold
  * Second step, we go traverse this sub-tree, and delete this node, which we can do so by just calling our current function

## Time Complexity
* $O(2h)$ because first pass to find the value which could be anywhere from the root to the leaf level, and when matched, we still have to traverse down to the leaf node of this sub-tree to find the minimum, which constitutes the full height $h$. But we need to then traverse down again to delete this minimum value, which can take another $h$.

## Space Complexity
* $O(h)$ for recursive stack because the number of calls would never exceed the full height of the tree

In [None]:
from typing import Optional

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right


class Solution:
    def deleteNode(self, root: Optional[TreeNode], key: int) -> Optional[TreeNode]:
        if not root:
            return None

        if root.val != key:
            if key > root.val:
                root.right = self.deleteNode(root.right, key)
            else:
                root.left = self.deleteNode(root.left, key)
        else:
            if not root.right:
                return root.left
            elif not root.left:
                return root.right
            else:
                # need to replace with the smallest value
                # in the right subtree
                min_val = self.findMinVal(root.right)
                root.val = min_val
                root.right = self.deleteNode(root.right, min_val)
        return root
    
    def findMinVal(self, root):
        while root.left:
            root = root.left
        return root.val