Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[not-for-merge] [src] Adding binary lattice-to-ngram-counts #2778

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
132 changes: 132 additions & 0 deletions src/lat/lattice-functions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
// 2013 Cisco Systems (author: Neha Agrawal) [code modified
// from original code in ../gmmbin/gmm-rescore-lattice.cc]
// 2014 Guoguo Chen
// 2018 David Snyder

// See ../../COPYING for clarification regarding multiple authors
//
Expand Down Expand Up @@ -1755,4 +1756,135 @@ void ReplaceAcousticScoresFromMap(
}
}

void ComputeSoftNgramCounts(const CompactLattice &lat, int32 n,
CompactLattice::Arc::Label eos_symbol,
std::vector<std::pair<std::vector<CompactLattice::Arc::Label>,
double> > *soft_counts ) {
kaldi::uint64 props = lat.Properties(fst::kFstProperties, false);
if (!(props & fst::kTopSorted))
KALDI_ERR << "Input lattice must be topologically sorted.";
typedef CompactLattice::Arc Arc;
typedef CompactLattice::StateId StateId;
typedef CompactLattice::Weight Weight;
unordered_map<StateId, std::vector<Arc> > ngram_history;
std::vector<StateId> discovered(lat.NumStates(), 0);

StateId start_state = lat.Start(),
super_final_offset = lat.NumStates();
std::vector<Arc> start_history;
std::pair<const StateId, std::vector<Arc> > new_pair(start_state,
start_history);
ngram_history.insert(new_pair);

std::vector<double> alpha,
beta;

double tot_like = kaldi::ComputeLatticeAlphasAndBetas(lat, false,
&alpha, &beta);

for (StateId state = 0; state < lat.NumStates(); state++) {
for (fst::ArcIterator<CompactLattice > aiter(lat, state);
!aiter.Done(); aiter.Next()) {
const Arc &arc = aiter.Value();
StateId next_state = arc.nextstate;
// When the ngram history reaches n-1, drop the oldest
// arc in the history when copying. However, we don't want to do this
// if the current transition is an epsilon.
std::vector<Arc> arc_history;

// General case where the state is not a final state.
if (n > 1) {
// If the state has full n-gram history.
if (ngram_history[state].size() == n-1 && arc.olabel != 0)
arc_history = std::vector<Arc>(ngram_history[state].begin() + 1,
ngram_history[state].end());
// Else the n-gram history is truncated.
else
arc_history = std::vector<Arc>(ngram_history[state]);
// Epsilons aren't part of the ngram history.
if (arc.olabel != 0)
arc_history.push_back(arc);
}

std::pair<const StateId, std::vector<Arc> > new_pair(next_state,
arc_history);
// We need unique histories. So if we've already added a history
// for this state, it needs to be the same history that we're about
// to add.
if (ngram_history.find(next_state) != ngram_history.end()) {
KALDI_ASSERT(ngram_history[next_state].size()
== arc_history.size());
for (int32 i = 0; i < arc_history.size(); i++)
KALDI_ASSERT(arc_history[i].olabel
== ngram_history[next_state][i].olabel);
} else {
ngram_history.insert(new_pair);
}

// Retrieve probability for this instance of the n-gram.
if (arc.olabel != 0) {
std::vector<Arc::Label> ngram;
for (int32 i = 0; i < ngram_history[state].size(); i++)
ngram.push_back(ngram_history[state][i].olabel);
ngram.push_back(arc.olabel);
double prob = exp(alpha[state] -
ConvertToCost(arc.weight) + beta[next_state] - tot_like);
std::pair<std::vector<Arc::Label>,
double> ngram_and_prob(ngram, prob);
soft_counts->push_back(ngram_and_prob);
}
}

// Handle the case where the state is a final state.
Weight final_weight = lat.Final(state);
if (final_weight != Weight::Zero()) {
// Each final state gets its own super_final state. This is needed to
// statisfy the requirement that every state has a unique history up
// to the order of n because we associate a EOS symbol with the final
// state. This super_final state is not added to the lattice.
StateId super_final = state + super_final_offset;
// This arc points to an imaginary "super-final" state.
Arc arc;
arc.weight = final_weight;
arc.ilabel = eos_symbol;
arc.olabel = eos_symbol;
arc.nextstate = super_final;
std::vector<Arc> arc_history;
if (n > 1) {
if (ngram_history[state].size() == n-1)
arc_history = std::vector<Arc>(ngram_history[state].begin() + 1,
ngram_history[state].end());
else
arc_history = std::vector<Arc>(ngram_history[state]);
// If n == 1 then there is no transition history.
arc_history.push_back(arc);
}

std::pair<const StateId, std::vector<Arc> > new_pair(super_final,
arc_history);
if (ngram_history.find(super_final) != ngram_history.end()) {
KALDI_ASSERT(ngram_history[super_final].size()
== arc_history.size());
for (int32 i = 0; i < arc_history.size(); i++)
KALDI_ASSERT(arc_history[i].olabel
== ngram_history[super_final][i].olabel);
} else {
ngram_history.insert(new_pair);
}

std::vector<Arc::Label> ngram;
for (int32 i = 0; i < ngram_history[state].size(); i++)
ngram.push_back(ngram_history[state][i].olabel);
ngram.push_back(arc.olabel);
// Retrieve probability of this instance of the n-gram.
// Note that beta[super_final] == 0.
double prob = exp(alpha[state] - ConvertToCost(final_weight)
- tot_like);
std::pair<std::vector<Arc::Label>,
double> ngram_and_prob(ngram, prob);
soft_counts->push_back(ngram_and_prob);
}
}
}

} // namespace kaldi
43 changes: 29 additions & 14 deletions src/lat/lattice-functions.h
Original file line number Diff line number Diff line change
Expand Up @@ -377,26 +377,26 @@ void ComposeCompactLatticeDeterministic(
fst::DeterministicOnDemandFst<fst::StdArc>* det_fst,
CompactLattice* composed_clat);

