From 51cac1031b69bf07193821fe8b6fbbc97bb9ef8a Mon Sep 17 00:00:00 2001 From: Reuben Morais Date: Thu, 2 Aug 2018 23:50:35 -0300 Subject: [PATCH] Address review comments --- DeepSpeech.py | 10 -- bin/ops_in_graph.py | 2 +- evaluate.py | 30 ++--- native_client/args.h | 22 +--- native_client/beam_search.h | 2 +- native_client/client.cc | 24 ++-- native_client/deepspeech.cc | 157 +++++++++++++------------- native_client/deepspeech.h | 73 +++++++----- native_client/javascript/client.js | 2 +- native_client/javascript/deepspeech.i | 12 +- native_client/javascript/index.js | 7 +- native_client/python/__init__.py | 4 +- native_client/python/impl.i | 2 +- tc-tests-utils.sh | 2 +- util/audio.py | 1 - 15 files changed, 177 insertions(+), 173 deletions(-) diff --git a/DeepSpeech.py b/DeepSpeech.py index 5219c6a60a..223f207871 100755 --- a/DeepSpeech.py +++ b/DeepSpeech.py @@ -131,7 +131,6 @@ def create_flags(): # Initialization tf.app.flags.DEFINE_integer ('random_seed', 4567, 'default random seed that is used to initialize variables') - tf.app.flags.DEFINE_float ('default_stddev', 0.046875, 'default standard deviation to use when initialising weights and biases') # Early Stopping @@ -163,9 +162,6 @@ def create_flags(): tf.app.flags.DEFINE_string ('initialize_from_frozen_model', '', 'path to frozen model to initialize from. This behaves like a checkpoint, loading the weights from the frozen model and starting training with those weights. The optimizer parameters aren\'t restored, so remember to adjust the learning rate.') - for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']: - tf.app.flags.DEFINE_float('%s_stddev' % var, None, 'standard deviation to use when initialising %s' % var) - FLAGS = tf.app.flags.FLAGS def initialize_globals(): @@ -285,12 +281,6 @@ def initialize_globals(): global n_hidden_6 n_hidden_6 = n_character - # Assign default values for standard deviation - for var in ['b1', 'h1', 'b2', 'h2', 'b3', 'h3', 'b5', 'h5', 'b6', 'h6']: - val = getattr(FLAGS, '%s_stddev' % var) - if val is None: - setattr(FLAGS, '%s_stddev' % var, FLAGS.default_stddev) - # Queues that are used to gracefully stop parameter servers. # Each queue stands for one ps. A finishing worker sends a token to each queue before joining/quitting. # Each ps will dequeue as many tokens as there are workers before joining/quitting. diff --git a/bin/ops_in_graph.py b/bin/ops_in_graph.py index ee0f2ff54d..9078a91d2e 100755 --- a/bin/ops_in_graph.py +++ b/bin/ops_in_graph.py @@ -6,6 +6,6 @@ with tf.gfile.FastGFile(sys.argv[1], 'rb') as fin: graph_def = tf.GraphDef() - graph_def.MergeFromString(fin.read()) + graph_def.ParseFromString(fin.read()) print('\n'.join(sorted(set(n.op for n in graph_def.node)))) diff --git a/evaluate.py b/evaluate.py index 6115a3e5c6..5725b23c64 100755 --- a/evaluate.py +++ b/evaluate.py @@ -84,26 +84,28 @@ def preprocess(dataset_files, batch_size, hdf5_dest_path=None): features, features_len, transcript, transcript_len = zip(*out_data) with tables.open_file(hdf5_dest_path, 'w') as file: - features_dset = file.create_vlarray(file.root, 'features', - tables.Float32Atom(shape=()), filters=tables.Filters(complevel=1)) + features_dset = file.create_vlarray(file.root, + 'features', + tables.Float32Atom(), + filters=tables.Filters(complevel=1)) # VLArray atoms need to be 1D, so flatten feature array for f in features: features_dset.append(np.reshape(f, -1)) - features_len_dset = file.create_array( - file.root, 'features_len', features_len) + features_len_dset = file.create_array(file.root, + 'features_len', + features_len) - transcript_dset = file.create_vlarray( - file.root, - 'transcript', - tables.Int32Atom(), - filters=tables.Filters( - complevel=1)) + transcript_dset = file.create_vlarray(file.root, + 'transcript', + tables.Int32Atom(), + filters=tables.Filters(complevel=1)) for t in transcript: transcript_dset.append(t) - transcript_len_dset = file.create_array( - file.root, 'transcript_len', transcript_len) + transcript_len_dset = file.create_array(file.root, + 'transcript_len', + transcript_len) return pandas.DataFrame(data=out_data, columns=COLUMNS) @@ -159,8 +161,8 @@ def calculate_report(labels, decodings, distances, losses): # Order the remaining items by their loss (lowest loss on top) samples.sort(key=lambda s: s.loss) - # Then order by WER (lowest WER on top) - samples.sort(key=lambda s: s.wer) + # Then order by WER (highest WER on top) + samples.sort(key=lambda s: s.wer, reverse=True) return samples_wer, samples diff --git a/native_client/args.h b/native_client/args.h index c4b1438b9d..27899c182a 100644 --- a/native_client/args.h +++ b/native_client/args.h @@ -6,20 +6,15 @@ #include "deepspeech.h" -bool has_model = false; -char* model; +char* model = NULL; -bool has_alphabet = false; -char* alphabet; +char* alphabet = NULL; -bool has_lm = false; -char* lm; +char* lm = NULL; -bool has_trie = false; -char* trie; +char* trie = NULL; -bool has_audio = false; -char* audio; +char* audio = NULL; bool show_times = false; @@ -70,27 +65,22 @@ bool ProcessArgs(int argc, char** argv) { case 'm': model = optarg; - has_model = true; break; case 'a': alphabet = optarg; - has_alphabet = true; break; case 'l': lm = optarg; - has_lm = true; break; case 'r': trie = optarg; - has_trie = true; break; case 'w': audio = optarg; - has_audio = true; break; case 't': @@ -114,7 +104,7 @@ bool ProcessArgs(int argc, char** argv) return false; } - if (!has_model || !has_alphabet || !has_audio || strlen(alphabet) == 0 || strlen(audio) == 0) { + if (!model || !alphabet || !audio) { PrintHelp(argv[0]); return false; } diff --git a/native_client/beam_search.h b/native_client/beam_search.h index c72e51625a..dab58db0ea 100644 --- a/native_client/beam_search.h +++ b/native_client/beam_search.h @@ -119,7 +119,7 @@ class KenLMBeamScorer : public tensorflow::ctc::BaseBeamScorer { // score to this beam's score. state->score += lm_weight_ * state->delta_score; if (state->num_words > 0) { - float normalized_score = state->score /= (float)state->num_words; + float normalized_score = state->score / (float)state->num_words; state->delta_score = normalized_score - state->score; } } diff --git a/native_client/client.cc b/native_client/client.cc index a7e947bf3d..d477284259 100644 --- a/native_client/client.cc +++ b/native_client/client.cc @@ -32,7 +32,7 @@ typedef struct { } ds_result; ds_result -LocalDsSTT(ModelState* aCtx, short* aBuffer, size_t aBufferSize, +LocalDsSTT(ModelState* aCtx, const short* aBuffer, size_t aBufferSize, int aSampleRate) { ds_result res = {0}; @@ -170,7 +170,7 @@ ProcessFile(ModelState* context, const char* path, bool show_times) // We take half of buffer_size because buffer is a char* while // LocalDsSTT() expected a short* ds_result result = LocalDsSTT(context, - (short*)audio.buffer, + (const short*)audio.buffer, audio.buffer_size / 2, audio.sample_rate); free(audio.buffer); @@ -197,17 +197,21 @@ main(int argc, char **argv) ModelState* ctx; int status = DS_CreateModel(model, N_CEP, N_CONTEXT, alphabet, BEAM_WIDTH, &ctx); if (status != 0) { + fprintf(stderr, "Could not create model.\n"); return 1; } - if (has_lm && has_trie) { - DS_EnableDecoderWithLM( - ctx, - alphabet, - lm, - trie, - LM_WEIGHT, - VALID_WORD_COUNT_WEIGHT); + if (lm && trie) { + int status = DS_EnableDecoderWithLM(ctx, + alphabet, + lm, + trie, + LM_WEIGHT, + VALID_WORD_COUNT_WEIGHT); + if (status != 0) { + fprintf(stderr, "Could not enable CTC decoder with LM.\n"); + return 1; + } } // Initialise SOX diff --git a/native_client/deepspeech.cc b/native_client/deepspeech.cc index b4486a53af..731c75da61 100644 --- a/native_client/deepspeech.cc +++ b/native_client/deepspeech.cc @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -26,23 +27,23 @@ #include "c_speech_features.h" //TODO: infer batch size from model/use dynamic batch size -const int BATCH_SIZE = 1; +const unsigned int BATCH_SIZE = 1; //TODO: use dynamic sample rate -const int SAMPLE_RATE = 16000; +const unsigned int SAMPLE_RATE = 16000; const float AUDIO_WIN_LEN = 0.025f; const float AUDIO_WIN_STEP = 0.01f; -const int AUDIO_WIN_LEN_SAMPLES = (int)(AUDIO_WIN_LEN * SAMPLE_RATE); -const int AUDIO_WIN_STEP_SAMPLES = (int)(AUDIO_WIN_STEP * SAMPLE_RATE); +const unsigned int AUDIO_WIN_LEN_SAMPLES = (unsigned int)(AUDIO_WIN_LEN * SAMPLE_RATE); +const unsigned int AUDIO_WIN_STEP_SAMPLES = (unsigned int)(AUDIO_WIN_STEP * SAMPLE_RATE); -const int MFCC_FEATURES = 26; +const unsigned int MFCC_FEATURES = 26; const float PREEMPHASIS_COEFF = 0.97f; -const int N_FFT = 512; -const int N_FILTERS = 26; -const int LOWFREQ = 0; -const int CEP_LIFTER = 22; +const unsigned int N_FFT = 512; +const unsigned int N_FILTERS = 26; +const unsigned int LOWFREQ = 0; +const unsigned int CEP_LIFTER = 22; using namespace tensorflow; using tensorflow::ctc::CTCBeamSearchDecoder; @@ -50,36 +51,36 @@ using tensorflow::ctc::CTCDecoder; using std::vector; +/* This is the actual implementation of the streaming inference API, with the + Model class just forwarding the calls to this class. + + The streaming process uses three buffers that are fed eagerly as audio data + is fed in. The buffers only hold the minimum amount of data needed to do a + step in the acoustic model. The three buffers which live in StreamingContext + are: + + - audio_buffer, used to buffer audio samples until there's enough data to + compute input features for a single window. + + - mfcc_buffer, used to buffer input features until there's enough data for + a single timestep. Remember there's overlap in the features, each timestep + contains n_context past feature frames, the current feature frame, and + n_context future feature frames, for a total of 2*n_context + 1 feature + frames per timestep. + + - batch_buffer, used to buffer timesteps until there's enough data to compute + a batch of n_steps. + + Data flows through all three buffers as audio samples are fed via the public + API. When audio_buffer is full, features are computed from it and pushed to + mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer. + When batch_buffer is full, we do a single step through the acoustic model + and accumulate results in StreamingState::accumulated_logits. + + When fininshStream() is called, we decode the accumulated logits and return + the corresponding transcription. +*/ struct StreamingState { - /* This is the actual implementation of the streaming inference API, with the - Model class just forwarding the calls to this class. - - The streaming process uses three buffers that are fed eagerly as audio data - is fed in. The buffers only hold the minimum amount of data needed to do a - step in the acoustic model. The three buffers which live in StreamingContext - are: - - - audio_buffer, used to buffer audio samples until there's enough data to - compute input features for a single window. - - - mfcc_buffer, used to buffer input features until there's enough data for - a single timestep. Remember there's overlap in the features, each timestep - contains n_context past feature frames, the current feature frame, and - n_context future feature frames, for a total of 2*n_context + 1 feature - frames per timestep. - - - batch_buffer, used to buffer timesteps until there's enough data to compute - a batch of n_steps. - - Data flows through all three buffers as audio samples are fed via the public - API. When audio_buffer is full, features are computed from it and pushed to - mfcc_buffer. When mfcc_buffer is full, the timestep is copied to batch_buffer. - When batch_buffer is full, we do a single step through the acoustic model - and accumulate results in StreamingState::accumulated_logits. - - When fininshStream() is called, we decode the accumulated logits and return - the corresponding transcription. - */ vector accumulated_logits; vector audio_buffer; float last_sample; // used for preemphasis @@ -103,15 +104,15 @@ struct ModelState { MemmappedEnv* mmap_env; Session* session; GraphDef graph_def; - int ncep; - int ncontext; + unsigned int ncep; + unsigned int ncontext; Alphabet* alphabet; KenLMBeamScorer* scorer; - int beam_width; + unsigned int beam_width; bool run_aot; - int n_steps; - int mfcc_feats_per_timestep; - int n_context; + unsigned int n_steps; + unsigned int mfcc_feats_per_timestep; + unsigned int n_context; ModelState(); ~ModelState(); @@ -120,7 +121,6 @@ struct ModelState { * @brief Perform decoding of the logits, using basic CTC decoder or * CTC decoder with KenLM enabled * - * @param n_frames Number of timesteps to deal with * @param logits Flat matrix of logits, of size: * n_frames * batch_size * num_classes * @@ -136,10 +136,9 @@ struct ModelState { * @param mfcc batch input data * @param n_frames number of timesteps in the data * - * @param[out] output_logits Should be large enough to fit - * aNFrames * alphabet_size floats. + * @param[out] output_logits Where to store computed logits. */ - void infer(const float* mfcc, int n_frames, vector& output_logits); + void infer(const float* mfcc, unsigned int n_frames, vector& output_logits); }; ModelState::ModelState() @@ -294,7 +293,7 @@ StreamingState::processBatch(const vector& buf, unsigned int n_steps) } void -ModelState::infer(const float* aMfcc, int n_frames, vector& logits_output) +ModelState::infer(const float* aMfcc, unsigned int n_frames, vector& logits_output) { const size_t num_classes = alphabet->GetSize() + 1; // +1 for blank @@ -410,14 +409,14 @@ ModelState::decode(vector& logits) } int -DS_CreateModel(char* aModelPath, - int aNCep, - int aNContext, - char* aAlphabetConfigPath, - int aBeamWidth, +DS_CreateModel(const char* aModelPath, + unsigned int aNCep, + unsigned int aNContext, + const char* aAlphabetConfigPath, + unsigned int aBeamWidth, ModelState** retval) { - ModelState* model = new ModelState(); + std::unique_ptr model(new ModelState()); model->mmap_env = new MemmappedEnv(Env::Default()); model->ncep = aNCep; model->ncontext = aNContext; @@ -469,14 +468,12 @@ DS_CreateModel(char* aModelPath, } if (!status.ok()) { std::cerr << status << std::endl; - delete model; return status.code(); } status = model->session->Create(model->graph_def); if (!status.ok()) { std::cerr << status << std::endl; - delete model; return status.code(); } @@ -502,7 +499,6 @@ DS_CreateModel(char* aModelPath, << " classes in its output. Make sure you're passing an alphabet " << "file with the same size as the one used for training." << std::endl; - delete model; return error::INVALID_ARGUMENT; } } @@ -515,11 +511,10 @@ DS_CreateModel(char* aModelPath, << "changed the number of features in the input, adjust the " << "MFCC_FEATURES constant in " __FILE__ << std::endl; - delete model; return error::INVALID_ARGUMENT; } - *retval = model; + *retval = model.release(); return tensorflow::error::OK; } @@ -529,26 +524,31 @@ DS_DestroyModel(ModelState* ctx) delete ctx; } -void +int DS_EnableDecoderWithLM(ModelState* aCtx, - char* aAlphabetConfigPath, - char* aLMPath, - char* aTriePath, + const char* aAlphabetConfigPath, + const char* aLMPath, + const char* aTriePath, float aLMWeight, float aValidWordCountWeight) { - aCtx->scorer = new KenLMBeamScorer(aLMPath, aTriePath, aAlphabetConfigPath, - aLMWeight, aValidWordCountWeight); + try { + aCtx->scorer = new KenLMBeamScorer(aLMPath, aTriePath, aAlphabetConfigPath, + aLMWeight, aValidWordCountWeight); + return 0; + } catch (...) { + return 1; + } } char* DS_SpeechToText(ModelState* aCtx, - short* aBuffer, + const short* aBuffer, unsigned int aBufferSize, - int aSampleRate) + unsigned int aSampleRate) { StreamingState* ctx; - int status = DS_SetupStream(aCtx, 150, aSampleRate, &ctx); + int status = DS_SetupStream(aCtx, 0, aSampleRate, &ctx); if (status != tensorflow::error::OK) { return nullptr; } @@ -570,7 +570,7 @@ DS_SetupStream(ModelState* aCtx, return status.code(); } - StreamingState* ctx = new StreamingState; + std::unique_ptr ctx(new StreamingState()); if (!ctx) { std::cerr << "Could not allocate streaming state." << std::endl; return status.code(); @@ -578,6 +578,11 @@ DS_SetupStream(ModelState* aCtx, const size_t num_classes = aCtx->alphabet->GetSize() + 1; // +1 for blank + // Default initial allocation = 3 seconds. + if (aPreAllocFrames == 0) { + aPreAllocFrames = 150; + } + ctx->accumulated_logits.reserve(aPreAllocFrames * BATCH_SIZE * num_classes); ctx->audio_buffer.reserve(AUDIO_WIN_LEN_SAMPLES); @@ -590,13 +595,13 @@ DS_SetupStream(ModelState* aCtx, ctx->model = aCtx; - *retval = ctx; + *retval = ctx.release(); return tensorflow::error::OK; } void DS_FeedAudioContent(StreamingState* aSctx, - short* aBuffer, + const short* aBuffer, unsigned int aBufferSize) { aSctx->feedAudioContent(aBuffer, aBufferSize); @@ -617,11 +622,11 @@ DS_FinishStream(StreamingState* aSctx) } void -DS_AudioToInputVector(short* aBuffer, +DS_AudioToInputVector(const short* aBuffer, unsigned int aBufferSize, - int aSampleRate, - int aNCep, - int aNContext, + unsigned int aSampleRate, + unsigned int aNCep, + unsigned int aNContext, float** aMfcc, int* aNFrames, int* aFrameLen) diff --git a/native_client/deepspeech.h b/native_client/deepspeech.h index 4eea35bb15..5f9d603a43 100644 --- a/native_client/deepspeech.h +++ b/native_client/deepspeech.h @@ -27,11 +27,11 @@ struct StreamingState; * @return Zero on success, non-zero on failure. */ DEEPSPEECH_EXPORT -int DS_CreateModel(char* aModelPath, - int aNCep, - int aNContext, - char* aAlphabetConfigPath, - int aBeamWidth, +int DS_CreateModel(const char* aModelPath, + unsigned int aNCep, + unsigned int aNContext, + const char* aAlphabetConfigPath, + unsigned int aBeamWidth, ModelState** retval); /** @@ -53,14 +53,16 @@ void DS_DestroyModel(ModelState* ctx); * ring. * @param aValidWordCountWeight The weight (bonus) to give to beams when * adding a new valid word to the decoding. + * + * @return Zero on success, non-zero on failure (invalid arguments). */ DEEPSPEECH_EXPORT -void DS_EnableDecoderWithLM(ModelState* aCtx, - char* aAlphabetConfigPath, - char* aLMPath, - char* aTriePath, - float aLMWeight, - float aValidWordCountWeight); +int DS_EnableDecoderWithLM(ModelState* aCtx, + const char* aAlphabetConfigPath, + const char* aLMPath, + const char* aTriePath, + float aLMWeight, + float aValidWordCountWeight); /** * @brief Use the DeepSpeech model to perform Speech-To-Text. @@ -72,25 +74,26 @@ void DS_EnableDecoderWithLM(ModelState* aCtx, * @param aSampleRate The sample-rate of the audio signal. * * @return The STT result. The user is responsible for freeing the string. + * Returns NULL on error. */ DEEPSPEECH_EXPORT char* DS_SpeechToText(ModelState* aCtx, - short* aBuffer, + const short* aBuffer, unsigned int aBufferSize, - int aSampleRate); + unsigned int aSampleRate); /** - * @brief Setup a context used for performing streaming inference. - * the context pointer returned by this function can then be passed - * to {@link DS_FeedAudioContent()} and {@link DS_FinishStream()}. + * @brief Create a new streaming inference state. The streaming state returned + * by this function can then be passed to {@link DS_FeedAudioContent()} + * and {@link DS_FinishStream()}. * + * @param aCtx The ModelState pointer for the model to use. * @param aPreAllocFrames Number of timestep frames to reserve. One timestep - * is equivalent to two window lengths (50ms), so - * by default we reserve enough frames for 3 seconds - * of audio. + * is equivalent to two window lengths (20ms). If set to + * 0 we reserve enough frames for 3 seconds of audio (150). * @param aSampleRate The sample-rate of the audio signal. - * @param[out] retval a context pointer that represents the streaming state. Can - * be null if an error occurs. + * @param[out] retval an opaque pointer that represents the streaming state. Can + * be NULL if an error occurs. * * @return Zero for success, non-zero on failure. */ @@ -103,17 +106,27 @@ int DS_SetupStream(ModelState* aCtx, /** * @brief Feed audio samples to an ongoing streaming inference. * - * @param aCtx A streaming context pointer returned by {@link DS_SetupStream()}. + * @param aSctx A streaming state pointer returned by {@link DS_SetupStream()}. * @param aBuffer An array of 16-bit, mono raw audio samples at the * appropriate sample rate. * @param aBufferSize The number of samples in @p aBuffer. */ DEEPSPEECH_EXPORT void DS_FeedAudioContent(StreamingState* aSctx, - short* aBuffer, + const short* aBuffer, unsigned int aBufferSize); - +/** + * @brief Compute the intermediate decoding of an ongoing streaming inference. + * This is an expensive process as the decoder implementation is isn't + * currently capable of streaming, so it always starts from the beginning + * of the audio. + * + * @param aSctx A streaming state pointer returned by {@link DS_SetupStream()}. + * + * @return The STT intermediate result. The user is responsible for freeing the + * string. + */ DEEPSPEECH_EXPORT char* DS_IntermediateDecode(StreamingState* aSctx); @@ -121,11 +134,11 @@ char* DS_IntermediateDecode(StreamingState* aSctx); * @brief Signal the end of an audio signal to an ongoing streaming * inference, returns the STT result over the whole audio signal. * - * @param aSctx A streaming context pointer returned by {@link DS_SetupStream()}. + * @param aSctx A streaming state pointer returned by {@link DS_SetupStream()}. * * @return The STT result. The user is responsible for freeing the string. * - * @note This method will free the context pointer (@p aCtx). + * @note This method will free the state pointer (@p aSctx). */ DEEPSPEECH_EXPORT char* DS_FinishStream(StreamingState* aSctx); @@ -152,11 +165,11 @@ char* DS_FinishStream(StreamingState* aSctx); * (ncep * ncontext) in @p aMfcc. */ DEEPSPEECH_EXPORT -void DS_AudioToInputVector(short* aBuffer, +void DS_AudioToInputVector(const short* aBuffer, unsigned int aBufferSize, - int aSampleRate, - int aNCep, - int aNContext, + unsigned int aSampleRate, + unsigned int aNCep, + unsigned int aNContext, float** aMfcc, int* aNFrames = NULL, int* aFrameLen = NULL); diff --git a/native_client/javascript/client.js b/native_client/javascript/client.js index 7d2ac0e94c..86e06c4d0d 100644 --- a/native_client/javascript/client.js +++ b/native_client/javascript/client.js @@ -41,7 +41,7 @@ parser.addArgument(['--version'], {help: 'Print version and exits'}) var args = parser.parseArgs(); if (args['version']) { - Ds.print_versions(); + Ds.printVersions(); return 0; } diff --git a/native_client/javascript/deepspeech.i b/native_client/javascript/deepspeech.i index b355b11d32..a71cf20c05 100644 --- a/native_client/javascript/deepspeech.i +++ b/native_client/javascript/deepspeech.i @@ -23,23 +23,23 @@ using namespace node; } // apply to DS_FeedAudioContent and DS_SpeechToText -%apply (short* IN_ARRAY1, int DIM1) {(short* aBuffer, unsigned int aBufferSize)}; +%apply (short* IN_ARRAY1, int DIM1) {(const short* aBuffer, unsigned int aBufferSize)}; // convert DS_AudioToInputVector return values to a Node Buffer %typemap(in,numinputs=0) - (float** ARGOUTVIEWM_ARRAY2, int* DIM1, int* DIM2) - (float* data_temp, int dim1_temp, int dim2_temp) + (float** ARGOUTVIEWM_ARRAY2, unsigned int* DIM1, unsigned int* DIM2) + (float* data_temp, unsigned int dim1_temp, unsigned int dim2_temp) { $1 = &data_temp; $2 = &dim1_temp; $3 = &dim2_temp; } %typemap(argout) - (float** ARGOUTVIEWM_ARRAY2, int* DIM1, int* DIM2) + (float** ARGOUTVIEWM_ARRAY2, unsigned int* DIM1, unsigned int* DIM2) { Handle array = Array::New(Isolate::GetCurrent(), *$2); - for (int i = 0, idx = 0; i < *$2; i++) { + for (unsigned int i = 0, idx = 0; i < *$2; i++) { Handle buffer = ArrayBuffer::New(Isolate::GetCurrent(), *$1, *$3 * sizeof(float)); memcpy(buffer->GetContents().Data(), @@ -51,7 +51,7 @@ using namespace node; $result = array; } -%apply (float** ARGOUTVIEWM_ARRAY2, int* DIM1, int* DIM2) {(float** aMfcc, int* aNFrames, int* aFrameLen)}; +%apply (float** ARGOUTVIEWM_ARRAY2, unsigned int* DIM1, unsigned int* DIM2) {(float** aMfcc, unsigned int* aNFrames, unsigned int* aFrameLen)}; // make sure the string returned by SpeechToText is freed %typemap(newfree) char* "free($1);"; diff --git a/native_client/javascript/index.js b/native_client/javascript/index.js index 855224ebd2..798d4f11d0 100644 --- a/native_client/javascript/index.js +++ b/native_client/javascript/index.js @@ -19,7 +19,7 @@ function Model() { Model.prototype.enableDecoderWithLM = function() { const args = [this._impl].concat(Array.prototype.slice.call(arguments)); - binding.EnableDecoderWithLM.apply(null, args); + return binding.EnableDecoderWithLM.apply(null, args); } Model.prototype.stt = function() { @@ -43,7 +43,7 @@ Model.prototype.feedAudioContent = function() { } Model.prototype.intermediateDecode = function() { - binding.IntermediateDecode.apply(null, arguments); + return binding.IntermediateDecode.apply(null, arguments); } Model.prototype.finishStream = function() { @@ -52,5 +52,6 @@ Model.prototype.finishStream = function() { module.exports = { Model: Model, - audioToInputVector: binding.AudioToInputVector + audioToInputVector: binding.AudioToInputVector, + printVersions: binding.PrintVersions }; diff --git a/native_client/python/__init__.py b/native_client/python/__init__.py index 417a3172da..59912d2c2e 100644 --- a/native_client/python/__init__.py +++ b/native_client/python/__init__.py @@ -20,7 +20,7 @@ def __del__(self): self._impl = None def enableDecoderWithLM(self, *args, **kwargs): - deepspeech.impl.EnableDecoderWithLM(self._impl, *args, **kwargs) + return deepspeech.impl.EnableDecoderWithLM(self._impl, *args, **kwargs) def stt(self, *args, **kwargs): return deepspeech.impl.SpeechToText(self._impl, *args, **kwargs) @@ -37,7 +37,7 @@ def feedAudioContent(self, *args, **kwargs): deepspeech.impl.FeedAudioContent(*args, **kwargs) def intermediateDecode(self, *args, **kwargs): - deepspeech.impl.IntermediateDecode(*args, **kwargs) + return deepspeech.impl.IntermediateDecode(*args, **kwargs) def finishStream(self, *args, **kwargs): return deepspeech.impl.FinishStream(*args, **kwargs) diff --git a/native_client/python/impl.i b/native_client/python/impl.i index 80d8004efb..6564442f7e 100644 --- a/native_client/python/impl.i +++ b/native_client/python/impl.i @@ -11,7 +11,7 @@ import_array(); %} // apply NumPy conversion typemap to DS_FeedAudioContent and DS_SpeechToText -%apply (short* IN_ARRAY1, int DIM1) {(short* aBuffer, unsigned int aBufferSize)}; +%apply (short* IN_ARRAY1, int DIM1) {(const short* aBuffer, unsigned int aBufferSize)}; // apply NumPy conversion typemap to DS_AudioToInputVector %apply (float** ARGOUTVIEWM_ARRAY2, int* DIM1, int* DIM2) {(float** aMfcc, int* aNFrames, int* aFrameLen)}; diff --git a/tc-tests-utils.sh b/tc-tests-utils.sh index 42fa454bef..b0ef70f4be 100755 --- a/tc-tests-utils.sh +++ b/tc-tests-utils.sh @@ -256,7 +256,7 @@ run_prod_inference_tests() assert_correct_ldc93s1_prodmodel "${phrase_pbmodel_withlm}" phrase_pbmodel_withlm_stereo_44k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_2_44100.wav) - assert_correct_inference "${phrase_pbmodel_withlm_stereo_44k}" "she had to ducksoan greasy wash water all year" + assert_correct_inference "${phrase_pbmodel_withlm_stereo_44k}" "she had to ducksoan greasy wash water all earl" phrase_pbmodel_withlm_mono_8k=$(deepspeech --model ${TASKCLUSTER_TMP_DIR}/${model_name_mmap} --alphabet ${TASKCLUSTER_TMP_DIR}/alphabet.txt --lm ${TASKCLUSTER_TMP_DIR}/lm.binary --trie ${TASKCLUSTER_TMP_DIR}/trie --audio ${TASKCLUSTER_TMP_DIR}/LDC93S1_pcms16le_1_8000.wav 2>&1 1>/dev/null) assert_correct_warning_upsampling "${phrase_pbmodel_withlm_mono_8k}" diff --git a/util/audio.py b/util/audio.py index eb0ce3b9a0..f6f7cc2005 100644 --- a/util/audio.py +++ b/util/audio.py @@ -2,7 +2,6 @@ import scipy.io.wavfile as wav import sys -import math import warnings class DeepSpeechDeprecationWarning(DeprecationWarning):