forked from scikit-learn/scikit-learn
-
Notifications
You must be signed in to change notification settings - Fork 0
/
_cdnmf_fast.pyx
38 lines (28 loc) · 1.11 KB
/
_cdnmf_fast.pyx
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Author: Mathieu Blondel, Tom Dupre la Tour
# License: BSD 3 clause
from cython cimport floating
from libc.math cimport fabs
def _update_cdnmf_fast(floating[:, ::1] W, floating[:, :] HHt,
floating[:, :] XHt, Py_ssize_t[::1] permutation):
cdef:
floating violation = 0
Py_ssize_t n_components = W.shape[1]
Py_ssize_t n_samples = W.shape[0] # n_features for H update
floating grad, pg, hess
Py_ssize_t i, r, s, t
with nogil:
for s in range(n_components):
t = permutation[s]
for i in range(n_samples):
# gradient = GW[t, i] where GW = np.dot(W, HHt) - XHt
grad = -XHt[i, t]
for r in range(n_components):
grad += HHt[t, r] * W[i, r]
# projected gradient
pg = min(0., grad) if W[i, t] == 0 else grad
violation += fabs(pg)
# Hessian
hess = HHt[t, t]
if hess != 0:
W[i, t] = max(W[i, t] - grad / hess, 0.)
return violation