# 类与接口

python是面向对象的，完全支持继承多态封装等机制
- 有复杂需求多使用类，而不是嵌套字典、元组、集合、列表等内置的类型
- 让简单的接口接受函数，而不是类的实例
    - python中有许多内置的API，都允许传入某个函数（hook函数/挂钩函数）来定制它的行为。
    - API执行过程中会回调（call back）这些挂钩函数，例如list中的sort方法就带key参数，可以传入挂钩函数
    - 某个类如果定义了__call__方法，那么它的实例可以像普通python函数一样调用
    - 如果想用函数来维护状态（带状态的闭包函数），可以考虑定义一个带__call__方法的类，而不要使用有状态的闭包去实现

## 1 继承
- 子类通过super()调用__init__来初始化超类，可以应对菱形继承结构（不建议多继承），不用双参数直接不带参数调用super，会自动把__class__，self参数传递进去
- min-in类来表示可组合的功能（一定程度上缓解多继承带来的问题）
## 2 多态
- 类的多态，通过@classmethod来构造同一体系中的各类对象
## 3 封装

In [3]:
# 让简单接口接受hook

# list的sort方法中key参数传递hook函数len
names = ['Socrates', 'Archimedes', 'Plato', 'Aristotle']
names.sort(key=len)
print(names)

# defaultdict方法传入自定义的hook函数log_missing
def log_missing():
    print('Key added')
    return 0

from collections import defaultdict

current = {'green': 12, 'blue': 3}
increments = [
    ('red', 5),
    ('blue', 17),
    ('orange', 9),
]
result = defaultdict(log_missing, current)
print('Before:', dict(result))
for key, amount in increments:
    result[key] += amount
print('After: ', dict(result))

# hook中有带状态的闭包，不推荐使用
def increment_with_report(current, increments):
    added_count = 0

    def missing():
        nonlocal added_count  # Stateful closure
        added_count += 1
        return 0

    result = defaultdict(missing, current)
    for key, amount in increments:
        result[key] += amount

    return result, added_count

result, count = increment_with_report(current, increments)
assert count == 2
print(result)

# 定义一个新类，代替有状态的闭包
class CountMissing:
    def __init__(self):
        self.added = 0

    def missing(self):
        self.added += 1
        return 0

counter = CountMissing()
result = defaultdict(counter.missing, current)  # Method ref
for key, amount in increments:
    result[key] += amount
assert counter.added == 2
print(result)

# 带__call__的类来取代有状态的闭包是更好的方式
class BetterCountMissing:
    def __init__(self):
        self.added = 0

    def __call__(self):
        self.added += 1
        return 0

counter = BetterCountMissing()
assert counter() == 0
assert callable(counter)

counter = BetterCountMissing()
result = defaultdict(counter, current)  # Relies on __call__
for key, amount in increments:
    result[key] += amount
assert counter.added == 2
print(result)

['Plato', 'Socrates', 'Aristotle', 'Archimedes']
Before: {'green': 12, 'blue': 3}
Key added
Key added
After:  {'green': 12, 'blue': 20, 'red': 5, 'orange': 9}
defaultdict(<function increment_with_report.<locals>.missing at 0x0000019D8EC34160>, {'green': 12, 'blue': 20, 'red': 5, 'orange': 9})
defaultdict(<bound method CountMissing.missing of <__main__.CountMissing object at 0x0000019D8F4C8100>>, {'green': 12, 'blue': 20, 'red': 5, 'orange': 9})
defaultdict(<__main__.BetterCountMissing object at 0x0000019D8EC48C40>, {'green': 12, 'blue': 20, 'red': 5, 'orange': 9})


In [2]:
# 避免生成临时文件
# Write all output to a temporary directory
import atexit
import gc
import io
import os
import tempfile

TEST_DIR = tempfile.TemporaryDirectory()
atexit.register(TEST_DIR.cleanup)

# Make sure Windows processes exit cleanly
OLD_CWD = os.getcwd()
atexit.register(lambda: os.chdir(OLD_CWD))
os.chdir(TEST_DIR.name)

def close_open_files():
    everything = gc.get_objects()
    for obj in everything:
        if isinstance(obj, io.IOBase):
            obj.close()

atexit.register(close_open_files)

<function __main__.close_open_files()>

In [3]:
# 使用@classmethod实现类的多态

import os
import random

# 多线程 实现计算一个文件夹下 所有文件的行数量的总和
# spark中的mapreduce也是多个work去map，最后用第一个节点做最后的reduce

# 通用InputData类
class GenericInputData:
    def read(self):
        raise NotImplementedError

    @classmethod
    def generate_inputs(cls, config): # 构造一个通用的接口
        raise NotImplementedError

