Skip to content

Commit

Permalink
Merge pull request #62 from mdekstrand/fix/pickling
Browse files Browse the repository at this point in the history
Pickling bug fixes and tests
  • Loading branch information
mdekstrand committed Jan 19, 2019
2 parents 8846075 + 80afa19 commit 145ddb8
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 0 deletions.
2 changes: 2 additions & 0 deletions lenskit/algorithms/implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,8 @@ def recommend(self, user, n=None, candidates=None, ratings=None):
return rec_df.loc[:, ['item', 'score']]

def __getattr__(self, name):
if 'delegate' not in self.__dict__:
raise AttributeError()
dd = self.delegate.__dict__
if name in dd:
return dd[name]
Expand Down
13 changes: 13 additions & 0 deletions lenskit/batch/_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pathlib
import collections
import json
from copy import copy

import pandas as pd

Expand Down Expand Up @@ -366,3 +367,15 @@ def _read_json(self, name, *args):

with fn.open('r') as f:
return json.load(f)

def __getstate__(self):
if not self._is_flat:
_logger.warning('attempting to pickle non-flattened experiment')
state = copy(self.__dict__)
# clone the algorithms to only pickle their parameters
state['algorithms'] = [a._replace(algorithm=util.clone(a.algorithm))
for a in self.algorithms]
return state

def __setstate__(self, state):
self.__dict__.update(state)
35 changes: 35 additions & 0 deletions tests/test_batch_sweep.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import pathlib
import json
import pickle

import pandas as pd
import numpy as np
Expand Down Expand Up @@ -160,6 +161,40 @@ def test_sweep_oneshot(tmp_path):
assert run['RunId'] == 3


def test_sweep_save(tmp_path):
tmp_path = norm_path(tmp_path)
work = pathlib.Path(tmp_path)
sweep = batch.MultiEval(tmp_path)

ratings = ml_pandas.renamed.ratings
sweep.add_datasets(lambda: xf.partition_users(ratings, 5, xf.SampleN(5)), name='ml-small')
sweep.add_algorithms(Bias(damping=5))

sweep.persist_data()
pf = work / 'sweep.dat'
with pf.open('wb') as f:
pickle.dump(sweep, f)

with pf.open('rb') as f:
sweep = pickle.load(f)

try:
sweep.run()
finally:
if (work / 'runs.csv').exists():
runs = pd.read_csv(work / 'runs.csv')
print(runs)

assert (work / 'runs.csv').exists()
assert (work / 'runs.parquet').exists()
assert (work / 'predictions.parquet').exists()
assert (work / 'recommendations.parquet').exists()

runs = pd.read_parquet(work / 'runs.parquet')
# 1 algorithms by 5 partitions
assert len(runs) == 5


def test_sweep_combine(tmp_path):
tmp_path = norm_path(tmp_path)
work = pathlib.Path(tmp_path)
Expand Down
18 changes: 18 additions & 0 deletions tests/test_implicit.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import pickle

import pandas as pd
import numpy as np
Expand Down Expand Up @@ -26,6 +27,7 @@
@mark.skipif(not have_implicit, reason='implicit not installed')
def test_implicit_als_train_rec():
algo = ALS(25)
assert algo.factors == 25
ratings = lktu.ml_pandas.renamed.ratings

ret = algo.fit(ratings)
Expand Down Expand Up @@ -72,9 +74,25 @@ def eval(train, test):
@mark.skipif(not have_implicit, reason='implicit not installed')
def test_implicit_bpr_train_rec():
algo = BPR(25)
assert algo.factors == 25
ratings = lktu.ml_pandas.renamed.ratings

algo.fit(ratings)

recs = algo.recommend(100, n=20)
assert len(recs) == 20


@mark.skipif(not have_implicit, reason='implicit not installed')
def test_implicit_pickle_untrained(tmp_path):
mf = tmp_path / 'bpr.dat'
algo = BPR(25)

with mf.open('wb') as f:
pickle.dump(algo, f)

with mf.open('rb') as f:
a2 = pickle.load(f)

assert a2 is not algo
assert a2.factors == 25

0 comments on commit 145ddb8

Please sign in to comment.