In [3]:
import numpy as np
import pandas as pd

In [2]:
# 评分矩阵R
R = np.array([[4,0,2,0,1],
             [0,2,3,0,0],
             [1,0,2,4,0],
             [5,0,0,3,1],
             [0,0,1,5,1],
             [0,3,2,4,1],])
len(R[0])

5

In [4]:
"""
@输入参数：
R：M*N 的评分矩阵
K：隐特征向量维度
max_iter: 最大迭代次数
alpha：步长
lamda：正则化系数

@输出：
分解之后的 P，Q
P：初始化用户特征矩阵M*K
Q：初始化物品特征矩阵N*K
"""

# 给定超参数

K = 5
max_iter = 5000
alpha = 0.0002
lamda = 0.004

def LFM_grad_desc(R, K=2, max_iter=1000, alpha=0.0001, lamda = 0.002):
    M = len(R)
    N = len(R[0])

    P = np.random.rand(M, K)
    Q = np.random.rand(N, K)
    Q = Q.T

    # 迭代
    for step in range(max_iter):
        # 对所有的用户u，物品i遍历，对应的特征向量Pu，Qi
        for u in range(M):
            for i in range(N):
                # 对于每一个大于0的评分，求出预测评分误差
                if R[u][i] > 0:
                    eui = np.dot(P[u,:], Q[:,i]) - R[u][i]
                    # 代入公式，梯度下降更新Pu和Qi
                    for k in range(K):
                        P[u][k] = P[u][k] - alpha * (2*eui * Q[k][i] + 2 * lamda * P[u][k])
                        Q[k][i] = Q[k][i] - alpha * (2*eui * P[u][k] + 2 *lamda * Q[k][i])
        # u, i遍历完成，所有特征向量更新完成，可以得到P，Q
        predR = np.dot(P, Q)

        # 迭代结束条件：计算当前loss
        cost = 0
        for u in range(M):
            for i in range(N):
                if R[u][i] > 0:
                    cost += (np.dot(P[u,:], Q[:,i]) - R[u][i]) ** 2
                    # 加上正则化项
                    for k in range(K):
                        cost += lamda * (P[u][k]**2 + Q[k][i]**2)
        if cost < 0.0001:
            break

    return P, Q.T, cost

In [6]:
# 测试
P, Q, cost = LFM_grad_desc(R, K, max_iter, alpha, lamda)
print(P)
print(Q)
print(cost)
predR = P.dot(Q.T)
predR

[[ 1.06775611  0.49029882  0.89776846  0.77621858  0.83832292]
 [ 0.18631968  0.71949575  0.01231543  1.41760591  0.43035172]
 [ 0.01725054 -0.14442411  0.66590499  1.01162477  1.16302205]
 [ 1.35275788  1.28289837  0.4642324   0.4951377  -0.01490516]
 [ 1.21720531  0.44572357  1.08540188  0.13765771  1.09430577]
 [ 0.64252188  1.06699904  0.54505902  0.49672473  1.05101248]]
[[ 2.01765714  1.34498725  0.72143371  0.42247709  0.23052808]
 [ 0.53157674  0.5643566   1.01382178  0.74153258  1.0461196 ]
 [ 0.0669754   0.67582204 -0.03354961  1.62459102  0.41504238]
 [ 1.15344913  0.21279378  1.37810503  1.10782617  1.67402973]
 [ 0.06859985  0.36155169  0.59207311  0.30378059  0.05423066]]
0.554460644140371


array([[3.98268337, 3.2070523 , 1.98172563, 4.83644685, 1.06332379],
       [2.05064078, 2.01898136, 2.97996076, 2.67586894, 0.73418746],
       [1.01646093, 2.56958516, 2.00738946, 3.97448996, 0.71361447],
       [4.99554472, 2.26532385, 1.74024765, 2.996666  , 0.9810975 ],
       [4.14886644, 3.24584153, 1.02415857, 4.97903059, 0.98845215],
       [3.57685562, 2.96413266, 1.98903736, 4.02902641, 0.96045944]])