Skip to content

Commit

Permalink
Construct RaggedArc from unary function tensor (#30)
Browse files Browse the repository at this point in the history
* Construct RaggedArc from unary function tensor

* Move fsa_from_unary_ragged and fsa_from_binary_tensor to C++

* add unit test to from unary function; add more functions to fsa

* Remove some rabbish code

* Add more unit tests and docs

* Remove the unused code

* Fix review comments, propagate attributes in To()

* Change the argument type from RaggedAny to Ragged<int32_t> in autograd function

* Delete declaration for template function

* Apply suggestions from code review

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>

* Fix documentation errors

Co-authored-by: Fangjun Kuang <csukuangfj@gmail.com>
  • Loading branch information
pkufool and csukuangfj committed Sep 30, 2021
1 parent cbff6a1 commit cca7a54
Show file tree
Hide file tree
Showing 29 changed files with 1,704 additions and 433 deletions.
16 changes: 7 additions & 9 deletions k2/csrc/fsa.h
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,12 @@ struct DenseFsaVec {
std::ostream &operator<<(std::ostream &os, const DenseFsaVec &dfsavec);

/*
Create an FSA from a Tensor. The Tensor t is expected to be an N by 4 tensor of
int32_t, where N is the number of arcs (the format is src_state, dest_state,
symbol, cost). The cost is not really an int32_t, it is a float. This code
will print an error message and output 'true' to 'error', and return an empty
FSA (with no states or arcs) if t was not interpretable as a valid FSA.
These requirements for a valid FSA are:
Create an FSA from a Tensor. The Tensor t is expected to be an N by 4 tensor
of int32_t, where N is the number of arcs (the format is src_state,
dest_state, symbol, cost). The cost is not really an int32_t, it is a float.
This code will print an error message and output 'true' to 'error', and return
an empty FSA (with no states or arcs) if t was not interpretable as a valid
FSA. These requirements for a valid FSA are:
- src_state values on the arcs must be non-decreasing
- all arcs with -1 as the label must be to a single state (call this
Expand Down Expand Up @@ -333,9 +333,7 @@ FsaVec FsaVecFromTensor(Tensor &t, bool *error);
refer to a part of the `values` array of
the input `vec`.
*/
inline Fsa GetFsaVecElement(FsaVec &vec, int32_t i) {
return vec.Index(0, i);
}
inline Fsa GetFsaVecElement(FsaVec &vec, int32_t i) { return vec.Index(0, i); }

/*
Create an FsaVec from a list of Fsas. Caution: Fsa and FsaVec are really
Expand Down
41 changes: 41 additions & 0 deletions k2/python/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,44 @@ target_link_libraries(_k2 PRIVATE context)
target_link_libraries(_k2 PRIVATE fsa)
target_include_directories(_k2 PRIVATE ${CMAKE_SOURCE_DIR})
target_include_directories(_k2 PRIVATE ${CMAKE_BINARY_DIR})
set_property(TARGET _k2 PROPERTY CXX_VISIBILITY_PRESET "default")

#---------------------------- Test torch CUDA sources ----------------------------

# please sort the source files alphabetically
set(torch_cuda_test_srcs
torch/v2/ragged_arc_test.cu
)
if(NOT K2_WITH_CUDA)
transform(OUTPUT_VARIABLE torch_cuda_test_srcs SRCS ${torch_cuda_test_srcs})
endif()

# utility function to add gtest
function(torch_add_cuda_test source)
get_filename_component(target_name ${source} NAME_WE)
add_executable(${target_name} "${source}")
set_target_properties(${target_name} PROPERTIES CUDA_SEPARABLE_COMPILATION ON)
target_link_libraries(${target_name}
PRIVATE
_k2
context
fsa
gtest
)

# NOTE: We set the working directory here so that
# it works also on windows. The reason is that
# the required DLLs are inside ${TORCH_DIR}/lib
# and they can be found by the exe if the current
# working directory is ${TORCH_DIR}\lib
add_test(NAME "Test.Cuda.${target_name}"
COMMAND
$<TARGET_FILE:${target_name}>
WORKING_DIRECTORY ${TORCH_DIR}/lib
)
endfunction()

foreach(source IN LISTS torch_cuda_test_srcs)
torch_add_cuda_test(${source})
endforeach()

2 changes: 2 additions & 0 deletions k2/python/csrc/k2.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

PYBIND11_MODULE(_k2, m) {
m.doc() = "pybind11 binding of k2";
// _k2 depends on torch, we should import torch before importing _k2.
py::module_::import("torch");
PybindVersion(m);
PybindTorch(m);
}
4 changes: 0 additions & 4 deletions k2/python/csrc/torch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
#include "k2/python/csrc/torch/discounted_cum_sum.h"
#include "k2/python/csrc/torch/fsa.h"
#include "k2/python/csrc/torch/fsa_algo.h"
#include "k2/python/csrc/torch/index_add.h"
#include "k2/python/csrc/torch/index_select.h"
#include "k2/python/csrc/torch/nbest.h"
#include "k2/python/csrc/torch/ragged.h"
#include "k2/python/csrc/torch/ragged_ops.h"
Expand All @@ -40,8 +38,6 @@ void PybindTorch(py::module &m) {
PybindDiscountedCumSum(m);
PybindFsa(m);
PybindFsaAlgo(m);
PybindIndexAdd(m);
PybindIndexSelect(m);
PybindNbest(m);
PybindRagged(m);
PybindRaggedOps(m);
Expand Down
1 change: 0 additions & 1 deletion k2/python/csrc/torch.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
#define K2_PYTHON_CSRC_TORCH_H_

#include "k2/csrc/log.h"
#include "k2/python/csrc/torch.h"
#include "torch/extension.h"

namespace pybind11 {
Expand Down
4 changes: 2 additions & 2 deletions k2/python/csrc/torch/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ set(torch_srcs
discounted_cum_sum.cu
fsa.cu
fsa_algo.cu
index_add.cu
index_select.cu
nbest.cu
ragged.cu
ragged_ops.cu
Expand All @@ -15,6 +13,8 @@ set(torch_srcs
v2/doc/doc.cu
v2/fsa.cu
v2/k2.cu
v2/k2_ops.cu
v2/ops.cu
v2/ragged_any.cu
v2/ragged_arc.cu
v2/ragged_shape.cu
Expand Down
71 changes: 0 additions & 71 deletions k2/python/csrc/torch/index_add.cu

This file was deleted.

35 changes: 0 additions & 35 deletions k2/python/csrc/torch/index_add.h

This file was deleted.

21 changes: 2 additions & 19 deletions k2/python/csrc/torch/torch_util.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,6 @@

namespace k2 {

torch::DeviceType ToTorchDeviceType(DeviceType type) {
switch (type) {
case kCuda:
return torch::kCUDA;
case kCpu:
return torch::kCPU;
case kUnk: // fall-through
default:
K2_LOG(FATAL) << "kUnk is not supported!";
return torch::kCPU; // unreachable code
}
}

DeviceType FromTorchDeviceType(const torch::DeviceType &type) {
switch (type) {
case torch::kCUDA:
Expand Down Expand Up @@ -86,9 +73,7 @@ torch::ScalarType ScalarTypeFromDtype(Dtype dtype) {

template <>
torch::Tensor ToTorch(Array1<Arc> &array) {
auto device_type = ToTorchDeviceType(array.Context()->GetDeviceType());
int32_t device_id = array.Context()->GetDeviceId();
auto device = torch::Device(device_type, device_id);
auto device = GetDevice(array.Context());
auto scalar_type = ToScalarType<int32_t>::value;
// an Arc has 4 members
K2_STATIC_ASSERT(sizeof(Arc) == 4 * sizeof(int32_t));
Expand Down Expand Up @@ -134,9 +119,7 @@ Tensor FromTorch(torch::Tensor tensor, TensorTag) {
return Tensor(dtype, shape, region, 0);
}
torch::Tensor ToTorch(Tensor &tensor) {
auto device_type = ToTorchDeviceType(tensor.Context()->GetDeviceType());
int32_t device_id = tensor.Context()->GetDeviceId();
auto device = torch::Device(device_type, device_id);
auto device = GetDevice(tensor.Context());
auto scalar_type = ScalarTypeFromDtype(tensor.GetDtype());
auto options = torch::device(device).dtype(scalar_type);

Expand Down
25 changes: 24 additions & 1 deletion k2/python/csrc/torch/torch_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,18 @@ namespace k2 {
@return torch::kCUDA or torch.kCPU.
*/
torch::DeviceType ToTorchDeviceType(DeviceType type);
inline torch::DeviceType ToTorchDeviceType(DeviceType type) {
switch (type) {
case kCuda:
return torch::kCUDA;
case kCpu:
return torch::kCPU;
case kUnk: // fall-through
default:
K2_LOG(FATAL) << "kUnk is not supported!";
return torch::kCPU; // unreachable code
}
}

/* Convert torch::DeviceType to k2::DeviceType.
Abort on failure.
Expand Down Expand Up @@ -252,6 +263,18 @@ PyClass To(PyClass &pyclass, py::object device) {
*/
ContextPtr GetContext(torch::Device device);

/** Create a torch device from a k2 context.
@param [in] context It must be a CPU or a CUDA context.
@return Return a CPU or a GPU device depending on the given context.
*/
inline torch::Device GetDevice(ContextPtr context) {
auto device_type = ToTorchDeviceType(context->GetDeviceType());
int32_t device_id = context->GetDeviceId();
return torch::Device(device_type, device_id);
}

inline ContextPtr GetContext(torch::Tensor tensor) {
return GetContext(tensor.device());
}
Expand Down
Loading

0 comments on commit cca7a54

Please sign in to comment.