-
Notifications
You must be signed in to change notification settings - Fork 27
/
computeBetaSparseVariational.m
33 lines (31 loc) · 1.06 KB
/
computeBetaSparseVariational.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
function eta = computeBetaSparseVariational(ecounts,eq_m,varargin)
%function [eta bound] =
%computeBetaSparseVariational(ecounts,eq_m,varargin)
% newton optimization, variational EM for tau.
[max_its verbose init_eta min_eta max_inv_tau] = ...
process_options(varargin,'max-its',1,...
'verbose',0,'init-eta',[],'min-eta',1e-20,...
'max-inv-tau',1e5);
[W K] = size(ecounts); %eta = zeros(size(ecounts));
if isempty(init_eta),
eta = zeros(W,1);
eq_inv_tau = ones(size(eta));
else
eta = init_eta;
eta(eta.^2<min_eta.^2) = sign(eta(eta.^2<min_eta^2))*min_eta;
eq_inv_tau = 1./(eta.^2);
end
if verbose==1, fprintf('.'); end
if sum(ecounts)==0,
eta = zeros(W,1);
else
em_iter = newDeltaIterator(max_its,'debug',verbose,'thresh',1e-4);
exp_eq_m = exp(eq_m);
while ~(em_iter.done)
eta = newtonArmijo(@evalLogNormal,eta,{ecounts,exp_eq_m,eq_inv_tau},'debug',verbose>1,'init-alpha',.1,'max-its',10000);
eq_inv_tau = 1./(eta.^2);
eq_inv_tau(eq_inv_tau >= max_inv_tau) = max_inv_tau;
em_iter = updateDeltaIterator(em_iter,eta);
end
end
end