Skip to content
This repository has been archived by the owner on Dec 29, 2022. It is now read-only.

Various InvalidArgumentError in Evaluation #103

Closed
dennybritz opened this issue Mar 23, 2017 · 74 comments
Closed

Various InvalidArgumentError in Evaluation #103

dennybritz opened this issue Mar 23, 2017 · 74 comments
Labels

Comments

@dennybritz
Copy link
Contributor

dennybritz commented Mar 23, 2017

The error manifests in multiple ways, always during evaluation and some kind of shape error:

Possible causes:

  • Some kind of bug in tf.learn evaluate that causes issues with the GPU during eva (this doesn't seem to be a problem with CPU training)
  • Mismatch of parameters or training data in training/eval
  • Issue with the data reading / input pipeline, e.g. mismatched sequence lengths.

It could be related to these Tensorflow issues:

For everyone having this issue, please answer the following to help debug:

  • What versions of TensorFlow are you using?
  • Can you run on the CPU without error?
  • What versions of CUDA are you using?
@dennybritz
Copy link
Contributor Author

It would be great if someone running into this can provide full log files including both training and evaluation.

@DevSinghSachan
Copy link

DevSinghSachan commented Mar 23, 2017

Hi Denny Britz,

Here is one example log file for both training and evaluation

seq2seq_bug.txt

The dataset I used was the WMT En->Fr translation task. I tokenized it using BPE code as mentioned in the seq2seq tutorial.

Also, attached is my run script.

sample_seq2seq_run_script_sh.txt

Thanks,
Devendra

@dennybritz
Copy link
Contributor Author

Thanks for the log. Everything in the log seems fine to me, unfortunately. Could you try reducing the batch size and see if the error still happens?

@DevSinghSachan
Copy link

I just tried for smaller batch sizes of 16 and 8, but getting the same error in the number of first dimension and label shape. The log files are almost similar to the one attached above.

@leezhao9
Copy link

@dennybritz, thanks for seqseq and for looking at this.

fr-en.log2.txt
fr-en.log.txt
I had three failed runs, all started clean. Though the runs are identical, the reported logtis shapes are different for each run. The logs are attached.
It is not directly related to evaluation, since I can successfully run prediction on latest saved checkpoint when training is not running. I think it is related to loading the model for validation while the model loaded for training, some memory sharing issue?

As you can see below the reported logits shape keep increasing with subsequent failing runs. It does reset again when running on a different GPU.

InvalidArgumentError (see above for traceback): logits and labels must have the same first dimension, got logits shape [1280,36240] and labels shape [6272]

InvalidArgumentError (see above for traceback): logits and labels must have the same first dimension, got logits shape [2304,36240] and labels shape [6272]

InvalidArgumentError (see above for traceback): logits and labels must have the same first dimension, got logits shape [4480,36240] and labels shape [6272]

@dennybritz
Copy link
Contributor Author

dennybritz commented Mar 24, 2017

I think it is related to loading the model for validation while the model loaded for training, some memory sharing issue?

Yes, I think it's a GPU memory sharing issue. It seems likely that it's a bug in Tensorflow. This code doesn't do anything special and just uses the tf.learn estimator that handles all the model construction.

Thanks for running these experiments. The fact that the shapes are increasing is very suspicious. I'll create a Tensorflow issue.

@dennybritz
Copy link
Contributor Author

dennybritz commented Mar 24, 2017

Opened a TF issue: tensorflow/tensorflow#8701

@amirj
Copy link
Contributor

amirj commented Mar 24, 2017

Reducing the number of validation records (50 for example) solved the problem for me.

@skyw
Copy link

skyw commented Mar 24, 2017

@amirj
How do you do that?
"Reducing the number of validation records (50 for example)"

@amirj
Copy link
Contributor

amirj commented Mar 24, 2017

Just decrease the number of samples in the validation set (for example the first 50 lines!).
It may be help the others to better find the problem.

@dummyindex
Copy link

dummyindex commented Mar 26, 2017

I also got similar problem when following the documentation tutorial exactly.

GPU: Titan X Pascal
tensorflow version: 1.0.0 1.0.1 (tried both)
CUDA lib: 8

BTW, amirj's suggestion temporarily solves the problem. ty (still outputs error some time)

See rihardsk's suggestion below, which works in my case.

@rihardsk
Copy link

rihardsk commented Mar 27, 2017

Using a smaller validation set didn't help for me. I tried to decrease it down to just 10 sentences with no luck. The only way to get the training going was to set the --schedule parameter to train to disable the evaluation altogether.

@chenb67
Copy link

chenb67 commented Mar 29, 2017

Removing the buckets line(buckets: 10,20,30,40) from example_configs/train_seq2seq.yml seems to solve the InvalidArgumentError for me

@dennybritz
Copy link
Contributor Author

@chenb67 That's an interesting finding. It's strange because bucketing is disabled during eval anyway and only enabled during training. It could be an issue with bucket_by_sequence_length. Maybe it somehow confuses train and dev input queues.

I added an extra scope for the input function (#126), can you see if that fixes it with bucketing enabled?

@chenb67
Copy link

chenb67 commented Mar 29, 2017

Hi @dennybritz,
I tested it with bucketing enabled with the code from #126 - it doesn't work.
I'm getting InvalidArgumentError like before

@M4t1ss
Copy link

M4t1ss commented Mar 29, 2017

Removing buckets didn't work for me. I'm currently running with evaluation disabled and performing a manual run of translation and evaluation every once in a while 😆

@dennybritz
Copy link
Contributor Author

dennybritz commented Mar 31, 2017

I just added a few GPU memory options to the training script in #137. Could you try the following:

  • Pass gpu_allow_growth as a flag to training script.
  • OR Pass gpu_memory_fraction as a flag to the training script as it to to something lower than 1.0

See https://www.tensorflow.org/tutorials/using_gpu for what these options do. Does this solve the issue?

@dennybritz
Copy link
Contributor Author

For everyone having this issue, please answer the following to help debug:

  • What versions of TensorFlow are you using?
  • Can you run on the CPU without error?
  • What versions of CUDA are you using?

@aselle
Copy link

aselle commented Mar 31, 2017

And operating system as well, please.

@chenb67
Copy link

chenb67 commented Apr 1, 2017

I tested it with the flags from #137, it didn't help.
I 'm getting the same results with tensorflow version 1.0.1 and 1.1.0 on ubuntu 16.04 with cuda 8.0

@M4t1ss
Copy link

M4t1ss commented Apr 1, 2017

@dennybritz @aselle

  • Ubuntu 16.04.2 LTS
  • TensorFlow 1.0.1
  • Cuda compilation tools, release 8.0, V8.0.44

And it runs fine on a CPU 😮

@electroducer
Copy link

One thing I noticed was that the evaluation phase will crash in GPU mode if the bleu scorer throws error due to a complete mismatch (see #106). I was able to at least stop it from crashing by setting eval_every_n_steps to a high enough number (in my case, 5000) that the bleu scorer would stop throwing errors.

Still get a bunch of warnings at evaluation time though:

W tensorflow/core/framework/op_kernel.cc:993] Out of range: Reached limit of 1
	 [[Node: dev_input_fn/parallel_read_1/filenames/limit_epochs/CountUpTo = CountUpTo[T=DT_INT64, _class=["loc:@dev_input_fn/parallel_read_1/filenames/limit_epochs/epochs"], limit=1, _device="/job:localhost/replica:0/task:0/cpu:0"](dev_input_fn/parallel_read_1/filenames/limit_epochs/epochs)]]
W tensorflow/core/framework/op_kernel.cc:993] Out of range: FIFOQueue '_29_dev_input_fn/parallel_read_1/common_queue' is closed and has insufficient elements (requested 1, current size 0)
	 [[Node: dev_input_fn/parallel_read_1/common_queue_Dequeue = QueueDequeueV2[component_types=[DT_STRING, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](dev_input_fn/parallel_read_1/common_queue)]]
W tensorflow/core/framework/op_kernel.cc:993] Out of range: FIFOQueue '_29_dev_input_fn/parallel_read_1/common_queue' is closed and has insufficient elements (requested 1, current size 0)
	 [[Node: dev_input_fn/parallel_read_1/common_queue_Dequeue = QueueDequeueV2[component_types=[DT_STRING, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](dev_input_fn/parallel_read_1/common_queue)]]
W tensorflow/core/framework/op_kernel.cc:993] Out of range: PaddingFIFOQueue '_28_dev_input_fn/batch_queue/padding_fifo_queue' is closed and has insufficient elements (requested 32, current size 0)
	 [[Node: dev_input_fn/batch_queue = QueueDequeueUpToV2[component_types=[DT_INT32, DT_STRING, DT_INT32, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](dev_input_fn/batch_queue/padding_fifo_queue, dev_input_fn/batch_queue/n)]]
W tensorflow/core/framework/op_kernel.cc:993] Out of range: PaddingFIFOQueue '_28_dev_input_fn/batch_queue/padding_fifo_queue' is closed and has insufficient elements (requested 32, current size 0)
	 [[Node: dev_input_fn/batch_queue = QueueDequeueUpToV2[component_types=[DT_INT32, DT_STRING, DT_INT32, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](dev_input_fn/batch_queue/padding_fifo_queue, dev_input_fn/batch_queue/n)]]

@dennybritz
Copy link
Contributor Author

dennybritz commented Apr 3, 2017

Still get a bunch of warnings at evaluation time though

These are normal Tensorflow/tf.learn warnings and completely OK. I just means that the input data queue is exhausted and you've itereated over all the data.

One thing I noticed was that the evaluation phase will crash in GPU mode if the bleu scorer throws error due to a complete mismatch

@electroducer That's very strange. Just to make sure, you were seeing the same error as the other people (a shape mismatch) when the BLEU script didn't exit normally? I have a hard time imagining how these could be related. Maybe it is something related to subprocess and GPUs.

@ghost
Copy link

ghost commented Apr 14, 2017

@kadir-gunel

This is the train.py file I used:
train.txt

@ghost
Copy link

ghost commented Apr 15, 2017

@SwordYork

Unfortunately my model crashed again after 62.000 training steps with the following error message below. Do you happen to know what caused it? Is it due to the continuous_train_and_eval perhaps?

Caused by op 'bleu/value', defined at:
File "/opt/rh/rh-python35/root/usr/lib64/python3.5/runpy.py", line 170, in _run_module_as_main
"main", mod_spec)
File "/opt/rh/rh-python35/root/usr/lib64/python3.5/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/seq2seq/bin/train.py", line 351, in
tf.app.run()
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/python/platform/app.py", line 44, in run
_sys.exit(main(_sys.argv[:1] + flags_passthrough))
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/seq2seq/bin/train.py", line 344, in main
schedule=FLAGS.schedule)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/learn_runner.py", line 106, in run
return task()
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/seq2seq/bin/train.py", line 105, in continuous_train_and_eval
hooks=self._eval_hooks)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/python/util/deprecation.py", line 280, in new_func
return func(*args, **kwargs)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 514, in evaluate
log_progress=log_progress)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 810, in _evaluate_model
eval_ops = self._get_eval_ops(features, labels, metrics)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 1195, in _get_eval_ops
metrics, features, labels, model_fn_ops.predictions))
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/contrib/learn/python/learn/estimators/estimator.py", line 258, in _make_metrics_ops
result[name] = metric.create_metric_ops(features, labels, predictions)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/seq2seq/seq2seq/metrics/metric_specs.py", line 124, in create_metric_ops
name="value")
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/python/ops/script_ops.py", line 189, in py_func
input=inp, token=token, Tout=Tout, name=name)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/python/ops/gen_script_ops.py", line 40, in _py_func
name=name)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/python/framework/op_def_library.py", line 763, in apply_op
op_def=op_def)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 2327, in create_op
original_op=self._default_original_op, op_def=op_def)
File "/home/pricie/milan/VirtualEnvironment/GPU/seq2seq/lib/python3.5/site-packages/tensorflow/python/framework/ops.py", line 1226, in init
self._traceback = _extract_stack()

