Skip to content

Commit

Permalink
fix: add ctc loss configuration prams.
Browse files Browse the repository at this point in the history
  • Loading branch information
kerlomz committed Dec 29, 2018
1 parent 3680d9a commit 1e1b9a3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 4 additions & 0 deletions config.py
Expand Up @@ -158,6 +158,10 @@ def char_set(_type):
TEST_BATCH_SIZE = cf_system['Trains'].get('TestBatchSize')
TEST_BATCH_SIZE = TEST_BATCH_SIZE if TEST_BATCH_SIZE else 200
MOMENTUM = 0.9
PREPROCESS_COLLAPSE_REPEATED = cf_system['Trains'].get('PreprocessCollapseRepeated')
PREPROCESS_COLLAPSE_REPEATED = PREPROCESS_COLLAPSE_REPEATED if PREPROCESS_COLLAPSE_REPEATED is not None else False
CTC_MERGE_REPEATED = cf_system['Trains'].get('CTCMergeRepeated')
CTC_MERGE_REPEATED = CTC_MERGE_REPEATED if CTC_MERGE_REPEATED is not None else True

"""PRETREATMENT"""
BINARYZATION = cf_model['Pretreatment'].get('Binaryzation')
Expand Down
4 changes: 2 additions & 2 deletions framework.py
Expand Up @@ -149,8 +149,8 @@ def _build_train_op(self):
labels=self.labels,
inputs=self.predict,
sequence_length=self.seq_len,
ctc_merge_repeated=True,
preprocess_collapse_repeated=False
ctc_merge_repeated=CTC_MERGE_REPEATED,
preprocess_collapse_repeated=PREPROCESS_COLLAPSE_REPEATED
)
self.cost = tf.reduce_mean(self.loss)
tf.summary.scalar('cost', self.cost)
Expand Down

0 comments on commit 1e1b9a3

Please sign in to comment.