In [58]:
#-------------------------------------------------------------------------------
# treeset.py
#
#
# Copyright (C) 2016, Ryosuke Fukatani
# License: Apache 2.0
#-------------------------------------------------------------------------------

import bisect


class TreeSet(object):
    """
    Binary-tree set like java Treeset.
    Duplicate elements will not be added.
    When added new element, TreeSet will be sorted automatically.
    """
    def __init__(self, elements):
        self._treeset = []
        self.addAll(elements)

    def addAll(self, elements):
        for element in elements:
            if element in self: continue
            self.add(element)

    def add(self, element):
        if element not in self:
            bisect.insort(self._treeset, element)

    def ceiling(self, e):
        if not self._treeset:
            return None
        index = bisect.bisect_right(self._treeset, e)
        if self[index - 1] == e:
            return e
        try:
            return self._treeset[bisect.bisect_right(self._treeset, e)]
        except IndexError:
            return None

    def floor(self, e):
        if not self._treeset:
            return None
        index = bisect.bisect_left(self._treeset, e)
        if index == len(self._treeset):
            return self[index-1]
        if self[index] == e:
            return e
        check = self._treeset[bisect.bisect_left(self._treeset, e) - 1]
        if check <= e:
            return check
        return None

    def __getitem__(self, num):
        return self._treeset[num]

    def __len__(self):
        return len(self._treeset)

    def clear(self):
        """Delete all elements in TreeSet."""
        self._treeset = []

    def clone(self):
        """Return shallow copy of self."""
        return TreeSet(self._treeset)

    def remove(self, element):
        """Remove element if element in TreeSet."""
        try:
            self._treeset.remove(element)
        except ValueError:
            return False
        return True

    def __iter__(self):
        """Do ascending iteration for TreeSet"""
        for element in self._treeset:
            yield element

    def pop(self, index):
        return self._treeset.pop(index)

    def __str__(self):
        return str(self._treeset)

    def __eq__(self, target):
        if isinstance(target, TreeSet):
            return self._treeset == target.treeset
        elif isinstance(target, list):
            return self._treeset == target
        return None

    def __contains__(self, e):
        """Fast attribution judgment by bisect"""
        try:
            return e == self._treeset[bisect.bisect_left(self._treeset, e)]
        except Exception:
            return False

if __name__ == '__main__':
    ts = TreeSet([3,7,7,1,3])
    print(ts.floor(4))
    print(ts.ceiling(4))
    print(ts.floor(3))
    print(ts.ceiling(3))
    print(ts)

3
7
3
3
[1, 3, 7]


In [59]:
class TreeMap(dict):
    """
    "TreeMap" is a dictionary with sorted keys similar to java TreeMap.
    Keys, iteration, items, values will all return values ordered by key.
    Otherwise it should behave just like the builtin dict.
    """

    def __init__(self, seq=None, **kwargs):
        if seq is None:
            super().__init__(**kwargs)
        else:
            super().__init__(seq, **kwargs)
        self.sorted_keys = TreeSet(super().keys())

    def __setitem__(self, key, value):
        super().__setitem__(key, value)
        self.sorted_keys.add(key)

    def __delitem__(self, key):
        super().__delitem__(key)
        self.sorted_keys.remove(key)

    def keys(self):
        return self.sorted_keys

    def items(self):
        return [(k, self[k]) for k in self.sorted_keys]

    def __iter__(self):
        for k in self.sorted_keys:
            yield k

    def values(self):
        for k in self.sorted_keys:
            yield self[k]

    def clear(self):
        super().clear()
        self.sorted_keys.clear()

In [60]:
ts = TreeSet([3,7,2,7,1,3])
print(ts)
# >>> [1, 2, 3, 7]

ts.add(4)
print(ts)
# >>> [1, 2, 3, 4, 7]

ts.remove(7)
print(ts)
# >>> [1, 2, 3, 4]

ts.remove(5)
print(ts)
# >>> [1, 2, 3, 4]

ts.addAll([3,4,5,6])
print(ts)
# >>> [1, 2, 3, 4, 5, 6]

print(ts[0])
# >>> 1

print(ts[-1])
# >>> 6

print(1 in ts)
# >>> True

print(100 in ts)
# >>> False

for i in TreeSet([1,3,1]):
	print(i)
# >>> 1
# >>> 3


[1, 2, 3, 7]
[1, 2, 3, 4, 7]
[1, 2, 3, 4]
[1, 2, 3, 4]
[1, 2, 3, 4, 5, 6]
1
6
True
False
1
3


In [61]:
tm = TreeMap({'y':1, 'a':100, 'm':-4, 'Z':25, 'M': -34} )
tm

{'M': -34, 'Z': 25, 'a': 100, 'm': -4, 'y': 1}

In [62]:
list(tm.keys())

['M', 'Z', 'a', 'm', 'y']

In [63]:
tm.keys().floor('A')

In [64]:
tm.keys().ceiling('z')

In [65]:
tm.keys().floor('z')

'y'

In [66]:
tm.keys().ceiling('A')

'M'