Operator overloading
====================

:hourglass: 3h

**Outline**:
1. First example
2. Comparison overloading
3. Arithmetic overloading
4. Shift operators
5. Unary operators
6. Exercises
7. Closing words

## 1. First example

:hourglass: 40 min

The example below illustrates two operation overloads.

In [15]:
from __future__ import annotations

from typing import Union
import random


class Gram:
    def create_choice(self, other: Union[Gram, str]) -> OrRule:
        return OrRule(self, other).flatten()
    
    def create_rule(self, other: Union[Gram, str]) -> CFRule:
        return CFRule(self, other)
    
    def __or__(self, other: Union[Gram, str]) -> OrRule:
        return self.create_choice(other)
    
    def __add__(self, other: Union[Gram, str]) -> CFRule:
        return self.create_rule(other)
        

class Rule(Gram):
    def __init__(self, *operands: Union[Gram, str]) -> None:
        super().__init__()
        def cast(x: Union[Gram, str]) -> Gram:
            if isinstance(x, str):
                x = Terminal(x)
            return x
        
        self._operands = tuple(cast(x) for x in operands)

class OrRule(Rule):
    def __str__(self) -> str:
        return str(random.choice(self._operands))
    
    def flatten(self) -> OrRule:
        ops = []
        for operand in self._operands:
            if isinstance(operand, OrRule):
                operand = operand.flatten()
                ops.extend(operand._operands)
            else:
                ops.append(operand)
        
        return OrRule(*ops)
    
class CFRule(Rule):

    def __str__(self) -> str:
        return " ".join(str(x) for x in self._operands)


class Terminal(Gram):
    def __init__(self, symbol: str) -> None:
        self._symbol = symbol

    def __str__(self) -> str:
        return self._symbol


In [17]:
protagonist = (
    Terminal("Harry Potter") | "Luke Skywalker" | "Frodo Baggins"
)


actions = Terminal("must find") | "must destroy"

macguffin = (
    Terminal("the horcruxes") | "the death star" | "the one true ring"
)

victory = Terminal("to defeat") | "to beat" | "to vanquished" | "to rid the world of"

enemy = Terminal("Voldemort") | "the empire" | "Sauron"


story = protagonist + actions + macguffin + victory + enemy

for _ in range(5):
    print(story)


Harry Potter must destroy the death star to defeat Voldemort
Frodo Baggins must find the death star to beat Sauron
Luke Skywalker must destroy the one true ring to defeat Voldemort
Harry Potter must destroy the one true ring to beat Voldemort
Luke Skywalker must find the horcruxes to beat the empire


Overloading `|` and `+` does not allow us to do anything more, so

:question: Why use operation overloading?

:question: What should you be careful with when dealing with operation overloading?

:question: What are good use cases for operation overloading?

:skull: The example above is about context-free grammar. The formal notion of grammar is used extensively in CS from computation theory to compilers, passing through procedural generation, fractal drawing and natural language processing.

## 2. Comparison overloading

:hourglass: 30 min

You can overload the following comparison operators in Python:

| Operator         | Symbol | Dunder   |
|------------------|--------|----------|
| equal            | ==     | `__eq__` |
| not equal        | !=     | `__ne__` |
| lower than       | <      | `__lt__` |
| lower or equal   | <=     | `__le__` |
| greater than     | >      | `__lt__` |
| greater or equal | <      | `__lt__` |

All those relationships are binary, so the methods always take another element as input.

Related to comparison, you can also overwrite the `__bool__` type-cast method of a object so that it can be used as a boolean.


In [4]:
from abc import ABCMeta, abstractmethod
from typing import Any

import pandas as pd


class Test(metaclass=ABCMeta):
    def __init__(self, column_name: str, reference: float) -> None:
        self._column_name = column_name
        self._reference = reference

    @abstractmethod
    def __call__(self, df: pd.DataFrame) -> "TestResult":
        raise NotImplementedError()
    
    def __repr__(self) -> str:
        return f"{self.__class__.__qualname__}({self._column_name!r}, {self._reference!r})"
    
