Skip to content

Commit

Permalink
[src] Fix bug in model-update consolidation code (thanks: sriram gana…
Browse files Browse the repository at this point in the history
…pathy).
  • Loading branch information
danpovey committed Aug 9, 2017
1 parent 53e5e12 commit 4d27deb
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 6 deletions.
8 changes: 4 additions & 4 deletions src/gmm/diag-gmm-test.cc
Expand Up @@ -26,7 +26,7 @@ namespace kaldi {

void InitRandomGmm(DiagGmm *gmm_in) {
int32 num_gauss = 10 + Rand() % 5;
int32 dim = 20 + Rand() % 15;
int32 dim = 10 + Rand() % 10;
DiagGmm &gmm(*gmm_in);
gmm.Resize(num_gauss, dim);
Matrix<BaseFloat> inv_vars(num_gauss, dim),
Expand Down Expand Up @@ -85,9 +85,9 @@ void UnitTestDiagGmmGenerate() {

void UnitTestDiagGmm() {
// random dimension of the gmm
size_t dim = 1 + kaldi::RandInt(0, 9);
size_t dim = 1 + kaldi::RandInt(0, 5);
// random number of mixtures
size_t nMix = 1 + kaldi::RandInt(0, 9);
size_t nMix = 1 + kaldi::RandInt(0, 5);

std::cout << "Testing NumGauss: " << nMix << ", " << "Dim: " << dim
<< '\n';
Expand Down Expand Up @@ -284,7 +284,7 @@ void UnitTestDiagGmm() {
std::vector<std::pair<BaseFloat, const DiagGmm*> > vec;
vec.push_back(std::make_pair(static_cast<BaseFloat>(0.4), (const DiagGmm*)(&gmm1)));
vec.push_back(std::make_pair(static_cast<BaseFloat>(0.6), (const DiagGmm*)(&gmm1)));

DiagGmm gmm2(vec);

float loglike1 = gmm1.LogLikelihood(feat);
Expand Down
7 changes: 5 additions & 2 deletions src/nnet3/nnet-optimize-utils.cc
Expand Up @@ -1235,7 +1235,8 @@ void ModelUpdateConsolidator::ConsolidateModelUpdate() {
int32 num_components = nnet_.NumComponents(),
num_commands = computation_->commands.size();
// 'backprop_commands' is a list, for each component (but nonempty only for
// updatable components), of the command indexes for the backprop commands.
// updatable simple components), of the command indexes for the backprop
// commands.
std::vector<std::vector<int32> > backprop_commands(num_components);
for (int32 command_index = 0;
command_index < num_commands; command_index++) {
Expand All @@ -1244,7 +1245,9 @@ void ModelUpdateConsolidator::ConsolidateModelUpdate() {
int32 component_index = c.arg1;
const Component *component = nnet_.GetComponent(component_index);
int32 properties = component->Properties();
if ((properties & kUpdatableComponent) && !(properties & kUsesMemo))
if ((properties & kUpdatableComponent) &&
(properties & kSimpleComponent) &&
!(properties & kUsesMemo))
backprop_commands[component_index].push_back(command_index);
}
}
Expand Down

0 comments on commit 4d27deb

Please sign in to comment.