Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using tf.data.Dataset API #10110

Closed
AakashKumarNain opened this issue May 3, 2018 · 13 comments
Closed

Using tf.data.Dataset API #10110

AakashKumarNain opened this issue May 3, 2018 · 13 comments

Comments

@AakashKumarNain
Copy link
Contributor

I was trying to use the tf Dataset API with keras but I am getting weird errors. Here is my code:

def data_gen(X=None, y=None, batch_size=32, nb_epochs=1, sess=None):
    def _parse_function(filename, label):
        image_string = tf.read_file(filename)
        image_decoded = tf.cast(tf.image.decode_jpeg(image_string), tf.float32)
        image_decoded = (image_decoded / tf.constant(127.5)) - tf.constant(1.)
        image_resized = tf.image.resize_images(image_decoded, [224, 224])
        
        return image_resized, label
    
    dataset = tf.data.Dataset.from_tensor_slices((X,y))
    dataset = dataset.map(_parse_function)
    dataset = dataset.batch(batch_size).repeat(nb_epochs)
    iterator = dataset.make_initializable_iterator()
    next_element = iterator.get_next()
    
    for i in range(nb_epochs):
        sess.run(iterator.initializer)
        while True:
            try:
                nxb, nxl = sess.run(next_element)
                nxl = to_categorical(nxl, num_classes=10)
                yield nxb, nxl
            except tf.errors.OutOfRangeError:
                break


train_images = tf.constant(train_df['image'].values)  
train_labels = tf.constant([labels_dict[l] for l in train_df['label'].values])

valid_images = tf.constant(valid_df['image'].values)
valid_labels = tf.constant([labels_dict[l] for l in valid_df['label'].values])

sess = K.get_session()
model = get_model()

train_gen = data_gen(X=train_images, y=train_labels, nb_epochs=10, sess=sess)
valid_gen = data_gen(X=valid_images, y=valid_labels, nb_epochs=10, sess=sess)

batch_size=32
nb_train_steps = train_images.shape.num_elements() // batch_size
nb_valid_steps = valid_images.shape.num_elements() // batch_size

# Fit the model
model.fit_generator(train_gen, steps_per_epoch=nb_train_steps,validation_data=valid_gen, validation_steps=nb_valid_steps)

The last line throws this error:

model.fit_generator(train_gen, steps_per_epoch=nb_train_steps,validation_data=valid_gen, validation_steps=nb_valid_steps)

Epoch 1/1

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-32-be01c033aa94> in <module>()
----> 1 model.fit_generator(train_gen, steps_per_epoch=nb_train_steps,validation_data=valid_gen, validation_steps=nb_valid_steps)

/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/legacy/interfaces.py in wrapper(*args, **kwargs)
     89                 warnings.warn('Update your `' + object_name +
     90                               '` call to the Keras 2 API: ' + signature, stacklevel=2)
---> 91             return func(*args, **kwargs)
     92         wrapper._original_function = func
     93         return wrapper

/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/engine/training.py in fit_generator(self, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
   1415             use_multiprocessing=use_multiprocessing,
   1416             shuffle=shuffle,
-> 1417             initial_epoch=initial_epoch)
   1418 
   1419     @interfaces.legacy_generator_methods_support

/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/engine/training_generator.py in fit_generator(model, generator, steps_per_epoch, epochs, verbose, callbacks, validation_data, validation_steps, class_weight, max_queue_size, workers, use_multiprocessing, shuffle, initial_epoch)
    153             batch_index = 0
    154             while steps_done < steps_per_epoch:
--> 155                 generator_output = next(output_generator)
    156 
    157                 if not hasattr(generator_output, '__len__'):

/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/utils/data_utils.py in get(self)
    791             success, value = self.queue.get()
    792             if not success:
--> 793                 six.reraise(value.__class__, value, value.__traceback__)

/opt/conda/lib/python3.6/site-packages/six.py in reraise(tp, value, tb)
    691             if value.__traceback__ is not tb:
    692                 raise value.with_traceback(tb)
