-
Notifications
You must be signed in to change notification settings - Fork 0
/
als.py
83 lines (60 loc) · 2.17 KB
/
als.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import sys
import numpy as np
from numpy.random import rand
from numpy import matrix
from pyspark.sql import SparkSession
LAMBDA = 0.01 # regularization
np.random.seed(42)
def rmse(R, ms, us):
diff = R - ms * us.T
return np.sqrt(np.sum(np.power(diff, 2)) / (M * U))
def update(i, mat, ratings):
uu = mat.shape[0]
ff = mat.shape[1]
XtX = mat.T * mat
Xty = mat.T * ratings[i, :].T
for j in range(ff):
XtX[j, j] += LAMBDA * uu
return np.linalg.solve(XtX, Xty)
if __name__ == "__main__":
"""
Usage: als [M] [U] [F] [iterations] [partitions]"
"""
print("""WARN: This is a naive implementation of ALS and is given as an
example. Please use pyspark.ml.recommendation.ALS for more
conventional use.""", file=sys.stderr)
spark = SparkSession \
.builder \
.appName("PythonALS") \
.getOrCreate()
sc = spark.sparkContext
M = int(sys.argv[1]) if len(sys.argv) > 1 else 100
U = int(sys.argv[2]) if len(sys.argv) > 2 else 500
F = int(sys.argv[3]) if len(sys.argv) > 3 else 10
ITERATIONS = int(sys.argv[4]) if len(sys.argv) > 4 else 5
partitions = int(sys.argv[5]) if len(sys.argv) > 5 else 2
print("Running ALS with M=%d, U=%d, F=%d, iters=%d, partitions=%d\n" %
(M, U, F, ITERATIONS, partitions))
R = matrix(rand(M, F)) * matrix(rand(U, F).T)
ms = matrix(rand(M, F))
us = matrix(rand(U, F))
Rb = sc.broadcast(R)
msb = sc.broadcast(ms)
usb = sc.broadcast(us)
for i in range(ITERATIONS):
ms = sc.parallelize(range(M), partitions) \
.map(lambda x: update(x, usb.value, Rb.value)) \
.collect()
# collect() returns a list, so array ends up being
# a 3-d array, we take the first 2 dims for the matrix
ms = matrix(np.array(ms)[:, :, 0])
msb = sc.broadcast(ms)
us = sc.parallelize(range(U), partitions) \
.map(lambda x: update(x, msb.value, Rb.value.T)) \
.collect()
us = matrix(np.array(us)[:, :, 0])
usb = sc.broadcast(us)
error = rmse(R, ms, us)
print("Iteration %d:" % i)
print("\nRMSE: %5.4f\n" % error)
spark.stop()