/// This function computes the mapping from the pair
/// (frame-index, transition-id) to the pair
/// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the
/// This function computes the mapping from the pair
/// (frame-index, transition-id) to the pair
/// (sum-of-acoustic-scores, num-of-occurences) over all occurences of the
/// transition-id in that frame.
/// frame-index in the lattice.
/// This function is useful for retaining the acoustic scores in a
/// non-compact lattice after a process like determinization where the
/// frame-index in the lattice.
/// This function is useful for retaining the acoustic scores in a
/// non-compact lattice after a process like determinization where the
/// frame-level acoustic scores are typically lost.
/// The function ReplaceAcousticScoresFromMap is used to restore the
/// The function ReplaceAcousticScoresFromMap is used to restore the
/// acoustic scores computed by this function.
///
/// @param [in] lat Input lattice. Expected to be top-sorted. Otherwise the
/// function will crash.
/// @param [out] acoustic_scores
/// @param [in] lat Input lattice. Expected to be top-sorted. Otherwise the
/// function will crash.
/// @param [out] acoustic_scores
/// Pointer to a map from the pair (frame-index,
/// transition-id) to a pair (sum-of-acoustic-scores,
/// num-of-occurences).
/// Usually the acoustic scores for a pdf-id (and hence
/// transition-id) on a frame will be the same for all the
/// occurences of the pdf-id in that frame.
/// occurences of the pdf-id in that frame.
/// But if not, we will take the average of the acoustic
/// scores. Hence, we store both the sum-of-acoustic-scores
/// and the num-of-occurences of the transition-id in that
Expand All @@ -409,18 +409,33 @@ void ComputeAcousticScoresMap(
/// This function restores acoustic scores computed using the function
/// ComputeAcousticScoresMap into the lattice.
///
/// @param [in] acoustic_scores
/// @param [in] acoustic_scores
/// A map from the pair (frame-index, transition-id) to a
/// pair (sum-of-acoustic-scores, num-of-occurences) of
/// pair (sum-of-acoustic-scores, num-of-occurences) of
/// the occurences of the transition-id in that frame.
/// See the comments for ComputeAcousticScoresMap for
/// See the comments for ComputeAcousticScoresMap for
/// details.
/// @param [out] lat Pointer to the output lattice.
void ReplaceAcousticScoresFromMap(
const unordered_map<std::pair<int32, int32>, std::pair<BaseFloat, int32>,
PairHasher<int32> > &acoustic_scores,
Lattice *lat);


/// For a specified n, this function computes soft n-gram counts from
/// the arc labels of the input lattice lat. The input lattice needs to be
/// expanded for order n. For any state in lat all paths to that state must
/// have the same last n labels preceding the state. The binary
/// lattice-expand-ngram should be used to expand lattices for a given n.
/// The soft n-gram counts are returned in soft_counts as n-gram, probability
/// pairs and instances of the same n-gram probability are stored as separate
/// entries; these need to be summed over to get the final soft-count for a
/// given n-gram.
void ComputeSoftNgramCounts(const CompactLattice &lat, int32 n,
CompactLattice::Arc::Label eos_symbol,
std::vector<std::pair<std::vector<CompactLattice::Arc::Label>,
double> > *soft_counts);

} // namespace kaldi

