# 基于二叉树实现的映射

- 参考资料
 - <<算法>>(第四版) 3.2小节

In [1]:
%run search_base.ipynb

In [4]:
class _Node(object):
    l = None    # left node
    r = None    # right node
    n = 0       # 该节点下总节点数(包括自身)

    def __init__(self, k, v, n):
        self.k = k
        self.v = v
        self.n = n


class BinaryTreeMapping(MapBase):
    """基于二叉树实现的映射"""

    root = None

    def size(self):
        return self._size(self.root)

    @classmethod
    def _size(cls, node):
        return 0 if node is None else node.n

    def get(self, k):
        node = self._get(self.root, k)
        return node.v if node else None

    @classmethod
    def _get(cls, node, k):
        if node is None: return None
        if k < node.k:   return cls._get(node.l, k)
        elif k > node.k: return cls._get(node.r, k)
        else:            return node

    def put(self, k, v):
        # key存在则更新, 否则创建一个新节点
        self.root = self._put(self.node, k, v)

    @classmethod
    def _put(cls, node, k, v):
        if node is None: return _Node(k, v, 1)            # ``k`` 不存在时, 新建一个节点

        if k < node.k:   node.l = cls._put(node.l, k, v)  # 将 node.l 连接到一个包含 k 的节点
        elif k > node.k: node.r = cls._put(node.r, k, v)  # 将 node.r 链接到一个包含 k 的节点
        else:            node.v = v

        # 因为可能是修改节点, 也可能是新增节点, 故要更新计数
        node.n = cls._size(node.l) + cls._size(node.r) + 1
        return node

    def delete_min(self):
        if self.is_empty(): return
        self.root = self._delete_min(self.root)

    @classmethod
    def _delete_min(cls, node, k):
        # 如果 ``node.l`` 是 ``None``, 则是 ``node`` 就是 ``min``,
        # 此时返回 ``node.r``, ``node.r`` 将被连接到 ``node`` 的父节点,
        # 而 ``node`` 会被当做垃圾回收
        if node.l is None: return node.r

        node.l = cls._delete_min(node.l)
        node.n = cls._size(node.l) + cls._size(node.r) + 1
        return node

    def delete_max(self):
        if self.is_empty(): return
        self.root = self._delete_max(self.root)

    @classmethod
    def _delete_max(cls, node):
        if node.r is None: return node.l
        node.r = cls._delete_max(node.r)
        node.n = cls._size(node.l) + cls._size(node.r) + 1
        return node

    def delete(self, k):
        if self.get(k) is None: return
        self.root = self._delete(self.root, k)

    @classmethod
    def _delete(cls, node, k):
        if k < node.k:
            node.l = cls._delete(node.l, k)
        elif k > node.k:
            node.r = cls._delete(node.r, k)
        else:
            # ``node`` 将被删除, 需要将 ``node.l`` 连接到 ``min(node.r)``
            # ``min(node.r)`` 将被返回并被连接到 ``node`` 的父节点

            if node.r is None: return node.l
            if node.l is None: return node.r

            old = node
            node = cls._min(old.r)
            node.l = old.l
            node.r = cls._delete_min(old.r)

        node.n = cls._size(node.l) + cls._size(node.r) + 1
        return node

    def min(self):
        node = self._min(self.root)
        return node.k if node else None

    @classmethod
    def _min(cls, node):
        if node is None:   return None
        if node.l is None: return node
        else:              return cls._min(node.l)

    def max(self):
        node = self._max(self.root)
        return node.k if node else None

    @classmethod
    def _max(cls, node):
        if node is None:   return None
        if node.r is None: return node
        else:              return cls._max(node.r)

    def floor(self, k):
        node = self._floor(self.root, k)
        return node.k if node else None

    @classmethod
    def _floor(cls, node, k):
        if node is None: return None
        if k == node.k:  return node
        if k < node.k:   return cls._floor(node.l, k)
        n = cls._floor(node.r, k)
        # 如果 ``node.r`` 中没有找到floor, 则floor就是 ``node``
        return node if n is None else n

    def ceiling(self, k):
        node = self._ceiling(self.root, k)
        return node.k if node else None

    @classmethod
    def _ceiling(cls, node, k):
        if node is None:  return None
        if k == node.k:  return node
        if k > node.k:   return cls._ceiling(node.r, k)
        n = cls._ceiling(node.l, k)
        return node if n is None else n

    def select(self, n):
        node = self._select(self.root, n)
        return node.k if node else None

    @classmethod
    def _select(cls, node, n):
        """返回第n个元素k, n的本质是: 列表中小于k的元素的 **个数**

            node.l        node        node.r
        0-------------49   50  51---------------100

        CASE1: n < n_left: node.l 中的元素个数多于n个, 需要在node.l中找第n个元素
        0--------------> ==> n_left
        0------->        ==> n

        CASE2: n > n_left: node.l 中的元素个数少于n个, 需要在node.r中找第n-n_left-1个元素
        0--------------> ==> n_left
        0--------------------------------->n
        |      n_left     |1|  n-n_left-1  |

        CASE3: n == n_left: node.l 中的元素个数等于n个, 则node就是要找的
        0--------------> ==> n_left
        0--------------> ==> n
        """
        if node is None: return None

        n_left = cls._size(node.l)
        if n < n_left:   return cls._select(node.l, n)
        elif n > n_left: return cls._select(node.r, n-n_left-1)
        else:            return node

    def rank(self, k):
        return self._rank(self.root, k)

    @classmethod
    def _rank(cls, node, k):
        if node is None: return 0
        if k == node.k:  return cls._size(node.l)
        elif k > node.k: return cls._size(node.l) + 1 + cls._rank(node.r, k)
        else:            return cls._rank(node.l, k)

    def keys(self):
        return self.keys_between(self.min(), self.max())

    def keys_between(self, k1, k2):
        lst = []
        self._keys(self.root, k1, k2, lst)
        return lst

    @classmethod
    def _keys(cls, node, k1, k2, lst):
        if node is None:       return
        if k1 < node.k:        cls._keys(node.l, k1, k2, lst)
        if k1 <= node.k <= k2: lst.append(node.k)
        if k2 > node.k:        cls._keys(node.r, k1, k2, lst)