Skip to content

Commit

Permalink
[src] nnet3: Add the "per-frame" option to DropoutComponent (#1324)
Browse files Browse the repository at this point in the history
  • Loading branch information
GaofengCheng authored and danpovey committed Jan 25, 2017
1 parent 82167f9 commit 9208165
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 31 deletions.
72 changes: 54 additions & 18 deletions src/nnet3/nnet-simple-component.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,27 +87,34 @@ void PnormComponent::Write(std::ostream &os, bool binary) const {
}


void DropoutComponent::Init(int32 dim, BaseFloat dropout_proportion) {
void DropoutComponent::Init(int32 dim, BaseFloat dropout_proportion,
bool dropout_per_frame) {
dropout_proportion_ = dropout_proportion;
dropout_per_frame_ = dropout_per_frame;
dim_ = dim;
}

void DropoutComponent::InitFromConfig(ConfigLine *cfl) {
int32 dim = 0;
BaseFloat dropout_proportion = 0.0;
bool dropout_per_frame = false;
bool ok = cfl->GetValue("dim", &dim) &&
cfl->GetValue("dropout-proportion", &dropout_proportion);
cfl->GetValue("dropout-per-frame", &dropout_per_frame);
// for this stage, dropout is hard coded in
// normal mode if not declared in config
if (!ok || cfl->HasUnusedValues() || dim <= 0 ||
dropout_proportion < 0.0 || dropout_proportion > 1.0)
KALDI_ERR << "Invalid initializer for layer of type "
<< Type() << ": \"" << cfl->WholeLine() << "\"";
Init(dim, dropout_proportion);
KALDI_ERR << "Invalid initializer for layer of type "
<< Type() << ": \"" << cfl->WholeLine() << "\"";
Init(dim, dropout_proportion, dropout_per_frame);
}

std::string DropoutComponent::Info() const {
std::ostringstream stream;
stream << Type() << ", dim=" << dim_
<< ", dropout-proportion=" << dropout_proportion_;
<< ", dropout-proportion=" << dropout_proportion_
<< ", dropout-per-frame=" << (dropout_per_frame_ ? "true" : "false");
return stream.str();
}

Expand All @@ -119,16 +126,29 @@ void DropoutComponent::Propagate(const ComponentPrecomputedIndexes *indexes,

BaseFloat dropout = dropout_proportion_;
KALDI_ASSERT(dropout >= 0.0 && dropout <= 1.0);
if (!dropout_per_frame_) {
// This const_cast is only safe assuming you don't attempt
// to use multi-threaded code with the GPU.
const_cast<CuRand<BaseFloat>&>(random_generator_).RandUniform(out);

// This const_cast is only safe assuming you don't attempt
// to use multi-threaded code with the GPU.
const_cast<CuRand<BaseFloat>&>(random_generator_).RandUniform(out);
out->Add(-dropout); // now, a proportion "dropout" will be <0.0
// apply the function (x>0?1:0). Now, a proportion
// "dropout" will be zero and (1 - dropout) will be 1.0.
out->ApplyHeaviside();

out->Add(-dropout); // now, a proportion "dropout" will be <0.0
out->ApplyHeaviside(); // apply the function (x>0?1:0). Now, a proportion "dropout" will
// be zero and (1 - dropout) will be 1.0.

out->MulElements(in);
out->MulElements(in);
} else {
// randomize the dropout matrix by row,
// i.e. [[1,1,1,1],[0,0,0,0],[0,0,0,0],[1,1,1,1],[0,0,0,0]]
CuMatrix<BaseFloat> tmp(1, out->NumRows(), kUndefined);
// This const_cast is only safe assuming you don't attempt
// to use multi-threaded code with the GPU.
const_cast<CuRand<BaseFloat>&>(random_generator_).RandUniform(&tmp);
tmp.Add(-dropout);
tmp.ApplyHeaviside();
out->CopyColsFromVec(tmp.Row(0));
out->MulElements(in);
}
}


Expand All @@ -150,11 +170,25 @@ void DropoutComponent::Backprop(const std::string &debug_info,


void DropoutComponent::Read(std::istream &is, bool binary) {
ExpectOneOrTwoTokens(is, binary, "<DropoutComponent>", "<Dim>");
ReadBasicType(is, binary, &dim_);
ExpectToken(is, binary, "<DropoutProportion>");
ReadBasicType(is, binary, &dropout_proportion_);
ExpectToken(is, binary, "</DropoutComponent>");
std::string token;
ReadToken(is, binary, &token);
if (token == "<DropoutComponent>") {
ReadToken(is, binary, &token);
}
KALDI_ASSERT(token == "<Dim>");
ReadBasicType(is, binary, &dim_); // read dimension.
ReadToken(is, binary, &token);
KALDI_ASSERT(token == "<DropoutProportion>");
ReadBasicType(is, binary, &dropout_proportion_); // read dropout rate
ReadToken(is, binary, &token);
if (token == "<DropoutPerFrame>") {
ReadBasicType(is, binary, &dropout_per_frame_); // read dropout mode
ReadToken(is, binary, &token);
KALDI_ASSERT(token == "</DropoutComponent>");
} else {
dropout_per_frame_ = false;
KALDI_ASSERT(token == "</DropoutComponent>");
}
}

void DropoutComponent::Write(std::ostream &os, bool binary) const {
Expand All @@ -163,6 +197,8 @@ void DropoutComponent::Write(std::ostream &os, bool binary) const {
WriteBasicType(os, binary, dim_);
WriteToken(os, binary, "<DropoutProportion>");
WriteBasicType(os, binary, dropout_proportion_);
WriteToken(os, binary, "<DropoutPerFrame>");
WriteBasicType(os, binary, dropout_per_frame_);
WriteToken(os, binary, "</DropoutComponent>");
}

Expand Down
20 changes: 14 additions & 6 deletions src/nnet3/nnet-simple-component.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,16 @@ class PnormComponent: public Component {
// "Dropout: A Simple Way to Prevent Neural Networks from Overfitting".
class DropoutComponent : public RandomComponent {
public:
void Init(int32 dim, BaseFloat dropout_proportion = 0.0);
void Init(int32 dim, BaseFloat dropout_proportion = 0.0,
bool dropout_per_frame = false);

DropoutComponent(int32 dim, BaseFloat dropout = 0.0) { Init(dim, dropout); }
DropoutComponent(int32 dim, BaseFloat dropout = 0.0,
bool dropout_per_frame = false) {
Init(dim, dropout, dropout_per_frame);
}

DropoutComponent(): dim_(0), dropout_proportion_(0.0) { }
DropoutComponent(): dim_(0), dropout_proportion_(0.0),
dropout_per_frame_(false) { }

virtual int32 Properties() const {
return kLinearInInput|kBackpropInPlace|kSimpleComponent|kBackpropNeedsInput|kBackpropNeedsOutput;
Expand Down Expand Up @@ -120,17 +125,20 @@ class DropoutComponent : public RandomComponent {
Component *to_update,
CuMatrixBase<BaseFloat> *in_deriv) const;
virtual Component* Copy() const { return new DropoutComponent(dim_,
dropout_proportion_); }
dropout_proportion_,
dropout_per_frame_); }
virtual std::string Info() const;

void SetDropoutProportion(BaseFloat dropout_proportion) { dropout_proportion_ = dropout_proportion; }
void SetDropoutProportion(BaseFloat dropout_proportion) {
dropout_proportion_ = dropout_proportion;
}

private:
int32 dim_;
/// dropout-proportion is the proportion that is dropped out,
/// e.g. if 0.1, we set 10% to zero value.
BaseFloat dropout_proportion_;

bool dropout_per_frame_;
};

class ElementwiseProductComponent: public Component {
Expand Down
12 changes: 7 additions & 5 deletions src/nnet3/nnet-utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -625,19 +625,21 @@ void ReadEditConfig(std::istream &edit_config_is, Nnet *nnet) {
KALDI_ERR << "In edits-config, expected proportion to be set in line: "
<< config_line.WholeLine();
}
DropoutComponent *component = NULL;
DropoutComponent *dropout_component = NULL;
int32 num_dropout_proportions_set = 0;
for (int32 c = 0; c < nnet->NumComponents(); c++) {
if (NameMatchesPattern(nnet->GetComponentName(c).c_str(),
name_pattern.c_str()) &&
(component =
(dropout_component =
dynamic_cast<DropoutComponent*>(nnet->GetComponent(c)))) {
component->SetDropoutProportion(proportion);
num_dropout_proportions_set++;
if (dropout_component != NULL) {
dropout_component->SetDropoutProportion(proportion);
num_dropout_proportions_set++;
}
}
}
KALDI_LOG << "Set dropout proportions for "
<< num_dropout_proportions_set << " nodes.";
<< num_dropout_proportions_set << " components.";
} else {
KALDI_ERR << "Directive '" << directive << "' is not currently "
"supported (reading edit-config).";
Expand Down
4 changes: 2 additions & 2 deletions src/nnet3bin/nnet3-copy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,10 @@ int main(int argc, char *argv[]) {

if (learning_rate >= 0)
SetLearningRate(learning_rate, &nnet);

if (scale != 1.0)
ScaleNnet(scale, &nnet);

if (!edits_config.empty()) {
Input ki(edits_config);
ReadEditConfig(ki.Stream(), &nnet);
Expand Down

0 comments on commit 9208165

Please sign in to comment.