InternalError (see above for traceback): Failed to run py callback pyfunc_488: see error log.
[[Node: bleu/value = PyFunc[Tin=[DT_STRING, DT_STRING], Tout=[DT_FLOAT], token="pyfunc_488", _device="/job:localhost/replica:0/task:0/cpu:0"](bleu/Identity, bleu/Identity_1)]]
[[Node: rouge_2/p_score/value/_329 = _Recvclient_terminated=false, recv_device="/job:localhost/replica:0/task:0/gpu:0", send_device="/job:localhost/replica:0/task:0/cpu:0", send_device_incarnation=1, tensor_name="edge_1011_rouge_2/p_score/value", tensor_type=DT_FLOAT, _device="/job:localhost/replica:0/task:0/gpu:0"]]

@SwordYork
Copy link

@milanv1 I don't know why. You may check issue #160 . It may related to the bleu.py. I have uncommented these lines to use the local version of multi-bleu.perl.

@coventry
Copy link
Contributor

coventry commented Apr 16, 2017

Here is a reliable reproduction of this bug. @SwordYork's branch seems to fix it. This is the output from the failure. The last few lines are

  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 2336, in create_op
    original_op=self._default_original_op, op_def=op_def)
  File "/usr/local/lib/python2.7/dist-packages/tensorflow/python/framework/ops.py", line 1228, in __init__
    self._traceback = _extract_stack()

