# 正确重载运算符

## 运算符重载基础
Python施加了一些限制，做好了灵活性、可用性和安全性方面的平衡:

+ 不能重载内置类型的运算符
+ 不能新建运算符，只能重载现有的
+ 某些运算符不能重载—— is,and,or,not（不过位运算符 &,|,~可以）

## 一元运算符

+ `- (__neg__)`：负
+ `+ (__pos__)`:正
+ `~ (__invert__)`: 对整数位按位取反

支持一元运算符很简单，只需要实现相应的特殊方法。这些特殊方法只有一个参数，self。然后，使用符合所在类的逻辑实现。
不过，要遵守运算符的一个基本规则：始终返回一个新对象。也就是说，不能修改self，要创建并返回合适类型的新实例。


## 重载向量加法运算符+



In [5]:
import numbers
import itertools
from array import array
import reprlib

class Vector:
    typecode = 'd'
    shortcut_names = 'xyzt'
    
    def __init__(self, components):
        self._components = array(self.typecode, components)
        
    def __iter__(self):
        return iter(self._components)
    
    def __repr__(self):
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)
    
    def __str__(self):
        return str(tuple(self))
    
    def __bytes__(self):
        return (bytes([ord(self.typecode)])+
               bytes(self._components))
    
    def __eq__(self, other):
        return tuple(self) == tuple(other)
    
    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))
    
    def __bool__(self):
        return bool(abs(self))
    
    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)
    
    def __len__(self):
        return len(self._components)
    
    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            # 如果是切片,构造一个对象
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))
            
    def __getattr__(self, name):
        cls = type(self)
        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))
        
    def __add__(self, other):
        pairs = itertools.zip_longest(self, other, fillvalue=0.0)
        return Vector(a+b for a,b in pairs)

In [8]:
v1 = Vector([3,4,5,6])
v2 = Vector([1,2])
v1+v2

Vector([4.0, 6.0, 5.0, 6.0])

In [9]:
v1 + (10,20,30)

Vector([13.0, 24.0, 35.0, 6.0])

之所以适用于其他对象，因为我们使用了zip_longest,它能处理任何可迭代对象。不过对调对象就会报错。

为了支持涉及不同类型的运算，Python为中缀方法提供了特殊的分派机制。对`a+b`来说，解释器会执行以下几个操作:

1.如果a有`__add__`方法，而且返回值不是`NotImplemented`，调用`a.__add__(b)`,然后返回结果
2.如果a没有`__add__`方法,或者返回值是`NotImplemented`，检查b有没有`__radd__`方法，如果有，且没有返回`NotImplemented`,则调用`b.__radd__(a)`
3.如果b没有`__radd__`方法,或者调用`__radd__`方法返回`NotImplemented`，抛出`TypeError`，并在错误消息中致命操作数类型不支持。

In [13]:
import numbers
import itertools
from array import array
import reprlib

class Vector:
    typecode = 'd'
    shortcut_names = 'xyzt'
    
    def __init__(self, components):
        self._components = array(self.typecode, components)
        
    def __iter__(self):
        return iter(self._components)
    
    def __repr__(self):
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)
    
    def __str__(self):
        return str(tuple(self))
    
    def __bytes__(self):
        return (bytes([ord(self.typecode)])+
               bytes(self._components))
    
    def __eq__(self, other):
        return tuple(self) == tuple(other)
    
    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))
    
    def __bool__(self):
        return bool(abs(self))
    
    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)
    
    def __len__(self):
        return len(self._components)
    
    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            # 如果是切片,构造一个对象
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))
            
    def __getattr__(self, name):
        cls = type(self)
        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))
        
    def __add__(self, other):
        pairs = itertools.zip_longest(self, other, fillvalue=0.0)
        return Vector(a+b for a,b in pairs)
    
    def __radd__(self, other):
        return self + other

In [15]:
v1 = Vector([3,4,5,6])
v1 + 1

TypeError: zip_longest argument #2 must support iteration

In [16]:
(2,3,4)+v1