#endif // KALDI_LAT_LATTICE_FUNCTIONS_H_
2 changes: 1 addition & 1 deletion src/latbin/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ BINFILES = lattice-best-path lattice-prune lattice-equivalent lattice-to-nbest \
nbest-to-linear nbest-to-lattice lattice-1best linear-to-nbest \
lattice-mbr-decode lattice-align-words lattice-to-mpe-post \
lattice-copy-backoff nbest-to-ctm lattice-determinize-pruned \
lattice-to-ctm-conf lattice-combine \
lattice-to-ctm-conf lattice-combine lattice-to-ngram-counts \
lattice-rescore-mapped lattice-depth lattice-align-phones \
lattice-to-smbr-post lattice-determinize-pruned-parallel \
lattice-add-penalty lattice-align-words-lexicon lattice-push \
Expand Down
107 changes: 107 additions & 0 deletions src/latbin/lattice-to-ngram-counts.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
// latbin/lattice-to-ngram-counts.cc

// Copyright 2014 Telepoint Global Hosting Service, LLC. (Author: David Snyder)
// 2018 David Snyder
// See ../../COPYING for clarification regarding multiple authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// THIS CODE IS PROVIDED *AS IS* BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WITHOUT LIMITATION ANY IMPLIED
// WARRANTIES OR CONDITIONS OF TITLE, FITNESS FOR A PARTICULAR PURPOSE,
// MERCHANTABLITY OR NON-INFRINGEMENT.
// See the Apache 2 License for the specific language governing permissions and
// limitations under the License.

#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "fstext/fstext-lib.h"
#include "lat/kaldi-lattice.h"
#include "lat/lattice-functions.h"
#include <climits>

int main(int argc, char *argv[]) {
try {
using namespace kaldi;
typedef kaldi::int32 int32;
typedef kaldi::int64 int64;
using fst::SymbolTable;
using fst::VectorFst;

const char *usage =
"Retrieve n-label soft-counts for each input lattice. Each line of\n"
"the output is of the form <uttid> <ngram_1>:<ngram_1-prob> ... "
"<ngram_k>:<ngram_k-prob>.\n"
"<ngram_k> is of the form <sym_1>,<sym_2>,...,<sym_n>.\n"
"Note that <ngram_k> is an instance of that n-gram. The actual soft-\n"
"counts are the consolidation of all instances of the same n-gram.\n"
"Usage: lattice-to-ngram-counts [options] <lattice-rspecifier> "
"<softcount-output-file>\n"
" e.g.: lattice-to-ngram-counts --n=3 --eos-symbol=100 ark:lats "
"counts.txt\n";

ParseOptions po(usage);
int32 n = 3;
CompactLatticeArc::Label eos_symbol = INT_MAX;
BaseFloat acoustic_scale = 0.075;

std::string word_syms_filename;
po.Register("n", &n, "n-gram context size for computing soft-counts");
po.Register("eos-symbol", &eos_symbol,
"Integer label for the end of sentence character");
po.Register("acoustic-scale", &acoustic_scale,
"Scaling factor for acoustic likelihoods");

po.Read(argc, argv);

if (po.NumArgs() != 2) {
po.PrintUsage();
exit(1);
}
KALDI_ASSERT(n > 0);

std::string lats_rspecifier = po.GetArg(1),
softcount_wspecifier = po.GetOptArg(2);

SequentialCompactLatticeReader clat_reader(lats_rspecifier);
std::ofstream softcount_file;
softcount_file.open(softcount_wspecifier.c_str());
softcount_file.flush();

int32 n_done = 0;

for (; !clat_reader.Done(); clat_reader.Next()) {
std::string key = clat_reader.Key();
KALDI_LOG << "Processing lattice for key " << key;
CompactLattice lat = clat_reader.Value();
fst::ScaleLattice(fst::AcousticLatticeScale(acoustic_scale), &lat);
std::vector<std::pair<std::vector<CompactLattice::Arc::Label>,
double> > soft_counts;
TopSortCompactLatticeIfNeeded(&lat);
kaldi::ComputeSoftNgramCounts(lat, n, eos_symbol, &soft_counts);
softcount_file << key << " ";
for (int i = 0; i < soft_counts.size(); i++) {
int32 size = soft_counts[i].first.size();
for (int j = 0; j < size-1; j++) {
softcount_file << soft_counts[i].first[j] << ",";
}
softcount_file << soft_counts[i].first[size-1] << ":"
<< soft_counts[i].second << " ";
}
softcount_file << std::endl;
clat_reader.FreeCurrent();
n_done++;
}
KALDI_LOG << "Computed ngram soft counts for " << n_done
<< " utterances.";
softcount_file.close();
return 0;
} catch(const std::exception &e) {
std::cerr << e.what();
return -1;
}
}