From 9b0763930ea9d773c14462992157b75878d0f187 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Fri, 4 Nov 2022 11:03:17 +0800 Subject: [PATCH] Merge the branch v2.0-pre to master (#1085) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * [WIP]: Move k2.Fsa to C++ (#814) * Make k2 ragged tensor more PyTorch-y like. * Refactoring: Start to add the wrapper class AnyTensor. * Refactoring. * initial attempt to support autograd. * First working version with autograd for Sum(). * Fix comments. * Support __getitem__ and pickling. * Add more docs for k2.ragged.Tensor * Put documentation in header files. * Minor fixes. * Fix a typo. * Fix an error. * Add more doc. * Wrap RaggedShape. * [Not for Merge]: Move k2.Fsa related code to C++. * Remove extra files. * Update doc URL. (#821) * Support manipulating attributes of k2.ragged.Fsa. * Support indexing 2-axes RaggedTensor, Support slicing for RaggedTensor (#825) * Support index 2-axes RaggedTensor, Support slicing for RaggedTensor * Fix compiling errors * Fix unit test * Change RaggedTensor.data to RaggedTensor.values * Fix style * Add docs * Run nightly-cpu when pushing code to nightly-cpu branch * Prune with max_arcs in IntersectDense (#820) * Add checking for array constructor * Prune with max arcs * Minor fix * Fix typo * Fix review comments * Fix typo * Release v1.8 * Create a ragged tensor from a regular tensor. (#827) * Create a ragged tensor from a regular tensor. * Add tests for creating ragged tensors from regular tensors. * Add more tests. * Print ragged tensors in a way like what PyTorch is doing. * Fix test cases. * Trigger GitHub actions manually. (#829) * Run GitHub actions on merging. (#830) * Support printing ragged tensors in a more compact way. (#831) * Support printing ragged tensors in a more compact way. * Disable support for torch 1.3.1 * Fix test failures. * Add levenshtein alignment (#828) * Add levenshtein graph * Contruct k2.RaggedTensor in python part * Fix review comments, return aux_labels in ctc_graph * Fix tests * Fix bug of accessing symbols * Fix bug of accessing symbols * Change argument name, add levenshtein_distance interface * Fix test error, add tests for levenshtein_distance * Fix review comments and add unit test for c++ side * update the interface of levenshtein alignment * Fix review comments * Release v1.9 * Add Fsa.get_forward_scores. * Implement backprop for Fsa.get_forward_scores() * Construct RaggedArc from unary function tensor (#30) * Construct RaggedArc from unary function tensor * Move fsa_from_unary_ragged and fsa_from_binary_tensor to C++ * add unit test to from unary function; add more functions to fsa * Remove some rabbish code * Add more unit tests and docs * Remove the unused code * Fix review comments, propagate attributes in To() * Change the argument type from RaggedAny to Ragged in autograd function * Delete declaration for template function * Apply suggestions from code review Co-authored-by: Fangjun Kuang * Fix documentation errors Co-authored-by: Fangjun Kuang Co-authored-by: Wei Kang * Remove pybind dependencies from RaggedArc. (#842) * Convert py::object and torch::IValue to each other * Remove py::object from RaggedAny * Remove py::object from RaggedArc * Move files to torch directory * remove unused files * Add unit tests * Remove v2 folder * Remove unused code * Remove unused files * Fix review comments & fix github actions * Check Ivalue contains RaggedAny * Minor fixes * Add attributes related unit test for FsaClass * Fix mutable_grad in older pytorch version * Fix github actions * Fix github action PYTHONPATH * Fix github action PYTHONPATH * Link pybind11::embed * import torch first (to fix macos github actions) * try to fix macos ci * Revert "Remove pybind dependencies from RaggedArc. (#842)" (#855) This reverts commit daa98e7504669a57495bb003f884051a9d7792be. * Support torchscript. (#839) * WIP: Support torchscript. * Test jit module with faked data. I have compared the output from C++ with that from Python. The sums of the tensors are equal. * Use precomputed features to test the correctness. * Build DenseFsaVec from a torch tensor. * Get lattice for CTC decoding. * Support CTC decoding. * Link sentencepiece statically. Link sentencepiece dynamically causes segmentation fault at the end of the process. * Support loading HLG.pt * Refactoring. * Implement HLG decoding. * Add WaveReader to read wave sound files. * Take soundfiles as inputs. * Refactoring. * Support GPU. * Minor fixes. * Fix typos. * Use kaldifeat v1.7 * Add copyright info. * Fix compilation for torch >= 1.9.0 * Minor fixes. * Fix comments. * Fix style issues. * Fix compiler warnings. * Use `torch::class_` to register custom classes. (#856) * Remove unused code (#857) * Update doc URL. (#821) * Support indexing 2-axes RaggedTensor, Support slicing for RaggedTensor (#825) * Support index 2-axes RaggedTensor, Support slicing for RaggedTensor * Fix compiling errors * Fix unit test * Change RaggedTensor.data to RaggedTensor.values * Fix style * Add docs * Run nightly-cpu when pushing code to nightly-cpu branch * Prune with max_arcs in IntersectDense (#820) * Add checking for array constructor * Prune with max arcs * Minor fix * Fix typo * Fix review comments * Fix typo * Release v1.8 * Create a ragged tensor from a regular tensor. (#827) * Create a ragged tensor from a regular tensor. * Add tests for creating ragged tensors from regular tensors. * Add more tests. * Print ragged tensors in a way like what PyTorch is doing. * Fix test cases. * Trigger GitHub actions manually. (#829) * Run GitHub actions on merging. (#830) * Support printing ragged tensors in a more compact way. (#831) * Support printing ragged tensors in a more compact way. * Disable support for torch 1.3.1 * Fix test failures. * Add levenshtein alignment (#828) * Add levenshtein graph * Contruct k2.RaggedTensor in python part * Fix review comments, return aux_labels in ctc_graph * Fix tests * Fix bug of accessing symbols * Fix bug of accessing symbols * Change argument name, add levenshtein_distance interface * Fix test error, add tests for levenshtein_distance * Fix review comments and add unit test for c++ side * update the interface of levenshtein alignment * Fix review comments * Release v1.9 * Support a[b[i]] where both a and b are ragged tensors. (#833) * Display import error solution message on MacOS (#837) * Fix installation doc. (#841) * Fix installation doc. Remove Windows support. Will fix it later. * Fix style issues. * fix typos in the install instructions (#844) * make cmake adhere to the modernized way of finding packages outside default dirs (#845) * import torch first in the smoke tests to preven SEGFAULT (#846) * Add doc about how to install a CPU version of k2. (#850) * Add doc about how to install a CPU version of k2. * Remove property setter of Fsa.labels * Update Ubuntu version in GitHub CI since 16.04 reaches end-of-life. * Support PyTorch 1.10. (#851) * Fix test cases for k2.union() (#853) * Revert "Construct RaggedArc from unary function tensor (#30)" (#31) This reverts commit cca7a540334fae0424f238e879c6b731bbf97ba0. * Remove unused code. * Fix github actions. Avoid downloading all git LFS files. * Enable github actions for v2.0-pre branch. Co-authored-by: Wei Kang Co-authored-by: Piotr Żelasko Co-authored-by: Jan "yenda" Trmal * Implements Cpp version FsaClass (#858) * Add C++ version FsaClass * Propagates attributes for CreateFsaVec * Add more docs * Remove the code that unnecessary needed currently * Remove the code unnecessary for ctc decoding & HLG decoding * Update k2/torch/csrc/deserialization.h Co-authored-by: Fangjun Kuang * Fix Comments * Fix code style Co-authored-by: Fangjun Kuang * Using FsaClass for ctc decoding & HLG decoding (#862) * Using FsaClass for ctc decoding & HLG decoding * Update docs * fix evaluating kFsaPropertiesValid (#866) * Refactor deserialization code (#863) * Fix compiler warnings about the usage of `tmpnam`. * Refactor deserialization code. * Minor fixes. * Support rescoring with an n-gram LM during decoding (#867) * Fix compiler warnings about the usage of `tmpnam`. * Refactor deserialization code. * Minor fixes. * Add n-gram LM rescoring. * Minor fixes. * Clear cached FSA properties when its labels are changed. * Fix typos. * Refactor FsaClass. (#868) Since FSAs in decoding contain only one or two attributes, we don't need to use an IValue to add one more indirection. Just check the type of the attribute and process it correspondingly. * Refactor bin/decode.cu (#869) * Add CTC decode. * Add HLG decoding. * Add n-gram LM rescoring. * Remove unused files. * Fix style issues. * Add missing files. * Add attention rescoring. (#870) * WIP: Add attention rescoring. * Finish attention rescoring. * Fix style issues. * Resolve comments. (#871) * Resolve comments. * Minor fixes. * update v2.0-pre (#922) * Update doc URL. (#821) * Support indexing 2-axes RaggedTensor, Support slicing for RaggedTensor (#825) * Support index 2-axes RaggedTensor, Support slicing for RaggedTensor * Fix compiling errors * Fix unit test * Change RaggedTensor.data to RaggedTensor.values * Fix style * Add docs * Run nightly-cpu when pushing code to nightly-cpu branch * Prune with max_arcs in IntersectDense (#820) * Add checking for array constructor * Prune with max arcs * Minor fix * Fix typo * Fix review comments * Fix typo * Release v1.8 * Create a ragged tensor from a regular tensor. (#827) * Create a ragged tensor from a regular tensor. * Add tests for creating ragged tensors from regular tensors. * Add more tests. * Print ragged tensors in a way like what PyTorch is doing. * Fix test cases. * Trigger GitHub actions manually. (#829) * Run GitHub actions on merging. (#830) * Support printing ragged tensors in a more compact way. (#831) * Support printing ragged tensors in a more compact way. * Disable support for torch 1.3.1 * Fix test failures. * Add levenshtein alignment (#828) * Add levenshtein graph * Contruct k2.RaggedTensor in python part * Fix review comments, return aux_labels in ctc_graph * Fix tests * Fix bug of accessing symbols * Fix bug of accessing symbols * Change argument name, add levenshtein_distance interface * Fix test error, add tests for levenshtein_distance * Fix review comments and add unit test for c++ side * update the interface of levenshtein alignment * Fix review comments * Release v1.9 * Support a[b[i]] where both a and b are ragged tensors. (#833) * Display import error solution message on MacOS (#837) * Fix installation doc. (#841) * Fix installation doc. Remove Windows support. Will fix it later. * Fix style issues. * fix typos in the install instructions (#844) * make cmake adhere to the modernized way of finding packages outside default dirs (#845) * import torch first in the smoke tests to preven SEGFAULT (#846) * Add doc about how to install a CPU version of k2. (#850) * Add doc about how to install a CPU version of k2. * Remove property setter of Fsa.labels * Update Ubuntu version in GitHub CI since 16.04 reaches end-of-life. * Support PyTorch 1.10. (#851) * Fix test cases for k2.union() (#853) * Fix out-of-boundary access (read). (#859) * Update all the example codes in the docs (#861) * Update all the example codes in the docs I have run all the modified codes with the newest version k2. * do some changes * Fix compilation errors with CUB 1.15. (#865) * Update README. (#873) * Update README. * Fix typos. * Fix ctc graph (make aux_labels of final arcs -1) (#877) * Fix LICENSE location to k2 folder (#880) * Release v1.11. (#881) It contains bugfixes. * Update documentation for hash.h (#887) * Update documentation for hash.h * Typo fix * Wrap MonotonicLowerBound (#883) * Wrap MonotonicLowerBound * Add unit tests * Support int64; update documents * Remove extra commas after 'TOPSORTED' properity and fix RaggedTensor constructer parameter 'byte_offset' out-of-range bug. (#892) Co-authored-by: gzchenduisheng * Fix small typos (#896) * Fix k2.ragged.create_ragged_shape2 (#901) Before the fix, we have to specify both `row_splits` and `row_ids` while calling `k2.create_ragged_shape2` even if one of them is `None`. After this fix, we only need to specify one of them. * Add rnnt loss (#891) * Add cpp code of mutual information * mutual information working * Add rnnt loss * Add pruned rnnt loss * Minor Fixes * Minor fixes & fix code style * Fix cpp style * Fix code style * Fix s_begin values in padding positions * Fix bugs related to boundary; Fix s_begin padding value; Add more tests * Minor fixes * Fix comments * Add boundary to pruned loss tests * Use more efficient way to fix boundaries (#906) * Release v1.12 (#907) * Change the sign of the rnnt_loss and add reduction argument (#911) * Add right boundary constrains for s_begin * Minor fixes to the interface of rnnt_loss to make it return positive value * Fix comments * Release a new version * Minor fixes * Minor fixes to the docs * Fix building doc. (#908) * Fix building doc. * Minor fixes. * Minor fixes. * Fix building doc (#912) * Fix building doc * Fix flake8 * Support torch 1.10.x (#914) * Support torch 1.10.x * Fix installing PyTorch. * Update INSTALL.rst (#915) * Update INSTALL.rst Setting a few additional env variables to enable compilation from source *with CUDA GPU computation support enabled* * Fix torch/cuda/python versions in the doc. (#918) * Fix torch/cuda/python versions in the doc. * Minor fixes. * Fix building for CUDA 11.6 (#917) * Fix building for CUDA 11.6 * Minor fixes. * Implement Unstack (#920) * Implement unstack * Remove code does not relate to this PR * Remove for loop on output dim; add Unstack ragged * Add more docs * Fix comments * Fix docs & unit tests * SubsetRagged & PruneRagged (#919) * Extend interface of SubsampleRagged. * Add interface for pruning ragged tensor. * Draft of new RNN-T decoding method * Implements SubsampleRaggedShape * Implements PruneRagged * Rename subsample-> subset * Minor fixes * Fix comments Co-authored-by: Daniel Povey Co-authored-by: Fangjun Kuang Co-authored-by: Piotr Żelasko Co-authored-by: Jan "yenda" Trmal Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Co-authored-by: Ludwig Kürzinger Co-authored-by: Daniel Povey Co-authored-by: drawfish Co-authored-by: gzchenduisheng Co-authored-by: alexei-v-ivanov * Online decoding (#876) * Add OnlineIntersectDensePruned * Fix get partial results * Support online decoding on intersect_dense_pruned * Update documents * Update v2.0-pre (#942) * Update doc URL. (#821) * Support indexing 2-axes RaggedTensor, Support slicing for RaggedTensor (#825) * Support index 2-axes RaggedTensor, Support slicing for RaggedTensor * Fix compiling errors * Fix unit test * Change RaggedTensor.data to RaggedTensor.values * Fix style * Add docs * Run nightly-cpu when pushing code to nightly-cpu branch * Prune with max_arcs in IntersectDense (#820) * Add checking for array constructor * Prune with max arcs * Minor fix * Fix typo * Fix review comments * Fix typo * Release v1.8 * Create a ragged tensor from a regular tensor. (#827) * Create a ragged tensor from a regular tensor. * Add tests for creating ragged tensors from regular tensors. * Add more tests. * Print ragged tensors in a way like what PyTorch is doing. * Fix test cases. * Trigger GitHub actions manually. (#829) * Run GitHub actions on merging. (#830) * Support printing ragged tensors in a more compact way. (#831) * Support printing ragged tensors in a more compact way. * Disable support for torch 1.3.1 * Fix test failures. * Add levenshtein alignment (#828) * Add levenshtein graph * Contruct k2.RaggedTensor in python part * Fix review comments, return aux_labels in ctc_graph * Fix tests * Fix bug of accessing symbols * Fix bug of accessing symbols * Change argument name, add levenshtein_distance interface * Fix test error, add tests for levenshtein_distance * Fix review comments and add unit test for c++ side * update the interface of levenshtein alignment * Fix review comments * Release v1.9 * Support a[b[i]] where both a and b are ragged tensors. (#833) * Display import error solution message on MacOS (#837) * Fix installation doc. (#841) * Fix installation doc. Remove Windows support. Will fix it later. * Fix style issues. * fix typos in the install instructions (#844) * make cmake adhere to the modernized way of finding packages outside default dirs (#845) * import torch first in the smoke tests to preven SEGFAULT (#846) * Add doc about how to install a CPU version of k2. (#850) * Add doc about how to install a CPU version of k2. * Remove property setter of Fsa.labels * Update Ubuntu version in GitHub CI since 16.04 reaches end-of-life. * Support PyTorch 1.10. (#851) * Fix test cases for k2.union() (#853) * Fix out-of-boundary access (read). (#859) * Update all the example codes in the docs (#861) * Update all the example codes in the docs I have run all the modified codes with the newest version k2. * do some changes * Fix compilation errors with CUB 1.15. (#865) * Update README. (#873) * Update README. * Fix typos. * Fix ctc graph (make aux_labels of final arcs -1) (#877) * Fix LICENSE location to k2 folder (#880) * Release v1.11. (#881) It contains bugfixes. * Update documentation for hash.h (#887) * Update documentation for hash.h * Typo fix * Wrap MonotonicLowerBound (#883) * Wrap MonotonicLowerBound * Add unit tests * Support int64; update documents * Remove extra commas after 'TOPSORTED' properity and fix RaggedTensor constructer parameter 'byte_offset' out-of-range bug. (#892) Co-authored-by: gzchenduisheng * Fix small typos (#896) * Fix k2.ragged.create_ragged_shape2 (#901) Before the fix, we have to specify both `row_splits` and `row_ids` while calling `k2.create_ragged_shape2` even if one of them is `None`. After this fix, we only need to specify one of them. * Add rnnt loss (#891) * Add cpp code of mutual information * mutual information working * Add rnnt loss * Add pruned rnnt loss * Minor Fixes * Minor fixes & fix code style * Fix cpp style * Fix code style * Fix s_begin values in padding positions * Fix bugs related to boundary; Fix s_begin padding value; Add more tests * Minor fixes * Fix comments * Add boundary to pruned loss tests * Use more efficient way to fix boundaries (#906) * Release v1.12 (#907) * Change the sign of the rnnt_loss and add reduction argument (#911) * Add right boundary constrains for s_begin * Minor fixes to the interface of rnnt_loss to make it return positive value * Fix comments * Release a new version * Minor fixes * Minor fixes to the docs * Fix building doc. (#908) * Fix building doc. * Minor fixes. * Minor fixes. * Fix building doc (#912) * Fix building doc * Fix flake8 * Support torch 1.10.x (#914) * Support torch 1.10.x * Fix installing PyTorch. * Update INSTALL.rst (#915) * Update INSTALL.rst Setting a few additional env variables to enable compilation from source *with CUDA GPU computation support enabled* * Fix torch/cuda/python versions in the doc. (#918) * Fix torch/cuda/python versions in the doc. * Minor fixes. * Fix building for CUDA 11.6 (#917) * Fix building for CUDA 11.6 * Minor fixes. * Implement Unstack (#920) * Implement unstack * Remove code does not relate to this PR * Remove for loop on output dim; add Unstack ragged * Add more docs * Fix comments * Fix docs & unit tests * SubsetRagged & PruneRagged (#919) * Extend interface of SubsampleRagged. * Add interface for pruning ragged tensor. * Draft of new RNN-T decoding method * Implements SubsampleRaggedShape * Implements PruneRagged * Rename subsample-> subset * Minor fixes * Fix comments Co-authored-by: Daniel Povey * Add Hash64 (#895) * Add hash64 * Fix tests * Resize hash64 * Fix comments * fix typo * Modified rnnt (#902) * Add modified mutual_information_recursion * Add modified rnnt loss * Using more efficient way to fix boundaries * Fix modified pruned rnnt loss * Fix the s_begin constrains of pruned loss for modified version transducer * Fix Stack (#925) * return the correct layer * unskip the test * Fix 'TypeError' of rnnt_loss_pruned function. (#924) * Fix 'TypeError' of rnnt_loss_simple function. Fix 'TypeError' exception when calling rnnt_loss_simple(..., return_grad=False) at validation steps. * Fix 'MutualInformationRecursionFunction.forward()' return type check error for pytorch < 1.10.x * Modify return type. * Add documents about class MutualInformationRecursionFunction. * Formated code style. * Fix rnnt_loss_smoothed return type. Co-authored-by: gzchenduisheng * Support torch 1.11.0 and CUDA 11.5 (#931) * Support torch 1.11.0 and CUDA 11.5 * Implement Rnnt decoding (#926) * first working draft of rnnt decoding * FormatOutput works... * Different num frames for FormatOutput works * Update docs * Fix comments, break advance into several stages, add more docs * Add python wrapper * Add more docs * Minor fixes * Fix comments * fix building docs (#933) * Release v1.14 * Remove unused DiscountedCumSum. (#936) * Fix compiler warnings. (#937) * Fix compiler warnings. * Minor fixes for RNN-T decoding. (#938) * Minor fixes for RNN-T decoding. * Removes arcs with label 0 from the TrivialGraph. (#939) * Implement linear_fsa_with_self_loops. (#940) * Implement linear_fsa_with_self_loops. * Fix the pruning with max-states (#941) Co-authored-by: Fangjun Kuang Co-authored-by: Piotr Żelasko Co-authored-by: Jan "yenda" Trmal Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Co-authored-by: Ludwig Kürzinger Co-authored-by: Daniel Povey Co-authored-by: drawfish Co-authored-by: gzchenduisheng Co-authored-by: alexei-v-ivanov Co-authored-by: Wang, Guanbo * update v2.0-pre (#953) * Update doc URL. (#821) * Support indexing 2-axes RaggedTensor, Support slicing for RaggedTensor (#825) * Support index 2-axes RaggedTensor, Support slicing for RaggedTensor * Fix compiling errors * Fix unit test * Change RaggedTensor.data to RaggedTensor.values * Fix style * Add docs * Run nightly-cpu when pushing code to nightly-cpu branch * Prune with max_arcs in IntersectDense (#820) * Add checking for array constructor * Prune with max arcs * Minor fix * Fix typo * Fix review comments * Fix typo * Release v1.8 * Create a ragged tensor from a regular tensor. (#827) * Create a ragged tensor from a regular tensor. * Add tests for creating ragged tensors from regular tensors. * Add more tests. * Print ragged tensors in a way like what PyTorch is doing. * Fix test cases. * Trigger GitHub actions manually. (#829) * Run GitHub actions on merging. (#830) * Support printing ragged tensors in a more compact way. (#831) * Support printing ragged tensors in a more compact way. * Disable support for torch 1.3.1 * Fix test failures. * Add levenshtein alignment (#828) * Add levenshtein graph * Contruct k2.RaggedTensor in python part * Fix review comments, return aux_labels in ctc_graph * Fix tests * Fix bug of accessing symbols * Fix bug of accessing symbols * Change argument name, add levenshtein_distance interface * Fix test error, add tests for levenshtein_distance * Fix review comments and add unit test for c++ side * update the interface of levenshtein alignment * Fix review comments * Release v1.9 * Support a[b[i]] where both a and b are ragged tensors. (#833) * Display import error solution message on MacOS (#837) * Fix installation doc. (#841) * Fix installation doc. Remove Windows support. Will fix it later. * Fix style issues. * fix typos in the install instructions (#844) * make cmake adhere to the modernized way of finding packages outside default dirs (#845) * import torch first in the smoke tests to preven SEGFAULT (#846) * Add doc about how to install a CPU version of k2. (#850) * Add doc about how to install a CPU version of k2. * Remove property setter of Fsa.labels * Update Ubuntu version in GitHub CI since 16.04 reaches end-of-life. * Support PyTorch 1.10. (#851) * Fix test cases for k2.union() (#853) * Fix out-of-boundary access (read). (#859) * Update all the example codes in the docs (#861) * Update all the example codes in the docs I have run all the modified codes with the newest version k2. * do some changes * Fix compilation errors with CUB 1.15. (#865) * Update README. (#873) * Update README. * Fix typos. * Fix ctc graph (make aux_labels of final arcs -1) (#877) * Fix LICENSE location to k2 folder (#880) * Release v1.11. (#881) It contains bugfixes. * Update documentation for hash.h (#887) * Update documentation for hash.h * Typo fix * Wrap MonotonicLowerBound (#883) * Wrap MonotonicLowerBound * Add unit tests * Support int64; update documents * Remove extra commas after 'TOPSORTED' properity and fix RaggedTensor constructer parameter 'byte_offset' out-of-range bug. (#892) Co-authored-by: gzchenduisheng * Fix small typos (#896) * Fix k2.ragged.create_ragged_shape2 (#901) Before the fix, we have to specify both `row_splits` and `row_ids` while calling `k2.create_ragged_shape2` even if one of them is `None`. After this fix, we only need to specify one of them. * Add rnnt loss (#891) * Add cpp code of mutual information * mutual information working * Add rnnt loss * Add pruned rnnt loss * Minor Fixes * Minor fixes & fix code style * Fix cpp style * Fix code style * Fix s_begin values in padding positions * Fix bugs related to boundary; Fix s_begin padding value; Add more tests * Minor fixes * Fix comments * Add boundary to pruned loss tests * Use more efficient way to fix boundaries (#906) * Release v1.12 (#907) * Change the sign of the rnnt_loss and add reduction argument (#911) * Add right boundary constrains for s_begin * Minor fixes to the interface of rnnt_loss to make it return positive value * Fix comments * Release a new version * Minor fixes * Minor fixes to the docs * Fix building doc. (#908) * Fix building doc. * Minor fixes. * Minor fixes. * Fix building doc (#912) * Fix building doc * Fix flake8 * Support torch 1.10.x (#914) * Support torch 1.10.x * Fix installing PyTorch. * Update INSTALL.rst (#915) * Update INSTALL.rst Setting a few additional env variables to enable compilation from source *with CUDA GPU computation support enabled* * Fix torch/cuda/python versions in the doc. (#918) * Fix torch/cuda/python versions in the doc. * Minor fixes. * Fix building for CUDA 11.6 (#917) * Fix building for CUDA 11.6 * Minor fixes. * Implement Unstack (#920) * Implement unstack * Remove code does not relate to this PR * Remove for loop on output dim; add Unstack ragged * Add more docs * Fix comments * Fix docs & unit tests * SubsetRagged & PruneRagged (#919) * Extend interface of SubsampleRagged. * Add interface for pruning ragged tensor. * Draft of new RNN-T decoding method * Implements SubsampleRaggedShape * Implements PruneRagged * Rename subsample-> subset * Minor fixes * Fix comments Co-authored-by: Daniel Povey * Add Hash64 (#895) * Add hash64 * Fix tests * Resize hash64 * Fix comments * fix typo * Modified rnnt (#902) * Add modified mutual_information_recursion * Add modified rnnt loss * Using more efficient way to fix boundaries * Fix modified pruned rnnt loss * Fix the s_begin constrains of pruned loss for modified version transducer * Fix Stack (#925) * return the correct layer * unskip the test * Fix 'TypeError' of rnnt_loss_pruned function. (#924) * Fix 'TypeError' of rnnt_loss_simple function. Fix 'TypeError' exception when calling rnnt_loss_simple(..., return_grad=False) at validation steps. * Fix 'MutualInformationRecursionFunction.forward()' return type check error for pytorch < 1.10.x * Modify return type. * Add documents about class MutualInformationRecursionFunction. * Formated code style. * Fix rnnt_loss_smoothed return type. Co-authored-by: gzchenduisheng * Support torch 1.11.0 and CUDA 11.5 (#931) * Support torch 1.11.0 and CUDA 11.5 * Implement Rnnt decoding (#926) * first working draft of rnnt decoding * FormatOutput works... * Different num frames for FormatOutput works * Update docs * Fix comments, break advance into several stages, add more docs * Add python wrapper * Add more docs * Minor fixes * Fix comments * fix building docs (#933) * Release v1.14 * Remove unused DiscountedCumSum. (#936) * Fix compiler warnings. (#937) * Fix compiler warnings. * Minor fixes for RNN-T decoding. (#938) * Minor fixes for RNN-T decoding. * Removes arcs with label 0 from the TrivialGraph. (#939) * Implement linear_fsa_with_self_loops. (#940) * Implement linear_fsa_with_self_loops. * Fix the pruning with max-states (#941) * Rnnt allow different encoder/decoder dims (#945) * Allow different encoder and decoder dim in rnnt_pruning * Bug fixes * Supporting building k2 on Windows (#946) * Fix nightly windows CPU build (#948) * Fix nightly building k2 for windows. * Run nightly build only if there are new commits. * Check the versions of PyTorch and CUDA at the import time. (#949) * Check the versions of PyTorch and CUDA at the import time. * More straightforward message when CUDA support is missing (#950) * Implement ArrayOfRagged (#927) * Implement ArrayOfRagged * Fix issues and pass tests * fix style * change few statements of functions and move the definiation of template Array1OfRagged to header file * add offsets test code * Fix precision (#951) * Fix precision * Using different pow version for windows and *nix * Use int64_t pow * Minor fixes Co-authored-by: Fangjun Kuang Co-authored-by: Piotr Żelasko Co-authored-by: Jan "yenda" Trmal Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Co-authored-by: Ludwig Kürzinger Co-authored-by: Daniel Povey Co-authored-by: drawfish Co-authored-by: gzchenduisheng Co-authored-by: alexei-v-ivanov Co-authored-by: Wang, Guanbo Co-authored-by: Nickolay V. Shmyrev Co-authored-by: LvHang * Add C++ Rnnt demo (#947) * rnnt_demo compiles * Change graph in RnntDecodingStream from shared_ptr to const reference * Change out_map from Array1 to Ragged * Add rnnt demo * Minor fixes * Add more docs * Support log_add when getting best path * Port kaldi::ParseOptions for parsing commandline options. (#974) * Port kaldi::ParseOptions for parsing commandline options. * Add more tests. * More tests. * Greedy search and modified beam search for pruned stateless RNN-T. (#975) * First version of greedy search. * WIP: Implement modified beam search and greedy search for pruned RNN-T. * Implement modified beam search. * Fix compiler warnings * Fix style issues * Update torch_api.h to include APIs for CTC decoding Co-authored-by: Wei Kang Co-authored-by: Piotr Żelasko Co-authored-by: Jan "yenda" Trmal Co-authored-by: pingfengluo Co-authored-by: Mingshuang Luo <37799481+luomingshuang@users.noreply.github.com> Co-authored-by: Ludwig Kürzinger Co-authored-by: Daniel Povey Co-authored-by: drawfish Co-authored-by: gzchenduisheng Co-authored-by: alexei-v-ivanov Co-authored-by: Wang, Guanbo Co-authored-by: Nickolay V. Shmyrev Co-authored-by: LvHang --- .flake8 | 2 + .github/scripts/run-test.sh | 144 ++++ .github/workflows/build-cpu-macos.yml | 34 +- .github/workflows/build-cpu-ubuntu.yml | 35 +- .github/workflows/build-cpu-windows.yml | 34 +- .github/workflows/build-cuda-ubuntu.yml | 3 + .github/workflows/run-tests-cpu.yml | 12 +- .github/workflows/run-tests.yml | 3 + .gitignore | 2 + CMakeLists.txt | 4 +- cmake/googletest.cmake | 2 +- cmake/kaldifeat.cmake | 52 ++ cmake/moderngpu.cmake | 2 +- cmake/pybind11.cmake | 2 +- cmake/torch.cmake | 3 + .../images/torch_ge_1.6.0-green.svg | 1 + docs/source/installation/pip_pypi.rst | 4 +- k2/CMakeLists.txt | 13 + k2/csrc/fake_cuda.h | 4 +- k2/csrc/fsa_algo.h | 4 +- k2/csrc/fsa_utils.cu | 12 +- k2/csrc/intersect_dense_pruned.cu | 374 ++++++--- k2/csrc/log.cu | 23 + k2/csrc/math.h | 3 +- k2/csrc/online_dense_intersector.h | 91 ++ k2/csrc/pytorch_context.cu | 3 +- k2/csrc/ragged.h | 22 + k2/csrc/ragged_ops.cu | 2 +- k2/csrc/ragged_ops.h | 2 +- k2/csrc/rm_epsilon.cu | 4 +- k2/csrc/torch_api.h | 60 ++ k2/python/csrc/CMakeLists.txt | 9 + k2/python/csrc/torch/v2/k2.cu | 2 +- k2/python/k2/rnnt_loss.py | 1 - k2/torch/CMakeLists.txt | 3 + k2/torch/README.md | 10 + k2/torch/bin/CMakeLists.txt | 89 ++ k2/torch/bin/attention_rescore.cu | 372 +++++++++ k2/torch/bin/ctc_decode.cu | 208 +++++ k2/torch/bin/hlg_decode.cu | 219 +++++ k2/torch/bin/ngram_lm_rescore.cu | 245 ++++++ k2/torch/bin/online_decode.cu | 322 +++++++ k2/torch/bin/pruned_stateless_transducer.cu | 193 +++++ k2/torch/bin/rnnt_demo.cu | 355 ++++++++ k2/torch/csrc/CMakeLists.txt | 67 ++ k2/torch/csrc/CPPLINT.cfg | 3 + k2/torch/csrc/beam_search.cu | 393 +++++++++ k2/torch/csrc/beam_search.h | 56 ++ k2/torch/csrc/decode.cu | 207 +++++ k2/torch/csrc/decode.h | 118 +++ k2/torch/csrc/dense_fsa_vec.cu | 173 ++++ k2/torch/csrc/dense_fsa_vec.h | 86 ++ k2/torch/csrc/dense_fsa_vec_test.cu | 161 ++++ k2/torch/csrc/deserialization.cu | 455 ++++++++++ k2/torch/csrc/deserialization.h | 68 ++ k2/torch/csrc/deserialization_test.cu | 230 +++++ k2/torch/csrc/features.cc | 56 ++ k2/torch/csrc/features.h | 49 ++ k2/torch/csrc/fsa_algo.cu | 297 +++++++ k2/torch/csrc/fsa_algo.h | 181 ++++ k2/torch/csrc/fsa_algo_inl.h | 46 + k2/torch/csrc/fsa_class.cu | 207 +++++ k2/torch/csrc/fsa_class.h | 268 ++++++ k2/torch/csrc/fsa_class_test.cu | 130 +++ k2/torch/csrc/hypothesis.cu | 56 ++ k2/torch/csrc/hypothesis.h | 112 +++ k2/torch/csrc/hypothesis_test.cu | 43 + k2/torch/csrc/nbest.cu | 116 +++ k2/torch/csrc/nbest.h | 98 +++ k2/torch/csrc/parse_options.cu | 783 ++++++++++++++++++ k2/torch/csrc/parse_options.h | 268 ++++++ k2/torch/csrc/parse_options_test.cu | 302 +++++++ k2/torch/csrc/symbol_table.cu | 78 ++ k2/torch/csrc/symbol_table.h | 61 ++ k2/torch/csrc/test_deserialization_data.h | 451 ++++++++++ k2/torch/csrc/test_wave_data.h | 23 + k2/torch/csrc/utils.cu | 162 ++++ k2/torch/csrc/utils.h | 231 ++++++ k2/torch/csrc/wave_reader.cu | 140 ++++ k2/torch/csrc/wave_reader.h | 81 ++ k2/torch/csrc/wave_reader_test.cu | 37 + scripts/github_actions/fix_torch.py | 46 + scripts/github_actions/fix_torch.sh | 25 + .../k2-torch-api-test/cmake/k2.cmake | 18 +- setup.py | 194 +++-- 85 files changed, 9302 insertions(+), 258 deletions(-) create mode 100755 .github/scripts/run-test.sh create mode 100644 cmake/kaldifeat.cmake create mode 100644 docs/source/installation/images/torch_ge_1.6.0-green.svg create mode 100644 k2/csrc/online_dense_intersector.h create mode 100644 k2/torch/CMakeLists.txt create mode 100644 k2/torch/README.md create mode 100644 k2/torch/bin/CMakeLists.txt create mode 100644 k2/torch/bin/attention_rescore.cu create mode 100644 k2/torch/bin/ctc_decode.cu create mode 100644 k2/torch/bin/hlg_decode.cu create mode 100644 k2/torch/bin/ngram_lm_rescore.cu create mode 100644 k2/torch/bin/online_decode.cu create mode 100644 k2/torch/bin/pruned_stateless_transducer.cu create mode 100644 k2/torch/bin/rnnt_demo.cu create mode 100644 k2/torch/csrc/CMakeLists.txt create mode 100644 k2/torch/csrc/CPPLINT.cfg create mode 100644 k2/torch/csrc/beam_search.cu create mode 100644 k2/torch/csrc/beam_search.h create mode 100644 k2/torch/csrc/decode.cu create mode 100644 k2/torch/csrc/decode.h create mode 100644 k2/torch/csrc/dense_fsa_vec.cu create mode 100644 k2/torch/csrc/dense_fsa_vec.h create mode 100644 k2/torch/csrc/dense_fsa_vec_test.cu create mode 100644 k2/torch/csrc/deserialization.cu create mode 100644 k2/torch/csrc/deserialization.h create mode 100644 k2/torch/csrc/deserialization_test.cu create mode 100644 k2/torch/csrc/features.cc create mode 100644 k2/torch/csrc/features.h create mode 100644 k2/torch/csrc/fsa_algo.cu create mode 100644 k2/torch/csrc/fsa_algo.h create mode 100644 k2/torch/csrc/fsa_algo_inl.h create mode 100644 k2/torch/csrc/fsa_class.cu create mode 100644 k2/torch/csrc/fsa_class.h create mode 100644 k2/torch/csrc/fsa_class_test.cu create mode 100644 k2/torch/csrc/hypothesis.cu create mode 100644 k2/torch/csrc/hypothesis.h create mode 100644 k2/torch/csrc/hypothesis_test.cu create mode 100644 k2/torch/csrc/nbest.cu create mode 100644 k2/torch/csrc/nbest.h create mode 100644 k2/torch/csrc/parse_options.cu create mode 100644 k2/torch/csrc/parse_options.h create mode 100644 k2/torch/csrc/parse_options_test.cu create mode 100644 k2/torch/csrc/symbol_table.cu create mode 100644 k2/torch/csrc/symbol_table.h create mode 100644 k2/torch/csrc/test_deserialization_data.h create mode 100644 k2/torch/csrc/test_wave_data.h create mode 100644 k2/torch/csrc/utils.cu create mode 100644 k2/torch/csrc/utils.h create mode 100644 k2/torch/csrc/wave_reader.cu create mode 100644 k2/torch/csrc/wave_reader.h create mode 100644 k2/torch/csrc/wave_reader_test.cu create mode 100755 scripts/github_actions/fix_torch.py create mode 100755 scripts/github_actions/fix_torch.sh diff --git a/.flake8 b/.flake8 index e938fd70f..1c225303b 100644 --- a/.flake8 +++ b/.flake8 @@ -26,3 +26,5 @@ ignore = F401, # W504, line break after binary operator W504, + # W503, line break before binary operator + W503, diff --git a/.github/scripts/run-test.sh b/.github/scripts/run-test.sh new file mode 100755 index 000000000..8b975588f --- /dev/null +++ b/.github/scripts/run-test.sh @@ -0,0 +1,144 @@ +#!/usr/bin/env bash + +set -ex + +log() { + # This function is from espnet + local fname=${BASH_SOURCE[1]##*/} + echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*" +} + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-conformer-ctc-jit-bpe-500-2021-11-09 +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "exp/cpu_jit.pt" + +git lfs pull --include "data/lang_bpe_500/tokens.txt" +git lfs pull --include "data/lang_bpe_500/HLG.pt" +git lfs pull --include "data/lang_bpe_500/words.txt" + +git lfs pull --include "data/lm/G_4_gram.pt" +popd + +log "Test CTC decode (librispeech)" + +./build/bin/ctc_decode \ + --use_gpu false \ + --nn_model $repo/exp/cpu_jit.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + +log "Test HLG decode (librispeech)" + +./build/bin/hlg_decode \ + --use_gpu false \ + --nn_model $repo/exp/cpu_jit.pt \ + --hlg $repo/data/lang_bpe_500/HLG.pt \ + --word_table $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +if [ $(uname) == "Darwin" ]; then + # GitHub only provides 7 GB RAM for Linux/Windows + # It has 14 GB RAM for macOS. This test requires a lot of RAM. + log "Test n-gram LM rescore (librispeech)" + ./build/bin/ngram_lm_rescore \ + --use_gpu false \ + --nn_model $repo/exp/cpu_jit.pt \ + --hlg $repo/data/lang_bpe_500/HLG.pt \ + --g $repo/data/lm/G_4_gram.pt \ + --ngram_lm_scale 1.0 \ + --word_table $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + + log "Test n-gram LM rescore + attention rescore (librispeech)" + ./build/bin/attention_rescore \ + --use_gpu false \ + --nn_model $repo/exp/cpu_jit.pt \ + --hlg $repo/data/lang_bpe_500/HLG.pt \ + --g $repo/data/lm/G_4_gram.pt \ + --ngram_lm_scale 1.0 \ + --attention_scale 1.0 \ + --num_paths 100 \ + --nbest_scale 0.5 \ + --word_table $repo/data/lang_bpe_500/words.txt \ + --sos_id 1 \ + --eos_id 1 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav +fi + +log "Streaming CTC decoding" + +./build/bin/online_decode \ + --use_ctc_decoding true \ + --jit_pt $repo/exp/cpu_jit.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +log "Streaming HLG decoding" + +./build/bin/online_decode \ + --use_ctc_decoding false \ + --jit_pt $repo/exp/cpu_jit.pt \ + --hlg $repo/data/lang_bpe_500/HLG.pt \ + --word_table $repo/data/lang_bpe_500/words.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +rm -rf repo + +# Now for RNN-T + +repo_url=https://huggingface.co/csukuangfj/icefall-asr-librispeech-pruned-transducer-stateless3-2022-05-13 +repo=$(basename $repo_url) +log "Download pretrained model and test-data from $repo_url" + +GIT_LFS_SKIP_SMUDGE=1 git clone $repo_url +pushd $repo +git lfs pull --include "exp/cpu_jit.pt" +git lfs pull --include "data/lang_bpe_500/LG.pt" +popd + +log "Test RNN-T decoding" + +./build/bin/pruned_stateless_transducer \ + --use-gpu=false \ + --nn-model=$repo/exp/cpu_jit.pt \ + --tokens=$repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +./build/bin/rnnt_demo \ + --use_lg false \ + --jit_pt $repo/exp/cpu_jit.pt \ + --tokens $repo/data/lang_bpe_500/tokens.txt \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav + +./build/bin/rnnt_demo \ + --use_lg true \ + --jit_pt $repo/exp/cpu_jit.pt \ + --lg $repo/data/lang_bpe_500/LG.pt \ + --word_table $repo/data/lang_bpe_500/words.txt \ + --beam 8 \ + --max_contexts 8 \ + --max_states 64 \ + $repo/test_wavs/1089-134686-0001.wav \ + $repo/test_wavs/1221-135766-0001.wav \ + $repo/test_wavs/1221-135766-0002.wav diff --git a/.github/workflows/build-cpu-macos.yml b/.github/workflows/build-cpu-macos.yml index c4cc7de7d..5706c8168 100644 --- a/.github/workflows/build-cpu-macos.yml +++ b/.github/workflows/build-cpu-macos.yml @@ -105,11 +105,29 @@ jobs: python3 -c "import torch; print('torch version:', torch.__version__)" + - name: Build wheel + shell: bash + run: | + export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF" + export K2_MAKE_ARGS="-j2" + python3 setup.py bdist_wheel + ls -lh dist/ + ls -lh build/* + + - name: Upload Wheel + uses: actions/upload-artifact@v2 + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-macos-10.15-cpu + path: dist/*.whl + - name: Build k2 shell: bash + env: + torch: ${{ matrix.torch }} run: | pwd - mkdir build + ./scripts/github_actions/fix_torch.sh + mkdir -p build cd build cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF .. cat k2/csrc/version.h @@ -123,17 +141,3 @@ jobs: cd build ctest --output-on-failure - - name: Build wheel - shell: bash - run: | - export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF" - export K2_MAKE_ARGS="-j2" - python3 setup.py bdist_wheel - ls -lh dist/ - ls -lh build/* - - - name: Upload Wheel - uses: actions/upload-artifact@v2 - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-macos-10.15-cpu - path: dist/*.whl diff --git a/.github/workflows/build-cpu-ubuntu.yml b/.github/workflows/build-cpu-ubuntu.yml index 4ca13431c..fe55d01a9 100644 --- a/.github/workflows/build-cpu-ubuntu.yml +++ b/.github/workflows/build-cpu-ubuntu.yml @@ -107,11 +107,29 @@ jobs: python3 -c "import torch; print('torch version:', torch.__version__)" + - name: Build wheel + shell: bash + run: | + export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF" + export K2_MAKE_ARGS="-j2" + python3 setup.py bdist_wheel + ls -lh dist/ + ls -lh build/* + + - name: Upload Wheel + uses: actions/upload-artifact@v2 + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu + path: dist/*.whl + - name: Build k2 shell: bash + env: + torch: ${{ matrix.torch }} run: | pwd - mkdir build + ./scripts/github_actions/fix_torch.sh + mkdir -p build cd build cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF .. cat k2/csrc/version.h @@ -124,18 +142,3 @@ jobs: run: | cd build ctest --output-on-failure - - - name: Build wheel - shell: bash - run: | - export K2_CMAKE_ARGS="-DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF" - export K2_MAKE_ARGS="-j2" - python3 setup.py bdist_wheel - ls -lh dist/ - ls -lh build/* - - - name: Upload Wheel - uses: actions/upload-artifact@v2 - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-ubuntu-18.04-cpu - path: dist/*.whl diff --git a/.github/workflows/build-cpu-windows.yml b/.github/workflows/build-cpu-windows.yml index 94334a7b0..93ec5d372 100644 --- a/.github/workflows/build-cpu-windows.yml +++ b/.github/workflows/build-cpu-windows.yml @@ -100,10 +100,29 @@ jobs: cmake --version cmake --help + + + - name: Build wheel + shell: bash + run: | + export K2_CMAKE_ARGS="-DK2_WITH_CUDA=OFF -DCMAKE_BUILD_TYPE=Release" + python3 setup.py bdist_wheel + ls -lh dist/ + pip install ./dist/*.whl + + - name: Upload Wheel + uses: actions/upload-artifact@v2 + with: + name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-windows-2019-cpu + path: dist/*.whl + - name: Configure CMake shell: bash + env: + torch: ${{ matrix.torch }} run: | - mkdir build_release + python3 ./scripts/github_actions/fix_torch.py + mkdir -p build_release cd build_release cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE -DK2_WITH_CUDA=OFF .. ls -lh @@ -119,14 +138,6 @@ jobs: ls -lh lib/*/* ls -lh bin/*/* - - name: Build wheel - shell: bash - run: | - export K2_CMAKE_ARGS="-DK2_WITH_CUDA=OFF -DCMAKE_BUILD_TYPE=Release" - python3 setup.py bdist_wheel - ls -lh dist/ - pip install ./dist/*.whl - - name: Run tests shell: bash run: | @@ -134,8 +145,3 @@ jobs: # disable python tests for k2host ctest -C Release --output-on-failure -E host - - name: Upload Wheel - uses: actions/upload-artifact@v2 - with: - name: torch-${{ matrix.torch }}-python-${{ matrix.python-version }}-windows-2019-cpu - path: dist/*.whl diff --git a/.github/workflows/build-cuda-ubuntu.yml b/.github/workflows/build-cuda-ubuntu.yml index 59926f465..a8ab9ee7f 100644 --- a/.github/workflows/build-cuda-ubuntu.yml +++ b/.github/workflows/build-cuda-ubuntu.yml @@ -139,8 +139,11 @@ jobs: - name: Configure CMake shell: bash + env: + torch: ${{ matrix.torch }} run: | pwd + ./scripts/github_actions/fix_torch.sh mkdir build cd build cmake -DCMAKE_BUILD_TYPE=$BUILD_TYPE .. diff --git a/.github/workflows/run-tests-cpu.yml b/.github/workflows/run-tests-cpu.yml index 092f1f566..63fc08db2 100644 --- a/.github/workflows/run-tests-cpu.yml +++ b/.github/workflows/run-tests-cpu.yml @@ -24,6 +24,7 @@ on: - master paths: - '.github/workflows/run-tests-cpu.yml' + - '.github/scripts/run-test.sh' - 'CMakeLists.txt' - 'cmake/**' - 'k2/csrc/**' @@ -32,6 +33,7 @@ on: types: [labeled] paths: - '.github/workflows/run-tests-cpu.yml' + - '.github/scripts/run-test.sh' - 'CMakeLists.txt' - 'cmake/**' - 'k2/csrc/**' @@ -43,7 +45,7 @@ concurrency: jobs: run-tests-cpu: - if: github.event.label.name == 'ready' || github.event_name == 'push' + if: github.event.label.name == 'ready' || github.event.label.name == 'cpp-test' || github.event_name == 'push' runs-on: ${{ matrix.os }} strategy: fail-fast: false @@ -109,8 +111,11 @@ jobs: - name: Configure CMake shell: bash + env: + torch: ${{ matrix.torch }} run: | pwd + ./scripts/github_actions/fix_torch.sh mkdir build cd build cmake -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DK2_WITH_CUDA=OFF .. @@ -141,3 +146,8 @@ jobs: K2_LOG_LEVEL=WARNING ./bin/cu_log_test --gtest_filter="Log.Cpu" K2_LOG_LEVEL=ERROR ./bin/cu_log_test --gtest_filter="Log.Cpu" K2_LOG_LEVEL=FATAL ./bin/cu_log_test --gtest_filter="Log.Cpu" + + - name: Run C++ API tests + shell: bash + run: | + .github/scripts/run-test.sh diff --git a/.github/workflows/run-tests.yml b/.github/workflows/run-tests.yml index 4bad9994c..b5ffe1420 100644 --- a/.github/workflows/run-tests.yml +++ b/.github/workflows/run-tests.yml @@ -119,8 +119,11 @@ jobs: - name: Configure CMake shell: bash + env: + torch: ${{ matrix.torch }} run: | pwd + ./scripts/github_actions/fix_torch.sh mkdir build cd build cmake -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} .. diff --git a/.gitignore b/.gitignore index 5639d1e6e..c630f1318 100644 --- a/.gitignore +++ b/.gitignore @@ -586,3 +586,5 @@ Mkfile.old dkms.conf !.github/** +!k2/torch/bin +*-bak diff --git a/CMakeLists.txt b/CMakeLists.txt index 3c28ccb94..9125b3050 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -354,6 +354,7 @@ if(WIN32) # 4624: destructor was implicitly defined as deleted # 4700: uninitialized local variable 'device' used # 4722: destructor never returns + # 4805: '|': unsafe mix of type 'uintptr_t' and type 'bool' in operation # 4819: The file contains a character that cannot be presented in the current code page. # 4838: conversion from 'type_1' to 'type_2' requires a narrowing conversion # 4996: "getenv": This function is unsafe @@ -366,8 +367,8 @@ if(WIN32) /wd4101 /wd4190 /wd4224 - /wd4251 /wd4244 + /wd4251 /wd4267 /wd4275 /wd4305 @@ -376,6 +377,7 @@ if(WIN32) /wd4624 /wd4700 /wd4722 + /wd4805 /wd4819 /wd4838 /wd4996 diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index 0252268cd..d598d05de 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -39,7 +39,7 @@ function(download_googltest) FetchContent_GetProperties(googletest) if(NOT googletest_POPULATED) - message(STATUS "Downloading googletest") + message(STATUS "Downloading googletest from ${googletest_URL}") FetchContent_Populate(googletest) endif() message(STATUS "googletest is downloaded to ${googletest_SOURCE_DIR}") diff --git a/cmake/kaldifeat.cmake b/cmake/kaldifeat.cmake new file mode 100644 index 000000000..3ba1a3ec6 --- /dev/null +++ b/cmake/kaldifeat.cmake @@ -0,0 +1,52 @@ +# Copyright 2021 Fangjun Kuang (csukuangfj@gmail.com) +# See ../LICENSE for clarification regarding multiple authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +function(download_kaldifeat) + if(CMAKE_VERSION VERSION_LESS 3.11) + # FetchContent is available since 3.11, + # we've copied it to ${CMAKE_SOURCE_DIR}/cmake/Modules + # so that it can be used in lower CMake versions. + message(STATUS "Use FetchContent provided by k2") + list(APPEND CMAKE_MODULE_PATH ${CMAKE_SOURCE_DIR}/cmake/Modules) + endif() + + include(FetchContent) + + set(kaldifeat_URL "https://github.com/csukuangfj/kaldifeat/archive/refs/tags/v1.20.tar.gz") + set(kaldifeat_HASH "SHA256=509110abbb4bf510831a9abbf1f3e7a0768f9e505d7f25defeaf6545566e1aaf") + + set(kaldifeat_BUILD_TESTS OFF CACHE BOOL "" FORCE) + + FetchContent_Declare(kaldifeat + URL ${kaldifeat_URL} + URL_HASH ${kaldifeat_HASH} + ) + + FetchContent_GetProperties(kaldifeat) + if(NOT kaldifeat_POPULATED) + message(STATUS "Downloading kaldifeat from ${kaldifeat_URL}") + FetchContent_Populate(kaldifeat) + endif() + message(STATUS "kaldifeat is downloaded to ${kaldifeat_SOURCE_DIR}") + message(STATUS "kaldifeat's binary dir is ${kaldifeat_BINARY_DIR}") + + set(KALDIFEAT_TORCH_VERSION_MAJOR ${K2_TORCH_VERSION_MAJOR}) + set(KALDIFEAT_TORCH_VERSION_MINOR ${K2_TORCH_VERSION_MINOR}) + add_subdirectory(${kaldifeat_SOURCE_DIR} ${kaldifeat_BINARY_DIR} EXCLUDE_FROM_ALL) + + target_include_directories(kaldifeat_core PUBLIC ${kaldifeat_SOURCE_DIR}) +endfunction() + +download_kaldifeat() diff --git a/cmake/moderngpu.cmake b/cmake/moderngpu.cmake index 030853fa0..e94dfb961 100644 --- a/cmake/moderngpu.cmake +++ b/cmake/moderngpu.cmake @@ -31,7 +31,7 @@ function(download_moderngpu) FetchContent_GetProperties(moderngpu) if(NOT moderngpu) - message(STATUS "Downloading moderngpu") + message(STATUS "Downloading moderngpu from ${moderngpu_URL}") FetchContent_Populate(moderngpu) endif() message(STATUS "moderngpu is downloaded to ${moderngpu_SOURCE_DIR}") diff --git a/cmake/pybind11.cmake b/cmake/pybind11.cmake index 314fdc252..bc04774fd 100644 --- a/cmake/pybind11.cmake +++ b/cmake/pybind11.cmake @@ -44,7 +44,7 @@ function(download_pybind11) FetchContent_GetProperties(pybind11) if(NOT pybind11_POPULATED) - message(STATUS "Downloading pybind11") + message(STATUS "Downloading pybind11 from ${pybind11_URL}") FetchContent_Populate(pybind11) endif() message(STATUS "pybind11 is downloaded to ${pybind11_SOURCE_DIR}") diff --git a/cmake/torch.cmake b/cmake/torch.cmake index f617a77c9..69fdf0c13 100644 --- a/cmake/torch.cmake +++ b/cmake/torch.cmake @@ -30,6 +30,9 @@ execute_process( OUTPUT_VARIABLE K2_TORCH_VERSION_MINOR ) +set(K2_TORCH_VERSION "${K2_TORCH_VERSION_MAJOR}.${K2_TORCH_VERSION_MINOR}") +message(STATUS "K2_TORCH_VERSION: ${K2_TORCH_VERSION}") + execute_process( COMMAND "${PYTHON_EXECUTABLE}" -c "import torch; print(torch.__version__)" OUTPUT_STRIP_TRAILING_WHITESPACE diff --git a/docs/source/installation/images/torch_ge_1.6.0-green.svg b/docs/source/installation/images/torch_ge_1.6.0-green.svg new file mode 100644 index 000000000..d3ece9a17 --- /dev/null +++ b/docs/source/installation/images/torch_ge_1.6.0-green.svg @@ -0,0 +1 @@ +torch: >= 1.6.0torch>= 1.6.0 \ No newline at end of file diff --git a/docs/source/installation/pip_pypi.rst b/docs/source/installation/pip_pypi.rst index 6d33e1e3b..662e5fa07 100644 --- a/docs/source/installation/pip_pypi.rst +++ b/docs/source/installation/pip_pypi.rst @@ -65,8 +65,8 @@ You should see something like below: OS used to build k2: Ubuntu 18.04.5 LTS CMake version: 3.21.6 GCC version: 7.5.0 - CMAKE_CUDA_FLAGS: -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_50,code=sm_50 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_60,code=sm_60 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_61,code=sm_61 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_70,code=sm_70 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -DONNX_NAMESPACE=onnx_c2 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_75,code=compute_75 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall --compiler-options -Wno-strict-overflow --compiler-options -Wno-unknown-pragmas - CMAKE_CXX_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable -Wno-strict-overflow + CMAKE_CUDA_FLAGS: -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_35,code=sm_35 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_50,code=sm_50 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_60,code=sm_60 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_61,code=sm_61 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_70,code=sm_70 -lineinfo --expt-extended-lambda -use_fast_math -Xptxas=-w --expt-extended-lambda -gencode arch=compute_75,code=sm_75 -DONNX_NAMESPACE=onnx_c2 -gencode arch=compute_35,code=sm_35 -gencode arch=compute_50,code=sm_50 -gencode arch=compute_52,code=sm_52 -gencode arch=compute_60,code=sm_60 -gencode arch=compute_61,code=sm_61 -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75 -gencode arch=compute_75,code=compute_75 -Xcudafe --diag_suppress=cc_clobber_ignored,--diag_suppress=integer_sign_change,--diag_suppress=useless_using_declaration,--diag_suppress=set_but_not_used,--diag_suppress=field_without_dll_interface,--diag_suppress=base_class_has_different_dll_interface,--diag_suppress=dll_interface_conflict_none_assumed,--diag_suppress=dll_interface_conflict_dllexport_assumed,--diag_suppress=implicit_return_from_non_void_function,--diag_suppress=unsigned_compare_with_zero,--diag_suppress=declared_but_not_referenced,--diag_suppress=bad_friend_decl --expt-relaxed-constexpr --expt-extended-lambda -D_GLIBCXX_USE_CXX11_ABI=0 --compiler-options -Wall --compiler-options -Wno-strict-overflow --compiler-options -Wno-unknown-pragmas + CMAKE_CXX_FLAGS: -D_GLIBCXX_USE_CXX11_ABI=0 -Wno-unused-variable -Wno-strict-overflow PyTorch version used to build k2: 1.12.1+cu102 PyTorch is using Cuda: 10.2 NVTX enabled: True diff --git a/k2/CMakeLists.txt b/k2/CMakeLists.txt index c70d00c60..6a7839d0e 100644 --- a/k2/CMakeLists.txt +++ b/k2/CMakeLists.txt @@ -1,2 +1,15 @@ add_subdirectory(csrc) add_subdirectory(python) + +if(K2_USE_PYTORCH) + # We use K2_TORCH_VERSION instead of TORCH_VERSION + # since TORCH_VERSION may contain something like "+cpu", "+cu113" + if(K2_TORCH_VERSION VERSION_GREATER_EQUAL 1.8 OR NOT K2_WITH_CUDA) + message(STATUS "Including k2/torch. K2_TORCH_VERSION is ${K2_TORCH_VERSION}") + include(kaldifeat) + add_subdirectory(torch) + else() + message(WARNING "Please use at least torch 1.8.0 when CUDA \ + is enabled - skipping compiling k2/torch. Current torch version: ${TORCH_VERSION}") + endif() +endif() diff --git a/k2/csrc/fake_cuda.h b/k2/csrc/fake_cuda.h index 9acdb2ae7..29915248e 100644 --- a/k2/csrc/fake_cuda.h +++ b/k2/csrc/fake_cuda.h @@ -43,7 +43,9 @@ #define __forceinline__ __forceinline #endif -#define K2_NIY K2_LOG(FATAL) << "Not implemented yet. Don't call me!" +#define K2_NIY \ + K2_LOG(FATAL) \ + << "Not implemented yet. Don't call me! (Not Compiled with CUDA ?)" using cudaError_t = int32_t; using cudaStream_t = int32_t *; diff --git a/k2/csrc/fsa_algo.h b/k2/csrc/fsa_algo.h index 92dde5dfe..cecf6940a 100644 --- a/k2/csrc/fsa_algo.h +++ b/k2/csrc/fsa_algo.h @@ -669,7 +669,7 @@ Fsa Closure(Fsa &fsa, Array1 *arc_map = nullptr); @param [in] labels_shape This might correspond to the shape of the `aux_labels`; it is a shape with `labels_shape.NumAxes() == 2` and - `arcs.shape.Dim0() == fsas.NumElements()`. + `labels_shape.Dim0() == fsas.NumElements()`. The i'th arc of the FsaVec will be expanded to a sequence of `max(1, l)` arcs, where l is the length of the i'th list in `labels_shape` @@ -724,7 +724,7 @@ FsaOrVec ExpandArcs(FsaOrVec &fsas, RaggedShape &labels_shape, to n (which also implies that aux_labels for final-arc must at least contain -1). For other arcs that are not final-arcs, - the corresponding aux_labels must contain no + the corresponding aux_labels must not contain -1. @param [out] dest Output Fsa or FsaVec, it's the inverted Fsa. At exit dest.NumAxes() == src.NumAxes() and num-states of it diff --git a/k2/csrc/fsa_utils.cu b/k2/csrc/fsa_utils.cu index 297ac6fea..5e3e3ba16 100644 --- a/k2/csrc/fsa_utils.cu +++ b/k2/csrc/fsa_utils.cu @@ -114,9 +114,13 @@ static Fsa K2FsaFromStream(std::istringstream &is, << ": src-state < 0 or dest-state < 0."; } for (int32_t i = 0; i < num_aux_labels; i++) { - int32_t aux; + float aux; line_is >> aux; - aux_labels.push_back(aux); + if ((int32_t)aux != aux) { + K2_LOG(FATAL) << "Invalid line " << line + << ": Expected an integer for aux_labels"; + } + aux_labels.push_back((int32_t)aux); } for (int32_t i = 0; i < num_ragged_labels; i++) { line_is >> std::ws; @@ -616,7 +620,7 @@ std::string FsaToString(const Fsa &fsa, bool openfst, /*= false*/ char line_sep = '\n'; for (int32_t a = 0; a != num_arcs; ++a) { const auto &arc = arcs[a]; - if (openfst & arc.label == -1) { + if (openfst && arc.label == -1) { os << arc.src_state << sep; } else { os << arc.src_state << sep << arc.dest_state << sep << arc.label << sep; @@ -633,7 +637,7 @@ std::string FsaToString(const Fsa &fsa, bool openfst, /*= false*/ os << (scale * arc.score) << line_sep; } - if (num_arcs > 0 & !openfst) { + if (num_arcs > 0 && !openfst) { int32_t final_state = fsa.shape.Dim0() - 1; os << final_state << line_sep; } else { diff --git a/k2/csrc/intersect_dense_pruned.cu b/k2/csrc/intersect_dense_pruned.cu index 6ef8f4f1b..3e498121e 100644 --- a/k2/csrc/intersect_dense_pruned.cu +++ b/k2/csrc/intersect_dense_pruned.cu @@ -1,5 +1,6 @@ /** - * Copyright 2020 Xiaomi Corporation (authors: Daniel Povey) + * Copyright 2020 Xiaomi Corporation (authors: Daniel Povey, + * Wei Kang) * * See LICENSE for clarification regarding multiple authors * @@ -25,6 +26,7 @@ #include "k2/csrc/fsa_algo.h" #include "k2/csrc/fsa_utils.h" #include "k2/csrc/hash.h" +#include "k2/csrc/online_dense_intersector.h" #include "k2/csrc/ragged_ops.h" #include "k2/csrc/thread_pool.h" @@ -115,14 +117,15 @@ class MultiGraphDenseIntersectPruned { something more complicated. Must have either the same Dim0() as b_fsas, or Dim0()==1 in which case the graph is shared. - @param [in] b_fsas The neural-net output, with each frame containing the - log-likes of each phone. A series of sequences of - (in general) different length. - @param [in] search_beam "Default" search/decoding beam. The actual + @param [in] num_seqs The number of sequences to do intersection at a + time, i.e. batch size. The input DenseFsaVec in + `Intersect` function MUST have `Dim0()` equals to + this. + @param [in] search_beam "Default" search/decoding beam. The actual beam is dynamic and also depends on max_active and min_active. - @param [in] output_beam Beam for pruning the output FSA, will - typically be smaller than search_beam. + @param [in] output_beam Beam for pruning the output FSA, will + typically be smaller than search_beam. @param [in] min_active Minimum number of FSA states that are allowed to be active on any given frame for any given intersection/composition task. This is advisory, @@ -133,29 +136,34 @@ class MultiGraphDenseIntersectPruned { intersection/composition task. This is advisory, in that it will try not to exceed that but may not always succeed. This determines the hash size. + @param [in] online_decoding True for online decoding (i.e. chunk by + chunk decoding), false for running in batch + mode. */ - MultiGraphDenseIntersectPruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, + MultiGraphDenseIntersectPruned(FsaVec &a_fsas, int32_t num_seqs, float search_beam, float output_beam, - int32_t min_active, int32_t max_active) + int32_t min_active, int32_t max_active, + bool online_decoding) : a_fsas_(a_fsas), - b_fsas_(b_fsas), + num_seqs_(num_seqs), search_beam_(search_beam), output_beam_(output_beam), min_active_(min_active), max_active_(max_active), - dynamic_beams_(a_fsas.Context(), b_fsas.shape.Dim0(), search_beam), - forward_semaphore_(1) { + online_decoding_(online_decoding), + dynamic_beams_(a_fsas.Context(), num_seqs, search_beam), + forward_semaphore_(1), + final_t_(a_fsas.Context(), num_seqs, 0), + reach_final_(0) { NVTX_RANGE(K2_FUNC); - c_ = GetContext(a_fsas.shape, b_fsas.shape); - T_ = b_fsas_.shape.MaxSize(1); + c_ = GetContext(a_fsas.shape); + T_ = 0; K2_CHECK_GT(search_beam, 0); K2_CHECK_GT(output_beam, 0); K2_CHECK_GE(min_active, 0); K2_CHECK_GT(max_active, min_active); - K2_CHECK(a_fsas.shape.Dim0() == b_fsas.shape.Dim0() || - a_fsas.shape.Dim0() == 1); - K2_CHECK_GE(b_fsas.shape.Dim0(), 1); - int32_t num_seqs = b_fsas.shape.Dim0(); + K2_CHECK(a_fsas.shape.Dim0() == num_seqs || a_fsas.shape.Dim0() == 1); + K2_CHECK_GE(num_seqs, 1); int32_t num_buckets = RoundUpToNearestPowerOfTwo(num_seqs * 4 * max_active); @@ -165,9 +173,9 @@ class MultiGraphDenseIntersectPruned { if (a_fsas.shape.Dim0() == 1) { a_fsas_stride_ = 0; state_map_fsa_stride_ = a_fsas.TotSize(1); - num_a_copies = b_fsas.shape.Dim0(); + num_a_copies = num_seqs; } else { - K2_CHECK_EQ(a_fsas.shape.Dim0(), b_fsas.shape.Dim0()); + K2_CHECK_EQ(a_fsas.shape.Dim0(), num_seqs); a_fsas_stride_ = 1; state_map_fsa_stride_ = 0; num_a_copies = 1; @@ -192,7 +200,53 @@ class MultiGraphDenseIntersectPruned { } } state_map_ = Hash(c_, num_buckets, num_key_bits); + } + + // The information we have for each frame of the pruned-intersection (really: + // decoding) algorithm. We keep an array of these, one for each frame, up to + // the length of the longest sequence we're decoding plus one. + struct FrameInfo { + // States that are active at the beginning of this frame. Indexed + // [fsa_idx][state_idx], where fsa_idx indexes b_fsas_ (and a_fsas_, if + // a_fsas_stride_ != 0); and state_idx just enumerates the active states + // on this frame (as state_idx01's in a_fsas_). + Ragged states; // 2 axes: fsa, state + + // Indexed [fsa_idx][state_idx][arc_idx].. the first 2 indexes are + // the same as those into 'states' (the first 2 levels of the structure + // are shared), and the last one enumerates the arcs leaving each of those + // states. + // + // Note: there may be indexes [fsa_idx] that have no states (because that + // FSA had fewer frames than the max), and indexes [fsa_idx][state_idx] that + // have no arcs due to pruning. + Ragged arcs; // 3 axes: fsa, state, arc + }; + /* Does the main work of intersection/composition, but doesn't produce any + output; the output is provided when you call FormatOutput(). + + @param [in] b_fsas The neural-net output, with each frame containing the + log-likes of each phone. A series of sequences of + (in general) different length. + */ + void Intersect(std::shared_ptr &b_fsas) { + /* + T is the largest number of (frames+1) of neural net output, or the largest + number of frames of log-likelihoods we count the final frame with (0, + -inf, -inf..) that is used for the final-arc. The largest number of + states in the fsas represented by b_fsas equals T+1 (e.g. 1 frame would + require 2 states, because that 1 frame is the arc from state 0 to state + 1). So the #states is 2 greater than the actual number of frames in the + neural-net output. + */ + + K2_CHECK(!online_decoding_); + K2_CHECK(c_->IsCompatible(*b_fsas->Context())); + + b_fsas_ = b_fsas; + K2_CHECK_EQ(num_seqs_, b_fsas_->shape.Dim0()); + T_ = T_ + b_fsas_->shape.MaxSize(1); { // set up do_pruning_after_ and prune_t_begin_end_. @@ -225,46 +279,12 @@ class MultiGraphDenseIntersectPruned { break; } } - } - - // The information we have for each frame of the pruned-intersection (really: - // decoding) algorithm. We keep an array of these, one for each frame, up to - // the length of the longest sequence we're decoding plus one. - struct FrameInfo { - // States that are active at the beginning of this frame. Indexed - // [fsa_idx][state_idx], where fsa_idx indexes b_fsas_ (and a_fsas_, if - // a_fsas_stride_ != 0); and state_idx just enumerates the active states - // on this frame (as state_idx01's in a_fsas_). - Ragged states; // 2 axes: fsa, state - // Indexed [fsa_idx][state_idx][arc_idx].. the first 2 indexes are - // the same as those into 'states' (the first 2 levels of the structure - // are shared), and the last one enumerates the arcs leaving each of those - // states. - // - // Note: there may be indexes [fsa_idx] that have no states (because that - // FSA had fewer frames than the max), and indexes [fsa_idx][state_idx] that - // have no arcs due to pruning. - Ragged arcs; // 3 axes: fsa, state, arc - }; - - /* Does the main work of intersection/composition, but doesn't produce any - output; the output is provided when you call FormatOutput(). */ - void Intersect() { - /* - T is the largest number of (frames+1) of neural net output, or the largest - number of frames of log-likelihoods we count the final frame with (0, - -inf, -inf..) that is used for the final-arc. The largest number of - states in the fsas represented by b_fsas equals T+1 (e.g. 1 frame would - require 2 states, because that 1 frame is the arc from state 0 to state - 1). So the #states is 2 greater than the actual number of frames in the - neural-net output. - */ - int32_t num_fsas = b_fsas_.shape.Dim0(), T = T_; + int32_t T = T_; std::ostringstream os; - os << "Intersect:T=" << T << ",num_fsas=" << num_fsas - << ",TotSize(1)=" << b_fsas_.shape.TotSize(1); + os << "Intersect:T=" << T << ",num_fsas=" << num_seqs_ + << ",TotSize(1)=" << b_fsas_->shape.TotSize(1); NVTX_RANGE(os.str().c_str()); ThreadPool* pool = GetThreadPool(); @@ -302,8 +322,96 @@ class MultiGraphDenseIntersectPruned { pool->WaitAllTasksFinished(); } + /* Does the main work of intersection/composition, but doesn't produce any + output; the output is provided when you call FormatOutput(). + Does almost the same work as `Intersect`, except that this would be call + serveral times for chunk by chunk decoding. + + @param [in] b_fsas The neural-net output, with each frame containing the + log-likes of each phone. A series of sequences of + (in general) different length. + @param [in] is_final Whether this is the final chunk of the nnet_output, + After calling this function with is_final is true, + means decoding finished. + */ + void OnlineIntersect(std::shared_ptr &b_fsas, bool is_final) { + /* + T is the largest number of (frames+1) of neural net output currently + received, or the largest number of frames of log-likelihoods we count the + final frame with (0, -inf, -inf..) that is used for the final-arc. + The largest number of states in the fsas represented by b_fsas equals + T+1 (e.g. 1 frame would require 2 states, because that 1 frame is the arc + from state 0 to state 1). So the #states is 2 greater than the actual + number of frames in the neural-net output. + */ + K2_CHECK(online_decoding_); + K2_CHECK(c_->IsCompatible(*b_fsas->Context())); + + K2_CHECK_EQ(reach_final_, 0) << "You can't continue decoding after " + << "reaching final."; + if (is_final) { + reach_final_ = 1; + } + + b_fsas_ = b_fsas; + K2_CHECK_EQ(num_seqs_, b_fsas_->shape.Dim0()); + int32_t T = T_ + b_fsas_->shape.MaxSize(1); + + // we'll initially populate frames_[0.. T+1], but discard the one at T+1, + // which has no arcs or states, the ones we use are from 0 to T. + frames_.reserve(T + 2); + + if (T_ == 0) frames_.push_back(InitialFrameInfo()); + int32_t prune_num_frames = 15, prune_shift = 10; + + for (int32_t t = 0; t <= b_fsas_->shape.MaxSize(1); t++) { + if (state_map_.NumKeyBits() == 32) { + frames_.push_back(PropagateForward<32>(t, frames_.back().get())); + } else if (state_map_.NumKeyBits() == 36) { + frames_.push_back(PropagateForward<36>(t, frames_.back().get())); + } else { + K2_CHECK_EQ(state_map_.NumKeyBits(), 40); + frames_.push_back(PropagateForward<40>(t, frames_.back().get())); + } + if (t != 0 && (T_ + t) % prune_shift == 0 || + t == b_fsas_->shape.MaxSize(1)) { + int32_t prune_t_begin = + (T_ + t - prune_num_frames) > 0 ? (T_ + t - prune_num_frames) : 0; + int32_t prune_t_end = T_ + t; + PruneTimeRange(prune_t_begin, prune_t_end); + } + } + // The FrameInfo for time T+1 will have no states. We did that + // last PropagateForward so that the 'arcs' member of frames_[T] + // is set up (it has no arcs but we need the shape). + frames_.pop_back(); + + if (is_final) { + T_ = T; + } else { + T_ = T - 1; + // partial_final_frame_ is the last frame to generate partial result, + // but it should not be the start frame of next chunk decoding. + partial_final_frame_ = std::move(frames_.back()); + frames_.pop_back(); + } + const int32_t *b_fsas_row_splits1 = b_fsas_->shape.RowSplits(1).Data(); + int32_t *final_t_data = final_t_.Data(); + + // Get final frame of each sequences. + K2_EVAL( + c_, num_seqs_, lambda_set_final_and_final_t, (int32_t i)->void { + int32_t b_chunk_size = + b_fsas_row_splits1[i + 1] - b_fsas_row_splits1[i]; + int32_t final_t = final_t_data[i]; + final_t = + is_final ? final_t + b_chunk_size : final_t + b_chunk_size - 1; + final_t_data[i] = final_t; + }); + } + void BackwardPass() { - int32_t num_fsas = b_fsas_.shape.Dim0(), + int32_t num_fsas = b_fsas_->shape.Dim0(), num_work_items = max_active_ * num_fsas * T_; ParallelRunner pr(c_); // if num_work_items is big enough, it will actually create a new stream. @@ -332,7 +440,7 @@ class MultiGraphDenseIntersectPruned { // Return FrameInfo for 1st frame, with `states` set but `arcs` not set. std::unique_ptr InitialFrameInfo() { NVTX_RANGE("InitialFrameInfo"); - int32_t num_fsas = b_fsas_.shape.Dim0(); + int32_t num_fsas = b_fsas_->shape.Dim0(); std::unique_ptr ans = std::make_unique(); if (a_fsas_.Dim0() == 1) { @@ -370,29 +478,45 @@ class MultiGraphDenseIntersectPruned { } void FormatOutput(FsaVec *ofsa, Array1 *arc_map_a, - Array1 *arc_map_b) { + Array1 *arc_map_b, bool is_final) { NVTX_RANGE("FormatOutput"); - int32_t T = T_; - + bool online_decoding = online_decoding_; + if (online_decoding) { + K2_CHECK_EQ(is_final, reach_final_); + K2_CHECK(arc_map_a); + K2_CHECK_EQ(arc_map_b, nullptr); + } else { + K2_CHECK(is_final); + K2_CHECK(arc_map_a && arc_map_b); + } + int32_t T = is_final ? T_ : T_ + 1; ContextPtr c_cpu = GetCpuContext(); Array1 arcs_data_ptrs(c_cpu, T + 1); Array1 arcs_row_splits1_ptrs(c_cpu, T + 1); - for (int32_t t = 0; t <= T; t++) { + for (int32_t t = 0; t < T; t++) { arcs_data_ptrs.Data()[t] = frames_[t]->arcs.values.Data(); arcs_row_splits1_ptrs.Data()[t] = frames_[t]->arcs.RowSplits(1).Data(); } + arcs_data_ptrs.Data()[T] = is_final + ? frames_[T]->arcs.values.Data() + : partial_final_frame_->arcs.values.Data(); + arcs_row_splits1_ptrs.Data()[T] = + is_final ? frames_[T]->arcs.RowSplits(1).Data() + : partial_final_frame_->arcs.RowSplits(1).Data(); + // transfer to GPU if we're using a GPU arcs_data_ptrs = arcs_data_ptrs.To(c_); ArcInfo **arcs_data_ptrs_data = arcs_data_ptrs.Data(); arcs_row_splits1_ptrs = arcs_row_splits1_ptrs.To(c_); int32_t **arcs_row_splits1_ptrs_data = arcs_row_splits1_ptrs.Data(); - const int32_t *b_fsas_row_splits1 = b_fsas_.shape.RowSplits(1).Data(); + const int32_t *b_fsas_row_splits1 = b_fsas_->shape.RowSplits(1).Data(); const int32_t *a_fsas_row_splits1 = a_fsas_.RowSplits(1).Data(); int32_t a_fsas_stride = a_fsas_stride_; // 0 or 1 depending if the decoding // graph is shared. - int32_t num_fsas = b_fsas_.shape.Dim0(); + int32_t *final_t_data = final_t_.Data(); + int32_t num_fsas = b_fsas_->shape.Dim0(); RaggedShape final_arcs_shape; { /* This block populates `final_arcs_shape`. It is the shape of a ragged @@ -407,7 +531,12 @@ class MultiGraphDenseIntersectPruned { Array1 num_extra_states(c_, num_fsas + 1); int32_t *num_extra_states_data = num_extra_states.Data(); K2_EVAL(c_, num_fsas, lambda_set_num_extra_states, (int32_t i) -> void { - int32_t final_t = b_fsas_row_splits1[i+1] - b_fsas_row_splits1[i]; + int32_t final_t; + if (online_decoding) + final_t = is_final ? final_t_data[i] : final_t_data[i] + 1; + else + final_t = b_fsas_row_splits1[i+1] - b_fsas_row_splits1[i]; + int32_t *arcs_row_splits1_data = arcs_row_splits1_ptrs_data[final_t]; int32_t num_states_final_t = arcs_row_splits1_data[i + 1] - arcs_row_splits1_data[i]; @@ -438,8 +567,12 @@ class MultiGraphDenseIntersectPruned { NVTX_RANGE("InitOshape"); // each of these have 3 axes. std::vector arcs_shapes(T + 2); - for (int32_t t = 0; t <= T; t++) + for (int32_t t = 0; t < T; t++) arcs_shapes[t] = &(frames_[t]->arcs.shape); + + arcs_shapes[T] = is_final ? &(frames_[T]->arcs.shape) + : &(partial_final_frame_->arcs.shape); + arcs_shapes[T + 1] = &final_arcs_shape; // oshape is a 4-axis ragged tensor which is indexed: @@ -448,7 +581,6 @@ class MultiGraphDenseIntersectPruned { oshape = Stack(axis, T + 2, arcs_shapes.data(), &oshape_merge_map); } - int32_t *oshape_row_ids3 = oshape.RowIds(3).Data(), *oshape_row_ids2 = oshape.RowIds(2).Data(), *oshape_row_ids1 = oshape.RowIds(1).Data(), @@ -456,17 +588,17 @@ class MultiGraphDenseIntersectPruned { *oshape_row_splits2 = oshape.RowSplits(2).Data(), *oshape_row_splits1 = oshape.RowSplits(1).Data(); - int32_t num_arcs = oshape.NumElements(); *arc_map_a = Array1(c_, num_arcs); - *arc_map_b = Array1(c_, num_arcs); + if (!online_decoding) + *arc_map_b = Array1(c_, num_arcs); int32_t *arc_map_a_data = arc_map_a->Data(), - *arc_map_b_data = arc_map_b->Data(); + *arc_map_b_data = online_decoding ? nullptr : arc_map_b->Data(); Array1 arcs_out(c_, num_arcs); Arc *arcs_out_data = arcs_out.Data(); const Arc *a_fsas_arcs = a_fsas_.values.Data(); - int32_t b_fsas_num_cols = b_fsas_.scores.Dim1(); - const int32_t *b_fsas_row_ids1 = b_fsas_.shape.RowIds(1).Data(); + int32_t b_fsas_num_cols = b_fsas_->scores.Dim1(); + const int32_t *b_fsas_row_ids1 = b_fsas_->shape.RowIds(1).Data(); const uint32_t *oshape_merge_map_data = oshape_merge_map.Data(); @@ -499,17 +631,20 @@ class MultiGraphDenseIntersectPruned { arc_info.u.dest_info_state_idx1; arc.dest_state = dest_state_idx012 - oarc_idx0xx; arc.label = a_fsas_arcs[arc_info.a_fsas_arc_idx012].label; - - int32_t fsa_id = oarc_idx0, - b_fsas_idx0x = b_fsas_row_splits1[fsa_id], - b_fsas_idx01 = b_fsas_idx0x + t, - b_fsas_idx2 = (arc.label + 1), - b_fsas_arc_idx012 = b_fsas_idx01 * b_fsas_num_cols + b_fsas_idx2; - arc.score = arc_info.arc_loglike; - arc_map_a_data[oarc_idx0123] = arc_info.a_fsas_arc_idx012; - arc_map_b_data[oarc_idx0123] = b_fsas_arc_idx012; arcs_out_data[oarc_idx0123] = arc; + + // We won't preduce arc_map_b (for nnet_output) for online_decoding, + // in this case, b_fsas_ is only a part of the whole sequence. + if (!online_decoding) { + int32_t fsa_id = oarc_idx0, + b_fsas_idx0x = b_fsas_row_splits1[fsa_id], + b_fsas_idx01 = b_fsas_idx0x + t, + b_fsas_idx2 = (arc.label + 1), + b_fsas_arc_idx012 = b_fsas_idx01 * b_fsas_num_cols + b_fsas_idx2; + arc_map_b_data[oarc_idx0123] = b_fsas_arc_idx012; + } + arc_map_a_data[oarc_idx0123] = arc_info.a_fsas_arc_idx012; }); // Remove axis 1, which corresponds to time. @@ -563,7 +698,7 @@ class MultiGraphDenseIntersectPruned { min_active = min_active_; K2_CHECK_LT(min_active, max_active); - const int32_t *b_fsas_row_splits1 = b_fsas_.shape.RowSplits(1).Data(); + const int32_t *b_fsas_row_splits1 = b_fsas_->shape.RowSplits(1).Data(); Array1 cutoffs(c_, num_fsas); float *cutoffs_data = cutoffs.Data(); @@ -670,11 +805,11 @@ class MultiGraphDenseIntersectPruned { const Arc *arcs = a_fsas_.values.Data(); // fsa_idx0 to idx0x (into b_fsas_), which gives the 1st row for this // sequence. - const int32_t *b_fsas_row_ids1 = b_fsas_.shape.RowIds(1).Data(); - const int32_t *b_fsas_row_splits1 = b_fsas_.shape.RowSplits(1).Data(); - const float *score_data = b_fsas_.scores.Data(); - int32_t scores_num_cols = b_fsas_.scores.Dim1(); - auto scores_acc = b_fsas_.scores.Accessor(); + const int32_t *b_fsas_row_ids1 = b_fsas_->shape.RowIds(1).Data(); + const int32_t *b_fsas_row_splits1 = b_fsas_->shape.RowSplits(1).Data(); + const float *score_data = b_fsas_->scores.Data(); + int32_t scores_num_cols = b_fsas_->scores.Dim1(); + auto scores_acc = b_fsas_->scores.Accessor(); Ragged ai(ai_shape); ArcInfo *ai_data = ai.values.Data(); // uninitialized @@ -715,9 +850,9 @@ class MultiGraphDenseIntersectPruned { return ai; } - // Later we may choose to support b_fsas_.Dim0() == 1 and a_fsas_.Dim0() > 1, + // Later we may choose to support b_fsas_->Dim0() == 1 and a_fsas_.Dim0() > 1, // and we'll have to change various bits of code for that to work. - inline int32_t NumFsas() const { return b_fsas_.shape.Dim0(); } + inline int32_t NumFsas() const { return b_fsas_->shape.Dim0(); } /* Does the forward-propagation (basically: the decoding step) and @@ -1189,7 +1324,7 @@ class MultiGraphDenseIntersectPruned { NVTX_RANGE(K2_FUNC); SetBackwardProbsFinal(frames_[end_t].get()); ContextPtr cpu = GetCpuContext(); - int32_t num_fsas = b_fsas_.shape.Dim0(), + int32_t num_fsas = b_fsas_->shape.Dim0(), num_t = end_t - begin_t; Array1 old_states_offsets(cpu, num_t + 1), old_arcs_offsets(cpu, num_t + 1); @@ -1210,7 +1345,7 @@ class MultiGraphDenseIntersectPruned { // contains respectively: row_splits1_ptrs, row_ids1_ptrs, - // row_splits1_ptrs, row_splits2_ptrs, + // row_splits2_ptrs, row_ids2_ptrs, // old_arcs_ptrs (really type ArcInfo*), // old_states_ptrs (really type StateInfo*). Array1 old_all_ptrs(cpu, num_t * 6); @@ -1453,8 +1588,12 @@ class MultiGraphDenseIntersectPruned { int32_t a_fsas_stride_; // 1 if we use a different FSA per sequence // (a_fsas_.Dim0() > 1), 0 if the decoding graph is // shared (a_fsas_.Dim0() == 1). - DenseFsaVec &b_fsas_; - int32_t T_; // == b_fsas_.shape.MaxSize(1). + std::shared_ptr b_fsas_; // nnet_output to be decoded. + int32_t num_seqs_; // the number of sequences to decode at a time, + // i.e. batch size for decoding. + int32_t T_; // equals to b_fsas_->shape.MaxSize(1), for + // batch intersection. + // means the number of frames decoded currently. float search_beam_; float output_beam_; int32_t min_active_; @@ -1463,6 +1602,14 @@ class MultiGraphDenseIntersectPruned { // but change due to max_active/min_active // constraints). + bool online_decoding_; // true for online decoding. + Array1 final_t_; // record the final frame id of each DenseFsa. + int32_t reach_final_; // only for online decoding, indicating whether + // the last chunk of audio received. + + std::unique_ptr partial_final_frame_; // store the final frame for + // partial results + int32_t state_map_fsa_stride_; // state_map_fsa_stride_ is a_fsas_.TotSize(1) // if a_fsas_.Dim0() == 1, else 0. @@ -1525,11 +1672,40 @@ void IntersectDensePruned(FsaVec &a_fsas, DenseFsaVec &b_fsas, Array1 *arc_map_b) { NVTX_RANGE("IntersectDensePruned"); FsaVec a_vec = FsaToFsaVec(a_fsas); - MultiGraphDenseIntersectPruned intersector(a_vec, b_fsas, search_beam, - output_beam, min_active_states, - max_active_states); + bool online_decoding = false; + MultiGraphDenseIntersectPruned intersector(a_vec, b_fsas.shape.Dim0(), + search_beam, output_beam, + min_active_states, + max_active_states, + online_decoding); + + auto b_fsas_p = std::make_shared(b_fsas); + intersector.Intersect(b_fsas_p); + intersector.FormatOutput(out, arc_map_a, arc_map_b, true); +} + +OnlineDenseIntersecter::OnlineDenseIntersecter(FsaVec &a_fsas, + int32_t num_seqs, float search_beam, float output_beam, + int32_t min_active_states, int32_t max_active_states) { + bool online_decoding = true; + K2_CHECK_EQ(a_fsas.NumAxes(), 3); + impl_ = new MultiGraphDenseIntersectPruned(a_fsas, num_seqs, search_beam, + output_beam, min_active_states, max_active_states, online_decoding); +} + +OnlineDenseIntersecter::~OnlineDenseIntersecter(){ + delete impl_; +} + +void OnlineDenseIntersecter::Intersect(DenseFsaVec &b_fsas, bool is_final) { + auto b_fsas_p = std::make_shared(b_fsas); + impl_->OnlineIntersect(b_fsas_p, is_final); +} - intersector.Intersect(); - intersector.FormatOutput(out, arc_map_a, arc_map_b); +void OnlineDenseIntersecter::FormatOutput(FsaVec *out, + Array1 *arc_map_a, + bool is_final) { + impl_->FormatOutput(out, arc_map_a, nullptr, is_final); } + } // namespace k2 diff --git a/k2/csrc/log.cu b/k2/csrc/log.cu index 3777622e5..e1733a0c6 100644 --- a/k2/csrc/log.cu +++ b/k2/csrc/log.cu @@ -22,6 +22,8 @@ #include "k2/csrc/log.h" +#include + #ifdef K2_HAVE_EXECINFO_H #include // To get stack trace in error messages. #ifdef K2_HAVE_CXXABI_H @@ -32,12 +34,33 @@ #include +#include // NOLINT +#include #include namespace k2 { namespace internal { +std::string GetTimeStamp() { + using namespace std::chrono; // NOLINT + auto now = system_clock::now(); + std::time_t time = system_clock::to_time_t(now); + std::tm tm; +#ifndef _MSC_VER + localtime_r(&time, &tm); +#else + localtime_s(&tm, &time); +#endif + char s[128]; + std::strftime(s, sizeof(s), "%F %T", &tm); + int32_t ms = + duration_cast(now.time_since_epoch()).count() % 1000; + std::ostringstream os; + os << s << "." << ms; + return os.str(); +} + static bool LocateSymbolRange(const std::string &trace_name, std::size_t *begin, std::size_t *end) { // Find the first '_' with leading ' ' or '('. diff --git a/k2/csrc/math.h b/k2/csrc/math.h index 250cbd6ca..bfe7d64a2 100644 --- a/k2/csrc/math.h +++ b/k2/csrc/math.h @@ -29,8 +29,7 @@ namespace k2 { // Currently, only used in k2/csrc/rnnt_decode.cu // See https://github.com/k2-fsa/k2/pull/951#issuecomment-1096650842 -K2_CUDA_HOSTDEV __forceinline__ int64_t Pow(int64_t base, - int64_t exponent) { +K2_CUDA_HOSTDEV __forceinline__ int64_t Pow(int64_t base, int64_t exponent) { K2_CHECK_GE(exponent, 0); int64_t exp = 0; int64_t result = 1; diff --git a/k2/csrc/online_dense_intersector.h b/k2/csrc/online_dense_intersector.h new file mode 100644 index 000000000..44583c129 --- /dev/null +++ b/k2/csrc/online_dense_intersector.h @@ -0,0 +1,91 @@ +/** + * Copyright (c) 2021 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_CSRC_ONLINE_DENSE_INTERSECTOR_H_ +#define K2_CSRC_ONLINE_DENSE_INTERSECTOR_H_ + +#include "k2/csrc/fsa.h" + +namespace k2 { +class MultiGraphDenseIntersectPruned; +/** + Pruned intersection (a.k.a. composition) that corresponds to decoding for + speech recognition-type tasks for online fashion. + + @param [in] a_fsas The decoding graphs, one per sequence. E.g. might + just be a linear sequence of phones, or might be + something more complicated. Must have either the + same Dim0() as b_fsas, or Dim0()==1 in which + case the graph is shared. + @param [in] num_seqs The number of sequences to do intersection at a + time, i.e. batch size. The input DenseFsaVec in + `Intersect` function MUST have `Dim0()` equals to + this. + @param [in] search_beam "Default" search/decoding beam. The actual + beam is dynamic and also depends on max_active and + min_active. + @param [in] output_beam Beam for pruning the output FSA, will + typically be smaller than search_beam. + @param [in] min_active Minimum number of FSA states that are allowed to + be active on any given frame for any given + intersection/composition task. This is advisory, + in that it will try not to have fewer than this + number active. + @param [in] max_active Maximum number of FSA states that are allowed to + be active on any given frame for any given + intersection/composition task. This is advisory, + in that it will try not to exceed that but may not + always succeed. This determines the hash size. +*/ +class OnlineDenseIntersecter { + public: + OnlineDenseIntersecter(FsaVec &a_fsas, int32_t num_seqs, float search_beam, + float output_beam, int32_t min_states, + int32_t max_states); + + /* Does intersection/composition for current chunk of nnet_output(given + by a DenseFsaVec), but doesn't produce any output; the output is + provided when you call FormatOutput(). + + @param [in] b_fsas The neural-net output, with each frame containing + the log-likes of each phone. + @param [in] is_final Whether this is the final chunk of the nnet_output, + After calling this function with is_final is true, + means decoding finished. + */ + void Intersect(DenseFsaVec &b_fsas, bool is_final); + + /* Format partial/final result of the intersection. + + @param [out] out The FsaVec to contain the output lattice of the + intersection result. + @param[out] arc_map_a Will be set to a vector with Dim() equal to + the number of arcs in `out`, whose elements + contain the corresponding arc_idx012 in decoding + graph (i.e. a_fsas). + @param [in] is_final True for final result, false for partial resutl. + */ + void FormatOutput(FsaVec *out, Array1 *arc_map_a, bool is_final); + ~OnlineDenseIntersecter(); + + private: + MultiGraphDenseIntersectPruned *impl_; +}; +}; // namespace k2 + +#endif // K2_CSRC_ONLINE_DENSE_INTERSECTOR_H_ diff --git a/k2/csrc/pytorch_context.cu b/k2/csrc/pytorch_context.cu index 86b88e4bb..14cbdc6d9 100644 --- a/k2/csrc/pytorch_context.cu +++ b/k2/csrc/pytorch_context.cu @@ -22,6 +22,7 @@ #ifdef K2_WITH_CUDA #include "c10/cuda/CUDACachingAllocator.h" #include "c10/cuda/CUDAFunctions.h" +#include "torch/cuda.h" #endif #include "k2/csrc/context.h" @@ -77,7 +78,7 @@ static void InitHasCuda() { else K2_LOG(WARNING) << "CUDA is not available. Return a CPU context."; #else - K2_LOG(WARNING) << "k2 was not compiled with CUDA. Return a CPU context."; + K2_LOG(WARNING) << "k2 was not compiled with CUDA. Return a CPU context."; #endif } diff --git a/k2/csrc/ragged.h b/k2/csrc/ragged.h index 91ae33bfb..d4bde0bfa 100644 --- a/k2/csrc/ragged.h +++ b/k2/csrc/ragged.h @@ -20,6 +20,7 @@ #ifndef K2_CSRC_RAGGED_H_ #define K2_CSRC_RAGGED_H_ +#include #include #include #include @@ -506,6 +507,27 @@ ToType(int64_t, Long) // that Array1's that are the row_ids or row_splits of a Ragged object are // not mutable so they can be re-used. Ragged Clone() const { return Ragged(shape, values.Clone()); } + + // Convert a ragged tensor with 2 axes into a vector of vector. + // + // CAUTION: this->NumAxes() must be 2. + std::vector> ToVecVec() const { + K2_CHECK_EQ(NumAxes(), 2); + if (Context()->GetDeviceType() == kCuda) { + return this->To(GetCpuContext()).ToVecVec(); + } + int32_t dim0 = this->Dim0(); + std::vector> ans(dim0); + const int32_t *row_splits_data = RowSplits(1).Data(); + const T *values_data = values.Data(); + for (int32_t i = 0; i != dim0; ++i) { + int32_t len = row_splits_data[i + 1] - row_splits_data[i]; + ans[i].resize(len); + std::copy(values_data + row_splits_data[i], + values_data + row_splits_data[i + 1], ans[i].begin()); + } + return ans; + } }; // e.g. will produce something like "[ [ 3 4 ] [ 1 ] ]". diff --git a/k2/csrc/ragged_ops.cu b/k2/csrc/ragged_ops.cu index 1a919a02a..071e41459 100644 --- a/k2/csrc/ragged_ops.cu +++ b/k2/csrc/ragged_ops.cu @@ -1204,7 +1204,7 @@ RaggedShape Stack(int32_t axis, int32_t num_srcs, RaggedShape **src, // Contains the pointers for split_map Array1 split_map_ptr; - int32_t **split_map_ptr_data; + int32_t **split_map_ptr_data = nullptr; if (axis == num_axes - 1 && split_map != nullptr) { split_map_ptr = Array1(GetCpuContext(), out_size); diff --git a/k2/csrc/ragged_ops.h b/k2/csrc/ragged_ops.h index 53ba84ec3..7b11bd89c 100644 --- a/k2/csrc/ragged_ops.h +++ b/k2/csrc/ragged_ops.h @@ -1202,7 +1202,7 @@ RaggedShape ComposeRaggedShapes3(const RaggedShape &a, const RaggedShape &b, If cached_tot_sizeN is not -1, it must equal the total size on that axis which will equal the last element of row_splitsN (if provided) and must equal the row_idsN.Dim(), if provided. See - documentation above for RagggedShape2 for details. + documentation above for RaggedShape2 for details. We also require that (supposing both row_splitsN and row_idsN are non-NULL): row_splits1[row_splits1.Dim() - 1] == row_ids1.Dim() diff --git a/k2/csrc/rm_epsilon.cu b/k2/csrc/rm_epsilon.cu index e6cd5fe28..8451806ed 100644 --- a/k2/csrc/rm_epsilon.cu +++ b/k2/csrc/rm_epsilon.cu @@ -1016,9 +1016,9 @@ void ComputeEpsilonClosureOneIter(FsaVec &epsilon_fsa, FsaVec *closure_fsa, } void RemoveEpsilonDevice(FsaOrVec &src_fsa, FsaOrVec *dest_fsa, - Ragged *arc_map_out) { + Ragged *arc_map_out /*=nullptr*/) { NVTX_RANGE(K2_FUNC); - K2_CHECK(dest_fsa != nullptr && arc_map_out != nullptr); + K2_CHECK_NE(dest_fsa, nullptr); K2_CHECK_GE(src_fsa.NumAxes(), 2); K2_CHECK_LE(src_fsa.NumAxes(), 3); if (src_fsa.NumAxes() == 2) { diff --git a/k2/csrc/torch_api.h b/k2/csrc/torch_api.h index 471b10e46..76292549b 100644 --- a/k2/csrc/torch_api.h +++ b/k2/csrc/torch_api.h @@ -20,6 +20,8 @@ #define K2_CSRC_TORCH_API_H_ #include +#include +#include #include "torch/script.h" @@ -70,6 +72,64 @@ torch::Tensor RowIds(RaggedShapePtr shape, int32_t axis); */ torch::Tensor RowSplits(RaggedShapePtr shape, int32_t axis); +class FsaClass; +using FsaClassPtr = std::shared_ptr; + +/* Create a CTC topology. + + Note: + A standard CTC topology is the conventional one, where there + is a mandatory blank between two repeated neighboring symbols. + A non-standard, i.e., modified CTC topology, imposes no such constraint. + + @param max_token The maximum token ID (inclusive). We assume that token IDs + are contiguous (from 1 to `max_token`). 0 represents blank. + @param modified If False, create a standard CTC topology. Otherwise, create + a modified CTC topology. + @param device A torch.device indicating what device the returned Fsa will + be. Default torch::CPU. + @return Return either a standard or a modified CTC topology as an FSA + depending on whether `modified` is false or true. + */ +FsaClassPtr GetCtcTopo(int32_t max_token, bool modified = false, + torch::Device device = torch::kCPU); + +/** + Load a file saved in Python by + + torch.save(fsa.as_dict(), filename, _use_new_zipfile_serialization=True) + + Note: `_use_new_zipfile_serialization` is True by default + + @param filename Path to the filename produced in Python by `torch.save()`. + @param map_location It has the same meaning as the one in `torch.load()`. + The loaded FSA is moved to this device + before returning. + @return Return the FSA contained in the filename. + */ +FsaClassPtr LoadFsa(const std::string &filename, + torch::Device map_location = torch::kCPU); + +/** Run CTC decode. + * @param log_softmax_out A tensor of shape (N, T, C) containing the output + * from a log_softmax layer. + * @param log_softmax_out_lens A tensor of shape (N,) containing the number + * of valid frames in log_softmax_out before + * padding. + * @param decoding_graph Can be either the return value of CtcTopo() or + * an HLG returned from LoadFsa() + * + * @return Return the decoding results of size `N`. ans[i] is the result + * for the i-th utterance. If the decoding_graph is a CtcTopo, + * then the decoding result contains token IDs; if the decoding_graph + * is an HLG, then the decoding result contains word IDs. + * Note: The decoding result does not contain repeats and does not + * contain blanks. + */ +std::vector> Decode(torch::Tensor log_softmax_out, + torch::Tensor log_softmax_out_lens, + FsaClassPtr decoding_graph); + } // namespace k2 #endif // K2_CSRC_TORCH_API_H_ diff --git a/k2/python/csrc/CMakeLists.txt b/k2/python/csrc/CMakeLists.txt index c76975eb9..18341ab69 100644 --- a/k2/python/csrc/CMakeLists.txt +++ b/k2/python/csrc/CMakeLists.txt @@ -32,6 +32,15 @@ endif() pybind11_add_module(_k2 ${k2_srcs}) target_link_libraries(_k2 PRIVATE context) target_link_libraries(_k2 PRIVATE fsa) + +if(APPLE) + # To fix the following error: + # ImportError: /xxx/lib/_k2.cpython-38-x86_64-linux-gnu.so: undefined symbol: THPDtypeType + target_link_libraries(_k2 PRIVATE ${TORCH_DIR}/lib/libtorch_python.dylib) +elseif(UNIX) + target_link_libraries(_k2 PRIVATE ${TORCH_DIR}/lib/libtorch_python.so) +endif() + target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR}) target_include_directories(_k2 PRIVATE ${CMAKE_BINARY_DIR}) set_target_properties(_k2 PROPERTIES CUDA_SEPARABLE_COMPILATION ON) diff --git a/k2/python/csrc/torch/v2/k2.cu b/k2/python/csrc/torch/v2/k2.cu index b8e05b96e..7ba2a1c21 100644 --- a/k2/python/csrc/torch/v2/k2.cu +++ b/k2/python/csrc/torch/v2/k2.cu @@ -32,7 +32,7 @@ void PybindV2(py::module &m) { PybindRaggedShape(ragged); - m.attr("RaggedShape") = ragged.attr("RaggedShape"); // TODO: remove it + m.attr("RaggedShape") = ragged.attr("RaggedShape"); PybindRaggedAny(ragged); } diff --git a/k2/python/k2/rnnt_loss.py b/k2/python/k2/rnnt_loss.py index 323f161f3..b9a130d68 100644 --- a/k2/python/k2/rnnt_loss.py +++ b/k2/python/k2/rnnt_loss.py @@ -271,7 +271,6 @@ def rnnt_loss_simple( get if you did `torch.autograd.grad((-loss.sum()), [px, py])`, note, the loss here is the loss with reduction "none". This is useful to implement the pruned version of rnnt loss. - Returns: If return_grad is False, returns a tensor of shape (B,), containing the total RNN-T loss values for each element of the batch if reduction equals diff --git a/k2/torch/CMakeLists.txt b/k2/torch/CMakeLists.txt new file mode 100644 index 000000000..dfdeebc30 --- /dev/null +++ b/k2/torch/CMakeLists.txt @@ -0,0 +1,3 @@ +add_subdirectory(csrc) + +add_subdirectory(bin) diff --git a/k2/torch/README.md b/k2/torch/README.md new file mode 100644 index 000000000..5182a5650 --- /dev/null +++ b/k2/torch/README.md @@ -0,0 +1,10 @@ +## Introduction + +This directory contains code for deployment using PyTorch C++ APIs, +without Python dependencies. + +If CUDA is enabled, you have to use PyTorch >= 1.8.0 to compile it. +(You can see a warning saying this when running `cmake`) + +We suggest that you use +if you want to use k2's C++ APIs for speech recognition. diff --git a/k2/torch/bin/CMakeLists.txt b/k2/torch/bin/CMakeLists.txt new file mode 100644 index 000000000..54399ecdb --- /dev/null +++ b/k2/torch/bin/CMakeLists.txt @@ -0,0 +1,89 @@ +# it is located in k2/csrc/cmake/transform.cmake +include(transform) + +set(bin_dep_libs + ${TORCH_LIBRARIES} + k2_torch + k2_fbank +) + +#---------------------------------------- +# CTC decoding +#---------------------------------------- +set(ctc_decode_srcs ctc_decode.cu) +if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE ctc_decode_srcs SRCS ${ctc_decode_srcs}) +endif() +add_executable(ctc_decode ${ctc_decode_srcs}) +set_property(TARGET ctc_decode PROPERTY CXX_STANDARD 14) +target_link_libraries(ctc_decode ${bin_dep_libs}) + +#---------------------------------------- +# HLG decoding +#---------------------------------------- +set(hlg_decode_srcs hlg_decode.cu) +if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE hlg_decode_srcs SRCS ${hlg_decode_srcs}) +endif() +add_executable(hlg_decode ${hlg_decode_srcs}) +set_property(TARGET hlg_decode PROPERTY CXX_STANDARD 14) +target_link_libraries(hlg_decode ${bin_dep_libs}) + +#------------------------------------------- +# HLG decoding + n-gram LM rescoring +#------------------------------------------- +set(ngram_lm_rescore_srcs ngram_lm_rescore.cu) +if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE ngram_lm_rescore_srcs SRCS ${ngram_lm_rescore_srcs}) +endif() +add_executable(ngram_lm_rescore ${ngram_lm_rescore_srcs}) +set_property(TARGET ngram_lm_rescore PROPERTY CXX_STANDARD 14) +target_link_libraries(ngram_lm_rescore ${bin_dep_libs}) + +#--------------------------------------------------------------- +# HLG decoding + n-gram LM rescoring + attenion rescoring +#--------------------------------------------------------------- +set(attention_rescore_srcs attention_rescore.cu) +if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE attention_rescore_srcs SRCS ${attention_rescore_srcs}) +endif() +add_executable(attention_rescore ${attention_rescore_srcs}) +set_property(TARGET attention_rescore PROPERTY CXX_STANDARD 14) +target_link_libraries(attention_rescore ${bin_dep_libs}) + + +#------------------------------------------- +# online decoding +#------------------------------------------- +set(online_decode_srcs online_decode.cu) +if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE online_decode_srcs SRCS ${online_decode_srcs}) +endif() + +add_executable(online_decode ${online_decode_srcs}) +set_property(TARGET online_decode PROPERTY CXX_STANDARD 14) +target_link_libraries(online_decode ${bin_dep_libs}) + +#------------------------------------------- +# rnnt demo +#------------------------------------------- +set(rnnt_demo_srcs rnnt_demo.cu) +if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE rnnt_demo_srcs SRCS ${rnnt_demo_srcs}) +endif() + +add_executable(rnnt_demo ${rnnt_demo_srcs}) +set_property(TARGET rnnt_demo PROPERTY CXX_STANDARD 14) +target_link_libraries(rnnt_demo ${bin_dep_libs}) + +#------------------------------------------- +# pruned stateless transducer +#------------------------------------------- +set(pruned_stateless_transducer_srcs pruned_stateless_transducer.cu) +if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE pruned_stateless_transducer_srcs SRCS ${pruned_stateless_transducer_srcs}) +endif() + +add_executable(pruned_stateless_transducer ${pruned_stateless_transducer_srcs}) +set_property(TARGET pruned_stateless_transducer PROPERTY CXX_STANDARD 14) +target_link_libraries(pruned_stateless_transducer ${bin_dep_libs}) diff --git a/k2/torch/bin/attention_rescore.cu b/k2/torch/bin/attention_rescore.cu new file mode 100644 index 000000000..6dbc7dfdd --- /dev/null +++ b/k2/torch/bin/attention_rescore.cu @@ -0,0 +1,372 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/array_ops.h" +#include "k2/csrc/fsa_algo.h" +#include "k2/csrc/ragged_ops.h" +#include "k2/torch/csrc/decode.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/nbest.h" +#include "k2/torch/csrc/symbol_table.h" +#include "k2/torch/csrc/wave_reader.h" +#include "torch/all.h" +#include "torch/script.h" + +static constexpr const char *kUsageMessage = R"( +This file implements decoding with an HLG decoding graph, using +an n-gram LM and an attention decoder for rescoring. + +Usage: + ./bin/attention_rescore \ + --use_gpu true \ + --nn_model \ + --hlg \ + --g \ + --ngram_lm_scale 1.0 \ + --attention_scale 1.0 \ + --num_paths 100 \ + --nbest_scale 0.5 \ + --word_table \ + --sos_id \ + --eos_id \ + \ + \ + + +To see all possible options, use + ./bin/attention_rescore --help + +Caution: + - Only sound files (*.wav) with single channel are supported. + - It assumes the model is conformer_ctc/transformer.py from icefall. + If you use a different model, you have to change the code + related to `model.forward` in this file. +)"; + +C10_DEFINE_bool(use_gpu, false, "true to use GPU; false to use CPU"); +C10_DEFINE_string(nn_model, "", "Path to the model exported by torch script."); +C10_DEFINE_string(hlg, "", "Path to HLG.pt."); +C10_DEFINE_string(g, "", "Path to an ngram LM, e.g, G_4gram.pt"); +C10_DEFINE_double(ngram_lm_scale, 1.0, "Scale for ngram LM scores"); +C10_DEFINE_double(attention_scale, 1.0, "Scale for attention scores"); +C10_DEFINE_int(num_paths, -1, "Number of paths to sample for rescoring"); +C10_DEFINE_double(nbest_scale, 0.5, + "Scale for lattice.scores by this value before sampling."); +C10_DEFINE_string(word_table, "", "Path to words.txt."); +C10_DEFINE_int(sos_id, -1, "ID of start of sentence symbol."); +C10_DEFINE_int(eos_id, -1, "ID of end of sentence symbol."); + +// Fsa decoding related +C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned"); +C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned"); +C10_DEFINE_int(min_activate_states, 30, + "min_activate_states in IntersectDensePruned"); +C10_DEFINE_int(max_activate_states, 10000, + "max_activate_states in IntersectDensePruned"); +// Fbank related +// NOTE: These parameters must match those used in training +C10_DEFINE_int(sample_rate, 16000, "Expected sample rate of wave files"); +C10_DEFINE_double(frame_shift_ms, 10.0, + "Frame shift in ms for computing Fbank"); +C10_DEFINE_double(frame_length_ms, 25.0, + "Frame length in ms for computing Fbank"); +C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank"); + +static void CheckArgs() { +#if !defined(K2_WITH_CUDA) + if (FLAGS_use_gpu) { + std::cerr << "k2 was not compiled with CUDA. " + "Please use --use_gpu false"; + exit(EXIT_FAILURE); + } +#endif + + if (FLAGS_nn_model.empty()) { + std::cerr << "Please provide --nn_model\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_hlg.empty()) { + std::cerr << "Please provide --hlg\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_g.empty()) { + std::cerr << "Please provide --g\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_word_table.empty()) { + std::cerr << "Please provide --word_table\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_sos_id == -1) { + std::cerr << "Please provide --sos_id\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_eos_id == -1) { + std::cerr << "Please provide --eos_id\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + K2_CHECK_GT(FLAGS_num_paths, 0); + K2_CHECK_GT(FLAGS_nbest_scale, 0); +} + +int main(int argc, char *argv[]) { + // see + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html + torch::set_num_threads(1); + torch::set_num_interop_threads(1); + torch::NoGradGuard no_grad; + + torch::SetUsageMessage(kUsageMessage); + torch::ParseCommandLineFlags(&argc, &argv); + CheckArgs(); + + torch::Device device(torch::kCPU); + if (FLAGS_use_gpu) { + K2_LOG(INFO) << "Use GPU"; + device = torch::Device(torch::kCUDA, 0); + } + + K2_LOG(INFO) << "Device: " << device; + + int32_t num_waves = argc - 1; + K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; + std::vector wave_filenames(num_waves); + for (int32_t i = 0; i != num_waves; ++i) { + wave_filenames[i] = argv[i + 1]; + } + + K2_LOG(INFO) << "Load wave files"; + auto wave_data = k2::ReadWave(wave_filenames, FLAGS_sample_rate); + + for (auto &w : wave_data) { + w = w.to(device); + } + + K2_LOG(INFO) << "Build Fbank computer"; + kaldifeat::FbankOptions fbank_opts; + fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate; + fbank_opts.frame_opts.dither = 0; + fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms; + fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms; + fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + fbank_opts.device = device; + + kaldifeat::Fbank fbank(fbank_opts); + + K2_LOG(INFO) << "Compute features"; + std::vector num_frames; + auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames); + + // Note: math.log(1e-10) is -23.025850929940457 + auto features = torch::nn::utils::rnn::pad_sequence(features_vec, true, + -23.025850929940457f); + + K2_LOG(INFO) << "Load neural network model"; + torch::jit::script::Module module = torch::jit::load(FLAGS_nn_model); + module.eval(); + module.to(device); + + int32_t subsampling_factor = module.attr("subsampling_factor").toInt(); + torch::Dict sup; + sup.insert("sequence_idx", torch::arange(num_waves, torch::kInt)); + sup.insert("start_frame", torch::zeros({num_waves}, torch::kInt)); + sup.insert("num_frames", + torch::from_blob(num_frames.data(), {num_waves}, torch::kLong) + .to(torch::kInt)); + + torch::IValue supervisions(sup); + + K2_LOG(INFO) << "Compute nnet_output"; + // the output for module.forward() is a tuple of 3 tensors + // See the definition of the model in conformer_ctc/transformer.py + // from icefall. + // If you use a model that has a different signature for `forward`, + // you can change the following line. + auto outputs = module.run_method("forward", features, supervisions).toTuple(); + assert(outputs->elements().size() == 3u); + + auto nnet_output = outputs->elements()[0].toTensor(); // shape (N, T, C) + auto memory = outputs->elements()[1].toTensor(); // shape (T, N, C) + auto memory_key_padding_mask = + outputs->elements()[2].toTensor(); // shape (N, T) + + torch::Tensor supervision_segments = + k2::GetSupervisionSegments(supervisions, subsampling_factor); + + K2_LOG(INFO) << "Load " << FLAGS_hlg; + k2::FsaClass decoding_graph = k2::LoadFsa(FLAGS_hlg, device); + K2_CHECK(decoding_graph.HasTensorAttr("aux_labels") || + decoding_graph.HasRaggedTensorAttr("aux_labels")); + // Add `lm_scores` so that we can separate acoustic scores and lm scores + // later in the rescoring stage. + decoding_graph.SetTensorAttr("lm_scores", decoding_graph.Scores().clone()); + + K2_LOG(INFO) << "Decoding"; + k2::FsaClass lattice = k2::GetLattice( + nnet_output, decoding_graph, supervision_segments, FLAGS_search_beam, + FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states, + subsampling_factor); + + K2_LOG(INFO) << "Load n-gram LM: " << FLAGS_g; + k2::FsaClass G = k2::LoadFsa(FLAGS_g, device); + G.fsa = k2::FsaToFsaVec(G.fsa); + + K2_CHECK_EQ(G.NumAttrs(), 0) << "G is expected to be an acceptor."; + k2::AddEpsilonSelfLoops(G.fsa, &G.fsa); + k2::ArcSort(&G.fsa); + G.SetTensorAttr("lm_scores", G.Scores().clone()); + + K2_LOG(INFO) << "Rescore with an n-gram LM"; + WholeLatticeRescoring(G, /*ngram_lm_scale*/ 1, &lattice); + + K2_LOG(INFO) << "Sample " << FLAGS_num_paths << " paths"; + k2::Nbest nbest = + k2::Nbest::FromLattice(lattice, FLAGS_num_paths, FLAGS_nbest_scale); + // nbest.fsa.Scores() are all 0s at this point + + nbest.Intersect(&lattice); + // Caution: lattice is changed inside nbest, we don't need it after + // this line + // + // Now nbest.fsa has its scores set. + // Also, nbest.fsa inherits the attributes from `lattice`. + K2_CHECK(nbest.fsa.HasTensorAttr("lm_scores")); + torch::Tensor am_scores = nbest.ComputeAmScores(); + torch::Tensor ngram_lm_scores = nbest.ComputeLmScores(); + + K2_CHECK(nbest.fsa.HasTensorAttr("tokens")); + + auto &path_to_utt_map_array = nbest.shape.RowIds(1); + torch::Tensor path_to_utt_map = + Array1ToTorch(path_to_utt_map_array).to(torch::kLong); + + // the shape of memory is (T, N, C), so we use axis=1 here + torch::Tensor expanded_memory = memory.index_select(1, path_to_utt_map); + + // the shape of memory_key_padding_mask is (N, T), so we use axis=0 here + torch::Tensor expanded_memory_key_padding_mask = + memory_key_padding_mask.index_select(0, path_to_utt_map); + + k2::RaggedShape tokens_shape = k2::RemoveAxis(nbest.fsa.fsa.shape, 1); + + torch::Tensor tokens_value = nbest.fsa.GetTensorAttr("tokens"); + k2::Ragged tokens{tokens_shape, + k2::Array1FromTorch(tokens_value)}; + tokens = k2::RemoveValuesLeq(tokens, 0); + + std::vector> token_ids = tokens.ToVecVec(); + // convert std::vector> + // to + // torch::List where torch::IValue is torch::Tensor + torch::List token_ids_list(torch::TensorType::get()); + + token_ids_list.reserve(token_ids.size()); + for (const auto tids : token_ids) { + torch::Tensor tids_tensor = torch::tensor(tids); + token_ids_list.emplace_back(tids_tensor); + } + + K2_LOG(INFO) << "Run attention decoder"; + torch::Tensor nll = + module + .run_method("decoder_nll", expanded_memory, + expanded_memory_key_padding_mask, token_ids_list, + FLAGS_sos_id, FLAGS_eos_id) + .toTensor(); + K2_CHECK_EQ(nll.dim(), 2); + K2_CHECK_EQ(nll.size(0), nbest.shape.TotSize(1)); + + K2_LOG(INFO) << "Rescoring"; + + torch::Tensor attention_scores = -1 * nll.sum(1); + + torch::Tensor tot_scores = am_scores + + FLAGS_ngram_lm_scale * ngram_lm_scores + + FLAGS_attention_scale * attention_scores; + k2::Array1 tot_scores_array = k2::Array1FromTorch(tot_scores); + k2::Ragged ragged_tot_scores(nbest.shape, tot_scores_array); + k2::Array1 argmax(ragged_tot_scores.Context(), + ragged_tot_scores.Dim0()); + + k2::ArgMaxPerSublist(ragged_tot_scores, std::numeric_limits::lowest(), + &argmax); + k2::Array1 value_indexes_out; + k2::Fsa best_paths = + k2::Index(nbest.fsa.fsa, 0, argmax, &value_indexes_out); + + lattice = k2::FsaClass(best_paths); + + if (nbest.fsa.HasTensorAttr("aux_labels")) { + torch::Tensor in_aux_labels_tensor = nbest.fsa.GetTensorAttr("aux_labels"); + + k2::Array1 in_aux_labels = + k2::Array1FromTorch(in_aux_labels_tensor); + + k2::Array1 out_aux_labels = + k2::Index(in_aux_labels, value_indexes_out, + false, // allow_minus_one + 0); // default_value + + lattice.SetTensorAttr("aux_labels", k2::Array1ToTorch(out_aux_labels)); + } else { + K2_CHECK(nbest.fsa.HasRaggedTensorAttr("aux_labels")); + k2::Ragged in_aux_labels = + nbest.fsa.GetRaggedTensorAttr("aux_labels"); + + k2::Ragged out_aux_labels = + k2::Index(in_aux_labels, 0, value_indexes_out); + + lattice.SetRaggedTensorAttr("aux_labels", out_aux_labels); + } + + auto ragged_aux_labels = k2::GetTexts(lattice); + auto aux_labels_vec = ragged_aux_labels.ToVecVec(); + + std::vector texts; + k2::SymbolTable symbol_table(FLAGS_word_table); + for (const auto &ids : aux_labels_vec) { + std::string text; + std::string sep = ""; + for (auto id : ids) { + text.append(sep); + text.append(symbol_table[id]); + sep = " "; + } + texts.emplace_back(std::move(text)); + } + + std::ostringstream os; + os << "\nDecoding result:\n\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + + return 0; +} diff --git a/k2/torch/bin/ctc_decode.cu b/k2/torch/bin/ctc_decode.cu new file mode 100644 index 000000000..a0110be6a --- /dev/null +++ b/k2/torch/bin/ctc_decode.cu @@ -0,0 +1,208 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/torch/csrc/decode.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/symbol_table.h" +#include "k2/torch/csrc/wave_reader.h" +#include "torch/all.h" +#include "torch/script.h" + +static constexpr const char *kUsageMessage = R"( +This file implements decoding with a CTC topology, without any +kinds of LM or lexicons. + +Usage: + ./bin/ctc_decode \ + --use_gpu true \ + --nn_model \ + --tokens \ + \ + \ + + +To see all possible options, use + ./bin/ctc_decode --help + +Caution: + - Only sound files (*.wav) with single channel are supported. + - It assumes the model is conformer_ctc/transformer.py from icefall. + If you use a different model, you have to change the code + related to `model.forward` in this file. +)"; + +C10_DEFINE_bool(use_gpu, false, "true to use GPU; false to use CPU"); +C10_DEFINE_string(nn_model, "", "Path to the model exported by torch script."); +C10_DEFINE_string(tokens, "", "Path to the tokens.txt"); + +// Fsa decoding related +C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned"); +C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned"); +C10_DEFINE_int(min_activate_states, 30, + "min_activate_states in IntersectDensePruned"); +C10_DEFINE_int(max_activate_states, 10000, + "max_activate_states in IntersectDensePruned"); +// Fbank related +// NOTE: These parameters must match those used in training +C10_DEFINE_int(sample_rate, 16000, "Expected sample rate of wave files"); +C10_DEFINE_double(frame_shift_ms, 10.0, + "Frame shift in ms for computing Fbank"); +C10_DEFINE_double(frame_length_ms, 25.0, + "Frame length in ms for computing Fbank"); +C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank"); + +static void CheckArgs() { +#if !defined(K2_WITH_CUDA) + if (FLAGS_use_gpu) { + std::cerr << "k2 was not compiled with CUDA. " + "Please use --use_gpu false"; + exit(EXIT_FAILURE); + } +#endif + + if (FLAGS_nn_model.empty()) { + std::cerr << "Please provide --nn_model\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_tokens.empty()) { + std::cerr << "Please provide --tokens\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } +} + +int main(int argc, char *argv[]) { + // see + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html + torch::set_num_threads(1); + torch::set_num_interop_threads(1); + torch::NoGradGuard no_grad; + + torch::SetUsageMessage(kUsageMessage); + torch::ParseCommandLineFlags(&argc, &argv); + CheckArgs(); + + torch::Device device(torch::kCPU); + if (FLAGS_use_gpu) { + K2_LOG(INFO) << "Use GPU"; + device = torch::Device(torch::kCUDA, 0); + } + + K2_LOG(INFO) << "Device: " << device; + + int32_t num_waves = argc - 1; + K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; + std::vector wave_filenames(num_waves); + for (int32_t i = 0; i != num_waves; ++i) { + wave_filenames[i] = argv[i + 1]; + } + + K2_LOG(INFO) << "Load wave files"; + auto wave_data = k2::ReadWave(wave_filenames, FLAGS_sample_rate); + + for (auto &w : wave_data) { + w = w.to(device); + } + + K2_LOG(INFO) << "Build Fbank computer"; + kaldifeat::FbankOptions fbank_opts; + fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate; + fbank_opts.frame_opts.dither = 0; + fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms; + fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms; + fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + fbank_opts.device = device; + + kaldifeat::Fbank fbank(fbank_opts); + + K2_LOG(INFO) << "Compute features"; + std::vector num_frames; + auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames); + + // Note: math.log(1e-10) is -23.025850929940457 + auto features = torch::nn::utils::rnn::pad_sequence(features_vec, true, + -23.025850929940457f); + + K2_LOG(INFO) << "Load neural network model"; + torch::jit::script::Module module = torch::jit::load(FLAGS_nn_model); + module.eval(); + module.to(device); + + int32_t subsampling_factor = module.attr("subsampling_factor").toInt(); + torch::Dict sup; + sup.insert("sequence_idx", torch::arange(num_waves, torch::kInt)); + sup.insert("start_frame", torch::zeros({num_waves}, torch::kInt)); + sup.insert("num_frames", + torch::from_blob(num_frames.data(), {num_waves}, torch::kLong) + .to(torch::kInt)); + + torch::IValue supervisions(sup); + + K2_LOG(INFO) << "Compute nnet_output"; + // the output for module.forward() is a tuple of 3 tensors + // See the definition of the model in conformer_ctc/transformer.py + // from icefall. + // If you use a model that has a different signature for `forward`, + // you can change the following line. + auto outputs = module.run_method("forward", features, supervisions).toTuple(); + assert(outputs->elements().size() == 3u); + + auto nnet_output = outputs->elements()[0].toTensor(); + + torch::Tensor supervision_segments = + k2::GetSupervisionSegments(supervisions, subsampling_factor); + + K2_LOG(INFO) << "Build CTC topo"; + auto decoding_graph = k2::CtcTopo(nnet_output.size(2) - 1, false, device); + + K2_LOG(INFO) << "Decoding"; + k2::FsaClass lattice = k2::GetLattice( + nnet_output, decoding_graph, supervision_segments, FLAGS_search_beam, + FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states, + subsampling_factor); + + lattice = k2::ShortestPath(lattice); + + auto ragged_aux_labels = k2::GetTexts(lattice); + auto aux_labels_vec = ragged_aux_labels.ToVecVec(); + + k2::SymbolTable symbol_table(FLAGS_tokens); + + std::vector texts; + for (const auto &ids : aux_labels_vec) { + std::string text; + for (auto id : ids) { + text.append(symbol_table[id]); + } + texts.emplace_back(std::move(text)); + } + + std::ostringstream os; + os << "\nDecoding result:\n\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + + return 0; +} diff --git a/k2/torch/bin/hlg_decode.cu b/k2/torch/bin/hlg_decode.cu new file mode 100644 index 000000000..7551881c3 --- /dev/null +++ b/k2/torch/bin/hlg_decode.cu @@ -0,0 +1,219 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/torch/csrc/decode.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/symbol_table.h" +#include "k2/torch/csrc/wave_reader.h" +#include "torch/all.h" +#include "torch/script.h" + +static constexpr const char *kUsageMessage = R"( +This file implements decoding with an HLG decoding graph. + +Usage: + ./bin/hlg_decode \ + --use_gpu true \ + --nn_model \ + --hlg \ + --word_table \ + \ + \ + + +To see all possible options, use + ./bin/hlg_decode --help + +Caution: + - Only sound files (*.wav) with single channel are supported. + - It assumes the model is conformer_ctc/transformer.py from icefall. + If you use a different model, you have to change the code + related to `model.forward` in this file. +)"; + +C10_DEFINE_bool(use_gpu, false, "true to use GPU; false to use CPU"); +C10_DEFINE_string(nn_model, "", "Path to the model exported by torch script."); +C10_DEFINE_string(hlg, "", "Path to HLG.pt."); +C10_DEFINE_string(word_table, "", "Path to words.txt."); + +// Fsa decoding related +C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned"); +C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned"); +C10_DEFINE_int(min_activate_states, 30, + "min_activate_states in IntersectDensePruned"); +C10_DEFINE_int(max_activate_states, 10000, + "max_activate_states in IntersectDensePruned"); +// Fbank related +// NOTE: These parameters must match those used in training +C10_DEFINE_int(sample_rate, 16000, "Expected sample rate of wave files"); +C10_DEFINE_double(frame_shift_ms, 10.0, + "Frame shift in ms for computing Fbank"); +C10_DEFINE_double(frame_length_ms, 25.0, + "Frame length in ms for computing Fbank"); +C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank"); + +static void CheckArgs() { +#if !defined(K2_WITH_CUDA) + if (FLAGS_use_gpu) { + std::cerr << "k2 was not compiled with CUDA. " + "Please use --use_gpu false"; + exit(EXIT_FAILURE); + } +#endif + + if (FLAGS_nn_model.empty()) { + std::cerr << "Please provide --nn_model\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_hlg.empty()) { + std::cerr << "Please provide --hlg\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_word_table.empty()) { + std::cerr << "Please provide --word_table\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } +} + +int main(int argc, char *argv[]) { + // see + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html + torch::set_num_threads(1); + torch::set_num_interop_threads(1); + torch::NoGradGuard no_grad; + + torch::SetUsageMessage(kUsageMessage); + torch::ParseCommandLineFlags(&argc, &argv); + CheckArgs(); + + torch::Device device(torch::kCPU); + if (FLAGS_use_gpu) { + K2_LOG(INFO) << "Use GPU"; + device = torch::Device(torch::kCUDA, 0); + } + + K2_LOG(INFO) << "Device: " << device; + + int32_t num_waves = argc - 1; + K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; + std::vector wave_filenames(num_waves); + for (int32_t i = 0; i != num_waves; ++i) { + wave_filenames[i] = argv[i + 1]; + } + + K2_LOG(INFO) << "Load wave files"; + auto wave_data = k2::ReadWave(wave_filenames, FLAGS_sample_rate); + + for (auto &w : wave_data) { + w = w.to(device); + } + + K2_LOG(INFO) << "Build Fbank computer"; + kaldifeat::FbankOptions fbank_opts; + fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate; + fbank_opts.frame_opts.dither = 0; + fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms; + fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms; + fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + fbank_opts.device = device; + + kaldifeat::Fbank fbank(fbank_opts); + + K2_LOG(INFO) << "Compute features"; + std::vector num_frames; + auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames); + + // Note: math.log(1e-10) is -23.025850929940457 + auto features = torch::nn::utils::rnn::pad_sequence(features_vec, true, + -23.025850929940457f); + + K2_LOG(INFO) << "Load neural network model"; + torch::jit::script::Module module = torch::jit::load(FLAGS_nn_model); + module.eval(); + module.to(device); + + int32_t subsampling_factor = module.attr("subsampling_factor").toInt(); + torch::Dict sup; + sup.insert("sequence_idx", torch::arange(num_waves, torch::kInt)); + sup.insert("start_frame", torch::zeros({num_waves}, torch::kInt)); + sup.insert("num_frames", + torch::from_blob(num_frames.data(), {num_waves}, torch::kLong) + .to(torch::kInt)); + + torch::IValue supervisions(sup); + + K2_LOG(INFO) << "Compute nnet_output"; + // the output for module.forward() is a tuple of 3 tensors + // See the definition of the model in conformer_ctc/transformer.py + // from icefall. + // If you use a model that has a different signature for `forward`, + // you can change the following line. + auto outputs = module.run_method("forward", features, supervisions).toTuple(); + assert(outputs->elements().size() == 3u); + + auto nnet_output = outputs->elements()[0].toTensor(); + auto memory = outputs->elements()[1].toTensor(); + + torch::Tensor supervision_segments = + k2::GetSupervisionSegments(supervisions, subsampling_factor); + + K2_LOG(INFO) << "Load " << FLAGS_hlg; + k2::FsaClass decoding_graph = k2::LoadFsa(FLAGS_hlg, device); + K2_CHECK(decoding_graph.HasTensorAttr("aux_labels") || + decoding_graph.HasRaggedTensorAttr("aux_labels")); + + K2_LOG(INFO) << "Decoding"; + k2::FsaClass lattice = k2::GetLattice( + nnet_output, decoding_graph, supervision_segments, FLAGS_search_beam, + FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states, + subsampling_factor); + + lattice = k2::ShortestPath(lattice); + + auto ragged_aux_labels = k2::GetTexts(lattice); + auto aux_labels_vec = ragged_aux_labels.ToVecVec(); + + std::vector texts; + k2::SymbolTable symbol_table(FLAGS_word_table); + for (const auto &ids : aux_labels_vec) { + std::string text; + std::string sep = ""; + for (auto id : ids) { + text.append(sep); + text.append(symbol_table[id]); + sep = " "; + } + texts.emplace_back(std::move(text)); + } + + std::ostringstream os; + os << "\nDecoding result:\n\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + + return 0; +} diff --git a/k2/torch/bin/ngram_lm_rescore.cu b/k2/torch/bin/ngram_lm_rescore.cu new file mode 100644 index 000000000..3ea8c31d6 --- /dev/null +++ b/k2/torch/bin/ngram_lm_rescore.cu @@ -0,0 +1,245 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/fsa_algo.h" +#include "k2/torch/csrc/decode.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/symbol_table.h" +#include "k2/torch/csrc/wave_reader.h" +#include "torch/all.h" +#include "torch/script.h" + +static constexpr const char *kUsageMessage = R"( +This file implements decoding with an HLG decoding graph, using +an n-gram LM for rescoring. + +Usage: + ./bin/ngram_lm_rescore \ + --use_gpu true \ + --nn_model \ + --hlg \ + --g \ + --ngram_lm_scale 1.0 \ + --word_table \ + \ + \ + + +To see all possible options, use + ./bin/ngram_lm_rescore --help + +Caution: + - Only sound files (*.wav) with single channel are supported. + - It assumes the model is conformer_ctc/transformer.py from icefall. + If you use a different model, you have to change the code + related to `model.forward` in this file. +)"; + +C10_DEFINE_bool(use_gpu, false, "true to use GPU; false to use CPU"); +C10_DEFINE_string(nn_model, "", "Path to the model exported by torch script."); +C10_DEFINE_string(hlg, "", "Path to HLG.pt."); +C10_DEFINE_string(g, "", "Path to an ngram LM, e.g, G_4gram.pt"); +C10_DEFINE_double(ngram_lm_scale, 1.0, "Scale for ngram LM scores"); +C10_DEFINE_string(word_table, "", "Path to words.txt."); + +// Fsa decoding related +C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned"); +C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned"); +C10_DEFINE_int(min_activate_states, 30, + "min_activate_states in IntersectDensePruned"); +C10_DEFINE_int(max_activate_states, 10000, + "max_activate_states in IntersectDensePruned"); +// Fbank related +// NOTE: These parameters must match those used in training +C10_DEFINE_int(sample_rate, 16000, "Expected sample rate of wave files"); +C10_DEFINE_double(frame_shift_ms, 10.0, + "Frame shift in ms for computing Fbank"); +C10_DEFINE_double(frame_length_ms, 25.0, + "Frame length in ms for computing Fbank"); +C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank"); + +static void CheckArgs() { +#if !defined(K2_WITH_CUDA) + if (FLAGS_use_gpu) { + std::cerr << "k2 was not compiled with CUDA. " + "Please use --use_gpu false"; + exit(EXIT_FAILURE); + } +#endif + + if (FLAGS_nn_model.empty()) { + std::cerr << "Please provide --nn_model\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_hlg.empty()) { + std::cerr << "Please provide --hlg\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_g.empty()) { + std::cerr << "Please provide --g\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } + + if (FLAGS_word_table.empty()) { + std::cerr << "Please provide --word_table\n" << torch::UsageMessage(); + exit(EXIT_FAILURE); + } +} + +int main(int argc, char *argv[]) { + // see + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html + torch::set_num_threads(1); + torch::set_num_interop_threads(1); + torch::NoGradGuard no_grad; + + torch::SetUsageMessage(kUsageMessage); + torch::ParseCommandLineFlags(&argc, &argv); + CheckArgs(); + + torch::Device device(torch::kCPU); + if (FLAGS_use_gpu) { + K2_LOG(INFO) << "Use GPU"; + device = torch::Device(torch::kCUDA, 0); + } + + K2_LOG(INFO) << "Device: " << device; + + int32_t num_waves = argc - 1; + K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; + std::vector wave_filenames(num_waves); + for (int32_t i = 0; i != num_waves; ++i) { + wave_filenames[i] = argv[i + 1]; + } + + K2_LOG(INFO) << "Load wave files"; + auto wave_data = k2::ReadWave(wave_filenames, FLAGS_sample_rate); + + for (auto &w : wave_data) { + w = w.to(device); + } + + K2_LOG(INFO) << "Build Fbank computer"; + kaldifeat::FbankOptions fbank_opts; + fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate; + fbank_opts.frame_opts.dither = 0; + fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms; + fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms; + fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + fbank_opts.device = device; + + kaldifeat::Fbank fbank(fbank_opts); + + K2_LOG(INFO) << "Compute features"; + std::vector num_frames; + auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames); + + // Note: math.log(1e-10) is -23.025850929940457 + auto features = torch::nn::utils::rnn::pad_sequence(features_vec, true, + -23.025850929940457f); + + K2_LOG(INFO) << "Load neural network model"; + torch::jit::script::Module module = torch::jit::load(FLAGS_nn_model); + module.eval(); + module.to(device); + + int32_t subsampling_factor = module.attr("subsampling_factor").toInt(); + torch::Dict sup; + sup.insert("sequence_idx", torch::arange(num_waves, torch::kInt)); + sup.insert("start_frame", torch::zeros({num_waves}, torch::kInt)); + sup.insert("num_frames", + torch::from_blob(num_frames.data(), {num_waves}, torch::kLong) + .to(torch::kInt)); + + torch::IValue supervisions(sup); + + K2_LOG(INFO) << "Compute nnet_output"; + // the output for module.forward() is a tuple of 3 tensors + // See the definition of the model in conformer_ctc/transformer.py + // from icefall. + // If you use a model that has a different signature for `forward`, + // you can change the following line. + auto outputs = module.run_method("forward", features, supervisions).toTuple(); + assert(outputs->elements().size() == 3u); + + auto nnet_output = outputs->elements()[0].toTensor(); + auto memory = outputs->elements()[1].toTensor(); + + torch::Tensor supervision_segments = + k2::GetSupervisionSegments(supervisions, subsampling_factor); + + K2_LOG(INFO) << "Load " << FLAGS_hlg; + k2::FsaClass decoding_graph = k2::LoadFsa(FLAGS_hlg, device); + K2_CHECK(decoding_graph.HasTensorAttr("aux_labels") || + decoding_graph.HasRaggedTensorAttr("aux_labels")); + // Add `lm_scores` so that we can separate acoustic scores and lm scores + // later in the rescoring stage. + decoding_graph.SetTensorAttr("lm_scores", decoding_graph.Scores().clone()); + + K2_LOG(INFO) << "Decoding"; + k2::FsaClass lattice = k2::GetLattice( + nnet_output, decoding_graph, supervision_segments, FLAGS_search_beam, + FLAGS_output_beam, FLAGS_min_activate_states, FLAGS_max_activate_states, + subsampling_factor); + + K2_LOG(INFO) << "Load n-gram LM: " << FLAGS_g; + k2::FsaClass G = k2::LoadFsa(FLAGS_g, device); + G.fsa = k2::FsaToFsaVec(G.fsa); + + K2_CHECK_EQ(G.NumAttrs(), 0) << "G is expected to be an acceptor."; + k2::AddEpsilonSelfLoops(G.fsa, &G.fsa); + k2::ArcSort(&G.fsa); + G.SetTensorAttr("lm_scores", G.Scores().clone()); + + K2_LOG(INFO) << "Rescore with an n-gram LM"; + WholeLatticeRescoring(G, FLAGS_ngram_lm_scale, &lattice); + + lattice = k2::ShortestPath(lattice); + + auto ragged_aux_labels = k2::GetTexts(lattice); + auto aux_labels_vec = ragged_aux_labels.ToVecVec(); + + std::vector texts; + k2::SymbolTable symbol_table(FLAGS_word_table); + for (const auto &ids : aux_labels_vec) { + std::string text; + std::string sep = ""; + for (auto id : ids) { + text.append(sep); + text.append(symbol_table[id]); + sep = " "; + } + texts.emplace_back(std::move(text)); + } + + std::ostringstream os; + os << "\nDecoding result:\n\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + + return 0; +} diff --git a/k2/torch/bin/online_decode.cu b/k2/torch/bin/online_decode.cu new file mode 100644 index 000000000..d10629747 --- /dev/null +++ b/k2/torch/bin/online_decode.cu @@ -0,0 +1,322 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include "k2/csrc/online_dense_intersector.h" +#include "k2/torch/csrc/decode.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/symbol_table.h" +#include "k2/torch/csrc/utils.h" +#include "k2/torch/csrc/wave_reader.h" +#include "kaldifeat/csrc/feature-fbank.h" +#include "torch/all.h" +#include "torch/script.h" +#include "torch/utils.h" + +C10_DEFINE_bool(use_gpu, false, "True to use GPU. False to use CPU"); +C10_DEFINE_string(jit_pt, "", "Path to exported jit file."); +C10_DEFINE_string(tokens, "", + "Path to tokens.txt. Needed if --use_ctc_decoding is true"); +C10_DEFINE_bool(use_ctc_decoding, true, "True to use CTC decoding"); +C10_DEFINE_string(hlg, "", + "Path to HLG.pt. Needed if --use_ctc_decoding is false"); +C10_DEFINE_string(word_table, "", + "Path to words.txt. Needed if --use_ctc_decoding is false"); +// Fsa decoding related +C10_DEFINE_double(search_beam, 20, "search_beam in IntersectDensePruned"); +C10_DEFINE_double(output_beam, 8, "output_beam in IntersectDensePruned"); +C10_DEFINE_int(min_activate_states, 30, + "min_activate_states in IntersectDensePruned"); +C10_DEFINE_int(max_activate_states, 10000, + "max_activate_states in IntersectDensePruned"); +// fbank related +C10_DEFINE_int(sample_rate, 16000, "Expected sample rate of wave files"); +C10_DEFINE_double(frame_shift_ms, 10.0, + "Frame shift in ms for computing Fbank"); +C10_DEFINE_double(frame_length_ms, 25.0, + "Frame length in ms for computing Fbank"); +C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank"); + +static void CheckArgs() { +#if !defined(K2_WITH_CUDA) + if (FLAGS_use_gpu) { + std::cerr << "k2 was not compiled with CUDA" + << "\n"; + std::cerr << "Please use --use_gpu 0" + << "\n"; + exit(EXIT_FAILURE); + } +#endif + + if (FLAGS_jit_pt.empty()) { + std::cerr << "Please provide --jit_pt" + << "\n"; + std::cerr << torch::UsageMessage() << "\n"; + exit(EXIT_FAILURE); + } + + if (FLAGS_use_ctc_decoding && FLAGS_tokens.empty()) { + std::cout << "Please provide --tokens" + << "\n"; + std::cout << torch::UsageMessage() << "\n"; + exit(EXIT_FAILURE); + } + + if (FLAGS_use_ctc_decoding == false && FLAGS_hlg.empty()) { + std::cerr << "Please provide --hlg" + << "\n"; + std::cerr << torch::UsageMessage() << "\n"; + exit(EXIT_FAILURE); + } + + if (FLAGS_use_ctc_decoding == false && FLAGS_word_table.empty()) { + std::cerr << "Please provide --word_table" + << "\n"; + std::cerr << torch::UsageMessage() << "\n"; + exit(EXIT_FAILURE); + } +} + +int main(int argc, char *argv[]) { + // see + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html + torch::set_num_threads(1); + torch::set_num_interop_threads(1); + + std::string usage = R"( + (1) CTC decoding + ./bin/online_decode \ + --use_ctc_decoding true \ + --jit_pt \ + --tokens \ + /path/to/foo.wav \ + /path/to/bar.wav \ + + (2) HLG decoding + ./bin/online_decode \ + --use_ctc_decoding false \ + --jit_pt \ + --hlg \ + --word_table \ + /path/to/foo.wav \ + /path/to/bar.wav \ + + --use_gpu false to use CPU + --use_gpu true to use GPU + )"; + torch::SetUsageMessage(usage); + + torch::ParseCommandLineFlags(&argc, &argv); + CheckArgs(); + + torch::Device device(torch::kCPU); + if (FLAGS_use_gpu) { + device = torch::Device(torch::kCUDA, 0); + } + + K2_LOG(INFO) << "Device: " << device; + + int32_t num_waves = argc - 1; + K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; + std::vector wave_filenames(num_waves); + for (int32_t i = 0; i != num_waves; ++i) { + wave_filenames[i] = argv[i + 1]; + } + + kaldifeat::FbankOptions fbank_opts; + fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate; + fbank_opts.frame_opts.dither = 0; + fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms; + fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms; + fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + fbank_opts.device = device; + + kaldifeat::Fbank fbank(fbank_opts); + + K2_LOG(INFO) << "Load wave files"; + auto wave_data = k2::ReadWave(wave_filenames, FLAGS_sample_rate); + + for (auto &w : wave_data) { + w = w.to(device); + } + + K2_LOG(INFO) << "Compute features"; + std::vector num_frames; + auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames); + + // Note: math.log(1e-10) is -23.025850929940457 + auto features = torch::nn::utils::rnn::pad_sequence(features_vec, true, + -23.025850929940457f); + + K2_LOG(INFO) << "Load neural network model"; + torch::jit::script::Module module = torch::jit::load(FLAGS_jit_pt); + module.eval(); + module.to(device); + + int32_t subsampling_factor = module.attr("subsampling_factor").toInt(); + torch::Dict sup; + sup.insert("sequence_idx", torch::arange(num_waves, torch::kInt)); + sup.insert("start_frame", torch::zeros({num_waves}, torch::kInt)); + sup.insert("num_frames", + torch::from_blob(num_frames.data(), {num_waves}, torch::kLong) + .to(torch::kInt)); + + torch::IValue supervisions(sup); + + std::vector inputs; + inputs.emplace_back(std::move(features)); + inputs.emplace_back(supervisions); + + K2_LOG(INFO) << "Compute nnet_output"; + // the output for module.forward() is a tuple of 3 tensors + auto outputs = module.forward(inputs).toTuple(); + assert(outputs->elements().size() == 3u); + + auto nnet_output = outputs->elements()[0].toTensor(); + auto memory = outputs->elements()[1].toTensor(); + + // memory_key_padding_mask is used in attention decoder rescoring + // auto memory_key_padding_mask = outputs->elements()[2].toTensor(); + + k2::FsaClass decoding_graph; + + if (FLAGS_use_ctc_decoding) { + K2_LOG(INFO) << "Build CTC topo"; + decoding_graph = + k2::CtcTopo(nnet_output.size(2) - 1, /*modified*/ false, device); + } else { + K2_LOG(INFO) << "Load HLG.pt"; + decoding_graph = k2::LoadFsa(FLAGS_hlg, device); + K2_CHECK(decoding_graph.HasTensorAttr("aux_labels") || + decoding_graph.HasRaggedTensorAttr("aux_labels")); + } + + K2_LOG(INFO) << "Decoding"; + + auto decoding_fsa = k2::FsaToFsaVec(decoding_graph.fsa); + k2::OnlineDenseIntersecter decoder( + decoding_fsa, num_waves, FLAGS_search_beam, FLAGS_output_beam, + FLAGS_min_activate_states, FLAGS_max_activate_states); + + std::vector texts(num_waves, ""); + + int32_t T = nnet_output.size(1); + int32_t chunk_size = 20; // 20 frames per chunk + int32_t chunk_num = (T / chunk_size) + ((T % chunk_size) ? 1 : 0); + + for (int32_t c = 0; c < chunk_num; ++c) { + int32_t start = c * chunk_size; + int32_t end = (c + 1) * chunk_size >= T ? T : (c + 1) * chunk_size; + std::vector num_frame; + for (auto &frame : num_frames) { + if (frame < chunk_size * subsampling_factor) { + num_frame.push_back(frame); + frame = 0; + } else { + num_frame.push_back(chunk_size * subsampling_factor); + frame -= chunk_size * subsampling_factor; + } + } + torch::Dict sup; + sup.insert("sequence_idx", torch::arange(num_waves, torch::kInt)); + sup.insert("start_frame", torch::zeros({num_waves}, torch::kInt)); + sup.insert("num_frames", + torch::from_blob(num_frame.data(), {num_waves}, torch::kLong) + .to(torch::kInt)); + torch::IValue supervision(sup); + + // cut nnet_output into chunks + using namespace torch::indexing; // NOLINT + auto sub_nnet_output = + nnet_output.index({Slice(), Slice(start, end), Slice()}); + + torch::Tensor supervision_segments = + k2::GetSupervisionSegments(supervision, subsampling_factor); + + k2::DenseFsaVec dense_fsa_vec = k2::CreateDenseFsaVec( + sub_nnet_output, supervision_segments, subsampling_factor - 1); + + bool is_final = c == chunk_num - 1 ? true : false; + decoder.Intersect(dense_fsa_vec, is_final); + + k2::FsaVec fsa; + k2::Array1 graph_arc_map; + decoder.FormatOutput(&fsa, &graph_arc_map, is_final); + + k2::FsaClass lattice(fsa); + lattice.CopyAttrs(decoding_graph, + k2::Array1ToTorch(graph_arc_map)); + + lattice = k2::ShortestPath(lattice); + + auto ragged_aux_labels = k2::GetTexts(lattice); + + auto aux_labels_vec = ragged_aux_labels.ToVecVec(); + + if (FLAGS_use_ctc_decoding) { + k2::SymbolTable symbol_table(FLAGS_tokens); + for (size_t i = 0; i < aux_labels_vec.size(); ++i) { + std::string text; + + for (auto id : aux_labels_vec[i]) { + text.append(symbol_table[id]); + } + + texts[i] = std::move(text); + } + } else { + k2::SymbolTable symbol_table(FLAGS_word_table); + for (size_t i = 0; i < aux_labels_vec.size(); ++i) { + std::string text; + std::string sep = ""; + for (auto id : aux_labels_vec[i]) { + text.append(sep); + text.append(symbol_table[id]); + sep = " "; + } + texts[i] = text; + } + } + std::ostringstream os; + os << "\nPartial result:\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + } + + std::ostringstream os; + os << "\nDecoding result:\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + + return 0; +} diff --git a/k2/torch/bin/pruned_stateless_transducer.cu b/k2/torch/bin/pruned_stateless_transducer.cu new file mode 100644 index 000000000..18f6a8ea8 --- /dev/null +++ b/k2/torch/bin/pruned_stateless_transducer.cu @@ -0,0 +1,193 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/log.h" +#include "k2/torch/csrc/beam_search.h" +#include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/parse_options.h" +#include "k2/torch/csrc/symbol_table.h" +#include "k2/torch/csrc/wave_reader.h" +#include "torch/all.h" + +static constexpr const char *kUsageMessage = R"( +This file implements RNN-T decoding for pruned stateless transducer models +that are trained using pruned_transducer_statelessX (X>=2) from icefall. + +Usage: + ./bin/pruned_stateless_transducer --help + + ./bin/pruned_stateless_transducer \ + --nn-model=/path/to/cpu_jit.pt \ + --tokens=/path/to/tokens.txt \ + --use-gpu=true \ + --decoding-method=modified_beam_search \ + /path/to/foo.wav \ + /path/to/bar.wav +)"; + +static void RegisterFrameExtractionOptions( + k2::ParseOptions *po, kaldifeat::FrameExtractionOptions *opts) { + po->Register("sample-frequency", &opts->samp_freq, + "Waveform data sample frequency (must match the waveform file, " + "if specified there)"); + + po->Register("frame-length", &opts->frame_length_ms, + "Frame length in milliseconds"); + + po->Register("frame-shift", &opts->frame_shift_ms, + "Frame shift in milliseconds"); + + po->Register("dither", &opts->dither, + "Dithering constant (0.0 means no dither)."); +} + +static void RegisterMelBanksOptions(k2::ParseOptions *po, + kaldifeat::MelBanksOptions *opts) { + po->Register("num-mel-bins", &opts->num_bins, + "Number of triangular mel-frequency bins"); +} + +int main(int argc, char *argv[]) { + // see + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html + torch::set_num_threads(1); + torch::set_num_interop_threads(1); + torch::NoGradGuard no_grad; + + k2::ParseOptions po(kUsageMessage); + + std::string nn_model; // path to the torch jit model file + std::string tokens; // path to tokens.txt + bool use_gpu = false; // true to use GPU for decoding; false to use CPU. + std::string decoding_method = "greedy_search"; // Supported methods are: + // greedy_search, + // modified_beam_search + + kaldifeat::FbankOptions fbank_opts; + fbank_opts.frame_opts.dither = 0; + RegisterFrameExtractionOptions(&po, &fbank_opts.frame_opts); + fbank_opts.mel_opts.num_bins = 80; + RegisterMelBanksOptions(&po, &fbank_opts.mel_opts); + + po.Register("nn-model", &nn_model, "Path to the torch jit model file"); + + po.Register("tokens", &tokens, "Path to the tokens.txt"); + + po.Register("use-gpu", &use_gpu, + "true to use GPU for decoding; false to use CPU. " + "If GPU is enabled, it always uses GPU 0. You can use " + "the environment variable CUDA_VISIBLE_DEVICES to control " + "which GPU device to use."); + + po.Register( + "decoding-method", &decoding_method, + "Decoding method to use." + "Currently implemented methods are: greedy_search, modified_beam_search"); + + po.Read(argc, argv); + + K2_CHECK(decoding_method == "greedy_search" || + decoding_method == "modified_beam_search") + << "Currently supported decoding methods are: " + "greedy_search, modified_beam_search. " + << "Given: " << decoding_method; + + torch::Device device(torch::kCPU); + if (use_gpu) { + K2_LOG(INFO) << "Use GPU"; + device = torch::Device(torch::kCUDA, 0); + } + + K2_LOG(INFO) << "Device: " << device; + + int32_t num_waves = po.NumArgs(); + K2_CHECK_GT(num_waves, 0) << "Please provide at least one wave file"; + + std::vector wave_filenames(num_waves); + for (int32_t i = 0; i < num_waves; ++i) { + wave_filenames[i] = po.GetArg(i + 1); + } + + K2_LOG(INFO) << "Loading wave files"; + std::vector wave_data = + k2::ReadWave(wave_filenames, fbank_opts.frame_opts.samp_freq); + for (auto &w : wave_data) { + w = w.to(device); + } + + fbank_opts.device = device; + + kaldifeat::Fbank fbank(fbank_opts); + + K2_LOG(INFO) << "Computing features"; + std::vector num_frames; + std::vector features_vec = + k2::ComputeFeatures(fbank, wave_data, &num_frames); + + // Note: math.log(1e-10) is -23.025850929940457 + torch::Tensor features = torch::nn::utils::rnn::pad_sequence( + features_vec, /*batch_first*/ true, + /*padding_value*/ -23.025850929940457f); + torch::Tensor feature_lens = torch::tensor(num_frames, device); + + K2_LOG(INFO) << "Loading neural network model from " << nn_model; + torch::jit::Module module = torch::jit::load(nn_model); + module.eval(); + module.to(device); + + K2_LOG(INFO) << "Computing output of the encoder network"; + + auto outputs = module.attr("encoder") + .toModule() + .run_method("forward", features, feature_lens) + .toTuple(); + assert(outputs->elements().size() == 2u); + + auto encoder_out = outputs->elements()[0].toTensor(); + auto encoder_out_lens = outputs->elements()[1].toTensor(); + + K2_LOG(INFO) << "Using " << decoding_method; + + std::vector> hyp_tokens; + if (decoding_method == "greedy_search") { + hyp_tokens = k2::GreedySearch(module, encoder_out, encoder_out_lens.cpu()); + } else { + hyp_tokens = + k2::ModifiedBeamSearch(module, encoder_out, encoder_out_lens.cpu()); + } + + k2::SymbolTable symbol_table(tokens); + + std::vector texts; + for (const auto &ids : hyp_tokens) { + std::string text; + for (auto id : ids) { + text.append(symbol_table[id]); + } + texts.emplace_back(std::move(text)); + } + + std::ostringstream os; + os << "\nDecoding result:\n\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); +}; diff --git a/k2/torch/bin/rnnt_demo.cu b/k2/torch/bin/rnnt_demo.cu new file mode 100644 index 000000000..462db70fc --- /dev/null +++ b/k2/torch/bin/rnnt_demo.cu @@ -0,0 +1,355 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include + +#include "k2/torch/csrc/decode.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/features.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/symbol_table.h" +#include "k2/torch/csrc/utils.h" +#include "k2/torch/csrc/wave_reader.h" +#include "kaldifeat/csrc/feature-fbank.h" +#include "torch/all.h" +#include "torch/script.h" +#include "torch/utils.h" + +C10_DEFINE_bool(use_gpu, false, "True to use GPU. False to use CPU"); +C10_DEFINE_string(jit_pt, "", "Path to exported jit file."); +C10_DEFINE_bool( + use_lg, false, + "True to use an LG decoding graph. False to use a trivial decoding graph"); +C10_DEFINE_string(tokens, "", + "Path to a tokens.txt. Needed if --use_lg is false"); +C10_DEFINE_string(lg, "", "Path to LG.pt. Needed if --use_lg is true"); +C10_DEFINE_string(word_table, "", + "Path to words.txt. Needed if --use_lg is true"); +// Rnnt decoding related +C10_DEFINE_double(beam, 8.0, "beam in RnntDecodingStreams"); +C10_DEFINE_int(max_states, 64, "max_states in RnntDecodingStreams"); +C10_DEFINE_int(max_contexts, 8, "max_contexts in RnntDecodingStreams"); +// fbank related +C10_DEFINE_int(sample_rate, 16000, "Expected sample rate of wave files"); +C10_DEFINE_double(frame_shift_ms, 10.0, + "Frame shift in ms for computing Fbank"); +C10_DEFINE_double(frame_length_ms, 25.0, + "Frame length in ms for computing Fbank"); +C10_DEFINE_int(num_bins, 80, "Number of triangular bins for computing Fbank"); +C10_DEFINE_int(max_num_streams, 2, "Max number of decoding streams"); +C10_DEFINE_bool( + use_max, true, + "True to use max operation to select the hypothesis with the largest " + "log_prob when there are duplicate hypotheses; False to use log-add."); +C10_DEFINE_int(num_paths, 200, + "Number of paths to sample when generating Nbest"); +C10_DEFINE_double(nbest_scale, 0.8, + "The scale value applying to lattice.score before sampling"); + +static void CheckArgs() { +#if !defined(K2_WITH_CUDA) + if (FLAGS_use_gpu) { + std::cerr << "k2 was not compiled with CUDA" + << "\n"; + std::cerr << "Please use --use_gpu 0" + << "\n"; + exit(EXIT_FAILURE); + } +#endif + + if (FLAGS_jit_pt.empty()) { + std::cerr << "Please provide --jit_pt" + << "\n"; + std::cerr << torch::UsageMessage() << "\n"; + exit(EXIT_FAILURE); + } + + if (FLAGS_use_lg == false && FLAGS_tokens.empty()) { + std::cout << "Please provide --tokens" + << "\n"; + std::cout << torch::UsageMessage() << "\n"; + exit(EXIT_FAILURE); + } + + if (FLAGS_use_lg && FLAGS_lg.empty()) { + std::cout << "Please provide --lg" + << "\n"; + std::cout << torch::UsageMessage() << "\n"; + exit(EXIT_FAILURE); + } + + if (FLAGS_use_lg && FLAGS_word_table.empty()) { + std::cerr << "Please provide --word_table" + << "\n"; + std::cerr << torch::UsageMessage() << "\n"; + exit(EXIT_FAILURE); + } +} + +int main(int argc, char *argv[]) { + // see + // https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html + torch::set_num_threads(1); + torch::set_num_interop_threads(1); + + std::string usage = R"( + (1) Decoding without LG graph + ./bin/rnnt_demo \ + --use_lg false \ + --jit_pt \ + --tokens \ + /path/to/foo.wav \ + /path/to/bar.wav \ + + (2) Decoding with LG graph + ./bin/rnnt_demo \ + --use_lg true \ + --jit_pt \ + --lg \ + --word_table \ + --beam 8 \ + --max_contexts 8 \ + --max_states 64 \ + /path/to/foo.wav \ + /path/to/bar.wav \ + + --use_gpu false to use CPU + --use_gpu true to use GPU + )"; + torch::SetUsageMessage(usage); + + torch::ParseCommandLineFlags(&argc, &argv); + CheckArgs(); + + torch::Device device(torch::kCPU); + if (FLAGS_use_gpu) { + device = torch::Device(torch::kCUDA, 0); + } + + K2_LOG(INFO) << "Device: " << device; + + int32_t num_waves = argc - 1; + K2_CHECK_GE(num_waves, 1) << "You have to provide at least one wave file"; + std::vector wave_filenames(num_waves); + for (int32_t i = 0; i != num_waves; ++i) { + wave_filenames[i] = argv[i + 1]; + } + + kaldifeat::FbankOptions fbank_opts; + fbank_opts.frame_opts.samp_freq = FLAGS_sample_rate; + fbank_opts.frame_opts.dither = 0; + fbank_opts.frame_opts.frame_shift_ms = FLAGS_frame_shift_ms; + fbank_opts.frame_opts.frame_length_ms = FLAGS_frame_length_ms; + fbank_opts.mel_opts.num_bins = FLAGS_num_bins; + fbank_opts.device = device; + + kaldifeat::Fbank fbank(fbank_opts); + + K2_LOG(INFO) << "Load wave files"; + auto wave_data = k2::ReadWave(wave_filenames, FLAGS_sample_rate); + + for (auto &w : wave_data) { + w = w.to(device); + } + + K2_LOG(INFO) << "Compute features"; + std::vector num_frames; + auto features_vec = k2::ComputeFeatures(fbank, wave_data, &num_frames); + + // Note: math.log(1e-10) is -23.025850929940457 + auto features = torch::nn::utils::rnn::pad_sequence(features_vec, true, + -23.025850929940457f); + + K2_LOG(INFO) << "Load neural network model"; + torch::jit::script::Module module = torch::jit::load(FLAGS_jit_pt); + module.eval(); + module.to(device); + + int32_t vocab_size = 500; + int32_t context_size = 2; + int32_t subsampling_factor = 4; + + k2::FsaClass decoding_graph; + if (FLAGS_use_lg) { + K2_LOG(INFO) << "Load LG.pt"; + decoding_graph = k2::LoadFsa(FLAGS_lg, device); + K2_CHECK(decoding_graph.HasTensorAttr("aux_labels") || + decoding_graph.HasRaggedTensorAttr("aux_labels")); + } else { + K2_LOG(INFO) << "Build Trivial graph"; + decoding_graph = k2::TrivialGraph(vocab_size - 1, device); + } + + K2_LOG(INFO) << "Decoding"; + + k2::rnnt_decoding::RnntDecodingConfig config(vocab_size, context_size, + FLAGS_beam, FLAGS_max_states, + FLAGS_max_contexts); + + std::vector> + individual_streams(num_waves); + std::vector individual_graphs(num_waves); + // suppose we are using same graph for all waves. + for (int32_t i = 0; i < num_waves; ++i) { + individual_graphs[i] = decoding_graph; + individual_streams[i] = k2::rnnt_decoding::CreateStream( + std::make_shared(individual_graphs[i].fsa)); + } + + // we are not using a streaming model currently, so calculate encoder_outs + // at a time. + auto input_lengths = + torch::from_blob(num_frames.data(), {num_waves}, torch::kLong) + .to(torch::kInt) + .to(device); + + K2_LOG(INFO) << "Compute encoder outs"; + // the output for module.encoder.forward() is a tuple of 2 tensors + auto outputs = module.attr("encoder") + .toModule() + .run_method("forward", features, input_lengths) + .toTuple(); + assert(outputs->elements().size() == 2u); + + auto encoder_outs = outputs->elements()[0].toTensor(); + auto encoder_outs_lengths = outputs->elements()[1].toTensor(); + + int32_t T = encoder_outs.size(1); + int32_t chunk_size = 10; // 10 frames per chunk + std::vector decoded_frames(num_waves, 0); + std::vector positions(num_waves, 0); + + // decocding results for each waves + std::vector texts(num_waves, ""); + + // simulate asynchronous decoding + while (true) { + std::vector> + current_streams; + std::vector current_encoder_outs; + // which waves we are decoding now + std::vector current_wave_ids; + + std::vector current_num_frames; + std::vector current_graphs; + + for (int32_t i = 0; i < num_waves; ++i) { + // this wave is done + if (decoded_frames[i] * subsampling_factor >= num_frames[i]) continue; + + current_streams.emplace_back(individual_streams[i]); + current_graphs.emplace_back(individual_graphs[i]); + current_wave_ids.push_back(i); + + if ((num_frames[i] - decoded_frames[i]) <= + chunk_size * subsampling_factor) { + decoded_frames[i] = num_frames[i] / subsampling_factor; + } else { + decoded_frames[i] += chunk_size; + } + + current_num_frames.emplace_back(decoded_frames[i]); + + int32_t start = positions[i], + end = start + chunk_size >= T ? T : start + chunk_size; + positions[i] = end; + auto sub_output = encoder_outs.index( + {i, torch::indexing::Slice(start, end), torch::indexing::Slice()}); + + // padding T axis to chunk_size if needed + namespace F = torch::nn::functional; + sub_output = F::pad(sub_output, + F::PadFuncOptions({0, 0, 0, chunk_size - end + start}) + .mode(torch::kConstant)); + + current_encoder_outs.push_back(sub_output); + + // we can decode at most `FLAGS_max_num_streams` waves at a time + if (static_cast(current_wave_ids.size()) >= + FLAGS_max_num_streams) + break; + } + if (current_wave_ids.size() == 0) break; // finished + + auto sub_encoder_outs = torch::stack(current_encoder_outs); + + auto streams = + k2::rnnt_decoding::RnntDecodingStreams(current_streams, config); + k2::DecodeOneChunk(streams, module, sub_encoder_outs); + + k2::FsaVec ofsa; + k2::Array1 out_map; + bool allow_partial = true; + streams.FormatOutput(current_num_frames, allow_partial, &ofsa, &out_map); + + auto arc_map = k2::Ragged(ofsa.shape, out_map).RemoveAxis(1); + k2::FsaClass lattice(ofsa); + lattice.CopyAttrs(current_graphs, arc_map); + + lattice = k2::GetBestPaths(lattice, FLAGS_use_max, FLAGS_num_paths, + FLAGS_nbest_scale); + + auto ragged_aux_labels = k2::GetTexts(lattice); + + auto aux_labels_vec = ragged_aux_labels.ToVecVec(); + + if (!FLAGS_use_lg) { + k2::SymbolTable symbol_table(FLAGS_tokens); + for (size_t i = 0; i < current_wave_ids.size(); ++i) { + std::string text; + for (auto id : aux_labels_vec[i]) { + text.append(symbol_table[id]); + } + texts[current_wave_ids[i]] = std::move(text); + } + } else { + k2::SymbolTable symbol_table(FLAGS_word_table); + for (size_t i = 0; i < current_wave_ids.size(); ++i) { + std::string text; + std::string sep = ""; + for (auto id : aux_labels_vec[i]) { + text.append(sep); + text.append(symbol_table[id]); + sep = " "; + } + texts[current_wave_ids[i]] = text; + } + } + std::ostringstream os; + os << "\nPartial result:\n"; + for (size_t i = 0; i != current_wave_ids.size(); ++i) { + os << wave_filenames[current_wave_ids[i]] << "\n"; + os << texts[current_wave_ids[i]]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + } + + std::ostringstream os; + os << "\nDecoding result:\n"; + for (int32_t i = 0; i != num_waves; ++i) { + os << wave_filenames[i] << "\n"; + os << texts[i]; + os << "\n\n"; + } + K2_LOG(INFO) << os.str(); + return 0; +} diff --git a/k2/torch/csrc/CMakeLists.txt b/k2/torch/csrc/CMakeLists.txt new file mode 100644 index 000000000..748b1ee1c --- /dev/null +++ b/k2/torch/csrc/CMakeLists.txt @@ -0,0 +1,67 @@ +include_directories(${CMAKE_SOURCE_DIR}) + +# it is located in k2/csrc/cmake/transform.cmake +include(transform) +set(k2_torch_srcs + beam_search.cu + decode.cu + dense_fsa_vec.cu + deserialization.cu + fsa_algo.cu + fsa_class.cu + hypothesis.cu + nbest.cu + parse_options.cu + symbol_table.cu + utils.cu + wave_reader.cu +) + +if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE k2_torch_srcs SRCS ${k2_torch_srcs}) +endif() + +add_library(k2_torch ${k2_torch_srcs}) +target_link_libraries(k2_torch PUBLIC ${TORCH_LIBRARIES} context) + +add_library(k2_fbank features.cc) +target_link_libraries(k2_fbank PUBLIC ${TORCH_LIBRARIES} kaldifeat_core) + +if(K2_ENABLE_TESTS) + # Please sort files alphabetically + set(k2_torch_test_srcs + dense_fsa_vec_test.cu + deserialization_test.cu + fsa_class_test.cu + hypothesis_test.cu + parse_options_test.cu + wave_reader_test.cu + ) + + if(NOT K2_WITH_CUDA) + transform(OUTPUT_VARIABLE k2_torch_test_srcs SRCS ${k2_torch_test_srcs}) + endif() + + function(k2_add_torch_test source) + get_filename_component(name ${source} NAME_WE) + set(target_name "cu_k2_torch_${name}") + add_executable(${target_name} "${source}") + set_target_properties(${target_name} PROPERTIES CUDA_SEPARABLE_COMPILATION ON) + target_link_libraries(${target_name} k2_torch gtest gtest_main) + + # NOTE: We set the working directory here so that + # it works also on windows. The reason is that + # the required DLLs are inside ${TORCH_DIR}/lib + # and they can be found by the exe if the current + # working directory is ${TORCH_DIR}\lib + add_test(NAME "Test.Cuda.${target_name}" + COMMAND + $ + WORKING_DIRECTORY ${TORCH_DIR}/lib + ) + endfunction() + + foreach(source IN LISTS k2_torch_test_srcs) + k2_add_torch_test(${source}) + endforeach() +endif() diff --git a/k2/torch/csrc/CPPLINT.cfg b/k2/torch/csrc/CPPLINT.cfg new file mode 100644 index 000000000..ce4942ccf --- /dev/null +++ b/k2/torch/csrc/CPPLINT.cfg @@ -0,0 +1,3 @@ +exclude_files=custom_class.h +exclude_files=test_wave_data.h +exclude_files=test_deserialization_data.h diff --git a/k2/torch/csrc/beam_search.cu b/k2/torch/csrc/beam_search.cu new file mode 100644 index 000000000..549d40d8c --- /dev/null +++ b/k2/torch/csrc/beam_search.cu @@ -0,0 +1,393 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_BEAM_SEARCH_H_ +#define K2_TORCH_CSRC_BEAM_SEARCH_H_ + +#include +#include +#include +#include + +#include "k2/csrc/array.h" +#include "k2/csrc/log.h" +#include "k2/csrc/ragged_ops.h" +#include "k2/torch/csrc/beam_search.h" +#include "k2/torch/csrc/hypothesis.h" +#include "torch/all.h" + +namespace k2 { + +static inline torch::Tensor FloorDivide(torch::Tensor a, int32_t b) { +#if K2_TORCH_VERSION_MAJOR > 1 || \ + (K2_TORCH_VERSION_MAJOR == 1 && K2_TORCH_VERSION_MINOR > 7) + return torch::div(a, b, /*rounding_mode*/ "trunc"); +#else + return torch::floor_divide(a, b); +#endif +} + +/** + * Construct the decoder input from the current hypothesis. + * + * @param hyps A list-of-list of token IDs containing the current decoding + * results. Its length is `batch_size` + * @param decoder_input A 2-D tensor of shape (batch_size, context_size). + */ +static void BuildDecoderInput(const std::vector> &hyps, + torch::Tensor *decoder_input) { + int32_t batch_size = decoder_input->size(0); + int32_t context_size = decoder_input->size(1); + int64_t *p = decoder_input->data_ptr(); + for (int32_t i = 0; i != batch_size; ++i) { + auto start = hyps[i].end() - context_size; + auto end = hyps[i].end(); + std::copy(start, end, p); + p += context_size; + } +} + +static torch::Tensor BuildDecoderInput(const std::vector hyps, + int32_t context_size) { + int32_t num_hyps = hyps.size(); + torch::Tensor decoder_input = + torch::empty({num_hyps, context_size}, + torch::dtype(torch::kLong) + .memory_format(torch::MemoryFormat::Contiguous)); + + int64_t *p = decoder_input.data_ptr(); + for (const auto &h : hyps) { + auto start = h.ys.end() - context_size; + auto end = h.ys.end(); + + std::copy(start, end, p); + p += context_size; + } + + return decoder_input; +} + +/** Return a ragged shape with axes [utt][num_hyps]. + * + * @param hyps hyps.size() == batch_size. Each entry contains the active + * hypotheses of an utterance. + * @return Return a ragged shape with 2 axes [utt][num_hyps]. Note that the + * shape is on CPU. + */ +static RaggedShape GetHypsShape(const std::vector &hyps) { + int32_t num_utt = hyps.size(); + Array1 row_splits(GetCpuContext(), num_utt + 1); + int32_t *row_splits_data = row_splits.Data(); + + for (int32_t i = 0; i != num_utt; ++i) { + row_splits_data[i] = hyps[i].Size(); + } + + ExclusiveSum(row_splits, &row_splits); + + return RaggedShape2(&row_splits, nullptr, row_splits.Back()); +} + +std::vector> GreedySearch( + const torch::jit::Module &model, const torch::Tensor &encoder_out, + const torch::Tensor &encoder_out_lens) { + K2_CHECK_EQ(encoder_out.dim(), 3); + K2_CHECK_EQ(encoder_out.scalar_type(), torch::kFloat); + + K2_CHECK_EQ(encoder_out_lens.dim(), 1); + K2_CHECK_EQ(encoder_out_lens.scalar_type(), torch::kLong); + K2_CHECK(encoder_out_lens.device().is_cpu()); + + torch::nn::utils::rnn::PackedSequence packed_seq = + torch::nn::utils::rnn::pack_padded_sequence(encoder_out, encoder_out_lens, + /*batch_first*/ true, + /*enforce_sorted*/ false); + torch::jit::Module decoder = model.attr("decoder").toModule(); + torch::jit::Module joiner = model.attr("joiner").toModule(); + torch::jit::Module decoder_proj = joiner.attr("decoder_proj").toModule(); + + auto projected_encoder_out = joiner.attr("encoder_proj") + .toModule() + .run_method("forward", packed_seq.data()) + .toTensor(); + + int32_t blank_id = decoder.attr("blank_id").toInt(); + + int32_t unk_id = blank_id; + if (decoder.hasattr("unk_id")) { + unk_id = decoder.attr("unk_id").toInt(); + } + + int32_t context_size = decoder.attr("context_size").toInt(); + int32_t batch_size = encoder_out_lens.size(0); + + torch::Device device = encoder_out.device(); + + std::vector blanks(context_size, blank_id); + std::vector> hyps(batch_size, blanks); + + auto decoder_input = + torch::full({batch_size, context_size}, blank_id, + torch::dtype(torch::kLong) + .memory_format(torch::MemoryFormat::Contiguous)); + auto decoder_out = + decoder + .run_method("forward", decoder_input.to(device), /*need_pad*/ false) + .toTensor(); + decoder_out = decoder_proj.run_method("forward", decoder_out).toTensor(); + // decoder_out's shape is (batch_size, 1, joiner_dim) + + using torch::indexing::Slice; + auto batch_sizes_acc = packed_seq.batch_sizes().accessor(); + int32_t num_batches = packed_seq.batch_sizes().numel(); + int32_t offset = 0; + for (int32_t i = 0; i != num_batches; ++i) { + int32_t cur_batch_size = batch_sizes_acc[i]; + int32_t start = offset; + int32_t end = start + cur_batch_size; + auto cur_encoder_out = projected_encoder_out.index({Slice(start, end)}); + offset = end; + + cur_encoder_out = cur_encoder_out.unsqueeze(1).unsqueeze(1); + // Now cur_encoder_out's shape is (cur_batch_size, 1, 1, joiner_dim) + if (cur_batch_size < decoder_out.size(0)) { + decoder_out = decoder_out.index({Slice(0, cur_batch_size)}); + } + + auto logits = + joiner + .run_method("forward", cur_encoder_out, decoder_out.unsqueeze(1), + /*project_input*/ false) + .toTensor(); + // logits' shape is (cur_batch_size, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1); + auto max_indices = logits.argmax(/*dim*/ -1).cpu(); + auto max_indices_acc = max_indices.accessor(); + bool emitted = false; + for (int32_t k = 0; k != cur_batch_size; ++k) { + auto index = max_indices_acc[k]; + if (index != blank_id && index != unk_id) { + emitted = true; + hyps[k].push_back(index); + } + } + + if (emitted) { + if (cur_batch_size < decoder_input.size(0)) { + decoder_input = decoder_input.index({Slice(0, cur_batch_size)}); + } + BuildDecoderInput(hyps, &decoder_input); + decoder_out = decoder + .run_method("forward", decoder_input.to(device), + /*need_pad*/ false) + .toTensor(); + decoder_out = decoder_proj.run_method("forward", decoder_out).toTensor(); + } + } + + auto unsorted_indices = packed_seq.unsorted_indices().cpu(); + auto unsorted_indices_accessor = unsorted_indices.accessor(); + + std::vector> ans(batch_size); + + for (int32_t i = 0; i != batch_size; ++i) { + torch::ArrayRef arr(hyps[unsorted_indices_accessor[i]]); + ans[i] = arr.slice(context_size).vec(); + } + + return ans; +} + +std::vector> ModifiedBeamSearch( + const torch::jit::Module &model, const torch::Tensor &encoder_out, + const torch::Tensor &encoder_out_lens, int32_t num_acitve_paths /*=4*/) { + K2_CHECK_EQ(encoder_out.dim(), 3); + K2_CHECK_EQ(encoder_out.scalar_type(), torch::kFloat); + + K2_CHECK_EQ(encoder_out_lens.dim(), 1); + K2_CHECK_EQ(encoder_out_lens.scalar_type(), torch::kLong); + K2_CHECK(encoder_out_lens.device().is_cpu()); + + torch::nn::utils::rnn::PackedSequence packed_seq = + torch::nn::utils::rnn::pack_padded_sequence(encoder_out, encoder_out_lens, + /*batch_first*/ true, + /*enforce_sorted*/ false); + torch::jit::Module decoder = model.attr("decoder").toModule(); + torch::jit::Module joiner = model.attr("joiner").toModule(); + torch::jit::Module decoder_proj = joiner.attr("decoder_proj").toModule(); + + auto projected_encoder_out = joiner.attr("encoder_proj") + .toModule() + .run_method("forward", packed_seq.data()) + .toTensor(); + + int32_t blank_id = decoder.attr("blank_id").toInt(); + + int32_t unk_id = blank_id; + if (decoder.hasattr("unk_id")) { + unk_id = decoder.attr("unk_id").toInt(); + } + + int32_t context_size = decoder.attr("context_size").toInt(); + int32_t batch_size = encoder_out_lens.size(0); + + torch::Device device = encoder_out.device(); + + std::vector blanks(context_size, blank_id); + Hypotheses blank_hyp({{blanks, 0}}); + + std::deque finalized; + std::vector cur(batch_size, blank_hyp); + std::vector prev; + + using torch::indexing::Slice; + auto batch_sizes_acc = packed_seq.batch_sizes().accessor(); + int32_t num_batches = packed_seq.batch_sizes().numel(); + int32_t offset = 0; + for (int32_t i = 0; i != num_batches; ++i) { + int32_t cur_batch_size = batch_sizes_acc[i]; + int32_t start = offset; + int32_t end = start + cur_batch_size; + auto cur_encoder_out = projected_encoder_out.index({Slice(start, end)}); + offset = end; + + cur_encoder_out = cur_encoder_out.unsqueeze(1).unsqueeze(1); + // Now cur_encoder_out's shape is (cur_batch_size, 1, 1, joiner_dim) + + if (cur_batch_size < cur.size()) { + for (int32_t k = static_cast(cur.size()) - 1; + k >= cur_batch_size; --k) { + finalized.push_front(std::move(cur[k])); + } + cur.erase(cur.begin() + cur_batch_size, cur.end()); + } + + // Due to merging paths with identical token sequences, + // not all utterances have "num_acitve_paths" paths. + auto hyps_shape = GetHypsShape(cur); + + prev.clear(); + prev.reserve(hyps_shape.TotSize(1)); + for (auto &hyps : cur) { + for (auto &h : hyps) { + prev.push_back(std::move(h.second)); + } + } + cur.clear(); + cur.reserve(cur_batch_size); + + torch::Tensor ys_log_probs = + torch::empty({hyps_shape.TotSize(1), 1}, torch::dtype(torch::kFloat)); + + auto ys_log_probs_acc = ys_log_probs.accessor(); + for (int32_t k = 0; k != prev.size(); ++k) { + ys_log_probs_acc[k][0] = prev[k].log_prob; + } + + auto decoder_input = BuildDecoderInput(prev, context_size).to(device); + + auto decoder_out = + decoder.run_method("forward", decoder_input, /*need_pad*/ false) + .toTensor(); + + decoder_out = decoder_proj.run_method("forward", decoder_out).toTensor(); + // decoder_out is of shape (num_hyps, 1, joiner_dim) + + auto row_ids = hyps_shape.RowIds(1); + + auto index = + torch::from_blob(row_ids.Data(), {row_ids.Dim()}, torch::kInt32) + .to(torch::device(device).dtype(torch::kLong)); + + cur_encoder_out = cur_encoder_out.index_select(/*dim*/ 0, /*index*/ index); + // cur_encoder_out is of shape (num_hyps, 1, 1, joiner_dim) + + auto logits = + joiner + .run_method("forward", cur_encoder_out, decoder_out.unsqueeze(1), + /*project_input*/ false) + .toTensor(); + // logits' shape is (cur_batch_size, 1, 1, vocab_size) + logits = logits.squeeze(1).squeeze(1); + // now logits' shape is (cur_batch_size, vocab_size) + + auto log_probs = logits.log_softmax(-1).cpu(); + + log_probs.add_(ys_log_probs); + + int32_t vocab_size = log_probs.size(1); + log_probs = log_probs.reshape(-1); + auto row_splits = hyps_shape.RowSplits(1); + const int32_t *row_splits_data = row_splits.Data(); + + for (int32_t k = 0; k != cur_batch_size; ++k) { + int32_t start = row_splits_data[k]; + int32_t end = row_splits_data[k + 1]; + + torch::Tensor values, indexes; + std::tie(values, indexes) = + log_probs.slice(/*dim*/ 0, start * vocab_size, end * vocab_size) + .topk(/*k*/ num_acitve_paths, /*dim*/ 0, + /*largest*/ true, /*sorted*/ true); + + auto topk_hyp_indexes = FloorDivide(indexes, vocab_size); + auto topk_token_indexes = torch::remainder(indexes, vocab_size); + + auto values_acc = values.accessor(); + auto topk_hyp_indexes_acc = topk_hyp_indexes.accessor(); + auto topk_token_indexes_acc = topk_token_indexes.accessor(); + + Hypotheses hyps; + for (int32_t j = 0; j != values.numel(); ++j) { + int32_t hyp_idx = topk_hyp_indexes_acc[j]; + Hypothesis new_hyp = prev[start + hyp_idx]; // note: hyp_idx is 0 based + + int32_t new_token = topk_token_indexes_acc[j]; + if (new_token != blank_id && new_token != unk_id) { + new_hyp.ys.push_back(new_token); + } + + // We already added log_prob of the path to log_probs before, so + // we use values_acc[j] here directly. + new_hyp.log_prob = values_acc[j]; + hyps.Add(std::move(new_hyp)); + } + cur.push_back(std::move(hyps)); + } + } + + for (auto &h : finalized) { + cur.push_back(std::move(h)); + } + + auto unsorted_indices = packed_seq.unsorted_indices().cpu(); + auto unsorted_indices_accessor = unsorted_indices.accessor(); + + std::vector> ans(batch_size); + for (int32_t i = 0; i != batch_size; ++i) { + Hypothesis hyp = cur[unsorted_indices_accessor[i]].GetMostProbable(true); + torch::ArrayRef arr(hyp.ys); + ans[i] = arr.slice(context_size).vec(); + } + + return ans; +} + +} // namespace k2 + +#endif // K2_TORCH_CSRC_BEAM_SEARCH_H_ diff --git a/k2/torch/csrc/beam_search.h b/k2/torch/csrc/beam_search.h new file mode 100644 index 000000000..031a6380c --- /dev/null +++ b/k2/torch/csrc/beam_search.h @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_BEAM_SEARCH_H_ +#define K2_TORCH_CSRC_BEAM_SEARCH_H_ + +#include + +#include "torch/all.h" + +namespace k2 { + +/** RNN-T Greedy search decoding by limiting the max symol per frame to one. + * + * @param model The transducer model. See pruned_transducer_stateless2/model.py + * for the methods and properties it has. + * + * @param encoder_out Output from the encoder network. Its shape is + * (batch_size, T, encoder_out_dim) and its dtype is + * torch::kFloat. + * + * @param encoder_out_lens A 1-D tensor containing the valid frames before + * padding in `encoder_out`. Its dtype is torch.kLong + * and its shape is (batch_size,). Also, it must be + * on CPU. + * + * @return Return A list-of-list of token IDs containing the decoding results. + * The returned vector has size `batch_size` and each entry contains the + * decoding results for the corresponding input in encoder_out. + */ +std::vector> GreedySearch( + const torch::jit::Module &model, const torch::Tensor &encoder_out, + const torch::Tensor &encoder_out_lens); + +std::vector> ModifiedBeamSearch( + const torch::jit::Module &model, const torch::Tensor &encoder_out, + const torch::Tensor &encoder_out_lens, int32_t num_acitve_paths = 4); + +} // namespace k2 + +#endif // K2_TORCH_CSRC_BEAM_SEARCH_H_ diff --git a/k2/torch/csrc/decode.cu b/k2/torch/csrc/decode.cu new file mode 100644 index 000000000..b710c4f6f --- /dev/null +++ b/k2/torch/csrc/decode.cu @@ -0,0 +1,207 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "k2/csrc/fsa_algo.h" +#include "k2/csrc/fsa_utils.h" +#include "k2/csrc/ragged_ops.h" +#include "k2/torch/csrc/decode.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/utils.h" + +namespace k2 { + +FsaClass GetLattice(torch::Tensor nnet_output, FsaClass &decoding_graph, + torch::Tensor supervision_segments, float search_beam, + float output_beam, int32_t min_activate_states, + int32_t max_activate_states, int32_t subsampling_factor) { + DenseFsaVec dense_fsa_vec = CreateDenseFsaVec( + nnet_output, supervision_segments, subsampling_factor - 1); + return IntersectDensePruned(decoding_graph, dense_fsa_vec, search_beam, + output_beam, min_activate_states, + max_activate_states); +} + +Ragged GetTexts(FsaClass &lattice) { + if (lattice.HasTensorAttr("aux_labels")) { + torch::Tensor aux_labels = lattice.GetTensorAttr("aux_labels"); + Array1 aux_labels_array = Array1FromTorch(aux_labels); + RaggedShape aux_labels_shape = RemoveAxis(lattice.fsa.shape, 1); + auto ragged_aux_labels = + Ragged(aux_labels_shape, aux_labels_array); + return RemoveValuesLeq(ragged_aux_labels, 0); + } else { + K2_CHECK(lattice.HasRaggedTensorAttr("aux_labels")); + + auto aux_labels = lattice.GetRaggedTensorAttr("aux_labels"); + RaggedShape aux_labels_shape = + ComposeRaggedShapes(lattice.fsa.shape, aux_labels.shape); + aux_labels_shape = RemoveAxis(aux_labels_shape, 1); + aux_labels_shape = RemoveAxis(aux_labels_shape, 1); + auto ragged_aux_labels = + Ragged(aux_labels_shape, aux_labels.values); + return RemoveValuesLeq(ragged_aux_labels, 0); + } +} + +void WholeLatticeRescoring(FsaClass &G, float ngram_lm_scale, + FsaClass *lattice) { + K2_CHECK(lattice->HasTensorAttr("lm_scores")); + + torch::Tensor am_scores = + lattice->Scores() - lattice->GetTensorAttr("lm_scores"); + lattice->SetScores(am_scores); + + // Now, lattice contains only acoustic scores, we will attach LM scores + // from the given n-gram LM + lattice->DeleteTensorAttr("lm_scores"); + + K2_CHECK_EQ(G.NumAttrs(), 1) + << "G is expected to contain only 1 attribute: lm_scores."; + K2_CHECK_EQ(G.fsa.NumAxes(), 3); + K2_CHECK_EQ(G.fsa.Dim0(), 1); + + k2::Invert(lattice); + // Now lattice has word IDs as labels and token IDs as aux_labels. + + // TODO(fangjun): Use Intersect() when device is CPU + auto b_to_a_map = + k2::Array1(G.fsa.Context(), lattice->fsa.Dim0(), 0); + k2::Array1 arc_map_a, arc_map_b; + + k2::Fsa dest = k2::IntersectDevice(G.fsa, G.Properties(), lattice->fsa, + lattice->Properties(), b_to_a_map, + &arc_map_a, &arc_map_b, true); + + lattice->properties = 0; + lattice->fsa = dest; + lattice->CopyAttrs(*lattice, k2::Array1ToTorch(arc_map_b)); + lattice->CopyAttrs(G, k2::Array1ToTorch(arc_map_a)); + k2::Connect(lattice); + k2::TopSort(lattice); + k2::Invert(lattice); + // Now lattice has token IDs as labels and word IDs as aux_labels + + if (ngram_lm_scale != 1) { + torch::Tensor lm_scores = lattice->GetTensorAttr("lm_scores"); + am_scores = lattice->Scores() - lm_scores; + torch::Tensor scores = am_scores / ngram_lm_scale + lm_scores; + lattice->SetScores(scores); + } +} + +FsaClass GetBestPaths(FsaClass &lattice, bool use_max, int32_t num_paths, + float nbest_scale) { + if (use_max) { + return ShortestPath(lattice); + } else { + K2_CHECK(lattice.HasTensorAttr("aux_labels") || + lattice.HasRaggedTensorAttr("aux_labels")); + Nbest nbest = Nbest::FromLattice(lattice, num_paths, nbest_scale); + + auto word_fsa = nbest.fsa; + Invert(&word_fsa); + + // delete token IDs, as it is not needed. + if (word_fsa.HasTensorAttr("aux_labels")) + word_fsa.DeleteTensorAttr("aux_labels"); + if (word_fsa.HasRaggedTensorAttr("aux_labels")) + word_fsa.DeleteRaggedTensorAttr("aux_labels"); + word_fsa.Scores().zero_(); + + auto word_fsa_with_self_loops = LinearFsaWithSelfLoops(word_fsa); + + auto inv_lattice = lattice; + Invert(&inv_lattice); + ArcSort(&inv_lattice); + + Array1 path_to_utt_map; + if (inv_lattice.fsa.Dim0() == 1) { + path_to_utt_map = + Array1(nbest.shape.Context(), nbest.shape.TotSize(1), 0); + } else { + path_to_utt_map = nbest.shape.RowIds(1); + } + + auto path_lattice = IntersectDevice(inv_lattice, word_fsa_with_self_loops, + path_to_utt_map, true); + Connect(&path_lattice); + TopSort(&path_lattice); + + using FloatType = double; + Array1 tot_scores = + GetTotScores(path_lattice, true /*log_semiring*/); + auto ragged_tot_scores = Ragged(nbest.shape, tot_scores); + + Array1 best_hyp_indexes(ragged_tot_scores.Context(), + ragged_tot_scores.Dim0()); + ArgMaxPerSublist(ragged_tot_scores, + -std::numeric_limits::infinity(), + &best_hyp_indexes); + + Array1 indexes_map; + auto raw_fsa = + Index(nbest.fsa.fsa, 0 /*axis*/, best_hyp_indexes, &indexes_map); + + FsaClass best_path = FsaClass(raw_fsa); + best_path.CopyAttrs(nbest.fsa, Array1ToTorch(indexes_map)); + return best_path; + } +} + +void DecodeOneChunk(rnnt_decoding::RnntDecodingStreams &streams, + torch::jit::script::Module module, + torch::Tensor encoder_outs) { + K2_CHECK_EQ(encoder_outs.dim(), 3); + K2_CHECK_EQ(streams.NumStreams(), encoder_outs.size(0)); + int32_t T = encoder_outs.size(1); + for (int32_t t = 0; t < T; ++t) { + RaggedShape shape; + Array2 contexts; + streams.GetContexts(&shape, &contexts); + auto contexts_tensor = Array2ToTorch(contexts); + // `nn.Embedding()` in torch below v1.7.1 supports only torch.int64 + contexts_tensor = contexts_tensor.to(torch::kInt64); + auto decoder_outs = module.attr("decoder") + .toModule() + .run_method("forward", contexts_tensor, false) + .toTensor(); + auto current_encoder_outs = encoder_outs.index( + {torch::indexing::Slice(), torch::indexing::Slice(t, t + 1), + torch::indexing::Slice()}); + auto row_ids = Array1ToTorch(shape.RowIds(1)); + current_encoder_outs = + torch::index_select(current_encoder_outs, 0, row_ids); + + auto logits = module.attr("joiner") + .toModule() + .run_method("forward", current_encoder_outs.unsqueeze(1), + decoder_outs.unsqueeze(1)) + .toTensor() + .squeeze(1) + .squeeze(1); + auto logprobs = logits.log_softmax(-1); + auto logprobs_array = Array2FromTorch(logprobs); + streams.Advance(logprobs_array); + } + streams.TerminateAndFlushToStreams(); +} + +} // namespace k2 diff --git a/k2/torch/csrc/decode.h b/k2/torch/csrc/decode.h new file mode 100644 index 000000000..cb032ec82 --- /dev/null +++ b/k2/torch/csrc/decode.h @@ -0,0 +1,118 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_DECODE_H_ +#define K2_TORCH_CSRC_DECODE_H_ + +#include "k2/csrc/array.h" +#include "k2/csrc/fsa.h" +#include "k2/csrc/ragged.h" +#include "k2/csrc/rnnt_decode.h" +#include "k2/torch/csrc/fsa_class.h" +#include "torch/script.h" + +namespace k2 { + +/** Get decoding lattice from a neural network output and a decoding graph. + + @param nnet_output A 3-D tensor with dtype torch.float32. It is usally + the last layer of the neural network model, e.g., + the output of `log-softmax` layer. It has shape + `(N, T, C)`. + @param decoding_graph It is an FsaClass. It usually contains only one + graph. For instance, when using CTC decoding, + it contains a single CTC topo graph; when using + HLG decoding, it contains a single HLG graph. + + @param supervision_segments A 2-D tensor with dtype torch.int32. + Please refer to `k2::CreateDenseFsaVec()` + for its format. + @param search_beam See `k2::IntersectDensePruned()` for its meaning. + @param output_beam See `k2::IntersectDensePruned()` for its meaning. + @param min_activate_states See `k2::IntersectDensePruned()` for its meaning. + @param max_activate_states See `k2::IntersectDensePruned()` for its meaning. + @param subsampling_factor The subsampling factor of the model. + + @return Return an FsaClass, which contains the intersection of decoding graph + and the FSA constructed from `nnet_output`. All the attributes of the + decoding_graph are propagated the returned FsaClass as well. + */ +FsaClass GetLattice(torch::Tensor nnet_output, FsaClass &decoding_graph, + torch::Tensor supervision_segments, float search_beam, + float output_beam, int32_t min_activate_states, + int32_t max_activate_states, int32_t subsampling_factor); + +/** Get aux labels of each FSA contained in the lattice. + + @param lattice An FsaVec containing linear FSAs. It can be the return + value of `OneBestDecoding()`. + + @return Return a ragged array with two axes [utt][aux_label]. + */ +Ragged GetTexts(FsaClass &lattice); + +/** Rescore a lattice with an n-gram LM. + + @param G An acceptor. It MUST be an FsaVec containing only one + arc-sorted FSA. Also, it contains epsilon self loops + (see AddEpsilonSelfLoops()). It contains only one tensor + attribute: "lm_scores". + @param ngram_lm_scale The scale value for ngram LM scores. + @param lattice The input/output lattice. It can be the + return value of `GetLattice()`. + */ +void WholeLatticeRescoring(FsaClass &G, float ngram_lm_scale, + FsaClass *lattice); + +/** Get the best path of a given lattice. + + @param lattice The given lattice. + @param use_max True to use max operation to select the hypothesis with the + largest log_prob when there are duplicate hypotheses; False + to use log-add. + @param num_paths Number of paths to sample when generating Nbest. Only used + when use_max equals to false. + @param nbest_scale The scale value applying to lattice.score before + sampling. Only used when use_max equals to false. + + @return Return the lattice containing the best paths for each Fsa. + */ +FsaClass GetBestPaths(FsaClass &lattice, bool use_max, int32_t num_paths, + float nbest_scale); + +/** Advance a chunk of frames for rnnt decoding. + + @param streams The rnnt decoding streams. + @param module Jit script module containing "decoder_forword" and + "joiner_forward" methods. + @param encoder_outs The output of rnnt encoder which has a shape of + (B, T, C), B (i.e. the batch size) equals to + streams.NumStreams(). T is the chunk size. C is the + embedding dimension. + + Note: streams.TerminateAndFlushToStreams() will be invoked in this function, + so all the decoding results will be flushed back to the individual + streams belonging to the corresponding sequences. + */ +void DecodeOneChunk(rnnt_decoding::RnntDecodingStreams &streams, + torch::jit::script::Module module, + torch::Tensor encoder_outs); + +} // namespace k2 + +#endif // K2_TORCH_CSRC_DECODE_H_ diff --git a/k2/torch/csrc/dense_fsa_vec.cu b/k2/torch/csrc/dense_fsa_vec.cu new file mode 100644 index 000000000..12140def4 --- /dev/null +++ b/k2/torch/csrc/dense_fsa_vec.cu @@ -0,0 +1,173 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "k2/csrc/log.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/utils.h" +#include "torch/script.h" + +namespace k2 { + +DenseFsaVec CreateDenseFsaVec(torch::Tensor log_probs, + torch::Tensor supervision_segments, + int32_t allow_truncate /*=0*/) { + K2_CHECK_EQ(log_probs.dtype(), torch::kFloat32); + K2_CHECK_EQ(log_probs.dim(), 3); + + K2_CHECK_EQ(supervision_segments.dtype(), torch::kInt); + K2_CHECK_EQ(supervision_segments.dim(), 2); + K2_CHECK_EQ(supervision_segments.size(1), 3); + K2_CHECK_EQ(supervision_segments.device().type(), torch::kCPU); + K2_CHECK_GE(allow_truncate, 0); + + int32_t N = log_probs.size(0); + int32_t T = log_probs.size(1); + int32_t C = log_probs.size(2); + + // iterate the supervision_segments to get number of frames for each segment + int32_t num_utt = supervision_segments.size(0); + int32_t stride = supervision_segments.stride(0); + const int32_t *p_sup = supervision_segments.data_ptr(); + + // linear_indexes contains indexes along axis 0 + // for the tensor obtained from log_probs.view(-1, C) + std::vector linear_indexes; + linear_indexes.reserve(num_utt * (T + 1)); // the worse case + + // It contains the index of the extra frame of each utterance + // that will be set to [0, -inf, -inf, -inf, ... ] + std::vector extra_frame_indexes; + extra_frame_indexes.reserve(num_utt); + + int32_t duration_in_total = 0; + + for (int32_t i = 0; i != num_utt; ++i) { + const int32_t *this_row = p_sup + i * stride; + int32_t utt_index = this_row[0]; + int32_t start_frame = this_row[1]; + int32_t duration = this_row[2]; + + K2_CHECK_GE(utt_index, 0); + K2_CHECK_LT(utt_index, N); + + K2_CHECK_GE(start_frame, 0); + K2_CHECK_LT(start_frame, T); + + K2_CHECK_GE(duration, 0); + K2_CHECK_LE(start_frame + duration, T + allow_truncate); + + int32_t end_frame = std::min(start_frame + duration, T); // exclusive + duration = end_frame - start_frame; + + int32_t offset = utt_index * T; + std::vector this_utt_frames(duration); + std::iota(this_utt_frames.begin(), this_utt_frames.end(), + start_frame + offset); + linear_indexes.insert(linear_indexes.end(), this_utt_frames.begin(), + this_utt_frames.end()); + + // a placeholder for the extra frame that will be set to + // [ 0, -inf, -inf, ...] + linear_indexes.push_back(0); + + duration_in_total += duration; + extra_frame_indexes.push_back(duration_in_total); + duration_in_total += 1; // plus one for the extra frame + } + + torch::Tensor indexes = + torch::from_blob( + linear_indexes.data(), /*sizes*/ {int64_t(linear_indexes.size())}, + /*options*/ torch::device(torch::kCPU).dtype(torch::kLong)) + .to(log_probs.device()); + + torch::Tensor extra_frame_indexes_tensor = + torch::from_blob( + extra_frame_indexes.data(), + /*sizes*/ {int64_t(extra_frame_indexes.size())}, + /*options*/ torch::device(torch::kCPU).dtype(torch::kLong)) + .to(log_probs.device()); + + torch::Tensor scores = + torch::empty({duration_in_total, C + 1}, log_probs.options()); + + using namespace torch::indexing; // NOLINT + // scores[:, 1:] = log_probs.reshape(-1, C).index_select(0, indexes) + scores.index({"...", Slice(1, None, None)}) = + log_probs.reshape({-1, C}).index_select(0, indexes); + + // now set extra frames to [0, -inf, -inf. -inf, ... ] + // + // `scores` contains -infinity in certain locations: in scores[j,0] where + // j is not the last row-index for a given FSA-index, and scores[j,k] + // where j is the last row-index for a given FSA-index and k>0. + // The remaining locations contain the neural net output, except + // scores[j,0] where j is the last row-index for a given FSA-index; + // this contains zero. + // + // scores[:, 0] = float('-inf'); + // scores[last_frame_indexes] = torch.tensor([0] + [float('-inf')] * C, + // device=device); + constexpr float kNegInf = -1.0f * std::numeric_limits::infinity(); + scores.index({"...", 0}) = kNegInf; + std::vector tmp(C + 1); + tmp[0] = 0; + std::fill_n(tmp.begin() + 1, C, kNegInf); + torch::Tensor extra_frame = + torch::from_blob( + tmp.data(), /*sizes*/ {int64_t(tmp.size())}, + /*options*/ torch::device(torch::kCPU).dtype(torch::kFloat32)) + .to(log_probs.device()); + + scores.index_put_({extra_frame_indexes_tensor}, extra_frame); + + // Now compute row splits so that we can create a ragged shape + std::vector row_splits(num_utt + 1); + row_splits[0] = 0; + std::transform(extra_frame_indexes.begin(), extra_frame_indexes.end(), + &row_splits[1], [](int64_t i) { return i + 1; }); + + ContextPtr ctx = ContextFromTensor(log_probs); + Array1 row_splits_array(ctx, row_splits); + Array2 scores_array = Array2FromTorch(scores); + RaggedShape shape = + RaggedShape2(&row_splits_array, nullptr, row_splits.back()); + + return {shape, scores_array}; +} + +torch::Tensor GetSupervisionSegments(torch::IValue supervisions, + int32_t subsampling_factor) { + torch::Dict dict = supervisions.toGenericDict(); + torch::Tensor sequence_idx = dict.at("sequence_idx").toTensor(); + torch::Tensor start_frame = torch::floor_divide( + dict.at("start_frame").toTensor(), subsampling_factor); + + torch::Tensor num_frames = + torch::floor_divide(dict.at("num_frames").toTensor(), subsampling_factor); + + torch::Tensor supervision_segments = + torch::stack({sequence_idx, start_frame, num_frames}, 1).to(torch::kCPU); + return supervision_segments; +} + +} // namespace k2 diff --git a/k2/torch/csrc/dense_fsa_vec.h b/k2/torch/csrc/dense_fsa_vec.h new file mode 100644 index 000000000..5b78d98bd --- /dev/null +++ b/k2/torch/csrc/dense_fsa_vec.h @@ -0,0 +1,86 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_DENSE_FSA_VEC_H_ +#define K2_TORCH_CSRC_DENSE_FSA_VEC_H_ + +#include "k2/csrc/fsa.h" +#include "torch/script.h" + +namespace k2 { + +/** Construct a DenseFsaVec from neural net log-softmax outputs. + + @params log_probs A 3-D tensor of dtype torch.float32. It has shape + (N, T, C), where `N` is the number of utterances, + `T` the maximum input length, and `C` the number of + output classes. This is usually the output of the + log-softmax layer of a neural network. + + @param supervision_segments A 2-D tensor of dtype torch.int32 with 3 columns. + It has be to on CPU. + Each row contains information for a supervision segment. Column 0 + is the `utterance_index` indicating which utterance this segment + comes from; column 1 specifies the `start_frame` of this segment + within the utterance; column 2 contains the `duration` of this + segment (in number of frames). + + Note: + - `0 < start_frame + duration <= T + allow_truncate` + - `0 <= start_frame < T` + - `duration > 0` + + Caution: + If the resulting dense fsa vec is used as an input to + `k2::IntersectDense`, then the last column, i.e., the duration + column, has to be sorted in **decreasing** order. + That is, the first supervision_segment (the first row) has the + largest duration. + Otherwise, you don't need to sort the last column. + + `k2::IntersectDense` is often used in the training stage, so + you should usually sort dense fsa vecs by its duration + in training. `k2::IntersectDensePruned` is usually used in the + decoding stage, so you don't need to sort dense fsa vecs in + decoding. + + @param allow_truncate If not zero, it truncates at most this number of frames + from `duration` in case `start_frame + duration > T`. + + @param Return a DenseFsaVec. + */ +DenseFsaVec CreateDenseFsaVec(torch::Tensor log_probs, + torch::Tensor supervision_segments, + int32_t allow_truncate = 0); + +// See +// https://github.com/lhotse-speech/lhotse/blob/master/lhotse/dataset/speech_recognition.py#L32 +// for the format of "supervisions" +// +// @param supervisions A dict containing keys and values shown in the following: +// - sequence_idx: torch.Tensor +// - start_frame: torch.Tensor +// - num_frames: torch.Tensor +// @return Return a 2-D torch.int32 tensor that can be used to construct a +// DenseFsaVec. See `k2::CreateDenseFsaVec()` +torch::Tensor GetSupervisionSegments(torch::IValue supervisions, + int32_t subsampling_factor); + +} // namespace k2 + +#endif // K2_TORCH_CSRC_DENSE_FSA_VEC_H_ diff --git a/k2/torch/csrc/dense_fsa_vec_test.cu b/k2/torch/csrc/dense_fsa_vec_test.cu new file mode 100644 index 000000000..e3fac014c --- /dev/null +++ b/k2/torch/csrc/dense_fsa_vec_test.cu @@ -0,0 +1,161 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#ifdef K2_WITH_CUDA +#include "c10/cuda/CUDAFunctions.h" +#endif +#include "gtest/gtest.h" +#include "k2/torch/csrc/dense_fsa_vec.h" +#include "k2/torch/csrc/utils.h" + +namespace k2 { + +TEST(CreateDenseFsaVec, AllowTruncate_0) { + std::vector device_types = {torch::kCPU}; +#ifdef K2_WITH_CUDA + if (torch::cuda::device_count()) { + device_types.push_back(torch::kCUDA); + } +#endif + // clang-format off + std::vector v = { + // utterance 0, 3 frames + 0, 1, 2, 3.5, 5, + 0.5, 2, 3, 10, -1, + 3, 6, 9, 1, 8, + + // utterance 1, 2 frames + 1, 3, 5, -2, 0, + 0, 2, 1, 3, 10, + 0, 1, 3, 8, 7, + + // utterance 2, 1 frame + 1, 5, 9, 10, 12, + 13, 6, 8, 7, 9, + 0, -1, 3, 8, 7, + }; + + std::vector sup = { + // utterance 0 + 0, 0, 3, + // utterance 2 + 2, 0, 1, + // utterance 1 + 1, 1, 2, + }; + // clang-format on + for (auto device_type : device_types) { + torch::Device device(device_type, 0); + torch::Tensor log_probs = + torch::from_blob(v.data(), {3, 3, 5}, + torch::device(torch::kCPU).dtype(torch::kFloat32)) + .to(device); + + torch::Tensor supervision_segments = torch::from_blob( + sup.data(), {3, 3}, torch::device(torch::kCPU).dtype(torch::kInt)); + + DenseFsaVec dense_fsa_vec = + CreateDenseFsaVec(log_probs, supervision_segments, 0); + + ContextPtr ctx = ContextFromDevice(device); + RaggedShape expected_shape = + RaggedShape("[[x x x x] [x x] [x x x]]").To(ctx); + EXPECT_TRUE(Equal(dense_fsa_vec.shape, expected_shape)); + + Array2 expected_scores(R"( + [[-inf 0 1 2 3.5 5] + [-inf 0.5 2 3 10 -1] + [-inf 3 6 9 1 8] + [0 -inf -inf -inf -inf -inf] + + [-inf 1 5 9 10 12] + [0 -inf -inf -inf -inf -inf] + + [-inf 0 2 1 3 10] + [-inf 0 1 3 8 7] + [0 -inf -inf -inf -inf -inf] + ])"); + expected_scores = expected_scores.To(ctx); + + EXPECT_TRUE(Equal(expected_scores, dense_fsa_vec.scores)); + } +} + +TEST(CreateDenseFsaVec, AllowTruncate_1) { + std::vector device_types = {torch::kCPU}; +#ifdef K2_WITH_CUDA + if (torch::cuda::device_count()) { + device_types.push_back(torch::kCUDA); + } +#endif + // clang-format off + std::vector v = { + // utterance 0, 3 frames + -1, 2, 3, 4, + 8, 9, 6, 5.5, + 2, 3, 4, 5, + // utterance 1, 1 frame + -2, -1, 3, 4, + 2, 3, 0, 8, + 8, 9, 0, 9.8, + }; + + std::vector sup = { + // utterance 1 + 1, 2, 2, + // utterance 0 + 0, 0, 5, + }; + + // clang-format on + for (auto device_type : device_types) { + torch::Device device(device_type, 0); + torch::Tensor log_probs = + torch::from_blob(v.data(), {2, 3, 4}, + torch::device(torch::kCPU).dtype(torch::kFloat32)) + .to(device); + + torch::Tensor supervision_segments = torch::from_blob( + sup.data(), {2, 3}, torch::device(torch::kCPU).dtype(torch::kInt)); + + DenseFsaVec dense_fsa_vec = + CreateDenseFsaVec(log_probs, supervision_segments, 2); + + ContextPtr ctx = ContextFromDevice(device); + RaggedShape expected_shape = RaggedShape("[[x x] [x x x x]]").To(ctx); + EXPECT_TRUE(Equal(dense_fsa_vec.shape, expected_shape)); + + Array2 expected_scores(R"( + [[-inf 8 9 0 9.8] + [0 -inf -inf -inf -inf] + + [-inf -1 2 3 4] + [-inf 8 9 6 5.5] + [-inf 2 3 4 5] + [0 -inf -inf -inf -inf] + ])"); + + expected_scores = expected_scores.To(ctx); + + EXPECT_TRUE(Equal(expected_scores, dense_fsa_vec.scores)); + } +} + +} // namespace k2 diff --git a/k2/torch/csrc/deserialization.cu b/k2/torch/csrc/deserialization.cu new file mode 100644 index 000000000..d571083a9 --- /dev/null +++ b/k2/torch/csrc/deserialization.cu @@ -0,0 +1,455 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // NOLINT +#include +#include +#include + +#include "caffe2/serialize/file_adapter.h" +#include "caffe2/serialize/inline_container.h" +#include "k2/csrc/fsa.h" +#include "k2/csrc/ragged.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/utils.h" +#include "torch/csrc/jit/serialization/import_source.h" +#if K2_TORCH_VERSION_MAJOR > 1 || \ + (K2_TORCH_VERSION_MAJOR == 1 && K2_TORCH_VERSION_MINOR >= 9) +// for torch::jit::readArchiveAndTensors +#include "torch/csrc/jit/serialization/import_read.h" +#endif + +namespace k2 { + +// A helper class to construct a Ragged from an archive +struct RaggedIntHelper : public Ragged, + public torch::CustomClassHolder { + using k2::Ragged::Ragged; + explicit RaggedIntHelper(const Ragged &ragged) + : Ragged(ragged) {} +}; + +/** Whether the torch IValue contains a Ragged instance. + + @param value The given torch IValue. + @return Return true if the given value contains a Ragged instance, + otherwise false. + */ +static bool IsRaggedInt(torch::IValue value) { + return value.type() == + torch::getCustomClassType>(); +} + +/// Convert an IValue to a Ragged +/// It is not static as it's used in deserialization_test.cu +/*static*/ Ragged ToRaggedInt(torch::IValue value) { + auto ragged_int_holder = value.toCustomClass(); + return *ragged_int_holder; +} + +static void RegisterRaggedInt(); + +struct RaggedRegister { + RaggedRegister() { RegisterRaggedInt(); } +}; + +// Register Ragged as a custom class of torch, so that we can wrap +// it to torch IValue and do serialization & deserialization thing. +static RaggedRegister ragged_register; + +namespace { + +// copied & modified from torch/csrc/jit/serialization/unpickler.cpp +void restoreAccurateTypeTags(const torch::IValue &root, + const torch::jit::TypePtr &type_tag) { + struct Work { + torch::jit::TypePtr static_type; + torch::IValue value; + }; + std::vector to_process = {{type_tag, root}}; + std::unordered_set scanned; + while (!to_process.empty()) { + Work w = std::move(to_process.back()); + to_process.pop_back(); + // ensure we only scan each pointer value once, otherwise this + // can become exponential (and if we allow recursive data in the future, + // it would not terminiate). + if (w.value.isPtrType()) { + const void *key = w.value.internalToPointer(); + auto it = scanned.find(key); + if (it != scanned.end()) { + continue; + } + scanned.emplace_hint(it, key); + } + switch (w.static_type->kind()) { + case torch::jit::TensorType::Kind: + case torch::jit::NumberType::Kind: + case torch::jit::FloatType::Kind: + case torch::jit::IntType::Kind: + case torch::jit::NoneType::Kind: + case torch::jit::GeneratorType::Kind: + case torch::jit::BoolType::Kind: + case torch::jit::VarType::Kind: + case torch::jit::CapsuleType::Kind: + case torch::jit::PyObjectType::Kind: + case torch::jit::StringType::Kind: + case torch::jit::FunctionType::Kind: + case torch::jit::DeviceObjType::Kind: + case torch::jit::QSchemeType::Kind: + case torch::jit::LayoutType::Kind: + case torch::jit::ScalarTypeType::Kind: + case torch::jit::RRefType::Kind: + case torch::jit::AnyType::Kind: + case torch::jit::AnyListType::Kind: + case torch::jit::AnyTupleType::Kind: + case torch::jit::AnyClassType::Kind: +#if K2_TORCH_VERSION_MAJOR > 1 || \ + (K2_TORCH_VERSION_MAJOR == 1 && K2_TORCH_VERSION_MINOR >= 7) + case torch::jit::AnyEnumType::Kind: + case torch::jit::QuantizerType::Kind: +#endif + // no op, there is nothing to tag + break; +#if K2_TORCH_VERSION_MAJOR > 1 || \ + (K2_TORCH_VERSION_MAJOR == 1 && K2_TORCH_VERSION_MINOR >= 7) + case torch::jit::EnumType::Kind: + // TODO(gmagogsfm): Implement serialization/deserialization of Enum. + AT_ASSERT(false); +#endif + case torch::jit::TupleType::Kind: { + auto t = w.value.toTuple(); + auto ttype = w.static_type->expect(); + for (size_t i = 0; i < ttype->containedTypes().size(); ++i) { + Work elem = {ttype->containedTypes().at(i), t->elements().at(i)}; + to_process.emplace_back(std::move(elem)); + } + } break; + case torch::jit::FutureType::Kind: { + auto f = w.value.toFuture(); + auto t = w.static_type->expect(); + if (f->completed()) { + Work elem = {t->getElementType(), f->value()}; + to_process.emplace_back(std::move(elem)); + } + } break; + case torch::jit::OptionalType::Kind: { + if (!w.value.isNone()) { + auto t = w.static_type->expect(); + Work elem = {t->getElementType(), w.value}; + to_process.emplace_back(std::move(elem)); + } + } break; + case torch::jit::ListType::Kind: { + // specialized lists do not need their type refined, so we can exit + // early here + if (!w.value.isList()) { + break; + } + auto elem_type = + w.static_type->cast()->getElementType(); + auto lst = w.value.toList(); + lst.unsafeSetElementType(elem_type); + for (const torch::IValue &item : lst) { + Work elem = {elem_type, item}; + to_process.emplace_back(std::move(elem)); + } + } break; + case torch::jit::DictType::Kind: { + auto dt = w.static_type->cast(); + auto d = w.value.toGenericDict(); + d.unsafeSetKeyType(dt->getKeyType()); + d.unsafeSetValueType(dt->getValueType()); + for (const auto &item : d) { + Work kelem = {dt->getKeyType(), item.key()}; + Work velem = {dt->getValueType(), item.value()}; + to_process.emplace_back(std::move(kelem)); + to_process.emplace_back(std::move(velem)); + } + } break; + // in both cases the dynamic type is a class, and we are going to tag with + // the dynamic type + case torch::jit::InterfaceType::Kind: + case torch::jit::ClassType::Kind: { + auto obj = w.value.toObject(); + auto typ = obj->type(); // note: intentionally using the dynamic type, + // the static type is potentially less accurate + for (size_t i = 0; i < typ->numAttributes(); ++i) { + Work elem = {typ->getAttribute(i), obj->getSlot(i)}; + to_process.emplace_back(std::move(elem)); + } + } + } + } +} + +// modified from torch/csrc/jit/serialization/pickler.cpp +bool checkHasValidSetGetState(const std::shared_ptr &cls) { + // Check that the schemas for __getstate__ and __setstate__ are correct + auto getstate = cls->findMethod("__getstate__"); + if (getstate == nullptr) { + return false; + } + auto get_schema = getstate->getSchema(); + + // Check __getstate__ + // __getstate__ is expected to be (self) -> T + K2_CHECK_EQ(get_schema.arguments().size(), 1) + << "'__getstate__' must have 'self' as its only argument, but found " + << get_schema.arguments().size() << " arguments"; + + K2_CHECK_EQ(get_schema.returns().size(), 1) + << "'__getstate__' must return 1 value, but found " + << get_schema.returns().size(); + + // Check __setstate__ if the method exists + // __setstate__ is expected to be (self, T) -> None + auto setstate = cls->findMethod("__setstate__"); + if (!setstate) { + return false; + } + auto set_schema = setstate->getSchema(); + + K2_CHECK_EQ(set_schema.arguments().size(), 2) + << "'__setstate__' must have 'self' and the state as its " + "only arguments, but found " + << set_schema.arguments().size() << " arguments"; + + K2_CHECK_EQ(set_schema.returns().size(), 1) + << "'__setstate__' must return None, but found " + << set_schema.returns().size() << " return values"; + + K2_CHECK(set_schema.returns().at(0).type()->isSubtypeOf( + torch::jit::NoneType::get())) + << "'__setstate__' must return None, but found value of type " + << set_schema.returns().at(0).type()->annotation_str(); + + // Check that the return type of __getstate__ matches the input to + // __setstate__ + auto get_type = get_schema.returns().at(0).type(); + auto set_type = set_schema.arguments().at(1).type(); + + K2_CHECK(get_type->isSubtypeOf(set_type)) + << "'__getstate__'s return type (" << get_type->annotation_str() + << ") does not match '__setstate__'s argument type (" + << set_type->annotation_str() << ")"; + + return true; +} + +// modified from torch/csrc/jit/serialization/import.cpp +// The code style in this function is also kept. +void postSetStateValidate(const torch::IValue &v) { + auto obj = v.toObject(); + const auto &objType = obj->type(); + for (size_t i = 0; i < objType->numAttributes(); i++) { + const auto &attrType = objType->getAttribute(i); + const auto &attrName = objType->getAttributeName(i); + const auto &slot = obj->getSlot(i); + // const auto attrType = objType->getAttribute(i); + // Verify that all the non-optional attributes have been initialized + // TODO: Issue #20497 + if (attrType->kind() != torch::jit::TypeKind::OptionalType) { + K2_CHECK(!slot.isNone()) + << "The field '" << attrName + << "' was left uninitialized after '__setstate__'," + "but expected a value of type '" + << attrType->repr_str() << "'"; + } + } +} + +} // namespace + +static void RegisterRaggedInt() { + // Register a custom class so that PyTorch knows how to parse + // the value from the archive. + // + // TODO: to support other types other than Ragged + torch::class_("_k2", "RaggedTensor") + .def_pickle( + // __getstate__ + [](const c10::intrusive_ptr &self) { + std::vector v; + c10::intrusive_ptr ans = + torch::ivalue::Tuple::create(v); + return torch::IValue(ans); + }, + // __setstate__ + [](torch::IValue states) { + K2_CHECK(states.isTuple()); + auto tuple = states.toTuple(); + auto &elements = tuple->elements(); + K2_CHECK(elements.size() == 3u || elements.size() == 5u) + << "actual size: " << elements.size(); + + // TODO: handle the case when size is 5 + K2_CHECK_EQ(elements.size(), 3u); + + k2::Array1 row_splits = + k2::Array1FromTorch(elements[0].toTensor()); + k2::Array1 values = + k2::Array1FromTorch(elements[2].toTensor()); + K2_CHECK_EQ(elements[1].toStringRef(), "row_ids1"); + + k2::RaggedShape shape = + k2::RaggedShape2(&row_splits, nullptr, values.Dim()); + + return c10::make_intrusive(shape, values); + }); + + // the default namespace for custom classes is __torch__.torch.classes + // but `RaggedTensor` is serialized to the namespace _k2.ragged, + // so we need to change it to `_k2.ragged` + torch::ClassTypePtr p = + torch::getCustomClassType>(); + const_cast(p->name().value()) = + torch::QualifiedName("_k2.ragged.RaggedTensor"); + // We need to re-register the class type since we changed its name. + torch::registerCustomClass(p); +} + +// This function is modified from torch::jit::load() +// See torch/csrc/jit/serialization/import.cpp +// +torch::IValue Load( + const std::string &filename, + torch::optional map_location /*= torch::nullopt*/) { + auto rai = std::make_unique(filename); + + // Verify that we're loading a zip archive and not a torch.save pickle archive + // (marked by the 0x80 0x02 bytes at the start) + // i.e., _use_new_zipfile_serialization is False when torch.save was invoked + uint8_t first_short[2]; + rai->read( + /*pos=*/0, + /*buf=*/&first_short, + /*n=*/2, + /*what=*/"checking archive"); + if (first_short[0] == 0x80 && first_short[1] == 0x02) { + // NB: zip files by spec can start with any data, so technically they might + // start with 0x80 0x02, but in practice zip files start with a file entry + // which begins with 0x04034b50. Furthermore, PyTorch will never produce zip + // files that do not start with the file entry, so it is relatively safe to + // perform this check. + K2_LOG(FATAL) << "Please set _use_new_zipfile_serialization to True " + "when invoking torch.save()"; + } + + auto reader = torch::make_unique( + std::move(rai)); + + auto cu = std::make_shared(); + torch::jit::SourceImporter source_importer(cu, nullptr, nullptr, + reader->version()); + + auto type_resolver = [&](const c10::QualifiedName &qn) { + auto cls = source_importer.loadType(qn); + return c10::StrongTypePtr(cu, std::move(cls)); + }; + + // Decouple how to get obj from type. + // For bytecode import we need to decouple these dependencies. + auto obj_loader = [&](at::StrongTypePtr type, torch::IValue input) { + auto cls = type.type_->expect(); + auto qn = cls->name(); + size_t n = cls->numAttributes(); + if (checkHasValidSetGetState(cls)) { + auto obj = c10::ivalue::Object::create(type, n); + // XXX: Do not optimize __setstate__, so that we don't try to + // specialize the class before it is initialized. + torch::jit::GraphOptimizerEnabledGuard guard(false); + torch::jit::Function &set_state = cls->getMethod("__setstate__"); + // since we are in the middle of unpickling we might still have lists and + // dicts that do not have accurate tags (e.g. they report they are + // List[Any]). But we need to run __setstate__ which will check the input + // type and may access the tags. Since setstate has a known input type, we + // can correctly restore the tags now by apply the input type of set_state + // to the state object being passed. + // TODO: Remove once [serialization type tags] is landed + restoreAccurateTypeTags(input, + set_state.getSchema().arguments().at(1).type()); + set_state({obj, input}); + postSetStateValidate(obj); + return obj; + } else { + auto dict = std::move(input).toGenericDict(); + auto obj = c10::ivalue::Object::create(type, n); + for (size_t i = 0; i < n; ++i) { + obj->setSlot(i, dict.at(cls->getAttributeName(i))); + } + return obj; + } + }; + +#if K2_TORCH_VERSION_MAJOR > 1 || \ + (K2_TORCH_VERSION_MAJOR == 1 && K2_TORCH_VERSION_MINOR >= 9) + torch::IValue ivalue = torch::jit::readArchiveAndTensors( + "data", "", "", type_resolver, obj_loader, + /*device=*/map_location, *reader); + +#else + torch::IValue ivalue = + torch::jit::readArchiveAndTensors("data", type_resolver, obj_loader, + /*device=*/map_location, *reader); +#endif + return ivalue; +} + +k2::FsaClass LoadFsa( + const std::string &filename, + torch::optional map_location /*= torch::nullopt*/) { + auto ivalue = Load(filename, map_location); + K2_CHECK(ivalue.isGenericDict()) + << "Expect a dict. Given: " << ivalue.tagKind(); + + torch::Dict dict = ivalue.toGenericDict(); + K2_CHECK(dict.contains("arcs")) << "Expect to contain 'arcs' in the dict"; + + Tensor arcs = TensorFromTorch(dict.at("arcs").toTensor()); + + bool error = false; + Fsa fsa; + if (arcs.NumAxes() == 2) { + fsa = FsaFromTensor(arcs, &error); + } else if (arcs.NumAxes() == 1) { + fsa = FsaVecFromTensor(arcs, &error); + } + K2_CHECK_EQ(error, false); + + FsaClass ans(fsa); + + (void)dict.erase(torch::IValue("arcs")); + for (const auto &p : dict) { + const auto &name = p.key().toStringRef(); + auto v = p.value(); + if (v.isTensor()) { + ans.SetTensorAttr(name, v.toTensor()); + } else if (IsRaggedInt(v)) { + ans.SetRaggedTensorAttr(name, ToRaggedInt(v)); + } else { + K2_LOG(WARNING) << "Ignore non tensor attribute: '" << name + << "' of type: " << v.tagKind(); + } + } + + return ans; +} + +} // namespace k2 diff --git a/k2/torch/csrc/deserialization.h b/k2/torch/csrc/deserialization.h new file mode 100644 index 000000000..eb1d8a853 --- /dev/null +++ b/k2/torch/csrc/deserialization.h @@ -0,0 +1,68 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_DESERIALIZATION_H_ +#define K2_TORCH_CSRC_DESERIALIZATION_H_ + +#include + +#include "k2/csrc/fsa.h" +#include "k2/torch/csrc/fsa_class.h" +#include "torch/script.h" + +namespace k2 { + +/** Read a file saved in Python by `torch.save()`. + + Unlike torch::jit::pickle_load(), this function can also handle + k2.ragged.RaggedTensor. + + Caution: If you save a dict of tensors in `filename`, the dict MUST + have at least two items. Otherwise, it will throw. See + https://github.com/pytorch/pytorch/issues/67902 for more details. + + @param filename Path to the file to be loaded. + @param map_location It has the same meaning as the one in `torch.load()`. + The loaded IValue is moved to this device + before returning. + @return Return an IValue containing the content in the given file. + */ +torch::IValue Load( + const std::string &filename, + torch::optional map_location = torch::nullopt); + +/** + Load a file saved in Python by + + torch.save(fsa.as_dict(), filename, _use_new_zipfile_serialization=True) + + Note: `_use_new_zipfile_serialization` is True by default + + @param filename Path to the filename produced in Python by `torch.save()`. + @param map_location It has the same meaning as the one in `torch.load()`. + The loaded FSA is moved to this device + before returning. + @return Return the FSA contained in the filename. + */ +k2::FsaClass LoadFsa( + const std::string &filename, + torch::optional map_location = torch::nullopt); + +} // namespace k2 + +#endif // K2_TORCH_CSRC_DESERIALIZATION_H_ diff --git a/k2/torch/csrc/deserialization_test.cu b/k2/torch/csrc/deserialization_test.cu new file mode 100644 index 000000000..5e05eff92 --- /dev/null +++ b/k2/torch/csrc/deserialization_test.cu @@ -0,0 +1,230 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/test_deserialization_data.h" + +#ifdef K2_WITH_CUDA +#include "torch/cuda.h" +#endif + +namespace k2 { + +// defined in k2/torch/csrc/deserialization.cu +Ragged ToRaggedInt(torch::IValue value); + +static void TestDictOfTensorIntStr(const std::string &dir_name) { + std::string filename = dir_name + "/d1.pt"; + { + std::ofstream os(filename, std::ofstream::binary); + os.write(reinterpret_cast(kTestLoadData1), + sizeof(kTestLoadData1)); + } + // d1.pt contains + // {"a": torch.tensor([1., 2.]), "b": 10, "c": "k2"} + torch::IValue ivalue = Load(filename); + EXPECT_TRUE(ivalue.isGenericDict()); + torch::Dict dict = ivalue.toGenericDict(); + + EXPECT_TRUE(dict.contains("a")); + EXPECT_TRUE(dict.contains("b")); + EXPECT_TRUE(dict.contains("c")); + + torch::Tensor a = dict.at("a").toTensor(); + int32_t b = dict.at("b").toInt(); + const std::string &c = dict.at("c").toStringRef(); + + EXPECT_TRUE(a.allclose(torch::tensor({1, 2}, a.options()))); + EXPECT_EQ(b, 10); + EXPECT_EQ(c, "k2"); + + int32_t ret = remove(filename.c_str()); + assert(ret == 0); +} + +static void TestDictOfTensorAndRaggedTensor(const std::string &dir_name) { + std::string filename = dir_name + "/d2.pt"; + { + std::ofstream os(filename, std::ofstream::binary); + os.write(reinterpret_cast(kTestLoadData2), + sizeof(kTestLoadData2)); + } + // d2.pt contains + // {"a": torch.tensor([1.0, 2.0]), "b": k2.RaggedTensor([[1.5, 2], [3], []])} + torch::IValue ivalue = Load(filename); + EXPECT_TRUE(ivalue.isGenericDict()); + torch::Dict dict = ivalue.toGenericDict(); + + EXPECT_TRUE(dict.contains("a")); + EXPECT_TRUE(dict.contains("b")); + + torch::Tensor a = dict.at("a").toTensor(); + Ragged b = ToRaggedInt(dict.at("b")); + + EXPECT_TRUE(a.allclose(torch::tensor({1, 2}, a.options()))); + EXPECT_TRUE(Equal(b, Ragged("[[15 2] [3] []]"))); + + int32_t ret = remove(filename.c_str()); + assert(ret == 0); +} + +#ifdef K2_WITH_CUDA +static void TestDictOfCudaTensorAndCudaRaggedTensor( + const std::string &dir_name) { + std::string filename = dir_name + "/d3.pt"; + { + std::ofstream os(filename, std::ofstream::binary); + os.write(reinterpret_cast(kTestLoadData3), + sizeof(kTestLoadData3)); + } + // d3.pt contains: + // { + // "a": torch.tensor([1, 2], device=torch.device("cuda:0")), + // "b": k2.RaggedTensor([[15, 2], [3], []], device="cuda:0"), + // } + torch::IValue ivalue = Load(filename); + EXPECT_TRUE(ivalue.isGenericDict()); + torch::Dict dict = ivalue.toGenericDict(); + + EXPECT_TRUE(dict.contains("a")); + EXPECT_TRUE(dict.contains("b")); + + torch::Tensor a = dict.at("a").toTensor(); + Ragged b = ToRaggedInt(dict.at("b")); + + EXPECT_TRUE(a.is_cuda()); + + EXPECT_TRUE(a.allclose(torch::tensor({1, 2}, a.options()))); + EXPECT_TRUE(Equal(b, Ragged("[[15 2] [3] []]").To(b.Context()))); + + int32_t ret = remove(filename.c_str()); + assert(ret == 0); +} + +static void TestLoadFsaCuda(const std::string &dir_name) { + std::string filename = dir_name + "/d4.pt"; + { + std::ofstream os(filename, std::ofstream::binary); + os.write(reinterpret_cast(kTestLoadData4), + sizeof(kTestLoadData4)); + } + // d4.pt contains: + // { 'arcs': tensor([[0,1,-1,1036831949]],device='cuda:0',dtype=torch.int32), + // 'aux_labels': RaggedTensor([[1, 2]], device='cuda:0', dtype=torch.int32), + // 'attr': tensor([1.5000], device='cuda:0') + // } + FsaClass fsa = LoadFsa(filename); + auto device = DeviceFromContext(fsa.fsa.Context()); + EXPECT_EQ(device, torch::Device("cuda:0")); + + auto attr = fsa.GetTensorAttr("attr"); + EXPECT_TRUE(attr.allclose(torch::tensor({1.5}, attr.options()))); + EXPECT_TRUE(Equal(fsa.GetRaggedTensorAttr("aux_labels"), + Ragged("[[1 2]]").To(fsa.fsa.Context()))); + + int32_t ret = remove(filename.c_str()); + assert(ret == 0); +} + +#endif + +static void TestDictOfTensorAndRaggedTensorMapToCpu( + const std::string &dir_name) { + std::string filename = dir_name + "/d3.pt"; + { + std::ofstream os(filename, std::ofstream::binary); + os.write(reinterpret_cast(kTestLoadData3), + sizeof(kTestLoadData3)); + } + // d3.pt contains: + // { + // "a": torch.tensor([1, 2], device=torch.device("cuda:0")), + // "b": k2.RaggedTensor([[15, 2], [3], []], device="cuda:0"), + // } + torch::IValue ivalue = Load(filename, /*map_location*/ torch::kCPU); + EXPECT_TRUE(ivalue.isGenericDict()); + torch::Dict dict = ivalue.toGenericDict(); + + EXPECT_TRUE(dict.contains("a")); + EXPECT_TRUE(dict.contains("b")); + + torch::Tensor a = dict.at("a").toTensor(); + Ragged b = ToRaggedInt(dict.at("b")); + EXPECT_FALSE(a.is_cuda()); + + EXPECT_TRUE(a.allclose(torch::tensor({1, 2}, a.options()))); + EXPECT_TRUE(Equal(b, Ragged("[[15 2] [3] []]"))); + + int32_t ret = remove(filename.c_str()); + assert(ret == 0); +} + +static void TestLoadFsaMapToCpu(const std::string &dir_name) { + std::string filename = dir_name + "/d4.pt"; + { + std::ofstream os(filename, std::ofstream::binary); + os.write(reinterpret_cast(kTestLoadData4), + sizeof(kTestLoadData4)); + } + // d4.pt contains: + // { 'arcs': tensor([[0,1,-1,1036831949]],device='cuda:0',dtype=torch.int32), + // 'aux_labels': RaggedTensor([[1, 2]], device='cuda:0', dtype=torch.int32), + // 'attr': tensor([1.5000], device='cuda:0') + // } + FsaClass fsa = LoadFsa(filename, torch::kCPU); + auto device = DeviceFromContext(fsa.fsa.Context()); + EXPECT_EQ(device, torch::Device(torch::kCPU)); + + auto attr = fsa.GetTensorAttr("attr"); + EXPECT_TRUE(attr.allclose(torch::tensor({1.5}, attr.options()))); + EXPECT_TRUE(Equal(fsa.GetRaggedTensorAttr("aux_labels"), + Ragged("[[1 2]]").To(fsa.fsa.Context()))); + + int32_t ret = remove(filename.c_str()); + assert(ret == 0); +} + +TEST(Deserialization, Test) { + char pattern[] = "/tmp/k2_test.XXXXXX"; +#ifndef _MSC_VER + char *dir_name = mkdtemp(pattern); +#else + char *dir_name = "./"; +#endif + assert(dir_name != nullptr); + + TestDictOfTensorIntStr(dir_name); + TestDictOfTensorAndRaggedTensor(dir_name); + TestDictOfTensorAndRaggedTensorMapToCpu(dir_name); + TestLoadFsaMapToCpu(dir_name); + +#ifdef K2_WITH_CUDA + if (torch::cuda::is_available()) { + TestDictOfCudaTensorAndCudaRaggedTensor(dir_name); + TestLoadFsaCuda(dir_name); + } +#endif + +#ifndef _MSC_VER + int ret = rmdir(dir_name); + assert(ret == 0); +#endif +} + +} // namespace k2 diff --git a/k2/torch/csrc/features.cc b/k2/torch/csrc/features.cc new file mode 100644 index 000000000..06cec96f1 --- /dev/null +++ b/k2/torch/csrc/features.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "k2/torch/csrc/features.h" +#include "kaldifeat/csrc/feature-fbank.h" + +namespace k2 { + +torch::Tensor ComputeFeatures(kaldifeat::Fbank &fbank, + torch::Tensor wave_data) { + return fbank.ComputeFeatures(wave_data, /*vtln_warp*/ 1.0f); +} + +std::vector ComputeFeatures( + kaldifeat::Fbank &fbank, const std::vector &wave_data, + std::vector *num_frames /*=nullptr*/) { + const auto &frame_opts = fbank.GetOptions().frame_opts; + + std::vector num_frames_vec; + num_frames_vec.reserve(wave_data.size()); + + std::vector strided_vec; + strided_vec.reserve(wave_data.size()); + + for (const auto &t : wave_data) { + torch::Tensor strided = kaldifeat::GetStrided(t, frame_opts); + num_frames_vec.push_back(strided.size(0)); + strided_vec.emplace_back(std::move(strided)); + } + + torch::Tensor strided = torch::cat(strided_vec, 0); + torch::Tensor features = fbank.ComputeFeatures(strided, /*vtln_warp*/ 1.0f); + + auto ans = features.split_with_sizes(num_frames_vec, /*dim*/ 0); + if (num_frames) *num_frames = std::move(num_frames_vec); + return ans; +} + +} // namespace k2 diff --git a/k2/torch/csrc/features.h b/k2/torch/csrc/features.h new file mode 100644 index 000000000..714603d7d --- /dev/null +++ b/k2/torch/csrc/features.h @@ -0,0 +1,49 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_FEATURES_H_ +#define K2_TORCH_CSRC_FEATURES_H_ + +#include + +#include "kaldifeat/csrc/feature-fbank.h" + +namespace k2 { + +/** Compute fbank features of a 1-D tensor containing audio samples. + + @param fbank The Fbank computer. + @param wave_data A 1-D tensor with dtype torch.float32. Its elements + are expected in the range [-1, 1). + @return Return a 2-D tensor containing the features. Number of + rows equals to the number of frames. + */ +torch::Tensor ComputeFeatures(kaldifeat::Fbank &fbank, torch::Tensor wave_data); + +/// See `ComputeFeatures` above. It computes fbank features for a list +/// of audio samples, in parallel. +/// +/// @params num_frames If not null, it contains the number of feature frames of +/// each wave. +std::vector ComputeFeatures( + kaldifeat::Fbank &fbank, const std::vector &wave_data, + std::vector *num_frames = nullptr); + +} // namespace k2 + +#endif // K2_TORCH_CSRC_FEATURES_H_ diff --git a/k2/torch/csrc/fsa_algo.cu b/k2/torch/csrc/fsa_algo.cu new file mode 100644 index 000000000..f6583c291 --- /dev/null +++ b/k2/torch/csrc/fsa_algo.cu @@ -0,0 +1,297 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/fsa_algo.h" +#include "k2/csrc/fsa_utils.h" +#include "k2/csrc/ragged_ops.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/utils.h" + +namespace k2 { + +FsaClass CtcTopo(int32_t max_token, bool modified /*= false*/, + torch::Device device /*=torch::kCPU*/) { + Array1 aux_labels; + auto ctx = ContextFromDevice(device); + Fsa fsa = CtcTopo(ctx, max_token, modified, &aux_labels); + FsaClass dest(fsa); + dest.SetTensorAttr("aux_labels", Array1ToTorch(aux_labels)); + return dest; +} + +FsaClass TrivialGraph(int32_t max_token, + torch::Device device /*=torch::kCPU*/) { + Array1 aux_labels; + auto ctx = ContextFromDevice(device); + Fsa fsa = TrivialGraph(ctx, max_token, &aux_labels); + FsaClass dest(fsa); + dest.SetTensorAttr("aux_labels", Array1ToTorch(aux_labels)); + return dest; +} + +FsaClass IntersectDensePruned(FsaClass &graph, DenseFsaVec &dense, + float search_beam, float output_beam, + int32_t min_activate_states, + int32_t max_activate_states) { + Array1 graph_arc_map; + Array1 dense_arc_map; + FsaVec fsa; + IntersectDensePruned(graph.fsa, dense, search_beam, output_beam, + min_activate_states, max_activate_states, &fsa, + &graph_arc_map, &dense_arc_map); + FsaClass dest(fsa); + dest.CopyAttrs(graph, Array1ToTorch(graph_arc_map)); + return dest; +} + +FsaClass ShortestPath(FsaClass &lattice) { + Ragged state_batches = GetStateBatches(lattice.fsa, true); + Array1 dest_states = GetDestStates(lattice.fsa, true); + Ragged incoming_arcs = GetIncomingArcs(lattice.fsa, dest_states); + Ragged entering_arc_batches = + GetEnteringArcIndexBatches(lattice.fsa, incoming_arcs, state_batches); + + bool log_semiring = false; + Array1 entering_arcs; + GetForwardScores(lattice.fsa, state_batches, entering_arc_batches, + log_semiring, &entering_arcs); + + Ragged best_path_arc_indexes = + ShortestPath(lattice.fsa, entering_arcs); + + FsaVec out = FsaVecFromArcIndexes(lattice.fsa, best_path_arc_indexes); + torch::Tensor arc_map = Array1ToTorch(best_path_arc_indexes.values); + return FsaClass::FromUnaryFunctionTensor(lattice, out, arc_map); +} + +void Invert(FsaClass *lattice) { + K2_CHECK_NE(lattice, nullptr); + + if (lattice->HasTensorAttr("aux_labels")) { + // The invert is trivial, just swap the labels and aux_labels. + // No new arcs are added. + auto aux_labels = lattice->GetTensorAttr("aux_labels").clone(); + auto labels = lattice->Labels().clone(); + + // FixFinalLabels + auto minus_one = + torch::tensor(-1, torch::device(labels.device()).dtype(labels.dtype())); + aux_labels = torch::where(labels == -1, minus_one, aux_labels); + + lattice->SetTensorAttr("aux_labels", labels); + lattice->SetLabels(aux_labels); + } else { + K2_CHECK(lattice->HasRaggedTensorAttr("aux_labels")); + Ragged src_aux_labels = lattice->GetRaggedTensorAttr("aux_labels"); + + Fsa dest; + Ragged dest_aux_labels; + Array1 arc_map; + Invert(lattice->fsa, src_aux_labels, &dest, &dest_aux_labels, &arc_map); + + // `label` is the 3rd field of struct Arc. + FixFinalLabels(dest, reinterpret_cast(dest.values.Data()) + 2, + 4); + + lattice->DeleteRaggedTensorAttr("aux_labels"); + lattice->properties = 0; + lattice->fsa = dest; + lattice->CopyAttrs(*lattice, Array1ToTorch(arc_map)); + lattice->SetRaggedTensorAttr("aux_labels", dest_aux_labels); + } +} + +void ArcSort(FsaClass *lattice) { + Fsa dest; + Array1 arc_map; + ArcSort(lattice->fsa, &dest, &arc_map); + lattice->properties = 0; + lattice->fsa = dest; + lattice->CopyAttrs(*lattice, Array1ToTorch(arc_map)); +} + +void Connect(FsaClass *lattice) { + Fsa dest; + Array1 arc_map; + Connect(lattice->fsa, &dest, &arc_map); + lattice->properties = 0; + lattice->fsa = dest; + lattice->CopyAttrs(*lattice, Array1ToTorch(arc_map)); +} + +void TopSort(FsaClass *lattice) { + Fsa dest; + Array1 arc_map; + TopSort(lattice->fsa, &dest, &arc_map); + lattice->properties = 0; + lattice->fsa = dest; + lattice->CopyAttrs(*lattice, Array1ToTorch(arc_map)); +} + +Nbest RandomPaths(FsaClass &lattice, int32_t num_paths) { + auto &fsas = lattice.fsa; + Ragged state_batches = GetStateBatches(fsas, /*transpose*/ true); + Array1 dest_states = GetDestStates(fsas, /*as_idx01*/ true); + + Ragged incoming_arcs = GetIncomingArcs(fsas, dest_states); + + Ragged entering_arc_batches = + GetEnteringArcIndexBatches(fsas, incoming_arcs, state_batches); + + Ragged leaving_arc_batches = + GetLeavingArcIndexBatches(fsas, state_batches); + bool log_semiring = true; + + using FloatType = float; + Array1 forward_scores = GetForwardScores( + fsas, state_batches, entering_arc_batches, log_semiring, nullptr); + + Array1 backward_scores = GetBackwardScores( + fsas, state_batches, leaving_arc_batches, log_semiring); + + Array1 arc_post = + GetArcPost(fsas, forward_scores, backward_scores); + + Array1 arc_cdf = GetArcCdf(fsas, arc_post); + + Array1 tot_scores = GetTotScores(fsas, forward_scores); + + // paths has three axes [utt][path][arc_pos] + Ragged paths = + RandomPaths(fsas, arc_cdf, num_paths, tot_scores, state_batches); + + bool has_ragged_aux_labels = true; + + // word_seqs has three axes [utt][path[word_id] + Ragged word_seqs; + if (lattice.HasTensorAttr("aux_labels")) { + has_ragged_aux_labels = false; + // Index a tensor with a ragged index + // see Index() in k2/csrc/ragged_ops.h + auto &aux_labels = lattice.GetTensorAttr("aux_labels"); + Array1 aux_labels_array = Array1FromTorch(aux_labels); + word_seqs = Index(aux_labels_array, paths); + } else { + K2_CHECK(lattice.HasRaggedTensorAttr("aux_labels")); + auto &aux_labels = lattice.GetRaggedTensorAttr("aux_labels"); + // Index a ragged tensor with a ragged index + // see Index() in k2/csrc/ragged_ops.h + bool remove_axis = true; + word_seqs = Index(aux_labels, paths, remove_axis); + } + + word_seqs = RemoveValuesLeq(word_seqs, 0); + + // Each utterance has `num_paths` paths but some of them transduces + // to the same word sequence, so we need to remove repeated word + // sequences within an utterance. After removing repeats, each utterance + // contains different number of paths + // + // `new2old` maps from the output path index to the input path index. + Array1 new2old_indexes; + (void)UniqueSequences(word_seqs, nullptr, &new2old_indexes); + + // Index a ragged tensor with a tensor + // See Index() in k2/csrc/ragged_ops.h + // + // kept_paths has axes [utt][path][arc_pos] + Ragged kept_paths = Index(paths, /*axis*/ 1, new2old_indexes); + + // utt_to_path_shape has axes [utt][path] + RaggedShape utt_to_path_shape = GetLayer(kept_paths.shape, 0); + + // Remove the utterance axis. + kept_paths = kept_paths.RemoveAxis(0); + // Now kept_paths has only two axes [path][arc_pos] + + // labels has 2 axes [path][token_id] + // Note that it contains -1s. + // + // Index a tensor with a ragged index + // see Index() in k2/csrc/ragged_ops.h + auto lattice_labels = lattice.Labels(); + auto lattice_labels_array = + Array1FromTorch(lattice_labels.contiguous()); + Ragged labels = Index(lattice_labels_array, kept_paths); + + // Remove -1 from labels as we will use it to construct a linear FSA + labels = RemoveValuesEq(labels, -1); + Fsa dest = LinearFsas(labels); + FsaClass ans_lattice(dest); + if (has_ragged_aux_labels) { + auto &aux_labels = lattice.GetRaggedTensorAttr("aux_labels"); + // Index a ragged tensor with a tensor + // See Index() in k2/csrc/ragged_ops.h + Ragged ans_aux_labels = + Index(aux_labels, /*axis*/ 0, kept_paths.values); + ans_lattice.SetRaggedTensorAttr("aux_labels", ans_aux_labels); + } else { + auto &aux_labels = lattice.GetTensorAttr("aux_labels"); + Array1 aux_labels_array = Array1FromTorch(aux_labels); + // Index a tensor with a tensor index + // See Index() in k2/csrc/array_ops.h + Array1 ans_aux_labels = Index(aux_labels_array, kept_paths.values, + false, // allow_minus_one + 0); // default value + ans_lattice.SetTensorAttr("aux_labels", Array1ToTorch(ans_aux_labels)); + } + + return {ans_lattice, utt_to_path_shape}; +} + +FsaClass IntersectDevice(FsaClass &a_fsas, FsaClass &b_fsas, + const Array1 &b_to_a_map, + bool sorted_match_a) { + Array1 arc_map_a, arc_map_b; + + Fsa c_fsas = IntersectDevice(a_fsas.fsa, a_fsas.Properties(), b_fsas.fsa, + b_fsas.Properties(), b_to_a_map, &arc_map_a, + &arc_map_b, sorted_match_a); + + FsaClass ans(c_fsas); + ans.CopyAttrs(a_fsas, Array1ToTorch(arc_map_a)); + ans.CopyAttrs(b_fsas, Array1ToTorch(arc_map_b)); + return ans; +} + +FsaClass LinearFsaWithSelfLoops(FsaClass &fsas) { + RaggedShape shape; + if (fsas.fsa.NumAxes() == 2) { + // A single Fsa + auto shape0 = + RegularRaggedShape(fsas.fsa.Context(), 1, fsas.fsa.TotSize(0)); + shape = ComposeRaggedShapes(shape0, fsas.fsa.shape); + } else { + shape = fsas.fsa.shape; + } + + shape = RemoveAxis(shape, 1); // remove the state axis + + auto labels = Ragged( + shape, Array1FromTorch(fsas.Labels().contiguous())); + labels = RemoveValuesLeq(labels, 0); + + auto linear_fsa = LinearFsas(labels); + FsaVec ans; + AddEpsilonSelfLoops(linear_fsa, &ans); + + if (fsas.fsa.NumAxes() == 2) ans = ans.RemoveAxis(0); + return FsaClass(ans); +} + +} // namespace k2 diff --git a/k2/torch/csrc/fsa_algo.h b/k2/torch/csrc/fsa_algo.h new file mode 100644 index 000000000..04087e880 --- /dev/null +++ b/k2/torch/csrc/fsa_algo.h @@ -0,0 +1,181 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_FSA_ALGO_H_ +#define K2_TORCH_CSRC_FSA_ALGO_H_ + +#include "k2/csrc/fsa.h" +#include "k2/torch/csrc/fsa_class.h" +#include "k2/torch/csrc/nbest.h" + +namespace k2 { + +/* Create a CTC topology. + + Note: + A standard CTC topology is the conventional one, where there + is a mandatory blank between two repeated neighboring symbols. + A non-standard, i.e., modified CTC topology, imposes no such constraint. + + @param max_token The maximum token ID (inclusive). We assume that token IDs + are contiguous (from 1 to `max_token`). 0 represents blank. + @param modified If False, create a standard CTC topology. Otherwise, create + a modified CTC topology. + @param device A torch.device indicating what device the returned Fsa will + be. Default torch::CPU. + @return Return either a standard or a modified CTC topology as an FSA + depending on whether `modified` is false or true. + */ +FsaClass CtcTopo(int32_t max_token, bool modified = false, + torch::Device device = torch::kCPU); + +/* + Create a trivial graph which has only two states. On state 0, there are + `max_token` self loops(i.e. a loop for each symbol from 1 to max_token), and + state 1 is the final state. + + @param [in] max_token The maximum token ID (inclusive). We assume that + token IDs are contiguous (from 1 to `max_token`). + @param device A torch.device indicating what device the returned Fsa will + be. Default torch::CPU. + @return Returns the expected trivial graph on the given device. + Note the returned graph does not contain arcs with label being 0. + */ +FsaClass TrivialGraph(int32_t max_token, torch::Device device = torch::kCPU); + +/* Intersect a DenseFsaVec constructed from nnet_output with an FsaClass, i.e., + decoding graphs. + + @param graphs Input FsaClass containing decoding graphs and the associated + attributes. The decoding graph might just be a linear + sequence of phones, or might be something more complicated. + Must have either `graph.fsa.shape[0] == dense.dim0()`, or + `graphs.fsa.shape[0] == 1` in which case the graph is shared. + @param dense Input FSAs that correspond to neural network output. + @param search_beam Decoding beam, e.g. 20. Smaller is faster, larger is + more exact (less pruning). This is the default value; it + may be modified by `min_active_states` and + `max_active_states`. + @param output_beam Pruning beam for the output of intersection (vs. best + path); equivalent to kaldi's lattice-beam. E.g. 8. + @param max_active_states Maximum number of FSA states that are allowed to + be active on any given frame for any given + intersection/composition task. This is advisory, + in that it will try not to exceed that but may not + always succeed. You can use a very large number if + no constraint is needed. + @param min_active_states Minimum number of FSA states that are allowed to + be active on any given frame for any given + intersection/composition task. This is advisory, + in that it will try not to have fewer than this + number active. Set it to zero if there is no + constraint. + @return Returns an FsaClass containing the intersection of DenseFsaVec and + decoding graphs with the attributes propagated. + */ +FsaClass IntersectDensePruned(FsaClass &graphs, DenseFsaVec &dense, + float search_beam, float output_beam, + int32_t min_activate_states, + int32_t max_activate_states); + +/* Return the shortest paths as linear FSAs from the start state + to the final state in the tropical semiring. + + Note: + It uses the opposite sign. That is, It uses `max` instead of `min`. + + @param lattice The input FsaClass. + @return An FsaClass containing the best paths as linear FSAs with the + attributes propagated. + */ +FsaClass ShortestPath(FsaClass &lattice); + + +/* + Return array of total scores (one per FSA) + @param [in] fsas Input FsaVec (must have 3 axes) + @param [in] log_semiring If true, combine path with LogAdd + (i.e., mathematically, `log(exp(a)+exp(b))`); if false, + combine as `max(a,b)`. + @return Returns array of total scores, of dimension fsas.fsa.Dim0(), + which will contain the scores in the final-states of + `forward_scores`, or -infinity for FSAs that had no + states. +*/ +template +Array1 GetTotScores(FsaClass &fsa, bool log_semiring = true); + + +/** Swap the labels and aux labels of a lattice. + + Caution: This is an in-place operation. + + @param lattice The input/output lattice. It has to have + an attribute "aux_labels". + */ +void Invert(FsaClass *lattice); + +/** Arc sort an FSA in place. + + @param lattice The input/output lattice. + */ +void ArcSort(FsaClass *lattice); + +/** Trim an FSA in place. + + @param lattice The input/output lattice. + */ +void Connect(FsaClass *lattice); + +/** TopSort an FSA in place. + + @param lattice The input/output lattice. + */ +void TopSort(FsaClass *lattice); + +/** Sample num_paths from the given lattice. + @param lattice The input lattice to be sampled from. + @param num_paths Number of paths to sample + + @return Return a nbest object containing the sampled paths. + */ +Nbest RandomPaths(FsaClass &lattice, int32_t num_paths); + +/// Wrapper for k2::IntersectDevice() in k2/csrc/fsa_algo.h +/// to support attribute propagation. +FsaClass IntersectDevice(FsaClass &a_fsas, FsaClass &b_fsas, + const Array1 &b_to_a_map, + bool sorted_match_a); + +/** Create a linear FSA with epsilon self-loops by first removing epsilon + transitions from the input linear FSA. + + @param [in] fsas An FSA or an FsaVec. It MUST be a linear FSA or a vector + of linear FSAs. + @return Return an FSA or FsaVec, where each FSA contains epsilon + self-loops but contains no epsilon transitions for arcs that are + not self-loops. + */ +FsaClass LinearFsaWithSelfLoops(FsaClass &fsas); +} // namespace k2 + +#define IS_IN_K2_TORCH_CSRC_FSA_ALGO_H_ +#include "k2/torch/csrc/fsa_algo_inl.h" +#undef IS_IN_K2_TORCH_CSRC_FSA_ALGO_H_ + +#endif // K2_TORCH_CSRC_FSA_ALGO_H_ diff --git a/k2/torch/csrc/fsa_algo_inl.h b/k2/torch/csrc/fsa_algo_inl.h new file mode 100644 index 000000000..c7c163c08 --- /dev/null +++ b/k2/torch/csrc/fsa_algo_inl.h @@ -0,0 +1,46 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_FSA_ALGO_INL_H_ +#define K2_TORCH_CSRC_FSA_ALGO_INL_H_ + +#ifndef IS_IN_K2_TORCH_CSRC_FSA_ALGO_H_ +#error "this file is supposed to be included only by fsa_algo.h" +#endif + +#include "k2/csrc/fsa.h" +#include "k2/csrc/fsa_utils.h" +#include "k2/torch/csrc/fsa_class.h" + +namespace k2 { + +template +Array1 GetTotScores(FsaClass &fsa, bool log_semiring /* = true*/) { + Ragged state_batches = GetStateBatches(fsa.fsa, true); + Array1 dest_states = GetDestStates(fsa.fsa, true); + Ragged incoming_arcs = GetIncomingArcs(fsa.fsa, dest_states); + Ragged entering_arc_batches = + GetEnteringArcIndexBatches(fsa.fsa, incoming_arcs, state_batches); + + auto forward_scores = GetForwardScores( + fsa.fsa, state_batches, entering_arc_batches, log_semiring, nullptr); + return GetTotScores(fsa.fsa, forward_scores); +} + +} // namespace k2 +#endif // K2_TORCH_CSRC_FSA_ALGO_INL_H_ diff --git a/k2/torch/csrc/fsa_class.cu b/k2/torch/csrc/fsa_class.cu new file mode 100644 index 000000000..90c3de397 --- /dev/null +++ b/k2/torch/csrc/fsa_class.cu @@ -0,0 +1,207 @@ +/** + * @brief A wrapper around FsaOrVec + * + * @copyright + * Copyright 2021 Xiaomi Corp. (authors: Wei Kang, Fangjun Kuang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "k2/csrc/device_guard.h" +#include "k2/csrc/fsa_algo.h" +#include "k2/csrc/fsa_utils.h" +#include "k2/csrc/ragged_ops.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/fsa_class.h" +#include "k2/torch/csrc/utils.h" + +namespace k2 { + +FsaClass FsaClass::FromUnaryFunctionTensor(FsaClass &src, const FsaOrVec &arcs, + torch::Tensor arc_map) { + FsaClass dest(arcs); + dest.CopyAttrs(src, arc_map); + return dest; +} + +void FsaClass::CopyAttrs(FsaClass &src, torch::Tensor arc_map) { + CopyTensorAttrs(src, arc_map); + CopyRaggedTensorAttrs(src, arc_map); +} + +void FsaClass::CopyAttrs(std::vector &srcs, + Ragged &arc_map) { + K2_CHECK_EQ(fsa.NumAxes(), 3); + CopyTensorAttrs(srcs, arc_map); + CopyRaggedTensorAttrs(srcs, arc_map); +} + +void FsaClass::CopyTensorAttrs(FsaClass &src, torch::Tensor arc_map) { + for (const auto &iter : src.tensor_attrs) { + Dtype dtype = ConvertDtype(iter.second.scalar_type()); + FOR_REAL_AND_INT32_TYPES(dtype, T, { + auto value = IndexSelect(iter.second, arc_map, 0); + SetTensorAttr(iter.first, value); + }); + } +} + +void FsaClass::CopyTensorAttrs(std::vector &srcs, + Ragged &arc_map) { + K2_CHECK_EQ(arc_map.NumAxes(), 2); + K2_CHECK_EQ(arc_map.Dim0(), static_cast(srcs.size())); + // Gather attributes info of all source fsas. + std::unordered_map attrs_info; + for (const auto &fsa : srcs) { + for (const auto &iter : fsa.tensor_attrs) { + Dtype dtype = ConvertDtype(iter.second.scalar_type()); + attrs_info.insert(std::make_pair(iter.first, dtype)); + } + } + std::vector values; + auto row_splits = arc_map.RowSplits(1).To(GetCpuContext()); + for (const auto &iter : attrs_info) { + for (int32_t i = 0; i < static_cast(srcs.size()); ++i) { + auto this_arc_map_array = + arc_map.values.Arange(row_splits[i], row_splits[i + 1]); + auto this_arc_map = Array1ToTorch(this_arc_map_array); + if (srcs[i].HasTensorAttr(iter.first)) { + auto attr = srcs[i].GetTensorAttr(iter.first); + FOR_REAL_AND_INT32_TYPES(iter.second, T, { + auto value = IndexSelect(attr, this_arc_map, 0); + values.emplace_back(value); + }); + } else { + FOR_REAL_AND_INT32_TYPES(iter.second, T, { + auto opts = torch::dtype(ConvertDtype(iter.second)) + .device(this_arc_map.device()); + auto value = torch::zeros(this_arc_map.numel(), opts); + values.emplace_back(value); + }); + } + } + SetTensorAttr(iter.first, torch::cat(values)); + } +} + +void FsaClass::CopyRaggedTensorAttrs(FsaClass &src, torch::Tensor arc_map) { + Array1 indexes_array = Array1FromTorch(arc_map); + for (auto &iter : src.ragged_tensor_attrs) { + auto value = Index(iter.second, 0, indexes_array, nullptr); + SetRaggedTensorAttr(iter.first, value); + } +} + +void FsaClass::CopyRaggedTensorAttrs(std::vector &srcs, + Ragged &arc_map) { + K2_CHECK_EQ(arc_map.NumAxes(), 2); + K2_CHECK_EQ(arc_map.Dim0(), static_cast(srcs.size())); + std::unordered_set attrs_name; + for (const auto &fsa : srcs) { + for (const auto &iter : fsa.ragged_tensor_attrs) { + attrs_name.insert(iter.first); + } + } + std::vector> values; + auto row_splits = arc_map.RowSplits(1).To(GetCpuContext()); + for (const auto &name : attrs_name) { + for (int32_t i = 0; i < static_cast(srcs.size()); ++i) { + auto this_arc_map = + arc_map.values.Arange(row_splits[i], row_splits[i + 1]); + if (srcs[i].HasRaggedTensorAttr(name)) { + auto attr = srcs[i].GetRaggedTensorAttr(name); + auto value = Index(attr, 0 /*axis*/, this_arc_map); + values.emplace_back(value); + } else { + auto empty_shape = + RegularRaggedShape(this_arc_map.Context(), this_arc_map.Dim(), 0); + auto value = Ragged(empty_shape); + values.emplace_back(value); + } + } + SetRaggedTensorAttr(name, Cat(0 /*axis*/, values.size(), values.data())); + } +} + +void FsaClass::SetScores(torch::Tensor scores) { + K2_CHECK_EQ(scores.numel(), fsa.NumElements()); + K2_CHECK_EQ(scores.scalar_type(), torch::kFloat32); + K2_CHECK(ContextFromTensor(scores)->IsCompatible(*fsa.Context())); + Scores().copy_(scores); +} + +torch::Tensor FsaClass::Scores() { + auto device = DeviceFromContext(fsa.Context()); + auto scalar_type = caffe2::TypeMeta::Make(); + + // an Arc has 4 members + static_assert(sizeof(Arc) == 4 * sizeof(int32_t), ""); + + std::vector sizes = {fsa.values.Dim(), 4}; // [num_rows, num_cols] + std::vector strides = {4, 1}; // in number of elements + auto options = torch::device(device).dtype(scalar_type); + + auto tmp_scores = torch::from_blob( + fsa.values.Data(), sizes, strides, + [saved_region = fsa.values.GetRegion()](void *) {}, options); + return tmp_scores.index({"...", -1}); +} + +int32_t FsaClass::Properties() { + if (properties == 0) { + if (fsa.NumAxes() == 2) { + properties = GetFsaBasicProperties(fsa); + } else { + GetFsaVecBasicProperties(fsa, nullptr, &properties); + } + if ((properties & kFsaPropertiesValid) != kFsaPropertiesValid) { + K2_LOG(FATAL) << "Fsa is not valid, properties are : " << properties + << " = " << FsaPropertiesAsString(properties); + } + } + return properties; +} + +torch::Tensor FsaClass::Labels() { + auto device = DeviceFromContext(fsa.Context()); + auto scalar_type = caffe2::TypeMeta::Make(); + // an Arc has 4 members + static_assert(sizeof(Arc) == 4 * sizeof(int32_t), ""); + + std::vector sizes = {fsa.values.Dim(), 4}; // [num_rows, num_cols] + std::vector strides = {4, 1}; // in number of elements + auto options = torch::device(device).dtype(scalar_type); + + torch::Tensor arcs = torch::from_blob( + fsa.values.Data(), sizes, strides, + [saved_region = fsa.values.GetRegion()](void *) {}, options); + + return arcs.index({"...", 2}); +} + +void FsaClass::SetLabels(torch::Tensor labels) { + K2_CHECK_EQ(labels.numel(), fsa.NumElements()); + K2_CHECK_EQ(labels.scalar_type(), torch::kInt32); + K2_CHECK(ContextFromTensor(labels)->IsCompatible(*fsa.Context())); + Labels().copy_(labels); + properties = 0; // Clear cached properties as we changed the labels +} + +} // namespace k2 diff --git a/k2/torch/csrc/fsa_class.h b/k2/torch/csrc/fsa_class.h new file mode 100644 index 000000000..9aa704b0e --- /dev/null +++ b/k2/torch/csrc/fsa_class.h @@ -0,0 +1,268 @@ +/** + * @brief Wrapper for k2::Fsa to support attribute propagation. + * + * @copyright + * Copyright 2021 Xiaomi Corp. (authors: Wei Kang, Fangjun Kuang) + * + * @copyright + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_FSA_CLASS_H_ +#define K2_TORCH_CSRC_FSA_CLASS_H_ + +#include +#include +#include +#include +#include + +#include "k2/csrc/fsa.h" +#include "k2/csrc/ragged.h" +#include "k2/torch/csrc/utils.h" +#include "torch/script.h" + +namespace k2 { + +// It is a wrapper of FsaOrVec to support attributes propagation +struct FsaClass { + // TODO(fangjun): Make it a class and set its data members to private + FsaOrVec fsa; + int32_t properties = 0; + + // TODO(fangjun): Use two arrays to represent tensor_attrs + // as there are usually only one or two attributes associated + // with an FSA in decoding. + // + /// It contains all tensor attributes of this FSA + std::unordered_map tensor_attrs; + + /// It contains all ragged tensor attributes of this FSA + std::unordered_map> ragged_tensor_attrs; + + // The default constructor initializes an invalid FSA. + FsaClass() = default; + + explicit FsaClass(const FsaOrVec &fsa) : fsa(fsa) { + // Check the validation of the fsa, will trigger a fatal error if the fsa + // is not valid. + Properties(); + } + + FsaClass(const FsaClass &) = default; + FsaClass &operator=(const FsaClass &) = default; + FsaClass(FsaClass &&) = default; + FsaClass &operator=(FsaClass &&) = default; + + /// Returns the number of attributes contained in this FSA + int32_t NumAttrs() const { + return tensor_attrs.size() + ragged_tensor_attrs.size(); + } + + /** + Create an Fsa object, including propagating properties from the source FSA. + This is intended to be called from unary functions on FSAs where the arc_map + is a Tensor of int32 (i.e. not ragged). + @param src The source Fsa, i.e. the arg to the unary function. + @param arcs The raw output of the unary function, as output by whatever C++ + algorithm we used. + @param arc_map A map from arcs in `arcs` to the corresponding arc-index in + `src`, or -1 if the arc had no source arc + (e.g. added epsilon self-loops). + */ + static FsaClass FromUnaryFunctionTensor(FsaClass &src, const FsaOrVec &arcs, + torch::Tensor arc_map); + + /* Return a 1-D torch.float32 torch tensor. + + @caution It shares the underlying memory with this FSA. + */ + torch::Tensor Scores(); + + /** Set scores, will modify scores in fsa.arcs + + @param scores A 1-D tensor of dtype torch.float32. + */ + void SetScores(torch::Tensor scores); + + /** Return a 1-D int32 torch tensor. + @caution It shares the underlying memory with this FSA. + */ + torch::Tensor Labels(); + + /** Set labels, will modify labels in fsa.arcs + + @param labels A 1-D tensor of dtype torch.int32. + */ + void SetLabels(torch::Tensor labels); + + // Get fsa properties. + int32_t Properties(); + + /// Return the given tensor attribute by its name + const torch::Tensor &GetTensorAttr(const std::string &name) const { + return tensor_attrs.at(name); + } + + /// Return the given tensor attribute by its name + torch::Tensor &GetTensorAttr(const std::string &name) { + return tensor_attrs.at(name); + } + + /// Return the given ragged tensor attribute by its name + const Ragged &GetRaggedTensorAttr(const std::string &name) const { + return ragged_tensor_attrs.at(name); + } + + /// Return the given ragged tensor attribute by its name + Ragged &GetRaggedTensorAttr(const std::string &name) { + return ragged_tensor_attrs.at(name); + } + + /// Return true if this FSA has a tensor attribute with such a name. + /// Return false otherwise. + bool HasTensorAttr(const std::string &name) const { + return tensor_attrs.count(name) > 0; + } + + /// Return true if this FSA has a ragged tensor attribute with such a name. + /// Return false otherwise. + bool HasRaggedTensorAttr(const std::string &name) const { + return ragged_tensor_attrs.count(name) > 0; + } + + /** Delete a tensor attribute by its name. + * + Raise a RuntimeError exception if there is no such attribute. + + @param name The attribute name. + */ + void DeleteTensorAttr(const std::string &name) { + auto it = tensor_attrs.find(name); + if (it == tensor_attrs.end()) { + K2_LOG(FATAL) << "No such tensor attribute: " << name; + } + tensor_attrs.erase(it); + } + + /** Delete a ragged attribute by its name. + + Raise a RuntimeError exception if there is no such attribute. + + @param name The attribute name. + */ + void DeleteRaggedTensorAttr(const std::string &name) { + auto it = ragged_tensor_attrs.find(name); + if (it == ragged_tensor_attrs.end()) { + K2_LOG(FATAL) << "No such ragged tensor attribute: " << name; + } + ragged_tensor_attrs.erase(it); + } + + /** Propagate attributes from source FsaClass via tensor arc_map. + + @param src The source FsaClass. + @param arc_map The arc_map (as idx012) to select items in attributes. + */ + void CopyAttrs(FsaClass &src, torch::Tensor arc_map); + + /** Propagate attributes from a list of source FsaClasses via ragged tensor + arc_map. We assume that each sublist in arc_map contains the indexes into + arcs (as idx01) of corresponding Fsa in the list of source FsaClasses. + And we propagate the attributes from the source FsaClass to the + corresponding Fsa(i.e. sub Fsa of raw FsaVec in current FsaClass object) + via the indexes in the corresponding sublist of arc_map. + + Caution: The raw fsa in current object MUST be an 3 axes FsaVec, and it + MUST satisfy `fsa.Numelements() == arc_map.Numelements()` and + `fsa.Dim0() == arc_map.Dim0()`. + + Note: The attributes of current object is a union of the attributes + of all the source FsaClasses. For example, srcs[0] has attributes + attr1, attr2; srcs[1] has attributes attr1, attr3; srcs[2] has + attributes attr3, attr4; then current FsaClass object has attributes + attr1, attr2, attr3, attr4 after propagation. + + @param srcs A vector of the source FsaClasses. + @param arc_map The arc_map (as idx01) to select items in attributes. + */ + void CopyAttrs(std::vector &srcs, Ragged &arc_map); + + /** Associate an tensor attribute with a value directly. + + @param name The attribute name. + @param value The attribute value. + */ + void SetTensorAttr(const std::string &name, torch::Tensor value) { + K2_CHECK_EQ(value.size(0), fsa.NumElements()) + << "'" << name + << "': shape[0] of the tensor MUST be equal to number of arcs"; + K2_CHECK(ContextFromTensor(value)->IsCompatible(*fsa.Context())); + tensor_attrs[name] = value; + } + + /** Associate a ragged tensor attribute with a value directly. + + @param name The attribute name. + @param value The attribute value. + */ + void SetRaggedTensorAttr(const std::string &name, + const Ragged &value) { + K2_CHECK_EQ(value.Dim0(), fsa.NumElements()) + << "'" << name + << "': dim0 of the tensor MUST be equal to number of arcs"; + K2_CHECK(value.Context()->IsCompatible(*fsa.Context())); + ragged_tensor_attrs[name] = value; + } + + private: + /** Propagate tensor attributes from source FsaClass via tensor arc_map. + + @param src The source FsaClass. + @param arc_map The arc_map (as idx012) to select items in attributes. + */ + void CopyTensorAttrs(FsaClass &src, torch::Tensor arc_map); + + + /** Propagate tensor attributes from a list of source FsaClasses via ragged + tensor arc_map. + See docs in CopyAttrs that has same arguments for more details. + + @param srcs A vector of the source FsaClasses. + @param arc_map The arc_map (as idx01) to select items in attributes. + */ + void CopyTensorAttrs(std::vector &srcs, Ragged &arc_map); + + /** Propagate ragged tensor attributes from source FsaClass via tensor + arc_map. + + @param src The source FsaClass. + @param arc_map The arc_map (as idx012) to select items in attributes. + */ + void CopyRaggedTensorAttrs(FsaClass &src, torch::Tensor arc_map); + + /** Propagate ragged tensor attributes from a list of source FsaClasses via + ragged tensor arc_map. + See docs in CopyAttrs that has same arguments for more details. + + @param srcs A vector of the source FsaClasses. + @param arc_map The arc_map (as idx01) to select items in attributes. + */ + void CopyRaggedTensorAttrs(std::vector &srcs, + Ragged &arc_map); +}; + +} // namespace k2 +#endif // K2_TORCH_CSRC_FSA_CLASS_H_ diff --git a/k2/torch/csrc/fsa_class_test.cu b/k2/torch/csrc/fsa_class_test.cu new file mode 100644 index 000000000..854ddfd7b --- /dev/null +++ b/k2/torch/csrc/fsa_class_test.cu @@ -0,0 +1,130 @@ +/** + * Copyright (c) 2021 Xiaomi Corporation (authors: Wei Kang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "gtest/gtest.h" +#include "k2/csrc/fsa_algo.h" +#include "k2/csrc/fsa_utils.h" +#include "k2/csrc/ragged_ops.h" +#include "k2/torch/csrc/fsa_class.h" +#include "k2/torch/csrc/utils.h" + +namespace k2 { + +TEST(FsaClassTest, FromUnaryFunctionTensor) { + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + std::string s = R"(0 1 2 10 + 0 1 1 20 + 1 2 -1 30 + 2)"; + + auto device = DeviceFromContext(c); + Fsa fsa = FsaFromString(s).To(c); + FsaClass src = FsaClass(fsa); + + auto float32_opts = torch::dtype(torch::kFloat32).device(device); + auto int32_opts = torch::dtype(torch::kInt32).device(device); + src.SetTensorAttr("float_attr", + torch::tensor({0.1, 0.2, 0.3}, float32_opts)); + + src.SetTensorAttr("int_attr", torch::tensor({1, 2, 3}, int32_opts)); + + Ragged ragged_attr(c, "[[1 2 3] [5 6] []]"); + + src.SetRaggedTensorAttr("ragged_attr", ragged_attr); + + Array1 arc_map; + Ragged arcs; + ArcSort(src.fsa, &arcs, &arc_map); + auto dest = FsaClass::FromUnaryFunctionTensor( + src, arcs, Array1ToTorch(arc_map)); + + EXPECT_TRUE(torch::allclose(dest.GetTensorAttr("float_attr"), + torch::tensor({0.2, 0.1, 0.3}, float32_opts))); + + EXPECT_TRUE(torch::allclose(dest.Scores(), + torch::tensor({20, 10, 30}, float32_opts))); + + EXPECT_TRUE(torch::equal(dest.GetTensorAttr("int_attr"), + torch::tensor({2, 1, 3}, int32_opts))); + + Ragged expected_ragged_attr = + Ragged(c, "[[5 6] [1 2 3] []]"); + + EXPECT_TRUE( + Equal(dest.GetRaggedTensorAttr("ragged_attr"), expected_ragged_attr)); + } +} + +TEST(FsaClassTest, Attributes) { + for (const ContextPtr &c : {GetCpuContext(), GetCudaContext()}) { + auto device = DeviceFromContext(c); + std::string s = R"(0 1 2 10 + 0 1 1 20 + 1 2 -1 30 + 2)"; + Fsa fsa = FsaFromString(s).To(c); + FsaClass src = FsaClass(fsa); + + auto float32_opts = torch::dtype(torch::kFloat32).device(device); + auto int32_opts = torch::dtype(torch::kInt32).device(device); + + // test scores + EXPECT_TRUE( + torch::equal(src.Scores(), torch::tensor({10, 20, 30}, float32_opts))); + + torch::Tensor scores = torch::tensor({1, 2, 3}, float32_opts); + src.SetScores(scores); + EXPECT_TRUE(torch::equal(src.Scores(), scores)); + + // test labels + EXPECT_TRUE( + torch::equal(src.Labels(), torch::tensor({2, 1, -1}, int32_opts))); + + torch::Tensor labels = torch::tensor({20, 10, -1}, int32_opts); + src.SetLabels(labels); + EXPECT_TRUE(torch::equal(src.Labels(), labels)); + + // test tensor attribute + torch::Tensor tensor_int = torch::tensor({1, 2, 3}, int32_opts); + src.SetTensorAttr("tensor_int", tensor_int); + + torch::Tensor tensor_float = torch::tensor({1, 2, 3}, float32_opts); + src.SetTensorAttr("tensor_float", tensor_float); + + EXPECT_TRUE(torch::equal(src.GetTensorAttr("tensor_int"), tensor_int)); + EXPECT_TRUE( + torch::allclose(src.GetTensorAttr("tensor_float"), tensor_float)); + + src.DeleteTensorAttr("tensor_int"); + EXPECT_FALSE(src.HasTensorAttr("tensor_int")); + + // test ragged attribute + auto ragged_int = Ragged(c, "[[1, 2], [3], [4]]"); + src.SetRaggedTensorAttr("ragged_int", ragged_int); + + EXPECT_TRUE(Equal(src.GetRaggedTensorAttr("ragged_int"), ragged_int)); + src.DeleteRaggedTensorAttr("ragged_int"); + EXPECT_FALSE(src.HasRaggedTensorAttr("ragged_int")); + } +} + +} // namespace k2 diff --git a/k2/torch/csrc/hypothesis.cu b/k2/torch/csrc/hypothesis.cu new file mode 100644 index 000000000..47a9991c1 --- /dev/null +++ b/k2/torch/csrc/hypothesis.cu @@ -0,0 +1,56 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "k2/csrc/utils.h" +#include "k2/torch/csrc/hypothesis.h" +namespace k2 { + +void Hypotheses::Add(Hypothesis hyp) { + auto key = hyp.Key(); + auto it = hyps_dict_.find(key); + if (it == hyps_dict_.end()) { + hyps_dict_[key] = std::move(hyp); + } else { + it->second.log_prob = LogAdd()(it->second.log_prob, hyp.log_prob); + } +} + +Hypothesis Hypotheses::GetMostProbable(bool length_norm) const { + if (length_norm == false) { + return std::max_element(hyps_dict_.begin(), hyps_dict_.end(), + [](const auto &left, auto &right) -> bool { + return left.second.log_prob < + right.second.log_prob; + }) + ->second; + } else { + // for length_norm is true + return std::max_element( + hyps_dict_.begin(), hyps_dict_.end(), + [](const auto &left, const auto &right) -> bool { + return left.second.log_prob / left.second.ys.size() < + right.second.log_prob / right.second.ys.size(); + }) + ->second; + } +} + +} // namespace k2 diff --git a/k2/torch/csrc/hypothesis.h b/k2/torch/csrc/hypothesis.h new file mode 100644 index 000000000..731dc43e8 --- /dev/null +++ b/k2/torch/csrc/hypothesis.h @@ -0,0 +1,112 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_HYPOTHESIS_H_ +#define K2_TORCH_CSRC_HYPOTHESIS_H_ + +#include +#include +#include +#include + +#include "torch/all.h" + +namespace k2 { + +struct Hypothesis { + // The predicted tokens so far. Newly predicated tokens are appended. + std::vector ys; + + // The total score of ys in log space. + double log_prob = 0; + + Hypothesis() = default; + Hypothesis(const std::vector &ys, double log_prob) + : ys(ys), log_prob(log_prob) {} + + // If two Hypotheses have the same `Key`, then they contain + // the same token sequence. + std::string Key() const { return torch::Join("-", ys); } + + // For debugging + std::string ToString() const { + std::ostringstream os; + os << "(" << Key() << ", " << log_prob << ")"; + return os.str(); + } +}; + +class Hypotheses { + public: + Hypotheses() = default; + + explicit Hypotheses(std::vector hyps) { + for (auto &h : hyps) { + hyps_dict_[h.Key()] = std::move(h); + } + } + + explicit Hypotheses(std::unordered_map hyps_dict) + : hyps_dict_(std::move(hyps_dict)) {} + + // Add hyp to this object. If it already exists, its log_prob + // is updated with the given hyp using log-sum-exp. + void Add(Hypothesis hyp); + + // Get the hyp that has the largest log_prob. + // If length_norm is true, hyp's log_prob are divided by + // len(hyp.ys) before comparison. + Hypothesis GetMostProbable(bool length_norm) const; + + // Remove the given hyp from this object. + // It is *NOT* an error if hyp does not exist in this object. + void Remove(const Hypothesis &hyp) { hyps_dict_.erase(hyp.Key()); } + + // Return a list of hyps contained in this object. + std::vector Vec() const { + std::vector ans; + ans.reserve(hyps_dict_.size()); + for (const auto &p : hyps_dict_) { + ans.push_back(p.second); + } + return ans; + } + + int32_t Size() const { return hyps_dict_.size(); } + + std::string ToString() const { + std::ostringstream os; + for (const auto &p : hyps_dict_) { + os << p.second.ToString() << "\n"; + } + return os.str(); + } + + auto begin() { return hyps_dict_.begin(); } + auto end() { return hyps_dict_.end(); } + + const auto begin() const { return hyps_dict_.begin(); } + const auto end() const { return hyps_dict_.end(); } + + private: + using Map = std ::unordered_map; + Map hyps_dict_; +}; + +} // namespace k2 +#endif // K2_TORCH_CSRC_HYPOTHESIS_H_ diff --git a/k2/torch/csrc/hypothesis_test.cu b/k2/torch/csrc/hypothesis_test.cu new file mode 100644 index 000000000..c06c46f4d --- /dev/null +++ b/k2/torch/csrc/hypothesis_test.cu @@ -0,0 +1,43 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "k2/torch/csrc/hypothesis.h" + +namespace k2 { + +TEST(Hypothesis, Key) { + Hypothesis hyp; + hyp.ys = {1, 2, 3}; + EXPECT_EQ(hyp.Key(), "1-2-3"); +} + +TEST(Hypotheses, ConstructorFromVector) { + std::vector hyp_vec; + hyp_vec.emplace_back(Hypothesis({1, 2, 3}, -1.5)); + hyp_vec.emplace_back(Hypothesis({30}, -2.5)); + + EXPECT_EQ(hyp_vec[0].ys.size(), 3); + EXPECT_EQ(hyp_vec[1].ys.size(), 1); + + Hypotheses hyps(std::move(hyp_vec)); + + EXPECT_TRUE(hyp_vec.empty()); +} + +} // namespace k2 diff --git a/k2/torch/csrc/nbest.cu b/k2/torch/csrc/nbest.cu new file mode 100644 index 000000000..19ebb1fc5 --- /dev/null +++ b/k2/torch/csrc/nbest.cu @@ -0,0 +1,116 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "k2/csrc/fsa_algo.h" +#include "k2/torch/csrc/fsa_algo.h" +#include "k2/torch/csrc/nbest.h" +namespace k2 { + +Nbest::Nbest(const FsaClass &fsa, const RaggedShape &shape) + : fsa(fsa), shape(shape) { + K2_CHECK_EQ(fsa.fsa.NumAxes(), 3) << "Expect an FsaVec"; + K2_CHECK_EQ(shape.NumAxes(), 2) << "Expect a shape with axes [utt][path]"; + K2_CHECK_EQ(fsa.fsa.Dim0(), shape.NumElements()); +} + +Nbest Nbest::FromLattice(FsaClass &lattice, int32_t num_paths, + float nbest_scale /*= 0.5*/) { + K2_CHECK_EQ(lattice.fsa.NumAxes(), 3); + K2_CHECK_GT(num_paths, 1); + + torch::Tensor scores = lattice.Scores(); + torch::Tensor saved_scores = scores.clone(); + + scores = scores * nbest_scale; + lattice.SetScores(scores); + Nbest ans = RandomPaths(lattice, num_paths); + lattice.SetScores(saved_scores); + return ans; +} + +void Nbest::Intersect(FsaClass *lattice) { + K2_CHECK_EQ(lattice->fsa.NumAxes(), 3); + Invert(&fsa); + // Now fsa contains word IDs as labels and aux_labels as token IDs. + + fsa.Scores().zero_(); // Just in case it has scores set + + K2_CHECK(lattice->HasTensorAttr("aux_labels") || + lattice->HasRaggedTensorAttr("aux_labels")); + + // We don't need the aux labels for this->fsa, + // as we are going to use the one from lattice. + Fsa word_fsa_with_epsilon_self_loops; + RemoveEpsilonAndAddSelfLoops(fsa.fsa, fsa.Properties(), + &word_fsa_with_epsilon_self_loops); + + auto &path_to_utt_map = shape.RowIds(1); + + // The following Invert() and ArcSort() change lattice in-place + Invert(lattice); + // Now lattice has word IDs as labels and token IDs as aux_labels + ArcSort(lattice); + + FsaClass word_fsa_with_epsilon_self_loops_wrapper( + word_fsa_with_epsilon_self_loops); + + FsaClass ans = + IntersectDevice(*lattice, word_fsa_with_epsilon_self_loops_wrapper, + path_to_utt_map, true); + + Connect(&ans); + TopSort(&ans); + ans = ShortestPath(ans); + Invert(&ans); + // now ans.fsa has token IDs as labels and word IDs as aux_labels. + + this->fsa = ans; +} + +torch::Tensor Nbest::ComputeAmScores() { + K2_CHECK(fsa.HasTensorAttr("lm_scores")); + torch::Tensor am_scores = + (fsa.Scores() - fsa.GetTensorAttr("lm_scores")).contiguous(); + + // fsa.shape has axes [fsa][state][arc] + RaggedShape scores_shape = RemoveAxis(fsa.fsa.shape, 1); + // scores_shape has axes [fsa][arc] + + Ragged ragged_am_scores{scores_shape, + Array1FromTorch(am_scores)}; + Array1 tot_scores(fsa.fsa.Context(), fsa.fsa.Dim0()); + SumPerSublist(ragged_am_scores, 0, &tot_scores); + return Array1ToTorch(tot_scores); +} + +torch::Tensor Nbest::ComputeLmScores() { + K2_CHECK(fsa.HasTensorAttr("lm_scores")); + torch::Tensor lm_scores = fsa.GetTensorAttr("lm_scores"); + + // fsa.shape has axes [fsa][state][arc] + RaggedShape scores_shape = RemoveAxis(fsa.fsa.shape, 1); + // scores_shape has axes [fsa][arc] + + Ragged ragged_lm_scores{scores_shape, + Array1FromTorch(lm_scores)}; + Array1 tot_scores(fsa.fsa.Context(), fsa.fsa.Dim0()); + SumPerSublist(ragged_lm_scores, 0, &tot_scores); + return Array1ToTorch(tot_scores); +} + +} // namespace k2 diff --git a/k2/torch/csrc/nbest.h b/k2/torch/csrc/nbest.h new file mode 100644 index 000000000..3c9af1736 --- /dev/null +++ b/k2/torch/csrc/nbest.h @@ -0,0 +1,98 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_NBEST_H_ +#define K2_TORCH_CSRC_NBEST_H_ + +#include +#include +#include + +#include "k2/csrc/fsa.h" +#include "k2/torch/csrc/fsa_class.h" + +namespace k2 { + +/* + +An Nbest object contains two fields: + + (1) fsa. It is an FsaVec containing a vector of **linear** FSAs. + Its axes are [path][state][arc] + (2) shape. Its type is :class:`k2::RaggedShape`. + Its axes are [utt][path] + +The field `shape` has two axes [utt][path]. `shape.Dim0` contains +the number of utterances, which is also the number of rows in the +supervision_segments. `shape.tot_size(1)` contains the number +of paths, which is also the number of FSAs in `fsa`. + +Caution: + Don't be confused by the name `Nbest`. The best in the name `Nbest` + has nothing to do with `best scores`. The important part is + `N` in `Nbest`, not `best`. + */ +struct Nbest { + FsaClass fsa; + RaggedShape shape; + + Nbest(const FsaClass &fsa, const RaggedShape &shape); + + // Return a string representation of this object + // in the form + // Nbest(num_utteraces=xxx, num_paths=xxx) + std::string ToString() const { + std::ostringstream os; + os << "Nbest(num_utterances=" << shape.Dim0() + << ", num_paths=" << shape.NumElements() << ")"; + return os.str(); + } + /** Construct an Nbest object by sampling num_paths from a lattice. + + @param lattice The input/output lattice to be sampled. + @param num_paths Number of paths to sample. + @param nbest_scale Scale lattice.scores by this value before + sampling. + @return Return an Nbest object containing the sampled paths, with + duplicated paths being removed. + */ + static Nbest FromLattice(FsaClass &lattice, int32_t num_paths, + float nbest_scale = 0.5); + + /// Intersect this object with a lattice to assign scores + /// `this` nbest. + /// + /// @param lattice The lattice to intersect. Note it is modified in-place. + /// You should not use it after invoking this function. + /// + /// Note: The scores for the return value of FromLattice() are + /// all 0s. + void Intersect(FsaClass *lattice); + + /// Compute the AM scores of each path + /// Return a 1-D torch.float32 tensor with dim equal to fsa.Dim0() + torch::Tensor ComputeAmScores() /*const*/; + + /// Compute the LM scores of each path + /// Return a 1-D torch.float32 tensor with dim equal to fsa.Dim0() + torch::Tensor ComputeLmScores() /*const*/; +}; + +} // namespace k2 + +#endif // K2_TORCH_CSRC_NBEST_H_ diff --git a/k2/torch/csrc/parse_options.cu b/k2/torch/csrc/parse_options.cu new file mode 100644 index 000000000..19e4a3c7a --- /dev/null +++ b/k2/torch/csrc/parse_options.cu @@ -0,0 +1,783 @@ +/** + * Copyright 2009-2011 Karel Vesely; Microsoft Corporation; + * Saarland University (Author: Arnab Ghoshal); + * Copyright 2012-2013 Johns Hopkins University (Author: Daniel Povey); + * Frantisek Skala; Arnab Ghoshal + * Copyright 2013 Tanel Alumae + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied and modified from kaldi/src/util/parse-options.cu + +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "k2/csrc/log.h" +#include "k2/torch/csrc/parse_options.h" + +#ifdef _MSC_VER +#define K2_STRTOLL(cur_cstr, end_cstr) _strtoi64(cur_cstr, end_cstr, 10); +#else +#define K2_STRTOLL(cur_cstr, end_cstr) strtoll(cur_cstr, end_cstr, 10); +#endif + +namespace k2 { + +/// Converts a string into an integer via strtoll and returns false if there was +/// any kind of problem (i.e. the string was not an integer or contained extra +/// non-whitespace junk, or the integer was too large to fit into the type it is +/// being converted into). Only sets *out if everything was OK and it returns +/// true. +template +bool ConvertStringToInteger(const std::string &str, Int *out) { + // copied from kaldi/src/util/text-util.h + static_assert(std::is_integral::value, ""); + const char *this_str = str.c_str(); + char *end = nullptr; + errno = 0; + int64_t i = K2_STRTOLL(this_str, &end); + if (end != this_str) { + while (isspace(*end)) ++end; + } + if (end == this_str || *end != '\0' || errno != 0) return false; + Int iInt = static_cast(i); + if (static_cast(iInt) != i || + (i < 0 && !std::numeric_limits::is_signed)) { + return false; + } + *out = iInt; + return true; +} + +// copied from kaldi/src/util/text-util.cc +template +class NumberIstream { + public: + explicit NumberIstream(std::istream &i) : in_(i) {} + + NumberIstream &operator>>(T &x) { + if (!in_.good()) return *this; + in_ >> x; + if (!in_.fail() && RemainderIsOnlySpaces()) return *this; + return ParseOnFail(&x); + } + + private: + std::istream &in_; + + bool RemainderIsOnlySpaces() { + if (in_.tellg() != std::istream::pos_type(-1)) { + std::string rem; + in_ >> rem; + + if (rem.find_first_not_of(' ') != std::string::npos) { + // there is not only spaces + return false; + } + } + + in_.clear(); + return true; + } + + NumberIstream &ParseOnFail(T *x) { + std::string str; + in_.clear(); + in_.seekg(0); + // If the stream is broken even before trying + // to read from it or if there are many tokens, + // it's pointless to try. + if (!(in_ >> str) || !RemainderIsOnlySpaces()) { + in_.setstate(std::ios_base::failbit); + return *this; + } + + std::unordered_map inf_nan_map; + // we'll keep just uppercase values. + inf_nan_map["INF"] = std::numeric_limits::infinity(); + inf_nan_map["+INF"] = std::numeric_limits::infinity(); + inf_nan_map["-INF"] = -std::numeric_limits::infinity(); + inf_nan_map["INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["+INFINITY"] = std::numeric_limits::infinity(); + inf_nan_map["-INFINITY"] = -std::numeric_limits::infinity(); + inf_nan_map["NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["+NAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-NAN"] = -std::numeric_limits::quiet_NaN(); + // MSVC + inf_nan_map["1.#INF"] = std::numeric_limits::infinity(); + inf_nan_map["-1.#INF"] = -std::numeric_limits::infinity(); + inf_nan_map["1.#QNAN"] = std::numeric_limits::quiet_NaN(); + inf_nan_map["-1.#QNAN"] = -std::numeric_limits::quiet_NaN(); + + std::transform(str.begin(), str.end(), str.begin(), ::toupper); + + if (inf_nan_map.find(str) != inf_nan_map.end()) { + *x = inf_nan_map[str]; + } else { + in_.setstate(std::ios_base::failbit); + } + + return *this; + } +}; + +/// ConvertStringToReal converts a string into either float or double +/// and returns false if there was any kind of problem (i.e. the string +/// was not a floating point number or contained extra non-whitespace junk). +/// Be careful- this function will successfully read inf's or nan's. +template +bool ConvertStringToReal(const std::string &str, T *out) { + std::istringstream iss(str); + + NumberIstream i(iss); + + i >> *out; + + if (iss.fail()) { + // Number conversion failed. + return false; + } + + return true; +} + +ParseOptions::ParseOptions(const std::string &prefix, ParseOptions *po) + : print_args_(false), help_(false), usage_(""), argc_(0), argv_(nullptr) { + if (po != nullptr && po->other_parser_ != nullptr) { + // we get here if this constructor is used twice, recursively. + other_parser_ = po->other_parser_; + } else { + other_parser_ = po; + } + if (po != nullptr && po->prefix_ != "") { + prefix_ = po->prefix_ + std::string(".") + prefix; + } else { + prefix_ = prefix; + } +} + +void ParseOptions::Register(const std::string &name, bool *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, int32_t *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, uint32_t *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, float *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, double *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +void ParseOptions::Register(const std::string &name, std::string *ptr, + const std::string &doc) { + RegisterTmpl(name, ptr, doc); +} + +// old-style, used for registering application-specific parameters +template +void ParseOptions::RegisterTmpl(const std::string &name, T *ptr, + const std::string &doc) { + if (other_parser_ == nullptr) { + this->RegisterCommon(name, ptr, doc, false); + } else { + K2_CHECK(prefix_ != "") + << "prefix: " << prefix_ << "\n" + << "Cannot use empty prefix when registering with prefix."; + std::string new_name = prefix_ + '.' + name; // name becomes prefix.name + other_parser_->Register(new_name, ptr, doc); + } +} + +// does the common part of the job of registering a parameter +template +void ParseOptions::RegisterCommon(const std::string &name, T *ptr, + const std::string &doc, bool is_standard) { + K2_CHECK(ptr != nullptr); + std::string idx = name; + NormalizeArgName(&idx); + if (doc_map_.find(idx) != doc_map_.end()) { + K2_LOG(WARNING) << "Registering option twice, ignoring second time: " + << name; + } else { + this->RegisterSpecific(name, idx, ptr, doc, is_standard); + } +} + +// used to register standard parameters (those that are present in all of the +// applications) +template +void ParseOptions::RegisterStandard(const std::string &name, T *ptr, + const std::string &doc) { + this->RegisterCommon(name, ptr, doc, true); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, bool *b, + const std::string &doc, bool is_standard) { + bool_map_[idx] = b; + doc_map_[idx] = + DocInfo(name, doc + " (bool, default = " + ((*b) ? "true)" : "false)"), + is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, int32_t *i, + const std::string &doc, bool is_standard) { + int_map_[idx] = i; + std::ostringstream ss; + ss << doc << " (int, default = " << *i << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, uint32_t *u, + const std::string &doc, bool is_standard) { + uint_map_[idx] = u; + std::ostringstream ss; + ss << doc << " (uint, default = " << *u << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, float *f, + const std::string &doc, bool is_standard) { + float_map_[idx] = f; + std::ostringstream ss; + ss << doc << " (float, default = " << *f << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, double *f, + const std::string &doc, bool is_standard) { + double_map_[idx] = f; + std::ostringstream ss; + ss << doc << " (double, default = " << *f << ")"; + doc_map_[idx] = DocInfo(name, ss.str(), is_standard); +} + +void ParseOptions::RegisterSpecific(const std::string &name, + const std::string &idx, std::string *s, + const std::string &doc, bool is_standard) { + string_map_[idx] = s; + doc_map_[idx] = + DocInfo(name, doc + " (string, default = \"" + *s + "\")", is_standard); +} + +void ParseOptions::DisableOption(const std::string &name) { + if (argv_ != nullptr) { + K2_LOG(FATAL) << "DisableOption must not be called after calling Read()."; + } + if (doc_map_.erase(name) == 0) { + K2_LOG(FATAL) << "Option " << name + << " was not registered so cannot be disabled: "; + } + bool_map_.erase(name); + int_map_.erase(name); + uint_map_.erase(name); + float_map_.erase(name); + double_map_.erase(name); + string_map_.erase(name); +} + +int ParseOptions::NumArgs() const { return positional_args_.size(); } + +std::string ParseOptions::GetArg(int i) const { + if (i < 1 || i > static_cast(positional_args_.size())) { + K2_LOG(FATAL) << "ParseOptions::GetArg, invalid index " << i; + } + + return positional_args_[i - 1]; +} + +// We currently do not support any other options. +enum ShellType { kBash = 0 }; + +// This can be changed in the code if it ever does need to be changed (as it's +// unlikely that one compilation of this tool-set would use both shells). +static ShellType kShellType = kBash; + +// Returns true if we need to escape a string before putting it into +// a shell (mainly thinking of bash shell, but should work for others) +// This is for the convenience of the user so command-lines that are +// printed out by ParseOptions::Read (with --print-args=true) are +// paste-able into the shell and will run. If you use a different type of +// shell, it might be necessary to change this function. +// But it's mostly a cosmetic issue as it basically affects how +// the program echoes its command-line arguments to the screen. +static bool MustBeQuoted(const std::string &str, ShellType st) { + // Only Bash is supported (for the moment). + K2_CHECK_EQ(st, kBash) << "Invalid shell type."; + + const char *c = str.c_str(); + if (*c == '\0') { + return true; // Must quote empty string + } else { + const char *ok_chars[2]; + + // These seem not to be interpreted as long as there are no other "bad" + // characters involved (e.g. "," would be interpreted as part of something + // like a{b,c}, but not on its own. + ok_chars[kBash] = "[]~#^_-+=:.,/"; + + // Just want to make sure that a space character doesn't get automatically + // inserted here via an automated style-checking script, like it did before. + K2_CHECK(!strchr(ok_chars[kBash], ' ')); + + for (; *c != '\0'; ++c) { + // For non-alphanumeric characters we have a list of characters which + // are OK. All others are forbidden (this is easier since the shell + // interprets most non-alphanumeric characters). + if (!isalnum(*c)) { + const char *d; + for (d = ok_chars[st]; *d != '\0'; ++d) { + if (*c == *d) break; + } + // If not alphanumeric or one of the "ok_chars", it must be escaped. + if (*d == '\0') return true; + } + } + return false; // The string was OK. No quoting or escaping. + } +} + +// Returns a quoted and escaped version of "str" +// which has previously been determined to need escaping. +// Our aim is to print out the command line in such a way that if it's +// pasted into a shell of ShellType "st" (only bash for now), it +// will get passed to the program in the same way. +static std::string QuoteAndEscape(const std::string &str, ShellType st) { + // Only Bash is supported (for the moment). + K2_CHECK_EQ(st, kBash) << "Invalid shell type."; + + // For now we use the following rules: + // In the normal case, we quote with single-quote "'", and to escape + // a single-quote we use the string: '\'' (interpreted as closing the + // single-quote, putting an escaped single-quote from the shell, and + // then reopening the single quote). + char quote_char = '\''; + const char *escape_str = "'\\''"; // e.g. echo 'a'\''b' returns a'b + + // If the string contains single-quotes that would need escaping this + // way, and we determine that the string could be safely double-quoted + // without requiring any escaping, then we double-quote the string. + // This is the case if the characters "`$\ do not appear in the string. + // e.g. see http://www.redhat.com/mirrors/LDP/LDP/abs/html/quotingvar.html + const char *c_str = str.c_str(); + if (strchr(c_str, '\'') && !strpbrk(c_str, "\"`$\\")) { + quote_char = '"'; + escape_str = "\\\""; // should never be accessed. + } + + char buf[2]; + buf[1] = '\0'; + + buf[0] = quote_char; + std::string ans = buf; + const char *c = str.c_str(); + for (; *c != '\0'; ++c) { + if (*c == quote_char) { + ans += escape_str; + } else { + buf[0] = *c; + ans += buf; + } + } + buf[0] = quote_char; + ans += buf; + return ans; +} + +// static function +std::string ParseOptions::Escape(const std::string &str) { + return MustBeQuoted(str, kShellType) ? QuoteAndEscape(str, kShellType) : str; +} + +int ParseOptions::Read(int argc, const char *const argv[]) { + argc_ = argc; + argv_ = argv; + std::string key, value; + int i; + + // first pass: look for config parameter, look for priority + for (i = 1; i < argc; ++i) { + if (std::strncmp(argv[i], "--", 2) == 0) { + if (std::strcmp(argv[i], "--") == 0) { + // a lone "--" marks the end of named options + break; + } + bool has_equal_sign; + SplitLongArg(argv[i], &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (key.compare("config") == 0) { + ReadConfigFile(value); + } else if (key.compare("help") == 0) { + PrintUsage(); + exit(0); + } + } + } + + bool double_dash_seen = false; + // second pass: add the command line options + for (i = 1; i < argc; ++i) { + if (std::strncmp(argv[i], "--", 2) == 0) { + if (std::strcmp(argv[i], "--") == 0) { + // A lone "--" marks the end of named options. + // Skip that option and break the processing of named options + i += 1; + double_dash_seen = true; + break; + } + bool has_equal_sign; + SplitLongArg(argv[i], &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (!SetOption(key, value, has_equal_sign)) { + PrintUsage(true); + K2_LOG(FATAL) << "Invalid option " << argv[i]; + } + } else { + break; + } + } + + // process remaining arguments as positional + for (; i < argc; ++i) { + if ((std::strcmp(argv[i], "--") == 0) && !double_dash_seen) { + double_dash_seen = true; + } else { + positional_args_.push_back(std::string(argv[i])); + } + } + + // if the user did not suppress this with --print-args = false.... + if (print_args_) { + std::ostringstream strm; + for (int j = 0; j < argc; ++j) strm << Escape(argv[j]) << " "; + strm << '\n'; + K2_LOG(INFO) << strm.str(); + } + return i; +} + +void ParseOptions::PrintUsage(bool print_command_line /*=false*/) const { + std::ostringstream os; + os << '\n' << usage_ << '\n'; + // first we print application-specific options + bool app_specific_header_printed = false; + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) { + if (it->second.is_standard_ == false) { // application-specific option + if (app_specific_header_printed == false) { // header was not yet printed + os << "Options:" << '\n'; + app_specific_header_printed = true; + } + os << " --" << std::setw(25) << std::left << it->second.name_ << " : " + << it->second.use_msg_ << '\n'; + } + } + if (app_specific_header_printed == true) { + os << '\n'; + } + + // then the standard options + os << "Standard options:" << '\n'; + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) { + if (it->second.is_standard_ == true) { // we have standard option + os << " --" << std::setw(25) << std::left << it->second.name_ << " : " + << it->second.use_msg_ << '\n'; + } + } + os << '\n'; + if (print_command_line) { + std::ostringstream strm; + strm << "Command line was: "; + for (int j = 0; j < argc_; ++j) strm << Escape(argv_[j]) << " "; + strm << '\n'; + os << strm.str(); + } + + K2_LOG(INFO) << os.str(); +} + +void ParseOptions::PrintConfig(std::ostream &os) const { + os << '\n' << "[[ Configuration of UI-Registered options ]]" << '\n'; + std::string key; + for (auto it = doc_map_.begin(); it != doc_map_.end(); ++it) { + key = it->first; + os << it->second.name_ << " = "; + if (bool_map_.end() != bool_map_.find(key)) { + os << (*bool_map_.at(key) ? "true" : "false"); + } else if (int_map_.end() != int_map_.find(key)) { + os << (*int_map_.at(key)); + } else if (uint_map_.end() != uint_map_.find(key)) { + os << (*uint_map_.at(key)); + } else if (float_map_.end() != float_map_.find(key)) { + os << (*float_map_.at(key)); + } else if (double_map_.end() != double_map_.find(key)) { + os << (*double_map_.at(key)); + } else if (string_map_.end() != string_map_.find(key)) { + os << "'" << *string_map_.at(key) << "'"; + } else { + K2_LOG(FATAL) << "PrintConfig: unrecognized option " << key + << "[code error]"; + } + os << '\n'; + } + os << '\n'; +} + +void ParseOptions::ReadConfigFile(const std::string &filename) { + std::ifstream is(filename.c_str(), std::ifstream::in); + if (!is.good()) { + K2_LOG(FATAL) << "Cannot open config file: " << filename; + } + + std::string line, key, value; + int32_t line_number = 0; + while (std::getline(is, line)) { + ++line_number; + // trim out the comments + size_t pos; + if ((pos = line.find_first_of('#')) != std::string::npos) { + line.erase(pos); + } + // skip empty lines + Trim(&line); + if (line.length() == 0) continue; + + if (line.substr(0, 2) != "--") { + K2_LOG(FATAL) + << "Reading config file " << filename << ": line " << line_number + << " does not look like a line " + << "from a Kaldi command-line program's config file: should " + << "be of the form --x=y. Note: config files intended to " + << "be sourced by shell scripts lack the '--'."; + } + + // parse option + bool has_equal_sign; + SplitLongArg(line, &key, &value, &has_equal_sign); + NormalizeArgName(&key); + Trim(&value); + if (!SetOption(key, value, has_equal_sign)) { + PrintUsage(true); + K2_LOG(FATAL) << "Invalid option " << line << " in config file " + << filename << ": line " << line_number; + } + } +} + +void ParseOptions::SplitLongArg(const std::string &in, std::string *key, + std::string *value, + bool *has_equal_sign) const { + K2_CHECK(in.substr(0, 2) == "--") << in; // precondition. + size_t pos = in.find_first_of('=', 0); + if (pos == std::string::npos) { // we allow --option for bools + // defaults to empty. We handle this differently in different cases. + *key = in.substr(2, in.size() - 2); // 2 because starts with --. + *value = ""; + *has_equal_sign = false; + } else if (pos == 2) { // we also don't allow empty keys: --=value + PrintUsage(true); + K2_LOG(FATAL) << "Invalid option (no key): " << in; + } else { // normal case: --option=value + *key = in.substr(2, pos - 2); // 2 because starts with --. + *value = in.substr(pos + 1); + *has_equal_sign = true; + } +} + +void ParseOptions::NormalizeArgName(std::string *str) const { + std::string out; + std::string::iterator it; + + for (it = str->begin(); it != str->end(); ++it) { + if (*it == '_') { + out += '-'; // convert _ to - + } else { + out += std::tolower(*it); + } + } + *str = out; + + K2_CHECK_GT(str->length(), 0); +} + +void ParseOptions::Trim(std::string *str) const { + const char *white_chars = " \t\n\r\f\v"; + + std::string::size_type pos = str->find_last_not_of(white_chars); + if (pos != std::string::npos) { + str->erase(pos + 1); + pos = str->find_first_not_of(white_chars); + if (pos != std::string::npos) str->erase(0, pos); + } else { + str->erase(str->begin(), str->end()); + } +} + +bool ParseOptions::SetOption(const std::string &key, const std::string &value, + bool has_equal_sign) { + if (bool_map_.end() != bool_map_.find(key)) { + if (has_equal_sign && value == "") { + K2_LOG(FATAL) << "Invalid option --" << key << "="; + } + *(bool_map_[key]) = ToBool(value); + } else if (int_map_.end() != int_map_.find(key)) { + *(int_map_[key]) = ToInt(value); + } else if (uint_map_.end() != uint_map_.find(key)) { + *(uint_map_[key]) = ToUint(value); + } else if (float_map_.end() != float_map_.find(key)) { + *(float_map_[key]) = ToFloat(value); + } else if (double_map_.end() != double_map_.find(key)) { + *(double_map_[key]) = ToDouble(value); + } else if (string_map_.end() != string_map_.find(key)) { + if (!has_equal_sign) { + K2_LOG(FATAL) << "Invalid option --" << key + << " (option format is --x=y)."; + } + *(string_map_[key]) = value; + } else { + return false; + } + return true; +} + +bool ParseOptions::ToBool(std::string str) const { + std::transform(str.begin(), str.end(), str.begin(), ::tolower); + + // allow "" as a valid option for "true", so that --x is the same as --x=true + if ((str.compare("true") == 0) || (str.compare("t") == 0) || + (str.compare("1") == 0) || (str.compare("") == 0)) { + return true; + } + if ((str.compare("false") == 0) || (str.compare("f") == 0) || + (str.compare("0") == 0)) { + return false; + } + // if it is neither true nor false: + PrintUsage(true); + K2_LOG(FATAL) + << "Invalid format for boolean argument [expected true or false]: " + << str; + return false; // never reached +} + +int32_t ParseOptions::ToInt(const std::string &str) const { + int32_t ret; + if (!ConvertStringToInteger(str, &ret)) + K2_LOG(FATAL) << "Invalid integer option \"" << str << "\""; + return ret; +} + +uint32_t ParseOptions::ToUint(const std::string &str) const { + uint32_t ret; + if (!ConvertStringToInteger(str, &ret)) + K2_LOG(FATAL) << "Invalid integer option \"" << str << "\""; + return ret; +} + +float ParseOptions::ToFloat(const std::string &str) const { + float ret; + if (!ConvertStringToReal(str, &ret)) + K2_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; + return ret; +} + +double ParseOptions::ToDouble(const std::string &str) const { + double ret; + if (!ConvertStringToReal(str, &ret)) + K2_LOG(FATAL) << "Invalid floating-point option \"" << str << "\""; + return ret; +} + +// instantiate templates +template void ParseOptions::RegisterTmpl(const std::string &name, bool *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, int32_t *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, uint32_t *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, float *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, double *ptr, + const std::string &doc); +template void ParseOptions::RegisterTmpl(const std::string &name, + std::string *ptr, + const std::string &doc); + +template void ParseOptions::RegisterStandard(const std::string &name, bool *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + int32_t *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + uint32_t *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + float *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + double *ptr, + const std::string &doc); +template void ParseOptions::RegisterStandard(const std::string &name, + std::string *ptr, + const std::string &doc); + +template void ParseOptions::RegisterCommon(const std::string &name, bool *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + int32_t *ptr, const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + uint32_t *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, float *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, double *ptr, + const std::string &doc, + bool is_standard); +template void ParseOptions::RegisterCommon(const std::string &name, + std::string *ptr, + const std::string &doc, + bool is_standard); + +} // namespace k2 diff --git a/k2/torch/csrc/parse_options.h b/k2/torch/csrc/parse_options.h new file mode 100644 index 000000000..b15bbc870 --- /dev/null +++ b/k2/torch/csrc/parse_options.h @@ -0,0 +1,268 @@ +/** + * Copyright 2009-2011 Karel Vesely; Microsoft Corporation; + * Saarland University (Author: Arnab Ghoshal); + * Copyright 2012-2013 Frantisek Skala; Arnab Ghoshal + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// This file is copied and modified from kaldi/src/util/parse-options.h + +#ifndef K2_TORCH_CSRC_PARSE_OPTIONS_H_ +#define K2_TORCH_CSRC_PARSE_OPTIONS_H_ + +#include +#include +#include +#include + +namespace k2 { + +class ParseOptions { + public: + explicit ParseOptions(const char *usage) + : print_args_(true), + help_(false), + usage_(usage), + argc_(0), + argv_(nullptr), + prefix_(""), + other_parser_(nullptr) { +#if !defined(_MSC_VER) && !defined(__CYGWIN__) + // This is just a convenient place to set the stderr to line + // buffering mode, since it's called at program start. + // This helps ensure different programs' output is not mixed up. + setlinebuf(stderr); +#endif + RegisterStandard("config", &config_, + "Configuration file to read (this " + "option may be repeated)"); + RegisterStandard("print-args", &print_args_, + "Print the command line arguments (to stderr)"); + RegisterStandard("help", &help_, "Print out usage message"); + } + + /** + This is a constructor for the special case where some options are + registered with a prefix to avoid conflicts. The object thus created will + only be used temporarily to register an options class with the original + options parser (which is passed as the *other pointer) using the given + prefix. It should not be used for any other purpose, and the prefix must + not be the empty string. It seems to be the least bad way of implementing + options with prefixes at this point. + Example of usage is: + ParseOptions po; // original ParseOptions object + ParseOptions po_mfcc("mfcc", &po); // object with prefix. + MfccOptions mfcc_opts; + mfcc_opts.Register(&po_mfcc); + The options will now get registered as, e.g., --mfcc.frame-shift=10.0 + instead of just --frame-shift=10.0 + */ + ParseOptions(const std::string &prefix, ParseOptions *other); + + ParseOptions(const ParseOptions &) = delete; + ParseOptions &operator=(const ParseOptions &) = delete; + ~ParseOptions() = default; + + void Register(const std::string &name, bool *ptr, const std::string &doc); + void Register(const std::string &name, int32_t *ptr, const std::string &doc); + void Register(const std::string &name, uint32_t *ptr, const std::string &doc); + void Register(const std::string &name, float *ptr, const std::string &doc); + void Register(const std::string &name, double *ptr, const std::string &doc); + void Register(const std::string &name, std::string *ptr, + const std::string &doc); + + /// If called after registering an option and before calling + /// Read(), disables that option from being used. Will crash + /// at runtime if that option had not been registered. + void DisableOption(const std::string &name); + + /// This one is used for registering standard parameters of all the programs + template + void RegisterStandard(const std::string &name, T *ptr, + const std::string &doc); + + /** + Parses the command line options and fills the ParseOptions-registered + variables. This must be called after all the variables were registered!!! + + Initially the variables have implicit values, + then the config file values are set-up, + finally the command line values given. + Returns the first position in argv that was not used. + [typically not useful: use NumParams() and GetParam(). ] + */ + int Read(int argc, const char *const *argv); + + /// Prints the usage documentation [provided in the constructor]. + void PrintUsage(bool print_command_line = false) const; + + /// Prints the actual configuration of all the registered variables + void PrintConfig(std::ostream &os) const; + + /// Reads the options values from a config file. Must be called after + /// registering all options. This is usually used internally after the + /// standard --config option is used, but it may also be called from a + /// program. + void ReadConfigFile(const std::string &filename); + + /// Number of positional parameters (c.f. argc-1). + int NumArgs() const; + + /// Returns one of the positional parameters; 1-based indexing for argc/argv + /// compatibility. Will crash if param is not >=1 and <=NumArgs(). + /// + /// Note: Index is 1 based. + std::string GetArg(int param) const; + + std::string GetOptArg(int param) const { + return (param <= NumArgs() ? GetArg(param) : ""); + } + + /// The following function will return a possibly quoted and escaped + /// version of "str", according to the current shell. Currently + /// this is just hardwired to bash. It's useful for debug output. + static std::string Escape(const std::string &str); + + private: + /// Template to register various variable types, + /// used for program-specific parameters + template + void RegisterTmpl(const std::string &name, T *ptr, const std::string &doc); + + // Following functions do just the datatype-specific part of the job + /// Register boolean variable + void RegisterSpecific(const std::string &name, const std::string &idx, + bool *b, const std::string &doc, bool is_standard); + /// Register int32_t variable + void RegisterSpecific(const std::string &name, const std::string &idx, + int32_t *i, const std::string &doc, bool is_standard); + /// Register unsigned int32_t variable + void RegisterSpecific(const std::string &name, const std::string &idx, + uint32_t *u, const std::string &doc, bool is_standard); + /// Register float variable + void RegisterSpecific(const std::string &name, const std::string &idx, + float *f, const std::string &doc, bool is_standard); + /// Register double variable [useful as we change BaseFloat type]. + void RegisterSpecific(const std::string &name, const std::string &idx, + double *f, const std::string &doc, bool is_standard); + /// Register string variable + void RegisterSpecific(const std::string &name, const std::string &idx, + std::string *s, const std::string &doc, + bool is_standard); + + /// Does the actual job for both kinds of parameters + /// Does the common part of the job for all datatypes, + /// then calls RegisterSpecific + template + void RegisterCommon(const std::string &name, T *ptr, const std::string &doc, + bool is_standard); + + /// Set option with name "key" to "value"; will crash if can't do it. + /// "has_equal_sign" is used to allow --x for a boolean option x, + /// and --y=, for a string option y. + bool SetOption(const std::string &key, const std::string &value, + bool has_equal_sign); + + bool ToBool(std::string str) const; + int32_t ToInt(const std::string &str) const; + uint32_t ToUint(const std::string &str) const; + float ToFloat(const std::string &str) const; + double ToDouble(const std::string &str) const; + + // maps for option variables + std::unordered_map bool_map_; + std::unordered_map int_map_; + std::unordered_map uint_map_; + std::unordered_map float_map_; + std::unordered_map double_map_; + std::unordered_map string_map_; + + /** + Structure for options' documentation + */ + struct DocInfo { + DocInfo() = default; + DocInfo(const std::string &name, const std::string &usemsg) + : name_(name), use_msg_(usemsg), is_standard_(false) {} + DocInfo(const std::string &name, const std::string &usemsg, + bool is_standard) + : name_(name), use_msg_(usemsg), is_standard_(is_standard) {} + + std::string name_; + std::string use_msg_; + bool is_standard_; + }; + using DocMapType = std::unordered_map; + DocMapType doc_map_; ///< map for the documentation + + bool print_args_; ///< variable for the implicit --print-args parameter + bool help_; ///< variable for the implicit --help parameter + std::string config_; ///< variable for the implicit --config parameter + std::vector positional_args_; + const char *usage_; + int argc_; + const char *const *argv_; + + /// These members are not normally used. They are only used when the object + /// is constructed with a prefix + std::string prefix_; + ParseOptions *other_parser_; + + protected: + /// SplitLongArg parses an argument of the form --a=b, --a=, or --a, + /// and sets "has_equal_sign" to true if an equals-sign was parsed.. + /// this is needed in order to correctly allow --x for a boolean option + /// x, and --y= for a string option y, and to disallow --x= and --y. + void SplitLongArg(const std::string &in, std::string *key, std::string *value, + bool *has_equal_sign) const; + + void NormalizeArgName(std::string *str) const; + + /// Removes the beginning and trailing whitespaces from a string + void Trim(std::string *str) const; +}; + +/// This template is provided for convenience in reading config classes from +/// files; this is not the standard way to read configuration options, but may +/// occasionally be needed. This function assumes the config has a function +/// "void Register(ParseOptions *opts)" which it can call to register the +/// ParseOptions object. +template +void ReadConfigFromFile(const std::string &config_filename, C *c) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << config_filename << "'"; + ParseOptions po(usage_str.str().c_str()); + c->Register(&po); + po.ReadConfigFile(config_filename); +} + +/// This variant of the template ReadConfigFromFile is for if you need to read +/// two config classes from the same file. +template +void ReadConfigsFromFile(const std::string &conf, C1 *c1, C2 *c2) { + std::ostringstream usage_str; + usage_str << "Parsing config from " + << "from '" << conf << "'"; + ParseOptions po(usage_str.str().c_str()); + c1->Register(&po); + c2->Register(&po); + po.ReadConfigFile(conf); +} + +} // namespace k2 + +#endif // K2_TORCH_CSRC_PARSE_OPTIONS_H_ diff --git a/k2/torch/csrc/parse_options_test.cu b/k2/torch/csrc/parse_options_test.cu new file mode 100644 index 000000000..db1c2fac8 --- /dev/null +++ b/k2/torch/csrc/parse_options_test.cu @@ -0,0 +1,302 @@ +/** + * Copyright 2022 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include "gtest/gtest.h" +#include "k2/torch/csrc/parse_options.h" + +namespace k2 { + +struct MyOptions { + bool b = false; + int32_t i32 = 1; + uint32_t u32 = 2; + float f = 3; + double d = 4; + std::string s; + + void Register(ParseOptions *po) { + po->Register("my-bool", &b, "A bool variable in MyOptions."); + po->Register("my-i32", &i32, "An int32 variable in MyOptions."); + + po->Register("my-u32", &u32, "An uint32 variable in MyOptions."); + + po->Register("my-f", &f, "A float variable in MyOptions."); + + po->Register("my-d", &d, "A double variable in MyOptions."); + + po->Register("my-s", &s, "A string variable in MyOptions."); + } +}; + +TEST(ParseOptions, FromCommandline) { + int32_t a; + double d; + const char *const argv[] = {"./a.out", "--my-bool=1", "--my-i32=100", + "--my-u32=8", "--my-f=0.5", "--my-d=1.5", + "--my-s=hello", "--a=3", "--d=-1.25", + "--print-args", "foo", "bar"}; + int32_t argc = sizeof(argv) / sizeof(argv[0]); + ParseOptions po("Test parsing from the commandline"); + MyOptions opts; + opts.Register(&po); + po.Register("a", &a, "An integer variable"); + po.Register("d", &d, "A double variable"); + po.Read(argc, argv); + + EXPECT_EQ(a, 3); + EXPECT_EQ(d, -1.25); + EXPECT_EQ(opts.b, true); + EXPECT_EQ(opts.i32, 100); + EXPECT_EQ(opts.u32, 8); + EXPECT_EQ(opts.f, 0.5); + EXPECT_EQ(opts.d, 1.5); + EXPECT_EQ(opts.s, "hello"); + + EXPECT_EQ(po.NumArgs(), 2); + EXPECT_EQ(po.GetArg(1), "foo"); + EXPECT_EQ(po.GetArg(2), "bar"); +} + +TEST(ParseOptions, FromCommandlineWithPrefix) { + int32_t a; + double d; + const char *const argv[] = {"./a.out", + "--print-args", + "--k2.my-bool=1", + "--k2.my-i32=100", + "--k2.my-u32=8", + "--k2.my-f=0.5", + "--k2.my-d=1.5", + "--k2.my-s=hello", + "--a=3", + "--d=-1.25", + "foo", + "bar"}; + int32_t argc = sizeof(argv) / sizeof(argv[0]); + ParseOptions po("Test parsing from the commandline with prefix"); + ParseOptions po2("k2", &po); + MyOptions opts; + opts.Register(&po2); + po.Register("a", &a, "An integer variable"); + po.Register("d", &d, "A double variable"); + po.Read(argc, argv); + + EXPECT_EQ(a, 3); + EXPECT_EQ(d, -1.25); + EXPECT_EQ(opts.b, true); + EXPECT_EQ(opts.i32, 100); + EXPECT_EQ(opts.u32, 8); + EXPECT_EQ(opts.f, 0.5); + EXPECT_EQ(opts.d, 1.5); + EXPECT_EQ(opts.s, "hello"); + + EXPECT_EQ(po.NumArgs(), 2); + EXPECT_EQ(po.GetArg(1), "foo"); + EXPECT_EQ(po.GetArg(2), "bar"); +} + +TEST(ParseOptions, FromCommandlineWithTwoPrefixes) { + int32_t a; + double d; + const char *const argv[] = {"./a.out", + "--print-args", + "--k2.torch.my-bool=1", + "--k2.torch.my-i32=100", + "--k2.torch.my-u32=8", + "--k2.torch.my-f=0.5", + "--k2.torch.my-d=1.5", + "--k2.torch.my-s=hello", + "--a=3", + "--d=-1.25", + "foo", + "bar"}; + int32_t argc = sizeof(argv) / sizeof(argv[0]); + ParseOptions po("Test parsing from the commandline with two prefixes"); + ParseOptions po2("k2", &po); + ParseOptions po3("torch", &po2); + MyOptions opts; + opts.Register(&po3); + po.Register("a", &a, "An integer variable"); + po.Register("d", &d, "A double variable"); + po.Read(argc, argv); + + EXPECT_EQ(a, 3); + EXPECT_EQ(d, -1.25); + EXPECT_EQ(opts.b, true); + EXPECT_EQ(opts.i32, 100); + EXPECT_EQ(opts.u32, 8); + EXPECT_EQ(opts.f, 0.5); + EXPECT_EQ(opts.d, 1.5); + EXPECT_EQ(opts.s, "hello"); + + EXPECT_EQ(po.NumArgs(), 2); + EXPECT_EQ(po.GetArg(1), "foo"); + EXPECT_EQ(po.GetArg(2), "bar"); +} + +TEST(ParseOptions, ParseHelp) { + const char *const argv[] = {"./a.out", "--help"}; + int32_t argc = sizeof(argv) / sizeof(argv[0]); + + ParseOptions po("Parse help"); + MyOptions opts; + opts.Register(&po); + + EXPECT_EXIT(po.Read(argc, argv), testing::ExitedWithCode(0), ""); +} + +TEST(ParseOptions, ParseFromFile) { + std::string filename = "my-options-for-parse-options.txt"; + { + std::ofstream of(filename); + + of << "--my-bool=1\n"; + of << "--my-i32=-100\n"; + of << "--my-s=hello\n"; + } + + const char *const argv[] = { + "./a.out", "--config=my-options-for-parse-options.txt", + "--my-u32=8", "--my-f=0.5", + "--my-d=1.5", "--my-s=world", + "--print-args", "foo", + "bar"}; + + int32_t argc = sizeof(argv) / sizeof(argv[0]); + ParseOptions po("Test parsing from the commandline and config file"); + + MyOptions opts; + opts.Register(&po); + + po.Read(argc, argv); + + EXPECT_EQ(opts.b, true); + EXPECT_EQ(opts.i32, -100); + EXPECT_EQ(opts.u32, 8); + EXPECT_EQ(opts.f, 0.5); + EXPECT_EQ(opts.d, 1.5); + EXPECT_EQ(opts.s, "world"); // commandline options have a higher priority + + remove(filename.c_str()); +} + +TEST(ParseOptions, ParseFromMultipleFiles) { + std::string filename1 = "my-options-for-parse-options1.txt"; + std::string filename2 = "my-options-for-parse-options2.txt"; + { + std::ofstream of(filename1); + + of << "--my-bool=1\n"; + of << "--my-i32=-100\n"; + } + + { + std::ofstream of(filename2); + + of << "--my-s=hello\n"; + } + + const char *const argv[] = {"./a.out", + "--config=my-options-for-parse-options1.txt", + "--config=my-options-for-parse-options2.txt", + "--my-u32=8", + "--my-f=0.5", + "--my-d=1.5", + "--print-args", + "foo", + "bar"}; + + int32_t argc = sizeof(argv) / sizeof(argv[0]); + ParseOptions po("Test parsing from the commandline and config files"); + + MyOptions opts; + opts.Register(&po); + + po.Read(argc, argv); + + EXPECT_EQ(opts.b, true); + EXPECT_EQ(opts.i32, -100); + EXPECT_EQ(opts.u32, 8); + EXPECT_EQ(opts.f, 0.5); + EXPECT_EQ(opts.d, 1.5); + EXPECT_EQ(opts.s, "hello"); + + remove(filename1.c_str()); + remove(filename2.c_str()); +} + +TEST(ParseOptions, Duplicates) { + int32_t a = 10; + int32_t b = 20; + ParseOptions po("Test duplicates"); + po.Register("i", &a, "My integer option"); + po.Register("i", &b, "My integer option"); + // The second one is ignored + const char *const argv[] = {"./a.out", "--i=3"}; + int32_t argc = sizeof(argv) / sizeof(argv[0]); + po.Read(argc, argv); + + EXPECT_EQ(a, 3); + EXPECT_EQ(b, 20); + EXPECT_EQ(po.NumArgs(), 0); +} + +TEST(ParseOptions, DoubleDash) { + int32_t a = 10; + + const char *const argv[] = {"./a.out", "--i=3", "--", "--foo=bar", "baz"}; + int32_t argc = sizeof(argv) / sizeof(argv[0]); + + ParseOptions po("Test double dash"); + po.Register("i", &a, "My integer option"); + po.Read(argc, argv); + + EXPECT_EQ(a, 3); + EXPECT_EQ(po.NumArgs(), 2); + EXPECT_EQ(po.GetArg(1), "--foo=bar"); + EXPECT_EQ(po.GetArg(2), "baz"); +} + +TEST(ReadConfigFromFile, OneOption) { + std::string filename = "my-options-for-parse-options.txt"; + { + std::ofstream of(filename); + + of << "--my-bool=1\n"; + of << "--my-i32=-100\n"; + of << "--my-u32=1000\n"; + of << "--my-f=-0.5\n"; + of << "--my-d=3.5\n"; + of << "--my-s=hello world\n"; + } + MyOptions opts; + ReadConfigFromFile(filename, &opts); + + EXPECT_EQ(opts.b, true); + EXPECT_EQ(opts.i32, -100); + EXPECT_EQ(opts.u32, 1000); + EXPECT_EQ(opts.f, -0.5); + EXPECT_EQ(opts.d, 3.5); + EXPECT_EQ(opts.s, "hello world"); + + remove(filename.c_str()); +} + +} // namespace k2 diff --git a/k2/torch/csrc/symbol_table.cu b/k2/torch/csrc/symbol_table.cu new file mode 100644 index 000000000..015479490 --- /dev/null +++ b/k2/torch/csrc/symbol_table.cu @@ -0,0 +1,78 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "k2/csrc/log.h" +#include "k2/torch/csrc/symbol_table.h" + +namespace k2 { + +SymbolTable::SymbolTable(const std::string &filename) { + std::ifstream is(filename); + std::string sym; + int32_t id; + while (is >> sym >> id) { + if (sym.size() >= 3) { + // For BPE-based models, we replace ▁ with a space + // Unicode 9601, hex 0x2581, utf8 0xe29681 + const uint8_t *p = reinterpret_cast(sym.c_str()); + if (p[0] == 0xe2 && p[1] == 0x96 && p[2] == 0x81) { + sym = sym.replace(0, 3, " "); + } + } + + K2_CHECK(!sym.empty()); + K2_CHECK_EQ(sym2id_.count(sym), 0) << "Duplicated symbol: " << sym; + K2_CHECK_EQ(id2sym_.count(id), 0) << "Duplicated ID: " << id; + + sym2id_.insert({sym, id}); + id2sym_.insert({id, sym}); + } + K2_CHECK(is.eof()); +} + +std::string SymbolTable::ToString() const { + std::ostringstream os; + char sep = ' '; + for (const auto &p : sym2id_) { + os << p.first << sep << p.second << "\n"; + } + return os.str(); +} + +const std::string &SymbolTable::operator[](int32_t id) const { + return id2sym_.at(id); +} + +int32_t SymbolTable::operator[](const std::string &sym) const { + return sym2id_.at(sym); +} + +bool SymbolTable::contains(int32_t id) const { return id2sym_.count(id) != 0; } + +bool SymbolTable::contains(const std::string &sym) const { + return sym2id_.count(sym) != 0; +} + +std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table) { + return os << symbol_table.ToString(); +} + +} // namespace k2 diff --git a/k2/torch/csrc/symbol_table.h b/k2/torch/csrc/symbol_table.h new file mode 100644 index 000000000..662754f78 --- /dev/null +++ b/k2/torch/csrc/symbol_table.h @@ -0,0 +1,61 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_SYMBOL_TABLE_H_ +#define K2_TORCH_CSRC_SYMBOL_TABLE_H_ + +#include +#include + +namespace k2 { + +/// It manages mapping between symbols and integer IDs. +class SymbolTable { + public: + /// Construct a symbol table from a file. + /// Each line in the file contains two fields: + /// + /// sym ID + /// + /// Fields are separated by space(s). + explicit SymbolTable(const std::string &filename); + + /// Return a string representation of this symbol table + std::string ToString() const; + + /// Return the symbol corresponding to the given ID. + const std::string &operator[](int32_t id) const; + /// Return the ID corresponding to the given symbol. + int32_t operator[](const std::string &sym) const; + + /// Return true if there is a symbol with the given ID. + bool contains(int32_t id) const; + + /// Return true if there is a given symbol in the symbol table. + bool contains(const std::string &sym) const; + + private: + std::unordered_map sym2id_; + std::unordered_map id2sym_; +}; + +std::ostream &operator<<(std::ostream &os, const SymbolTable &symbol_table); + +} // namespace k2 + +#endif // K2_TORCH_CSRC_SYMBOL_TABLE_H_ diff --git a/k2/torch/csrc/test_deserialization_data.h b/k2/torch/csrc/test_deserialization_data.h new file mode 100644 index 000000000..861d0bbe5 --- /dev/null +++ b/k2/torch/csrc/test_deserialization_data.h @@ -0,0 +1,451 @@ +// This file contains pre-generated test data to test +// deserialization. It is used only in deserialization_test.cu +// +// clang-format off + +/* The following array is generated using the following steps: +(1) Python code +``` +import torch +d1 = {"a": torch.Tensor([1, 2]), "b": 10, "c": "k2"} +torch.save(d1, "d1.pt") +``` + +(2) Bash command +``` +bin2c --name kTestLoadData1 d1.pt > xxx.h +``` + +(3) Copy the content in xxx.h to this file + +So kTestLoadData1 contains a dict containing: +- key "a", value: torch.tensor([1, 2], dtype=torch.float32) +- key "b", value: 10 +- key "c", value: "k2" +*/ +static const uint8_t kTestLoadData1[] = { +0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x10,0x00,0x12,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2e,0x70,0x6b,0x6c,0x46,0x42, +0x0e,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x80,0x02,0x7d,0x71,0x00,0x28,0x58,0x01,0x00,0x00,0x00,0x61,0x71,0x01,0x63,0x74, +0x6f,0x72,0x63,0x68,0x2e,0x5f,0x75,0x74,0x69,0x6c,0x73,0x0a,0x5f,0x72,0x65,0x62, +0x75,0x69,0x6c,0x64,0x5f,0x74,0x65,0x6e,0x73,0x6f,0x72,0x5f,0x76,0x32,0x0a,0x71, +0x02,0x28,0x28,0x58,0x07,0x00,0x00,0x00,0x73,0x74,0x6f,0x72,0x61,0x67,0x65,0x71, +0x03,0x63,0x74,0x6f,0x72,0x63,0x68,0x0a,0x46,0x6c,0x6f,0x61,0x74,0x53,0x74,0x6f, +0x72,0x61,0x67,0x65,0x0a,0x71,0x04,0x58,0x0e,0x00,0x00,0x00,0x39,0x34,0x37,0x31, +0x32,0x37,0x31,0x37,0x37,0x32,0x31,0x32,0x39,0x36,0x71,0x05,0x58,0x03,0x00,0x00, +0x00,0x63,0x70,0x75,0x71,0x06,0x4b,0x02,0x74,0x71,0x07,0x51,0x4b,0x00,0x4b,0x02, +0x85,0x71,0x08,0x4b,0x01,0x85,0x71,0x09,0x89,0x63,0x63,0x6f,0x6c,0x6c,0x65,0x63, +0x74,0x69,0x6f,0x6e,0x73,0x0a,0x4f,0x72,0x64,0x65,0x72,0x65,0x64,0x44,0x69,0x63, +0x74,0x0a,0x71,0x0a,0x29,0x52,0x71,0x0b,0x74,0x71,0x0c,0x52,0x71,0x0d,0x58,0x01, +0x00,0x00,0x00,0x62,0x71,0x0e,0x4b,0x0a,0x58,0x01,0x00,0x00,0x00,0x63,0x71,0x0f, +0x58,0x02,0x00,0x00,0x00,0x6b,0x32,0x71,0x10,0x75,0x2e,0x50,0x4b,0x07,0x08,0x22, +0x1f,0x1b,0x8d,0xcb,0x00,0x00,0x00,0xcb,0x00,0x00,0x00,0x50,0x4b,0x03,0x04,0x00, +0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x1b,0x00,0x2c,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65, +0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x37,0x31,0x32,0x37,0x31,0x37,0x37,0x32, +0x31,0x32,0x39,0x36,0x46,0x42,0x28,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x00,0x00,0x80,0x3f,0x00,0x00,0x00,0x40,0x50,0x4b,0x07,0x08,0x76,0xa5,0x3f,0x2e, +0x08,0x00,0x00,0x00,0x08,0x00,0x00,0x00,0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x0f,0x00,0x3b,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x76,0x65, +0x72,0x73,0x69,0x6f,0x6e,0x46,0x42,0x37,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x33,0x0a,0x50,0x4b,0x07,0x08,0xd1,0x9e,0x67,0x55,0x02,0x00,0x00,0x00,0x02,0x00, +0x00,0x00,0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00, +0x00,0x00,0x22,0x1f,0x1b,0x8d,0xcb,0x00,0x00,0x00,0xcb,0x00,0x00,0x00,0x10,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2e,0x70,0x6b,0x6c, +0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00, +0x76,0xa5,0x3f,0x2e,0x08,0x00,0x00,0x00,0x08,0x00,0x00,0x00,0x1b,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x1b,0x01,0x00,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x37,0x31,0x32, +0x37,0x31,0x37,0x37,0x32,0x31,0x32,0x39,0x36,0x50,0x4b,0x01,0x02,0x00,0x00,0x00, +0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0xd1,0x9e,0x67,0x55,0x02,0x00,0x00, +0x00,0x02,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x98,0x01,0x00,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x76, +0x65,0x72,0x73,0x69,0x6f,0x6e,0x50,0x4b,0x06,0x06,0x2c,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x1e,0x03,0x2d,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x03,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xc4,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x12,0x02,0x00,0x00,0x00,0x00,0x00,0x00,0x50,0x4b, +0x06,0x07,0x00,0x00,0x00,0x00,0xd6,0x02,0x00,0x00,0x00,0x00,0x00,0x00,0x01,0x00, +0x00,0x00,0x50,0x4b,0x05,0x06,0x00,0x00,0x00,0x00,0x03,0x00,0x03,0x00,0xc4,0x00, +0x00,0x00,0x12,0x02,0x00,0x00,0x00,0x00 +}; + +/* The following array is generated using the following steps: +(1) Python code +``` +import torch +import k2 + +d2 = {"a": torch.Tensor([1, 2]), "b": k2.RaggedTensor([[1.5, 2], [3], []])} +torch.save(d2, "d2.pt") +``` + +(2) Bash command +``` +bin2c --name kTestLoadData2 d2.pt > xxx.h +``` + +(3) Copy the content in xxx.h to this file + +So kTestLoadData2 contains a dict containing: +- key "a", value: torch.tensor([1, 2]) +- key "b", value: k2.RaggedTensor([[15, 2], [3], []]) +*/ +static const uint8_t kTestLoadData2[] = { +0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x10,0x00,0x12,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2e,0x70,0x6b,0x6c,0x46,0x42, +0x0e,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x80,0x02,0x7d,0x71,0x00,0x28,0x58,0x01,0x00,0x00,0x00,0x61,0x71,0x01,0x63,0x74, +0x6f,0x72,0x63,0x68,0x2e,0x5f,0x75,0x74,0x69,0x6c,0x73,0x0a,0x5f,0x72,0x65,0x62, +0x75,0x69,0x6c,0x64,0x5f,0x74,0x65,0x6e,0x73,0x6f,0x72,0x5f,0x76,0x32,0x0a,0x71, +0x02,0x28,0x28,0x58,0x07,0x00,0x00,0x00,0x73,0x74,0x6f,0x72,0x61,0x67,0x65,0x71, +0x03,0x63,0x74,0x6f,0x72,0x63,0x68,0x0a,0x46,0x6c,0x6f,0x61,0x74,0x53,0x74,0x6f, +0x72,0x61,0x67,0x65,0x0a,0x71,0x04,0x58,0x0e,0x00,0x00,0x00,0x39,0x34,0x37,0x33, +0x34,0x31,0x30,0x39,0x37,0x30,0x32,0x38,0x39,0x36,0x71,0x05,0x58,0x03,0x00,0x00, +0x00,0x63,0x70,0x75,0x71,0x06,0x4b,0x02,0x74,0x71,0x07,0x51,0x4b,0x00,0x4b,0x02, +0x85,0x71,0x08,0x4b,0x01,0x85,0x71,0x09,0x89,0x63,0x63,0x6f,0x6c,0x6c,0x65,0x63, +0x74,0x69,0x6f,0x6e,0x73,0x0a,0x4f,0x72,0x64,0x65,0x72,0x65,0x64,0x44,0x69,0x63, +0x74,0x0a,0x71,0x0a,0x29,0x52,0x71,0x0b,0x74,0x71,0x0c,0x52,0x71,0x0d,0x58,0x01, +0x00,0x00,0x00,0x62,0x71,0x0e,0x63,0x5f,0x6b,0x32,0x2e,0x72,0x61,0x67,0x67,0x65, +0x64,0x0a,0x52,0x61,0x67,0x67,0x65,0x64,0x54,0x65,0x6e,0x73,0x6f,0x72,0x0a,0x71, +0x0f,0x29,0x81,0x71,0x10,0x68,0x02,0x28,0x28,0x68,0x03,0x63,0x74,0x6f,0x72,0x63, +0x68,0x0a,0x49,0x6e,0x74,0x53,0x74,0x6f,0x72,0x61,0x67,0x65,0x0a,0x71,0x11,0x58, +0x0e,0x00,0x00,0x00,0x39,0x34,0x37,0x33,0x34,0x31,0x30,0x39,0x37,0x36,0x37,0x38, +0x34,0x30,0x71,0x12,0x68,0x06,0x4b,0x04,0x74,0x71,0x13,0x51,0x4b,0x00,0x4b,0x04, +0x85,0x71,0x14,0x4b,0x01,0x85,0x71,0x15,0x89,0x68,0x0a,0x29,0x52,0x71,0x16,0x74, +0x71,0x17,0x52,0x71,0x18,0x58,0x08,0x00,0x00,0x00,0x72,0x6f,0x77,0x5f,0x69,0x64, +0x73,0x31,0x71,0x19,0x68,0x02,0x28,0x28,0x68,0x03,0x68,0x11,0x58,0x0e,0x00,0x00, +0x00,0x39,0x34,0x37,0x33,0x34,0x31,0x30,0x39,0x37,0x36,0x35,0x36,0x33,0x32,0x71, +0x1a,0x68,0x06,0x4b,0x03,0x74,0x71,0x1b,0x51,0x4b,0x00,0x4b,0x03,0x85,0x71,0x1c, +0x4b,0x01,0x85,0x71,0x1d,0x89,0x68,0x0a,0x29,0x52,0x71,0x1e,0x74,0x71,0x1f,0x52, +0x71,0x20,0x87,0x71,0x21,0x62,0x75,0x2e,0x50,0x4b,0x07,0x08,0x65,0x4b,0xbb,0x5c, +0x78,0x01,0x00,0x00,0x78,0x01,0x00,0x00,0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x1b,0x00,0x3f,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61, +0x74,0x61,0x2f,0x39,0x34,0x37,0x33,0x34,0x31,0x30,0x39,0x37,0x30,0x32,0x38,0x39, +0x36,0x46,0x42,0x3b,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x00,0x00,0x80,0x3f,0x00,0x00,0x00,0x40,0x50,0x4b,0x07,0x08,0x76,0xa5,0x3f,0x2e, +0x08,0x00,0x00,0x00,0x08,0x00,0x00,0x00,0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x1b,0x00,0x2f,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61, +0x74,0x61,0x2f,0x39,0x34,0x37,0x33,0x34,0x31,0x30,0x39,0x37,0x36,0x35,0x36,0x33, +0x32,0x46,0x42,0x2b,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x0f,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0x50,0x4b,0x07,0x08, +0x8d,0xf1,0xd1,0x59,0x0c,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x50,0x4b,0x03,0x04, +0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x1b,0x00,0x2b,0x00,0x61,0x72,0x63,0x68,0x69,0x76, +0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x37,0x33,0x34,0x31,0x30,0x39,0x37, +0x36,0x37,0x38,0x34,0x30,0x46,0x42,0x27,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0x03,0x00,0x00,0x00, +0x50,0x4b,0x07,0x08,0xc7,0x7d,0xba,0x9c,0x10,0x00,0x00,0x00,0x10,0x00,0x00,0x00, +0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x0f,0x00,0x33,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x76,0x65,0x72,0x73,0x69,0x6f,0x6e,0x46,0x42,0x2f, +0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x33,0x0a,0x50,0x4b,0x07,0x08,0xd1,0x9e,0x67,0x55,0x02,0x00,0x00,0x00,0x02,0x00, +0x00,0x00,0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00, +0x00,0x00,0x65,0x4b,0xbb,0x5c,0x78,0x01,0x00,0x00,0x78,0x01,0x00,0x00,0x10,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2e,0x70,0x6b,0x6c, +0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00, +0x76,0xa5,0x3f,0x2e,0x08,0x00,0x00,0x00,0x08,0x00,0x00,0x00,0x1b,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xc8,0x01,0x00,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x37,0x33,0x34, +0x31,0x30,0x39,0x37,0x30,0x32,0x38,0x39,0x36,0x50,0x4b,0x01,0x02,0x00,0x00,0x00, +0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x8d,0xf1,0xd1,0x59,0x0c,0x00,0x00, +0x00,0x0c,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x58,0x02,0x00,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64, +0x61,0x74,0x61,0x2f,0x39,0x34,0x37,0x33,0x34,0x31,0x30,0x39,0x37,0x36,0x35,0x36, +0x33,0x32,0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00, +0x00,0x00,0xc7,0x7d,0xba,0x9c,0x10,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x1b,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xdc,0x02,0x00,0x00, +0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x37, +0x33,0x34,0x31,0x30,0x39,0x37,0x36,0x37,0x38,0x34,0x30,0x50,0x4b,0x01,0x02,0x00, +0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0xd1,0x9e,0x67,0x55,0x02, +0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x60,0x03,0x00,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65, +0x2f,0x76,0x65,0x72,0x73,0x69,0x6f,0x6e,0x50,0x4b,0x06,0x06,0x2c,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x1e,0x03,0x2d,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x05,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x05,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x56,0x01,0x00,0x00,0x00,0x00,0x00,0x00,0xd2,0x03,0x00,0x00,0x00,0x00,0x00,0x00, +0x50,0x4b,0x06,0x07,0x00,0x00,0x00,0x00,0x28,0x05,0x00,0x00,0x00,0x00,0x00,0x00, +0x01,0x00,0x00,0x00,0x50,0x4b,0x05,0x06,0x00,0x00,0x00,0x00,0x05,0x00,0x05,0x00, +0x56,0x01,0x00,0x00,0xd2,0x03,0x00,0x00,0x00,0x00 +}; + +/* The following array is generated using the following steps: +(1) Python code +``` +import torch +import k2 + +d3 = { + "a": torch.tensor([1, 2], device=torch.device("cuda:0")), + "b": k2.RaggedTensor([[15, 2], [3], []], device="cuda:0"), +} +torch.save(d3, "d3.pt") +``` + +(2) Bash command +``` +bin2c --name kTestLoadData3 d3.pt > xxx.h +``` + +(3) Copy the content in xxx.h to this file + +So kTestLoadData3 contains a dict containing: +- key "a", value: torch.tensor([1, 2], device="cuda:0") +- key "b", value: k2.RaggedTensor([[15, 2], [3], []], device="cuda:0") +*/ +static const uint8_t kTestLoadData3[] = { +0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x10,0x00,0x12,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2e,0x70,0x6b,0x6c,0x46,0x42, +0x0e,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x80,0x02,0x7d,0x71,0x00,0x28,0x58,0x01,0x00,0x00,0x00,0x61,0x71,0x01,0x63,0x74, +0x6f,0x72,0x63,0x68,0x2e,0x5f,0x75,0x74,0x69,0x6c,0x73,0x0a,0x5f,0x72,0x65,0x62, +0x75,0x69,0x6c,0x64,0x5f,0x74,0x65,0x6e,0x73,0x6f,0x72,0x5f,0x76,0x32,0x0a,0x71, +0x02,0x28,0x28,0x58,0x07,0x00,0x00,0x00,0x73,0x74,0x6f,0x72,0x61,0x67,0x65,0x71, +0x03,0x63,0x74,0x6f,0x72,0x63,0x68,0x0a,0x4c,0x6f,0x6e,0x67,0x53,0x74,0x6f,0x72, +0x61,0x67,0x65,0x0a,0x71,0x04,0x58,0x0e,0x00,0x00,0x00,0x39,0x34,0x32,0x39,0x31, +0x36,0x39,0x36,0x35,0x32,0x32,0x37,0x35,0x32,0x71,0x05,0x58,0x06,0x00,0x00,0x00, +0x63,0x75,0x64,0x61,0x3a,0x30,0x71,0x06,0x4b,0x02,0x74,0x71,0x07,0x51,0x4b,0x00, +0x4b,0x02,0x85,0x71,0x08,0x4b,0x01,0x85,0x71,0x09,0x89,0x63,0x63,0x6f,0x6c,0x6c, +0x65,0x63,0x74,0x69,0x6f,0x6e,0x73,0x0a,0x4f,0x72,0x64,0x65,0x72,0x65,0x64,0x44, +0x69,0x63,0x74,0x0a,0x71,0x0a,0x29,0x52,0x71,0x0b,0x74,0x71,0x0c,0x52,0x71,0x0d, +0x58,0x01,0x00,0x00,0x00,0x62,0x71,0x0e,0x63,0x5f,0x6b,0x32,0x2e,0x72,0x61,0x67, +0x67,0x65,0x64,0x0a,0x52,0x61,0x67,0x67,0x65,0x64,0x54,0x65,0x6e,0x73,0x6f,0x72, +0x0a,0x71,0x0f,0x29,0x81,0x71,0x10,0x68,0x02,0x28,0x28,0x68,0x03,0x63,0x74,0x6f, +0x72,0x63,0x68,0x0a,0x49,0x6e,0x74,0x53,0x74,0x6f,0x72,0x61,0x67,0x65,0x0a,0x71, +0x11,0x58,0x0e,0x00,0x00,0x00,0x39,0x34,0x32,0x39,0x34,0x30,0x36,0x37,0x34,0x36, +0x33,0x39,0x32,0x30,0x71,0x12,0x58,0x06,0x00,0x00,0x00,0x63,0x75,0x64,0x61,0x3a, +0x30,0x71,0x13,0x4b,0x04,0x74,0x71,0x14,0x51,0x4b,0x00,0x4b,0x04,0x85,0x71,0x15, +0x4b,0x01,0x85,0x71,0x16,0x89,0x68,0x0a,0x29,0x52,0x71,0x17,0x74,0x71,0x18,0x52, +0x71,0x19,0x58,0x08,0x00,0x00,0x00,0x72,0x6f,0x77,0x5f,0x69,0x64,0x73,0x31,0x71, +0x1a,0x68,0x02,0x28,0x28,0x68,0x03,0x68,0x11,0x58,0x0e,0x00,0x00,0x00,0x39,0x34, +0x32,0x39,0x31,0x36,0x39,0x36,0x30,0x35,0x34,0x34,0x38,0x30,0x71,0x1b,0x58,0x06, +0x00,0x00,0x00,0x63,0x75,0x64,0x61,0x3a,0x30,0x71,0x1c,0x4b,0x03,0x74,0x71,0x1d, +0x51,0x4b,0x00,0x4b,0x03,0x85,0x71,0x1e,0x4b,0x01,0x85,0x71,0x1f,0x89,0x68,0x0a, +0x29,0x52,0x71,0x20,0x74,0x71,0x21,0x52,0x71,0x22,0x87,0x71,0x23,0x62,0x75,0x2e, +0x50,0x4b,0x07,0x08,0xb2,0x47,0xda,0xbc,0x90,0x01,0x00,0x00,0x90,0x01,0x00,0x00, +0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x1b,0x00,0x27,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x32,0x39,0x31, +0x36,0x39,0x36,0x30,0x35,0x34,0x34,0x38,0x30,0x46,0x42,0x23,0x00,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x0f,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0x50,0x4b,0x07,0x08, +0x8d,0xf1,0xd1,0x59,0x0c,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x50,0x4b,0x03,0x04, +0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x1b,0x00,0x2b,0x00,0x61,0x72,0x63,0x68,0x69,0x76, +0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x32,0x39,0x31,0x36,0x39,0x36,0x35, +0x32,0x32,0x37,0x35,0x32,0x46,0x42,0x27,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x01,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x50,0x4b,0x07,0x08,0xb9,0xdd,0xf6,0x00,0x10,0x00,0x00,0x00,0x10,0x00,0x00,0x00, +0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x1b,0x00,0x27,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x32,0x39,0x34, +0x30,0x36,0x37,0x34,0x36,0x33,0x39,0x32,0x30,0x46,0x42,0x23,0x00,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x03,0x00,0x00,0x00,0x03,0x00,0x00,0x00, +0x50,0x4b,0x07,0x08,0xc7,0x7d,0xba,0x9c,0x10,0x00,0x00,0x00,0x10,0x00,0x00,0x00, +0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x0f,0x00,0x33,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x76,0x65,0x72,0x73,0x69,0x6f,0x6e,0x46,0x42,0x2f, +0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x33,0x0a,0x50,0x4b,0x07,0x08,0xd1,0x9e,0x67,0x55,0x02,0x00,0x00,0x00,0x02,0x00, +0x00,0x00,0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00, +0x00,0x00,0xb2,0x47,0xda,0xbc,0x90,0x01,0x00,0x00,0x90,0x01,0x00,0x00,0x10,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2e,0x70,0x6b,0x6c, +0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00, +0x8d,0xf1,0xd1,0x59,0x0c,0x00,0x00,0x00,0x0c,0x00,0x00,0x00,0x1b,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xe0,0x01,0x00,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x32,0x39,0x31, +0x36,0x39,0x36,0x30,0x35,0x34,0x34,0x38,0x30,0x50,0x4b,0x01,0x02,0x00,0x00,0x00, +0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0xb9,0xdd,0xf6,0x00,0x10,0x00,0x00, +0x00,0x10,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x5c,0x02,0x00,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64, +0x61,0x74,0x61,0x2f,0x39,0x34,0x32,0x39,0x31,0x36,0x39,0x36,0x35,0x32,0x32,0x37, +0x35,0x32,0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00, +0x00,0x00,0xc7,0x7d,0xba,0x9c,0x10,0x00,0x00,0x00,0x10,0x00,0x00,0x00,0x1b,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xe0,0x02,0x00,0x00, +0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x32, +0x39,0x34,0x30,0x36,0x37,0x34,0x36,0x33,0x39,0x32,0x30,0x50,0x4b,0x01,0x02,0x00, +0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0xd1,0x9e,0x67,0x55,0x02, +0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x0f,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x60,0x03,0x00,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65, +0x2f,0x76,0x65,0x72,0x73,0x69,0x6f,0x6e,0x50,0x4b,0x06,0x06,0x2c,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x1e,0x03,0x2d,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x05,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x05,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x56,0x01,0x00,0x00,0x00,0x00,0x00,0x00,0xd2,0x03,0x00,0x00,0x00,0x00,0x00,0x00, +0x50,0x4b,0x06,0x07,0x00,0x00,0x00,0x00,0x28,0x05,0x00,0x00,0x00,0x00,0x00,0x00, +0x01,0x00,0x00,0x00,0x50,0x4b,0x05,0x06,0x00,0x00,0x00,0x00,0x05,0x00,0x05,0x00, +0x56,0x01,0x00,0x00,0xd2,0x03,0x00,0x00,0x00,0x00 +}; + +/* The following array is generated using the following steps: +(1) Python code +``` +import torch +import k2 + +fsa = k2.Fsa.from_str( +""" +0 1 -1 0.1 +1 +""" +) +fsa.aux_labels = k2.RaggedTensor([[1, 2]]) +fsa.attr = torch.tensor([1.5]) +fsa = fsa.to("cuda:0") +torch.save(fsa.as_dict(), "fsa.pt") +``` + +(2) Bash command +``` +bin2c --name kTestLoadData4 fsa.pt > xxx.h +``` + +(3) Copy the content in xxx.h to this file + +So kTestLoadData4 contains a dict containing: +- key "arcs", value: torch.tensor([0, 1, -1, 1036831949], dtype=torch.int32) +- key "aux_labels", value: k2.RaggedTensor([[1, 2]], device='cuda:0', dtype=torch.int32) // NOLINT +- key "attr", value: torch.tensor([1.5], device='cuda:0') +*/ +static const uint8_t kTestLoadData4[] = { +0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x10,0x00,0x12,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2e,0x70,0x6b,0x6c,0x46,0x42, +0x0e,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x80,0x02,0x7d,0x71,0x00,0x28,0x58,0x04,0x00,0x00,0x00,0x61,0x72,0x63,0x73,0x71, +0x01,0x63,0x74,0x6f,0x72,0x63,0x68,0x2e,0x5f,0x75,0x74,0x69,0x6c,0x73,0x0a,0x5f, +0x72,0x65,0x62,0x75,0x69,0x6c,0x64,0x5f,0x74,0x65,0x6e,0x73,0x6f,0x72,0x5f,0x76, +0x32,0x0a,0x71,0x02,0x28,0x28,0x58,0x07,0x00,0x00,0x00,0x73,0x74,0x6f,0x72,0x61, +0x67,0x65,0x71,0x03,0x63,0x74,0x6f,0x72,0x63,0x68,0x0a,0x49,0x6e,0x74,0x53,0x74, +0x6f,0x72,0x61,0x67,0x65,0x0a,0x71,0x04,0x58,0x0e,0x00,0x00,0x00,0x39,0x34,0x36, +0x37,0x37,0x32,0x30,0x33,0x35,0x31,0x34,0x38,0x36,0x34,0x71,0x05,0x58,0x06,0x00, +0x00,0x00,0x63,0x75,0x64,0x61,0x3a,0x30,0x71,0x06,0x4b,0x04,0x74,0x71,0x07,0x51, +0x4b,0x00,0x4b,0x01,0x4b,0x04,0x86,0x71,0x08,0x4b,0x04,0x4b,0x01,0x86,0x71,0x09, +0x89,0x63,0x63,0x6f,0x6c,0x6c,0x65,0x63,0x74,0x69,0x6f,0x6e,0x73,0x0a,0x4f,0x72, +0x64,0x65,0x72,0x65,0x64,0x44,0x69,0x63,0x74,0x0a,0x71,0x0a,0x29,0x52,0x71,0x0b, +0x74,0x71,0x0c,0x52,0x71,0x0d,0x58,0x0a,0x00,0x00,0x00,0x61,0x75,0x78,0x5f,0x6c, +0x61,0x62,0x65,0x6c,0x73,0x71,0x0e,0x63,0x5f,0x6b,0x32,0x2e,0x72,0x61,0x67,0x67, +0x65,0x64,0x0a,0x52,0x61,0x67,0x67,0x65,0x64,0x54,0x65,0x6e,0x73,0x6f,0x72,0x0a, +0x71,0x0f,0x29,0x81,0x71,0x10,0x68,0x02,0x28,0x28,0x68,0x03,0x68,0x04,0x58,0x0e, +0x00,0x00,0x00,0x39,0x34,0x36,0x37,0x37,0x32,0x30,0x33,0x36,0x30,0x39,0x39,0x35, +0x32,0x71,0x11,0x58,0x06,0x00,0x00,0x00,0x63,0x75,0x64,0x61,0x3a,0x30,0x71,0x12, +0x4b,0x02,0x74,0x71,0x13,0x51,0x4b,0x00,0x4b,0x02,0x85,0x71,0x14,0x4b,0x01,0x85, +0x71,0x15,0x89,0x68,0x0a,0x29,0x52,0x71,0x16,0x74,0x71,0x17,0x52,0x71,0x18,0x58, +0x08,0x00,0x00,0x00,0x72,0x6f,0x77,0x5f,0x69,0x64,0x73,0x31,0x71,0x19,0x68,0x02, +0x28,0x28,0x68,0x03,0x68,0x04,0x58,0x0e,0x00,0x00,0x00,0x39,0x34,0x36,0x37,0x37, +0x32,0x30,0x33,0x35,0x31,0x34,0x37,0x36,0x38,0x71,0x1a,0x58,0x06,0x00,0x00,0x00, +0x63,0x75,0x64,0x61,0x3a,0x30,0x71,0x1b,0x4b,0x02,0x74,0x71,0x1c,0x51,0x4b,0x00, +0x4b,0x02,0x85,0x71,0x1d,0x4b,0x01,0x85,0x71,0x1e,0x89,0x68,0x0a,0x29,0x52,0x71, +0x1f,0x74,0x71,0x20,0x52,0x71,0x21,0x87,0x71,0x22,0x62,0x58,0x04,0x00,0x00,0x00, +0x61,0x74,0x74,0x72,0x71,0x23,0x68,0x02,0x28,0x28,0x68,0x03,0x63,0x74,0x6f,0x72, +0x63,0x68,0x0a,0x46,0x6c,0x6f,0x61,0x74,0x53,0x74,0x6f,0x72,0x61,0x67,0x65,0x0a, +0x71,0x24,0x58,0x0e,0x00,0x00,0x00,0x39,0x34,0x36,0x37,0x37,0x32,0x31,0x33,0x34, +0x32,0x36,0x33,0x36,0x38,0x71,0x25,0x58,0x06,0x00,0x00,0x00,0x63,0x75,0x64,0x61, +0x3a,0x30,0x71,0x26,0x4b,0x01,0x74,0x71,0x27,0x51,0x4b,0x00,0x4b,0x01,0x85,0x71, +0x28,0x4b,0x01,0x85,0x71,0x29,0x89,0x68,0x0a,0x29,0x52,0x71,0x2a,0x74,0x71,0x2b, +0x52,0x71,0x2c,0x75,0x2e,0x50,0x4b,0x07,0x08,0x1d,0xd8,0x24,0x72,0xf5,0x01,0x00, +0x00,0xf5,0x01,0x00,0x00,0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x1b, +0x00,0x42,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f, +0x39,0x34,0x36,0x37,0x37,0x32,0x30,0x33,0x35,0x31,0x34,0x37,0x36,0x38,0x46,0x42, +0x3e,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x01,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x50,0x4b,0x07,0x08,0x7c,0x17,0x81,0x03, +0x08,0x00,0x00,0x00,0x08,0x00,0x00,0x00,0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x1b,0x00,0x2f,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61, +0x74,0x61,0x2f,0x39,0x34,0x36,0x37,0x37,0x32,0x30,0x33,0x35,0x31,0x34,0x38,0x36, +0x34,0x46,0x42,0x2b,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x00,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0xff,0xff,0xff,0xff,0xcd,0xcc,0xcc,0x3d, +0x50,0x4b,0x07,0x08,0x00,0xaa,0x76,0xce,0x10,0x00,0x00,0x00,0x10,0x00,0x00,0x00, +0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x1b,0x00,0x27,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x36,0x37,0x37, +0x32,0x30,0x33,0x36,0x30,0x39,0x39,0x35,0x32,0x46,0x42,0x23,0x00,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x00,0x00,0x00,0x00,0x02,0x00,0x00,0x00,0x50,0x4b,0x07,0x08,0xe2,0x17,0x2b,0xcf, +0x08,0x00,0x00,0x00,0x08,0x00,0x00,0x00,0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x1b,0x00,0x2f,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61, +0x74,0x61,0x2f,0x39,0x34,0x36,0x37,0x37,0x32,0x31,0x33,0x34,0x32,0x36,0x33,0x36, +0x38,0x46,0x42,0x2b,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x00,0x00,0xc0,0x3f,0x50,0x4b,0x07,0x08,0x6f,0x25,0xd8,0x5c,0x04,0x00,0x00,0x00, +0x04,0x00,0x00,0x00,0x50,0x4b,0x03,0x04,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x0f,0x00, +0x3f,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x76,0x65,0x72,0x73,0x69,0x6f, +0x6e,0x46,0x42,0x3b,0x00,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a,0x5a, +0x33,0x0a,0x50,0x4b,0x07,0x08,0xd1,0x9e,0x67,0x55,0x02,0x00,0x00,0x00,0x02,0x00, +0x00,0x00,0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00, +0x00,0x00,0x1d,0xd8,0x24,0x72,0xf5,0x01,0x00,0x00,0xf5,0x01,0x00,0x00,0x10,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2e,0x70,0x6b,0x6c, +0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00, +0x7c,0x17,0x81,0x03,0x08,0x00,0x00,0x00,0x08,0x00,0x00,0x00,0x1b,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x45,0x02,0x00,0x00,0x61,0x72, +0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x36,0x37,0x37, +0x32,0x30,0x33,0x35,0x31,0x34,0x37,0x36,0x38,0x50,0x4b,0x01,0x02,0x00,0x00,0x00, +0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0xaa,0x76,0xce,0x10,0x00,0x00, +0x00,0x10,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0xd8,0x02,0x00,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64, +0x61,0x74,0x61,0x2f,0x39,0x34,0x36,0x37,0x37,0x32,0x30,0x33,0x35,0x31,0x34,0x38, +0x36,0x34,0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00, +0x00,0x00,0xe2,0x17,0x2b,0xcf,0x08,0x00,0x00,0x00,0x08,0x00,0x00,0x00,0x1b,0x00, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x60,0x03,0x00,0x00, +0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x36, +0x37,0x37,0x32,0x30,0x33,0x36,0x30,0x39,0x39,0x35,0x32,0x50,0x4b,0x01,0x02,0x00, +0x00,0x00,0x00,0x08,0x08,0x00,0x00,0x00,0x00,0x00,0x00,0x6f,0x25,0xd8,0x5c,0x04, +0x00,0x00,0x00,0x04,0x00,0x00,0x00,0x1b,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x00,0x00,0x00,0x00,0xd8,0x03,0x00,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65, +0x2f,0x64,0x61,0x74,0x61,0x2f,0x39,0x34,0x36,0x37,0x37,0x32,0x31,0x33,0x34,0x32, +0x36,0x33,0x36,0x38,0x50,0x4b,0x01,0x02,0x00,0x00,0x00,0x00,0x08,0x08,0x00,0x00, +0x00,0x00,0x00,0x00,0xd1,0x9e,0x67,0x55,0x02,0x00,0x00,0x00,0x02,0x00,0x00,0x00, +0x0f,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x54,0x04, +0x00,0x00,0x61,0x72,0x63,0x68,0x69,0x76,0x65,0x2f,0x76,0x65,0x72,0x73,0x69,0x6f, +0x6e,0x50,0x4b,0x06,0x06,0x2c,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x1e,0x03,0x2d, +0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x06,0x00,0x00,0x00,0x00,0x00,0x00, +0x00,0x06,0x00,0x00,0x00,0x00,0x00,0x00,0x00,0x9f,0x01,0x00,0x00,0x00,0x00,0x00, +0x00,0xd2,0x04,0x00,0x00,0x00,0x00,0x00,0x00,0x50,0x4b,0x06,0x07,0x00,0x00,0x00, +0x00,0x71,0x06,0x00,0x00,0x00,0x00,0x00,0x00,0x01,0x00,0x00,0x00,0x50,0x4b,0x05, +0x06,0x00,0x00,0x00,0x00,0x06,0x00,0x06,0x00,0x9f,0x01,0x00,0x00,0xd2,0x04,0x00, +0x00,0x00,0x00 +}; diff --git a/k2/torch/csrc/test_wave_data.h b/k2/torch/csrc/test_wave_data.h new file mode 100644 index 000000000..06e26cd3c --- /dev/null +++ b/k2/torch/csrc/test_wave_data.h @@ -0,0 +1,23 @@ +// This file is generated by +// +// bin2c --name kTestWav a.wav > test_wave_data.h +// +// where a.wav is generated by the following python code: +// +// import soundfile +// import numpy as np +// a = np.arange(16, dtype=np.int16) +// a[0] = 32767 +// soundfile.write("a.wav", a, 16000, subtype="PCM_16") +// +// It is included in k2/torch/csrc/wave_reader_test.cu +// +// clang-format off + +static const uint8_t kTestWav[] = { +0x52,0x49,0x46,0x46,0x44,0x00,0x00,0x00,0x57,0x41,0x56,0x45,0x66,0x6d,0x74,0x20, +0x10,0x00,0x00,0x00,0x01,0x00,0x01,0x00,0x80,0x3e,0x00,0x00,0x00,0x7d,0x00,0x00, +0x02,0x00,0x10,0x00,0x64,0x61,0x74,0x61,0x20,0x00,0x00,0x00,0xff,0x7f,0x01,0x00, +0x02,0x00,0x03,0x00,0x04,0x00,0x05,0x00,0x06,0x00,0x07,0x00,0x08,0x00,0x09,0x00, +0x0a,0x00,0x0b,0x00,0x0c,0x00,0x0d,0x00,0x0e,0x00,0x0f,0x00 +}; diff --git a/k2/torch/csrc/utils.cu b/k2/torch/csrc/utils.cu new file mode 100644 index 000000000..324ae5f70 --- /dev/null +++ b/k2/torch/csrc/utils.cu @@ -0,0 +1,162 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include "caffe2/serialize/file_adapter.h" +#include "caffe2/serialize/inline_container.h" +#include "k2/csrc/array.h" +#include "k2/torch/csrc/deserialization.h" +#include "k2/torch/csrc/utils.h" + +#if K2_TORCH_VERSION_MAJOR > 1 || \ + (K2_TORCH_VERSION_MAJOR == 1 && K2_TORCH_VERSION_MINOR >= 9) +// for torch::jit::readArchiveAndTensors +#include "torch/csrc/jit/serialization/import_read.h" +#endif + +namespace k2 { + +torch::DeviceType ConvertDeviceType(DeviceType device_type) { + switch (device_type) { + case kCpu: + return torch::kCPU; + case kCuda: + return torch::kCUDA; + default: + K2_LOG(FATAL) << "Unknown device type: " << device_type; + } + // Unreachable code + return torch::kCPU; +} + +DeviceType ConvertDeviceType(torch::DeviceType device_type) { + switch (device_type) { + case torch::kCPU: + return kCpu; + case torch::kCUDA: + return kCuda; + default: + K2_LOG(FATAL) << "Unknown device type: " << device_type; + } + // Unreachable code + return kCpu; +} + +Dtype ConvertDtype(torch::ScalarType scalar_type) { + switch (scalar_type) { + case torch::kFloat: + return kFloatDtype; + case torch::kDouble: + return kDoubleDtype; + case torch::kInt: + return kInt32Dtype; + case torch::kLong: + return kInt64Dtype; + default: + // TODO(fangjun): add other types when needed + K2_LOG(FATAL) << "Unsupported scalar_type: " << scalar_type; + return kInt32Dtype; // unreachable code + } +} + +torch::ScalarType ConvertDtype(Dtype dtype) { + switch (dtype) { + case kFloatDtype: + return torch::kFloat; + case kDoubleDtype: + return torch::kDouble; + case kInt32Dtype: + return torch::kInt; + case kInt64Dtype: + return torch::kLong; + default: + // TODO(fangjun): add other types when needed + K2_LOG(FATAL) << "Unsupported dtype: " << TraitsOf(dtype).Name(); + return torch::ScalarType::Undefined; // unreachable code + } +} + +torch::Device DeviceFromContext(ContextPtr context) { + auto device_type = ConvertDeviceType(context->GetDeviceType()); + int32_t device_id = context->GetDeviceId(); + return torch::Device(device_type, device_id); +} + +ContextPtr ContextFromDevice(torch::Device device) { + torch::DeviceType device_type = device.type(); + + if (device_type == torch::kCPU) return GetCpuContext(); + + K2_CHECK_EQ(device.type(), torch::kCUDA); + return GetCudaContext(device.index()); +} + +template <> +Array1 Array1FromTorch(torch::Tensor tensor) { + K2_CHECK_EQ(tensor.dim(), 2) << "Expected dim: 2. Given: " << tensor.dim(); + K2_CHECK(tensor.dtype().Match()) + << "Expected dtype type: " << caffe2::TypeMeta::Make() + << ". Given: " << tensor.scalar_type(); + + K2_CHECK_EQ(tensor.stride(0), 4) << "Expected stride: 4. " + << "Given: " << tensor.stride(0); + + K2_CHECK_EQ(tensor.stride(1), 1) << "Expected stride: 1. " + << "Given: " << tensor.stride(1); + + K2_CHECK_EQ(tensor.numel() % 4, 0); + + auto region = NewRegion(tensor); + Array1 ans(tensor.numel() / 4, region, 0); + return ans; +} + +Tensor TensorFromTorch(torch::Tensor tensor) { + Dtype dtype = ConvertDtype(tensor.scalar_type()); + torch::IntArrayRef sizes = tensor.sizes(); + torch::IntArrayRef strides = tensor.strides(); + Shape shape({sizes.begin(), sizes.end()}, {strides.begin(), strides.end()}); + + auto region = NewRegion(tensor); + return Tensor(dtype, shape, region, 0); +} + +torch::Tensor TensorToTorch(Tensor &tensor) { + auto device = DeviceFromContext(tensor.Context()); + auto scalar_type = ConvertDtype(tensor.GetDtype()); + auto options = torch::device(device).dtype(scalar_type); + + auto dims_int32 = tensor.Dims(); + auto strides_int32 = tensor.Strides(); + std::vector sizes(dims_int32.begin(), dims_int32.end()); + std::vector strides(strides_int32.begin(), strides_int32.end()); + + // NOTE: we keep a copy of `Region` inside the lambda + // so that `torch::Tensor` always accesses valid memory. + // This prevent the memory managed by k2::Tensor from being freed + // as long as torch::Tensor is alive. + return torch::from_blob( + tensor.Data(), sizes, strides, + [saved_region = tensor.GetRegion()](void *) {}, options); +} + +} // namespace k2 diff --git a/k2/torch/csrc/utils.h b/k2/torch/csrc/utils.h new file mode 100644 index 000000000..8c6125ebe --- /dev/null +++ b/k2/torch/csrc/utils.h @@ -0,0 +1,231 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_UTILS_H_ +#define K2_TORCH_CSRC_UTILS_H_ + +#include + +#include "k2/csrc/array.h" +#include "k2/csrc/array_ops.h" +#include "k2/csrc/context.h" +#include "k2/csrc/fsa.h" +#include "k2/csrc/pytorch_context.h" +#include "k2/csrc/tensor.h" +#include "k2/csrc/tensor_ops.h" +#include "torch/script.h" + +namespace k2 { + +/** Convert a device type in k2 to torch. + + @param device_type A k2 device type. + @return Return a torch device type. + */ +torch::DeviceType ConvertDeviceType(DeviceType device_type); + +/** Convert a device type in torch to k2. + + @param device_type A torch device type. + @return Return a k2 device type. + */ +DeviceType ConvertDeviceType(torch::DeviceType device_type); + +/** Construct a k2 context from a torch device. + + @param device A torch device. It can be either a CPU device or + a CUDA device. + @return Return a k2 context. + */ +ContextPtr ContextFromDevice(torch::Device device); + +/** Create a torch device from a k2 context. + @param [in] context It must be a CPU or a CUDA context. + @return Return a CPU or a GPU device depending on the given context. + */ +torch::Device DeviceFromContext(ContextPtr context); + +/** Convert torch ScalarType to k2 Dtype + + @param scalar_type A torch ScalarType. + @return Return a k2 Dtype. + */ +Dtype ConvertDtype(torch::ScalarType scalar_type); + +/** Conver k2 Dtype to torch ScalarType. + + @param dtype A k2 Dtype. + @return Return a torch ScalarType + */ +torch::ScalarType ConvertDtype(Dtype dtype); + +inline ContextPtr ContextFromTensor(torch::Tensor tensor) { + return ContextFromDevice(tensor.device()); +} + +template +Array1 Array1FromTorch(torch::Tensor tensor) { + K2_CHECK_EQ(tensor.dim(), 1) << "Expected dim: 1. Given: " << tensor.dim(); + K2_CHECK(tensor.dtype().Match()) + << "Expected dtype type: " << caffe2::TypeMeta::Make() + << ". Given: " << tensor.scalar_type(); + // Some empty tensor may have stride not equal to 1, e.g., tensor returned by + // clone() method, it is valid here, so we won't check its strides. + if (tensor.numel() > 0) + K2_CHECK_EQ(tensor.stride(0), 1) + << "Expected stride: 1. Given: " << tensor.stride(0); + + auto region = NewRegion(tensor); + Array1 ans(tensor.numel(), region, 0); + return ans; +} + +template <> +Array1 Array1FromTorch(torch::Tensor tensor); + +template +Array2 Array2FromTorch(torch::Tensor tensor) { + K2_CHECK_EQ(tensor.dim(), 2) << "Expected dim: 2. Given: " << tensor.dim(); + K2_CHECK(tensor.dtype().Match()) + << "Expected dtype type: " << caffe2::TypeMeta::Make() + << ". Given: " << tensor.scalar_type(); + + K2_CHECK_EQ(tensor.stride(1), 1) + << "Expected stride: 1. Given: " << tensor.stride(1); + + auto region = NewRegion(tensor); + Array2 ans(tensor.size(0), // dim0 + tensor.size(1), // dim1 + tensor.stride(0), // elem_stride0 + 0, // byte_offset + region); // region + return ans; +} + +template +torch::Tensor Array2ToTorch(Array2 &array) { + auto device = DeviceFromContext(array.Context()); + auto scalar_type = caffe2::TypeMeta::Make(); + auto options = torch::device(device).dtype(scalar_type); + + // NOTE: we keep a copy of `Region` inside the lambda + // so that `torch::Tensor` always accesses valid memory. + auto tensor = torch::from_blob( + array.Data(), {array.Dim0(), array.Dim1()}, {array.ElemStride0(), 1}, + [saved_region = array.GetRegion()](void *) {}, options); + return tensor; +} + +/* Convert an Array1 to torch::Tensor. + + @tparam T A primitive type, e.g., int32_t, which has + the corresponding `ToScalarType::value`. + + @param [in] array The input array. + + @return a 1-D torch::Tensor which shares the underlying memory + with the input array. + */ +template +torch::Tensor Array1ToTorch(Array1 &array) { + auto device = DeviceFromContext(array.Context()); + auto scalar_type = caffe2::TypeMeta::Make(); + auto options = torch::device(device).dtype(scalar_type); + // We will call torch::from_blob below. However, if we + // call it with an empty Array1, we'll get error: + // RuntimeError: CUDA error: invalid argument Exception raised from + // getDeviceFromPtr at /pytorch/aten/src/ATen/cuda/CUDADevice.h + // Definitely we need look into this, but let's just return an empty tensor + // when the input Array1 is empty for now. + if (array.Dim() == 0) return torch::empty(0, options); + + // NOTE: we keep a copy of `Region` inside the lambda + // so that `torch::Tensor` always accesses valid memory. + return torch::from_blob( + array.Data(), array.Dim(), [saved_region = array.GetRegion()](void *) {}, + options); +} + +/** Convert torch Tensor to k2 Tensor + + @param tensor A torch Tensor. + @return Return a k2 Tensor. + */ +Tensor TensorFromTorch(torch::Tensor tensor); + +/** Convert k2 Tensor to torch Tensor + + @param tensor A k2 Tensor. + @return Return a torch Tensor. + */ +torch::Tensor TensorToTorch(Tensor &tensor); + +/* Returns a 1-D tensor which indexes the src tensor using entries + from `index`. + + @param [in] src A 1-D tensor. + @param [in] index A 1-D tensor with dtype torch.int32. + It has to satisfy: + -1 <= index[i] < src.numel() + for i in [0, index.numel()) + CAUTION: We require that index.is_contiguous() is true. + @param [in] default_value The value for ans[i] when index[i] is -1. + @return + Returns a 1-D contiguous tensor such that: + ans[i] = src[index[i]] if index[i] > 0 + ans[i] = default_value if index[i] is -1 + */ +template +torch::Tensor IndexSelect(torch::Tensor src, torch::Tensor index, + T default_value) { + K2_CHECK_EQ(src.dim(), 1) << "Expected dim: 1. Given: " << src.dim(); + K2_CHECK(src.dtype().Match()) + << "Expected dtype type: " << caffe2::TypeMeta::Make() + << ". Given: " << src.scalar_type(); + K2_CHECK_EQ(index.dim(), 1) + << "Expected index dim: 1. Given : " << index.dim(); + K2_CHECK(index.dtype().Match()) + << "Expected dtype type: " << caffe2::TypeMeta::Make() + << ". Given: " << index.scalar_type(); + K2_CHECK(index.is_contiguous()) << "Expected contiguous"; + K2_CHECK_EQ(src.device(), index.device()) + << "Expected in the same device" + << " Given : " << src.device() << ", " << index.device(); + + bool allow_minus_one = true; + Array1 index_array = Array1FromTorch(index); + // If index_array.Dim() equals to zero, the `Index` below would produce an + // ans with `ans.Data()` be a nullptr, which will cause crash when calling + // `torch::from_blob`. Just return an empty tensor here. + // If src is an empty tensor, we should return an empty torch. + if (index_array.Dim() == 0 || src.numel() == 0) + return torch::empty({0}, src.options()); + if (src.is_contiguous()) { + Array1 src_array = Array1FromTorch(src); + Array1 ans_array = + Index(src_array, index_array, allow_minus_one, default_value); + return Array1ToTorch(ans_array); + } + Tensor tensor = TensorFromTorch(src); + Tensor ans = Index(tensor, index_array, allow_minus_one, default_value); + return TensorToTorch(ans); +} + +} // namespace k2 + +#endif // K2_TORCH_CSRC_UTILS_H_ diff --git a/k2/torch/csrc/wave_reader.cu b/k2/torch/csrc/wave_reader.cu new file mode 100644 index 000000000..814ca5781 --- /dev/null +++ b/k2/torch/csrc/wave_reader.cu @@ -0,0 +1,140 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include + +#include "k2/csrc/log.h" +#include "k2/torch/csrc/wave_reader.h" + +namespace k2 { +namespace { +// see http://soundfile.sapp.org/doc/WaveFormat/ +// +// Note: We assume little endian here +// TODO(fangjun): Support big endian +struct WaveHeader { + void Validate() const { + // F F I R + K2_CHECK_EQ(chunk_id, 0x46464952); + // E V A W + K2_CHECK_EQ(format, 0x45564157); + K2_CHECK_EQ(subchunk1_id, 0x20746d66); + K2_CHECK_EQ(subchunk1_size, 16); // 16 for PCM + K2_CHECK_EQ(audio_format, 1); // 1 for PCM + K2_CHECK_EQ(num_channels, 1); // we support only single channel for now + K2_CHECK_EQ(byte_rate, sample_rate * num_channels * bits_per_sample / 8) + << "byte_rate: " << byte_rate << ", " + << "sample_rate: " << sample_rate << ", " + << "num_channels: " << num_channels << ", " + << "bits_per_sample: " << bits_per_sample; + K2_CHECK_EQ(block_align, num_channels * bits_per_sample / 8) + << "block_align: " << block_align << ", " + << "num_channels: " << num_channels << ", " + << "bits_per_sample: " << bits_per_sample << ", "; + K2_CHECK_EQ(bits_per_sample, 16); // we support only 16 bits per sample + } + + // See + // https://en.wikipedia.org/wiki/WAV#Metadata + // and + // https://www.robotplanet.dk/audio/wav_meta_data/riff_mci.pdf + void SeekToDataChunk(std::istream &is) { + // a t a d + while (is && subchunk2_id != 0x61746164) { + const char *p = reinterpret_cast(&subchunk2_id); + is.seekg(subchunk2_size, std::istream::cur); + is.read(reinterpret_cast(&subchunk2_id), sizeof(int32_t)); + is.read(reinterpret_cast(&subchunk2_size), sizeof(int32_t)); + } + } + + int32_t chunk_id; + int32_t chunk_size; + int32_t format; + int32_t subchunk1_id; + int32_t subchunk1_size; + int16_t audio_format; + int16_t num_channels; + int32_t sample_rate; + int32_t byte_rate; + int16_t block_align; + int16_t bits_per_sample; + int32_t subchunk2_id; + int32_t subchunk2_size; +}; +static_assert(sizeof(WaveHeader) == 44, ""); + +// Read a wave file of mono-channel. +// Return its samples in a 1-D torch.float32 tensor, normalized +// by dividing 32768. +// Also, it returns the sample rate. +std::pair ReadWaveImpl(std::istream &is) { + WaveHeader header; + is.read(reinterpret_cast(&header), sizeof(header)); + K2_CHECK((bool)is) << "Failed to read wave header"; + + header.Validate(); + + header.SeekToDataChunk(is); + K2_CHECK((bool)is) << "Failed to locate the data chunk"; + + float sample_rate = header.sample_rate; + + // header.subchunk2_size contains the number of bytes in the data. + // As we assume each sample contains two bytes, so it is divided by 2 here + torch::Tensor data = torch::empty({header.subchunk2_size / 2}, torch::kShort); + + is.read(reinterpret_cast(data.data_ptr()), + header.subchunk2_size); + + K2_CHECK((bool)is) << "Failed to read wave samples"; + data = (data / 32768.).to(torch::kFloat32); + return {data, sample_rate}; +} + +} // namespace + +WaveReader::WaveReader(const std::string &filename) { + std::ifstream is(filename, std::ifstream::binary); + std::tie(data_, sample_rate_) = ReadWaveImpl(is); +} + +WaveReader::WaveReader(std::istream &is) { + std::tie(data_, sample_rate_) = ReadWaveImpl(is); +} + +torch::Tensor ReadWave(const std::string &filename, + float expected_sample_rate) { + WaveReader reader(filename); + K2_CHECK_EQ(reader.SampleRate(), expected_sample_rate); + return reader.Data(); +} + +std::vector ReadWave(const std::vector &filenames, + float expected_sample_rate) { + std::vector ans; + ans.reserve(filenames.size()); + for (const auto &path : filenames) { + ans.emplace_back(ReadWave(path, expected_sample_rate)); + } + return ans; +} + +} // namespace k2 diff --git a/k2/torch/csrc/wave_reader.h b/k2/torch/csrc/wave_reader.h new file mode 100644 index 000000000..8e2a6ad61 --- /dev/null +++ b/k2/torch/csrc/wave_reader.h @@ -0,0 +1,81 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef K2_TORCH_CSRC_WAVE_READER_H_ +#define K2_TORCH_CSRC_WAVE_READER_H_ + +#include +#include +#include + +#include "torch/script.h" + +namespace k2 { + +// It supports only mono, i.e., single channel, wave files, encoded +// in PCM format, i.e., raw format, without compression. +// Each sound sample shall be two bytes. +// +// If the above constraints are not satisfied, it throws an exception +// and shows you which constraint was violated. +class WaveReader { + public: + /** Construct a wave reader from a wave filename, encoded in PCM format. + + @param filename Path to a wave file. Must be mono and PCM encoded. + Note: Samples are divided by 32768 so that they are + in the range [-1, 1) + */ + explicit WaveReader(const std::string &filename); + + /** Construct a wave reader from a input stream. + See the help in the above function. You can open a file + with a std::ifstream and pass it to this function. + */ + explicit WaveReader(std::istream &is); + + /// Return a 1-D tensor with dtype torch.float32 + const torch::Tensor &Data() const { return data_; } + + float SampleRate() const { return sample_rate_; } + + private: + /// A 1-D tensor with dtype torch.float32 + torch::Tensor data_; + + float sample_rate_; +}; + +/** Read a wave file with expected sample rate. + + @param filename Path to a wave file. It MUST be single channel, PCM encoded. + @param expected_sample_rate Expected sample rate of the wave file. If the + sample rate don't match, it throws an exception. + + @return Return a 1-D torch tensor with dtype torch.float32. Samples are + normalized to the range [-1, 1). + */ +torch::Tensor ReadWave(const std::string &filename, float expected_sample_rate); + +/// Same `ReadWave` above. It supports reading a list of wave files. +std::vector ReadWave(const std::vector &filenames, + float expected_sample_rate); + +} // namespace k2 + +#endif // K2_TORCH_CSRC_WAVE_READER_H_ diff --git a/k2/torch/csrc/wave_reader_test.cu b/k2/torch/csrc/wave_reader_test.cu new file mode 100644 index 000000000..862e6cfe7 --- /dev/null +++ b/k2/torch/csrc/wave_reader_test.cu @@ -0,0 +1,37 @@ +/** + * Copyright 2021 Xiaomi Corporation (authors: Fangjun Kuang) + * + * See LICENSE for clarification regarding multiple authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "gtest/gtest.h" +#include "k2/csrc/log.h" +#include "k2/torch/csrc/test_wave_data.h" +#include "k2/torch/csrc/wave_reader.h" + +namespace k2 { + +TEST(WaveReader, Mono) { + std::stringstream ss; + ss.write(reinterpret_cast(kTestWav), sizeof(kTestWav)); + WaveReader reader(ss); + torch::Tensor expected = torch::arange(16, torch::kShort); + expected.data_ptr()[0] = 32767; + expected = (expected / 32768.).to(torch::kFloat32); + EXPECT_TRUE(reader.Data().allclose(expected, 1e-6)); + EXPECT_EQ(reader.SampleRate(), 16000); +} + +} // namespace k2 diff --git a/scripts/github_actions/fix_torch.py b/scripts/github_actions/fix_torch.py new file mode 100755 index 000000000..52adb037c --- /dev/null +++ b/scripts/github_actions/fix_torch.py @@ -0,0 +1,46 @@ +#!/usr/bin/env python3 + +import torch +from pathlib import Path +import urllib.request + + +def get_pytorch_version(): + # if it is 1.7.1+cuda101, then strip +cuda101 + return torch.__version__.split("+")[0] + + +def fix_pytorch_1_12(): + print("Fix https://github.com/pytorch/pytorch/issues/88290") + + torch_dir = Path(torch.__file__).parent + print("torch_dir", torch_dir) + mobile_dir = torch_dir / "include" / "torch" / "csrc" / "jit" / "mobile" + if mobile_dir.is_dir(): + print("Skip") + return + mobile_dir.mkdir() + files = ( + "code.h", + "debug_info.h", + "function.h", + "method.h", + "module.h", + ) + base_url = "https://raw.githubusercontent.com/pytorch/pytorch/v1.12.1/torch/csrc/jit/mobile/" # noqa + for f in files: + path = mobile_dir / f + url = base_url + f + print(f"Donwloading {url} to {path}") + urllib.request.urlretrieve(url, path) + + +def main(): + if "1.12" in get_pytorch_version(): + fix_pytorch_1_12() + else: + print(f"Skip since version is {get_pytorch_version()}") + + +if __name__ == "__main__": + main() diff --git a/scripts/github_actions/fix_torch.sh b/scripts/github_actions/fix_torch.sh new file mode 100755 index 000000000..a1ae8b6cc --- /dev/null +++ b/scripts/github_actions/fix_torch.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# To fix the following issue for torch 1.12.x +# https://github.com/pytorch/pytorch/issues/88290 +if [[ "${torch}" == "1.12.0" || ${torch} == "1.12.1" ]]; then + torch_dir=$(python3 -c "from pathlib import Path; import torch; print(Path(torch.__file__).parent)") + echo "torch_dir: ${torch_dir}" + cd $torch_dir/include/torch/csrc/jit + if [ ! -d mobile ]; then + mkdir mobile + cd mobile + files=( + code.h + debug_info.h + function.h + method.h + module.h + ) + for f in ${files[@]}; do + url=https://raw.githubusercontent.com/pytorch/pytorch/v1.12.1/torch/csrc/jit/mobile/$f + echo "Downloading $url" + wget $url + done + fi +fi diff --git a/scripts/github_actions/k2-torch-api-test/cmake/k2.cmake b/scripts/github_actions/k2-torch-api-test/cmake/k2.cmake index f11534a0b..ece61994b 100644 --- a/scripts/github_actions/k2-torch-api-test/cmake/k2.cmake +++ b/scripts/github_actions/k2-torch-api-test/cmake/k2.cmake @@ -1,11 +1,17 @@ # PYTHON_EXECUTABLE is set by cmake/pybind11.cmake -message(STATUS "Python executable: ${PYTHON_EXECUTABLE}") +if(DEFINED ENV{K2_INSTALL_PREFIX}) + message(STATUS "Using environment variable K2_INSTALL_PREFIX: $ENV{K2_INSTALL_PREFIX}") + set(K2_CMAKE_PREFIX_PATH $ENV{K2_INSTALL_PREFIX}) +else() + # PYTHON_EXECUTABLE is set by cmake/pybind11.cmake + message(STATUS "Python executable: ${PYTHON_EXECUTABLE}") -execute_process( - COMMAND "${PYTHON_EXECUTABLE}" -c "import k2; print(k2.cmake_prefix_path)" - OUTPUT_STRIP_TRAILING_WHITESPACE - OUTPUT_VARIABLE K2_CMAKE_PREFIX_PATH -) + execute_process( + COMMAND "${PYTHON_EXECUTABLE}" -c "import k2; print(k2.cmake_prefix_path)" + OUTPUT_STRIP_TRAILING_WHITESPACE + OUTPUT_VARIABLE K2_CMAKE_PREFIX_PATH + ) +endif() message(STATUS "K2_CMAKE_PREFIX_PATH: ${K2_CMAKE_PREFIX_PATH}") list(APPEND CMAKE_PREFIX_PATH "${K2_CMAKE_PREFIX_PATH}") diff --git a/setup.py b/setup.py index 42c109f24..b1129f827 100644 --- a/setup.py +++ b/setup.py @@ -34,12 +34,13 @@ import glob import os -import setuptools import shutil -from subprocess import DEVNULL, check_call import sys - +import urllib.request from pathlib import Path +from subprocess import DEVNULL, check_call + +import setuptools from setuptools.command.build_ext import build_ext import get_version @@ -51,29 +52,60 @@ is_windows = get_version.is_windows if sys.version_info < (3,): - print('Python 2 has reached end-of-life and is no longer supported by k2.') + print("Python 2 has reached end-of-life and is no longer supported by k2.") sys.exit(-1) if sys.version_info < (3, 6): - print('Python 3.5 has reached end-of-life on September 13th, 2020 ' - 'and is no longer supported by k2.') + print( + "Python 3.5 has reached end-of-life on September 13th, 2020 " + "and is no longer supported by k2." + ) sys.exit(-1) -cmake_path = shutil.which('cmake') -if cmake_path is None: - raise Exception('Please install CMake before you proceed.') +def fix_pytorch_1_12(): + print("Fix https://github.com/pytorch/pytorch/issues/88290") + + import torch + + torch_dir = Path(torch.__file__).parent + print(torch_dir) + mobile_dir = torch_dir / "include" / "torch" / "csrc" / "jit" / "mobile" + if mobile_dir.is_dir(): + print("Skip") + return + mobile_dir.mkdir() + files = ( + "code.h", + "debug_info.h", + "function.h", + "method.h", + "module.h", + ) + base_url = "https://raw.githubusercontent.com/pytorch/pytorch/v1.12.1/torch/csrc/jit/mobile/" # noqa + for f in files: + path = mobile_dir / f + url = base_url + f + print(f"Donwloading {url} to {path}") + urllib.request.urlretrieve(url, path) + + +if "1.12" in get_pytorch_version(): + fix_pytorch_1_12() -ret = check_call(['cmake', '--version'], stdout=DEVNULL, stderr=DEVNULL) +cmake_path = shutil.which("cmake") +if cmake_path is None: + raise Exception("Please install CMake before you proceed.") + +ret = check_call(["cmake", "--version"], stdout=DEVNULL, stderr=DEVNULL) if ret != 0: - raise Exception('Failed to get CMake version') + raise Exception("Failed to get CMake version") try: from wheel.bdist_wheel import bdist_wheel as _bdist_wheel class bdist_wheel(_bdist_wheel): - def finalize_options(self): _bdist_wheel.finalize_options(self) if is_for_pypi() and not is_macos(): @@ -84,20 +116,21 @@ def finalize_options(self): # The generated wheel has a name ending with # -linux_x86_64.whl self.root_is_pure = False + + except ImportError: bdist_wheel = None def cmake_extension(name, *args, **kwargs) -> setuptools.Extension: - kwargs['language'] = 'c++' + kwargs["language"] = "c++" sources = [] return setuptools.Extension(name, sources, *args, **kwargs) class BuildExtension(build_ext): - def build_extension(self, ext: setuptools.extension.Extension): - print(f'cmake_path: {cmake_path}') + print(f"cmake_path: {cmake_path}") # build/temp.linux-x86_64-3.8 os.makedirs(self.build_temp, exist_ok=True) @@ -107,16 +140,18 @@ def build_extension(self, ext: setuptools.extension.Extension): k2_dir = os.path.dirname(os.path.abspath(__file__)) - cmake_args = os.environ.get('K2_CMAKE_ARGS', '') - make_args = os.environ.get('K2_MAKE_ARGS', '') - system_make_args = os.environ.get('MAKEFLAGS', '') + cmake_args = os.environ.get("K2_CMAKE_ARGS", "") + make_args = os.environ.get("K2_MAKE_ARGS", "") + system_make_args = os.environ.get("MAKEFLAGS", "") - extra_cmake_args = ' -DK2_ENABLE_BENCHMARK=OFF ' - extra_cmake_args += ' -DK2_ENABLE_TESTS=OFF ' - extra_cmake_args += f' -DCMAKE_INSTALL_PREFIX={Path(self.build_lib).resolve()}/k2 ' # noqa + extra_cmake_args = " -DK2_ENABLE_BENCHMARK=OFF " + extra_cmake_args += " -DK2_ENABLE_TESTS=OFF " + extra_cmake_args += ( + f" -DCMAKE_INSTALL_PREFIX={Path(self.build_lib).resolve()}/k2 " # noqa + ) - if cmake_args == '': - cmake_args = '-DCMAKE_BUILD_TYPE=Release' + if cmake_args == "": + cmake_args = "-DCMAKE_BUILD_TYPE=Release" if ( make_args == "" @@ -125,45 +160,37 @@ def build_extension(self, ext: setuptools.extension.Extension): ): print("For fast compilation, run:") print('export K2_MAKE_ARGS="-j"; python setup.py install') - make_args = ' -j4 ' + make_args = " -j4 " print("Setting make_args to '-j4'") if is_macos(): - if not 'K2_WITH_CUDA=OFF' in cmake_args: - print('Disable CUDA for macOS') - cmake_args += ' -DK2_WITH_CUDA=OFF' + if not "K2_WITH_CUDA=OFF" in cmake_args: + print("Disable CUDA for macOS") + cmake_args += " -DK2_WITH_CUDA=OFF" - if 'PYTHON_EXECUTABLE' not in cmake_args: - print(f'Setting PYTHON_EXECUTABLE to {sys.executable}') - cmake_args += f' -DPYTHON_EXECUTABLE={sys.executable}' + if "PYTHON_EXECUTABLE" not in cmake_args: + print(f"Setting PYTHON_EXECUTABLE to {sys.executable}") + cmake_args += f" -DPYTHON_EXECUTABLE={sys.executable}" cmake_args += extra_cmake_args if is_windows(): - build_cmd = f''' + build_cmd = f""" cmake {cmake_args} -B {self.build_temp} -S {k2_dir} - cmake --build {self.build_temp} --target _k2 --config Release -- -m - cmake --build {self.build_temp} --target k2_torch_api --config Release -- -m cmake --build {self.build_temp} --target install --config Release -- -m - ''' - print(f'build command is:\n{build_cmd}') - ret = os.system(f'cmake {cmake_args} -B {self.build_temp} -S {k2_dir}') + """ + print(f"build command is:\n{build_cmd}") + ret = os.system(f"cmake {cmake_args} -B {self.build_temp} -S {k2_dir}") if ret != 0: - raise Exception('Failed to build k2') + raise Exception("Failed to build k2") - ret = os.system(f'cmake --build {self.build_temp} --target _k2 --config Release -- -m') + ret = os.system( + f"cmake --build {self.build_temp} --target install --config Release -- -m" + ) if ret != 0: - raise Exception('Failed to build k2') - - ret = os.system(f'cmake --build {self.build_temp} --target k2_torch_api --config Release -- -m') - if ret != 0: - raise Exception('Failed to build k2_torch_api') - - ret = os.system(f'cmake --build {self.build_temp} --target install --config Release -- -m') - if ret != 0: - raise Exception('Failed to build k2') + raise Exception("Failed to build k2") else: - build_cmd = f''' + build_cmd = f""" cd {self.build_temp} cmake {cmake_args} {k2_dir} @@ -171,78 +198,75 @@ def build_extension(self, ext: setuptools.extension.Extension): cat k2/csrc/version.h make {make_args} _k2 k2_torch_api install - ''' - print(f'build command is:\n{build_cmd}') + """ + print(f"build command is:\n{build_cmd}") ret = os.system(build_cmd) if ret != 0: - raise Exception('Failed to build k2') + raise Exception("Failed to build k2") def get_long_description(): - with open('README.md', 'r') as f: + with open("README.md", "r") as f: long_description = f.read() return long_description def get_short_description(): - return 'FSA/FST algorithms, intended to (eventually) be interoperable with PyTorch and similar' + return "FSA/FST algorithms, intended to (eventually) be interoperable with PyTorch and similar" -with open('k2/python/k2/__init__.py', 'a') as f: +with open("k2/python/k2/__init__.py", "a") as f: f.write(f"__dev_version__ = '{get_package_version()}'\n") dev_requirements = [ - 'clang-format==9.0.0', - 'flake8==3.8.3', - 'yapf==0.27.0', + "clang-format==9.0.0", + "flake8==3.8.3", + "yapf==0.27.0", ] install_requires = [ - f'torch=={get_pytorch_version()}', - 'graphviz', + f"torch=={get_pytorch_version()}", + "graphviz", ] setuptools.setup( - python_requires='>=3.6', - name='k2', + python_requires=">=3.6", + name="k2", version=get_package_version(), - author='Daniel Povey', - author_email='dpovey@gmail.com', - keywords='k2, FSA, FST', + author="Daniel Povey", + author_email="dpovey@gmail.com", + keywords="k2, FSA, FST", description=get_short_description(), long_description=get_long_description(), - long_description_content_type='text/markdown', - url='https://github.com/k2-fsa/k2', + long_description_content_type="text/markdown", + url="https://github.com/k2-fsa/k2", package_dir={ - 'k2': 'k2/python/k2', - 'k2.ragged': 'k2/python/k2/ragged', - 'k2.sparse': 'k2/python/k2/sparse', - 'k2.version': 'k2/python/k2/version', + "k2": "k2/python/k2", + "k2.ragged": "k2/python/k2/ragged", + "k2.sparse": "k2/python/k2/sparse", + "k2.version": "k2/python/k2/version", }, - packages=['k2', 'k2.ragged', 'k2.sparse', 'k2.version'], + packages=["k2", "k2.ragged", "k2.sparse", "k2.version"], install_requires=install_requires, - extras_require={'dev': dev_requirements}, - ext_modules=[cmake_extension('_k2')], - cmdclass={ - 'build_ext': BuildExtension, - 'bdist_wheel': bdist_wheel - }, + extras_require={"dev": dev_requirements}, + ext_modules=[cmake_extension("_k2")], + cmdclass={"build_ext": BuildExtension, "bdist_wheel": bdist_wheel}, zip_safe=False, classifiers=[ - 'Development Status :: 3 - Alpha', - 'Programming Language :: Python :: 3', - 'Programming Language :: C++', - 'Programming Language :: Python :: Implementation :: CPython', - 'Topic :: Scientific/Engineering :: Artificial Intelligence', + "Development Status :: 3 - Alpha", + "Programming Language :: Python :: 3", + "Programming Language :: C++", + "Programming Language :: Python :: Implementation :: CPython", + "Topic :: Scientific/Engineering :: Artificial Intelligence", ], ) # remove the line __dev_version__ from k2/python/k2/__init__.py -with open('k2/python/k2/__init__.py', 'r') as f: +with open("k2/python/k2/__init__.py", "r") as f: lines = f.readlines() -with open('k2/python/k2/__init__.py', 'w') as f: +with open("k2/python/k2/__init__.py", "w") as f: for line in lines: - if '__dev_version__' not in line: + if "__dev_version__" not in line: f.write(line)