Skip to content
This repository

HTTPS clone URL

Subversion checkout URL

You can clone with HTTPS or Subversion.

Download ZIP
Browse code

Merge branch 'balancefit' into develop

  • Loading branch information...
commit 97f9a77356c277a6232358f4b9d618b379206fd3 2 parents 829d019 + b1e552d
Nikita Zhiltsov authored

Showing 1 changed file with 16 additions and 14 deletions. Show diff stats Hide diff stats

  1. +16 14 extrescal.py
30 extrescal.py
@@ -12,7 +12,7 @@
12 12
13 13 __version__ = "0.1"
14 14
15   -__DEF_MAXITER = 100
  15 +__DEF_MAXITER = 50
16 16 __DEF_PREHEATNUM = 1
17 17 __DEF_INIT = 'nvecs'
18 18 __DEF_PROJ = True
@@ -110,7 +110,9 @@ def rescal(X, D, rank, **kwargs):
110 110 # precompute norms of X
111 111 normX = [squareFrobeniusNormOfSparse(M) for M in X]
112 112 sumNormX = sum(normX)
  113 + normD = squareFrobeniusNormOfSparse(D)
113 114 _log.debug('[Algorithm] The tensor norm: %.5f' % sumNormX)
  115 + _log.debug('[Algorithm] The extended matrix norm: %.5f' % normD)
114 116 # initialize A
115 117 A = zeros((n,rank), dtype=np.float64)
116 118 if ainit == 'random':
@@ -149,15 +151,18 @@ def rescal(X, D, rank, **kwargs):
149 151 V = updateV(A, D, lmbda)
150 152 # compute fit values
151 153 fit = 0
  154 + tensorFit = 0
152 155 regularizedFit = 0
  156 + extRegularizedFit = 0
153 157 regRFit = 0
154 158 fitDAV = 0
155 159 if iter > preheatnum:
156 160 if lmbda != 0:
157 161 for i in range(len(R)):
158 162 regRFit += norm(R[i])**2
159   - regularizedFit = lmbda*(norm(A)**2) + lmbda*regRFit + lmbda*(norm(V)**2)
160   -
  163 + regularizedFit = lmbda*(norm(A)**2) + lmbda*regRFit
  164 + if lmbda != 0:
  165 + extRegularizedFit = lmbda*(norm(V)**2)
161 166 if exactfit:
162 167 fitDAV = norm(D - dot(A,V))**2
163 168 else :
@@ -167,7 +172,7 @@ def rescal(X, D, rank, **kwargs):
167 172
168 173 if exactfit:
169 174 for i in range(len(R)):
170   - fit = norm(X[i] - dot(A,dot(R[i], A.T)))**2
  175 + tensorFit = norm(X[i] - dot(A,dot(R[i], A.T)))**2
171 176 else :
172 177 for i in range(len(R)):
173 178 ARk = dot(A, R[i])
@@ -175,24 +180,21 @@ def rescal(X, D, rank, **kwargs):
175 180 fits = []
176 181 for rr in range(len(Xrow)):
177 182 fits.append(fitNorm(Xrow[rr], Xcol[rr], X[i], ARk, A))
178   - fit = sum(fits)
  183 + tensorFit = sum(fits)
179 184
180   - fit *= 0.5
  185 + fit = 0.5*tensorFit
181 186 fit += regularizedFit
182   - fit += fitDAV
183   - fit /= sumNormX
  187 + fit /= sumNormX
  188 + fit += (0.5*fitDAV + extRegularizedFit)/normD
  189 +
184 190 else :
185 191 _log.debug('[Algorithm] Preheating is going on.')
186 192
187 193 toc = time.clock()
188 194 exectimes.append( toc - tic )
189 195 fitchange = abs(fitold - fit)
190   - if lmbda != 0:
191   - _log.debug('[%3d] totalFit: %.20f | regularized fit: %.20f | extended fit: %.20f | delta: %.20f | secs: %.5f' % (iter,
192   - fit, regularizedFit, fitDAV, fitchange, exectimes[-1]))
193   - else :
194   - _log.debug('[%3d] totalFit: %.20f | extended fit: %.20f | delta: %.20f | secs: %.5f' % (iter,
195   - fit, fitDAV, fitchange, exectimes[-1]))
  196 + _log.debug('[%3d] total fit: %.10f | tensor fit: %.10f | matrix fit: %.10f | delta: %.10f | secs: %.5f' % (iter,
  197 + fit, tensorFit, fitDAV, fitchange, exectimes[-1]))
196 198
197 199 fitold = fit
198 200 if iter > preheatnum and fitchange < conv:

0 comments on commit 97f9a77

Please sign in to comment.
Something went wrong with that request. Please try again.