Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for CTC in both Theano and Tensorflow along with image OCR example. #3436

Merged
merged 9 commits into from
Aug 16, 2016

Conversation

mbhenry
Copy link
Contributor

@mbhenry mbhenry commented Aug 10, 2016

This commit adds support for training RNNs with Connectionist Temporal Classification (CTC), which is a popular loss function for streams where the temporal or translational alignment between the input data and labels is unknown. An example would be raw speech spectrograms as input data and phonemes as labels. Another example is an input image that includes rendered text with an unknown translational location, word/character spacing, or rotation.

For Tensorflow, a wrapper was created for the built-in CTC code and put in tensorflow_backend.py. This wrapper is fairly complex as it has to transform a dense tensor into a sparse tensor. Note that in the bleeding-edge Tensorflow, they moved the location of CTC from contrib to util.

For Theano, an implementation was included courtesy of and used with permission from @shawntan. Because it was not written for batch processing, its quite a bit slower than Tensorflow - but it does work.

This commit includes an example that performs OCR on an image. The example works with both Theano and Tensorflow. The text-based image is generated using a list of single words (wordlist_mono_clean.txt) and double words (wordlist_bi_clean.txt). I did my best to make sure no profanity ended up in these lists, but apologies in advance if I missed something. Here is an example output after 40 epochs:

e39

The text is printed onto a 512 x 64 image using a variety of fonts (note the font list works in Centos 7, but not sure what will happen on other OSes). This is done on the fly for all training images using generators. A random amount of speckle noise, font, rotation, and translation is applied. These images are then fed into a network consisting of two convolutional layers, a fully connected layer, two bidirectional recurrent layers, and finally a fully connected layer with 28 outputs (26 letters, space, and CTC blank). After about 10 epochs it does pretty well with 5 letter words, so harder words are introduced. After 20 epochs, phrases with spaces are introduced.

Additional notes:

  • I have no idea if its learning general text shapes or just remembering fonts. This is more to demonstrate CTC with Keras and actually uses quite a few useful Keras features, including callbacks, functional layers, and dataset generators.
  • A wrapper is also included for Tensorflow's constrained dictionary decoding, but is not used in any tests or examples yet. This functionality would be useful for OCR or speech recognition.
  • For Theano, the following option is required: "on_unused_input='ignore'. This is because the loss function has extra parameters which Keras doesn't support, so a dummy loss function was required to make Keras happy.

This is my first Github commit ever so please go easy on me if I mucked something up :)

@mbhenry mbhenry changed the title Added support support for CTC in both Theano and Tensorflow along with image OCR example. Added support for CTC in both Theano and Tensorflow along with image OCR example. Aug 10, 2016
@fchollet
Copy link
Member

Sounds great! I'll review it tomorrow. CTC is definitely a much needed addition.

For the time being, one immediate comment: please do not commit data files into the git tree, rather put them online and have your script fetch them like so: https://github.com/fchollet/keras/blob/master/examples/lstm_text_generation.py#L23

# for the particular OS in use.
#
# This starts off with easy 5 letter words. After 10 or so epochs, CTC
# learn translataional invariance, so longer words and groups of words
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typo: "translation invariance"

@mbhenry
Copy link
Contributor Author

mbhenry commented Aug 15, 2016

Here's whats in the latest PR:

  • I'm pretty sure I knocked out all of the style and idiomatic issues you brought up.
  • I got rid of the line that shuffles the bigram file. This should eliminate the issue where validation words previously appear in training sets. Now the file is always read in order, but as the difficulty is ratcheted up, more words are pulled from the source files.
  • I made it so the learning rate for Tensorflow and Theano are the same. This is more because I gave up trying to get the Theano results to match the Tensorflow results. I think more than just the LR needs to be tweaked to achieve that.

@fchollet
Copy link
Member

Thanks! Style-wise: still lots of unused imports. Otherwise LGTM.

@mbhenry
Copy link
Contributor Author

mbhenry commented Aug 16, 2016

Latest commit fixes those unused imports. This has been a learning experience....

@fchollet
Copy link
Member

One last thing before I merge. Your commits are not associated with your Github email address. That means that your account won't be linked to the PR and you won't appear in the list of contributors. You may want to fix that (add your git email to your Github account, for instance).

self.X_text = []
self.Y_len = [0] * self.num_words

#monogram file is sorted by frequency in english speech
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In-line comments require one space after #

@fchollet
Copy link
Member

LGTM

@fchollet fchollet merged commit e8190a8 into keras-team:master Aug 16, 2016
@fchollet
Copy link
Member