InvalidArgumentError (see above for traceback): Input to reshape is a tensor with 303104 values, but the requested shape has 0
	 [[Node: model/att_seq2seq/decode/attention_decoder/decoder/while/attention/att_keys/Tensordot/Reshape = Reshape[T=DT_FLOAT, Tshape=DT_INT32, _device="/job:localhost/replica:0/task:0/gpu:0"](model/att_seq2seq/decode/attention_decoder/decoder/while/attention/att_keys/Tensordot/transpose, model/att_seq2seq/decode/attention_decoder/decoder/while/attention/att_keys/Tensordot/stack)]]

Steps to reproduce:

  1. Start & log in to AWS g2.2xlarge instance. You'll probably need to give it more space than the default 8GB. I gave it 30GB.
  2. Download the pre-processed WMT'16 EN-DE Data file wmt16_en_de.tar.gz and untar it into a directory ~/wmt.
  3. Run the following commands
git clone https://github.com/coventry/seq2seq-replication
cd seq2seq-replication
sudo bash install-nvidia-docker-build-and-run.sh

After much downloading and building, this will start a tensorflow/tensorflow:latest-devel-gpu docker container, and run the file process.sh, cobbled together from the seq2seq NMT tutorial. The attached log comes from running that.

This branch runs @SwordYork's fix, which at least runs the evaluation step for several minutes (whereas the default seq2seq code falls over immediately.)

