Skip to content

Commit

Permalink
Add min_training_interactions option to stackexchange datasets.
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejkula committed May 2, 2016
1 parent d7c0abb commit 43e5202
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 7 deletions.
14 changes: 11 additions & 3 deletions lightfm/datasets/stackexchange/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

import numpy as np

import requests

import scipy.sparse as sp

from lightfm.datasets import _common


def fetch_stackexchange(dataset, test_set_fraction=0.2, data_home=None,
def fetch_stackexchange(dataset, test_set_fraction=0.2,
min_training_interactions=1,
data_home=None,
indicator_features=True, tag_features=False,
download_if_missing=True):
"""
Expand All @@ -34,6 +34,8 @@ def fetch_stackexchange(dataset, test_set_fraction=0.2, data_home=None,
The fraction of the dataset used for testing. Splitting into the train and test set is done
in a time-based fashion: all interactions before a certain time are in the train set and
all interactions after that time are in the test set.
min_training_interactions: int, optional
Only include users with this amount of interactions in the training set.
data_home: path, optional
Path to the directory in which the downloaded data should be placed.
Defaults to ``~/lightfm_data/``.
Expand Down Expand Up @@ -109,6 +111,12 @@ def fetch_stackexchange(dataset, test_set_fraction=0.2, data_home=None,
interactions.col[in_test])),
shape=interactions.shape)

if min_training_interactions > 0:
include = np.squeeze(np.array(train.getnnz(axis=1))) > min_training_interactions

train = train.tocsr()[include].tocoo()
test = test.tocsr()[include].tocoo()

if indicator_features and not tag_features:
features = sp.identity(train.shape[1],
format='csr',
Expand Down
16 changes: 12 additions & 4 deletions tests/test_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def test_basic_fetching_stackexchange():
test_fractions = (0.2, 0.5, 0.6)

for test_fraction in test_fractions:
data = fetch_stackexchange('crossvalidated', test_set_fraction=test_fraction)
data = fetch_stackexchange('crossvalidated',
min_training_interactions=0,
test_set_fraction=test_fraction)

train = data['train']
test = data['test']
Expand All @@ -51,16 +53,22 @@ def test_basic_fetching_stackexchange():

for dataset in ('crossvalidated', 'stackoverflow'):

data = fetch_stackexchange(dataset, indicator_features=True, tag_features=False)
data = fetch_stackexchange(dataset,
min_training_interactions=0,
indicator_features=True, tag_features=False)
assert isinstance(data['item_features'], sp.csr_matrix)
assert (data['item_features'].shape[0] == data['item_features'].shape[1]
== data['train'].shape[1])

data = fetch_stackexchange(dataset, indicator_features=False, tag_features=True)
data = fetch_stackexchange(dataset,
min_training_interactions=0,
indicator_features=False, tag_features=True)
assert isinstance(data['item_features'], sp.csr_matrix)
assert data['item_features'].shape[0] > data['item_features'].shape[1]

data = fetch_stackexchange(dataset, indicator_features=True, tag_features=True)
data = fetch_stackexchange(dataset,
min_training_interactions=0,
indicator_features=True, tag_features=True)
assert isinstance(data['item_features'], sp.csr_matrix)
assert data['item_features'].shape[0] < data['item_features'].shape[1]

Expand Down

0 comments on commit 43e5202

Please sign in to comment.