Skip to content

Commit

Permalink
simple coding style fix (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
LvHang authored and keli78 committed Sep 7, 2018
1 parent 17d86e6 commit b2bba43
Show file tree
Hide file tree
Showing 7 changed files with 177 additions and 132 deletions.
70 changes: 40 additions & 30 deletions src/latbin/lattice-lmrescore-kaldi-rnnlm-adaptation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ class NnetComputerFromEg {
nnet_(nnet), compiler_(nnet) { }

// Compute the output (which will have the same number of rows as the number
// of Indexes in the output with the name 'output_name' of the eg),
// of Indexes in the output with the name 'output_name' of the eg),
// and put it in "*output".
// An output with the name 'output_name' is expected to exist in the network.
void Compute(const NnetExample &eg, const std::string &output_name,
void Compute(const NnetExample &eg, const std::string &output_name,
Matrix<BaseFloat> *output) {
ComputationRequest request;
bool need_backprop = false, store_stats = false;
Expand All @@ -61,7 +61,8 @@ class NnetComputerFromEg {
NnetComputer computer(options, computation, nnet_, NULL);
computer.AcceptInputs(nnet_, eg.io);
computer.Run();
const CuMatrixBase<BaseFloat> &nnet_output = computer.GetOutput(output_name);
const CuMatrixBase<BaseFloat> &nnet_output =
computer.GetOutput(output_name);
output->Resize(nnet_output.NumRows(), nnet_output.NumCols());
nnet_output.CopyToMat(output);
}
Expand All @@ -70,7 +71,7 @@ class NnetComputerFromEg {
CachingOptimizingCompiler compiler_;
};

} // namespace nnet3
} // namespace nnet3

// This class computes and outputs
// the information about arc posteriors.
Expand All @@ -82,7 +83,7 @@ class ArcPosteriorComputer {
BaseFloat min_post,
const TransitionModel *trans_model = NULL):
clat_(clat), min_post_(min_post) {}


// returns the number of arc posteriors that it output.
void OutputPosteriors(
Expand Down Expand Up @@ -133,7 +134,7 @@ class ArcPosteriorComputer {

BaseFloat min_post_;
};
} // namespace kaldi
} // namespace kaldi

void ReadUttToConvo(string filename, map<string, string> &m) {
KALDI_ASSERT(m.size() == 0);
Expand Down Expand Up @@ -186,7 +187,8 @@ int main(int argc, char *argv[]) {
kaldi::BaseFloat min_post = 0.0001;
int32 max_ngram_order = 3;
int32 weight_range = 10;
kaldi::BaseFloat acoustic_scale = 0.1, lm_scale = 1.0, weight = 5.0, correction_weight = 0.1;
kaldi::BaseFloat acoustic_scale = 0.1, lm_scale = 1.0,
weight = 5.0, correction_weight = 0.1;

po.Register("acoustic-scale", &acoustic_scale,
"Scaling factor for acoustic likelihoods");
Expand Down Expand Up @@ -241,7 +243,9 @@ int main(int argc, char *argv[]) {

NnetComputerFromEg nnet_computer(dnn);

std::vector<double> original_unigram(word_embedding_mat.NumRows(), 0.0); // number of words
// number of words
std::vector<double> original_unigram(word_embedding_mat.NumRows(), 0.0);

ReadUnigram(unigram_file, &original_unigram);

const rnnlm::RnnlmComputeStateInfo info(opts, rnnlm, word_embedding_mat);
Expand All @@ -253,7 +257,7 @@ int main(int argc, char *argv[]) {
std::map<string, map<int, double> > per_utt_counts;
std::map<string, double> per_convo_sums;
std::map<string, double> per_utt_sums;

std::vector<string> utt_ids;

{
Expand All @@ -263,13 +267,14 @@ int main(int argc, char *argv[]) {
std::string utt_id = clat_reader.Key();
utt_ids.push_back(utt_id);
kaldi::CompactLattice &clat = clat_reader.Value();

fst::ScaleLattice(fst::LatticeScale(lm_scale, acoustic_scale), &clat);
kaldi::TopSortCompactLatticeIfNeeded(&clat);

string convo_id = utt2convo[utt_id];
// Use convs of both speakers
std::string convo_id_twoSides = std::string(convo_id.begin(), convo_id.end() - 2);
std::string convo_id_twoSides = std::string(convo_id.begin(),
convo_id.end() - 2);
convo_id = convo_id_twoSides;
kaldi::ArcPosteriorComputer computer(clat, min_post);

Expand All @@ -282,16 +287,16 @@ int main(int argc, char *argv[]) {

clat_reader.Close();
}

// Collect stats of nearby sentences by looping over the per_utt_counts
int32 range = weight_range;
std::map<string, map<int, double> > per_utt_nearby_stats;
std::map<string, double> per_utt_nearby_sums;
std::vector<string>::iterator it = utt_ids.begin();
for(int32 i = 0; i < utt_ids.size(); i++) {
for (int32 i = 0; i < utt_ids.size(); i++) {
// current utterance id
std::string utt_id = *(it + i);
for(int32 j = 0; j <= range; j++) {
for (int32 j = 0; j <= range; j++) {
// get the correct idx of the nearby sentence
int32 idx = j - range / 2;
if (idx == 0)
Expand All @@ -312,47 +317,50 @@ int main(int argc, char *argv[]) {
std::cout << "Nearby utt id " << utt_nearby_id << std::endl;
*/
std::map<int, double> per_utt_stats = per_utt_counts[utt_nearby_id];
for(std::map<int, double>::iterator utt_it = per_utt_stats.begin();
for (std::map<int, double>::iterator utt_it = per_utt_stats.begin();
utt_it != per_utt_stats.end(); ++utt_it) {
int32 word = utt_it->first;
BaseFloat soft_count = utt_it->second;
per_utt_nearby_stats[utt_id][word] += soft_count * weight;
per_utt_nearby_sums[utt_id] += soft_count * weight;
}
}
}
}

SequentialCompactLatticeReader compact_lattice_reader(lats_rspecifier);
CompactLatticeWriter compact_lattice_writer(lats_wspecifier);

std::string output_name = "output";

for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) {

for (; !compact_lattice_reader.Done(); compact_lattice_reader.Next()) {
std::string key = compact_lattice_reader.Key();
// std::cout << "key is " << key << std::endl;
std::string convo_id = utt2convo[key];
// Use convs of both speakers
std::string convo_id_twoSides = std::string(convo_id.begin(), convo_id.end()-2);
std::string convo_id_twoSides = std::string(convo_id.begin(),
convo_id.end() - 2);
// std::cout << "convo_id is " << convo_id << std::endl;
// std::cout << "convo_id_twoSides is " << convo_id_twoSides << std::endl;
convo_id = convo_id_twoSides;
KALDI_ASSERT(convo_id != "");

map<int, double> unigram = per_convo_counts[convo_id];
for (map<int, double>::iterator iter = per_utt_counts[key].begin();
iter != per_utt_counts[key].end(); iter++) {
iter != per_utt_counts[key].end();
iter++) {
unigram[iter->first] =
(unigram[iter->first] - iter->second); // / per_utt_sums[key];
(unigram[iter->first] - iter->second); // per_utt_sums[key];
// debug_sum += unigram[iter->first];
}

// adjust weights of nearby sentences
// std::cout << "weighted count "<< per_utt_nearby_stats[key][iter->first] << std::endl;
// std::cout << "weighted count "<< per_utt_nearby_stats[key][iter->first]
// << std::endl;
// unigram[iter->first] += per_utt_nearby_stats[key][iter->first];
double weighted_sum = 0.0;
for (map<int, double>::iterator iter = per_utt_nearby_stats[key].begin();
iter != per_utt_nearby_stats[key].end(); ++iter) {
iter != per_utt_nearby_stats[key].end();
++iter) {
int32 word = iter->first;
unigram[word] += iter->second;
weighted_sum += iter->second;
Expand All @@ -374,7 +382,8 @@ int main(int argc, char *argv[]) {
iter != unigram.end(); iter++) {
input.push_back(std::make_pair(iter->first, iter->second));
}
// KALDI_LOG << "Input info: " << input.size() << " , " << input[0].first << " , " << input[0].second;
// KALDI_LOG << "Input info: " << input.size() << " , "
// << input[0].first << " , " << input[0].second;
std::vector<std::vector<std::pair<int32, BaseFloat> > > feat;
feat.push_back(input);
const SparseMatrix<BaseFloat> feat_sp(word_embedding_mat.NumRows(), feat);
Expand All @@ -383,7 +392,7 @@ int main(int argc, char *argv[]) {
eg.io.push_back(NnetIo("input", 0, eg_input));
// const Posterior post;
eg.io.push_back(NnetIo("output", word_embedding_mat.NumRows(), 0, feat));
// Second compute output given 1) eg input and 2) trained nnet
// Second compute output given 1) eg input and 2) trained nnet
Matrix<BaseFloat> output;
nnet_computer.Compute(eg, output_name, &output);
KALDI_ASSERT(output.NumRows() != 0);
Expand All @@ -400,11 +409,12 @@ int main(int argc, char *argv[]) {
}
}
KALDI_ASSERT(ApproxEqual(uni_sum, 1.0));

rnnlm::KaldiRnnlmDeterministicFst
rnnlm_fst(max_ngram_order, info, correction_weight, unigram_dnn, original_unigram);
rnnlm_fst(max_ngram_order, info, correction_weight, unigram_dnn,
original_unigram);
CompactLattice &clat = compact_lattice_reader.Value();

if (lm_scale != 0.0) {
// Before composing with the LM FST, we scale the lattice weights
// by the inverse of "lm_scale". We'll later scale by "lm_scale".
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ class ArcPosteriorComputer {
BaseFloat min_post,
const TransitionModel *trans_model = NULL):
clat_(clat), min_post_(min_post) {}


// returns the number of arc posteriors that it output.
void OutputPosteriors(
Expand Down Expand Up @@ -97,7 +97,7 @@ class ArcPosteriorComputer {

BaseFloat min_post_;
};
} // namespace kaldi
} // namespace kaldi

void ReadUttToConvo(string filename, map<string, string> &m) {
KALDI_ASSERT(m.size() == 0);
Expand Down Expand Up @@ -160,27 +160,28 @@ int main(int argc, char *argv[]) {
bool use_carpa = false;
bool two_speaker_mode = false, one_best_mode = false;

po.Register("lm-scale", &lm_scale, "Scaling factor for <lm-to-add>; its negative "
"will be applied to <lm-to-subtract>.");
po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for acoustic "
"probabilities (e.g. 0.1 for non-chain systems); important because "
"of its effect on pruning.");
po.Register("lm-scale", &lm_scale, "Scaling factor for <lm-to-add>; its "
"negative will be applied to <lm-to-subtract>.");
po.Register("acoustic-scale", &acoustic_scale, "Scaling factor for "
"acoustic probabilities (e.g. 0.1 for non-chain systems); "
"important because of its effect on pruning.");
po.Register("use-const-arpa", &use_carpa, "If true, read the old-LM file "
"as a const-arpa file as opposed to an FST file");
po.Register("correction_weight", &correction_weight, "The weight on the "
"correction term of the RNNLM scores.");
po.Register("max-ngram-order", &max_ngram_order,
"If positive, allow RNNLM histories longer than this to be identified "
"with each other for rescoring purposes (an approximation that "
"saves time and reduces output lattice size).");
"If positive, allow RNNLM histories longer than this to be "
"identified with each other for rescoring purposes (an "
"approximation that saves time and reduces output lattice "
"size).");
po.Register("min-post", &min_post,
"Arc posteriors below this value will be pruned away");
po.Register("two_speaker_mode", &two_speaker_mode, "If true, use two "
"speaker's utterances to estimate cache models or "
"speaker's utterances to estimate cache models or "
"as the input of DNN models.");
po.Register("one_best_mode", &one_best_mode, "If true, use 1 best decoding "
"results instead of lattice posteriors to estimate cache models "
"or as the input of DNN models.");
"results instead of lattice posteriors to estimate cache "
"models or as the input of DNN models.");

opts.Register(&po);
compose_opts.Register(&po);
Expand All @@ -206,7 +207,8 @@ int main(int argc, char *argv[]) {

// for G.fst
fst::ScaleDeterministicOnDemandFst *lm_to_subtract_det_scale = NULL;
fst::BackoffDeterministicOnDemandFst<StdArc> *lm_to_subtract_det_backoff = NULL;
fst::BackoffDeterministicOnDemandFst<StdArc>
*lm_to_subtract_det_backoff = NULL;
VectorFst<StdArc> *lm_to_subtract_fst = NULL;

// for G.carpa
Expand Down Expand Up @@ -241,9 +243,10 @@ int main(int argc, char *argv[]) {
CuMatrix<BaseFloat> word_embedding_mat;
ReadKaldiObject(word_embedding_rxfilename, &word_embedding_mat);

std::vector<double> original_unigram(word_embedding_mat.NumRows(), 0.0); // number of words
// number of words
std::vector<double> original_unigram(word_embedding_mat.NumRows(), 0.0);
ReadUnigram(unigram_file, &original_unigram);

const rnnlm::RnnlmComputeStateInfo info(opts, rnnlm, word_embedding_mat);

// Reads and writes as compact lattice.
Expand All @@ -256,7 +259,7 @@ int main(int argc, char *argv[]) {
std::map<string, map<int, double> > per_utt_counts;
std::map<string, double> per_convo_sums;
std::map<string, double> per_utt_sums;

std::vector<string> utt_ids;
{
SequentialCompactLatticeReader clat_reader(lats_rspecifier);
Expand All @@ -271,11 +274,12 @@ int main(int argc, char *argv[]) {

string convo_id = utt2convo[utt_id];
if (two_speaker_mode) {
std::string convo_id_2spk = std::string(convo_id.begin(), convo_id.end() - 2);
std::string convo_id_2spk = std::string(convo_id.begin(),
convo_id.end() - 2);
convo_id = convo_id_2spk;
}

// Estimate cache models from 1-best hypotheses instead of
// Estimate cache models from 1-best hypotheses instead of
// word-posteriors from first-pass decoded lattices
if (one_best_mode) {
kaldi::CompactLattice best_path;
Expand All @@ -290,11 +294,10 @@ int main(int argc, char *argv[]) {
&(per_utt_counts[utt_id]),
&(per_convo_sums[convo_id]),
&(per_utt_sums[utt_id]));

}
clat_reader.Close();
}

std::map<string, map<int, double> > per_utt_hists;
// std::map<string, double> per_utt_hists_sums;
std::vector<string>::iterator it = utt_ids.begin();
Expand All @@ -307,8 +310,9 @@ int main(int argc, char *argv[]) {
// copy the previous speaker's utts to the current one
per_utt_hists[utt_id] = per_utt_hists[utt_id_prev];
// add the current utt to the counts for the current speaker
for (std::map<int, double>::iterator cur_utt = per_utt_counts[utt_id].begin();
cur_utt != per_utt_counts[utt_id].end(); ++cur_utt) {
for (std::map<int, double>::iterator cur_utt =
per_utt_counts[utt_id].begin();
cur_utt != per_utt_counts[utt_id].end(); ++cur_utt) {
int32 word = cur_utt->first;
BaseFloat count = cur_utt->second;
per_utt_hists[utt_id][word] += count;
Expand All @@ -320,14 +324,16 @@ int main(int argc, char *argv[]) {
std::string key = compact_lattice_reader.Key();
std::string convo_id = utt2convo[key];
if (two_speaker_mode) {
std::string convo_id_2spks = std::string(convo_id.begin(), convo_id.end() - 2);
std::string convo_id_2spks = std::string(convo_id.begin(),
convo_id.end() - 2);
convo_id = convo_id_2spks;
}
KALDI_ASSERT(convo_id != "");

map<int, double> unigram = per_utt_hists[key];
for (map<int, double>::iterator iter = per_utt_counts[key].begin();
iter != per_utt_counts[key].end(); ++iter) {
iter != per_utt_counts[key].end();
++iter) {
unigram[iter->first] = (unigram[iter->first] - iter->second);
}
double sum = 0.0;
Expand All @@ -346,8 +352,10 @@ int main(int argc, char *argv[]) {
KALDI_ASSERT(ApproxEqual(debug_sum, 1.0));

// Rescoring and pruning happens below.
rnnlm::KaldiRnnlmDeterministicFst* lm_to_add_orig =
new rnnlm::KaldiRnnlmDeterministicFst(max_ngram_order, info, correction_weight, unigram, original_unigram);
rnnlm::KaldiRnnlmDeterministicFst* lm_to_add_orig =
new rnnlm::KaldiRnnlmDeterministicFst(max_ngram_order, info,
correction_weight,
unigram, original_unigram);
fst::DeterministicOnDemandFst<StdArc> *lm_to_add =
new fst::ScaleDeterministicOnDemandFst(lm_scale, lm_to_add_orig);

Expand Down

0 comments on commit b2bba43

Please sign in to comment.