@amirj
Copy link
Contributor

amirj commented Apr 16, 2017

During the handling of this bug, what's the correct way to learn a model without a validation set?, I just want to trace the loss on the training set and stop the process after a limited number of iterations.
I just set the parameter eval_every_n_steps to a large value. Is it OK?

This was referenced Apr 17, 2017
@dennybritz
Copy link
Contributor Author

dennybritz commented Apr 17, 2017

Thanks everyone for helping debug this. As this seems to an issue with Tensorflow/tf.learn I'm also not sure what the best way to fix this is. I may overwrite the Estimator class as per @SwordYork suggestion and just put it in this code for now until tensorflow/tensorflow#8701 is resolved.

just set the parameter eval_every_n_steps to a large value. Is it OK?

Seems ok to me.

@dennybritz
Copy link
Contributor Author

#173 has the patch.

@ghost
Copy link

ghost commented May 11, 2017

@SwordYork

Do you happen to know if it is possible that this training schedule affects the performance of the system?

When I used the previous repo, before this fix, I managed to get a BLEU around 15 with a small model (200.000 training sentences, 2000 dev set, 2000 test set, batch size 32, 178.500 training steps). Whenever I have tried to replicate this experiment with the exact same configurations afterwards, I have only managed to get something around 5 BLEU. The only difference in configuration, is this new training schedule..

