Skip to content

Commit

Permalink
Use new predict method consistently
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Nov 29, 2018
1 parent ffdbe83 commit 96a6acb
Show file tree
Hide file tree
Showing 7 changed files with 14 additions and 38 deletions.
12 changes: 1 addition & 11 deletions lenskit/batch/_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,12 +68,6 @@ def predict(algo, model, pairs, nprocs=None):
result will also contain a `rating` column.
"""

pfun = None
if not isinstance(algo, Predictor):
warnings.warn('non-Perdictor argument to predict deprecated', DeprecationWarning)
nprocs = None
pfun = algo

if nprocs and nprocs > 1 and mp.get_start_method() == 'fork':
_logger.info('starting predict process with %d workers', nprocs)
with MPRecContext(algo, model), Pool(nprocs) as pool:
Expand All @@ -83,11 +77,7 @@ def predict(algo, model, pairs, nprocs=None):
else:
results = []
for user, udf in pairs.groupby('user'):
if pfun:
res = pfun(user, udf['item'])
res = pd.DataFrame({'user': user, 'item': res.index, 'prediction': res.values})
else:
res = _predict_user(algo, model, user, udf)
res = _predict_user(algo, model, user, udf)
results.append(res)

results = pd.concat(results)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_als_explicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def eval(train, test):
_log.info('running training')
model = algo.train(train)
_log.info('testing %d users', test.user.nunique())
return batch.predict(algo, test, model=model)
return batch.predict(algo, model, test)

folds = xf.partition_users(ratings, 5, xf.SampleFrac(0.2))
preds = pd.concat(eval(train, test) for (train, test) in folds)
Expand Down
24 changes: 5 additions & 19 deletions tests/test_batch_predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
_log = logging.getLogger(__name__)

MLB = namedtuple('MLB', ['ratings', 'algo', 'model'])
MLB.predictor = property(lambda mlb: partial(mlb.algo.predict, mlb.model))


@pytest.fixture
Expand All @@ -27,20 +26,7 @@ def mlb():

def test_predict_single(mlb):
tf = pd.DataFrame({'user': [1], 'item': [31]})
res = lkb.predict(mlb.predictor, tf)

assert len(res) == 1
assert all(res.user == 1)
assert set(res.columns) == set(['user', 'item', 'prediction'])
assert all(res.item == 31)

expected = mlb.model.mean + mlb.model.items.loc[31] + mlb.model.users.loc[1]
assert res.prediction.iloc[0] == pytest.approx(expected)


def test_predict_single_model(mlb):
tf = pd.DataFrame({'user': [1], 'item': [31]})
res = lkb.predict(mlb.algo, tf, mlb.model)
res = lkb.predict(mlb.algo, mlb.model, tf)

assert len(res) == 1
assert all(res.user == 1)
Expand All @@ -61,7 +47,7 @@ def test_predict_user(mlb):
test_items = pd.concat([test_rated, pd.Series(test_unrated)])

tf = pd.DataFrame({'user': uid, 'item': test_items})
res = lkb.predict(mlb.predictor, tf)
res = lkb.predict(mlb.algo, mlb.model, tf)

assert len(res) == 15
assert set(res.columns) == set(['user', 'item', 'prediction'])
Expand All @@ -83,7 +69,7 @@ def test_predict_two_users(mlb):
while tf is None or len(set(tf.user)) < 2:
tf = mlb.ratings[mlb.ratings.user.isin(uids)].loc[:, ('user', 'item')].sample(10)

res = lkb.predict(mlb.predictor, tf)
res = lkb.predict(mlb.algo, mlb.model, tf)

assert len(res) == 10
assert set(res.user) == set(uids)
Expand All @@ -102,7 +88,7 @@ def test_predict_include_rating(mlb):
while tf is None or len(set(tf.user)) < 2:
tf = mlb.ratings[mlb.ratings.user.isin(uids)].loc[:, ('user', 'item', 'rating')].sample(10)

res = lkb.predict(mlb.predictor, tf)
res = lkb.predict(mlb.algo, mlb.model, tf)

assert len(res) == 10
assert set(res.user) == set(uids)
Expand Down Expand Up @@ -133,7 +119,7 @@ def eval(train, test):
_log.info('running training')
model = algo.train(train)
_log.info('testing %d users', test.user.nunique())
recs = batch.predict(algo, test, model=model, nprocs=ncpus)
recs = batch.predict(algo, model, test, nprocs=ncpus)
return recs

preds = pd.concat((eval(train, test)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_funksvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_fsvd_known_preds():
known_preds = pd.read_csv(str(pred_file))
pairs = known_preds.loc[:, ['user', 'item']]

preds = batch.predict(algo, pairs, model=model)
preds = batch.predict(algo, model, pairs)
merged = pd.merge(known_preds.rename(columns={'prediction': 'expected'}), preds)
assert len(merged) == len(preds)
merged['error'] = merged.expected - merged.prediction
Expand Down Expand Up @@ -173,7 +173,7 @@ def eval(train, test):
_log.info('running training')
model = algo.train(train)
_log.info('testing %d users', test.user.nunique())
return batch.predict(algo, test, model=model)
return batch.predict(algo, model, test)

folds = xf.partition_users(ratings, 5, xf.SampleFrac(0.2))
preds = pd.concat(eval(train, test) for (train, test) in folds)
Expand Down
4 changes: 2 additions & 2 deletions tests/test_knn_item_item.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,7 @@ def eval(train, test):
_log.info('running training')
model = algo.train(train)
_log.info('testing %d users', test.user.nunique())
return batch.predict(lambda u, xs: algo.predict(model, u, xs), test)
return batch.predict(algo, model, test)

preds = pd.concat((eval(train, test)
for (train, test)
Expand Down Expand Up @@ -477,7 +477,7 @@ def test_ii_known_preds():
known_preds = pd.read_csv(str(pred_file))
pairs = known_preds.loc[:, ['user', 'item']]

preds = batch.predict(algo, pairs, model=model)
preds = batch.predict(algo, model, pairs)
merged = pd.merge(known_preds.rename(columns={'prediction': 'expected'}), preds)
assert len(merged) == len(preds)
merged['error'] = merged.expected - merged.prediction
Expand Down
4 changes: 2 additions & 2 deletions tests/test_knn_user_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def test_uu_known_preds():
known_preds = pd.read_csv(str(pred_file))
pairs = known_preds.loc[:, ['user', 'item']]

preds = batch.predict(algo, pairs, model=model)
preds = batch.predict(algo, model, pairs)
merged = pd.merge(known_preds.rename(columns={'prediction': 'expected'}), preds)
assert len(merged) == len(preds)
merged['error'] = merged.expected - merged.prediction
Expand All @@ -205,7 +205,7 @@ def __batch_eval(job):
_log.info('running training')
model = algo.train(train)
_log.info('testing %d users', test.user.nunique())
return batch.predict(lambda u, xs: algo.predict(model, u, xs), test)
return batch.predict(algo, model, test)


@mark.slow
Expand Down
2 changes: 1 addition & 1 deletion tests/test_predict_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def test_batch_rmse():

def eval(train, test):
model = algo.train(train)
preds = batch.predict(lambda u, xs: algo.predict(model, u, xs), test)
preds = batch.predict(algo, model, test)
return preds.set_index(['user', 'item'])

results = pd.concat((eval(train, test)
Expand Down

0 comments on commit 96a6acb

Please sign in to comment.