Vector([5.0, 7.0, 9.0, 6.0])

In [17]:
v1 + 'ABC'

TypeError: unsupported operand type(s) for +: 'float' and 'str'

上面返回的报错不是很清晰。我们应该让他返回`NotImplemented`

```python
def __add__(self,other):
    try:
        pairs = itertools.zip_longest(self,other,fillvalue=0)
        return Vector(a+b for a,b in pairs)
    except TypeError:
        return NotImplemented

def __radd__(self,other):
    return self + other

```

## 重载乘法运算符*

In [18]:
import numbers
import itertools
from array import array
import reprlib
import math

class Vector:
    typecode = 'd'
    shortcut_names = 'xyzt'
    
    def __init__(self, components):
        self._components = array(self.typecode, components)
        
    def __iter__(self):
        return iter(self._components)
    
    def __repr__(self):
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)
    
    def __str__(self):
        return str(tuple(self))
    
    def __bytes__(self):
        return (bytes([ord(self.typecode)])+
               bytes(self._components))
    
    def __eq__(self, other):
        return tuple(self) == tuple(other)
    
    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))
    
    def __bool__(self):
        return bool(abs(self))
    
    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)
    
    def __len__(self):
        return len(self._components)
    
    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            # 如果是切片,构造一个对象
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))
            
    def __getattr__(self, name):
        cls = type(self)
        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))
        
    def __add__(self,other):
        try:
            pairs = itertools.zip_longest(self,other,fillvalue=0)
            return Vector(a+b for a,b in pairs)
        except TypeError:
            return NotImplemented   
    
    def __radd__(self, other):
        return self + other
    
    def __mul__(self, scalar):
        if isinstance(scalar, numbers.Real):
            return Vector(n*scalar for n in self)
        else:
            return NotImplementedError
        
    def __rmul__(self, scalar):
        return self * scalar

In [19]:
v1 = Vector([1.0, 2.0, 3.0])
14 * v1

Vector([14.0, 28.0, 42.0])

In [20]:
v1 * True

Vector([1.0, 2.0, 3.0])

In [21]:
from fractions import Fraction
v1 * Fraction(1, 3)

Vector([0.3333333333333333, 0.6666666666666666, 1.0])

## 众多比较运算符

Python对比较运算符的处理与前文类似，不过在两个方面有重大区别。

+ 正向和反向调用使用的是同一系列方法。正向的`__gt__`调用的是反向的`__lt__`
+ 对`==`和`!=`来说，如果反向调用失败，Python会比较对象的id，而不抛出`TypeError`


In [24]:
import numbers
import itertools
from array import array
import reprlib
import math

class Vector:
    typecode = 'd'
    shortcut_names = 'xyzt'
    
    def __init__(self, components):
        self._components = array(self.typecode, components)
        
    def __iter__(self):
        return iter(self._components)
    
    def __repr__(self):
        components = reprlib.repr(self._components)
        components = components[components.find('['):-1]
        return 'Vector({})'.format(components)
    
    def __str__(self):
        return str(tuple(self))
    
    def __bytes__(self):
        return (bytes([ord(self.typecode)])+
               bytes(self._components))
    
    def __eq__(self, other):
        if isinstance(other, Vector):
            return (len(self) == len(other) and
                    all(a == b for a, b in zip(self, other)))
        else:
            return NotImplementedError
    
    def __abs__(self):
        return math.sqrt(sum(x * x for x in self))
    
    def __bool__(self):
        return bool(abs(self))
    
    @classmethod
    def frombytes(cls, octets):
        typecode = chr(octets[0])
        memv = memoryview(octets[1:]).cast(typecode)
        return cls(memv)
    
    def __len__(self):
        return len(self._components)
    
    def __getitem__(self, index):
        cls = type(self)
        if isinstance(index, slice):
            # 如果是切片,构造一个对象
            return cls(self._components[index])
        elif isinstance(index, numbers.Integral):
            return self._components[index]
        else:
            msg = '{cls.__name__} indices must be integers'
            raise TypeError(msg.format(cls=cls))
            
    def __getattr__(self, name):
        cls = type(self)
        if len(name) == 1:
            pos = cls.shortcut_names.find(name)
            if 0 <= pos < len(self._components):
                return self._components[pos]
        msg = '{.__name__!r} object has no attribute {!r}'
        raise AttributeError(msg.format(cls, name))
        
    def __add__(self,other):
        try:
            pairs = itertools.zip_longest(self,other,fillvalue=0)
            return Vector(a+b for a,b in pairs)
        except TypeError:
            return NotImplemented   
    
    def __radd__(self, other):
        return self + other
    
    def __mul__(self, scalar):
        if isinstance(scalar, numbers.Real):
            return Vector(n*scalar for n in self)
        else:
            return NotImplementedError
        
    def __rmul__(self, scalar):
        return self * scalar