@SwordYork
Copy link

@milanv1 I have not suffered this issue. Could you please try the new repo on CPU?

@ghost
Copy link

ghost commented May 12, 2017

@SwordYork

Thanks for your quick reply. The low BLEU scores were achieved after using the new repo :/

@kyleyeung
Copy link

kyleyeung commented May 16, 2017

@milanv1 I tried using tf 1.0 and tf 1.1 to reproduce the training procedure in the NMT tutotial, but the model always started to overfit at a BLEU score around ~7.5. And I also tried different versions of train.py (before and after the PatchedExperiment commit), nothing changed. (I didn't train on CPU because that could be way too slow.) So it seems the low BLEU score was not resulted from the PatchExperiment commit.

@SwordYork Could you please tell me if you have successfully reproduced the tutorial result?

@SwordYork
Copy link

@milanv1, @kyleyeung I think it may related to the GPU. Because @milanv1 could train it properly using the previous repo on CPU, however fail to train the new repo on GPU.
I have successfully reproduced the tutorial result, I don't know what happens.

@j-min
Copy link

j-min commented May 22, 2017

I have faced a similar problem while implementing image captioning code in tf 1.1+ training helper and dynamic_decode. It always caused that cross entropy shape mismatch when trained with multiple GPUs.
In my case the bug is from input pipeline, which was similar to https://github.com/tensorflow/models/blob/master/im2txt/im2txt/ops/inputs.py .
I have implemented a CPU based input pipeline and split each batch into 4 gpus with queue. The paddings are applied up to maximum length of current batch, while training helper takes sequence length, which was calculated as self.sequence_length = tf.reduce_sum(input_mask, axis=1, name='sentence_length'), sot that the maximum unrolling for each tower is determined to the maximum sequence in batch given each tower, not the maximum length of batch across whole towers. I somehow detoured this bug with tf.slice

# max_len = maximum length of sentences passes to current tower
# stack all logits for sprse_softmax
# logits:       [batch_size   max_len, vocab_size]
# logits_stack: [batch_size x max_len, vocab_size]
self.logits_stack = tf.reshape(
    logits,
    shape=[-1, self.config.vocab_size],
    name='logits_stack')

##################################
# TRIM TARGET SEQS AND MASKS
##################################
# Multi GPU training cause length mismatch!
# dynamic_decode will unroll RNNs up to longest input of current tower.
# However, batches are already padded up to maximum sequence across all towers.

# target_seqs: [batch_size,  padded_len]
# => [batch_size, max_len] if padded_len > max_len 
# target_seqs: [batch_size,  max_len]
batch_size = tf.shape(target_seqs)[0]
padded_len = tf.shape(target_seqs)[1]
current_batch_max_len = tf.shape(logits)[1]

if config.num_gpus > 1:

    # targets:     [batch_size x max_len]
    target_seqs = tf.slice(
            target_seqs,
            [0, 0],
            [batch_size, current_batch_max_len],
            name='trim_target_seqs')

    # input_mask:   [batch_size,  max_len]
    input_mask = tf.slice(
            input_mask,
            [0, 0],
            [batch_size, current_batch_max_len],
            name='trim_mask')

self.targets = tf.reshape(
    target_seqs,
    shape=[-1],
    name='targets')

cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
    labels=self.targets,
    logits=self.logits_stack,
    name='cross_entropy_per_example')