--> 693             raise value
    694         finally:
    695             value = None

/opt/conda/lib/python3.6/site-packages/Keras-2.1.6-py3.6.egg/keras/utils/data_utils.py in _data_generator_task(self)
    656                             # => Serialize calls to
    657                             # infinite iterator/generator's next() function
--> 658                             generator_output = next(self._generator)
    659                             self.queue.put((True, generator_output))
    660                         else:

<ipython-input-18-a1077a5b59f8> in data_gen(X, y, batch_size, nb_epochs, sess)
     11     dataset = dataset.map(_parse_function)
     12     dataset = dataset.batch(batch_size).repeat(nb_epochs)
---> 13     iterator = dataset.make_initializable_iterator()
     14     next_element = iterator.get_next()
     15 

/opt/conda/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in make_initializable_iterator(self, shared_name)
    106             sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
    107     with ops.colocate_with(iterator_resource):
--> 108       initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
    109                                                   iterator_resource)
    110     return iterator_ops.Iterator(iterator_resource, initializer,

/opt/conda/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in _as_variant_tensor(self)
   1402   def _as_variant_tensor(self):
   1403     return gen_dataset_ops.repeat_dataset(
-> 1404         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
   1405         count=self._count,
   1406         output_shapes=nest.flatten(

/opt/conda/lib/python3.6/site-packages/tensorflow/python/data/ops/dataset_ops.py in _as_variant_tensor(self)
   1646             sparse.as_dense_shapes(self.output_shapes, self.output_classes)),
   1647         output_types=nest.flatten(
-> 1648             sparse.as_dense_types(self.output_types, self.output_classes)))
   1649 
   1650   @property

/opt/conda/lib/python3.6/site-packages/tensorflow/python/ops/gen_dataset_ops.py in batch_dataset(input_dataset, batch_size, output_types, output_shapes, name)
     54     _, _, _op = _op_def_lib._apply_op_helper(
     55         "BatchDataset", input_dataset=input_dataset, batch_size=batch_size,
---> 56         output_types=output_types, output_shapes=output_shapes, name=name)
     57     _result = _op.outputs[:]
     58     _inputs_flat = _op.inputs

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(self, op_type_name, name, **keywords)
    348       # Need to flatten all the arguments into a list.
    349       # pylint: disable=protected-access
--> 350       g = ops._get_graph_from_inputs(_Flatten(keywords.values()))
    351       # pylint: enable=protected-access
    352     except AssertionError as e:

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _get_graph_from_inputs(op_input_list, graph)
   5651         graph = graph_element.graph
   5652       elif original_graph_element is not None:
-> 5653         _assert_same_graph(original_graph_element, graph_element)
   5654       elif graph_element.graph is not graph:
   5655         raise ValueError("%s is not from the passed-in graph." % graph_element)

/opt/conda/lib/python3.6/site-packages/tensorflow/python/framework/ops.py in _assert_same_graph(original_item, item)
   5587   if original_item.graph is not item.graph:
   5588     raise ValueError("%s must be from the same graph as %s." % (item,
-> 5589                                                                 original_item))
   5590 
   5591 

ValueError: Tensor("batch_size:0", shape=(), dtype=int64) must be from the same graph as Tensor("MapDataset_3:0", shape=(), dtype=variant).


@fchollet
Copy link
Member

fchollet commented May 3, 2018

This error seems to occur inside your data generator. Are you able to do something like:

for x, y in data_gen(...):
   continue

@fchollet
Copy link
Member

fchollet commented May 3, 2018

Also, note that with Keras at HEAD and TF 1.8, you can fit from data tensors using:

x, y = iterator.get_next()
model.fit(x, y, steps_per_epoch=steps_per_epoch, epochs=epochs)

@AakashKumarNain
Copy link
Contributor Author

AakashKumarNain commented May 4, 2018

Yes I am able to do that. For example,

train_gen = data_gen(X=train_images, y=train_labels, nb_epochs=10, sess=sess)

for i,(x,y) in enumerate(train_gen):
    print(x.shape, y.shape)

The above line outputs:

(32, 224, 224, 3) (32, 10)
(32, 224, 224, 3) (32, 10)
(32, 224, 224, 3) (32, 10)
(32, 224, 224, 3) (32, 10)
..................................................
..................................................

Also, three more things:

  1. I used TF 1.8 yesterday. As soon as I fire up tensorboard, everything blows away. My main tf script gets killed. I switched back to TF1.7, everything worked fine. Switched back to TF1.8, again everything break. I hope that this bug will get resolved soon.

  2. x, y = iterator.get_next()
    model.fit(x, y, steps_per_epoch=steps_per_epoch, epochs=epochs)

    Will this work with fit_generator() too?

  3. I was using Eager and with Keras and so many things break in that. Should I open another issue or post the issue in this thread only?

@mattdornfeld
Copy link

@fchollet Maybe you can clarify this for me? Why would I need to use the Dataset API with Keras? Does it provide any functionality that fit_generator on its own does not? Thank you!

@fchollet
Copy link
Member

Will this work with fit_generator() too?

What do you mean? fit_generator is no longer necessary if you are fitting from a TF dataset.

I was using Eager and with Keras and so many things break in that. Should I open another issue or post the issue in this thread only?

Please open a new issue.

Why would I need to use the Dataset API with Keras? Does it provide any functionality that fit_generator on its own does not? Thank you!

Use it if your data is already in Dataset format. One reason to use Dataset is that it may offer better performance than multi-process Python generators in some cases.

@mattdornfeld
Copy link

@fchollet Is this because Datasets get data into RAM from storage in a more efficient manner or do they load data from RAM into GPU memory in a more efficient manner?

@knathanieltucker
Copy link

I was trying to use the Tensorboard callback with hist_freq=1, fitting on tensors from iterator.get_next(), and got the following error:

TypeError: The value of a feed cannot be a tf.Tensor object. Acceptable feed values include Python scalars, strings, lists, numpy ndarrays, or TensorHandles.For reference, the tensor object was Tensor("strided_slice_4:0", shape=(2, 256, 256, 3), dtype=float32) which was passed to the feed with key Tensor("input_1:0", shape=(?, 256, 256, 3), dtype=float32).

It's thrown on this line here: https://github.com/keras-team/keras/blob/master/keras/callbacks.py#L867

Would be happy to make another issue, just seemed related to the new tf 1.8 feature

@was84san
Copy link

@fchollet what about model.predict and evaluate, do they support tensor source data?

@was84san
Copy link

@knathanieltucker i noticed from your error message that tensorHandle is allowed, So if we convert the tensors input to tensor handler through tf.get_session_handle(input_tensor), Is the problem solved?

@knathanieltucker
Copy link

Yeah I think that would work. A good notebook to test this on would be this one:

https://github.com/tensorflow/workshops/blob/master/notebooks/2-mnist-with-keras-eager-and-tf-data.ipynb

Because they have the dataset object loaded up. Let me know if y'all don't have time. Otherwise I'll test it in a couple weeks.

@was84san
Copy link

@aakashkumar. Now I have the same issue with you. I solved by putting the worker = 0 in fit_generator. Still don't know why. If worker = 0 , the code will be implemented in main thread.

@ghulammustufa31
Copy link

Did you find any solution to this issue? I've got the same error while working with Initializable_Iterator in Keras.

@zhangyi02
Copy link

zhangyi02 commented Jan 17, 2019

@fchollet

x, y = iterator.get_next()
model.fit(x, y, steps_per_epoch=steps_per_epoch, epochs=epochs)

When I try this way, got error as:

When feeding symbolic tensors to a model, we expect the tensors to have a static batch size. Got tensor with shape: (None, None)

which is caused by iterator.get_next() return tensors with shape (?,?).
But I don't understand why it's batch size is None, even with dataset.bacth(BATCH_SIZE) .

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants