From 3f0444a741509e84148d7fbd4fb160f613b95bcc Mon Sep 17 00:00:00 2001 From: Tom Ko Date: Wed, 30 Nov 2016 22:08:51 -0500 Subject: [PATCH] Fix bug discovered by TDNN decoding script --- src/nnet3/decodable-simple-looped.cc | 25 +++++++++++++++++++------ src/nnet3/decodable-simple-looped.h | 5 +++++ src/nnet3/nnet-compile-looped.cc | 14 ++++++++------ 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/src/nnet3/decodable-simple-looped.cc b/src/nnet3/decodable-simple-looped.cc index 0df4c3b6c31..bb9a38632a1 100644 --- a/src/nnet3/decodable-simple-looped.cc +++ b/src/nnet3/decodable-simple-looped.cc @@ -72,7 +72,6 @@ void DecodableNnetSimpleLoopedInfo::Init( if (has_ivectors_) ModifyNnetIvectorPeriod(ivector_period, nnet); - ComputationRequest request1, request2, request3; int32 num_sequences = 1; // we're processing one utterance at a time. int32 extra_right_context = 0; CreateLoopedComputationRequestSimple(*nnet, frames_per_chunk_, @@ -80,9 +79,9 @@ void DecodableNnetSimpleLoopedInfo::Init( ivector_period, opts.extra_left_context_initial, extra_right_context, num_sequences, - &request1, &request2, &request3); + &request1_, &request2_, &request3_); - CompileLooped(*nnet, opts_.optimize_config, request1, request2, request3, + CompileLooped(*nnet, opts_.optimize_config, request1_, request2_, request3_, &computation_); computation_.ComputeCudaIndexes(); if (GetVerboseLevel() >= 3) { @@ -172,11 +171,25 @@ void DecodableNnetSimpleLooped::AdvanceChunk() { computer_.AcceptInput("input", &feats_chunk); if (info_.has_ivectors_) { + KALDI_ASSERT(info_.request1_.inputs.size() == 2); + // all but the 1st chunk should have 1 iVector, but no need + // to assume this. + int32 num_ivectors = (num_chunks_computed_ == 0 ? + info_.request1_.inputs[1].indexes.size() : + info_.request2_.inputs[1].indexes.size()); + KALDI_ASSERT(num_ivectors > 0); + Vector ivector; + // we just get the iVector from the last input frame we needed... + // we don't bother trying to be 'accurate' in getting the iVectors + // for their 'correct' frames, because in general using the + // iVector from as large 't' as possible will be better. GetCurrentIvector(end_input_frame, &ivector); - CuMatrix cu_ivector(1, ivector.Dim()); - cu_ivector.Row(0).CopyFromVec(ivector); - computer_.AcceptInput("ivector", &cu_ivector); + Matrix ivectors(num_ivectors, + ivector.Dim()); + ivectors.CopyRowsFromVec(ivector); + CuMatrix cu_ivectors(ivectors); + computer_.AcceptInput("ivector", &cu_ivectors); } computer_.Run(); diff --git a/src/nnet3/decodable-simple-looped.h b/src/nnet3/decodable-simple-looped.h index fe40c220f8f..5aba5b10505 100644 --- a/src/nnet3/decodable-simple-looped.h +++ b/src/nnet3/decodable-simple-looped.h @@ -148,6 +148,11 @@ class DecodableNnetSimpleLoopedInfo { // to accept the iVectors bool has_ivectors_; + // The 3 computation requests that are used to create the looped + // computation are stored in the class, as we need them to work out + // exactly shich iVectors are needed. + ComputationRequest request1_, request2_, request3_; + // The compiled, 'looped' computation. NnetComputation computation_; }; diff --git a/src/nnet3/nnet-compile-looped.cc b/src/nnet3/nnet-compile-looped.cc index 71329d2e8fe..d77f19ef13c 100644 --- a/src/nnet3/nnet-compile-looped.cc +++ b/src/nnet3/nnet-compile-looped.cc @@ -80,8 +80,9 @@ int32 GetChunkSize(const Nnet &nnet, /// for negative a is not specified (except by relation with the division '/' /// operator), but in practice it would be <= 0 for almost all implementations. template I Mod(I m, I n) { - if (m >= 0) return m % n; - else return -((-m) % n); + I ans = m % n; + if (ans < 0) ans += n; + return ans; } @@ -171,15 +172,16 @@ void CreateLoopedComputationRequestSimple(const Nnet &nnet, } for (int32 t = chunk2_input_begin_t; t < chunk2_input_end_t; t++) { int32 ivector_t = t - Mod(t, ivector_period); - if (ivector_times1.count(ivector_t) == 0) + if (ivector_times2.count(ivector_t) == 0 && + ivector_times1.count(ivector_t) == 0) ivector_times2.insert(ivector_t); } for (int32 t = chunk3_input_begin_t; t < chunk3_input_end_t; t++) { int32 ivector_t = t - Mod(t, ivector_period); - if (ivector_times1.count(ivector_t) == 0 && - ivector_times2.count(ivector_t) == 0) { + if (ivector_times3.count(ivector_t) == 0 && + ivector_times2.count(ivector_t) == 0 && + ivector_times1.count(ivector_t) == 0) ivector_times3.insert(ivector_t); - } } }