Skip to content

Commit

Permalink
Merge branch 'hotfix-0.1.3' into develop
Browse files Browse the repository at this point in the history
  • Loading branch information
nzhiltsov committed Feb 24, 2013
2 parents 89bafe4 + 3ecf88d commit ac9aec1
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 22 deletions.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -8,7 +8,7 @@ Ext-RESCAL is a memory efficient implementation of the [RESCAL algorithm](http:/

Current Version
------------
[0.1.2](https://github.com/nzhiltsov/Ext-RESCAL/archive/0.1.2.zip)
[0.1.3](https://github.com/nzhiltsov/Ext-RESCAL/archive/0.1.3.zip)

Features
------------
Expand Down
36 changes: 15 additions & 21 deletions rescal.py
Expand Up @@ -11,11 +11,11 @@

__version__ = "0.1"

__DEF_MAXITER = 100
__DEF_MAXITER = 50
__DEF_PREHEATNUM = 1
__DEF_INIT = 'nvecs'
__DEF_PROJ = True
__DEF_CONV = 1e-6
__DEF_CONV = 1e-5
__DEF_LMBDA = 0
__DEF_EXACT_FIT = False

Expand Down Expand Up @@ -138,35 +138,29 @@ def rescal(X, rank, **kwargs):
regRFit = 0
if iter > preheatnum:
if lmbda != 0:
for i in range(len(R)):
for i in xrange(len(R)):
regRFit += norm(R[i])**2
regularizedFit = lmbda*(norm(A)**2) + lmbda*regRFit

if exactfit:
for i in range(len(R)):
for i in xrange(len(R)):
fit = norm(X[i] - dot(A,dot(R[i], A.T)))**2
else :
for i in range(len(R)):
for i in xrange(len(R)):
ARk = dot(A, R[i])
Xrow, Xcol = X[i].nonzero()
fits = []
for rr in range(len(Xrow)):
fits.append(fitNorm(Xrow[rr], Xcol[rr], X[i], ARk, A))
fit = sum(fits)
fit *= 0.5
fit += regularizedFit
fit /= sumNormX
for rr in xrange(len(Xrow)):
fit += fitNorm(Xrow[rr], Xcol[rr], X[i], ARk, A)
fit *= 0.5
fit += regularizedFit
fit /= sumNormX
else :
_log.debug('[Algorithm] Preheating is going on.')

toc = time.clock()
exectimes.append( toc - tic )
fitchange = abs(fitold - fit)
if lmbda != 0:
_log.debug('[%3d] totalFit: %.20f | regularized fit: %.20f | delta: %.20f | secs: %.5f' % (iter,
fit, regularizedFit, fitchange, exectimes[-1]))
else :
_log.debug('[%3d] totalFit: %.20f | delta: %.20f | secs: %.5f' % (iter,
_log.debug('[%3d] total fit: %.10f | delta: %.10f | secs: %.5f' % (iter,
fit, fitchange, exectimes[-1]))

fitold = fit
Expand All @@ -180,7 +174,7 @@ def updateA(X, A, R, lmbda):
E = zeros((rank, rank), dtype=np.float64)

AtA = dot(A.T, A)
for i in range(len(X)):
for i in xrange(len(X)):
ar = dot(A, R[i])
art = dot(A, R[i].T)
F += X[i].dot(art) + X[i].T.dot(ar)
Expand All @@ -194,20 +188,20 @@ def __updateR(X, A, lmbda):
At = A.T
if lmbda == 0:
ainv = dot(pinv(dot(At, A)), At)
for i in range(len(X)):
for i in xrange(len(X)):
R.append( dot(ainv, X[i].dot(ainv.T)) )
else :
AtA = dot(At, A)
tmp = inv(kron(AtA, AtA) + lmbda * eye(r**2))
for i in range(len(X)):
for i in xrange(len(X)):
AtXA = dot(At, X[i].dot(A))
R.append( dot(AtXA.flatten(), tmp).reshape(r, r) )
return R

def __projectSlices(X, Q):
q = Q.shape[1]
X2 = []
for i in range(len(X)):
for i in xrange(len(X)):
X2.append( dot(Q.T, X[i].dot(Q)) )
return X2

Expand Down

0 comments on commit ac9aec1

Please sign in to comment.