Given a binary search tree, write a function kthSmallest to find the kth smallest element in it.

Note:
You may assume k is always valid, 1 ≤ k ≤ BST's total elements.

Example 1:
```
Input: root = [3,1,4,null,2], k = 1
   3
  / \
 1   4
  \
   2
Output: 1
```
Example 2:
```
Input: root = [5,3,6,2,4,null,null,1], k = 3
       5
      / \
     3   6
    / \
   2   4
  /
 1
Output: 3
```
Follow up:
What if the BST is modified (insert/delete operations) often and you need to find the kth smallest frequently? How would you optimize the kthSmallest routine?

In [13]:
public class TreeNode {
    int val;
    TreeNode left;
    TreeNode right;
    TreeNode() {}
    TreeNode(int val) { this.val = val; }
    TreeNode(int val, TreeNode left, TreeNode right) {
        this.val = val;
        this.left = left;
        this.right = right;
    }
 }

Time complexity $O(H + k)$ H = tree height

In [22]:
public class Solution1 {
    public int kthSmallest(TreeNode root, int k) {
        Stack<TreeNode> stack = new Stack<>();
        TreeNode cur = root;
        while (!stack.isEmpty() || cur != null) {
            if (cur != null) {
                stack.push(cur);
                cur = cur.left;
                continue;
            }
            cur = stack.pop();
            k--;
            if (k == 0) return cur.val;
            cur = cur.right;
        }
        return Integer.MIN_VALUE;
    }
}

In [31]:
TreeNode root1 = new TreeNode(3);
root1.left = new TreeNode(1);
root1.left.right = new TreeNode(2);
root1.right = new TreeNode(4);

REPL.$JShell$12$TreeNode@1cae47f7

In [32]:
new Solution1().kthSmallest(root1, 1);

1

In [33]:
TreeNode root2 = new TreeNode(5);
root2.left = new TreeNode(3);
root2.left.right = new TreeNode(4);
root2.left.left = new TreeNode(2);
root2.left.left.left = new TreeNode(1);
root2.right = new TreeNode(6);

REPL.$JShell$12$TreeNode@4abd0e97

In [34]:
new Solution1().kthSmallest(root2, 3);

3

Time complexity $O(N)$

In [41]:
public class Solution2 {
    List<Integer> list = new ArrayList<>();
    public int kthSmallest(TreeNode root, int k) {
        helper(root);
        return list.get(k - 1);
    }
    private void helper(TreeNode node) {
        if (node == null) return;
        helper(node.left);
        list.add(node.val);
        helper(node.right);
    }
}

In [42]:
new Solution2().kthSmallest(root1, 1);

1

In [44]:
new Solution2().kthSmallest(root2, 3);

3

Time complexity $O(k)$

In [50]:
public class Solution3 {
    int count = 0;
    public int kthSmallest(TreeNode root, int k) {
        return helper(root, k);
    }
    private int helper(TreeNode node, int k) {
        if (node == null) return Integer.MIN_VALUE;
        int left = helper(node.left, k);
        count++;
        if (count == k) return node.val;
        return Math.max(left, helper(node.right, k));
    }
}

In [51]:
new Solution3().kthSmallest(root1, 1);

1

In [52]:
new Solution3().kthSmallest(root2, 3);

3