@liyi193328
Copy link

liyi193328 commented May 22, 2017

@SwordYork Thanks for your contributions.
I run on distributed tensorflow mode, with estimator's train_and_evalution,but resulting to InvalidArgumentError (see above for traceback): Tried to read from index 45 but array size is: 45 when evaluation~. And I can run with continuous_train_and_eval setting in single server. So when I want to do distbuted trainging, What I can do then?
Another problem is do multiple gpus training. In training process with multiple gpus(or one worker have multiple gpus) , only one gpu have been used actually, although all gpus have some process on it. How can I use them all. Someone say it's the string type can't handle on gpus, so Only one gpu used? How use them all to train on the fly? Wish reply~Thanks.

@SwordYork
Copy link

@liyi193328 You may refer to https://www.tensorflow.org/deploy/distributed, I have replied you in the email.

@liyi193328
Copy link

liyi193328 commented May 23, 2017

@SwordYork I may not express myself.
I have already changed the codes in distributed environment(like ps_hosts, ps, job_name) by setting tf_config and so on(https://github.com/liyi193328/seq2seq.git). But I got some errors like "InvalidArgumentError (see above for traceback): Tried to read from index 45 but array size is: 45" when evalution in chief worker.
Today I can run it normally by copy tf.contrib.experiment code to seq2seq.contrib.experiment. and change continuse_train_and_eval function a little:

    while (not continuous_eval_predicate_fn or
           continuous_eval_predicate_fn(eval_result)):

      if self._has_training_stopped(eval_result):
        # Exits once max steps of training is satisfied.
        logging.info("Stop training model as max steps reached")
        break

      config = self._estimator.config
      if (config.environment != run_config.Environment.LOCAL and
                  config.environment != run_config.Environment.GOOGLE and
              config.cluster_spec and config.master):
          self._start_server()

      logging.info("Training model for %s steps", train_steps_per_iteration)
      self._call_train(input_fn=self._train_input_fn,
                       steps=train_steps_per_iteration,
                       hooks=self._train_monitors)

Now it can run normally when ever in truly distributed environment or fake distributed env in only one machine.

Last, I misunderstood the the data batch parallel in multiple gpus in one machine. I think it can do automaticlly in tf.contib.estimator. But after reading the tf.contrib.estimator source code, it's clear we need do batch parallel with mutiple gpus with average grads and loss manully.

Any way, it can run now, thanks.

@tobyyouup
Copy link

My error is as follows, and I solve this problem by reduce the evaluation size to less than 100.

`W tensorflow/core/framework/op_kernel.cc:993] Out of range: Reached limit of 1
[[Node: dev_input_fn/parallel_read_1/filenames/limit_epochs/CountUpTo = CountUpToT=DT_INT64, _class=["loc:@dev_input_fn/parallel_read_1/filenames/limit_epochs/epochs"], limit=1, _device="/job:localhost/replica:0/task:0/cpu:0"]]
W tensorflow/core/framework/op_kernel.cc:993] Out of range: FIFOQueue '_29_dev_input_fn/parallel_read_1/common_queue' is closed and has insufficient elements (requested 1, current size 0)
[[Node: dev_input_fn/parallel_read_1/common_queue_Dequeue = QueueDequeueV2component_types=[DT_STRING, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"]]
W tensorflow/core/framework/op_kernel.cc:993] Out of range: FIFOQueue '_29_dev_input_fn/parallel_read_1/common_queue' is closed and has insufficient elements (requested 1, current size 0)
[[Node: dev_input_fn/parallel_read_1/common_queue_Dequeue = QueueDequeueV2component_types=[DT_STRING, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"]]
W tensorflow/core/framework/op_kernel.cc:993] Out of range: PaddingFIFOQueue '_28_dev_input_fn/batch_queue/padding_fifo_queue' is closed and has insufficient elements (requested 32, current size 0)
[[Node: dev_input_fn/batch_queue = QueueDequeueUpToV2[component_types=[DT_INT32, DT_STRING, DT_INT32, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](dev_input_fn/batch_queue/padding_fifo_queue, dev_input_fn/batch_queue/n)]]
W tensorflow/core/framework/op_kernel.cc:993] Out of range: PaddingFIFOQueue '_28_dev_input_fn/batch_queue/padding_fifo_queue' is closed and has insufficient elements (requested 32, current size 0)
[[Node: dev_input_fn/batch_queue = QueueDequeueUpToV2[component_types=[DT_INT32, DT_STRING, DT_INT32, DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](dev_input_fn/batch_queue/padding_fifo_queue, dev_input_fn/batch_queue/n)]]```

@RylanSchaeffer
Copy link

RylanSchaeffer commented Jun 20, 2017

@dennybritz @SwordYork I'm receiving the following error, which seems related, but distinct:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Tried to read from index 8 but array size is: 8

However, I'm not using tf.learn, a GPU, any distributed training or any validation. The error occurs while attempting to train locally on a single CPU. Additional information is included:

[[Node: define_model/define_decoder/decoder/while/BasicDecoderStep/TrainingHelperNextInputs/cond/TensorArrayReadV3 = TensorArrayReadV3[dtype=DT_FLOAT, _device="/job:localhost/replica:0/task:0/cpu:0"](define_model/define_decoder/decoder/while/BasicDecoderStep/TrainingHelperNextInputs/cond/TensorArrayReadV3/Switch, define_model/define_decoder/decoder/while/BasicDecoderStep/TrainingHelperNextInputs/cond/TensorArrayReadV3/Switch_1, define_model/define_decoder/decoder/while/BasicDecoderStep/TrainingHelperNextInputs/cond/TensorArrayReadV3/Switch_2)]]

define_model and define_decoder are functions I wrote for my sequence to sequence model, but I'm not sure what the rest of the information means. If either of you could help, I'd appreciate it!

Machine Specifications: macOS Sierra Version 10.12.5
TensorFlow Version: v1.2.0-rc2-21-g12f033d 1.2.0

@SwordYork
Copy link

@RylanSchaeffer Did you modify the input pipeline? There may be some problem with the sequence_length, e.g., sequence_length is set to 9 but only 8 elements in the input array.

@RylanSchaeffer
Copy link

@SwordYork I found my bug. Thank you though :)

@mcemilg
Copy link

mcemilg commented Dec 26, 2017

I have encountered with the same issue when I try to create a encoder-decoder model on inference graph. I solved it with padding the logits to sequence_length. Maybe it will help for the same issues.

pad_size = dec_inp_seq_length - tf.shape(infer_logits)[1]
infer_logits = tf.pad(infer_logits, [[0,0], [0, pad_size], [0,0]])

@sysuzyx
Copy link

sysuzyx commented Jan 15, 2019

@RylanSchaeffer I have met the same problem with you. Would you mind telling me what your cause of this issue was and how you solved it?

@RylanSchaeffer
Copy link

Sadly I don't remember. It was probably some small, trivial mistake.

@TanyaChowdhury
Copy link

TanyaChowdhury commented Jan 29, 2019

@sysuzyx @RylanSchaeffer I'm facing a similar issue in CPU as well as GPU. Would be able to say how you solved it?

@ruohoruotsi
Copy link

I think this codebase is busted (i.e. is not working correctly since the edits of April 17th) I'd recommend using Tensor2Tensor or OpenNMT-py. Cheers.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
Projects
None yet
Development

No branches or pull requests