Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse files

Merge branch 'balancefit' into develop

  • Loading branch information...
commit 97f9a77356c277a6232358f4b9d618b379206fd3 2 parents 829d019 + b1e552d
@nzhiltsov authored
Showing with 16 additions and 14 deletions.
  1. +16 −14 extrescal.py
View
30 extrescal.py
@@ -12,7 +12,7 @@
__version__ = "0.1"
-__DEF_MAXITER = 100
+__DEF_MAXITER = 50
__DEF_PREHEATNUM = 1
__DEF_INIT = 'nvecs'
__DEF_PROJ = True
@@ -110,7 +110,9 @@ def rescal(X, D, rank, **kwargs):
# precompute norms of X
normX = [squareFrobeniusNormOfSparse(M) for M in X]
sumNormX = sum(normX)
+ normD = squareFrobeniusNormOfSparse(D)
_log.debug('[Algorithm] The tensor norm: %.5f' % sumNormX)
+ _log.debug('[Algorithm] The extended matrix norm: %.5f' % normD)
# initialize A
A = zeros((n,rank), dtype=np.float64)
if ainit == 'random':
@@ -149,15 +151,18 @@ def rescal(X, D, rank, **kwargs):
V = updateV(A, D, lmbda)
# compute fit values
fit = 0
+ tensorFit = 0
regularizedFit = 0
+ extRegularizedFit = 0
regRFit = 0
fitDAV = 0
if iter > preheatnum:
if lmbda != 0:
for i in range(len(R)):
regRFit += norm(R[i])**2
- regularizedFit = lmbda*(norm(A)**2) + lmbda*regRFit + lmbda*(norm(V)**2)
-
+ regularizedFit = lmbda*(norm(A)**2) + lmbda*regRFit
+ if lmbda != 0:
+ extRegularizedFit = lmbda*(norm(V)**2)
if exactfit:
fitDAV = norm(D - dot(A,V))**2
else :
@@ -167,7 +172,7 @@ def rescal(X, D, rank, **kwargs):
if exactfit:
for i in range(len(R)):
- fit = norm(X[i] - dot(A,dot(R[i], A.T)))**2
+ tensorFit = norm(X[i] - dot(A,dot(R[i], A.T)))**2
else :
for i in range(len(R)):
ARk = dot(A, R[i])
@@ -175,24 +180,21 @@ def rescal(X, D, rank, **kwargs):
fits = []
for rr in range(len(Xrow)):
fits.append(fitNorm(Xrow[rr], Xcol[rr], X[i], ARk, A))
- fit = sum(fits)
+ tensorFit = sum(fits)
- fit *= 0.5
+ fit = 0.5*tensorFit
fit += regularizedFit
- fit += fitDAV
- fit /= sumNormX
+ fit /= sumNormX
+ fit += (0.5*fitDAV + extRegularizedFit)/normD
+
else :
_log.debug('[Algorithm] Preheating is going on.')
toc = time.clock()
exectimes.append( toc - tic )
fitchange = abs(fitold - fit)
- if lmbda != 0:
- _log.debug('[%3d] totalFit: %.20f | regularized fit: %.20f | extended fit: %.20f | delta: %.20f | secs: %.5f' % (iter,
- fit, regularizedFit, fitDAV, fitchange, exectimes[-1]))
- else :
- _log.debug('[%3d] totalFit: %.20f | extended fit: %.20f | delta: %.20f | secs: %.5f' % (iter,
- fit, fitDAV, fitchange, exectimes[-1]))
+ _log.debug('[%3d] total fit: %.10f | tensor fit: %.10f | matrix fit: %.10f | delta: %.10f | secs: %.5f' % (iter,
+ fit, tensorFit, fitDAV, fitchange, exectimes[-1]))
fitold = fit
if iter > preheatnum and fitchange < conv:
Please sign in to comment.
Something went wrong with that request. Please try again.