Skip to content

Commit

Permalink
added dump option for NN output
Browse files Browse the repository at this point in the history
  • Loading branch information
githubharald committed May 25, 2019
1 parent f2f606d commit e8249f7
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 5 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,5 @@ model/snapshot-*
notes/
*.so
*.pyc
.idea/
.idea/
dump/
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ Tested with:
* `--validate`: validate the NN, details see below.
* `--beamsearch`: use vanilla beam search decoding (better, but slower) instead of best path decoding.
* `--wordbeamsearch`: use word beam search decoding (only outputs words contained in a dictionary) instead of best path decoding. This is a custom TF operation and must be compiled from source, more information see corresponding section below. It should **not** be used when training the NN.
* `--dump`: dumps the output of the NN to CSV file(s) saved in the `dump/` folder. Can be used as input for the [CTCDecoder](https://github.com/githubharald/CTCDecoder).

If neither `--train` nor `--validate` is specified, the NN infers the text from the test image (`data/test.png`).
Two examples: if you want to infer using beam search, execute `python main.py --beamsearch`, while you have to execute `python main.py --train --beamsearch` if you want to train the NN and do the validation using beam search.
Expand Down
34 changes: 31 additions & 3 deletions src/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import numpy as np
import tensorflow as tf
import os


class DecoderType:
Expand All @@ -20,8 +21,9 @@ class Model:
imgSize = (128, 32)
maxTextLen = 32

def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False):
def __init__(self, charList, decoderType=DecoderType.BestPath, mustRestore=False, dump=False):
"init model: add CNN, RNN and CTC and initialize TF"
self.dump = dump
self.charList = charList
self.decoderType = decoderType
self.mustRestore = mustRestore
Expand Down Expand Up @@ -217,14 +219,35 @@ def trainBatch(self, batch):
return lossVal


def dumpNNOutput(self, rnnOutput):
"dump the output of the NN to CSV file(s)"
dumpDir = '../dump/'
if not os.path.isdir(dumpDir):
os.mkdir(dumpDir)

# iterate over all batch elements and create a CSV file for each one
maxT, maxB, maxC = rnnOutput.shape
for b in range(maxB):
csv = ''
for t in range(maxT):
for c in range(maxC):
csv += str(rnnOutput[t, b, c]) + ';'
csv += '\n'
fn = dumpDir + 'rnnOutput_'+str(b)+'.csv'
print('Write dump of NN to file: ' + fn)
with open(fn, 'w') as f:
f.write(csv)


def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False):
"feed a batch into the NN to recognize the texts"

# decode, optionally save RNN output
numBatchElements = len(batch.imgs)
evalList = [self.decoder] + ([self.ctcIn3dTBC] if calcProbability else [])
evalRnnOutput = self.dump or calcProbability
evalList = [self.decoder] + ([self.ctcIn3dTBC] if evalRnnOutput else [])
feedDict = {self.inputImgs : batch.imgs, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
evalRes = self.sess.run([self.decoder, self.ctcIn3dTBC], feedDict)
evalRes = self.sess.run(evalList, feedDict)
decoded = evalRes[0]
texts = self.decoderOutputToText(decoded, numBatchElements)

Expand All @@ -237,6 +260,11 @@ def inferBatch(self, batch, calcProbability=False, probabilityOfGT=False):
feedDict = {self.savedCtcInput : ctcInput, self.gtTexts : sparse, self.seqLen : [Model.maxTextLen] * numBatchElements, self.is_train: False}
lossVals = self.sess.run(evalList, feedDict)
probs = np.exp(-lossVals)

# dump the output of the NN to CSV file(s)
if self.dump:
self.dumpNNOutput(evalRes[1])

return (texts, probs)


Expand Down
4 changes: 3 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ def main():
parser.add_argument('--validate', help='validate the NN', action='store_true')
parser.add_argument('--beamsearch', help='use beam search instead of best path decoding', action='store_true')
parser.add_argument('--wordbeamsearch', help='use word beam search instead of best path decoding', action='store_true')
parser.add_argument('--dump', help='dump output of NN to CSV file(s)', action='store_true')

args = parser.parse_args()

decoderType = DecoderType.BestPath
Expand Down Expand Up @@ -135,7 +137,7 @@ def main():
# infer text on test image
else:
print(open(FilePaths.fnAccuracy).read())
model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True)
model = Model(open(FilePaths.fnCharList).read(), decoderType, mustRestore=True, dump=args.dump)
infer(model, FilePaths.fnInfer)


Expand Down

0 comments on commit e8249f7

Please sign in to comment.