In [25]:
va = Vector([1.0, 2.0, 3.0])
vb = Vector(range(1,4))
va == vb

True

## 增量赋值运算符



In [26]:
v1 = Vector([1,2,3])
v1_alias = v1

In [27]:
id(v1)

4383413416

In [28]:
v1 += Vector([4,5,6])
v1

Vector([5.0, 7.0, 9.0])

In [29]:
id(v1)

4383412912

In [30]:
v1_alias

Vector([1.0, 2.0, 3.0])

In [31]:
id(v1_alias)

4383413416

如果一个类没有实现就地运算符，那么增量运算符就是语法糖，作用与 a = a+b一样。对于不可变类型，这是预期的行为，所以会产生新对象。

但是假如实现了就地运算符，例如`__iadd__`，计算a+=b时会调用就地运算符。

In [66]:
import abc

class Tombola(abc.ABC):
    
    @abc.abstractmethod
    def load(self, iterable):
        """从可迭代对象中添加元素"""
        
    @abc.abstractmethod
    def pick(self):
        """随机删除元素，然后将其返回
        
        如果实例为空，这个方法应该抛出`LookupError`
        """
        
    def loaded(self):
        """如果至少有一个元素，返回True,否则返回False"""
        return bool(self.inspect())
    
    def inspect(self):
        """返回一个有序元组，由当前元素构成"""
        items = []
        while True:
            try:
                items.append(self.pick())
            except LookupError:
                break
        self.load(items)
        return tuple(sorted(items))
    
import random

class BingoCage(Tombola):
    
    def __init__(self, items):
        self._randomizer = random.SystemRandom()
        self._items = []
        self.load(items)
        
    def load(self, items):
        self._items.extend(items)
        self._randomizer.shuffle(self._items)
    
    def pick(self):
        try:
            return self._items.pop()
        except IndexError:
            raise LookupError('pick from empty BingoCage')
            
    def __call__(self):
        self.pick()
        
# 可以不实现一些方法，直接去继承原先的方法。

In [67]:
class AddableBingoCage(BingoCage):
    
    def __add__(self, other):
        if isinstance(other, Tombola):
            return AddableBingoCage(self.inspect() + other.inspect())
        else:
            return NotImplemented
        
    def __iadd__(self, other):
        if isinstance(other, Tombola):
            other_iterable = other.inspect()
        else:
            try:
                other_iterable = iter(other)
            except TypeError:
                self_cls = type(self).__name__
                msg = "right operand in += must be {!r} or an iterable"
                raise TypeError(msg.format(self_cls))
            self.load(other_iterable)
            return self

In [76]:
vowel = "AEIOU"
globe = AddableBingoCage(vowel)
globe.inspect()

('A', 'E', 'I', 'O', 'U')

In [77]:
globe.pick() in vowel

True

In [78]:
len(globe.inspect())

4

In [79]:
globe2 = AddableBingoCage('XYZ')

In [80]:
globe3 = globe + globe2
len(globe3.inspect())

7

In [81]:
# 保存一下
globe_origin = globe
len(globe.inspect())

4

In [82]:
globe += globe2
len(globe.inspect())

AttributeError: 'NoneType' object has no attribute 'inspect'