Skip to content

Commit

Permalink
Separated IVector machine initialization from IVector trainer initial…
Browse files Browse the repository at this point in the history
…ization
  • Loading branch information
Manuel Guenther committed May 12, 2015
1 parent f2534aa commit 65ddebe
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 15 deletions.
37 changes: 22 additions & 15 deletions bob/learn/em/cpp/IVectorTrainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,17 @@ bob::learn::em::IVectorTrainer::~IVectorTrainer()
void bob::learn::em::IVectorTrainer::initialize(
bob::learn::em::IVectorMachine& machine)
{
// Initializes \f$T\f$ and \f$\Sigma\f$ of the machine
blitz::Array<double,2>& T = machine.updateT();
bob::core::array::randn(*m_rng, T);
blitz::Array<double,1>& sigma = machine.updateSigma();
sigma = machine.getUbm()->getVarianceSupervector();
machine.precompute();
}

void bob::learn::em::IVectorTrainer::resetAccumulators(const bob::learn::em::IVectorMachine& machine)
{
// Resize the accumulator
const int C = machine.getNGaussians();
const int D = machine.getNInputs();
const int Rt = machine.getDimRt();
Expand All @@ -75,14 +85,17 @@ void bob::learn::em::IVectorTrainer::initialize(
if (m_update_sigma)
m_tmp_dd1.resize(D,D);

// Initializes \f$T\f$ and \f$\Sigma\f$ of the machine
blitz::Array<double,2>& T = machine.updateT();
bob::core::array::randn(*m_rng, T);
blitz::Array<double,1>& sigma = machine.updateSigma();
sigma = machine.getUbm()->getVarianceSupervector();
machine.precompute();
// initialize with 0
m_acc_Nij_wij2 = 0.;
m_acc_Fnormij_wij = 0.;
if (m_update_sigma)
{
m_acc_Nij = 0.;
m_acc_Snormij = 0.;
}
}


void bob::learn::em::IVectorTrainer::eStep(
bob::learn::em::IVectorMachine& machine,
const std::vector<bob::learn::em::GMMStats>& data)
Expand All @@ -91,13 +104,8 @@ void bob::learn::em::IVectorTrainer::eStep(
const int C = machine.getNGaussians();

// Reinitializes accumulators to 0
m_acc_Nij_wij2 = 0.;
m_acc_Fnormij_wij = 0.;
if (m_update_sigma)
{
m_acc_Nij = 0.;
m_acc_Snormij = 0.;
}
resetAccumulators(machine);

for (std::vector<bob::learn::em::GMMStats>::const_iterator it = data.begin();
it != data.end(); ++it)
{
Expand Down Expand Up @@ -179,7 +187,7 @@ bob::learn::em::IVectorTrainer& bob::learn::em::IVectorTrainer::operator=
(const bob::learn::em::IVectorTrainer &other)
{
if (this != &other)
{
{
m_update_sigma = other.m_update_sigma;

m_acc_Nij_wij2.reference(bob::core::array::ccopy(other.m_acc_Nij_wij2));
Expand Down Expand Up @@ -225,4 +233,3 @@ bool bob::learn::em::IVectorTrainer::is_similar_to
bob::core::array::isClose(m_acc_Nij, other.m_acc_Nij, r_epsilon, a_epsilon) &&
bob::core::array::isClose(m_acc_Snormij, other.m_acc_Snormij, r_epsilon, a_epsilon);
}

6 changes: 6 additions & 0 deletions bob/learn/em/include/bob.learn.em/IVectorTrainer.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,12 @@ class IVectorTrainer
*/
virtual void initialize(bob::learn::em::IVectorMachine& ivector);

/**
* @brief Reset the statistics accumulators
* to the correct size and a value of zero.
*/
void resetAccumulators(const bob::learn::em::IVectorMachine& ivector);

/**
* @brief Calculates statistics across the dataset,
* and saves these as:
Expand Down

0 comments on commit 65ddebe

Please sign in to comment.