From 8619c0c174bf845eef7bcb3bcba6eb76ee987767 Mon Sep 17 00:00:00 2001 From: Manuel Guenther Date: Fri, 6 Mar 2015 12:06:21 +0100 Subject: [PATCH] Added rng to train_jfa method. --- bob/learn/em/train.py | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/bob/learn/em/train.py b/bob/learn/em/train.py index 1e0fe5f..517e941 100644 --- a/bob/learn/em/train.py +++ b/bob/learn/em/train.py @@ -57,25 +57,30 @@ def train(trainer, machine, data, max_iterations = 50, convergence_threshold=Non trainer.finalize(machine, data) -def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True): +def train_jfa(trainer, jfa_base, data, max_iterations=10, initialize=True, rng=None): """ Trains a :py:class`bob.learn.em.JFABase` given a :py:class`bob.learn.em.JFATrainer` and the proper data **Parameters**: - trainer - A trainer mechanism (:py:class`bob.learn.em.JFATrainer`) - machine - A container machine (:py:class`bob.learn.em.JFABase`) - data - The data to be trained list(list(:py:class`bob.learn.em.GMMStats`)) - max_iterations + trainer : :py:class`bob.learn.em.JFATrainer` + A JFA trainer mechanism + jfa_base : :py:class`bob.learn.em.JFABase` + A container machine + data : [[:py:class`bob.learn.em.GMMStats`]] + The data to be trained + max_iterations : int The maximum number of iterations to train a machine - initialize + initialize : bool If True, runs the initialization procedure + rng : :py:class:`bob.core.random.mt19937` + The Mersenne Twister mt19937 random generator used for the initialization of subspaces/arrays before the EM loops """ if initialize: - trainer.initialize(jfa_base, data) + if rng is not None: + trainer.initialize(jfa_base, data, rng) + else: + trainer.initialize(jfa_base, data) #V Subspace for i in range(max_iterations):