In [111]:
import numpy as np
class LowRankFactorization():
    
    def __init__(self, n_R, n_C, K, alpha, reg):
        '''
        n_R: number of rows
        n_C: number of cols
        K: latent factors dimensionality
        reg: regularization strength
        alpha: learning rate
        '''
        self.R = np.random.randn(n_R, K) / K
        self.C = np.random.randn(n_C, K) / K
        self.R_bias = np.zeros(n_R)
        self.C_bias = np.zeros(n_C)
        self.reg = reg
        self.alpha = alpha
        self.F = None
    
    def forward(self):
        self.F = self.R.dot(self.C.T) + (self.R_bias[:,np.newaxis] + self.C_bias[np.newaxis:,])
        return np.copy(self.F)
    
    def mse_loss(self, y, y_hat):
        return np.sqrt((y - y_hat)**2)
    
    def backward(self, M):
        Z = np.argwhere(M > 0)
        
        # sgd
        for i, j in Z:
            loss = M[i, j] - self.F[i, j]
            self.R_bias[i] += self.alpha * (loss - self.reg * self.R_bias[i])
            self.C_bias[j] += self.alpha * (loss - self.reg * self.C_bias[j])
            
            self.R[i, :] += self.alpha * (loss * self.C[j, :] - self.reg * self.R[i, :])
            self.C[j, :] += self.alpha * (loss * self.R[i, :] - self.reg * self.C[j, :])

In [112]:
R = np.random.randint(5, size=(10,10))

In [114]:
model = LowRankFactorization(R.shape[0], R.shape[1], 3, .01, .001)
epochs = 420

for e in range(epochs):
    model.forward()
    model.backward(R)

print("real: ")
#print_matrix(R)
print(R)
print("pred: ")
print(np.around(model.forward(), decimals=3))

real: 
[[3 1 3 4 3 3 4 3 1 0]
 [4 1 4 1 1 2 4 0 3 0]
 [1 2 4 3 4 2 4 3 0 0]
 [4 3 2 0 1 0 1 0 4 4]
 [3 3 3 4 3 0 2 4 0 1]
 [1 3 3 2 3 4 2 3 0 2]
 [0 0 1 4 1 4 0 4 3 1]
 [3 0 0 4 0 4 0 4 2 2]
 [4 4 3 3 2 1 2 3 4 2]
 [1 2 3 1 4 4 3 4 4 4]]
pred: 
[[ 3.017  0.749  3.159  3.724  2.189  3.368  4.074  3.239  1.457  3.513]
 [ 3.736  1.312  3.683  1.626  2.049  1.057  3.998  1.883  2.508  5.258]
 [ 1.748  1.807  3.815  1.89   3.62   3.131  3.835  3.138  3.497  4.707]
 [ 3.868  2.964  2.395 -1.07   0.692 -2.172  1.099  0.076  4.121  3.844]
 [ 2.816  3.291  2.509  4.03   2.654  4.044  2.159  4.208  3.719  1.346]
 [ 0.997  2.834  2.62   2.12   3.284  3.53   1.941  3.539  4.14   2.142]
 [ 2.624  2.499  1.759  4.05   1.717  3.62   1.644  3.738  2.633  0.39 ]
 [ 2.872  1.639  2.648  4.255  2.199  3.939  3.162  3.797  2.057  2.093]
 [ 4.317  3.817  2.609  2.696  1.76   1.63   1.832  2.874  4.197  2.261]
 [ 0.576  2.172  3.457  1.615  4.037  3.683  3.094  3.444  4.049  3.855]]
