In [29]:
import scipy.io
import numpy as np
from scipy.sparse import csr_matrix
from projL1 import projL1
from time import time

In [30]:
def projNuc(Z, kappa):
    #PROJNUC This function implements the projection onto nuclear norm ball.
    
    # Implement projection operator here!
    u, s, vh = np.linalg.svd(Z, full_matrices=False)
    s_l1 = projL1(s, kappa)
    
    return u@np.diag(s_l1)@vh

In [47]:
data = scipy.io.loadmat('./dataset/ml-100k/ub_base')  # load 100k dataset

Rating = data['Rating'].flatten()
UserID = data['UserID'].flatten() - 1  # Python indexing starts from 0 whereas Matlab from 1
MovID = data['MovID'].flatten() - 1    # Python indexing starts from 0 whereas Matlab from 1

nM = np.amax(data['MovID'])
nU = np.amax(data['UserID'])

Z = csr_matrix((Rating, (MovID, UserID)),shape=(nM, nU),dtype=float).toarray()
kappa = 5000 #5000

tstart = time()
Z_proj = projNuc(Z, kappa)
elapsed = time() - tstart
print('proj for 100k data takes {} sec'.format(elapsed))

proj for 100k data takes 0.4889242649078369 sec


In [48]:
Z_proj[0:20,0]#-Rating[:20]

array([4.57896269, 2.28692188, 1.95500422, 2.83015962, 1.45639257,
       1.46208011, 4.06488714, 2.1232902 , 3.4456202 , 1.93636815,
       2.29905439, 4.19155123, 3.61270109, 3.261568  , 3.17096125,
       2.03031605, 1.00282556, 0.9777876 , 2.68664031, 2.3901721 ])

In [5]:
# NOTE: This one can take few minutes!
data = scipy.io.loadmat('./dataset/ml-1m/ml1m_base')  # load 1M dataset

Rating = data['Rating'].flatten()
UserID = data['UserID'].flatten() - 1  # Python indexing starts from 0 whereas Matlab from 1
MovID = data['MovID'].flatten() - 1    # Python indexing starts from 0 whereas Matlab from 1

nM = np.amax(data['MovID'])
nU = np.amax(data['UserID'])

Z = csr_matrix((Rating, (MovID, UserID)),shape=(nM, nU),dtype=float).toarray()
kappa = 5000

tstart = time()
Z_proj = projNuc(Z, kappa)
elapsed = time() - tstart
print('proj for 1M data takes {} sec'.format(elapsed))

proj for 1M data takes 41.82280373573303 sec


In [6]:
Z_proj

array([[ 1.82324537e+00,  9.65689972e-01,  9.70444810e-01, ...,
         3.50175334e-01,  1.07139049e+00,  1.68496930e+00],
       [ 3.08464114e-01,  2.59729966e-01,  2.67246064e-01, ...,
         2.67766304e-02,  2.10915611e-01, -2.70079909e-02],
       [ 7.07435878e-02,  1.53593983e-01,  1.00178301e-01, ...,
         4.46633966e-03,  1.04055868e-01, -1.45571074e-01],
       ...,
       [ 3.02602494e-03,  1.30904611e-02, -1.27481384e-02, ...,
        -5.55886832e-03,  2.03766891e-02,  1.04262425e-01],
       [ 8.29825343e-03,  4.65417082e-03, -6.81981330e-03, ...,
        -1.48917089e-03,  5.08354843e-04,  9.27585745e-02],
       [ 4.78845686e-02,  1.37568060e-01, -3.14723762e-02, ...,
        -2.17695947e-02, -4.55672749e-02,  4.26821198e-01]])