In [1]:
import tensorflow as tf
assert tf.__version__ >= '1.5', ('This code requires TensorFlow v1.5, '
                                   'You have:%s' % tf.__version__)
tf.logging.set_verbosity(tf.logging.DEBUG)

In [15]:
sparse_matrix = [[0, 7, 0], 
                 [0, 0, 0], 
                 [0, 0, 8]]
sparse_matrx_tensor = tf.constant(sparse_matrix)

In [20]:
IXS = [[0, 1], [2, 2]]
VALS = [7, 8]
MATRIX_SHAPE = [3, 3]

In [21]:
sparse_tensor = tf.SparseTensor(indices=IXS, values=VALS, dense_shape=MATRIX_SHAPE)

In [22]:
sess = tf.Session()

In [23]:
spare_tensor_value = sess.run(sparse_tensor)
print(spare_tensor_value)
print(spare_tensor_value.dense_shape.tolist())
print('convert spare tensor to dense tensor....')
print(sess.run(tf.sparse_tensor_to_dense(sparse_tensor)))

SparseTensorValue(indices=array([[0, 1],
       [2, 2]]), values=array([7, 8], dtype=int32), dense_shape=array([3, 3]))
[3, 3]
convert spare tensor to dense tensor....
[[0 7 0]
 [0 0 0]
 [0 0 8]]


In [24]:
# ndarray to list and cmp
assert all([spare_tensor_value.indices.tolist() == IXS, 
            spare_tensor_value.values.tolist() == VALS,
#             spare_tensor_value.dense_shape == sparse_matrx_tensor.shape,
            spare_tensor_value.dense_shape.tolist() == MATRIX_SHAPE,
           ]), 'something wrong.'

In [33]:
# write sparse tensor to tfrecord
def _intlist_feature(values):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=values))
def _floatlist_feature(values):
    return tf.train.Feature(float_list=tf.train.FloatList(value=values))
def write_to_tfrecord(indices, values, dense_shape, tfrecord_file=None):
    inversed_indices = zip(*indices)
    features = {}
    for i, idx in enumerate(inversed_indices):
        features['idx_%s' % i] = _intlist_feature(idx)
    features['values'] = _floatlist_feature(values) 
    print(features)
    writer = tf.python_io.TFRecordWriter(tfrecord_file)
    example = tf.train.Example(features=tf.train.Features(feature=features))
    writer.write(example.SerializeToString())
    writer.close()

In [34]:
write_to_tfrecord(IXS, VALS, MATRIX_SHAPE, './sparse_tensor.tfrecords')

{'idx_0': int64_list {
  value: 0
  value: 2
}
, 'idx_1': int64_list {
  value: 1
  value: 2
}
, 'values': float_list {
  value: 7
  value: 8
}
}


In [74]:
# read sparse tensor from tfrecord
def parser(serialized_example):
    features = tf.parse_single_example(
        serialized_example,
        features={'sparse_value': tf.SparseFeature(index_key=['idx_0', 'idx_1'],
                                                   value_key='values',
                                                   dtype=tf.float32,
                                                   size=[3, 3])})
    return {'sparse_tensor': features['sparse_value']}, tf.constant('label...')
#     return features


def create_input_fun(file_path):
    def input_fun():
        dataset = tf.data.TFRecordDataset([file_path])
        dataset = dataset.map(parser)
        # write before batch
        dataset = dataset.repeat(100)
        dataset = dataset.batch(3)

        iterator = dataset.make_one_shot_iterator()
        value_batch = iterator.get_next()
        return value_batch
    return input_fun

In [75]:

sparse_input_fun = create_input_fun('./sparse_tensor.tfrecords')
values, labels  = sparse_input_fun()
with tf.train.MonitoredTrainingSession() as sess:
    vals, labs = sess.run([values, labels])
    
print('result:')
print(vals, labs)

result:
{'sparse_tensor': SparseTensorValue(indices=array([[0, 0, 1],
       [0, 1, 1],
       [1, 0, 1],
       [1, 1, 1],
       [2, 0, 1],
       [2, 1, 1]]), values=array([ 7.,  8.,  7.,  8.,  7.,  8.], dtype=float32), dense_shape=array([3, 3, 3]))} [b'label...' b'label...' b'label...']


In [77]:

val = vals['sparse_tensor']
tf.Session().run(tf.sparse_tensor_to_dense(tf.SparseTensor(val.indices, val.values, val.dense_shape)))

array([[[ 0.,  7.,  0.],
        [ 0.,  8.,  0.],
        [ 0.,  0.,  0.]],

       [[ 0.,  7.,  0.],
        [ 0.,  8.,  0.],
        [ 0.,  0.,  0.]],

       [[ 0.,  7.,  0.],
        [ 0.,  8.,  0.],
        [ 0.,  0.,  0.]]], dtype=float32)