Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Joint training with new l2reg technique #4

Open
wants to merge 37 commits into
base: leaky-hmm-merge-xent
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
c32b235
Tabs to spaces
Jan 11, 2016
cc50db7
Add note about patch failing due to line endings
Jan 11, 2016
1aac4b2
Don't convert patch file line endings to LF
Jan 11, 2016
40136da
Removed note about git archive mangling patch line endings
Jan 11, 2016
50d8431
Don't mangle patch file line endings in all directories
Jan 11, 2016
57e7f78
Change line endings in Windows patch file to CRLF
Jan 11, 2016
155dbc0
fix to how CUDA block and grid sizes are computed for common operatio…
danpovey Jan 12, 2016
1017986
Merge pull request #444 from Timmmm/windows_line_endings
jtrmal Jan 13, 2016
ad004c4
Merge pull request #445 from Timmmm/windows_docs
jtrmal Jan 13, 2016
138a9c8
Clarify the generate_solution.pl command in the Windows INSTALL
Jan 13, 2016
98e45e2
Formatting clean up and convert INSTALL to markdown
Jan 13, 2016
1ee5616
Rename INSTALL to INSTALL.md so github renders it.
Jan 13, 2016
4456ab0
Github markdown requires more spaces
Timmmm Jan 13, 2016
1698b39
minor documentation change
danpovey Jan 14, 2016
83b94ae
bug fix to a rather old script: get_lda_block.sh (in fact this revert…
danpovey Jan 15, 2016
e0cbd32
Separated ARPA parsing from const LM construction
Jan 19, 2016
9abd21c
Clarify that MLK and OpenBLAS are alternatives.
Jan 19, 2016
6265183
Improved error messages for ARPA file parsing
Jan 22, 2016
1aab0b6
Changes per @danpovey's review in #458
Jan 22, 2016
61a551d
Merge pull request #458 from kkm000/arpa-1
danpovey Jan 22, 2016
549af84
Merge pull request #448 from Timmmm/windows_docs
jtrmal Jan 23, 2016
c3fedd1
chain branch: Adding results for 4v (regarding cross-entropy regulari…
danpovey Jan 24, 2016
ffcb552
Code simplification and cleanup that was enabled by the implementatio…
danpovey Jan 25, 2016
4d42ea2
chain branch: add sorting on num-transitions (for very tiny speedup).
danpovey Jan 25, 2016
ca2772b
chain branch: adding results to a script; a couple of new scripts.
danpovey Jan 25, 2016
3342530
chain branch: script edits with new results shown.
danpovey Jan 26, 2016
2f585de
Merge branch 'master' into chain
danpovey Jan 26, 2016
e46406b
change to arpa-file-parser.cc to suppress spurious compiler warning
danpovey Jan 26, 2016
3431b74
chain branch: various new tuning-scripts for chain model.
danpovey Jan 27, 2016
8676636
chain branch: more Switchboard tuning scripts, with results.
danpovey Jan 27, 2016
37261b5
chain models: add results from tuning scripts
danpovey Jan 28, 2016
eb6d9de
addded new l2-regularization method, which regress chain output to be…
pegahgh Jan 28, 2016
af380cd
Merge branch 'leaky-hmm-merge-xent' of https://github.com/danpovey/ka…
pegahgh Jan 28, 2016
188824f
some modification to new l2_regularization method
pegahgh Jan 29, 2016
8269e43
small fix to chain-training.cc
pegahgh Jan 29, 2016
c427693
Merge branch 'chain' of https://github.com/kaldi-asr/kaldi into joint…
pegahgh Jan 29, 2016
a4a0cfb
fixed scale equation
pegahgh Jan 29, 2016
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitattributes
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,6 @@ windows/INSTALL* eol=native
windows/NewGuidCmd.exe.config text eol=crlf
windows/NewGuidCmd.exe binary

# Prevent git changing CR-LF to LF when archiving (patch requires CR-LF on Windows).
**/*.patch -text

2 changes: 2 additions & 0 deletions egs/swbd/s5c/local/chain/README.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,7 @@ ones to look at right now:
4f is a good jesus-layer system
4q is an improved TDNN with various bells and whistles from Vijay.
4r is a slightly-better jesus-layer system than 4f, with one more layer.
5e is the best configuration run so far.



9 changes: 9 additions & 0 deletions egs/swbd/s5c/local/chain/run_tdnn_4v.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,15 @@
# from 1.0 to 2.0 because there is a lot of parameter change in the final xent
# layer, and this limits the rate of change of the other layers.

#./compare_wer.sh 4r 4v
#System 4r 4v
#WER on train_dev(tg) 16.50 15.95
#WER on train_dev(fg) 15.45 14.69
#WER on eval2000(tg) 18.3 17.7
#WER on eval2000(fg) 16.7 16.0
#Final train prob -0.103652 -0.106646 -1.60775
#Final valid prob -0.121105 -0.118631 -1.62832

# _4r is as _4f, but one more hidden layer, and reducing context of existing
# layers so we can re-use the egs. Reducing jesus-forward-output-dim slightly
# from 1500 to 1400.
Expand Down
12 changes: 11 additions & 1 deletion egs/swbd/s5c/local/chain/run_tdnn_4w.sh
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
#!/bin/bash

# _4w is as _4v, but doubling --xent-regularize to 0.2
# _4w is as _4v, but doubling --xent-regularize to 0.2 WER seems consistently a
# bit worse, although final valid prob is very slightly better.

#./compare_wer.sh 4v 4w
#System 4v 4w
#WER on train_dev(tg) 15.95 16.05
#WER on train_dev(fg) 14.69 14.92
#WER on eval2000(tg) 17.7 18.0
#WER on eval2000(fg) 16.0 16.2
#Final train prob -0.106646 -0.108816
#Final valid prob -0.118631 -0.118254

# _4v is as _4r, but with --xent-regularize 0.1. Increasing max_param_change
# from 1.0 to 2.0 because there is a lot of parameter change in the final xent
Expand Down
12 changes: 11 additions & 1 deletion egs/swbd/s5c/local/chain/run_tdnn_4x.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,17 @@
#!/bin/bash

# _4x is as _4u, but with --leaky-hmm-coefficient 0.2. Note: the
# ultimate baseline is 4f.
# ultimate baseline is 4f. It seems a little bit worse than 4u on average: (+0.2, +0.2, 0.0, -0.1).
# So I'm guessing the best value is around --leaky-hmm-coefficient 0.1.
#
# ./compare_wer.sh 4f 4u 4x
# System 4f 4u 4x
# WER on train_dev(tg) 16.83 16.47 16.63
# WER on train_dev(fg) 15.73 15.23 15.42
# WER on eval2000(tg) 18.4 18.4 18.4
# WER on eval2000(fg) 16.6 16.7 16.6
# Final train prob -0.105832 -0.118911 -0.130674
# Final valid prob -0.123021 -0.135768 -0.146351

# _4u is as _4t, but with --leaky-hmm-coefficient 0.08. Note: the
# ultimate baseline is 4f.
Expand Down
401 changes: 401 additions & 0 deletions egs/swbd/s5c/local/chain/run_tdnn_5a.sh

Large diffs are not rendered by default.

404 changes: 404 additions & 0 deletions egs/swbd/s5c/local/chain/run_tdnn_5b.sh

Large diffs are not rendered by default.

409 changes: 409 additions & 0 deletions egs/swbd/s5c/local/chain/run_tdnn_5c.sh

Large diffs are not rendered by default.

407 changes: 407 additions & 0 deletions egs/swbd/s5c/local/chain/run_tdnn_5d.sh

Large diffs are not rendered by default.

417 changes: 417 additions & 0 deletions egs/swbd/s5c/local/chain/run_tdnn_5e.sh

Large diffs are not rendered by default.

423 changes: 423 additions & 0 deletions egs/swbd/s5c/local/chain/run_tdnn_5f.sh

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion egs/wsj/s5/steps/nnet2/get_lda_block.sh
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ while [ $[$cur_index+$block_size] -le $feat_dim ]; do
echo >> $dir/indexes
num_blocks=$[$num_blocks+1]
cur_index=$[$cur_index+$block_shift]
if [ $[$cur_index+$block_size-1] -gt $feat_dim ]; then
if [ $[$cur_index+$block_size] -gt $feat_dim ]; then
cur_index=$[$feat_dim-$block_size];
fi
done
Expand Down
6 changes: 4 additions & 2 deletions egs/wsj/s5/steps/nnet3/chain/train_tdnn.sh
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ right_deriv_truncate= # number of time-steps to avoid using the deriv of, on th

# End configuration section.

trap 'for pid in $(jobs -pr); do kill -KILL $pid; done' INT QUIT TERM
trap 'for pid in $(jobs -pr); do kill -TERM $pid; done' INT QUIT TERM

echo "$0 $@" # Print the command line for logging

Expand Down Expand Up @@ -497,7 +497,9 @@ while [ $x -lt $num_iters ]; do
rm $dir/.error 2>/dev/null


( # this sub-shell is so that when we "wait" below,
(
trap 'for pid in $(jobs -pr); do kill -TERM $pid; done' INT QUIT TERM
# this sub-shell is so that when we "wait" below,
# we only wait for the training jobs that we just spawned,
# not the diagnostic jobs that we spawned above.

Expand Down
14 changes: 9 additions & 5 deletions src/base/kaldi-math.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,20 +41,19 @@
#endif

#ifndef M_PI
# define M_PI 3.1415926535897932384626433832795
#define M_PI 3.1415926535897932384626433832795
#endif

#ifndef M_SQRT2
# define M_SQRT2 1.4142135623730950488016887
#define M_SQRT2 1.4142135623730950488016887
#endif


#ifndef M_2PI
# define M_2PI 6.283185307179586476925286766559005
#define M_2PI 6.283185307179586476925286766559005
#endif

#ifndef M_SQRT1_2
# define M_SQRT1_2 0.7071067811865475244008443621048490
#define M_SQRT1_2 0.7071067811865475244008443621048490
#endif

#ifndef M_LOG_2PI
Expand All @@ -65,6 +64,11 @@
#define M_LN2 0.693147180559945309417232121458
#endif

#ifndef M_LN10
#define M_LN10 2.302585092994045684017991454684
#endif


#define KALDI_ISNAN std::isnan
#define KALDI_ISINF std::isinf
#define KALDI_ISFINITE(x) std::isfinite(x)
Expand Down
2 changes: 1 addition & 1 deletion src/chain/chain-datastruct.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ extern "C" {


// Search for this in chain-kernels.cu for an explanation.
enum { kOccupationRescalingPowerOfTwo = 20, kThresholdingPowerOfTwo = 14 };
enum { kThresholdingPowerOfTwo = 14 };

}

Expand Down
111 changes: 30 additions & 81 deletions src/chain/chain-den-graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -139,87 +139,6 @@ void DenominatorGraph::SetInitialProbs(const fst::StdVectorFst &fst) {

Vector<BaseFloat> avg_prob_float(avg_prob);
initial_probs_ = avg_prob_float;
special_hmm_state_ = ComputeSpecialState(fst, avg_prob_float);
}

int32 NumStatesThatCanReach(const fst::StdVectorFst &fst,
int32 dest_state) {
int32 num_states = fst.NumStates(),
num_states_can_reach = 0;
KALDI_ASSERT(dest_state >= 0 && dest_state < num_states);
std::vector<bool> can_reach(num_states, false);
std::vector<std::vector<int32> > reverse_transitions(num_states);
for (int32 s = 0; s < num_states; s++)
for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, s); !aiter.Done();
aiter.Next())
reverse_transitions[aiter.Value().nextstate].push_back(s);
std::vector<int32> queue;
can_reach[dest_state] = true;
queue.push_back(dest_state);
num_states_can_reach++;
while (!queue.empty()) {
int32 state = queue.back();
queue.pop_back();
std::vector<int32>::const_iterator iter = reverse_transitions[state].begin(),
end = reverse_transitions[state].end();
for (; iter != end; ++iter) {
int32 prev_state = *iter;
if (!can_reach[prev_state]) {
can_reach[prev_state] = true;
queue.push_back(prev_state);
num_states_can_reach++;
}
}
}
KALDI_ASSERT(num_states_can_reach >= 1 &&
num_states_can_reach <= num_states);
return num_states_can_reach;
}


int32 DenominatorGraph::ComputeSpecialState(
const fst::StdVectorFst &fst,
const Vector<BaseFloat> &initial_probs) {
int32 num_states = initial_probs.Dim();
std::vector<int32> num_transitions_into(num_states, 0);
for (int32 s = 0; s < fst.NumStates(); s++) {
for (fst::ArcIterator<fst::StdVectorFst> aiter(fst, s); !aiter.Done();
aiter.Next())
num_transitions_into[aiter.Value().nextstate]++;
}
// this vector 'pairs' is a vector of pairs (-num-transitions-into-state, state).
std::vector<std::pair<int32, int32> > pairs(num_states);
for (int32 i = 0; i < num_states; i++) {
pairs[i].first = -num_transitions_into[i];
pairs[i].second = i;
}
// the first element of each pair is the negative of the num-transitions, so
// when we sort, the highest num-transitions will be first.
std::sort(pairs.begin(), pairs.end());

// this threshold of 0.75 is pretty arbitrary. We reject any
// state if it can't be reached by 75% of all other states.
// In practice we think that states will either be reachable by
// almost-all states, or almost-none (e.g. states that are active
// only at utterance-beginning), so this threshold shouldn't
// be too critical.
int32 min_states_can_reach = 0.75 * num_states;
for (int32 i = 0; i < num_states; i++) {
int32 state = pairs[i].second;
int32 n = NumStatesThatCanReach(fst, state);
if (n < min_states_can_reach) {
KALDI_WARN << "Rejecting state " << state << " as a 'special' HMM state "
<< "(for renormalization in fwd-bkwd), because it's only "
<< "reachable by " << n << " out of " << num_states
<< " states.";
} else {
return state;
}
}
KALDI_ERR << "Found no states that are reachable by at least "
<< min_states_can_reach << " out of " << num_states
<< " states. This is unexpected. Change the threshold";
return -1;
}

void DenominatorGraph::GetNormalizationFst(const fst::StdVectorFst &ifst,
Expand Down Expand Up @@ -271,6 +190,34 @@ void MinimizeAcceptorNoPush(fst::StdVectorFst *fst) {
fst::Decode(fst, encoder);
}

// This static function, used in CreateDenominatorFst, sorts an
// fst's states in decreasing order of number of transitions (into + out of)
// the state. The aim is to have states that have a lot of transitions
// either into them or out of them, be numbered earlier, so hopefully
// they will be scheduled first and won't delay the computation
static void SortOnTransitionCount(fst::StdVectorFst *fst) {
// negative_num_transitions[i] will contain (before sorting), the pair
// ( -(num-transitions-into(i) + num-transition-out-of(i)), i)
int32 num_states = fst->NumStates();
std::vector<std::pair<int32, int32> > negative_num_transitions(num_states);
for (int32 i = 0; i < num_states; i++) {
negative_num_transitions[i].first = 0;
negative_num_transitions[i].second = i;
}
for (int32 i = 0; i < num_states; i++) {
for (fst::ArcIterator<fst::StdVectorFst> aiter(*fst, i); !aiter.Done();
aiter.Next()) {
negative_num_transitions[i].first--;
negative_num_transitions[aiter.Value().nextstate].first--;
}
}
std::sort(negative_num_transitions.begin(), negative_num_transitions.end());
std::vector<fst::StdArc::StateId> order(num_states);
for (int32 i = 0; i < num_states; i++)
order[negative_num_transitions[i].second] = i;
fst::StateSort(fst, order);
}

void DenGraphMinimizeWrapper(fst::StdVectorFst *fst) {
for (int32 i = 1; i <= 3; i++) {
fst::PushSpecial(fst, fst::kDelta * 0.01);
Expand Down Expand Up @@ -424,6 +371,8 @@ void CreateDenominatorFst(const ContextDependency &ctx_dep,

DenGraphMinimizeWrapper(&transition_id_fst);

SortOnTransitionCount(&transition_id_fst);

*den_fst = transition_id_fst;
CheckDenominatorFst(trans_model.NumPdfs(), *den_fst);
PrintDenGraphStats(*den_fst);
Expand Down
43 changes: 7 additions & 36 deletions src/chain/chain-den-graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,13 +88,6 @@ class DenominatorGraph {
// Note: we renormalize each HMM-state to sum to one before doing this.
const CuVector<BaseFloat> &InitialProbs() const;

// returns the index of the HMM-state that has the highest value in
// InitialProbs (and which we believe will always be reachable from most other
// states... later on we may check this more carefully [TODO]).
// It's used in getting the 'arbitrary_scale' value to keep the alphas
// in a good dynamic range.
int32 SpecialHmmState() const { return special_hmm_state_; }

// This function outputs a modifified version of the FST that was used to
// build this object, that has an initial-state with epsilon transitions to
// each state, with weight determined by initial_probs_; and has each original
Expand All @@ -116,23 +109,15 @@ class DenominatorGraph {
// functions called from the constructor
void SetTransitions(const fst::StdVectorFst &fst, int32 num_pfds);

// work out the initial-probs and the 'special state'
// Note, there are no final-probs; we treat all states as final
// with probability one [we have a justification for this..
// assuming it's roughly a well-normalized HMM, this makes sense;
// note that we train on chunks, so the beginning and end of a chunk
// appear at arbitrary points in the sequence.
// At both beginning and end of the chunk, we limit ourselves to
// only those pdf-ids that were allowed in the numerator sequence.
// work out the initial-probs. Note, there are no final-probs; we treat all
// states as final with probability one [we have a justification for this..
// assuming it's roughly a well-normalized HMM, this makes sense; note that we
// train on chunks, so the beginning and end of a chunk appear at arbitrary
// points in the sequence. At both beginning and end of the chunk, we limit
// ourselves to only those pdf-ids that were allowed in the numerator
// sequence.
void SetInitialProbs(const fst::StdVectorFst &fst);

// return a suitable 'special' HMM-state used for normalizing probabilities in
// the forward-backward. It has to have a reasonably high probability and be
// reachable from most of the graph. returns a suitable state-index
// that we can set special_hmm_state_ to.
int32 ComputeSpecialState(const fst::StdVectorFst &fst,
const Vector<BaseFloat> &initial_probs);

// forward_transitions_ is an array, indexed by hmm-state index,
// of start and end indexes into the transition_ array, which
// give us the set of transitions out of this state.
Expand All @@ -152,23 +137,9 @@ class DenominatorGraph {
// distribution of the HMM. This isn't too critical.
CuVector<BaseFloat> initial_probs_;

// The index of a somewhat arbitrarily chosen HMM-state that we
// use for adjusting the alpha probabilities. It needs to be
// one that is reachable from all states (i.e. not a special
// state that's only reachable at sentence-start). We choose
// whichever one has the greatest initial-prob. It's set
// in SetInitialProbs().
int32 special_hmm_state_;

int32 num_pdfs_;
};

// returns the number of states from which there is a path to
// 'dest_state'. Utility function used in selecting 'special' state
// for normalization of probabilities.
int32 NumStatesThatCanReach(const fst::StdVectorFst &fst,
int32 dest_state);


// Function that does acceptor minimization without weight pushing...
// this is useful when constructing the denominator graph.
Expand Down
Loading