In [3]:
# default_exp datastructure.generator

%reload_ext autoreload
%autoreload 2

# generator
https://www.liaoxuefeng.com/wiki/1016959663602400/1017318207388128

通过列表生成式，我们可以直接创建一个列表。但是，受到内存限制，列表容量肯定是有限的。而且，创建一个包含100万个元素的列表，不仅占用很大的存储空间，如果我们仅仅需要访问前面几个元素，那后面绝大多数元素占用的空间都白白浪费了。

所以，如果列表元素可以按照某种算法推算出来，那我们是否可以在循环的过程中不断推算出后续的元素呢？这样就不必创建完整的list，从而节省大量的空间。在Python中，这种一边循环一边计算的机制，称为生成器：generator。

## generator创建方式1
要创建一个generator，有很多种方法。第一种方法很简单，只要把一个列表生成式的[]改成()，就创建了一个generator：

In [1]:
L = [x * x for x in range(10)]
L

[0, 1, 4, 9, 16, 25, 36, 49, 64, 81]

In [2]:
g = (x * x for x in range(10))
g

<generator object <genexpr> at 0x10a774318>

创建L和g的区别仅在于最外层的[]和()，L是一个list，而g是一个generator。

我们可以直接打印出list的每一个元素，但我们怎么打印出generator的每一个元素呢？

如果要一个一个打印出来，可以通过next()函数获得generator的下一个返回值：

In [3]:
next(g)

0

In [4]:
next(g)

1

In [5]:
next(g)

4

In [6]:
next(g), next(g), next(g), next(g), next(g), next(g), next(g)

(9, 16, 25, 36, 49, 64, 81)

In [7]:
next(g)

StopIteration: 

我们讲过，generator保存的是算法，每次调用next(g)，就计算出g的下一个元素的值，直到计算到最后一个元素，没有更多的元素时，抛出StopIteration的错误。

当然，上面这种不断调用next(g)实在是太变态了，正确的方法是使用for循环，因为generator也是可迭代对象：

In [8]:
g = (x * x for x in range(3))
for n in g:
    print(n)

0
1
4


所以，我们创建了一个generator后，基本上永远不会调用next()，而是通过for循环来迭代它，并且不需要关心StopIteration的错误。

generator非常强大。如果推算的算法比较复杂，用类似列表生成式的for循环无法实现的时候，还可以用函数来实现。

比如，著名的斐波拉契数列（Fibonacci），除第一个和第二个数外，任意一个数都可由前两个数相加得到：

1, 1, 2, 3, 5, 8, 13, 21, 34, ...

斐波拉契数列用列表生成式写不出来，但是，用函数把它打印出来却很容易：

In [9]:
def fib(max):
    n, a, b = 0, 0, 1
    while n < max:
        print(b)
        a, b = b, a + b  # 注意，赋值语句相当于t = (b, a + b) # t是一个tuple; a = t[0]; b = t[1]
        n = n + 1
    return 'done'

In [10]:
fib(6)

1
1
2
3
5
8


'done'

仔细观察，可以看出，fib函数实际上是定义了斐波拉契数列的推算规则，可以从第一个元素开始，推算出后续任意的元素，这种逻辑其实非常类似generator。
## generator创建方式2: 函数
带有yield的函数都被看成生成器，生成器是可迭代对象，且具备`__iter__` 和 `__next__`方法， 可以遍历获取元素
python要求迭代器本身也是可迭代的，所以我们还要为迭代器实现`__iter__`方法，而`__iter__`方法要返回一个迭代器，迭代器自身正是一个迭代器，所以迭代器的`__iter__`方法返回自身即可

也就是说，上面的函数和generator仅一步之遥。要把fib函数变成generator，只需要把print(b)改为yield b就可以了：

这就是定义generator的另一种方法。如果一个函数定义中包含yield关键字，那么这个函数就不再是一个普通函数，而是一个generator：



In [11]:
def fib(max):
    n, a, b = 0, 0, 1
    while n < max:
        yield b
        a, b = b, a + b
        n = n + 1
    return 'done'


In [12]:
f = fib(6)
f

<generator object fib at 0x10a774228>

In [13]:
next(f)

1

这里，最难理解的就是generator和函数的执行流程不一样。函数是顺序执行，遇到return语句或者最后一行函数语句就返回。而变成generator的函数，在每次调用next()的时候执行，遇到yield语句返回，再次执行时从上次返回的yield语句处继续执行。

# examples

In [14]:
class DataGenerator(object):
    """
    基类: 数据生成器模版, 子类需要实现__iter__()
    https://github.com/bojone/bert4keras/blob/master/bert4keras/snippets.py
    """
    def __init__(self, data, batch_size=32, buffer_size=None):
        self.data = data
        self.batch_size = batch_size
        if hasattr(self.data, '__len__'):
            self.steps = len(self.data) // self.batch_size
            if len(self.data) % self.batch_size != 0:
                self.steps += 1
        else:
            self.steps = None
        self.buffer_size = buffer_size or batch_size * 1000

    def __len__(self):
        return self.steps

    def sample(self, random=False):
        """
        采样函数，每个样本同时返回一个is_end标记
        :random: 
            False: 按顺序生成 数据
            True: 先打乱 再生成数据
        """
        if random:
            if self.steps is None:

                def generator():
                    caches, isfull = [], False
                    for d in self.data:
                        caches.append(d)
                        if isfull:
                            i = np.random.randint(len(caches))
                            yield caches.pop(i)
                        elif len(caches) == self.buffer_size:
                            isfull = True
                    while caches:
                        i = np.random.randint(len(caches))
                        yield caches.pop(i)

            else:

                def generator():
                    indices = list(range(len(self.data)))
                    np.random.shuffle(indices)
                    for i in indices:
                        yield self.data[i]

            data = generator()
        else:
            data = iter(self.data)

        d_current = next(data)
        for d_next in data:
            yield False, d_current
            d_current = d_next

        yield True, d_current

    def __iter__(self, random=False):
        raise NotImplementedError

    def forfit(self):
        while True:
            for d in self.__iter__(True):
                yield d

