Skip to content

Commit

Permalink
Variables init and decl in one line; check attr in constructor; check…
Browse files Browse the repository at this point in the history
… bounds before converting int64 to int; and other minor changes
  • Loading branch information
kaixih committed Dec 5, 2019
1 parent 4a89f04 commit cbf169c
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 133 deletions.
55 changes: 29 additions & 26 deletions tensorflow/core/kernels/ctc_loss_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -222,15 +222,30 @@ REGISTER_CPU(double);

#if GOOGLE_CUDA
class CTCLossOpGPU : public OpKernel {

public:
explicit CTCLossOpGPU(OpKernelConstruction* ctx) : OpKernel(ctx) {
bool preprocess_collapse_repeated_;
bool ctc_merge_repeated_;
bool ignore_longer_outputs_than_inputs_;
OP_REQUIRES_OK(ctx, ctx->GetAttr("preprocess_collapse_repeated",
&preprocess_collapse_repeated_));
OP_REQUIRES_OK(ctx,
ctx->GetAttr("ctc_merge_repeated", &ctc_merge_repeated_));
OP_REQUIRES_OK(ctx, ctx->GetAttr("ignore_longer_outputs_than_inputs",
&ignore_longer_outputs_than_inputs_));

OP_REQUIRES(ctx, !preprocess_collapse_repeated_,
errors::InvalidArgument("GPU CTCLossOp requires "
"preprocess_collapse_repeated to be "
"false"));
OP_REQUIRES(ctx, ctc_merge_repeated_,
errors::InvalidArgument("GPU CTCLossOp requires "
"ctc_merge_repeated_ to be "
"true"));
OP_REQUIRES(ctx, !ignore_longer_outputs_than_inputs_,
errors::InvalidArgument("GPU CTCLossOp requires "
"ignore_longer_outputs_than_inputs_ to"
"be false"));
}

void Compute(OpKernelContext* ctx) override {
Expand All @@ -256,6 +271,12 @@ class CTCLossOpGPU : public OpKernel {
const int64 max_time_raw = inputs_shape.dim_size(0);
const int64 batch_size_raw = inputs_shape.dim_size(1);
const int64 num_classes_raw = inputs_shape.dim_size(2);
OP_REQUIRES(
ctx, FastBoundsCheck(max_time_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("max_time_ cannot exceed max int"));
OP_REQUIRES(
ctx, FastBoundsCheck(batch_size_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("batch_size cannot exceed max int"));
OP_REQUIRES(
ctx, FastBoundsCheck(num_classes_raw, std::numeric_limits<int>::max()),
errors::InvalidArgument("num_classes cannot exceed max int"));
Expand All @@ -279,7 +300,6 @@ class CTCLossOpGPU : public OpKernel {

OP_REQUIRES(ctx, batch_size != 0,
errors::InvalidArgument("batch_size must not be 0"));


Tensor* loss = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output("loss", seq_len->shape(), &loss));
Expand All @@ -288,40 +308,26 @@ class CTCLossOpGPU : public OpKernel {
OP_REQUIRES_OK(ctx,
ctx->allocate_output("gradient", inputs_shape, &gradient));

OP_REQUIRES(ctx, preprocess_collapse_repeated_ == false,
errors::InvalidArgument("GPU CTCLossOp requires "
"preprocess_collapse_repeated to be "
"false"));
OP_REQUIRES(ctx, ctc_merge_repeated_ == true,
errors::InvalidArgument("GPU CTCLossOp requires "
"ctc_merge_repeated_ to be "
"true"));
OP_REQUIRES(ctx, ignore_longer_outputs_than_inputs_ == false,
errors::InvalidArgument("GPU CTCLossOp requires "
"ignore_longer_outputs_than_inputs_ to"
"be false"));

// Convert the labels_indices to labels_lengths

// Convert the labels_indices to labels_lengths.
std::vector<int> labels_lengths(batch_size, 0);
DoHistogram<int64>(ctx, labels_indices, num_indices, batch_size,
&labels_lengths);

StreamExecutor* executor = ctx->op_device_context()->stream()->parent();
se::dnn::DataType data_type = ToDataType<float>::value;

se::dnn::CtcLossDescriptor ctc_loss_desc;
std::unique_ptr<RnnStateTensorDescriptor> probs_desc;
std::unique_ptr<RnnStateTensorDescriptor> grads_desc;

auto probs_desc_s = executor->createRnnStateTensorDescriptor(
max_time, batch_size, num_classes, data_type);
OP_REQUIRES_OK(ctx, probs_desc_s.status());
probs_desc = probs_desc_s.ConsumeValueOrDie();
std::unique_ptr<RnnStateTensorDescriptor> probs_desc =
probs_desc_s.ConsumeValueOrDie();

auto grads_desc_s = executor->createRnnStateTensorDescriptor(
max_time, batch_size, num_classes, data_type);
OP_REQUIRES_OK(ctx, grads_desc_s.status());
grads_desc = grads_desc_s.ConsumeValueOrDie();
std::unique_ptr<RnnStateTensorDescriptor> grads_desc =
grads_desc_s.ConsumeValueOrDie();

absl::Span<const int32> labels_data(labels_values->flat<int32>().data(),
num_indices);
Expand All @@ -338,6 +344,7 @@ class CTCLossOpGPU : public OpKernel {
DnnScratchAllocator workspace_allocator(1LL << 32, ctx);

Stream* stream = ctx->op_device_context()->stream();
se::dnn::CtcLossDescriptor ctc_loss_desc;
bool cudnn_launch_status =
stream
->ThenCtcLoss(
Expand All @@ -353,10 +360,6 @@ class CTCLossOpGPU : public OpKernel {
}

private:
bool preprocess_collapse_repeated_;
bool ctc_merge_repeated_;
bool ignore_longer_outputs_than_inputs_;

TF_DISALLOW_COPY_AND_ASSIGN(CTCLossOpGPU);
};

Expand Down
57 changes: 0 additions & 57 deletions tensorflow/core/util/cudnn_scratch_allocator.cc

This file was deleted.

50 changes: 0 additions & 50 deletions tensorflow/core/util/cudnn_scratch_allocator.h

This file was deleted.

0 comments on commit cbf169c

Please sign in to comment.