Permalink
Browse files

Refactor to using numpy arrays for relatively small matrices

  • Loading branch information...
1 parent 90df3b9 commit a5c2781bff0904b66cf9cf9afe5fc9b310780fc9 @nzhiltsov committed Jan 2, 2013
Showing with 19 additions and 49 deletions.
  1. +19 −49 rescal.py
View
@@ -10,11 +10,10 @@
import numpy as np
import os
import fnmatch
-import carray as ca
import handythread
import operator
import itertools
-from multiprocessing import Pool, Process, Manager, Array
+from multiprocessing import Pool
__version__ = "0.1"
__all__ = ['rescal', 'rescal_with_random_restarts']
@@ -27,9 +26,9 @@
logging.basicConfig(filename='rescal.log',filemode='w', level=logging.DEBUG)
_log = logging.getLogger('RESCAL')
-ARk = ca.zeros((1,1))
-Aglobal = ca.zeros((1,1))
-Xiglobal = ca.zeros((1,1))
+ARk = zeros((1,1))
+Aglobal = zeros((1,1))
+Xiglobal = zeros((1,1))
def rescal_with_random_restarts(X, rank, restarts=10, **kwargs):
"""
@@ -51,39 +50,12 @@ def squareFrobeniusNormOfSparse(M):
norm = sum(M.dot(M.transpose()).diagonal())
return norm
-def minus(L, R):
- """
- Compute L - R for cArray matrices
- """
- l1, l2 = L.shape
- r1, r2 = R.shape
- if l1 != r1 or l2 != r2:
- raise 'Both the matrices must have the same shape.'
- matrix = L.copy()
- for i in range(l1):
- for j in range(l2):
- matrix[i,j] = matrix[i,j] - R[i,j]
- return matrix
-
-def dotAsCArray(L, R):
- """
- Computes the dot product as a cArray
- """
- l1, l2 = L.shape
- r1, r2 = R.shape
-
- matrix = ca.zeros((l1, r2))
- for i in range(l1):
- for j in range(r2):
- matrix[i,j] = dot(L[i,:],R[:,j])
- return matrix
-
def squareOfMatrix(M):
"""
Computes A^T * A, i.e., the square of a given matrix
"""
n,r = M.shape
- matrix = ca.zeros((r, r))
+ matrix = zeros((r, r))
for i in range(r):
for j in range(r):
matrix[i,j] = dot(M[:,i], M[:,j])
@@ -99,8 +71,8 @@ def fitNorm(i):
"""
Computes the squared Frobenius norm of the i-th fitting matrix row
"""
- n, r = A.shape
- ARAtValues = handythread.parallel_map2(ARAtFunc, range(n), ARk[i,:], Aglobal, threads=7)
+ n, r = Aglobal.shape
+ ARAtValues = handythread.parallel_map2(ARAtFunc, range(n), ARk[i,:], Aglobal, threads=2)
return norm(Xiglobal.getrow(i).todense() - ARAtValues)**2
def rescal(X, rank, **kwargs):
@@ -176,11 +148,9 @@ def rescal(X, rank, **kwargs):
sumNormX = sum(normX)
# initialize A
- A = ca.zeros((n,rank), dtype=np.float64)
+ A = zeros((n,rank), dtype=np.float64)
if ainit == 'random':
-# A = array(rand(n, rank), dtype=np.float64)
- for k in range(n/1000):
- A[k*1000:(k+1)*1000,0:rank] = rand(1000, rank)
+ A = array(rand(n, rank), dtype=np.float64)
elif ainit == 'nvecs':
S = coo_matrix((n, n), dtype=np.float64)
T = coo_matrix((n, n), dtype=dtype)
@@ -229,7 +199,7 @@ def rescal(X, rank, **kwargs):
fit = 0
for i in range(len(R)):
global ARk
- ARk = dotAsCArray(A, R[i])
+ ARk = dot(A, R[i])
global Xiglobal
Xiglobal = X[i]
p = Pool(4)
@@ -251,22 +221,22 @@ def rescal(X, rank, **kwargs):
fit, fitchange, exectimes[-1]))
fitold = fit
-# if iter > 1 and fitchange < conv:
-# break
+ if iter > 1 and fitchange < conv:
+ break
return A, R, fit, iter+1, array(exectimes)
def __updateA(X, A, R, lmbda):
n, rank = A.shape
- F = ca.zeros((n,rank))
+ F = zeros((n,rank))
E = zeros((rank, rank), dtype=np.float64)
AtA = squareOfMatrix(A)
for i in range(len(X)):
- ar = dotAsCArray(A, R[i])
- art = dotAsCArray(A, R[i].T)
+ ar = dot(A, R[i])
+ art = dot(A, R[i].T)
F = F + X[i].dot(art) + X[i].T.dot(ar)
- E = E + dotAsCArray(R[i], dotAsCArray(AtA, R[i].T)) + dotAsCArray(R[i].T, dotAsCArray(AtA, R[i]))
- A = dotAsCArray(F, inv(lmbda * eye(rank) + E))
+ E = E + dot(R[i], dot(AtA, R[i].T)) + dot(R[i].T, dot(AtA, R[i]))
+ A = dot(F, inv(lmbda * eye(rank) + E))
return A
def __updateR(X, A, lmbda):
@@ -314,8 +284,8 @@ def __projectSlices(X, Q):
col = loadtxt('./data2/' + file.replace("rows", "cols"), dtype=np.int32)
if col.size == 1:
col = np.atleast_1d(col)
- A = coo_matrix((ones(row.size),(row,col)), shape=(dim,dim), dtype=np.uint8)
- X.append(A)
+ Xi = coo_matrix((ones(row.size),(row,col)), shape=(dim,dim), dtype=np.uint8).tolil()
+ X.append(Xi)
print 'The number of slices: %d' % numSlices

0 comments on commit a5c2781

Please sign in to comment.