-
Notifications
You must be signed in to change notification settings - Fork 0
/
gradws.m
57 lines (40 loc) · 1.39 KB
/
gradws.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
% function definition for 'minFunc' optimiser
% fv = function value
% dfv = gradient of the fucntion w.r.t W
function [fv dv] = gradws(w, gamma, x, y, mask, lambda, options)
CLS = size(gamma, 2);
DIM = size(x, 2);
% compute regularisation term
[regV regDV] = feval(options.regFunc, w, options.sn);
reg = sum(sum(lambda.*regV));
w = reshape(w, DIM, CLS);
lambda = reshape(lambda, DIM, CLS);
regDV = reshape(regDV, DIM, CLS);
% calculate function value
sk = posteriorYHat(x, w, gamma);
fv = -sum(log(sk(mask)), 1) + reg;
% faster way
tmpw = zeros(DIM, CLS);
logit = softmax(x * w);
for c = 1:CLS
% calculate logit w.r.t to w_c and repmat to match sk
logit_c = repmat(logit(:,c), 1, CLS);
% calculate the big chunk in parentheses
gc = repmat(gamma(c,:), CLS, 1);
gc = gc - gamma;
chunk = logit * gc;
% calculate multiplier to X
multi = (logit_c .* chunk) ./ sk ;
% multiply with X
llhw = zeros(1, DIM);
for k = 1:CLS
%meaning res = repmat(multi(:,k), 1, DIM) .* x;
res = bsxfun(@times,multi(:,k),x);
llhw = llhw - sum(res(y==k,:), 1);
end
% save the derivative of w_c
tmpw(:,c) = llhw';
end
% compute the derivative of the objective
tmpw = tmpw + (lambda .* regDV);
dv = cat(2, tmpw(:));