Skip to content

Commit

Permalink
Allow excluding train set positives from test set evaluation.
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejkula committed May 2, 2016
1 parent da512a1 commit 023245e
Show file tree
Hide file tree
Showing 7 changed files with 724 additions and 485 deletions.
25 changes: 15 additions & 10 deletions lightfm/_lightfm_fast.pyx.template
Original file line number Diff line number Diff line change
Expand Up @@ -1188,7 +1188,8 @@ def predict_lightfm(CSRMatrix item_features,

def predict_ranks(CSRMatrix item_features,
CSRMatrix user_features,
CSRMatrix interactions,
CSRMatrix test_interactions,
CSRMatrix train_interactions,
flt[::1] ranks,
FastLightFM lightfm,
int num_threads):
Expand All @@ -1205,10 +1206,10 @@ def predict_ranks(CSRMatrix item_features,

# Figure out the max size of the predictions
# buffer.
for user_id in range(interactions.rows):
for user_id in range(test_interactions.rows):
predictions_size = int_max(predictions_size,
interactions.get_row_end(user_id)
- interactions.get_row_start(user_id))
test_interactions.get_row_end(user_id)
- test_interactions.get_row_start(user_id))

{nogil_block}

Expand All @@ -1217,7 +1218,7 @@ def predict_ranks(CSRMatrix item_features,
item_ids = <int *>malloc(sizeof(int) * predictions_size)
predictions = <flt *>malloc(sizeof(flt) * predictions_size)

for user_id in {range_block}(interactions.rows):
for user_id in {range_block}(test_interactions.rows):

compute_representation(user_features,
lightfm.user_features,
Expand All @@ -1227,14 +1228,14 @@ def predict_ranks(CSRMatrix item_features,
lightfm.user_scale,
user_repr)

row_start = interactions.get_row_start(user_id)
row_stop = interactions.get_row_end(user_id)
row_start = test_interactions.get_row_start(user_id)
row_stop = test_interactions.get_row_end(user_id)

# Compute predictions for the items whose
# ranks we want to know
for i in range(row_stop - row_start):

item_id = interactions.indices[row_start + i]
item_id = test_interactions.indices[row_start + i]

compute_representation(item_features,
lightfm.item_features,
Expand All @@ -1250,7 +1251,10 @@ def predict_ranks(CSRMatrix item_features,
lightfm.no_components)

# Now we can zip through all the other items and compute ranks
for item_id in range(interactions.cols):
for item_id in range(test_interactions.cols):

if in_positives(item_id, user_id, train_interactions):
continue

compute_representation(item_features,
lightfm.item_features,
Expand All @@ -1273,6 +1277,7 @@ def predict_ranks(CSRMatrix item_features,


def calculate_auc_from_rank(CSRMatrix ranks,
int[::1] num_train_positives,
flt[::1] rank_data,
flt[::1] auc,
int num_threads):
Expand All @@ -1287,7 +1292,7 @@ def calculate_auc_from_rank(CSRMatrix ranks,
row_stop = ranks.get_row_end(user_id)

num_positives = row_stop - row_start
num_negatives = ranks.cols - (row_stop - row_start)
num_negatives = ranks.cols - ((row_stop - row_start) + num_train_positives[user_id])

# If there is only one class present,
# return 0.5.
Expand Down

0 comments on commit 023245e

Please sign in to comment.