In [1]:
import tensorflow as tf

DATA_PATH = '../data/heart.csv'
BATCH_SIZE = 3
N_FEATURES = 9

  from ._conv import register_converters as _register_converters


In [2]:
def batch_generator(filenames):
    """ filenames is the list of files you want to read from. 
    In this case, it contains only heart.csv
    """
    # 1.0
    filename_queue = tf.train.string_input_producer(filenames)
    reader = tf.TextLineReader(skip_header_lines=1)  # skip the first line in the file
    _, value = reader.read(filename_queue)

    # 2.0
    # record_defaults are the default values in case some of our columns are empty
    # This is also to tell tensorflow the format of our data (the type of the decode result)
    # for this dataset, out of 9 feature columns, 
    # 8 of them are floats (some are integers, but to make our features homogenous, 
    # we consider them floats), and 1 is string (at position 5)
    # the last column corresponds to the lable is an integer

    record_defaults = [[1.0] for _ in range(N_FEATURES)]
    record_defaults[4] = ['']
    record_defaults.append([1])

    # read in the 10 rows of data
    # 把读取到的value值解码成特征向量，record_defaults定义解码格式及对应的数据类型
    print('record_defaults:', record_defaults)
    content = tf.decode_csv(value, record_defaults=record_defaults)
#     print('content:', content.run())
    # convert the 5th column (present/absent) to the binary value 0 and 1
    # TensorFlow函数：tf.where
    # refs:https://blog.csdn.net/a_a_ron/article/details/79048446
    # if content[4] == tf.constant('Present'):
    #     tf.constant(1.0)
    #  else:
    #     tf.constant(0.0)
    condition = tf.equal(content[4], tf.constant('Present'))
    content[4] = tf.where(condition, tf.constant(1.0), tf.constant(0.0))

    # pack all 9 features into a tensor
    features = tf.stack(content[:N_FEATURES])

    # assign the last column to label
    label = content[-1]

    # minimum number elements in the queue after a dequeue, used to ensure 
    # that the samples are sufficiently mixed
    # I think 10 times the BATCH_SIZE is sufficient
    min_after_dequeue = 10 * BATCH_SIZE

    # the maximum number of elements in the queue
    capacity = 20 * BATCH_SIZE
    # min_after_dequeue是出队后，队列至少剩下min_after_dequeue个数据，如果队列中的数据不足，则等待插入新数据
    # batch_size 队尾取出数据
    # capacity 队列容量
    # shuffle the data to generate BATCH_SIZE sample pairs
    data_batch, label_batch = tf.train.shuffle_batch([features, label], batch_size=BATCH_SIZE,
                                                     capacity=capacity, min_after_dequeue=min_after_dequeue)

    return data_batch, label_batch

In [3]:
def generate_batches(data_batch, label_batch):
    with tf.Session() as sess:
        coord = tf.train.Coordinator()
        threads = tf.train.start_queue_runners(coord=coord)
        for _ in range(10):  # generate 10 batches
            features, labels = sess.run([data_batch, label_batch])
            print("features:", features)
            print('labels:', labels)
        coord.request_stop()
        coord.join(threads)

In [4]:
def main():
    data_batch, label_batch = batch_generator([DATA_PATH])
    generate_batches(data_batch, label_batch)

In [5]:
main()

record_defaults: [[1.0], [1.0], [1.0], [1.0], [''], [1.0], [1.0], [1.0], [1.0], [1]]
features: [[120.     7.5   15.33  22.     0.    60.    25.31  34.49  49.  ]
 [118.     6.     9.65  33.91   0.    60.    38.8    0.    48.  ]
 [117.     1.53   2.44  28.95   1.    35.    25.89  30.03  46.  ]]
features: [[144.     4.09   5.55  31.4    1.    60.    29.43   5.55  56.  ]
 [134.    13.6    3.5   27.78   1.    60.    25.99  57.34  49.  ]
 [132.     6.2    6.47  36.21   1.    62.    30.77  14.14  45.  ]]
features: [[132.     7.9    2.85  26.5    1.    51.    26.16  25.71  44.  ]
 [160.    12.     5.73  23.11   1.    49.    25.3   97.2   52.  ]
 [126.     8.75   6.53  34.02   0.    49.    30.25   0.    41.  ]]
features: [[138.     0.6    3.81  28.66   0.    54.    28.7    1.46  58.  ]
 [146.    10.5    8.29  35.36   1.    78.    32.73  13.89  53.  ]
 [118.     0.28   5.8   33.7    1.    60.    30.98   0.    41.  ]]
features: [[114.     4.08   4.59  14.6    1.    62.    23.11   6.72  58.  ]
 [1