Given a complete binary tree, count the number of nodes.

__Note:__

Definition of a complete binary tree from Wikipedia:
In a complete binary tree every level, except possibly the last, is completely filled, and all nodes in the last level are as far left as possible. It can have between 1 and 2h nodes inclusive at the last level h.

__Example:__
```
Input: 
    1
   / \
  2   3
 / \  /
4  5 6

Output: 6
```

In [7]:
public class TreeNode {
    public int val;
    public TreeNode left, right;
    public TreeNode(int val) {
        this.val = val;
    }
    @Override
    public String toString() {
        Stack<TreeNode> stack = new Stack<>();
        stack.push(this);
        StringBuilder result = new StringBuilder();
        while (!stack.isEmpty()) {
            TreeNode cur = stack.pop();
            if (cur == null) {
                result.append("#").append(":");
                continue;
            }
            result.append(cur.val).append(":");
            if (cur.left == null && cur.right == null) continue;
            stack.push(cur.left);
            stack.push(cur.right);
        }
        return result.toString();
    }
}

In [8]:
TreeNode root = new TreeNode(1);
root.left = new TreeNode(2);
root.right = new TreeNode(3);
root.left.left = new TreeNode(4);
root.left.right = new TreeNode(5);
root.right.left = new TreeNode(6);
root;

1:3:#:6:2:5:4:

In [11]:
public class Solution {
    public int solve(TreeNode root) {
        if (root == null) return 0;
        return 1 + solve(root.left) + solve(root.right);
    }
}

In [12]:
new Solution().solve(root);

6

In [70]:
public class Solution {
    public int solve(TreeNode root) {
        int height = calcHeight(root);
        int lastLevel = (int) Math.pow(2, height);
        if (lastLevel == 0) return 1;
        int left = 0, right = lastLevel - 1;
        while (left <= right) {
            int mid = left + (right - left) / 2;
            if (exists(root, mid, height, (int) lastLevel)) {
                left = mid + 1;
            } else {
                right = mid - 1;
            }
        }
        // total nodes 2^levels - 1 => 
        //     1        2^0 = 1 node
        //    / \
        //   2   3      2^1 = 2 nodes
        // --------------------------
        // total nodes  2^2 - 1 = 3 
        return (int) Math.pow(2, height) - 1 + left;
    }
    private boolean exists(TreeNode node, int pos, int height, int lastLevel) {
        int left = 0, right = lastLevel - 1;
        for (int i = 0; i < height; ++i) {
            int mid = left + (right - left) / 2;
            if (pos > mid) {
                left = mid + 1;
                node = node.right;
            } else {
                right = mid - 1;
                node = node.left;
            }
        }
        return node != null;
    }
    private int calcHeight(TreeNode node) {
        int height = 0;
        while (node.left != null) {
            node = node.left;
            height++;
        }
        return height;
    }
}

In [71]:
TreeNode root = new TreeNode(2);
root.left = new TreeNode(1);
root.right = new TreeNode(3);
new Solution().solve(root);

3

In [72]:
TreeNode root = new TreeNode(1);
root.left = new TreeNode(2);
root.right = new TreeNode(3);
root.left.left = new TreeNode(4);
root.left.right = new TreeNode(5);
root.right.left = new TreeNode(6);
new Solution().solve(root);

6