From df375d84afd5c205dd4d11a97aa3545c8ffed7e3 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Mon, 8 Nov 2021 17:21:18 +0800 Subject: [PATCH] Refactor bin/decode.cu (#869) * Add CTC decode. * Add HLG decoding. * Add n-gram LM rescoring. * Remove unused files. * Fix style issues. * Add missing files. --- .gitignore | 2 + cmake/sentencepiece.cmake | 2 +- k2/torch/bin/CMakeLists.txt | 44 ++++-- k2/torch/bin/ctc_decode.cu | 210 ++++++++++++++++++++++++++ k2/torch/bin/hlg_decode.cu | 219 +++++++++++++++++++++++++++ k2/torch/bin/ngram_lm_rescore.cu | 245 +++++++++++++++++++++++++++++++ k2/torch/csrc/CPPLINT.cfg | 1 + k2/torch/csrc/utils.cu | 1 + 8 files changed, 713 insertions(+), 11 deletions(-) create mode 100644 k2/torch/bin/ctc_decode.cu create mode 100644 k2/torch/bin/hlg_decode.cu create mode 100644 k2/torch/bin/ngram_lm_rescore.cu diff --git a/.gitignore b/.gitignore index bfdf9eb31..39db3723a 100644 --- a/.gitignore +++ b/.gitignore @@ -585,3 +585,5 @@ Mkfile.old dkms.conf !.github/** +!k2/torch/bin +*-bak diff --git a/cmake/sentencepiece.cmake b/cmake/sentencepiece.cmake index 0ff608165..3055a3d28 100644 --- a/cmake/sentencepiece.cmake +++ b/cmake/sentencepiece.cmake @@ -44,7 +44,7 @@ function(download_sentencepiece) add_subdirectory(${sentencepiece_SOURCE_DIR} ${sentencepiece_BINARY_DIR} EXCLUDE_FROM_ALL) - # we will link to the static library of sentencepiece + # Link to sentencepiece statically target_include_directories(sentencepiece-static INTERFACE ${sentencepiece_SOURCE_DIR} diff --git a/k2/torch/bin/CMakeLists.txt b/k2/torch/bin/CMakeLists.txt index 0d95edc12..531d119f3 100644 --- a/k2/torch/bin/CMakeLists.txt +++ b/k2/torch/bin/CMakeLists.txt @@ -1,17 +1,41 @@ # it is located in k2/csrc/cmake/transform.cmake include(transform) -set(decode_srcs decode.cu) +set(bin_dep_libs + ${TORCH_LIBRARIES} + k2_torch + sentencepiece-static # see cmake/sentencepiece.cmake +) + +#---------------------------------------- +# CTC decoding +#---------------------------------------- +set(ctc_decode_srcs ctc_decode.cu) if(NOT K2_WITH_CUDA) - transform(OUTPUT_VARIABLE decode_srcs SRCS ${decode_srcs}) + transform(OUTPUT_VARIABLE ctc_decode_srcs SRCS ${ctc_decode_srcs}) endif() +add_executable(ctc_decode ${ctc_decode_srcs}) +set_property(TARGET ctc_decode PROPERTY CXX_STANDARD 14) +target_link_libraries(ctc_decode ${bin_dep_libs}) -add_executable(decode ${decode_srcs}) +#---------------------------------------- +# HLG decoding +#---------------------------------------- +set(hlg_decode_srcs hlg_decode.cu) +if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE hlg_decode_srcs SRCS ${hlg_decode_srcs}) +endif() +add_executable(hlg_decode ${hlg_decode_srcs}) +set_property(TARGET hlg_decode PROPERTY CXX_STANDARD 14) +target_link_libraries(hlg_decode ${bin_dep_libs}) -set_property(TARGET decode PROPERTY CXX_STANDARD 14) -target_link_libraries(decode - ${TORCH_LIBRARIES} # see cmake/torch.cmake - context - k2_torch - sentencepiece-static # see cmake/sentencepiece.cmake -) +#------------------------------------------- +# HLG decoding + n-gram LM rescoring +#------------------------------------------- +set(ngram_lm_rescore_srcs ngram_lm_rescore.cu) +if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE ngram_lm_rescore_srcs SRCS ${ngram_lm_rescore_srcs}) +endif() +add_executable(ngram_lm_rescore ${ngram_lm_rescore_srcs}) +set_property(TARGET ngram_lm_rescore PROPERTY CXX_STANDARD 14) +target_link_libraries(ngram_lm_rescore ${bin_dep_libs}) diff --git a/k2/torch/bin/ctc_decode.cu b/k2/torch/bin/ctc_decode.cu new file mode 100644 index 000000000..25b43b553 --- /dev/null +++ b/k2/torch/bin/ctc_decode.cu @@ -0,0 +1,210 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/torch/csrc/decode.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/wave_reader.h" +#include "sentencepiece_processor.h" // NOLINT +#include "torch/all.h" +#include "torch/script.h" + +static constexpr const char *kUsageMessage = R"( +This file implements decoding with a CTC topology, without any +kinds of LM or lexicons. + +Usage: + ./bin/ctc_decode \ + --use_gpu true \ + --nn_model \ + --bpe_model \ + \ + \ + + +To see all possible options, use + ./bin/ctc_decode --help + +Caution: + - Only sound files (*.wav) with single channel are supported. + - It assumes the model is conformer_ctc/transformer.py from icefall. + If you use a different model, you have to change the code + related to `model.forward` in this file. +)"; + +C10_DEFINE_bool(use_gpu, false, "true to use GPU; false to use CPU"); +C10_DEFINE_string(nn_model, "", "Path to the model exported by torch script."); +C10_DEFINE_string(bpe_model, "", "Path to the pretrained BPE model."); + +// Fsa decoding related +C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned"); +C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned"); +C10_DEFINE_int(min_activate_states, 30, + "min_activate_states in IntersectDensePruned"); +C10_DEFINE_int(max_activate_states, 10000, + "max_activate_states in IntersectDensePruned"); +// Fbank related +// NOTE: These parameters must match those used in training +C10_DEFINE_int(sample_rate, 16000, "Expected sample rate of wave files"); +C10_DEFINE_double(frame_shift_ms, 10.0, + "Frame shift in ms for computing Fbank"); +C10_DEFINE_double(frame_length_ms, 25.0, + "Frame length in ms for computing Fbank"); +C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank"); + +static void CheckArgs() { +#if !defined(K2_WITH_CUDA) + if (FLAGS_use_gpu) { + std::cerr << "k2 was not compiled with CUDA. " + "Please use --use_gpu false"; + exit(EXIT_FAILURE); + } +#endif + + if (FLAGS_nn_model.empty()) { + std::cerr << "Please provide --nn_model\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_bpe_model.empty()) { + std::cerr << "Please provide --bpe_model\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } +} + +int main(int argc, char *argv[]) { + // see + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html + torch::set_num_threads(1); + torch::set_num_interop_threads(1); + torch::NoGradGuard no_grad; + + torch::SetUsageMessage(kUsageMessage); + torch::ParseCommandLineFlags(&argc, &argv); + CheckArgs(); + + torch::Device device(torch::kCPU); + if (FLAGS_use_gpu) { + K2_LOG(INFO) << "Use GPU"; + device = torch::Device(torch::kCUDA, 0); + } + + K2_LOG(INFO) << "Device: " << device; + + int32_t num_waves = argc - 1; + K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; + std::vector wave_filenames(num_waves); + for (int32_t i = 0; i != num_waves; ++i) { + wave_filenames[i] = argv[i + 1]; + } + + K2_LOG(INFO) << "Load wave files"; + auto wave_data = k2::ReadWave(wave_filenames, FLAGS_sample_rate); + + for (auto &w : wave_data) { + w = w.to(device); + } + + K2_LOG(INFO) << "Build Fbank computer"; + kaldifeat::FbankOptions fbank_opts; + fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate; + fbank_opts.frame_opts.dither = 0; + fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms; + fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms; + fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + fbank_opts.device = device; + + kaldifeat::Fbank fbank(fbank_opts); + + K2_LOG(INFO) << "Compute features"; + std::vector num_frames; + auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames); + + // Note: math.log(1e-10) is -23.025850929940457 + auto features = torch::nn::utils::rnn::pad_sequence(features_vec, true, + -23.025850929940457f); + + K2_LOG(INFO) << "Load neural network model"; + torch::jit::script::Module module = torch::jit::load(FLAGS_nn_model); + module.eval(); + module.to(device); + + int32_t subsampling_factor = module.attr("subsampling_factor").toInt(); + torch::Dict sup; + sup.insert("sequence_idx", torch::arange(num_waves, torch::kInt)); + sup.insert("start_frame", torch::zeros({num_waves}, torch::kInt)); + sup.insert("num_frames", + torch::from_blob(num_frames.data(), {num_waves}, torch::kLong) + .to(torch::kInt)); + + torch::IValue supervisions(sup); + + K2_LOG(INFO) << "Compute nnet_output"; + // the output for module.forward() is a tuple of 3 tensors + // See the definition of the model in conformer_ctc/transformer.py + // from icefall. + // If you use a model that has a different signature for `forward`, + // you can change the following line. + auto outputs = module.run_method("forward", features, supervisions).toTuple(); + assert(outputs->elements().size() == 3u); + + auto nnet_output = outputs->elements()[0].toTensor(); + auto memory = outputs->elements()[1].toTensor(); + + torch::Tensor supervision_segments = + k2::GetSupervisionSegments(supervisions, subsampling_factor); + + K2_LOG(INFO) << "Build CTC topo"; + auto decoding_graph = k2::CtcTopo(nnet_output.size(2) - 1, false, device); + + K2_LOG(INFO) << "Decoding"; + k2::FsaClass lattice = k2::GetLattice( + nnet_output, decoding_graph, supervision_segments, FLAGS_search_beam, + FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states, + subsampling_factor); + + lattice = k2::ShortestPath(lattice); + + auto ragged_aux_labels = k2::GetTexts(lattice); + auto aux_labels_vec = ragged_aux_labels.ToVecVec(); + + sentencepiece::SentencePieceProcessor processor; + auto status = processor.Load(FLAGS_bpe_model); + K2_CHECK(status.ok()) << status.ToString(); + + std::vector texts; + for (const auto &ids : aux_labels_vec) { + std::string text; + status = processor.Decode(ids, &text); + K2_CHECK(status.ok()) << status.ToString(); + texts.emplace_back(std::move(text)); + } + + std::ostringstream os; + os << "\nDecoding result:\n\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + + return 0; +} diff --git a/k2/torch/bin/hlg_decode.cu b/k2/torch/bin/hlg_decode.cu new file mode 100644 index 000000000..7551881c3 --- /dev/null +++ b/k2/torch/bin/hlg_decode.cu @@ -0,0 +1,219 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/torch/csrc/decode.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/symbol_table.h" +#include "k2/torch/csrc/wave_reader.h" +#include "torch/all.h" +#include "torch/script.h" + +static constexpr const char *kUsageMessage = R"( +This file implements decoding with an HLG decoding graph. + +Usage: + ./bin/hlg_decode \ + --use_gpu true \ + --nn_model \ + --hlg \ + --word_table \ + \ + \ + + +To see all possible options, use + ./bin/hlg_decode --help + +Caution: + - Only sound files (*.wav) with single channel are supported. + - It assumes the model is conformer_ctc/transformer.py from icefall. + If you use a different model, you have to change the code + related to `model.forward` in this file. +)"; + +C10_DEFINE_bool(use_gpu, false, "true to use GPU; false to use CPU"); +C10_DEFINE_string(nn_model, "", "Path to the model exported by torch script."); +C10_DEFINE_string(hlg, "", "Path to HLG.pt."); +C10_DEFINE_string(word_table, "", "Path to words.txt."); + +// Fsa decoding related +C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned"); +C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned"); +C10_DEFINE_int(min_activate_states, 30, + "min_activate_states in IntersectDensePruned"); +C10_DEFINE_int(max_activate_states, 10000, + "max_activate_states in IntersectDensePruned"); +// Fbank related +// NOTE: These parameters must match those used in training +C10_DEFINE_int(sample_rate, 16000, "Expected sample rate of wave files"); +C10_DEFINE_double(frame_shift_ms, 10.0, + "Frame shift in ms for computing Fbank"); +C10_DEFINE_double(frame_length_ms, 25.0, + "Frame length in ms for computing Fbank"); +C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank"); + +static void CheckArgs() { +#if !defined(K2_WITH_CUDA) + if (FLAGS_use_gpu) { + std::cerr << "k2 was not compiled with CUDA. " + "Please use --use_gpu false"; + exit(EXIT_FAILURE); + } +#endif + + if (FLAGS_nn_model.empty()) { + std::cerr << "Please provide --nn_model\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_hlg.empty()) { + std::cerr << "Please provide --hlg\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_word_table.empty()) { + std::cerr << "Please provide --word_table\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } +} + +int main(int argc, char *argv[]) { + // see + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html + torch::set_num_threads(1); + torch::set_num_interop_threads(1); + torch::NoGradGuard no_grad; + + torch::SetUsageMessage(kUsageMessage); + torch::ParseCommandLineFlags(&argc, &argv); + CheckArgs(); + + torch::Device device(torch::kCPU); + if (FLAGS_use_gpu) { + K2_LOG(INFO) << "Use GPU"; + device = torch::Device(torch::kCUDA, 0); + } + + K2_LOG(INFO) << "Device: " << device; + + int32_t num_waves = argc - 1; + K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; + std::vector wave_filenames(num_waves); + for (int32_t i = 0; i != num_waves; ++i) { + wave_filenames[i] = argv[i + 1]; + } + + K2_LOG(INFO) << "Load wave files"; + auto wave_data = k2::ReadWave(wave_filenames, FLAGS_sample_rate); + + for (auto &w : wave_data) { + w = w.to(device); + } + + K2_LOG(INFO) << "Build Fbank computer"; + kaldifeat::FbankOptions fbank_opts; + fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate; + fbank_opts.frame_opts.dither = 0; + fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms; + fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms; + fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + fbank_opts.device = device; + + kaldifeat::Fbank fbank(fbank_opts); + + K2_LOG(INFO) << "Compute features"; + std::vector num_frames; + auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames); + + // Note: math.log(1e-10) is -23.025850929940457 + auto features = torch::nn::utils::rnn::pad_sequence(features_vec, true, + -23.025850929940457f); + + K2_LOG(INFO) << "Load neural network model"; + torch::jit::script::Module module = torch::jit::load(FLAGS_nn_model); + module.eval(); + module.to(device); + + int32_t subsampling_factor = module.attr("subsampling_factor").toInt(); + torch::Dict sup; + sup.insert("sequence_idx", torch::arange(num_waves, torch::kInt)); + sup.insert("start_frame", torch::zeros({num_waves}, torch::kInt)); + sup.insert("num_frames", + torch::from_blob(num_frames.data(), {num_waves}, torch::kLong) + .to(torch::kInt)); + + torch::IValue supervisions(sup); + + K2_LOG(INFO) << "Compute nnet_output"; + // the output for module.forward() is a tuple of 3 tensors + // See the definition of the model in conformer_ctc/transformer.py + // from icefall. + // If you use a model that has a different signature for `forward`, + // you can change the following line. + auto outputs = module.run_method("forward", features, supervisions).toTuple(); + assert(outputs->elements().size() == 3u); + + auto nnet_output = outputs->elements()[0].toTensor(); + auto memory = outputs->elements()[1].toTensor(); + + torch::Tensor supervision_segments = + k2::GetSupervisionSegments(supervisions, subsampling_factor); + + K2_LOG(INFO) << "Load " << FLAGS_hlg; + k2::FsaClass decoding_graph = k2::LoadFsa(FLAGS_hlg, device); + K2_CHECK(decoding_graph.HasTensorAttr("aux_labels") || + decoding_graph.HasRaggedTensorAttr("aux_labels")); + + K2_LOG(INFO) << "Decoding"; + k2::FsaClass lattice = k2::GetLattice( + nnet_output, decoding_graph, supervision_segments, FLAGS_search_beam, + FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states, + subsampling_factor); + + lattice = k2::ShortestPath(lattice); + + auto ragged_aux_labels = k2::GetTexts(lattice); + auto aux_labels_vec = ragged_aux_labels.ToVecVec(); + + std::vector texts; + k2::SymbolTable symbol_table(FLAGS_word_table); + for (const auto &ids : aux_labels_vec) { + std::string text; + std::string sep = ""; + for (auto id : ids) { + text.append(sep); + text.append(symbol_table[id]); + sep = " "; + } + texts.emplace_back(std::move(text)); + } + + std::ostringstream os; + os << "\nDecoding result:\n\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + + return 0; +} diff --git a/k2/torch/bin/ngram_lm_rescore.cu b/k2/torch/bin/ngram_lm_rescore.cu new file mode 100644 index 000000000..3ea8c31d6 --- /dev/null +++ b/k2/torch/bin/ngram_lm_rescore.cu @@ -0,0 +1,245 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/fsa_algo.h" +#include "k2/torch/csrc/decode.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/symbol_table.h" +#include "k2/torch/csrc/wave_reader.h" +#include "torch/all.h" +#include "torch/script.h" + +static constexpr const char *kUsageMessage = R"( +This file implements decoding with an HLG decoding graph, using +an n-gram LM for rescoring. + +Usage: + ./bin/ngram_lm_rescore \ + --use_gpu true \ + --nn_model \ + --hlg \ + --g \ + --ngram_lm_scale 1.0 \ + --word_table \ + \ + \ + + +To see all possible options, use + ./bin/ngram_lm_rescore --help + +Caution: + - Only sound files (*.wav) with single channel are supported. + - It assumes the model is conformer_ctc/transformer.py from icefall. + If you use a different model, you have to change the code + related to `model.forward` in this file. +)"; + +C10_DEFINE_bool(use_gpu, false, "true to use GPU; false to use CPU"); +C10_DEFINE_string(nn_model, "", "Path to the model exported by torch script."); +C10_DEFINE_string(hlg, "", "Path to HLG.pt."); +C10_DEFINE_string(g, "", "Path to an ngram LM, e.g, G_4gram.pt"); +C10_DEFINE_double(ngram_lm_scale, 1.0, "Scale for ngram LM scores"); +C10_DEFINE_string(word_table, "", "Path to words.txt."); + +// Fsa decoding related +C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned"); +C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned"); +C10_DEFINE_int(min_activate_states, 30, + "min_activate_states in IntersectDensePruned"); +C10_DEFINE_int(max_activate_states, 10000, + "max_activate_states in IntersectDensePruned"); +// Fbank related +// NOTE: These parameters must match those used in training +C10_DEFINE_int(sample_rate, 16000, "Expected sample rate of wave files"); +C10_DEFINE_double(frame_shift_ms, 10.0, + "Frame shift in ms for computing Fbank"); +C10_DEFINE_double(frame_length_ms, 25.0, + "Frame length in ms for computing Fbank"); +C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank"); + +static void CheckArgs() { +#if !defined(K2_WITH_CUDA) + if (FLAGS_use_gpu) { + std::cerr << "k2 was not compiled with CUDA. " + "Please use --use_gpu false"; + exit(EXIT_FAILURE); + } +#endif + + if (FLAGS_nn_model.empty()) { + std::cerr << "Please provide --nn_model\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_hlg.empty()) { + std::cerr << "Please provide --hlg\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_g.empty()) { + std::cerr << "Please provide --g\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_word_table.empty()) { + std::cerr << "Please provide --word_table\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } +} + +int main(int argc, char *argv[]) { + // see + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html + torch::set_num_threads(1); + torch::set_num_interop_threads(1); + torch::NoGradGuard no_grad; + + torch::SetUsageMessage(kUsageMessage); + torch::ParseCommandLineFlags(&argc, &argv); + CheckArgs(); + + torch::Device device(torch::kCPU); + if (FLAGS_use_gpu) { + K2_LOG(INFO) << "Use GPU"; + device = torch::Device(torch::kCUDA, 0); + } + + K2_LOG(INFO) << "Device: " << device; + + int32_t num_waves = argc - 1; + K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; + std::vector wave_filenames(num_waves); + for (int32_t i = 0; i != num_waves; ++i) { + wave_filenames[i] = argv[i + 1]; + } + + K2_LOG(INFO) << "Load wave files"; + auto wave_data = k2::ReadWave(wave_filenames, FLAGS_sample_rate); + + for (auto &w : wave_data) { + w = w.to(device); + } + + K2_LOG(INFO) << "Build Fbank computer"; + kaldifeat::FbankOptions fbank_opts; + fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate; + fbank_opts.frame_opts.dither = 0; + fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms; + fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms; + fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + fbank_opts.device = device; + + kaldifeat::Fbank fbank(fbank_opts); + + K2_LOG(INFO) << "Compute features"; + std::vector num_frames; + auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames); + + // Note: math.log(1e-10) is -23.025850929940457 + auto features = torch::nn::utils::rnn::pad_sequence(features_vec, true, + -23.025850929940457f); + + K2_LOG(INFO) << "Load neural network model"; + torch::jit::script::Module module = torch::jit::load(FLAGS_nn_model); + module.eval(); + module.to(device); + + int32_t subsampling_factor = module.attr("subsampling_factor").toInt(); + torch::Dict sup; + sup.insert("sequence_idx", torch::arange(num_waves, torch::kInt)); + sup.insert("start_frame", torch::zeros({num_waves}, torch::kInt)); + sup.insert("num_frames", + torch::from_blob(num_frames.data(), {num_waves}, torch::kLong) + .to(torch::kInt)); + + torch::IValue supervisions(sup); + + K2_LOG(INFO) << "Compute nnet_output"; + // the output for module.forward() is a tuple of 3 tensors + // See the definition of the model in conformer_ctc/transformer.py + // from icefall. + // If you use a model that has a different signature for `forward`, + // you can change the following line. + auto outputs = module.run_method("forward", features, supervisions).toTuple(); + assert(outputs->elements().size() == 3u); + + auto nnet_output = outputs->elements()[0].toTensor(); + auto memory = outputs->elements()[1].toTensor(); + + torch::Tensor supervision_segments = + k2::GetSupervisionSegments(supervisions, subsampling_factor); + + K2_LOG(INFO) << "Load " << FLAGS_hlg; + k2::FsaClass decoding_graph = k2::LoadFsa(FLAGS_hlg, device); + K2_CHECK(decoding_graph.HasTensorAttr("aux_labels") || + decoding_graph.HasRaggedTensorAttr("aux_labels")); + // Add `lm_scores` so that we can separate acoustic scores and lm scores + // later in the rescoring stage. + decoding_graph.SetTensorAttr("lm_scores", decoding_graph.Scores().clone()); + + K2_LOG(INFO) << "Decoding"; + k2::FsaClass lattice = k2::GetLattice( + nnet_output, decoding_graph, supervision_segments, FLAGS_search_beam, + FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states, + subsampling_factor); + + K2_LOG(INFO) << "Load n-gram LM: " << FLAGS_g; + k2::FsaClass G = k2::LoadFsa(FLAGS_g, device); + G.fsa = k2::FsaToFsaVec(G.fsa); + + K2_CHECK_EQ(G.NumAttrs(), 0) << "G is expected to be an acceptor."; + k2::AddEpsilonSelfLoops(G.fsa, &G.fsa); + k2::ArcSort(&G.fsa); + G.SetTensorAttr("lm_scores", G.Scores().clone()); + + K2_LOG(INFO) << "Rescore with an n-gram LM"; + WholeLatticeRescoring(G, FLAGS_ngram_lm_scale, &lattice); + + lattice = k2::ShortestPath(lattice); + + auto ragged_aux_labels = k2::GetTexts(lattice); + auto aux_labels_vec = ragged_aux_labels.ToVecVec(); + + std::vector texts; + k2::SymbolTable symbol_table(FLAGS_word_table); + for (const auto &ids : aux_labels_vec) { + std::string text; + std::string sep = ""; + for (auto id : ids) { + text.append(sep); + text.append(symbol_table[id]); + sep = " "; + } + texts.emplace_back(std::move(text)); + } + + std::ostringstream os; + os << "\nDecoding result:\n\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + + return 0; +} diff --git a/k2/torch/csrc/CPPLINT.cfg b/k2/torch/csrc/CPPLINT.cfg index 360dcce12..ce4942ccf 100644 --- a/k2/torch/csrc/CPPLINT.cfg +++ b/k2/torch/csrc/CPPLINT.cfg @@ -1,2 +1,3 @@ exclude_files=custom_class.h exclude_files=test_wave_data.h +exclude_files=test_deserialization_data.h diff --git a/k2/torch/csrc/utils.cu b/k2/torch/csrc/utils.cu index 4eb0da8c0..324ae5f70 100644 --- a/k2/torch/csrc/utils.cu +++ b/k2/torch/csrc/utils.cu @@ -19,6 +19,7 @@ #include #include #include +#include #include "caffe2/serialize/file_adapter.h" #include "caffe2/serialize/inline_container.h"