# 二叉树

## 示例: *python 面向对象编程指南* page132

In [1]:
import weakref

class TreeNode(object):
    """叶子节点"""

    def __init__(self, item, less=None, more=None, parent=None):
        self.item = item
        self.less = less
        self.more = more
        if parent is not None:
            self.parent = parent

    def __iter__(self):
        if self.less:
            for item in iter(self.less):
                yield item
        yield self.item
        if self.more:
            for item in iter(self.more):
                yield item
    
    def __repr__(self):
        return '<TreeNode item=%s less=%s more=%s parent=%s>' % (
            self.item,
            self.less.item if self.less is not None else None,
            self.more.item if self.more is not None else None,
            self.parent.item if self.parent is not None else None,
        )
        
    @property
    def parent(self):
        return self.parent_wref()

    @parent.setter
    def parent(self, value):
        self.parent_wref = weakref.ref(value)

    def find(self, item):
        if self.item is None and self.more:  # when is root node
            return self.more.find(item)
        elif self.item == item:
            return self
        elif self.item > item and self.less:
            return self.less.find(item)
        elif self.item < item and self.more:
            return self.more.find(item)
        raise KeyError

    def add(self, item):
        if self.item is None:
            if self.more:
                self.more.add(item)
            else:
                self.more = TreeNode(item, parent=self)
        elif self.item > item:
            if self.less:
                self.less.add(item)
            else:
                self.less = TreeNode(item, parent=self)
        elif self.item < item:
            if self.more:
                self.more.add(item)
            else:
                self.more = TreeNode(item, parent=self)

    def remove(self, item):
        if self.item is None or item > self.item:
            if self.more:
                self.more.remove(item)
            else:
                raise KeyError
        elif item < self.item:
            if self.less:
                self.less.remove(item)
            else:
                raise KeyError
        else:  # self.item == item
            if self.less and self.more:  # has two childred
                # 用 more tree 的最小节点替换当前节点
                successor = self.more._least()
                self.item = successor.item
                successor.remove(successor.item)
            elif self.less:
                self._replace(self.less)
            elif self.more:
                self._replace(self.more)
            else:  # no child
                self._replace(None)

    def _least(self):
        if self.less is None:
            return self
        return self.less._least()

    def _replace(self, new=None):
        if self.parent:
            if self == self.parent.less:
                self.parent.less = new
            else:
                self.parent.more = new
        if new is not None:
            new.parent = self.parent

In [2]:
from collections import MutableSet

class Tree(MutableSet):
    """二叉树"""

    def __init__(self, iterable):
        self.root = TreeNode(None)
        self.size = 0
        if iterable:
            for item in iterable:
                self.add(item)

    def add(self, item):
        self.root.add(item)
        self.size += 1

    def discard(self, item):
        try:
            self.root.more.remove(item)
            self.size -= 1
        except KeyError:
            pass

    def __contains__(self, item):
        try:
            self.root.more.find(item)
            return True
        except KeyError:
            return Fasle
    
    def __iter__(self):
        for item in iter(self.root.more):
            yield item
    
    def __len__(self):
        return self.size

In [3]:
t = Tree([6, 3, 2, 7, 9, 1, 0])

In [4]:
list(t)

[0, 1, 2, 3, 6, 7, 9]

In [5]:
len(t)

7

In [6]:
6 in t

True

In [7]:
t.discard(7)

In [8]:
list(t)

[0, 1, 2, 3, 6, 9]

In [9]:
t.add(8)

In [10]:
list(t)

[0, 1, 2, 3, 6, 8, 9]

In [11]:
node = t.root
print node.more

more = node.more.more
print more

<TreeNode item=6 less=3 more=9 parent=None>
<TreeNode item=9 less=8 more=None parent=6>


In [18]:
# 测试 Tree 的集合属性

t1 = Tree(['a', 'c', 'y'])
t2 = Tree(['b', 'x', 'e'])
union = t1 | t2

In [19]:
list(union)

['a', 'b', 'c', 'e', 'x', 'y']

In [20]:
'e' in union

True

In [21]:
union.remove('e')
list(union)

['a', 'b', 'c', 'x', 'y']