Skip to content

Commit

Permalink
Fix race condition in writing lattices out. (#36)
Browse files Browse the repository at this point in the history
* Fix race condition in writing lattices out.

This could cause a bottleneck. But the current behavior is incorrect
so let's fix the correctness first.

* Update cibuildwheel config.

Update version.

Warning: No tests are running right now, beause unfortunately no
version builds correctly at the same time as nemo.
  • Loading branch information
galv committed Jan 30, 2024
1 parent 8282368 commit c94c84e
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
8 changes: 6 additions & 2 deletions pyproject.toml
Expand Up @@ -10,7 +10,7 @@ requires = [
build-backend = "setuptools.build_meta"

[tool.cibuildwheel]
build = "{cp38-manylinux_x86_64,cp39-manylinux_x86_64,cp310-manylinux_x86_64,cp311-manylinux_x86_64}"
build = "{cp38-manylinux_x86_64,cp39-manylinux_x86_64,cp310-manylinux_x86_64,cp311-manylinux_x86_64,cp312-manylinux_x86_64}"
# Need to set the pythonLocation environment variable manually becase CMake picks up /usr/bin/python3.6 no matter what otherwise
# https://cmake.org/cmake/help/latest/prop_tgt/CUDA_ARCHITECTURES.html#prop_tgt:CUDA_ARCHITECTURES
environment = 'RIVA_ASRLIB_DECODER_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=Release" CUDAARCHS="52-real;60-real;61-real;70-real;75-real;80-real;86" pythonLocation="$(readlink -f $(which python))/.."'
Expand All @@ -21,7 +21,11 @@ test-command = "pytest -m ci {project}/src/riva/asrlib/decoder/"
test-requires = ["pytest", "kaldi-io", "more-itertools", "nemo_toolkit[asr]", "torchaudio"]
# 3.11 fails becasue of numba: Cannot install on Python version 3.11.1; only versions >=3.7,<3.11 are supported
# 3.10 fails because of some kind of sqlite issue (see below)
test-skip = "{cp310-manylinux_x86_64,cp311-manylinux_x86_64}"
# 3.8 fails because of TypeError: unsupported operand type(s) for |: 'type' and '_GenericAlias in nemo/collections/asr/data/huggingface/hf_audio_to_text.py:55
# 3.9 fails because of TypeError: unsupported operand type(s) for |: 'type' and '_GenericAlias in nemo/collections/asr/data/huggingface/hf_audio_to_text.py:55
# NeMo supports only python3.10 and up now.
# 3.12 is not supported by numba right now: https://github.com/numba/numba/issues/9197#issuecomment-1865218765. numba is a transitive dependency of nemo
test-skip = "{cp38-manylinux_x86_64,cp39-manylinux_x86_64,cp310-manylinux_x86_64,cp311-manylinux_x86_64,cp312-manylinux_x86_64}"

# _____ ERROR collecting src/riva/asrlib/decoder/test_graph_construction.py ______
# ImportError while importing test module '/project/src/riva/asrlib/decoder/test_graph_construction.py'.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -96,7 +96,7 @@ def build_extension(self, ext: setuptools.extension.Extension):
setuptools.setup(
python_requires='>=3.7',
name='riva-asrlib-decoder',
version='0.4.3',
version='0.4.4',
author='NVIDIA',
author_email='dgalvez@nvidia.com',
keywords='ASR, CUDA, WFST, Decoder',
Expand Down
11 changes: 8 additions & 3 deletions src/riva/asrlib/decoder/python_decoder.cc
Expand Up @@ -289,8 +289,10 @@ NanobindBatchedMappedDecoderCuda(nb::module_& m)
"decode_write_lattice",
[](PyClass& cuda_pipeline, LogitsArray& logits, LogitsLengthsArray& logits_lengths,
const std::vector<std::string>& keys, const std::string& output_wspecifier) {
int64_t batch_size = logits_lengths.shape(0);

int64_t batch_size = logits_lengths.shape(0);
// protects clat_writer, which is not thread safe
std::mutex write_mutex;
kaldi::CompactLatticeWriter clat_writer(output_wspecifier);
for (int64_t i = 0; i < batch_size; ++i) {
int64_t valid_time_steps = logits_lengths(i);
Expand All @@ -301,9 +303,12 @@ NanobindBatchedMappedDecoderCuda(nb::module_& m)
// stride of each row is stride. Always greater than number of cols
auto write_results =
[i, &clat_writer,
&keys](riva::asrlib::BatchedMappedOnlineDecoderCuda::ReturnType& asr_results) {
&keys, &write_mutex](riva::asrlib::BatchedMappedOnlineDecoderCuda::ReturnType& asr_results) {
const kaldi::CompactLattice& lattice = std::get<0>(asr_results).value();
clat_writer.Write(keys[i], lattice);
{
std::lock_guard<std::mutex> lock(write_mutex);
clat_writer.Write(keys[i], lattice);
}
};
cuda_pipeline.DecodeWithCallback(
single_sample_logits_start, logits.stride(1), valid_time_steps, write_results);
Expand Down

0 comments on commit c94c84e

Please sign in to comment.