In [19]:
import random
from mxnet import gluon
from mxnet import ndarray as nd

def transform(data, label):
    return data.astype('float32')/255, label.astype('float32')
mnist_train = gluon.data.vision.FashionMNIST(train=True, transform=transform)
mnist_test = gluon.data.vision.FashionMNIST(train=False, transform=transform)

In [12]:
print isinstance(mnist_test[0][0], nd.NDArray)

True


In [14]:
class Sampler():
    def __iter__(self):
        raise NotImplementedError
    def __len__(self):
        raise NotImplementedError

In [23]:
class SequentialSampler(Sampler):
    def __init__(self, length):
        self._length = length
    def __iter__(self):
        return iter(range(self._length))

    def __len__(self):
        return self._length
# 实现的是类似列表索引
sampler = SequentialSampler(10)
print list(sampler)
for i in sampler:
    print i,

[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
0 1 2 3 4 5 6 7 8 9


In [61]:
class RandomSampler(Sampler):
    def __init__(self, length):
        self._length = length
    def __iter__(self):
        indices = list(range(self._length))
        random.shuffle(indices)
        return iter(indices)
    def __len__(self):
        return self._length
r = RandomSampler(8)
for  i in r:
    print i,

2 7 4 5 3 1 0 6


In [62]:
class BatchSampler(Sampler):
    def __init__(self, smapler, batch_size, last_batch = "keep"):
        self._sampler = sampler
        self._batch_size = batch_size
        self._last_batch = last_batch
        self._prev = []
    def __iter__(self):
        batch, self._prev = self._prev, []
        for i in self._sampler:
            batch.append(i)
            if len(batch) == self._batch_size:
                # 每次yield大小为batch的索引列表
                yield batch
                batch = []
        # 最后剩余的元素个数不足batch
        if batch:
            if self._last_batch == 'keep':
                yield batch
            elif self._last_batch == 'discard':
                return
            # 剩余元素转到下一次
            elif self._last_batch == 'rollout':
                self._prev = batch
            else:
                raise ValueError("Oooo!")
    def __len__(self):
        if self._last_batch == 'keep':
            return (len(self._sampler) + self._batch_size - 1) // self._batch_size
        if self._last_batch == 'discard':
            return (len(self._sampler))//self._batch_size
        if self._last_batch == 'rollout':
            return (len(self._prev ) + len(self._sampler)) // self.batch_size
        raise ValueError("Oooo!")
                                 
batch_sampler = BatchSampler(sampler, 3, 'keep')
list(batch_sampler)
for i in batch_sampler:
    print i
print "长度为: ", len(batch_sampler)

[0, 1, 2]
[3, 4, 5]
[6, 7, 8]
[9]
长度为:  4


In [114]:
import numpy as np
# print (mnist_test[0][0]).shape, (mnist_test[3][1])
# a = [(3,45),(5,6),(5,6)]
# print zip(*a)
# for x in zip(*a):
#     print x


def _batchify(data):
    if isinstance(data[0], nd.NDArray):
        # 链接操作
        # x = [1, 2]
        # y = [3, 4]
        # stack(x, y) = [[1, 2],
        #                [3, 4]]
        return nd.stack(*data)
    elif isinstance(data[0], tuple):
        # 将[(3,45),(5,6),(5,6)] 变为 【（3,5,5），（45,6,6）】
        # 即把数据集的数据和标签分开来，各自组成一个元祖，data是一个列表,其中有两个元祖
        data = zip(*data)
        # 这里当然返回两个东西 
        return [_batchify(i) for i in data]
    else:
        # 变为标签
        data = np.asarray(data)
        return nd.array(data, dtype=data.dtype)
# mnist_test_sam = [mnist_test[i] for i in range(2)]

# _batchify(mnist_test_sam)


In [115]:

    

class DataLoader():
    """
        dataset: dataset
        batch_size : 元素数
        shuffle : 打乱？
        sampler : Sampler
        The sampler to use. Either specify sampler or shuffle, not both.
        last_batch : {'keep', 'discard', 'rollover'}
    """
    def __init__(self, dataset, batch_size=None, shuffle=False, sampler=None,
                last_batch=None, bach_sampler=None):
        self._dataset = dataset
        # 如果没有定义batch_sampler，就要往下定义
        if batch_sampler is None:
            if batch_size is None:
                raise ValueError("batch size must be defined")
            if sampler is None:
                if shuffle:
                    sampler = RandomSampler(len(dataset))
                else:
                    sampler = SequentialSampler(len(dataset))
            elif shuffle:
                raise ValueError("shuffle must be not specified if sampler is specified")
            
            batch_sampler = BatchSampler(sampler, batch_size, last_batch if last_batch else 'keep')
        self._batch_sampler = batch_sampler
    def __iter__(self):
        for batch in self._batch_sampler:
            yield _batchify([self._dataset[idx] for idx in batch])
    def __len__(self):
        return len(self._batch_sampler)