-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Add CTC decode. * Add HLG decoding. * Add n-gram LM rescoring. * Remove unused files. * Fix style issues. * Add missing files.
- Loading branch information
1 parent
c9ca90c
commit df375d8
Showing
8 changed files
with
713 additions
and
11 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -585,3 +585,5 @@ Mkfile.old | |
dkms.conf | ||
|
||
!.github/** | ||
!k2/torch/bin | ||
*-bak |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,41 @@ | ||
# it is located in k2/csrc/cmake/transform.cmake | ||
include(transform) | ||
|
||
set(decode_srcs decode.cu) | ||
set(bin_dep_libs | ||
${TORCH_LIBRARIES} | ||
k2_torch | ||
sentencepiece-static # see cmake/sentencepiece.cmake | ||
) | ||
|
||
#---------------------------------------- | ||
# CTC decoding | ||
#---------------------------------------- | ||
set(ctc_decode_srcs ctc_decode.cu) | ||
if(NOT K2_WITH_CUDA) | ||
transform(OUTPUT_VARIABLE decode_srcs SRCS ${decode_srcs}) | ||
transform(OUTPUT_VARIABLE ctc_decode_srcs SRCS ${ctc_decode_srcs}) | ||
endif() | ||
add_executable(ctc_decode ${ctc_decode_srcs}) | ||
set_property(TARGET ctc_decode PROPERTY CXX_STANDARD 14) | ||
target_link_libraries(ctc_decode ${bin_dep_libs}) | ||
|
||
add_executable(decode ${decode_srcs}) | ||
#---------------------------------------- | ||
# HLG decoding | ||
#---------------------------------------- | ||
set(hlg_decode_srcs hlg_decode.cu) | ||
if(NOT K2_WITH_CUDA) | ||
transform(OUTPUT_VARIABLE hlg_decode_srcs SRCS ${hlg_decode_srcs}) | ||
endif() | ||
add_executable(hlg_decode ${hlg_decode_srcs}) | ||
set_property(TARGET hlg_decode PROPERTY CXX_STANDARD 14) | ||
target_link_libraries(hlg_decode ${bin_dep_libs}) | ||
|
||
set_property(TARGET decode PROPERTY CXX_STANDARD 14) | ||
target_link_libraries(decode | ||
${TORCH_LIBRARIES} # see cmake/torch.cmake | ||
context | ||
k2_torch | ||
sentencepiece-static # see cmake/sentencepiece.cmake | ||
) | ||
#------------------------------------------- | ||
# HLG decoding + n-gram LM rescoring | ||
#------------------------------------------- | ||
set(ngram_lm_rescore_srcs ngram_lm_rescore.cu) | ||
if(NOT K2_WITH_CUDA) | ||
transform(OUTPUT_VARIABLE ngram_lm_rescore_srcs SRCS ${ngram_lm_rescore_srcs}) | ||
endif() | ||
add_executable(ngram_lm_rescore ${ngram_lm_rescore_srcs}) | ||
set_property(TARGET ngram_lm_rescore PROPERTY CXX_STANDARD 14) | ||
target_link_libraries(ngram_lm_rescore ${bin_dep_libs}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,210 @@ | ||
/** | ||
* Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) | ||
* | ||
* See LICENSE for clarification regarding multiple authors | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "k2/torch/csrc/decode.h" | ||
#include "k2/torch/csrc/dense_fsa_vec.h" | ||
#include "k2/torch/csrc/deserialization.h" | ||
#include "k2/torch/csrc/features.h" | ||
#include "k2/torch/csrc/fsa_algo.h" | ||
#include "k2/torch/csrc/wave_reader.h" | ||
#include "sentencepiece_processor.h" // NOLINT | ||
#include "torch/all.h" | ||
#include "torch/script.h" | ||
|
||
static constexpr const char *kUsageMessage = R"( | ||
This file implements decoding with a CTC topology, without any | ||
kinds of LM or lexicons. | ||
Usage: | ||
./bin/ctc_decode \ | ||
--use_gpu true \ | ||
--nn_model <path to torch scripted pt file> \ | ||
--bpe_model <path to pre-trained BPE model> \ | ||
<path to foo.wav> \ | ||
<path to bar.wav> \ | ||
<more waves if any> | ||
To see all possible options, use | ||
./bin/ctc_decode --help | ||
Caution: | ||
- Only sound files (*.wav) with single channel are supported. | ||
- It assumes the model is conformer_ctc/transformer.py from icefall. | ||
If you use a different model, you have to change the code | ||
related to `model.forward` in this file. | ||
)"; | ||
|
||
C10_DEFINE_bool(use_gpu, false, "true to use GPU; false to use CPU"); | ||
C10_DEFINE_string(nn_model, "", "Path to the model exported by torch script."); | ||
C10_DEFINE_string(bpe_model, "", "Path to the pretrained BPE model."); | ||
|
||
// Fsa decoding related | ||
C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned"); | ||
C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned"); | ||
C10_DEFINE_int(min_activate_states, 30, | ||
"min_activate_states in IntersectDensePruned"); | ||
C10_DEFINE_int(max_activate_states, 10000, | ||
"max_activate_states in IntersectDensePruned"); | ||
// Fbank related | ||
// NOTE: These parameters must match those used in training | ||
C10_DEFINE_int(sample_rate, 16000, "Expected sample rate of wave files"); | ||
C10_DEFINE_double(frame_shift_ms, 10.0, | ||
"Frame shift in ms for computing Fbank"); | ||
C10_DEFINE_double(frame_length_ms, 25.0, | ||
"Frame length in ms for computing Fbank"); | ||
C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank"); | ||
|
||
static void CheckArgs() { | ||
#if !defined(K2_WITH_CUDA) | ||
if (FLAGS_use_gpu) { | ||
std::cerr << "k2 was not compiled with CUDA. " | ||
"Please use --use_gpu false"; | ||
exit(EXIT_FAILURE); | ||
} | ||
#endif | ||
|
||
if (FLAGS_nn_model.empty()) { | ||
std::cerr << "Please provide --nn_model\n" << torch::UsageMessage(); | ||
exit(EXIT_FAILURE); | ||
} | ||
|
||
if (FLAGS_bpe_model.empty()) { | ||
std::cerr << "Please provide --bpe_model\n" << torch::UsageMessage(); | ||
exit(EXIT_FAILURE); | ||
} | ||
} | ||
|
||
int main(int argc, char *argv[]) { | ||
// see | ||
// https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html | ||
torch::set_num_threads(1); | ||
torch::set_num_interop_threads(1); | ||
torch::NoGradGuard no_grad; | ||
|
||
torch::SetUsageMessage(kUsageMessage); | ||
torch::ParseCommandLineFlags(&argc, &argv); | ||
CheckArgs(); | ||
|
||
torch::Device device(torch::kCPU); | ||
if (FLAGS_use_gpu) { | ||
K2_LOG(INFO) << "Use GPU"; | ||
device = torch::Device(torch::kCUDA, 0); | ||
} | ||
|
||
K2_LOG(INFO) << "Device: " << device; | ||
|
||
int32_t num_waves = argc - 1; | ||
K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; | ||
std::vector<std::string> wave_filenames(num_waves); | ||
for (int32_t i = 0; i != num_waves; ++i) { | ||
wave_filenames[i] = argv[i + 1]; | ||
} | ||
|
||
K2_LOG(INFO) << "Load wave files"; | ||
auto wave_data = k2::ReadWave(wave_filenames, FLAGS_sample_rate); | ||
|
||
for (auto &w : wave_data) { | ||
w = w.to(device); | ||
} | ||
|
||
K2_LOG(INFO) << "Build Fbank computer"; | ||
kaldifeat::FbankOptions fbank_opts; | ||
fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate; | ||
fbank_opts.frame_opts.dither = 0; | ||
fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms; | ||
fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms; | ||
fbank_opts.mel_opts.num_bins = FLAGS_num_bins; | ||
fbank_opts.device = device; | ||
|
||
kaldifeat::Fbank fbank(fbank_opts); | ||
|
||
K2_LOG(INFO) << "Compute features"; | ||
std::vector<int64_t> num_frames; | ||
auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames); | ||
|
||
// Note: math.log(1e-10) is -23.025850929940457 | ||
auto features = torch::nn::utils::rnn::pad_sequence(features_vec, true, | ||
-23.025850929940457f); | ||
|
||
K2_LOG(INFO) << "Load neural network model"; | ||
torch::jit::script::Module module = torch::jit::load(FLAGS_nn_model); | ||
module.eval(); | ||
module.to(device); | ||
|
||
int32_t subsampling_factor = module.attr("subsampling_factor").toInt(); | ||
torch::Dict<std::string, torch::Tensor> sup; | ||
sup.insert("sequence_idx", torch::arange(num_waves, torch::kInt)); | ||
sup.insert("start_frame", torch::zeros({num_waves}, torch::kInt)); | ||
sup.insert("num_frames", | ||
torch::from_blob(num_frames.data(), {num_waves}, torch::kLong) | ||
.to(torch::kInt)); | ||
|
||
torch::IValue supervisions(sup); | ||
|
||
K2_LOG(INFO) << "Compute nnet_output"; | ||
// the output for module.forward() is a tuple of 3 tensors | ||
// See the definition of the model in conformer_ctc/transformer.py | ||
// from icefall. | ||
// If you use a model that has a different signature for `forward`, | ||
// you can change the following line. | ||
auto outputs = module.run_method("forward", features, supervisions).toTuple(); | ||
assert(outputs->elements().size() == 3u); | ||
|
||
auto nnet_output = outputs->elements()[0].toTensor(); | ||
auto memory = outputs->elements()[1].toTensor(); | ||
|
||
torch::Tensor supervision_segments = | ||
k2::GetSupervisionSegments(supervisions, subsampling_factor); | ||
|
||
K2_LOG(INFO) << "Build CTC topo"; | ||
auto decoding_graph = k2::CtcTopo(nnet_output.size(2) - 1, false, device); | ||
|
||
K2_LOG(INFO) << "Decoding"; | ||
k2::FsaClass lattice = k2::GetLattice( | ||
nnet_output, decoding_graph, supervision_segments, FLAGS_search_beam, | ||
FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states, | ||
subsampling_factor); | ||
|
||
lattice = k2::ShortestPath(lattice); | ||
|
||
auto ragged_aux_labels = k2::GetTexts(lattice); | ||
auto aux_labels_vec = ragged_aux_labels.ToVecVec(); | ||
|
||
sentencepiece::SentencePieceProcessor processor; | ||
auto status = processor.Load(FLAGS_bpe_model); | ||
K2_CHECK(status.ok()) << status.ToString(); | ||
|
||
std::vector<std::string> texts; | ||
for (const auto &ids : aux_labels_vec) { | ||
std::string text; | ||
status = processor.Decode(ids, &text); | ||
K2_CHECK(status.ok()) << status.ToString(); | ||
texts.emplace_back(std::move(text)); | ||
} | ||
|
||
std::ostringstream os; | ||
os << "\nDecoding result:\n\n"; | ||
for (int32_t i = 0; i != num_waves; ++i) { | ||
os << wave_filenames[i] << "\n"; | ||
os << texts[i]; | ||
os << "\n\n"; | ||
} | ||
K2_LOG(INFO) << os.str(); | ||
|
||
return 0; | ||
} |
Oops, something went wrong.