|
| 1 | +""" test script to verify the CG method works, and time it versus cholesky """ |
| 2 | + |
| 3 | +from __future__ import print_function |
| 4 | + |
| 5 | +import argparse |
| 6 | +import functools |
| 7 | +import json |
| 8 | +import logging |
| 9 | +import time |
| 10 | +from collections import defaultdict |
| 11 | + |
| 12 | +import numpy |
| 13 | +from implicit._implicit import calculate_loss, least_squares, least_squares_cg |
| 14 | + |
| 15 | +from lastfm import bm25_weight, read_data |
| 16 | + |
| 17 | + |
| 18 | +def benchmark_solver(Cui, factors, solver, callback, iterations=7, dtype=numpy.float64, |
| 19 | + regularization=0.00, num_threads=0): |
| 20 | + users, items = Cui.shape |
| 21 | + |
| 22 | + # have to explode out most of the alternating_least_squares call here |
| 23 | + X = numpy.random.rand(users, factors).astype(dtype) * 0.01 |
| 24 | + Y = numpy.random.rand(items, factors).astype(dtype) * 0.01 |
| 25 | + |
| 26 | + Cui, Ciu = Cui.tocsr(), Cui.T.tocsr() |
| 27 | + |
| 28 | + for iteration in range(iterations): |
| 29 | + s = time.time() |
| 30 | + solver(Cui, X, Y, regularization, num_threads=num_threads) |
| 31 | + solver(Ciu, Y, X, regularization, num_threads=num_threads) |
| 32 | + callback(time.time() - s, X, Y) |
| 33 | + logging.debug("finished iteration %i in %s", iteration, time.time() - s) |
| 34 | + |
| 35 | + return X, Y |
| 36 | + |
| 37 | + |
| 38 | +def benchmark_accuracy(plays): |
| 39 | + output = defaultdict(list) |
| 40 | + benchmark_solver(plays, 100, |
| 41 | + least_squares, |
| 42 | + lambda _, X, Y: output['cholesky'].append(calculate_loss(plays, X, Y, |
| 43 | + 0)), |
| 44 | + iterations=25) |
| 45 | + |
| 46 | + for steps in [2, 3, 4]: |
| 47 | + benchmark_solver(plays, 100, functools.partial(least_squares_cg, cg_steps=steps), |
| 48 | + lambda _, X, Y: output['cg%i' % steps].append(calculate_loss(plays, X, Y, |
| 49 | + 0)), |
| 50 | + iterations=25) |
| 51 | + |
| 52 | + return output |
| 53 | + |
| 54 | + |
| 55 | +def benchmark_times(plays): |
| 56 | + output = defaultdict(list) |
| 57 | + for factors in [50, 100, 150, 200, 250]: |
| 58 | + output['factors'].append(factors) |
| 59 | + for steps in [2, 3, 4]: |
| 60 | + current = [] |
| 61 | + benchmark_solver(plays, factors, |
| 62 | + functools.partial(least_squares_cg, cg_steps=steps), |
| 63 | + lambda elapsed, X, Y: current.append(elapsed), |
| 64 | + iterations=3) |
| 65 | + print("cg%i: %i factors : %ss" % (steps, factors, min(current))) |
| 66 | + output['cg%i' % steps].append(min(current)) |
| 67 | + |
| 68 | + current = [] |
| 69 | + benchmark_solver(plays, factors, least_squares, |
| 70 | + lambda elapsed, X, Y: current.append(elapsed), |
| 71 | + iterations=3) |
| 72 | + output['cholesky'].append(min(current)) |
| 73 | + print("cholesky: %i factors : %ss" % (factors, min(current))) |
| 74 | + |
| 75 | + return output |
| 76 | + |
| 77 | + |
| 78 | +def generate_speed_graph(data, filename="cg_training_speed.html"): |
| 79 | + from bokeh.plotting import figure, save |
| 80 | + p = figure(title="Training Time", x_axis_label='Factors', y_axis_label='Seconds') |
| 81 | + |
| 82 | + to_plot = [(data['cg2'], "CG (2 Steps/Iteration)", "#2ca02c"), |
| 83 | + (data['cg3'], "CG (3 Steps/Iteration)", "#ff7f0e"), |
| 84 | + # (data['cg4'], "CG (4 Steps/Iteration)", "#d62728"), |
| 85 | + (data['cholesky'], "Cholesky", "#1f77b4")] |
| 86 | + |
| 87 | + p = figure(title="Training Speed", x_axis_label='Factors', y_axis_label='Time / Iteration (s)') |
| 88 | + for current, label, colour in to_plot: |
| 89 | + p.line(data['factors'], current, legend=label, line_color=colour, line_width=1) |
| 90 | + p.circle(data['factors'], current, legend=label, line_color=colour, size=6, |
| 91 | + fill_color="white") |
| 92 | + p.legend.location = "top_left" |
| 93 | + save(p, filename, title="CG ALS Training Speed") |
| 94 | + |
| 95 | + |
| 96 | +def generate_loss_graph(data, filename): |
| 97 | + from bokeh.plotting import figure, save |
| 98 | + |
| 99 | + iterations = range(1, len(data['cholesky']) + 1) |
| 100 | + to_plot = [(data['cg2'], "CG (2 Steps/Iteration)", "#2ca02c"), |
| 101 | + (data['cg3'], "CG (3 Steps/Iteration)", "#ff7f0e"), |
| 102 | + # (data['cg4'], "CG (4 Steps/Iteration)", "#d62728"), |
| 103 | + (data['cholesky'], "Cholesky", "#1f77b4")] |
| 104 | + |
| 105 | + p = figure(title="Training Loss", x_axis_label='Iteration', y_axis_label='MSE') |
| 106 | + for loss, label, colour in to_plot: |
| 107 | + p.line(iterations, loss, legend=label, line_color=colour, line_width=1) |
| 108 | + p.circle(iterations, loss, legend=label, line_color=colour, size=6, fill_color="white") |
| 109 | + |
| 110 | + save(p, filename, title="CG ALS Training Loss") |
| 111 | + |
| 112 | + |
| 113 | +if __name__ == "__main__": |
| 114 | + parser = argparse.ArgumentParser(description="Benchmark CG version against Cholesky", |
| 115 | + formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 116 | + |
| 117 | + parser.add_argument('--input', type=str, |
| 118 | + dest='inputfile', help='last.fm dataset file', required=True) |
| 119 | + parser.add_argument('--graph', help='generates graphs (requires bokeh)', |
| 120 | + action="store_true") |
| 121 | + parser.add_argument('--loss', help='test training loss', |
| 122 | + action="store_true") |
| 123 | + parser.add_argument('--speed', help='test training speed', |
| 124 | + action="store_true") |
| 125 | + |
| 126 | + args = parser.parse_args() |
| 127 | + if not (args.speed or args.loss): |
| 128 | + print("must specify at least one of --speed or --loss") |
| 129 | + parser.print_help() |
| 130 | + |
| 131 | + else: |
| 132 | + |
| 133 | + plays = bm25_weight(read_data(args.inputfile)[1]).tocsr() |
| 134 | + logging.basicConfig(level=logging.DEBUG) |
| 135 | + |
| 136 | + if args.loss: |
| 137 | + acc = benchmark_accuracy(plays) |
| 138 | + json.dump(acc, open("cg_accuracy.json", "w")) |
| 139 | + if args.graph: |
| 140 | + generate_loss_graph(acc, "cg_accuracy.html") |
| 141 | + |
| 142 | + if args.speed: |
| 143 | + speed = benchmark_times(plays) |
| 144 | + json.dump(speed, open("cg_speed.json", "w")) |
| 145 | + if args.graph: |
| 146 | + generate_speed_graph(speed, "cg_speed.html") |
0 commit comments