Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Rnnt decoding #926

Merged
merged 10 commits into from
Mar 16, 2022
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions k2/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ add_subdirectory(host)
set(context_srcs
algorithms.cu
array_ops.cu
array_of_ragged.cu
pkufool marked this conversation as resolved.
Show resolved Hide resolved
connect.cu
context.cu
dtype.cu
Expand All @@ -65,6 +66,7 @@ set(context_srcs
ragged_utils.cu
rand.cu
rm_epsilon.cu
rnnt_decode.cu
tensor.cu
tensor_ops.cu
thread_pool.cu
Expand Down Expand Up @@ -142,6 +144,7 @@ target_link_libraries(test_utils PUBLIC context gtest)
# please sort the source files alphabetically
set(cuda_test_srcs
algorithms_test.cu
array_of_ragged_test.cu
array_ops_test.cu
array_test.cu
connect_test.cu
Expand All @@ -163,6 +166,7 @@ set(cuda_test_srcs
ragged_utils_test.cu
rand_test.cu
rm_epsilon_test.cu
rnnt_decode_test.cu
tensor_ops_test.cu
tensor_test.cu
thread_pool_test.cu
Expand Down
9 changes: 3 additions & 6 deletions k2/csrc/algorithms.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,12 +119,9 @@ class Renumbering {
return new2old_;
}

/* Return a mapping from new index to old index, with one extra element
containing the total number of kept elements if extra_element == true.
If Keep() can be interpreted as a tails vector, i.e. with 1 at the end
of sub-lists of elements, then New2Old(true) would corresponds to a
row-splits array and Old2New(false) would correspond to a row-ids
array.
/*
Return a mapping from new index to old index, with one extra element
containing the total number of kept elements if extra_element == true.
*/
Array1<int32_t> New2Old(bool extra_element) {
Array1<int32_t> &new2old_part = New2Old();
Expand Down
53 changes: 53 additions & 0 deletions k2/csrc/array_of_ragged.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "k2/csrc/array_of_ragged.h"

namespace k2 {

Array1OfRaggedShape::Array1OfRaggedShape(RaggedShape *src, int32_t num_srcs)
: num_srcs_(num_srcs) {
K2_CHECK_GE(num_srcs, 1);
K2_CHECK(src);
num_axes_ = src[0].NumAxes();
c_ = src[0].Context();

row_splits_ = Array2<int32_t *>(GetCpuContext(), num_axes_ - 1, num_srcs_);
row_ids_ = Array2<int32_t *>(GetCpuContext(), num_axes_ - 1, num_srcs_);
tot_sizes_ = Array1<int32_t>(GetCpuContext(), num_axes_, 0);

auto row_splits_acc = row_splits_.Accessor(),
row_ids_acc = row_ids_.Accessor();
int32_t *tot_sizes_data = tot_sizes_.Data();

for (int32_t i = 0; i < num_srcs_; ++i) {
K2_CHECK_EQ(src[i].NumAxes(), num_axes_);
K2_CHECK(c_->IsCompatible(*(src[i].Context())));
for (int32_t j = 1; j < num_axes_; ++j) {
row_splits_acc(j - 1, i) = src[i].RowSplits(j).Data();
row_ids_acc(j - 1, i) = src[i].RowIds(j).Data();
tot_sizes_data[j] += src[i].TotSize(j);
}
tot_sizes_data[0] += src[i].TotSize(0);
}

row_splits_ = row_splits_.To(c_);
row_ids_ = row_ids_.To(c_);
}

} // namespace k2
191 changes: 191 additions & 0 deletions k2/csrc/array_of_ragged.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Daniel Povey, Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef K2_CSRC_ARRAY_OF_RAGGED_H_
#define K2_CSRC_ARRAY_OF_RAGGED_H_

#include <string>
#include <utility>
#include <vector>

#include "k2/csrc/array.h"
#include "k2/csrc/context.h"
#include "k2/csrc/log.h"
#include "k2/csrc/ragged_ops.h"

namespace k2 {
/*
ArrayOfRaggedShape is a convenience function that gives you easy access
pkufool marked this conversation as resolved.
Show resolved Hide resolved
to pointers-of-pointers for an array of ragged shapes.
*/
class Array1OfRaggedShape {
public:
/*
Constructor.
Args:
srcs: pointers to the source shapes, a CPU pointer
num_srcs: the number of source shapes. All shapes must have the
same NumAxes() and must be on the same device.

TODO: we'll likely, later, add optional args which dictate which of
the MetaRowSplits() and MetaRowIds() are to be pre-populated; this should
enable us to save kernels by combining certain operations across the
axes.
*/
Array1OfRaggedShape(RaggedShape *src, int32_t num_srcs);
pkufool marked this conversation as resolved.
Show resolved Hide resolved
Array1OfRaggedShape() = default;

int32_t NumSrcs() const { return num_srcs_; }
int32_t NumAxes() const { return num_axes_; }

ContextPtr &Context() { return c_; }

// Returns device-accessible array of row-splits for the individual shapes,
// indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this
// Array2 is [NumAxes() - 1][NumSrcs()].
Array2<int32_t *> *RowSplits() { return &row_splits_; }
pkufool marked this conversation as resolved.
Show resolved Hide resolved

// Returns device-accessible vector of row-splits for a particular
// axis, indexed by 0 <= src < num_srcs.
int32_t **RowSplits(int32_t axis) { return row_splits_.Row(axis - 1).Data(); }

// Returns device-accessible array of row-ids for the individual shapes
// indexed [axis-1][src], with 0 <= src < num_srcs. The shape of this
// Array2 is [NumAxes() - 1][NumSrcs()].
Array2<int32_t *> *RowIds() { return &row_ids_; }

// Returns device-accessible vector of row-splits for a particular
// axis, indexed by 0 <= src < num_srcs.
int32_t **RowIds(int32_t axis) { return row_ids_.Row(axis - 1).Data(); }

/* Return the total size on this axis, which is the sum of the TotSize() of
the individual shapes. Requires 0 <= axis < NumAxes() and
for axis=0 the returned value is the same as Dim0().
*/
int32_t TotSize(int32_t axis) const { return tot_sizes_[axis]; }

// equivalent to TotSize(0).
int32_t Dim0() const { return TotSize(0); }

/* Return the device-accessible meta-row-splits, which is the cumulative sum,
along the src axis, of the tot-sizes of the individual arrays.
This Array2 is of shape [NumAxes()][NumSrcs() + 1], indexed [axis][src];
caution, the indexing is different from RowSplits(), there is no offset.
Also, the meta_row_splits0 is a thing, unlike with regular row-splits
which start from 1.

Caution: the lengths of the arrays pointed to by the elements of this
Array2 (which contains pointers!) are of course all different, and
these lengths are currently only available

Implementation note: we can probably just populate this on CPU and transfer
to GPU, this will be faster than invoking an extra kernel in normal cases
when the NumSrcs() is small. [Also: see GetRowInfoMulti()].
*/
// TODO: implement it...
Array2<int32_t> MetaRowSplits();

// could POSSIBLY add this so this code could be used in functions like
// Stack(). would be like MetaRowSplits but with an extra 1st row containing
// 0,1,2,... We could perhaps create it with 1 extra initial row so this is
// always convenient to output.
// TODO: implement it...
Array2<int32_t> Offsets();

/*
Returns the meta-row-splits for a particular axis, with 0 <= axis <
NumAxes(); this is the cumulative sum of the TotSize(axis) for all of the
sources, with MetaRowSplits(axis).Dim() == NumSrcs() + 1.

Note: in ragged_ops.cu we refer to this as composed_row_splits
*/
// TODO: implement it...
Array1<int32_t> MetaRowSplits(int32_t axis);

/* Return the device-accessible meta-row-ids, which are the row-ids
corresponding to MetaRowSplits(); this tells us, for indexes into the
appended/concatenated array, which source array they belong to, i.e.
elements are in [0,NumSrcs()-1].

This cannot be an Array2 because unlike the MetaRowSplits(), all the
row-ids arrays are of different lengths.

Note: in ragged_ops.cu we refer to this as composed_row_ids.
*/
// TODO: implement it...
Array1<int32_t *> MetaRowIds();

/*
Returns the meta-row-ids for a particular axis, with 0 <= axis < NumAxes();
this is the row-ids corresponding to MetaRowSplits(axis), and its elements
gives, for indexes into the concatentated shape (concatenated on axis 0),m
which source they come from. E.g. element 100 of MetaRowIds(2)
would tell us which source an idx012 with value 100 into axis 2 of
concatenated array would come from.
*/
// TODO: implement it...
Array1<int32_t> MetaRowIds(int32_t axis);

private:
ContextPtr c_;
int32_t num_srcs_;
int32_t num_axes_;
Array2<int32_t *> row_splits_; // shape [num_axes_ - 1][num_srcs_]
Array2<int32_t *> row_ids_; // shape [num_axes_ - 1][num_srcs_]
Array1<int32_t> tot_sizes_; // dim num_axes_, this is on CPU
};

/*
ArrayOfRagged<T> is a 1-dimensional array of Ragged<T>.
pkufool marked this conversation as resolved.
Show resolved Hide resolved
It is intended for situations where you want to do some operations on
arrays of ragged arrays, without explicitly concatenating them (e.g. to
save time). This is a fairly low-level interface, intended to
be used mostly by CUDA/C++ implementation code. It is a convenience
wrapper that saves you the trouble of creating arrays of pointers.
*/
template <typename T>
struct Array1OfRagged {
Array1OfRaggedShape shape;

// Array of the individual values pointers of the source arrays, indexed by
// shape
Array1<T *> values;

int32_t NumSrcs() { return values.Dim(); }
pkufool marked this conversation as resolved.
Show resolved Hide resolved
ContextPtr &Context() { return shape.Context(); }

Array1OfRagged() = default;

Array1OfRagged(Ragged<T> *srcs, int32_t num_srcs) {
pkufool marked this conversation as resolved.
Show resolved Hide resolved
K2_CHECK_GE(num_srcs, 1);
K2_CHECK(srcs);
values = Array1<T *>(GetCpuContext(), num_srcs);
T **values_data = values.Data();
std::vector<RaggedShape> shapes(num_srcs);
for (int32_t i = 0; i < num_srcs; ++i) {
shapes[i] = srcs[i].shape;
values_data[i] = srcs[i].values.Data();
}
shape = Array1OfRaggedShape(shapes.data(), num_srcs);
values = values.To(shape.Context());
}
};

} // namespace k2

#endif // K2_CSRC_ARRAY_OF_RAGGED_H_
79 changes: 79 additions & 0 deletions k2/csrc/array_of_ragged_test.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/**
* Copyright 2022 Xiaomi Corporation (authors: Wei Kang)
*
* See LICENSE for clarification regarding multiple authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <gtest/gtest.h>
pkufool marked this conversation as resolved.
Show resolved Hide resolved

#include "k2/csrc/array_of_ragged.h"
#include "k2/csrc/ragged.h"
#include "k2/csrc/ragged_ops.h"
#include "k2/csrc/ragged_utils.h"
#include "k2/csrc/test_utils.h"

namespace k2 {

template <typename T>
void TestArray1OfRaggedConstruct() {
int32_t num_srcs = 5;
int32_t num_axes = 4;

for (auto &c : {GetCpuContext(), GetCudaContext()}) {
std::vector<Ragged<T>> raggeds;
for (int32_t i = 0; i < num_srcs; ++i) {
raggeds.emplace_back(
RandomRagged<T>(0 /*min_value*/, 100 /*max_value*/,
num_axes /*min_num_axes*/, num_axes /*max_num_axes*/,
0 /*min_num_elements*/, 100 /*max_num_elements*/)
.To(c, true /*copy_all*/));
}
auto array_of_ragged = Array1OfRagged<T>(raggeds.data(), num_srcs);
for (int32_t j = 1; j < num_axes; ++j) {
int32_t **row_splits = array_of_ragged.shape.RowSplits(j);
int32_t **row_ids = array_of_ragged.shape.RowIds(j);
Array1<int32_t *> except_row_splits(GetCpuContext(), num_srcs);
Array1<int32_t *> except_row_ids(GetCpuContext(), num_srcs);
pkufool marked this conversation as resolved.
Show resolved Hide resolved
int32_t **except_row_splits_data = except_row_splits.Data();
int32_t **except_row_ids_data = except_row_ids.Data();
for (int32_t i = 0; i < num_srcs; ++i) {
except_row_splits_data[i] = raggeds[i].RowSplits(j).Data();
except_row_ids_data[i] = raggeds[i].RowIds(j).Data();
}
except_row_splits = except_row_splits.To(c);
except_row_ids = except_row_ids.To(c);
except_row_splits_data = except_row_splits.Data();
except_row_ids_data = except_row_ids.Data();
Array1<int32_t> flags(c, 2, 1);
int32_t *flags_data = flags.Data();
K2_EVAL(
c, num_srcs, lambda_check_pointer, (int32_t i) {
if (row_splits[i] != except_row_splits_data[i]) flags_data[0] = 0;
if (row_ids[i] != except_row_ids_data[i]) flags_data[1] = 0;
});
K2_CHECK(Equal(flags, Array1<int32_t>(c, std::vector<int32_t>{1, 1})));
}
for (int32_t i = 0; i < num_srcs; ++i) {
K2_CHECK_EQ(array_of_ragged.values[i], raggeds[i].values.Data());
}
}
}

TEST(Array1OfRagged, Construct) {
TestArray1OfRaggedConstruct<int32_t>();
TestArray1OfRaggedConstruct<float>();
}

} // namespace k2
Loading