class TestResult:
    def __init__(self, test: Test, result: bool) -> None:
        self._test = test
        self._result = result

    def __repr__(self) -> str:
        return f"{self.__class__.__qualname__}({self._test!r}, {self._result!r})"

    def __bool__(self) -> bool:
        return self._result
    
class LeTest(Test):
    def __call__(self, df: pd.DataFrame) -> TestResult:
        return all(df[self._column_name] <= self._reference)

class EqTest(Test):
    def __call__(self, df: pd.DataFrame) -> TestResult:
        return all(df[self._column_name] == self._reference)
    
class PlaceHolder:
    def __init__(self, column_name: str) -> None:
        self._column_name = column_name

    def __eq__(self, other: float) -> EqTest:
        return EqTest(self._column_name, other)

    def __le__(self, other: float) -> LeTest:
        return LeTest(self._column_name, other)


df = pd.DataFrame({"col": [1, 2, 3, 4]})

test1 = PlaceHolder("col") <= 5
test2 = PlaceHolder("col") == 2

if test1(df):
    print("This should be false:", bool(test2(df)))


This should be false: False


> Type casting dunder are also available for int, float, and of course string.

> If you want to provide all the method but only wants to implement one inequality comparison and the equality, see [`functools.total_ordering`](https://docs.python.org/3/library/functools.html#functools.total_ordering).

## 3. Arithmetic overloading

### Exposition

:hourglass: 20

You can overload the following comparison operators in Python:

| Operator         | Symbol | Dunder         |
|------------------|--------|----------------|
| add              | +      | `__add__`      |
| subtract         | -      | `__sub__`      |
| multiply         | *      | `__mul__`      |
| divide           | /      | `__truediv__`  |
| divide (integer) | //     | `__floordiv__` |
| power            | **     | `__pow__`      |
| remainder        | %      | `__mod__`      |
| matrix mult.     | @      | `__matmul__`   |
| and              | &      | `__and__`      |
| or               | \|     | `__or__`       |
| xor              | ^      | `__xor__`      |

There are two variants around the operator: the `r`-operators and the `i`-operators. 

In [11]:
class MyFloat:
    def __init__(self, x):
        self.x = x

    def __add__(self, other):
        return self.x + other

class MyRFloat:
    def __init__(self, x):
        self.x = x

    def __add__(self, other):
        return self.x + other

    def __radd__(self, other):
        return other + self.x


print("Case 1:", MyRFloat(2) + 3)
print("Case 2:", 3 + MyRFloat(2))
print("Case 3:", MyFloat(2) + 3)
print("Case 4:", 3 + MyFloat(2))

Case 1: 5
Case 2: 5
Case 3: 5


TypeError: unsupported operand type(s) for +: 'int' and 'MyFloat'

The `i`-variant is about inplace operators. Note that the signature of the function is different:

In [16]:
from __future__ import annotations

class MyNewFloat:
    def __init__(self, x):
        self.x = x

    def __add__(self, other: float) -> float:
        return self.x + other

    def __iadd__(self, other: float) -> MyNewFloat:
        self.x += other
        return self  # Must return self!

mnf = MyNewFloat(2)
mnf += 3
mnf.x

5

In most cases, operation overload is essentially a syntactic sugar. Here is an example that goes beyond cosmetic aspects:

In [7]:
from __future__ import  annotations

from typing import Union

import pandas as pd

class Ratio:
    def __init__(self, numerator: int, denominator: int) -> None:
        self._numerator = numerator
        self._denominator = denominator

    def simplify(self) -> Ratio:
        return self  # TODO

    def __add__(self, other: Union[Ratio, int]) -> Ratio:
        if isinstance(other, int):
            other = Ratio(other, 1)
        
        return Ratio(
            self._numerator * other._denominator + other._numerator * other._denominator,
            self._denominator * other._denominator
        ).simplify()

    def __repr__(self) -> str:
        return f"{self.__class__.__qualname__}({self._numerator!r}, {self._denominator!r})"


pd.DataFrame([Ratio(1, 3), 2, Ratio(4, 5)]).sum()

0    Ratio(35, 15)
dtype: object

> :skull: the `operator` module from the standard Python library gives access to shortcuts to invoke the operators.

## 4. Shift operators

:hourglass: 10 min

You can overload the following bit shift operators in Python (+ the `r` and `i` variants):


| Operator     | Symbol | Dunder         |
|--------------|--------|----------------|
| right shift  | >>     | `__rshift__`   |
| left shift   | <<     | `__lshift__`   |

Technically, the bit shift operators are used to manipulate the bits in binary data. When they are overloaded, it is usually to take advantage of the directionality of the symbols. 

It tends to be used in the following cases:
- directed graph (eg. Airflow);
- move files around.

> The matrix mutiplication tends to be abused in a similar way as the `@` can express the idea of a context.

## 5. Unary operators

:hourglass: 10 min

You can overload the following unary operators in Python:

| Operator  | Symbol | Dunder       |
|-----------|--------|--------------|
| negative  | -      | `__neg__`    |
| positive  | +      | `__pos__`    |
| invertion | ~      | `__invert__` | 

Here is an example:

In [18]:
from __future__ import annotations

class MyStr(str):
    def __invert__(self) -> MyStr:
        return MyStr(self[::-1])

s = MyStr("this is a string")
print(s)
print(~s)

this is a string
gnirts a si siht


## 6. Exercises

:hourglass: 30 min + :coffee: 15 min

- Implement a dice class. A dice has $k$ sides and rolling each side has a probability of $\frac{1}{k}$. Using `int(d)` where `d` is a dice should return the result of a dice roll. A dice pool is a collection of dice; using `int(p)` where `p` is a pool should return the sum of the dice rolls. Dice should be summable and you can multiply a dice by a scalar to get as many dice.
- Implement graphs. A directed graph is a collection of nodes and edges. Edges between nodes can be specified via the shift operators. You can implement a simple graph algorithm such as whether there is a path between two nodes.
- Implement symbolic computation. You can add/subtract variables, and/or add/multiply/divide/subtract constants. You can use this to create functions or to solve simple sets of equations (eg. 1 quadratic function, 2 linear functions of two variables).

In [12]:
import random


class Dice:
    def __init__(self, n_sides: int, n_dice: int = 1) -> None:
        self._n_sides = n_sides
        self._n_dice = n_dice

    @property
    def n_sides(self) -> int:
        return self._n_sides
    
    def __str__(self) -> str:
        return f"{self._n_dice}d{self._n_sides}"
    
    def __repr__(self) -> str:
        return f"{self.__class__.__qualname__}({self._n_sides}, {self._n_dice})"

    def __int__(self) -> int:
        return sum(random.randint(1, self._n_sides) for _ in range(self._n_dice))
    
    # TODO implement addition and scalar multiplication

class Pool:
    def __init__(self, *dices: Dice) -> None:
        self._dices = tuple(sorted(dices, key=lambda d: d.n_sides))

    def __str__(self) -> str:
        return " ".join(str(x) for x in self._dices)
    
    def __repr__(self) -> str:
        dice_str = ", ".join(str(x) for x in self._dices)
        return f"{self.__class__.__qualname__}({dice_str})"
    
    def __int__(self) -> int:
        raise NotImplementedError()
    
    # TODO implement addition and scalar multiplication

    

    

# What is the probability of beating (ie. reaching or surpassing) a 10 with 2d6 + 1d4 ?
t = 10
successes = 0
N_TRIALS = 10000
_d = Dice # Shortcut for style
for _ in range(N_TRIALS):
    outcome = int(2*_d(6) + _d(4))
    if outcome > t:
        successes += 1

successes / N_TRIALS

1

## 7. Closing words

:hourglass: 10 min

:question: How should you handle an unsuitable operand?

In this module, we discussed operator overloading. In Python you can override the behavior of many operators: comparison, arithmetic (including bit shifting) and unary operators. 

The most important bit to remember regarding operation overloading is that it should be clear for the users what to expect when dealing with operations. For that, be mindful of the semantics of the operations, the in-place/mutability aspect, and operator priorities.


**Dunderscore**: too many to list.