Skip to content

Commit

Permalink
Bound function for resetting accumulators in IVectorTrainer
Browse files Browse the repository at this point in the history
  • Loading branch information
Manuel Guenther committed May 12, 2015
1 parent 65ddebe commit 4c34bb8
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 2 deletions.
29 changes: 29 additions & 0 deletions bob/learn/em/ivector_trainer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,29 @@ static PyObject* PyBobLearnEMIVectorTrainer_m_step(PyBobLearnEMIVectorTrainerObj
Py_RETURN_NONE;
}

/*** reset_accumulators ***/
static auto reset_accumulators = bob::extension::FunctionDoc(
"reset_accumulators",
"Reset the statistics accumulators to the correct size and a value of zero.",
0,
true
)
.add_prototype("ivector_machine")
.add_parameter("ivector_machine", ":py:class:`bob.learn.em.IVectorMachine`", "The IVector machine containing the right dimensions");
static PyObject* PyBobLearnEMIVectorTrainer_reset_accumulators(PyBobLearnEMIVectorTrainerObject* self, PyObject* args, PyObject* kwargs) {
BOB_TRY

/* Parses input arguments in a single shot */
char** kwlist = reset_accumulators.kwlist(0);

PyBobLearnEMIVectorMachineObject* machine;
if (!PyArg_ParseTupleAndKeywords(args, kwargs, "O!", kwlist, &PyBobLearnEMIVectorMachine_Type, &machine)) return 0;

self->cxx->resetAccumulators(*machine->cxx);
Py_RETURN_NONE;

BOB_CATCH_MEMBER("cannot perform the reset_accumulators method", 0)
}


static PyMethodDef PyBobLearnEMIVectorTrainer_methods[] = {
Expand All @@ -427,6 +450,12 @@ static PyMethodDef PyBobLearnEMIVectorTrainer_methods[] = {
METH_VARARGS|METH_KEYWORDS,
m_step.doc()
},
{
reset_accumulators.name(),
(PyCFunction)PyBobLearnEMIVectorTrainer_reset_accumulators,
METH_VARARGS|METH_KEYWORDS,
reset_accumulators.doc()
},
{0} /* Sentinel */
};

Expand Down
4 changes: 2 additions & 2 deletions bob/learn/em/test/test_ivector_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,11 +244,11 @@ def test_trainer_nosigma():
# M-Step
trainer.m_step(m)
assert numpy.allclose(t_ref[it], m.t, 1e-5)


#testing exceptions
nose.tools.assert_raises(RuntimeError, trainer.e_step, m, [1,2,2])


def test_trainer_update_sigma():
# Ubm
Expand Down

0 comments on commit 4c34bb8

Please sign in to comment.