In [293]:
class TreeNode(object):
    def __init__(self, key=None, value=None):
        self.key = key
        self.value = value
        self.left = None
        self.right = None
        self.parent = None

    def display(self):
        return [(self.key, self.value), ] + ([
            self.left.display() if self.left else None,
            self.right.display() if self.right else None] if (self.left or self.right) else [])
    
    def set_father(self, node):
        if node: node.parent = self
    
    def update_children(self):
        self.set_father(self.left)
        self.set_father(self.right)
    
    def replace_by(self, node):
        self.update_children()
        node.parent = self.parent
        node.update_children()
        

In [294]:
class BSTNode(TreeNode):
    def insert(self, key, value=None):
        if key == self.key:
            self.value = value
        elif key > self.key:
            if not self.right:
                self.right = BSTNode(key, value)
            else:
                self.right.insert(key, value)
            self.set_father(self.right)
        else:
            if not self.left:
                self.left = BSTNode(key, value)
            else:
                self.left.insert(key, value)
            self.set_father(self.left)
        return self
    
    def delete(self, key):
        if key != self.key:
            right = key > self.key
            target = self.right if right else self.left
            if target:
                if right:
                    self.right = target.delete(key)
                    self.set_father(self.right)
                else:
                    self.left = target.delete(key)
                    self.set_father(self.left)
            return self
        else:
            if not self.left and not self.right:
                return None
            else:
                right = self.left is None
                attr, other_attr = 'left' if right else 'right', 'left' if not right else 'right'
                target = self.right if right else self.left
                tmp = target
                while tmp.__getattribute__(attr):
                    tmp = tmp.__getattribute__(attr)
                self.key, self.value = tmp.key, tmp.value
                self.__setattr__(other_attr, self.__getattribute__(other_attr).delete(self.key))
                self.set_father(self.__getattribute__(other_attr))
                return self
            
    def get(self, key):
        if key == self.key:
            return self.value
        elif key > self.key:
            return self.right.get(key) if self.right else None
        else:
            return self.left.get(key) if self.left else None
    
    def get_successor(self):
        if self.right:
            tmp = self.right
            while tmp.left:
                tmp = tmp.left
            return tmp
        n = self
        p = self.parent
        while p and n != p.left:
            n = p
            p = p.parent
        return p
    
          

In [295]:
class AVLTreeNode(BSTNode):
    """A simple AVL Tree implementation in Python"""
    
    def __init__(self, key, value):
        super(AVLTreeNode, self).__init__(key, value)
        self.depth = 1
        
    
    def insert(self, key, value):
        if key == self.key:
            self.value = value
        elif key > self.key:
            if not self.right:
                self.right = AVLTreeNode(key, value)
            else:
                self.right = self.right.insert(key, value)
            self.set_father(self.right)
        else:
            if not self.left:
                self.left = AVLTreeNode(key, value)
            else:
                self.left = self.left.insert(key, value)
            self.set_father(self.left)
        self.update_depth()
        
        l_depth, r_depth = self.left.depth if self.left else 0, self.right.depth if self.right else 0
        if abs(l_depth - r_depth) >= 2:
            return self.maintain()
        else:
            return self
    
    def maintain(self):
        l_depth, r_depth = self.left.depth if self.left else 0, self.right.depth if self.right else 0
        if abs(l_depth - r_depth) > 1:
            left = l_depth > r_depth
            attr, other_attr = 'left' if left else 'right', 'left' if not left else 'right'
            self.__setattr__(attr, self.__getattribute__(attr).rotate(left))
            tmp = self.__getattribute__(attr)
            self.__setattr__(attr, tmp.__getattribute__(other_attr))
            tmp.__setattr__(other_attr, self)
            
            self.replace_by(tmp)
            self.update_depth()
            return tmp
       
        else:
            self.update_depth()
            return self
    
    def delete(self, key):
        self = super(AVLTreeNode, self).delete(key)
        if self:
            l_depth, r_depth = self.update_depth()
            return self.maintain()
        else:
            return
    
    def update_depth(self):
        l_depth, r_depth = 0, 0
        if self.left:
            l_depth = self.left.depth
        if self.right:
            r_depth = self.right.depth
        self.depth = 1 + max(l_depth, r_depth)
        return l_depth, r_depth
    
    def rotate(self, left):
        l_depth, r_depth = self.update_depth()
        m_left = l_depth > r_depth
        if m_left == left and l_depth != r_depth:
            if m_left:
                self.left = self.left.rotate(True)
                self.set_father(self.left)
            elif not m_left:
                self.right = self.right.rotate(False)
                self.set_father(self.right)
            return self
        elif l_depth != r_depth:
            other_depth = r_depth if m_left else l_depth
            attr, other_attr = 'left' if m_left else 'right', 'left' if not m_left else 'right'
            if other_depth:
                self.__setattr__(other_attr, self.__getattribute__(other_attr).rotate(m_left))
            tmp = self.__getattribute__(attr)
            self.__setattr__(tmp.__getattribute__(other_attr))
            tmp.__setattr__(other_attr, self)
            self.replace_by(tmp)
        return self
 

In [296]:
class Dictionary(object):
    """A dictionary implementation based on AVL Tree"""
    
    def __init__(self):
        self.__tree = None
    
    def __setitem__(self, key, item):
        if self.__tree:
            self.__tree = self.__tree.insert(key, item)
        else:
            self.__tree = AVLTreeNode(key, item)
    
    def __getitem__(self, key):
        if self.__tree:
            return self.__tree.get(key)
        else:
            return None
        
    def __delitem__(self, key):
        if self.__tree.get(key):
            self.__tree = self.__tree.delete(key)
        else:
            raise KeyError("No Key named %s" % key)
        
    def __iter__(self):
        t = self.__tree
        if not t: return
        while t.left:
            t = t.left
        self.iter = t
        return self
    
    def __next__(self):
        if not self.iter:
            raise StopIteration
        key = self.iter.key
        self.iter = self.iter.get_successor()
        return key
    
    next = __next__
    
    def tree(self):
        return self.__tree
    
    

In [297]:
import time

In [303]:
def test1(n):
    m_start = time.time()
    d1 = Dictionary()
    for i in range(n):
        d1[i] = str(i)
    for key in d1:
        key, d1[key]
    m_end = time.time()
    print "My Dict insert %d items use time %f second" % (n, m_end - m_start)
    
    d_start = time.time()
    d2 = dict()
    for i in range(n):
        d2[i] = str(i)
    for key in d2:
        key, d2[key]
    d_end = time.time()
    
    print "Built-in Dict insert %d items use time %f second" % (n, d_end - d_start)

import numpy as np
def test2(n):
    data = np.arange(n)
    np.random.shuffle(data)
    
    m_dic, dic = Dictionary(), dict()
    
    for i in range(n):
        m_dic[i] = data[i]
        dic[i] = data[i]
    
    correct = 0 
    for i in range(n):
        if m_dic[i] == dic[i]:
            correct += 1
    print "%d/%d in correctness" % (correct, n)
        
    

In [304]:
test1(5000)

My Dict insert 5000 items use time 0.374641 second
Built-in Dict insert 5000 items use time 0.002712 second


In [305]:
test2(5000)

5000/5000 in correctness