In [None]:
class data_generator(DataGenerator):
    """
    子类实例: 数据生成器
    
    return:
    [
        batch_token_ids, 
        batch_segment_ids,
        batch_subject_labels, 
        batch_subject_ids,
        batch_object_labels
    ], None
    """
    def __iter__(self, random=False):
        batch_token_ids, batch_segment_ids = [], []
        batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], []
        for is_end, d in self.sample(random):
            token_ids, segment_ids = tokenizer.encode(
                d['text'], max_length=maxlen
            )
            # 整理三元组 {s: [(o, p)]}
            spoes = {}
            for s, p, o in d['spo_list']:
                s = tokenizer.encode(s)[0][1:-1]
                p = predicate2id[p]
                o = tokenizer.encode(o)[0][1:-1]
                s_idx = search(s, token_ids)
                o_idx = search(o, token_ids)
                if s_idx != -1 and o_idx != -1:
                    s = (s_idx, s_idx + len(s) - 1)
                    o = (o_idx, o_idx + len(o) - 1, p)
                    if s not in spoes:
                        spoes[s] = []
                    spoes[s].append(o)
            if spoes:
                # subject标签
                subject_labels = np.zeros((len(token_ids), 2))
                for s in spoes:
                    subject_labels[s[0], 0] = 1
                    subject_labels[s[1], 1] = 1
                # 随机选一个subject 
                # mayi: 这里随机选取subject的方法很奇特，同时生成了负样本，如选到的subject_ids不存在时，object_labels都为0
                start, end = np.array(list(spoes.keys())).T
                start = np.random.choice(start)
                end = np.random.choice(end[end >= start])
                # subject对应的索引位置
                subject_ids = (start, end)
                # 对应的object标签
                object_labels = np.zeros((len(token_ids), len(predicate2id), 2))
                for o in spoes.get(subject_ids, []):
                    object_labels[o[0], o[2], 0] = 1
                    object_labels[o[1], o[2], 1] = 1
                # 构建batch
                batch_token_ids.append(token_ids)
                batch_segment_ids.append(segment_ids)
                batch_subject_labels.append(subject_labels)
                batch_subject_ids.append(subject_ids)
                batch_object_labels.append(object_labels)
                if len(batch_token_ids) == self.batch_size or is_end:
                    batch_token_ids = sequence_padding(batch_token_ids)
                    batch_segment_ids = sequence_padding(batch_segment_ids)
                    batch_subject_labels = sequence_padding(
                        batch_subject_labels, padding=np.zeros(2)
                    )
                    batch_subject_ids = np.array(batch_subject_ids)
                    batch_object_labels = sequence_padding(
                        batch_object_labels,
                        padding=np.zeros((len(predicate2id), 2))
                    )
                    yield [
                        batch_token_ids, batch_segment_ids,
                        batch_subject_labels, batch_subject_ids,
                        batch_object_labels
                    ], None
                    batch_token_ids, batch_segment_ids = [], []
                    batch_subject_labels, batch_subject_ids, batch_object_labels = [], [], []


In [None]:
train_generator = data_generator(train_data, batch_size)

In [None]:
train_model.fit_generator(
    train_generator.forfit(),
    steps_per_epoch=len(train_generator),
    epochs=epochs,
    callbacks=[evaluator]
)

# nb_export

In [2]:
from nbdev.export import *
notebook2script()

Converted 00_core.ipynb.
Converted 00_template.ipynb.
Converted active_learning.ipynb.
Converted algo_dl_keras.ipynb.
Converted algo_ml_eda.ipynb.
Converted algo_ml_tree_catboost.ipynb.
Converted algo_ml_tree_lgb.ipynb.
Converted algo_rs_associated_rules.ipynb.
Converted algo_rs_match_deepmatch.ipynb.
Converted algo_rs_matrix.ipynb.
Converted algo_rs_search_vector_faiss.ipynb.
Converted algo_seq_embeding.ipynb.
Converted algo_seq_features_extraction_text.ipynb.
Converted datastructure_dict_list_set.ipynb.
Converted datastructure_matrix_sparse.ipynb.
Converted engineering_concurrency.ipynb.
Converted engineering_nbdev.ipynb.
Converted engineering_panel.ipynb.
Converted engineering_snorkel.ipynb.
Converted index.ipynb.
Converted math_func_basic.ipynb.
Converted math_func_loss.ipynb.
Converted operating_system_command.ipynb.
Converted plot.ipynb.
Converted utils_functools.ipynb.
Converted utils_json.ipynb.
Converted utils_pickle.ipynb.
Converted utils_time.ipynb.


In [7]:
!nbdev_build_docs

No notebooks were modified
converting /Users/luoyonggui/PycharmProjects/nbdevlib/index.ipynb to README.md
