Skip to content

Commit

Permalink
refactor FMM and GMM (tested)
Browse files Browse the repository at this point in the history
  • Loading branch information
lindahua committed Nov 18, 2012
1 parent 764c0e9 commit 92102bc
Show file tree
Hide file tree
Showing 7 changed files with 192 additions and 106 deletions.
51 changes: 48 additions & 3 deletions bases/pli_pmodel_base.m
Expand Up @@ -6,15 +6,60 @@

n = check_observations(self, obs)
% Verify the validity of input observations
%
% It returns the number of samples contained in obs when
% obs is a valid sample set. Otherwise, it raises an error.
%

params = estimate_param(self, obs, weights, hints)
% Estimate model parameters from observations

params = update_params(self, obs, weights, sidx, params)
% Update model parameters from re-weighted observations
%
% It optimizes the parameters based on MAP criterion.
%
% It should be able to support several usage as follows:
%
% params = self.update_map(obs, [], [], params);
%
% updates parameters based on unweighted observations.
%
% params = self.update_map(obs, weights, [], params);
%
% updates parameters based on a weighted set of
% observations. weights can be an m-by-n matrix,
% for estimating m sets of parameters.
%
% param = self.update_map(obs, [], sidx, param);
%
% updates parameters based on a subset of non-weighted
% observations. sidx is a vector of indices.
%
% param = self.update_map(obs, weights, sidx, param);
%
% updates parameters based on a subset of weighted
% observations. weights and sidx should be vectors
% of the same length.
%

L = evaluate_loglik(self, params, obs)
% Evaluate log-likelihoods w.r.t. given parameters
%
% It evaluates the log-likelihood w.r.t a given set
% of parameters.
%
% The resultant matrix L should be an m-by-n matrix,
% where m is the number of components parameters, and
% n is the number of observed samples.
%

L = evaluate_logpri(self, params)
% Evaluate log prior values of the input parameters
%
% It evaluates the log-prior over a set of parameters
%
% If params contains m components, it should return
% a vector of length m in general. If no prior is associated,
% this function can simply return 0.
%

end

Expand Down
23 changes: 20 additions & 3 deletions distributions/pli_gauss_model.m
Expand Up @@ -3,7 +3,8 @@

properties
cform;
dim;
dim;
tie_cov;
end

methods
Expand All @@ -21,7 +22,14 @@
if ischar(cf) && isscalar(cf)
if cf == 's' || cf == 'd' || cf == 'f'
cform_ok = 1;
tiec = false;
end
elseif strcmp(cf(2:end), '-tied')
cf = cf(1);
if cf == 's' || cf == 'd' || cf == 'f'
cform_ok = 1;
tiec = true;
end
end

if ~cform_ok
Expand All @@ -33,6 +41,7 @@

obj.cform = cf;
obj.dim = d;
obj.tie_cov = tiec;
end

end
Expand All @@ -56,9 +65,17 @@
end


function G = estimate_param(self, X, weights, ~)
function G = update_params(self, X, weights, sidx, ~)

G = pli_gauss_mle(X, weights, self.cform);
if ~isempty(sidx)
X = X(:, sidx);
end

if self.tie_cov
G = pli_gauss_mle(X, weights, self.cform, 'tie-cov');
else
G = pli_gauss_mle(X, weights, self.cform);
end
end

function L = evaluate_loglik(~, G, X)
Expand Down
112 changes: 16 additions & 96 deletions mixmodels/pli_fmm_em_problem.m
Expand Up @@ -104,69 +104,7 @@ function set_obs(self, obs, w)
%% Methods to support EM estimation

methods

function v = eval_objv(self, sol)
% Evaluates the objective at a solution
%
% v = self.eval_objv(sol);
%
% The input solution should be completed.
%
% Note: the struct produced by the init_solution method
% is not completed, as components are initially empty.
% It would become completed after one update iteration.
%

mdl = self.model;
w = self.weights;
pric = self.pricount;

Q = sol.Q;
n = size(Q, 2);

if isempty(w)
w = ones(n, 1);
end

% sum of observation log-likelihoods

L = sol.logliks;
if isempty(L)
L = mdl.evaluate_loglik(sol.components, self.obs);
end

sll_obs = sum(L .* Q, 1) * w;

% sum of assignment log-likelihoods

log_pi = log(sol.pi);
sll_q = log_pi' * (Q * w);

% sum of component log-priors

lpri_comp = mdl.evaluate_logpri(sol.components);
lpri_comp = sum(lpri_comp);

% the log-prior of pi

if pric > 0
lpri_pi = pric * sum(log_pi);
else
lpri_pi = 0;
end

% entropy of Q

ent_q = - (nansum(Q .* log(Q), 1) * w);

% overall

v = sll_obs + sll_q + lpri_comp + lpri_pi + ent_q;

end



function sol = init_solution(self, K)
% Initializes a solution for EM estimation
%
Expand Down Expand Up @@ -210,6 +148,21 @@ function set_obs(self, obs, w)
sol.Q = Q;

end

