In [1]:
import tensorflow as tf
import io
import pathlib
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np

np.set_printoptions(precision=4)

In [2]:
train, test = tf.keras.datasets.fashion_mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz


In [3]:
images, labels = train
images = images / 255
dataset = tf.data.Dataset.from_tensor_slices((images, labels))

In [4]:
dataset

<TensorSliceDataset shapes: ((28, 28), ()), types: (tf.float64, tf.uint8)>

* 当数据集不大的时候, 推荐使用内存从内存中直接加载数据 tf.data.Dataset.from_tensor_slices()

* 当数据集较大的时候, 推荐使用Python中到generator来产生对象, 具体的做法是

``` python
def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
for i in count(5):
    print(i)
```

* 使用yield关键字会使生成器每次都停在yield 处, 当下一次调用next()才会继续运行

In [None]:
def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1
for i in count(5):
    print(i)
ds_counter = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes = (), )

In [5]:
class FileReader(object):

    def __init__(self, filenames: str):
        self.filenames: str = filenames


    def file_reader(self):
        for filename in self.filenames:
           fr = io.open(filename, "r+", encoding="utf-8")

           while True:
               data_line= fr.readline()
               if not data_line:
                   break
               data_list = data_line.split(" ")
               yield np.asarray(data_list, dtype='int32')

    def generator_batch(self, batch_size, num_epochs=None):
        data1 = tf.data.Dataset.from_generator(self.file_reader, tf.int32, tf.TensorShape[None])
        data2 = data1.repeat(num_epochs).padded_batch(batch_size, padded_shapes=tf.TensorShape([3]), padding_values=-1)
        iterator = data2.make_one_shot_iterator()
        one_batch = iterator.next()
        return one_batch

if __name__ == '__main__':
    file_list = ["../data/input/1.txt", "../data/input/2.txt", "../data/input/3.txt"]
    MyFileReader = FileReader(file_list)
    MyFileReader.generator_batch(2, 1)


In [6]:
def count(stop):
  i = 0
  while i<stop:
    yield i
    i += 1

dc_counts = tf.data.Dataset.from_generator(count, args=[25], output_types=tf.int32, output_shapes=())

for dc_count in dc_counts.repeat().batch(10).take(10):
    print(dc_count.numpy())

[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]
[0 1 2 3 4 5 6 7 8 9]
[10 11 12 13 14 15 16 17 18 19]
[20 21 22 23 24  0  1  2  3  4]
[ 5  6  7  8  9 10 11 12 13 14]
[15 16 17 18 19 20 21 22 23 24]


In [7]:
def gen_series():
    i = 0
    while True:
        size = np.random.randint(0, 10)
        yield i, np.random.normal(size=(size,))
        i += 1

for i, series in gen_series():
    print(i, ":", str(series))
    if i > 5: break

0 : []
1 : [-0.2952  0.2235  1.0418 -0.652 ]
2 : [ 0.6015  0.9682  1.3831 -0.086 ]
3 : [-0.8109  0.4219 -0.7097  0.4963 -0.5641]
4 : []
5 : [-1.39    2.2241 -0.741   0.3348 -0.5044  0.7392]
6 : [ 0.8258  1.1383  0.1615 -0.1248  1.1196 -0.9071 -0.0051  1.0094]


In [8]:
ds_counter = tf.data.Dataset.from_generator(gen_series,
                                            output_types=(tf.int32, tf.float64),
                                            output_shapes=((), (None,)))
ds_counter

<FlatMapDataset shapes: ((), (None,)), types: (tf.int32, tf.float64)>

In [9]:
# Note that when batching a dataset with a variable shape, you need to use Dataset.padded_batch.
ds_series_batch = ds_counter.shuffle(20).padded_batch(10)
ids, sequece_batch = next(iter(ds_series_batch))
print(ids.numpy())
print()
print(sequece_batch.numpy())

[ 4 20 16  8 11 24  6 25 18 15]

[[ 1.0737  0.2608  2.2549 -0.1077 -0.6878 -0.2789  0.    ]
 [ 0.8669  1.5525  0.4499  0.      0.      0.      0.    ]
 [ 0.0556 -0.5853  0.      0.      0.      0.      0.    ]
 [ 0.8558 -1.5641  1.5252  1.9275  0.2557  0.      0.    ]
 [-1.7397  0.8214 -0.2925 -1.1368 -0.964   0.4297  0.6739]
 [ 0.0648  1.632   0.5551 -0.6848  0.      0.      0.    ]
 [-0.799  -0.0554 -0.6191  2.7992  0.      0.      0.    ]
 [ 0.      0.      0.      0.      0.      0.      0.    ]
 [ 0.5642 -0.9947 -0.5092  0.      0.      0.      0.    ]
 [-0.195  -0.3346  1.9114 -0.1246  0.1853  0.2497 -0.7068]]