# 一个通过文件夹生成文件夹下文件路径，并对文件进行读取的子类
class PathInputData(GenericInputData):
    def __init__(self, path):
        super().__init__() 
        self.path = path

    def read(self):
        with open(self.path) as f:
            return f.read()

    @classmethod
    def generate_inputs(cls, config): # 子类实现父类接口，生成目录下文件路径生成器
        data_dir = config['data_dir']
        for name in os.listdir(data_dir):
            yield cls(os.path.join(data_dir, name))

# 通用Worker类
class GenericWorker:
    def __init__(self, input_data):
        self.input_data = input_data
        self.result = None

    def map(self):
        raise NotImplementedError

    def reduce(self, other):
        raise NotImplementedError

    @classmethod
    def create_workers(cls, input_class, config): # input_class是GenericInputData的一个子类，用来生成work节点
        workers = []
        for input_data in input_class.generate_inputs(config):
            workers.append(cls(input_data))
        return workers
    
class LineCountWorker(GenericWorker): # 定义行计数器work节点类
    def map(self): # 读取文件并对换行符计数
        data = self.input_data.read()
        self.result = data.count('\n')

    def reduce(self, other): # 合并两个节点的行计数器结果
        self.result += other.result

from threading import Thread

def execute(workers):
    threads = [Thread(target=w.map) for w in workers] # 一个worker一个线程
    for thread in threads: thread.start()
    for thread in threads: thread.join() # 启动所有线程

    first, *rest = workers
    for worker in rest:
        first.reduce(worker) # 用第一个worker来合并剩下的worker中的result
    return first.result

def mapreduce(worker_class, input_class, config):
    workers = worker_class.create_workers(input_class, config)
    return execute(workers)

def write_test_files(tmpdir):
    os.makedirs(tmpdir)
    for i in range(100):
        with open(os.path.join(tmpdir, str(i)), 'w') as f:
            f.write('\n' * random.randint(0, 100))
            
tmpdir = 'test_inputs'
write_test_files(tmpdir)

config = {'data_dir': tmpdir}
result = mapreduce(LineCountWorker, PathInputData, config) # 这里的调用可以其他通用类的子类，虽然我只实现了一个子类，程序的可扩展性强
print(f'There are {result} lines')

There are 5154 lines


In [8]:
# 调用super的__init__来初始化超类

class MyBaseClass:
    def __init__(self, value):
        self.value = value

class TimesTwo:
    def __init__(self):
        self.value *= 2

class PlusFive:
    def __init__(self):
        self.value += 5

class AnotherWay(MyBaseClass, PlusFive, TimesTwo): # 继承顺序
    def __init__(self, value): # 初始化顺序
        MyBaseClass.__init__(self, value)
        TimesTwo.__init__(self)
        PlusFive.__init__(self)

bar = AnotherWay(5) # 结果是先乘2再加5，和__init__的调用顺序一样，这不符合继承顺序
print('Second ordering value is', bar.value)

# 菱形继承的初始化问题
class TimesSeven(MyBaseClass):
    def __init__(self, value):
        MyBaseClass.__init__(self, value)
        self.value *= 7

class PlusNine(MyBaseClass):
    def __init__(self, value):
        MyBaseClass.__init__(self, value)
        self.value += 9

class ThisWay(TimesSeven, PlusNine):
    def __init__(self, value):
        TimesSeven.__init__(self, value)
        PlusNine.__init__(self, value)

foo = ThisWay(5)
print('Should be (5 * 7) + 9 = 44 but is', foo.value)
# Should be (5 * 7) + 9 = 44 but is 14

class MyBaseClass:
    def __init__(self, value):
        self.value = value

class TimesSevenCorrect(MyBaseClass):
    def __init__(self, value):
        super().__init__(value)
        self.value *= 7

class PlusNineCorrect(MyBaseClass):
    def __init__(self, value):
        super().__init__(value)
        self.value += 9

class GoodWay(TimesSevenCorrect, PlusNineCorrect):
    def __init__(self, value):
        super().__init__(value)

foo = GoodWay(5)
print('Should be 7 * (5 + 9) = 98 and is', foo.value)

mro_str = '\n'.join(repr(cls) for cls in GoodWay.mro()) # 利用mro来看一下初始化过程
print(mro_str)

Second ordering value is 15
Should be (5 * 7) + 9 = 44 but is 14
<class '__main__.ThisWay'>
<class '__main__.TimesSeven'>
<class '__main__.PlusNine'>
<class '__main__.MyBaseClass'>
<class 'object'>
Should be 7 * (5 + 9) = 98 and is 98
<class '__main__.GoodWay'>
<class '__main__.TimesSevenCorrect'>
<class '__main__.PlusNineCorrect'>
<class '__main__.MyBaseClass'>
<class 'object'>


In [None]:
# mix-in类

# 