Skip to content

Commit

Permalink
Make sure indices are sorted in the positives lookup matrix.
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejkula committed Aug 26, 2015
1 parent c16f399 commit f62ca69
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions lightfm/lightfm.py
Expand Up @@ -171,6 +171,15 @@ def _construct_feature_matrices(self, n_users, n_items, user_features,

return user_features, item_features

def _get_positives_lookup_matrix(self, interactions):

mat = interactions.tocsr()

if not mat.has_sorted_indices:
return mat.sorted_indices()
else:
return mat

def fit(self, interactions, user_features=None, item_features=None,
epochs=1, num_threads=1, verbose=False):

Expand Down Expand Up @@ -277,7 +286,7 @@ def _run_epoch(self, item_features, user_features, interactions, num_threads, lo
if loss == 'warp':
fit_warp(CSRMatrix(item_features),
CSRMatrix(user_features),
CSRMatrix(interactions.tocsr()),
CSRMatrix(self._get_positives_lookup_matrix(interactions)),
interactions.row,
interactions.col,
interactions.data,
Expand All @@ -290,7 +299,7 @@ def _run_epoch(self, item_features, user_features, interactions, num_threads, lo
elif loss == 'bpr':
fit_bpr(CSRMatrix(item_features),
CSRMatrix(user_features),
CSRMatrix(interactions.tocsr()),
CSRMatrix(self._get_positives_lookup_matrix(interactions)),
interactions.row,
interactions.col,
interactions.data,
Expand All @@ -303,7 +312,7 @@ def _run_epoch(self, item_features, user_features, interactions, num_threads, lo
elif loss == 'warp-kos':
fit_warp_kos(CSRMatrix(item_features),
CSRMatrix(user_features),
CSRMatrix(interactions.tocsr()),
CSRMatrix(self._get_positives_lookup_matrix(interactions)),
interactions.row,
shuffle_indices,
lightfm_data,
Expand Down

0 comments on commit f62ca69

Please sign in to comment.