Skip to content

Commit

Permalink
Getting shortcut compilation to the point where it's testable (test f…
Browse files Browse the repository at this point in the history
…ailing thouth)
  • Loading branch information
danpovey committed Dec 15, 2016
1 parent 420b2cb commit 83f205a
Show file tree
Hide file tree
Showing 6 changed files with 316 additions and 71 deletions.
6 changes: 3 additions & 3 deletions src/nnet3/nnet-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,9 @@ struct Index {
bool operator < (const Index &a) const {
if (t < a.t) { return true; }
else if (t > a.t) { return false; }
else if (n < a.n) { return true; }
else if (n > a.n) { return false; }
else return (x < a.x);
else if (x < a.x) { return true; }
else if (x > a.x) { return false; }
else return (n < a.n);
}
Index operator + (const Index &other) const {
return Index(n+other.n, t+other.t, x+other.x);
Expand Down
78 changes: 61 additions & 17 deletions src/nnet3/nnet-optimize-test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,12 @@
namespace kaldi {
namespace nnet3 {

// Run the test wothout optimizations and with optimizations specified by the
// parameter. Only print warnings; we'll fail the whole test later.
static bool UnitTestNnetOptimizeWithOptions(NnetOptimizeOptions opt_config) {
// Run the test without optimizations and with optimizations specified by the
// configs (the optimized version is done with class CachingOptimizingCompiler).
// Only print warnings; we'll fail the whole test later.
static bool UnitTestNnetOptimizeWithOptions(NnetOptimizeOptions opt_config,
CachingOptimizingCompilerOptions compiler_config) {

//opt_config.convert_addition = false;
//opt_config.remove_assignments = false;
//opt_config.move_sizing_commands = false;
Expand Down Expand Up @@ -60,20 +63,19 @@ static bool UnitTestNnetOptimizeWithOptions(NnetOptimizeOptions opt_config) {
{
std::ostringstream os;
computation.Print(os, nnet);
KALDI_LOG << "Generated computation is: " << os.str();
KALDI_LOG << "Generated computation with no optimization or shortcut is: " << os.str();
}
CheckComputationOptions check_config;
// we can do the rewrite check since it's before optimization.
check_config.check_rewrite = true;
ComputationChecker checker(check_config, nnet, computation);
checker.Check();

NnetComputation computation_opt(computation);
CachingOptimizingCompiler opt_compiler(nnet, opt_config, compiler_config);

const NnetComputation &computation_opt = *opt_compiler.Compile(request);

{
Optimize(opt_config, nnet,
MaxOutputTimeInRequest(request),
&computation_opt);
std::ostringstream os;
computation_opt.Print(os, nnet);
KALDI_LOG << "Optimized computation is: " << os.str();
Expand All @@ -84,7 +86,8 @@ static bool UnitTestNnetOptimizeWithOptions(NnetOptimizeOptions opt_config) {
compute_opts.debug = true;

computation.ComputeCudaIndexes();
computation_opt.ComputeCudaIndexes();
// computation_opt has already had this function called.

Nnet nnet_to_update(nnet); // copy of the nnet that we update... needed to
// test the consolidation of backprop commands,
// otherwise the optimized and non-optimized
Expand Down Expand Up @@ -179,6 +182,8 @@ static bool UnitTestNnetOptimizeWithOptions(NnetOptimizeOptions opt_config) {
// the outputs are the same.
static void UnitTestNnetOptimize() {
NnetOptimizeOptions optimize_all;
CachingOptimizingCompilerOptions compiler_all;

// randomly sometimes set min_deriv and max_deriv to small/large values,
// which will cause some of the LimitDerivativeTimes() code to be called
// (without really changing anything).
Expand All @@ -187,44 +192,83 @@ static void UnitTestNnetOptimize() {

// this is useful for debugging as it removes nans:
// optimize_all.initialize_undefined = false;
bool success = UnitTestNnetOptimizeWithOptions(optimize_all);
bool success = UnitTestNnetOptimizeWithOptions(optimize_all,
compiler_all);
if (success)
return;

// Test failed with full optimization. Slowly retry with various
// optimizations switched off.
NnetOptimizeOptions optimize = optimize_all;
optimize.propagate_in_place = false;
bool succ_no_propagate_in_place = UnitTestNnetOptimizeWithOptions(optimize);
CachingOptimizingCompilerOptions compiler = compiler_all;


compiler.use_shortcut = false;
bool succ_no_shortcut = UnitTestNnetOptimizeWithOptions(optimize,
compiler);
compiler = compiler_all;


optimize.propagate_in_place = false;
bool succ_no_propagate_in_place = UnitTestNnetOptimizeWithOptions(optimize,
compiler);
optimize = optimize_all;

optimize.backprop_in_place = false;
bool succ_no_backprop_in_place = UnitTestNnetOptimizeWithOptions(optimize);
bool succ_no_backprop_in_place = UnitTestNnetOptimizeWithOptions(optimize,
compiler);
optimize = optimize_all;

optimize.optimize_row_ops = false;
bool succ_no_row_ops = UnitTestNnetOptimizeWithOptions(optimize,
compiler);
optimize = optimize_all;
optimize.remove_assignments = false;
bool succ_no_remove_assignments = UnitTestNnetOptimizeWithOptions(optimize);

optimize.convert_addition = false;
bool succ_no_convert_addition = UnitTestNnetOptimizeWithOptions(optimize,
compiler);
optimize = optimize_all;

optimize.remove_assignments = false;
bool succ_no_remove_assignments = UnitTestNnetOptimizeWithOptions(optimize,
compiler);
optimize = optimize_all;

optimize.initialize_undefined = false;
bool succ_no_initialize_undefined = UnitTestNnetOptimizeWithOptions(optimize);
bool succ_no_initialize_undefined = UnitTestNnetOptimizeWithOptions(optimize,
compiler);
optimize = optimize_all;

optimize.allocate_from_other = false;
bool succ_no_allocate_from_other = UnitTestNnetOptimizeWithOptions(optimize,
compiler);
optimize = optimize_all;

optimize.move_sizing_commands = false;
bool succ_no_move_sizing_commands = UnitTestNnetOptimizeWithOptions(optimize);
bool succ_no_move_sizing_commands = UnitTestNnetOptimizeWithOptions(optimize,
compiler);
optimize = optimize_all;

#define KALDI_SUCCFAIL(b) ((b) ? "SUCCESS" : "FAILURE")
KALDI_ERR
<< "Test failed with all optimizations enabled. Retried test with the "
<< "following optimizations turned off:"
<< "\n use_shortcut ... " << KALDI_SUCCFAIL(succ_no_shortcut)
<< "\n propagate_in_place ... " << KALDI_SUCCFAIL(succ_no_propagate_in_place)
<< "\n backprop_in_place ... " << KALDI_SUCCFAIL(succ_no_backprop_in_place)
<< "\n optimize_row_ops ... " << KALDI_SUCCFAIL(succ_no_row_ops)
<< "\n convert_addition ... " << KALDI_SUCCFAIL(succ_no_convert_addition)
<< "\n remove_assignments ... " << KALDI_SUCCFAIL(succ_no_remove_assignments)
<< "\n initialize_undefined ... " << KALDI_SUCCFAIL(succ_no_initialize_undefined)
<< "\n allocate_from_other ... " << KALDI_SUCCFAIL(succ_no_allocate_from_other)
<< "\n move_sizing_commands ... " << KALDI_SUCCFAIL(succ_no_move_sizing_commands);
#undef KALDI_SUCCFAIL
}





} // namespace nnet3
} // namespace kaldi

Expand Down
124 changes: 124 additions & 0 deletions src/nnet3/nnet-optimize-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2635,6 +2635,7 @@ void ComputationExpander::InitFastInfo() {
// 'n' value be zero.
KALDI_ASSERT(debug_info.cindexes[0].second.n == 0);
bool is_fast = (debug_info.cindexes[1].second.n == 1);
n_fast_[m] = is_fast;

bool do_check = (RandInt(0, 2) == 0);
if (do_check) {
Expand Down Expand Up @@ -2983,6 +2984,129 @@ void ComputationExpander::GetNewLocationInfo(
}
}


void ExpandComputation(const Nnet &nnet,
const MiscComputationInfo &misc_info,
const NnetComputation &computation,
bool need_debug_info,
int32 num_n_values,
NnetComputation *expanded_computation) {
ComputationExpander expander(nnet, misc_info, computation,
need_debug_info, num_n_values,
expanded_computation);
expander.Expand();
}



// This helper function is used in RequestIsDecomposable(); you can work out
// what it does, and why, from the documentation of RequestIsDecomposable() in
// the header.
static bool IoSpecificationIsDecomposable(const IoSpecification &io_spec,
IoSpecification *mini_io_spec,
int32 *num_n_values_out) {
mini_io_spec->name = io_spec.name;
mini_io_spec->has_deriv = io_spec.has_deriv;
const std::vector<Index> &indexes = io_spec.indexes;
KALDI_ASSERT(!indexes.empty() && "Empty Indexes in computation request");
// For a computation to be decomposable, the 'n' values need to vary from 0 to
// N-1 for some N > 2, and they need to be in some kind of regular order with
// suitable repetition-- either with the 'n' values varying the 'fastest', or
// the 'slowest' of all the indexes.
if (indexes[0].n != 0 || indexes.back().n < 2) {
return false;
}
int32 num_n_values = indexes.back().n + 1,
size = indexes.size();
*num_n_values_out = num_n_values;
if (size % num_n_values != 0)
return false;
bool n_fast = (indexes[1].n == 1);
// if 'n_fast' is true, then the n index varies the fastest (stride == 1),
// otherwise it varies the slowest of any index. We require that it be one of
// these two options, otherwise we declare the computation to be
// non-decomposable.

mini_io_spec->indexes.resize((size / num_n_values) * 2);
if (n_fast) {
// 'block_size' is the size of blocks with the same x,t values, which are
// expected to have n values 0, 1, ... num_n_values - 1.
// of course each block is of size num_n_values.
int32 num_blocks = size / num_n_values;
const Index *indexes_ptr = &(indexes[0]);
Index *indexes_out = &(mini_io_spec->indexes[0]);
for (int32 block = 0; block < num_blocks; block++) {
*(indexes_out++) = indexes_ptr[0]; // for n == 0
*(indexes_out++) = indexes_ptr[1]; // for n == 1.

// we expect all the indexes in this block to have the same x and t
// values, but n values increasing from 0 to num_n_values - 1.
int32 t = indexes_ptr->t, x = indexes_ptr->x;

for (int32 n = 0; n < num_n_values; n++, indexes_ptr++) {
if (indexes_ptr->n != n || indexes_ptr->t != t || indexes_ptr->x != x)
return false;
}
}
} else {
// 'n' varies the slowest.
int32 block_size = size / num_n_values;
mini_io_spec->indexes.clear();
mini_io_spec->indexes.insert(mini_io_spec->indexes.end(),
indexes.begin(),
indexes.begin() + 2 * block_size);

// now verify that it has the expected structure...
for (int32 i = 0; i < block_size; i++) {
const Index *indexes_ptr = &(indexes[i]);
int32 t = indexes_ptr->t, x = indexes_ptr->x;
for (int32 n = 0; n < num_n_values; n++, indexes_ptr += block_size) {
if (indexes_ptr->n != n || indexes_ptr->t != t || indexes_ptr->x != x)
return false;
}
}
}
return true;
}

bool RequestIsDecomposable(const ComputationRequest &request,
ComputationRequest *mini_request,
int32 *num_n_values) {
size_t num_inputs = request.inputs.size(),
num_outputs = request.outputs.size();
mini_request->inputs.resize(num_inputs);
mini_request->outputs.resize(num_outputs);
mini_request->need_model_derivative = request.need_model_derivative;
mini_request->store_component_stats = request.store_component_stats;
mini_request->misc_info = request.misc_info;

KALDI_ASSERT(num_inputs != 0 && num_outputs != 0);
for (size_t i = 0; i < num_inputs; i++) {
int32 this_num_n_values = 0;
if (!IoSpecificationIsDecomposable(request.inputs[i],
&(mini_request->inputs[i]),
&this_num_n_values))
return false;
if (i == 0) {
*num_n_values = this_num_n_values;
} else {
if (this_num_n_values != *num_n_values)
return false; // .. which would be odd.
}
}
for (size_t i = 0; i < num_outputs; i++) {
int32 this_num_n_values = 0;
if (!IoSpecificationIsDecomposable(request.outputs[i],
&(mini_request->outputs[i]),
&this_num_n_values))
return false;
if (this_num_n_values != *num_n_values)
return false; // .. which would be odd.
}
return true;
}


class ComputationLoopedOptimizer {
public:
ComputationLoopedOptimizer(const Nnet &nnet,
Expand Down
6 changes: 3 additions & 3 deletions src/nnet3/nnet-optimize-utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,9 +379,9 @@ void LimitDerivativeTimes(const Nnet &nnet,
reason to the order of the t and x values; the regularity on 'n' is
all that we care about.
*/
bool ComputationIsDecomposable(const ComputationRequest &request,
ComputationRequest *mini_request,
int32 *num_n_values); // TODO: implement this.
bool RequestIsDecomposable(const ComputationRequest &request,
ComputationRequest *mini_request,
int32 *num_n_values);


/**
Expand Down
Loading

0 comments on commit 83f205a

Please sign in to comment.