Skip to content

Commit

Permalink
Fix pickling GPU models (#632)
Browse files Browse the repository at this point in the history
GPU models weren't default initializing members like _knn appropriately when
being loaded from pickle appropriately. Fix and update the unittest to test
this out.
  • Loading branch information
benfred committed Dec 8, 2022
1 parent 44daadc commit ec734fb
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 5 deletions.
11 changes: 9 additions & 2 deletions implicit/gpu/matrix_factorization_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,8 +221,15 @@ def __getstate__(self):
}

def __setstate__(self, state):
self.item_factors = implicit.gpu.Matrix(state["item_factors"])
self.user_factors = implicit.gpu.Matrix(state["user_factors"])
# default initialize members
self.__init__()
item_factors = state["item_factors"]
if item_factors is not None:
self.item_factors = implicit.gpu.Matrix(item_factors)

user_factors = state["user_factors"]
if user_factors is not None:
self.user_factors = implicit.gpu.Matrix(user_factors)


def check_random_state(random_state):
Expand Down
20 changes: 17 additions & 3 deletions tests/recommender_base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,12 +405,26 @@ def test_rank_items_batch(self):
self.assertEqual(set(current_ids), set(selected_items))

def test_pickle(self):
item_users = get_checker_board(50)
user_items = get_checker_board(50)
model = self._get_model()
model.fit(item_users, show_progress=False)
model.fit(user_items, show_progress=False)

pickled = pickle.dumps(model)
pickle.loads(pickled)
reloaded = pickle.loads(pickled)

# make sure we can call methods on the reloaded index, and get the same results back
# (https://github.com/benfred/implicit/issues/631)
ids, _ = model.recommend(0, user_items[0])
reloaded_ids, _ = reloaded.recommend(0, user_items[0])
assert_array_equal(ids, reloaded_ids)

ids, _ = model.similar_items(0)
reloaded_ids, _ = reloaded.similar_items(0)
assert_array_equal(ids, reloaded_ids)

def test_pickle_unfitted_model(self):
model = self._get_model()
pickle.loads(pickle.dumps(model))

def test_invalid_user_items(self):

Expand Down

0 comments on commit ec734fb

Please sign in to comment.