In [1]:
%%HTML
<style>.container{width:100%;}</style>

# Cython-Direktive

Zur Performanceoptimierung nutzen wir [*Cython*](https://cython.org/ "S. Behnel et al. (2020): Cython: C-Extensions for Python"), einen Compiler, der besondere Direktive bezüglich der Datentypen in Python-Code ausnutzt, um den Overhead von Pythons Typsystem zu verkleinern. Wir laden Cython in unser Notebook, jedoch keine Schönheitsfehlerüberprüfung, da diese nicht mit Cython kompatibel ist. Auch die Imports aus der Standard Library führen wir nicht jetzt schon durch, da sie bei Cython in der Zelle importiert werden müssen, in der sie verwendet werden.

In [2]:
%load_ext cython

Dass die Implementierung mit Cython-Hinweisen getrennt und am Stück vorliegt, hat zwei Gründe.

- In den Notebooks über Splay Trees und über geordnete Mengen haben wir die Funktionen einzeln besprochen und haben sie dann den jeweiligen Klassen hinzugefügt, diesen Stil nennt man auch *Monkey Patching*. Dies war wichtig für den Lesefluss, ist jedoch mit der kompilierten Natur von Cython inkompatibel.
- Wir wollen die Implementierungen mit und ohne Cython untereinander und mit anderen Implementierungen vergleichen. Cython hat jedoch keinen Modus, mit dem die Optimierungen deaktiviert werden können.

Sie unterscheidet sich auch in einigen Dingen von der reinen Python-Implementierung.

- `_splay` ist iterativ definiert, um nicht wie bei rekursiven Implementierungen durch Speicher begrenzt zu sein. Die einzelnen Schritte (wie *Zig-Zag*) werden dort direkt aufgerufen, sodass dafür keine Funktionen aufgerufen werden müssen.
- Wir können die Funktionen, die wir in mehreren Klassen verwenden, nicht monkey-patchen. Wir definieren daher die als privat markierte Klasse `_ArbComparison` für die Vergleiche zwischen beliebigen Objekten, von der sowohl `Node`, als auch `_GenericOrderedSet` erben. `_GenericOrderedSet`, ebenfalls privat markiert, dient als Basisklasse für `OrderedSet` und `OrderedFrozenset`, wobei die gemeinsamen Funktionen in `_GenericOrderedSet` definiert sind. Die eher trivialen und nur intern genutzten Funktionen von `_ArbComparison` werden außerdem als `inline` markiert, das heißt als Funktionen, die, soweit möglich, bei Verwendung direkt eingesetzt werden sollen und nicht als Funktionen aufgerufen werden.
- `_GenericOrderedSet.__init__` überprüft nicht, ob es sich bei `self` um ein bestehendes `OrderedFrozenset` handelt, weil das Attribut `_tree` immer schon gesetzt ist.
- `__hash__` wird nicht einzeln für `OrderedSet` und `OrderedFrozenset` definiert, da Python sonst anhand des Hashwertes über die Gleichheit zweier Mengen unterscheidet, und nicht anhand von `_GenericOrderedSet.__eq__`. Wir behelfen uns mit einer Typüberprüfung von `self`.
- Für Pickling müssen wir statt `__getnewargs__` [`__reduce__`](https://docs.python.org/3.7/library/pickle.html#object.__reduce__ "Python Software Foundation (2020): The Python Standard Library/Data Persistence/object.__reduce__(), Python Documentation") implementieren, welches aber für uns recht ähnlich funktioniert.
- Wann immer wir überprüfen, ob das `other`-Objekt eine geordnete Menge ist, so müssen wir hier auch `self` überprüfen, da sich diese Verteilung mit Cython ändern kann.
- Das Uncurrying, mit dem wir Operatoren wie `&` oder `|` definiert haben, verwenden wir hier nur für die In-Place-Operationen, weil dieser Stil bei Cython nicht mit abstrakten Klassen kompatibel ist. Auch kann `_uncurry_inp_op` nicht als Klasse definiert werden.

Grundsätzlich werden Klassenvariablen und lokale Variablen sowie Ein- und Ausgabetypen von Funktionen so eng wie möglich mit Typhinweisen versehen. Dabei müssen allerdings einige Ausnahmen gemacht werden.

- Wenn eine Funktion mit dem Exception-System kompatibel sein muss, wird auf einen expliziten Returntyp verzichtet. Exceptions müssten sonst allein über die Returntypen verarbeitet werden, was nicht mit den Erwartungen von möglichen Anwendern vereinbar ist. Dies betrifft die Funktionen von `_ArbComparison` sowie `isdisjoint`,  `add`, `remove`, `discard` und `pop`.
- Wir definieren für `_GenericOrderedSet` nicht die in Cython übliche `__cinit__`-Funktion, sondern das normale `__init__`, weil die Tests verlangen, dass diese Funktion explizit aufgerufen werden kann.
- *Special Methods* wie `__len__`, `__eq__` oder `__and__` können nur als reguläre Funktionen definiert werden.
- *Argumentlisten* (z. B. `*others` in `union` und verwandten Funktionen) können nur in regulären Funktionen verwendet werden.
- Bei der Verwendung von *Closures*, also funktionsinternen Definitionen, können keine Cython-Hinweise genutzt werden. Dies betrifft `copy`, `intersection`, `symmetric_difference` sowie die In-Place-Varianten der letzteren.

Die Klassenvariablen sind außerdem als `public` markiert, weil wir auf sie zugreifen wollen und sonst Getter und Setter benötigen würden. Diesen Zugriff nutzen wir bei der Implementierung von `intersection_update` und `symmetric_difference_update`, weil wir dort die Implementierungseigenschaften von Splay Trees ausnutzen.

In [3]:
%%cython

import collections
import functools

cdef class _ArbComparison:
    cdef inline _arb_gt(self, object x, object y):
        try:
            return y < x
        except TypeError:
            return type(x).__name__ > type(y).__name__

    cdef inline _arb_lt(self, object x, object y):
        try:
            return x < y
        except TypeError:
            return type(x).__name__ < type(y).__name__

    cdef inline _arb_eq(self, object x, object y):
        try:
            return x == y
        except TypeError:
            return False


cdef class Node(_ArbComparison):
    cdef public object payload
    cdef public Node left, right

    def __cinit__(self, object payload, Node left, Node right):
        self.payload = payload
        self.left    = left
        self.right   = right

    cpdef Node _splay(self, object payload):
        cdef Node max_less, min_greater, set_aside
        max_less = min_greater = set_aside = Node(None, None, None)
        while True:
            if self._arb_lt(payload, self.payload):
                if self.left is None:
                    break
                if self._arb_lt(payload, self.left.payload) \
                        and self.left.left is not None:
                    tmp              = self.left
                    self.left        = tmp.right
                    tmp.right        = self
                    min_greater.left = tmp
                    min_greater      = tmp
                    self             = tmp.left
                    continue
                if self._arb_gt(payload, self.left.payload) \
                        and self.left.right is not None:
                    max_less.right   = self.left
                    max_less         = self.left
                    min_greater.left = self
                    min_greater      = self
                    self             = self.left.right
                    continue
                min_greater.left     = self
                min_greater          = self
                self                 = self.left
                break
            if self._arb_gt(payload, self.payload):
                if self.right is None:
                    break
                if self._arb_gt(payload, self.right.payload) \
                        and self.right.right is not None:
                    tmp              = self.right
                    self.right       = tmp.left
                    tmp.left         = self
                    max_less.right   = tmp
                    max_less         = tmp
                    self             = tmp.right
                    continue
                if self._arb_lt(payload, self.right.payload) \
                        and self.right.left is not None:
                    max_less.right   = self
                    max_less         = self
                    min_greater.left = self.right
                    min_greater      = self.right
                    self             = self.right.left
                    continue
                max_less.right       = self
                max_less             = self
                self                 = self.right
                break
            break
        max_less.right   = self.left
        min_greater.left = self.right
        self.left        = set_aside.right
        self.right       = set_aside.left
        return self

    cpdef Node insert(self, object payload):
        self = self._splay(payload)
        if self._arb_eq(payload, self.payload):
            return self
        cdef Node tmp
        if self._arb_lt(payload, self.payload):
            tmp       = self.left
            self.left = None
            return Node(payload, tmp, self)
        tmp        = self.right
        self.right = None
        return Node(payload, self, tmp)

    cpdef object remove(self, object payload):
        self = self._splay(payload)
        if not self._arb_eq(payload, self.payload):
            return False, self
        if self.left is None:
            return True, self.right
        if self.right is None:
            return True, self.left
        cdef Node tmp = self.left
        self          = self.right._splay(payload)
        self.left     = tmp
        return True, self

    cpdef object contains(self, object payload):
        self = self._splay(payload)
        return self._arb_eq(payload, self.payload), self

    cpdef object minimum(self):
        while self.left is not None:
            self = self.left
        return self.payload

    cpdef object maximum(self):
        while self.right is not None:
            self = self.right
        return self.payload


cdef class _GenericOrderedSet(_ArbComparison):
    cdef public Node _tree
    cdef object __weakref__  # necessary for weakref, not used directly

    def __init__(self, iterable=[]):
        self._tree           = None
        cdef object iterator = iter(iterable)
        cdef object element
        try:
            element = next(iterator)
            if not element.__hash__:
                raise TypeError(f"unhashable type: '{type(element).__name__}'")
            self._tree = Node(element, None, None)
            while True:
                element = next(iterator)
                if not element.__hash__:
                    raise TypeError("unhashable type: " +
                                    f"'{type(element).__name__}'")
                self._tree = self._tree.insert(element)
        except StopIteration:
            pass

    def __iter__(self):
        stack          = collections.deque()
        cdef Node tree = self._tree
        while stack or tree is not None:
            if tree is not None:
                stack.append(tree)
                tree = tree.left
                continue
            tree = stack.pop()
            yield tree.payload
            tree = tree.right

    def __lt__(self, object other):
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            return NotImplemented
        cdef object x_iter = iter(self)
        cdef object y_iter = iter(other)
        cdef object x_item, y_item
        while True:
            try:
                y_item = next(y_iter)
            except StopIteration:
                return False  # x is longer or equal
            try:
                x_item = next(x_iter)
            except StopIteration:
                return True  # x is shorter
            if self._arb_lt(x_item, y_item):
                return True
            if self._arb_gt(x_item, y_item):
                return False

    def __gt__(self, object other):
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            return NotImplemented
        cdef object x_iter = iter(self)
        cdef object y_iter = iter(other)
        cdef object x_item, y_item
        while True:
            try:
                x_item = next(x_iter)
            except StopIteration:
                return False  # x is shorter or equal
            try:
                y_item = next(y_iter)
            except StopIteration:
                return True  # x is longer
            if self._arb_gt(x_item, y_item):
                return True
            if self._arb_lt(x_item, y_item):
                return False

    def __eq__(self, object other):
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            return False
        cdef object x_iter = iter(self)
        cdef object y_iter = iter(other)
        cdef object x_item, y_item
        while True:
            try:
                x_item = next(x_iter)
            except StopIteration:
                try:  # assert y is also exhausted
                    next(y_iter)
                except StopIteration:
                    return True
                return False
            try:
                y_item = next(y_iter)
            except StopIteration:
                return False  # x was not exhausted
            if not self._arb_eq(x_item, y_item):
                return False

    def __hash__(self):
        if isinstance(self, OrderedSet):
            raise TypeError(f"unhashable type: '{type(self).__name__}'")
        else:
            return sum(hash(element) for element in self)

    def __contains__(self, object element):
        if not element.__hash__:
            raise TypeError(f"unhashable type: '{type(element).__name__}'")
        if self._tree is None:
            return False
        cdef bint contains
        contains, self._tree = self._tree.contains(element)
        return contains

    def __len__(self):
        return sum(1 for i in self)

    def __repr__(self):
        if self._tree is None:
            return f"{type(self).__name__}()"
        return f"{type(self).__name__}({list(self)})"

    def __reduce__(self):
        return (type(self), (list(self),))

    cpdef object minimum(self):
        if self._tree is None:
            raise ValueError("Set is empty")
        return self._tree.minimum()

    cpdef object maximum(self):
        while self.right is not None:
            self = self.right
        return self.payload

    cdef bint _subseteq(self, _GenericOrderedSet x, _GenericOrderedSet y):
        cdef object x_iter = iter(x)
        cdef object y_iter = iter(y)
        cdef object x_item, y_item
        while True:
            try:
                x_item = next(x_iter)
            except StopIteration:
                return True
            while True:
                try:
                    y_item = next(y_iter)
                except StopIteration:
                    return False
                if self._arb_lt(x_item, y_item):
                    return False
                if self._arb_eq(x_item, y_item):
                    break

    cpdef bint issubset(self, object other):
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            other = _GenericOrderedSet(other)
        return self._subseteq(self, other)

    def __le__(self, object other):
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            return NotImplemented
        return self._subseteq(self, other)

    cdef bint _subsetneq(self, _GenericOrderedSet x, _GenericOrderedSet y):
        cdef object x_iter = iter(x)
        cdef object y_iter = iter(y)
        cdef object x_item, y_item
        cdef bint proper_subset = False
        while True:
            try:
                x_item = next(x_iter)
            except StopIteration:
                if not proper_subset:
                    try:  # assert y is not exhausted
                        next(y_iter)
                    except StopIteration:
                        return False
                return True
            while True:
                try:
                    y_item = next(y_iter)
                except StopIteration:
                    return False
                if self._arb_lt(x_item, y_item):
                    return False
                if self._arb_gt(x_item, y_item):
                    proper_subset = True
                else:
                    break

    cpdef bint is_proper_subset(self, object other):
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            other = _GenericOrderedSet(other)
        return self._subsetneq(self, other)

    cpdef bint issuperset(self, object other):
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            other = _GenericOrderedSet(other)
        return self._subseteq(other, self)

    def __ge__(self, object other):
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            return NotImplemented
        return self._subseteq(other, self)

    cpdef bint is_proper_superset(self, object other):
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            other = _GenericOrderedSet(other)
        return self._subsetneq(other, self)

    cpdef isdisjoint(self, object other):
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            other = _GenericOrderedSet(other)
        cdef object x_iter = iter(self)
        cdef object y_iter = iter(other)
        cdef object x_item, y_item
        try:
            y_item = next(y_iter)  # we need a y_item for first comparison
        except StopIteration:
            return True
        while True:
            try:
                x_item = next(x_iter)
            except StopIteration:
                return True
            while True:
                if self._arb_lt(x_item, y_item):
                    break
                if self._arb_gt(x_item, y_item):
                    try:
                        y_item = next(y_iter)
                    except StopIteration:
                        return True
                    continue
                return False

    def union(self, *others):
        union = OrderedSet(self)
        for other in others:
            for el in other:
                union.add(el)
        if isinstance(self, OrderedFrozenset):
            frozen       = OrderedFrozenset()
            frozen._tree = union._tree
            return frozen
        return union

    def __or__(self, other):
        if (isinstance(self, OrderedSet)
                or isinstance(self, OrderedFrozenset)) and \
                (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            return self.union(other)
        else:
            return NotImplemented

    __ror__ = __or__  # commutativity

    def difference(self, *others):
        difference = OrderedSet(self)
        for other in others:
            for el in other:
                if not el.__hash__:
                    raise TypeError(f"unhashable type: '{type(el).__name__}'")
                difference.discard(el)
        if isinstance(self, OrderedFrozenset):
            frozen       = OrderedFrozenset()
            frozen._tree = difference._tree
            return frozen
        return difference

    def __sub__(self, other):
        if (isinstance(self, OrderedSet)
                or isinstance(self, OrderedFrozenset)) and \
                (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            return self.difference(other)
        else:
            return NotImplemented

    def __rsub__(self, other):
        if isinstance(self, OrderedSet) \
                or isinstance(self, OrderedFrozenset) and \
                isinstance(other, OrderedSet) \
                or isinstance(other, OrderedFrozenset):
            return other.difference(self)
        else:
            return NotImplemented

    def intersection(self, *others):
        def intersect(x, y):
            intersection   = OrderedSet()
            x_iter, y_iter = iter(x), iter(y)
            try:
                x_item, y_item = next(x_iter), next(y_iter)
                while True:
                    if self._arb_lt(x_item, y_item):
                        x_item = next(x_iter)
                        continue
                    if self._arb_gt(x_item, y_item):
                        y_item = next(y_iter)
                        continue
                    intersection.add(x_item)
                    x_item, y_item = next(x_iter), next(y_iter)
            except StopIteration:
                return intersection
        if not others:
            return self.copy()  # otherwise we're returning self
        sets = [self] + [OrderedSet(other) if not isinstance(other,
                         OrderedSet) or isinstance(other,
                         OrderedFrozenset) else other for other in others]
        intersection = functools.reduce(intersect, sets)
        if isinstance(self, OrderedFrozenset):
            frozen       = OrderedFrozenset()
            frozen._tree = intersection._tree
            return frozen
        return intersection

    def __and__(self, other):
        if (isinstance(self, OrderedSet)
                or isinstance(self, OrderedFrozenset)) and \
                (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            return self.intersection(other)
        else:
            return NotImplemented

    __rand__ = __and__

    def symmetric_difference(self, other):
        class OtherStopIteration(StopIteration):
            pass
        if self._tree is None:
            return type(self)(other)
        if not (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            other = OrderedSet(other)
        x_iter, y_iter = iter(self), iter(other)
        x_item         = next(x_iter)
        try:
            y_item = next(y_iter)
        except StopIteration:
            return self.copy()
        symmetric_difference = OrderedSet()
        try:
            try:
                while True:
                    if self._arb_lt(x_item, y_item):
                        symmetric_difference.add(x_item)
                        try:
                            x_item = next(x_iter)
                        except StopIteration:
                            symmetric_difference.add(y_item)
                            raise StopIteration
                        continue
                    if self._arb_gt(x_item, y_item):
                        symmetric_difference.add(y_item)
                        try:
                            y_item = next(y_iter)
                        except StopIteration:
                            symmetric_difference.add(x_item)
                            raise OtherStopIteration
                        continue
                    x_item = next(x_iter)
                    try:
                        y_item = next(y_iter)
                    except StopIteration:
                        symmetric_difference.add(x_item)  # already updated
                        raise OtherStopIteration
            except OtherStopIteration:
                while True:
                    symmetric_difference.add(next(x_iter))
        except StopIteration:
            try:
                while True:
                    symmetric_difference.add(next(y_iter))
            except StopIteration:
                pass
        if isinstance(self, OrderedFrozenset):
            frozen       = OrderedFrozenset()
            frozen._tree = symmetric_difference._tree
            return frozen
        return symmetric_difference

    def __xor__(self, other):
        if (isinstance(self, OrderedSet)
                or isinstance(self, OrderedFrozenset)) and \
                (isinstance(other, OrderedSet)
                or isinstance(other, OrderedFrozenset)):
            return self.symmetric_difference(other)
        else:
            return NotImplemented

    __rxor__ = __xor__

def _uncurry_inp_op(func):
    def inp_op(self, other):
        if isinstance(other, OrderedSet) \
                or isinstance(other, OrderedFrozenset):
            func(self, other)
            return self
        else:
            return NotImplemented
    return inp_op

cdef class OrderedSet(_GenericOrderedSet):
    cpdef add(self, object element):
        if not element.__hash__:
            raise TypeError(f"unhashable type: '{type(element).__name__}'")
        if self._tree is None:
            self._tree = Node(element, None, None)
        else:
            self._tree = self._tree.insert(element)

    cpdef remove(self, object element):
        if not element.__hash__:
            raise TypeError(f"unhashable type: '{type(element).__name__}'")
        if self._tree is None:
            raise KeyError(element)
        rc, self._tree = self._tree.remove(element)
        if not rc:
            raise KeyError(element)

    cpdef discard(self, object element):
        if not element.__hash__:
            raise TypeError(f"unhashable type: '{type(element).__name__}'")
        if self._tree is not None:
            _, self._tree = self._tree.remove(element)
 
    def copy(self):
        return OrderedSet(el for el in self)

    cpdef pop(self):
        if self._tree is None:
            raise KeyError("pop from an empty set")
        cdef object popped = self.minimum()
        self.remove(popped)
        return popped

    cpdef void clear(self):
        self._tree = None

    def update(self, *others):
        for other in others:
            for el in other:
                self.add(el)

    __ior__ = _uncurry_inp_op(update)

    def difference_update(self, *others):
        for other in others:
            for el in other:
                if not el.__hash__:
                    raise TypeError(f"unhashable type: '{type(el).__name__}'")
                self.discard(el)

    __isub__ = _uncurry_inp_op(difference_update)

    def intersection_update(self, *others):
        def intersect_update(self, other):
            self._tree = self._tree._splay(self.minimum())
            other_iter = iter(other)
            try:
                other_item        = next(other_iter)
                comparison_helper = OrderedSet()
                while True:
                    if comparison_helper._arb_lt(
                            self._tree.payload, other_item):
                        if self._tree.right is None:
                            raise StopIteration
                        minimum = self._tree.left is None
                        self.remove(self._tree.payload)  # like moving on
                        if minimum and self._tree is not None:
                            # removal leads to using right subtree
                            # instead of next element
                            self._tree = self._tree._splay(self.minimum())
                        continue
                    if comparison_helper._arb_gt(
                            self._tree.payload, other_item):
                        other_item = next(other_iter)
                        continue
                    if self._tree.right is None:
                        return
                    # equivalent to inserting and moving on
                    self._tree = self._tree._splay(self._tree.right.minimum())
                    other_item = next(other_iter)
            except StopIteration:
                # leave the rest
                self._tree = self._tree.left
        for other in others:
            if self._tree is None:
                return
            if not (isinstance(other, OrderedSet)
                    or isinstance(other, OrderedFrozenset)) \
                    or id(self) == id(other):  # don't iterate through self
                other = OrderedSet(other)
            intersect_update(self, other)

    __iand__ = _uncurry_inp_op(intersection_update)

    def symmetric_difference_update(self, other):
        class SelfStopIteration(StopIteration):
            pass
        if self._tree is None:
            return OrderedSet(other)
        if id(self) == id(other):
            self._tree = None
            return  # iteration through self is unstable
        if type(other) not in (OrderedSet, OrderedFrozenset):
            other = OrderedSet(other)
        self._tree = self._tree._splay(self.minimum())
        other_iter = iter(other)
        try:
            other_item = next(other_iter)
            try:
                while True:
                    if self._arb_lt(self._tree.payload, other_item):
                        if self._tree.right is None:
                            self.add(other_item)
                            raise SelfStopIteration
                        self._tree = self._tree._splay(
                            self._tree.right.minimum())
                        continue
                    if self._arb_gt(self._tree.payload, other_item):
                        root = self._tree.payload
                        self.add(other_item)
                        self._tree = self._tree._splay(root)
                        other_item = next(other_iter)
                        continue
                    minimum = self._tree.left is None
                    maximum = self._tree.right is None
                    self.remove(self._tree.payload)
                    if self._tree is None or maximum:
                        raise SelfStopIteration
                    if minimum:
                        self._tree = self._tree._splay(self.minimum())
                    other_item = next(other_iter)
            except SelfStopIteration:
                while True:
                    self.add(next(other_iter))
        except StopIteration:
            return
        
    __ixor__ = _uncurry_inp_op(symmetric_difference_update)


cdef class OrderedFrozenset(_GenericOrderedSet):
    def copy(self):
        return OrderedFrozenset(el for el in self)