Permalink
Browse files

Change the objective to an approximate version (by non-zero values)

  • Loading branch information...
1 parent a5c2781 commit 800b29f701d9b365b8276b8879c9f543395a1587 @nzhiltsov committed Jan 3, 2013
Showing with 19 additions and 86 deletions.
  1. +7 −72 handythread.py
  2. +12 −14 rescal.py
View
@@ -3,7 +3,7 @@
import threading
from itertools import izip, count
-def foreach(f,l,ARk, X, A, threads=4,return_=False):
+def foreach(f, l, threads=4, return_=False):
"""
Apply f to each element of l, in parallel
"""
@@ -34,9 +34,9 @@ def runall():
try:
if return_:
n,x = v
- d[n] = f(x, ARk, X, A)
+ d[n] = f(x)
else:
- f(v, ARk, X[i], A)
+ f(v)
except:
e = sys.exc_info()
iteratorlock.acquire()
@@ -59,76 +59,11 @@ def runall():
return [v for (n,v) in r]
else:
if return_:
- return [f(v, ARk, X, A) for v in l]
+ return [f(v) for v in l]
else:
for v in l:
- f(v, ARk, X, A)
+ f(v)
return
-def parallel_map(f,l,ARk, X, A,threads=4):
- return foreach(f,l,ARk, X, A, threads=threads,return_=True)
-
-def parallel_map2(f,l,ARki, A,threads=4):
- return foreach2(f,l,ARki, A, threads=threads,return_=True)
-
-def foreach2(f,l,ARki, A, threads=4,return_=False):
- """
- Apply f to each element of l, in parallel
- """
-
- if threads>1:
- iteratorlock = threading.Lock()
- exceptions = []
- if return_:
- n = 0
- d = {}
- i = izip(count(),l.__iter__())
- else:
- i = l.__iter__()
-
-
- def runall():
- while True:
- iteratorlock.acquire()
- try:
- try:
- if exceptions:
- return
- v = i.next()
- finally:
- iteratorlock.release()
- except StopIteration:
- return
- try:
- if return_:
- n,x = v
- d[n] = f(x, ARki, A)
- else:
- f(v, ARki, A)
- except:
- e = sys.exc_info()
- iteratorlock.acquire()
- try:
- exceptions.append(e)
- finally:
- iteratorlock.release()
-
- threadlist = [threading.Thread(target=runall) for j in xrange(threads)]
- for t in threadlist:
- t.start()
- for t in threadlist:
- t.join()
- if exceptions:
- a, b, c = exceptions[0]
- raise a, b, c
- if return_:
- r = d.items()
- r.sort()
- return [v for (n,v) in r]
- else:
- if return_:
- return [f(v, ARki, A) for v in l]
- else:
- for v in l:
- f(v, ARki, A)
- return
+def parallel_map(f, l, threads=4):
+ return foreach(f, l, threads=threads,return_=True)
View
@@ -13,15 +13,14 @@
import handythread
import operator
import itertools
-from multiprocessing import Pool
__version__ = "0.1"
__all__ = ['rescal', 'rescal_with_random_restarts']
__DEF_MAXITER = 500
__DEF_INIT = 'nvecs'
__DEF_PROJ = True
-__DEF_CONV = 1e-2
+__DEF_CONV = 1e-5
__DEF_LMBDA = 0
logging.basicConfig(filename='rescal.log',filemode='w', level=logging.DEBUG)
@@ -61,19 +60,15 @@ def squareOfMatrix(M):
matrix[i,j] = dot(M[:,i], M[:,j])
return matrix
-def ARAtFunc(j, ARki, A):
- """
- Computes the j-th row of the matrix ARk * A^T
- """
- return dot(ARki, A[j,:])
-def fitNorm(i):
+def fitNorm(t):
"""
- Computes the squared Frobenius norm of the i-th fitting matrix row
+ Computes i,j element of the squared Frobenius norm of the fitting matrix
"""
+ row, col = t
n, r = Aglobal.shape
- ARAtValues = handythread.parallel_map2(ARAtFunc, range(n), ARk[i,:], Aglobal, threads=2)
- return norm(Xiglobal.getrow(i).todense() - ARAtValues)**2
+ ARAtValue = dot(ARk[row,:], Aglobal[col,:])
+ return (Xiglobal[row, col] - ARAtValue)**2
def rescal(X, rank, **kwargs):
"""
@@ -202,8 +197,11 @@ def rescal(X, rank, **kwargs):
ARk = dot(A, R[i])
global Xiglobal
Xiglobal = X[i]
- p = Pool(4)
- fits = p.map(fitNorm, range(n))
+ Xrow, Xcol = Xiglobal.nonzero()
+ nonzeroElems = []
+ for rr in range(len(Xrow)):
+ nonzeroElems.append((Xrow[rr], Xcol[rr]))
+ fits = handythread.parallel_map(fitNorm, nonzeroElems)
fit += sum(fits)
fit *= 0.5
fit += regularizedFit
@@ -289,7 +287,7 @@ def __projectSlices(X, Q):
print 'The number of slices: %d' % numSlices
-result = rescal(X, numLatentComponents, init='random')
+result = rescal(X, numLatentComponents, init='random', lmbda=0.1)
print 'Objective function value: %.5f' % result[2]
print '# of iterations: %d' % result[3]
#print the matrix of latent embeddings

0 comments on commit 800b29f

Please sign in to comment.