From ec734fbe8bc0cb5489d18300ad20f5b6e46d49e0 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 7 Dec 2022 22:39:32 -0800 Subject: [PATCH] Fix pickling GPU models (#632) GPU models weren't default initializing members like _knn appropriately when being loaded from pickle appropriately. Fix and update the unittest to test this out. --- implicit/gpu/matrix_factorization_base.py | 11 +++++++++-- tests/recommender_base_test.py | 20 +++++++++++++++++--- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/implicit/gpu/matrix_factorization_base.py b/implicit/gpu/matrix_factorization_base.py index 10b047e0..8491acdb 100644 --- a/implicit/gpu/matrix_factorization_base.py +++ b/implicit/gpu/matrix_factorization_base.py @@ -221,8 +221,15 @@ def __getstate__(self): } def __setstate__(self, state): - self.item_factors = implicit.gpu.Matrix(state["item_factors"]) - self.user_factors = implicit.gpu.Matrix(state["user_factors"]) + # default initialize members + self.__init__() + item_factors = state["item_factors"] + if item_factors is not None: + self.item_factors = implicit.gpu.Matrix(item_factors) + + user_factors = state["user_factors"] + if user_factors is not None: + self.user_factors = implicit.gpu.Matrix(user_factors) def check_random_state(random_state): diff --git a/tests/recommender_base_test.py b/tests/recommender_base_test.py index 4f6ed36f..8d829d5b 100644 --- a/tests/recommender_base_test.py +++ b/tests/recommender_base_test.py @@ -405,12 +405,26 @@ def test_rank_items_batch(self): self.assertEqual(set(current_ids), set(selected_items)) def test_pickle(self): - item_users = get_checker_board(50) + user_items = get_checker_board(50) model = self._get_model() - model.fit(item_users, show_progress=False) + model.fit(user_items, show_progress=False) pickled = pickle.dumps(model) - pickle.loads(pickled) + reloaded = pickle.loads(pickled) + + # make sure we can call methods on the reloaded index, and get the same results back + # (https://github.com/benfred/implicit/issues/631) + ids, _ = model.recommend(0, user_items[0]) + reloaded_ids, _ = reloaded.recommend(0, user_items[0]) + assert_array_equal(ids, reloaded_ids) + + ids, _ = model.similar_items(0) + reloaded_ids, _ = reloaded.similar_items(0) + assert_array_equal(ids, reloaded_ids) + + def test_pickle_unfitted_model(self): + model = self._get_model() + pickle.loads(pickle.dumps(model)) def test_invalid_user_items(self):