Skip to content

Commit 4139e0a

Browse files
committed
Implement Conjugate Gradient ALS
Implement the algorithm described in the paper "Applications of the Conjugate Gradient Method for Implicit Feedback Collaborative Filtering". More details in the blog post here: http://www.benfrederickson.com/fast-implicit-matrix-factorization/ , but this leads to between a 3x to 19x speed increase in training depending on the number of factors in the model, with identical results.
1 parent da1a7fa commit 4139e0a

File tree

7 files changed

+304
-32
lines changed

7 files changed

+304
-32
lines changed

README.md

+11-4
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ Implicit
66

77
Fast Python Collaborative Filtering for Implicit Datasets.
88

9-
This project provides a fast Python implementation of the algorithm described in the paper [Collaborative Filtering for Implicit Feedback Datasets](
10-
http://yifanhu.net/PUB/cf.pdf).
9+
This project provides fast Python implementations of the algorithms described in the paper [Collaborative Filtering for Implicit Feedback Datasets](
10+
http://yifanhu.net/PUB/cf.pdf) and in [Applications of the Conjugate Gradient Method for Implicit
11+
Feedback Collaborative
12+
Filtering](https://pdfs.semanticscholar.org/bfdf/7af6cf7fd7bb5e6b6db5bbd91be11597eaf0.pdf).
1113

1214

1315
To install:
@@ -29,7 +31,7 @@ last.fm dataset](https://github.com/benfred/implicit/blob/master/examples/lastfm
2931
#### Requirements
3032

3133
This library requires SciPy version 0.16 or later. Running on OSX requires an OpenMP compiler,
32-
which can be installed with homebrew: ```brew install gcc```.
34+
which can be installed with homebrew: ```brew install gcc```.
3335

3436
#### Why Use This?
3537

@@ -44,7 +46,12 @@ libraries distributed with SciPy. This leads to extremely fast matrix factorizat
4446

4547
On a simple [benchmark](https://github.com/benfred/implicit/blob/master/examples/benchmark.py), this
4648
library is about 1.8 times faster than the multithreaded C++ implementation provided by Quora's
47-
[QMF Library](https://github.com/quora/qmf) and at least 60,000 times faster than [implicit-mf](https://github.com/MrChrisJohnson/implicit-mf).
49+
[QMF Library](https://github.com/quora/qmf) and at least 60,000 times faster than
50+
[implicit-mf](https://github.com/MrChrisJohnson/implicit-mf).
51+
52+
A [follow up post](http://www.benfrederickson.com/fast-implicit-matrix-factorization/) describes
53+
further performance improvements based on the Conjugate Gradient method - that further boosts performance
54+
by 3x to over 19x depending on the number of factors used.
4855

4956
This library has been tested with Python 2.7 and 3.5. Running 'tox' will
5057
run unittests on both versions, and verify that all python files pass flake8.

examples/benchmark_cg.py

+146
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
""" test script to verify the CG method works, and time it versus cholesky """
2+
3+
from __future__ import print_function
4+
5+
import argparse
6+
import functools
7+
import json
8+
import logging
9+
import time
10+
from collections import defaultdict
11+
12+
import numpy
13+
from implicit._implicit import calculate_loss, least_squares, least_squares_cg
14+
15+
from lastfm import bm25_weight, read_data
16+
17+
18+
def benchmark_solver(Cui, factors, solver, callback, iterations=7, dtype=numpy.float64,
19+
regularization=0.00, num_threads=0):
20+
users, items = Cui.shape
21+
22+
# have to explode out most of the alternating_least_squares call here
23+
X = numpy.random.rand(users, factors).astype(dtype) * 0.01
24+
Y = numpy.random.rand(items, factors).astype(dtype) * 0.01
25+
26+
Cui, Ciu = Cui.tocsr(), Cui.T.tocsr()
27+
28+
for iteration in range(iterations):
29+
s = time.time()
30+
solver(Cui, X, Y, regularization, num_threads=num_threads)
31+
solver(Ciu, Y, X, regularization, num_threads=num_threads)
32+
callback(time.time() - s, X, Y)
33+
logging.debug("finished iteration %i in %s", iteration, time.time() - s)
34+
35+
return X, Y
36+
37+
38+
def benchmark_accuracy(plays):
39+
output = defaultdict(list)
40+
benchmark_solver(plays, 100,
41+
least_squares,
42+
lambda _, X, Y: output['cholesky'].append(calculate_loss(plays, X, Y,
43+
0)),
44+
iterations=25)
45+
46+
for steps in [2, 3, 4]:
47+
benchmark_solver(plays, 100, functools.partial(least_squares_cg, cg_steps=steps),
48+
lambda _, X, Y: output['cg%i' % steps].append(calculate_loss(plays, X, Y,
49+
0)),
50+
iterations=25)
51+
52+
return output
53+
54+
55+
def benchmark_times(plays):
56+
output = defaultdict(list)
57+
for factors in [50, 100, 150, 200, 250]:
58+
output['factors'].append(factors)
59+
for steps in [2, 3, 4]:
60+
current = []
61+
benchmark_solver(plays, factors,
62+
functools.partial(least_squares_cg, cg_steps=steps),
63+
lambda elapsed, X, Y: current.append(elapsed),
64+
iterations=3)
65+
print("cg%i: %i factors : %ss" % (steps, factors, min(current)))
66+
output['cg%i' % steps].append(min(current))
67+
68+
current = []
69+
benchmark_solver(plays, factors, least_squares,
70+
lambda elapsed, X, Y: current.append(elapsed),
71+
iterations=3)
72+
output['cholesky'].append(min(current))
73+
print("cholesky: %i factors : %ss" % (factors, min(current)))
74+
75+
return output
76+
77+
78+
def generate_speed_graph(data, filename="cg_training_speed.html"):
79+
from bokeh.plotting import figure, save
80+
p = figure(title="Training Time", x_axis_label='Factors', y_axis_label='Seconds')
81+
82+
to_plot = [(data['cg2'], "CG (2 Steps/Iteration)", "#2ca02c"),
83+
(data['cg3'], "CG (3 Steps/Iteration)", "#ff7f0e"),
84+
# (data['cg4'], "CG (4 Steps/Iteration)", "#d62728"),
85+
(data['cholesky'], "Cholesky", "#1f77b4")]
86+
87+
p = figure(title="Training Speed", x_axis_label='Factors', y_axis_label='Time / Iteration (s)')
88+
for current, label, colour in to_plot:
89+
p.line(data['factors'], current, legend=label, line_color=colour, line_width=1)
90+
p.circle(data['factors'], current, legend=label, line_color=colour, size=6,
91+
fill_color="white")
92+
p.legend.location = "top_left"
93+
save(p, filename, title="CG ALS Training Speed")
94+
95+
96+
def generate_loss_graph(data, filename):
97+
from bokeh.plotting import figure, save
98+
99+
iterations = range(1, len(data['cholesky']) + 1)
100+
to_plot = [(data['cg2'], "CG (2 Steps/Iteration)", "#2ca02c"),
101+
(data['cg3'], "CG (3 Steps/Iteration)", "#ff7f0e"),
102+
# (data['cg4'], "CG (4 Steps/Iteration)", "#d62728"),
103+
(data['cholesky'], "Cholesky", "#1f77b4")]
104+
105+
p = figure(title="Training Loss", x_axis_label='Iteration', y_axis_label='MSE')
106+
for loss, label, colour in to_plot:
107+
p.line(iterations, loss, legend=label, line_color=colour, line_width=1)
108+
p.circle(iterations, loss, legend=label, line_color=colour, size=6, fill_color="white")
109+
110+
save(p, filename, title="CG ALS Training Loss")
111+
112+
113+
if __name__ == "__main__":
114+
parser = argparse.ArgumentParser(description="Benchmark CG version against Cholesky",
115+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
116+
117+
parser.add_argument('--input', type=str,
118+
dest='inputfile', help='last.fm dataset file', required=True)
119+
parser.add_argument('--graph', help='generates graphs (requires bokeh)',
120+
action="store_true")
121+
parser.add_argument('--loss', help='test training loss',
122+
action="store_true")
123+
parser.add_argument('--speed', help='test training speed',
124+
action="store_true")
125+
126+
args = parser.parse_args()
127+
if not (args.speed or args.loss):
128+
print("must specify at least one of --speed or --loss")
129+
parser.print_help()
130+
131+
else:
132+
133+
plays = bm25_weight(read_data(args.inputfile)[1]).tocsr()
134+
logging.basicConfig(level=logging.DEBUG)
135+
136+
if args.loss:
137+
acc = benchmark_accuracy(plays)
138+
json.dump(acc, open("cg_accuracy.json", "w"))
139+
if args.graph:
140+
generate_loss_graph(acc, "cg_accuracy.html")
141+
142+
if args.speed:
143+
speed = benchmark_times(plays)
144+
json.dump(speed, open("cg_speed.json", "w"))
145+
if args.graph:
146+
generate_speed_graph(speed, "cg_speed.html")

examples/lastfm.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,8 @@ def calculate_similar_artists(input_filename, output_filename,
9292
iterations=15,
9393
exact=False, trees=20,
9494
use_native=True,
95-
dtype=numpy.float64):
95+
dtype=numpy.float64,
96+
cg=False):
9697
logging.debug("Calculating similar artists. This might take a while")
9798
logging.debug("reading data from %s", input_filename)
9899
start = time.time()
@@ -109,7 +110,8 @@ def calculate_similar_artists(input_filename, output_filename,
109110
regularization=regularization,
110111
iterations=iterations,
111112
use_native=use_native,
112-
dtype=dtype)
113+
dtype=dtype,
114+
use_cg=cg)
113115
logging.debug("calculated factors in %s", time.time() - start)
114116

115117
# write out artists by popularity
@@ -154,6 +156,9 @@ def calculate_similar_artists(input_filename, output_filename,
154156
parser.add_argument('--float32',
155157
help='use 32 bit floating point numbers',
156158
action="store_true")
159+
parser.add_argument('--cg',
160+
help='use CG optimizer',
161+
action="store_true")
157162
args = parser.parse_args()
158163

159164
logging.basicConfig(level=logging.DEBUG)
@@ -163,5 +168,6 @@ def calculate_similar_artists(input_filename, output_filename,
163168
exact=args.exact, trees=args.treecount,
164169
iterations=args.iterations,
165170
use_native=not args.purepython,
166-
dtype=numpy.float32 if args.float32 else numpy.float64)
171+
dtype=numpy.float32 if args.float32 else numpy.float64,
172+
cg=args.cg)
167173

implicit/_implicit.pyx

+85-1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ cdef inline floating dot(int *n, floating *sx, int *incx, floating *sy, int *inc
2828
else:
2929
return cython_blas.sdot(n, sx, incx, sy, incy)
3030

31+
cdef inline void scal(int *n, floating *sa, floating *sx, int *incx) nogil:
32+
if floating is double:
33+
cython_blas.dscal(n, sa, sx, incx)
34+
else:
35+
cython_blas.sscal(n, sa, sx, incx)
36+
3137
cdef inline void posv(char * u, int * n, int * nrhs, floating * a, int * lda, floating * b, int * ldb, int * info) nogil:
3238
if floating is double:
3339
cython_lapack.dposv(u, n, nrhs, a, lda, b, ldb, info)
@@ -42,7 +48,7 @@ cdef inline void gesv(int * n, int * nrhs, floating * a, int * lda, int * piv, f
4248

4349

4450
@cython.boundscheck(False)
45-
def least_squares(Cui, floating [:, :] X, floating [:, :] Y, double regularization, int num_threads):
51+
def least_squares(Cui, floating [:, :] X, floating [:, :] Y, double regularization, int num_threads=0):
4652
dtype = numpy.float64 if floating is double else numpy.float32
4753

4854
cdef int [:] indptr = Cui.indptr, indices = Cui.indices
@@ -107,6 +113,84 @@ def least_squares(Cui, floating [:, :] X, floating [:, :] Y, double regularizati
107113
free(pivot)
108114

109115

116+
@cython.cdivision(True)
117+
@cython.boundscheck(False)
118+
def least_squares_cg(Cui, floating [:, :] X, floating [:, :] Y, float regularization, int num_threads=0, int cg_steps=3):
119+
dtype = numpy.float64 if floating is double else numpy.float32
120+
cdef int [:] indptr = Cui.indptr, indices = Cui.indices
121+
cdef double [:] data = Cui.data
122+
123+
cdef int users = X.shape[0], N = X.shape[1], u, i, index, one = 1, it
124+
cdef floating confidence, temp, alpha, rsnew, rsold
125+
cdef floating zero = 0.
126+
127+
cdef floating[:, :] YtY = numpy.dot(numpy.transpose(Y), Y) + regularization * numpy.eye(N, dtype=dtype)
128+
129+
cdef floating * x
130+
cdef floating * p
131+
cdef floating * r
132+
cdef floating * Ap
133+
134+
with nogil, parallel(num_threads = num_threads):
135+
136+
# allocate temp memory for each thread
137+
Ap = <floating *> malloc(sizeof(floating) * N)
138+
p = <floating *> malloc(sizeof(floating) * N)
139+
r = <floating *> malloc(sizeof(floating) * N)
140+
try:
141+
for u in prange(users, schedule='guided'):
142+
# start from previous iteration
143+
x = &X[u, 0]
144+
145+
# calculate residual r = (YtCuPu - (YtCuY.dot(Xu)
146+
temp = -1.0
147+
symv("U", &N, &temp, &YtY[0, 0], &N, x, &one, &zero, r, &one)
148+
149+
for index in range(indptr[u], indptr[u + 1]):
150+
i = indices[index]
151+
confidence = data[index]
152+
temp = confidence - (confidence - 1) * dot(&N, &Y[i, 0], &one, x, &one)
153+
axpy(&N, &temp, &Y[i, 0], &one, r, &one)
154+
155+
memcpy(p, r, sizeof(floating) * N)
156+
rsold = dot(&N, r, &one, r, &one)
157+
158+
for it in range(cg_steps):
159+
# calculate Ap = YtCuYp - without actually calculating YtCuY
160+
temp = 1.0
161+
symv("U", &N, &temp, &YtY[0, 0], &N, p, &one, &zero, Ap, &one)
162+
163+
for index in range(indptr[u], indptr[u + 1]):
164+
i = indices[index]
165+
confidence = data[index]
166+
temp = (confidence - 1) * dot(&N, &Y[i, 0], &one, p, &one)
167+
axpy(&N, &temp, &Y[i, 0], &one, Ap, &one)
168+
169+
# alpha = rsold / p.dot(Ap);
170+
alpha = rsold / dot(&N, p, &one, Ap, &one)
171+
172+
# x += alpha * p
173+
axpy(&N, &alpha, p, &one, x, &one)
174+
175+
# r -= alpha * Ap
176+
temp = alpha * -1
177+
axpy(&N, &temp, Ap, &one, r, &one)
178+
179+
rsnew = dot(&N, r, &one, r, &one)
180+
181+
# p = r + (rsnew/rsold) * p
182+
temp = rsnew / rsold
183+
scal(&N, &temp, p, &one)
184+
temp = 1.0
185+
axpy(&N, &temp, r, &one, p, &one)
186+
187+
rsold = rsnew
188+
finally:
189+
free(p)
190+
free(r)
191+
free(Ap)
192+
193+
110194
@cython.cdivision(True)
111195
@cython.boundscheck(False)
112196
def calculate_loss(Cui, floating [:, :] X, floating [:, :] Y, float regularization, int num_threads=0):

0 commit comments

Comments
 (0)