function v = eval_objv(self, sol)
% Evaluates the objective at a solution
%
% v = self.eval_objv(sol);
%
% The input solution should be completed.
%
% Note: the struct produced by the init_solution method
% is not completed, as components are initially empty.
% It would become completed after one update iteration.
%

v = fmm_em_evalobjv(self, sol);
end


function sol = update(self, sol)
Expand All @@ -222,40 +175,7 @@ function set_obs(self, obs, w)
% an E-step.
%

mdl = self.model;
n = self.nobs;
O = self.obs;
w = self.weights;
pric = self.pricount;

% M-step

if isempty(w)
pi = sol.Q * ones(n, 1);
W = sol.Q.';
else
pi = sol.Q * w;
W = bsxfun(@times, sol.Q.', w);
end

if pric > 0
pi = pi + pric;
end
pi = pi / sum(pi);

params = mdl.estimate_param(O, W);
L = mdl.evaluate_loglik(params, O);

% E-step

Q = pli_nrmexp(bsxfun(@plus, L, log(pi)), 1);

% write update to sol

sol.pi = pi;
sol.components = params;
sol.logliks = L;
sol.Q = Q;
sol = fmm_em_update(self, sol);
end

end
Expand Down
6 changes: 4 additions & 2 deletions mixmodels/pli_gmm_em.m
Expand Up @@ -22,7 +22,7 @@
%
% Options
% -------
% - covform : The form of covariance:
% - covform : The form of covariance:
% 's' | 'd' | {'f'} | 's-tied' | 'd-tied' | 'f-tied'
%
% - weights : The sample weights, which should be an n x 1 vector.
Expand Down Expand Up @@ -84,9 +84,11 @@
% problem construction

d = size(X, 1);
model = pli_gauss_model(d, S.covform);
model = pli_gauss_model(d, S.covform);

problem = pli_fmm_em_problem(model, S.pricount);


problem.set_obs(X, S.weights);

% initialize solution
Expand Down
13 changes: 11 additions & 2 deletions mixmodels/plidemo_gmm_em.m
@@ -1,19 +1,28 @@
function plidemo_gmm_em()
function plidemo_gmm_em(cform)
%PLIDEMO_GMM_EM Demonstartes GMM estimation using EM
%
% PLIDEMO_GMM_EM();
% PLIDEMO_GMM_EM(cform);
%
% Here, cform is the form of covariance, which can be either of
% the following values: 's'|'d'|'f'|'s-tied'|'d-tied'|'f-tied'.
%
% Default value of cform is 'f';
%

%% Experiment configuration

if nargin == 0
cform = 'f';
end

d = 2;
K = 3;
n = 1000; % # samples / cluster

inter_sig = 6;
intra_sig = 1;

cform = 'f';

%% Generate data

Expand Down
52 changes: 52 additions & 0 deletions mixmodels/private/fmm_em_evalobjv.m
@@ -0,0 +1,52 @@
function objv = fmm_em_evalobjv(pb, sol)
%FMM_EM_EVALOBJV Evaluates objective of EM estimation for FMM
%
% objv = FMM_EM_EVALOBJV(pb, sol);
%

mdl = pb.model;
w = pb.weights;
pric = pb.pricount;

Q = sol.Q;
n = size(Q, 2);

if isempty(w)
w = ones(n, 1);
end

% sum of observation log-likelihoods

L = sol.logliks;
if isempty(L)
L = evaluate_loglik(mdl, sol.components, pb.obs);
end

sll_obs = sum(L .* Q, 1) * w;

% sum of assignment log-likelihoods

log_pi = log(sol.pi);
sll_q = log_pi' * (Q * w);

% sum of component log-priors

lpri_comp = evaluate_logpri(mdl, sol.components);
lpri_comp = sum(lpri_comp);

% the log-prior of pi

if pric > 0
lpri_pi = pric * sum(log_pi);
else
lpri_pi = 0;
end

% entropy of Q

ent_q = - (nansum(Q .* log(Q), 1) * w);

% overall

objv = sll_obs + sll_q + lpri_comp + lpri_pi + ent_q;

41 changes: 41 additions & 0 deletions mixmodels/private/fmm_em_update.m
@@ -0,0 +1,41 @@
function sol = fmm_em_update(pb, sol)
%PLI_FMM_EM_UPDATE Update an EM solution of Finite mixture estimation
%
% sol = PLI_FMM_EM_UPDATE(pb, sol);
%

mdl = pb.model;
n = pb.nobs;
obs = pb.obs;
w = pb.weights;
pric = pb.pricount;

% M-step

if isempty(w)
pi = sol.Q * ones(n, 1);
W = sol.Q.';
else
pi = sol.Q * w;
W = bsxfun(@times, sol.Q.', w);
end

if pric > 0
pi = pi + pric;
end
pi = pi / sum(pi);

params = update_params(mdl, obs, W, [], sol.components);
L = evaluate_loglik(mdl, params, obs);

% E-step

Q = pli_nrmexp(bsxfun(@plus, L, log(pi)), 1);

% write update to sol

sol.pi = pi;
sol.components = params;
sol.logliks = L;
sol.Q = Q;

0 comments on commit 92102bc

Please sign in to comment.