Thanks for the great OCR example. Very valuable imo.

@fchollet
Copy link
Member

We'll probably have to update the way to import CTC in TF. The current code appears to work with TF 0.9 but breaks in 0.10rc.

@mbhenry
Copy link
Contributor Author

mbhenry commented Aug 16, 2016

Yea, they moved CTC from the experimental contrib area to core.util.ctc. I wasn't sure of the best "pythonic" way of checking multiple locations for an import.

@kanzure
Copy link

kanzure commented Aug 20, 2016

I wasn't sure of the best "pythonic" way of checking multiple locations for an import.

Here's one way to do this:

try:
    import tf.contrib.ctc as tfctc
except ImportError:
    import tf.core.util.ctc as tfctc

Unfortunately this breaks symmetry with how tf.* is used elsewhere in the source code. A somewhat hacky way to overcome this would be:

try:
    import tf.contrib.ctc as tfctc
except ImportError:
    import tf.core.util.ctc as tfctc
finally:
    tf.ctc = tfctc
   del tfctc
# ... or place the assignment outside of the try/else block

Finally, another way to do this would be to make a tensorform compatibility layer elsewhere in the source code, and only access terraflow through that compatibility layer.

junwei-pan pushed a commit to junwei-pan/keras that referenced this pull request Aug 21, 2016
…OCR example. (keras-team#3436)

* Added CTC to Theano and Tensorflow backend along with image OCR example

* Fixed python style issues, made data files remote, and made code more idiomatic to Keras

* Fixed a couple more style issues brought up in the original PR

* Reverted wrappers.py

* Fixed potential training-on-validation issue and removed unused imports

* Fixed PEP8 issue

* Remaining PEP8 issues fixed
@BackT0TheFuture
Copy link

BackT0TheFuture commented Dec 1, 2016

@mbhenry
image_ocr.py doesn't work on windows duing to cairo.
maybe it's better gen image using [ PIL import Image, ImageFont, ImageDraw]
BTW how to make width variable?

@mbhenry
Copy link
Contributor Author

mbhenry commented Dec 2, 2016

Its on my list to look into both of those. Variable width would probably have to be with Dynamic RNNs. Right now I'm also working on improving the convergence stability...slight disturbances seem to have big impacts on convergence.

@AvenSun
Copy link

AvenSun commented Dec 5, 2016

@mbhenry
great to hear that !
I found one CTC implementation based Keras v1.0.6
maybe it's a good start!

@milani
Copy link
Contributor

milani commented Jul 8, 2017

Is this loss function going to be documented? I assume it is experimental when not documented, at the same time it is merged. Is it released at all?

@kushalkafle
Copy link

kushalkafle commented Jul 12, 2017

@mbhenry, Thanks for a great example. I noticed that you 'skip' first couple letters generated in the code.
Could this be due to the fact that the early time-steps do not have enough information from the image? Right now, features in dense1 represent features for a "vertical slice" of image along the width of the image, and they are being used as timesteps input for RNN.
So, it looks like if the text is not present towards the edges of images on either side (Since RNN is bi-directional) earlier timesteps will not have any information to start producing the letters.

Is there a reason for doing this? Is traversing the image in order important for generating the letters left-to-right? Would it make more sense to have a single image representation (perhaps at the end of LSTM that has seen all the "slices" of image features) and use RepeatVector to feed this image information to each timestep in the RNN? (e.g., something a simple captioning model would do)

time_steps = img_w / (pool_size_1 * pool_size_2)

fdir = os.path.dirname(get_file('wordlists.tgz',
origin='http://www.isosemi.com/datasets/wordlists.tgz', untar=True))
Copy link
Contributor

@milani milani Jul 12, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the http server tries to redirect http to https, but in the process, it removes a slash so it becomes: 'http://www.isosemi.comdatasets/wordlists.tgz' which is an invalid address.


# transforms RNN output to character activations:
inner = TimeDistributed(Dense(img_gen.get_output_size(), name='dense2'))(merge([gru_2, gru_2b], mode='concat'))
y_pred = Activation('softmax', name='softmax')(inner)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should the activation be set according to the backend in use? Tensorflow's documentation reads[1]:

This class performs the softmax operation for you, so inputs should be e.g. linear projections of outputs by an LSTM

[1] https://www.tensorflow.org/api_docs/python/tf/nn/ctc_loss

@nouiz
Copy link
Contributor

nouiz commented Aug 10, 2017

Note, we merged in the master of Theano a ops for the CTC from baidu:

http://deeplearning.net/software/theano_versions/dev/library/gpuarray/ctc.html

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants