Skip to content

Commit

Permalink
Merge pull request #272 from keener101/adding-non-unique-index-check-…
Browse files Browse the repository at this point in the history
…to-partition-users

Raise an error when attempting to split a data frame with a non-unique index
  • Loading branch information
mdekstrand committed Jan 19, 2022
2 parents d790b76 + 55ff084 commit b28a14f
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 1 deletion.
16 changes: 16 additions & 0 deletions lenskit/crossfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def partition_rows(data, partitions, *, rng_spec=None):
Returns:
iterator: an iterator of train-test pairs
"""

confirm_unique_index(data)
_logger.info('partitioning %d ratings into %d partitions', len(data), partitions)

# create an array of indexes
Expand Down Expand Up @@ -94,6 +96,7 @@ def sample_rows(data, partitions, size, disjoint=True, *, rng_spec=None):
iterator: An iterator of train-test pairs.
"""

confirm_unique_index(data)
if partitions is None:
test = data.sample(n=size)
tr_mask = pd.Series(True, index=data.index)
Expand Down Expand Up @@ -244,6 +247,7 @@ def partition_users(data, partitions: int, method: PartitionMethod, *, rng_spec=
iterator: an iterator of train-test pairs
"""

confirm_unique_index(data)
user_col = data['user']
users = user_col.unique()
_logger.info('partitioning %d rows for %d users into %d partitions',
Expand Down Expand Up @@ -297,6 +301,7 @@ def sample_users(data, partitions: int, size: int, method: PartitionMethod, disj
iterator: An iterator of train-test pairs (as :class:`TTPair` objects).
"""

confirm_unique_index(data)
rng = util.rng(rng_spec, legacy=True)

user_col = data['user']
Expand Down Expand Up @@ -345,3 +350,14 @@ def simple_test_pair(ratings, n_users=1000, n_rates=5, f_rates=None):
train, test = next(sample_users(ratings, 1, n_users, samp))

return train, test


def confirm_unique_index(data):
"""Confirms dataframe has unique index values, and if not,
throws ValueError with helpful log message"""

if not data.index.is_unique:
_logger.error("Index has duplicate values")
_logger.info("If index values do not matter, consider running " +
".reset_index() on the dataframe before partitioning")
raise ValueError('Index is not uniquely valued')
38 changes: 37 additions & 1 deletion tests/test_crossfold.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ def test_partition_users():


def test_partition_may_skip_train():
"Partitioning when users may not have enough ratings to be in the train set and test set."
"""Partitioning when users may not have enough ratings to be in the train set and test set."""
ratings = lktu.ml_test.ratings
# make a data set where some users only have 1 rating
ratings = ratings.sample(frac=0.1)
Expand Down Expand Up @@ -364,3 +364,39 @@ def test_sample_users_frac_oversize_ndj():
assert all(ucounts == 5)
assert all(s.test.index.union(s.train.index) == ratings.index)
assert len(s.test) + len(s.train) == len(ratings)


def test_non_unique_index_partition_users():
"""Partitioning users when dataframe has non-unique indices"""
ratings = lktu.ml_test.ratings
ratings = ratings.set_index('user') ##forces non-unique index
with pytest.raises(ValueError):
for split in xf.partition_users(ratings, 5, xf.SampleN(5)):
pass


def test_sample_users():
"""Sampling users when dataframe has non-unique indices"""
ratings = lktu.ml_test.ratings
ratings = ratings.set_index('user') ##forces non-unique index
with pytest.raises(ValueError):
for split in xf.sample_users(ratings, 5, 100, xf.SampleN(5)):
pass


def test_sample_rows():
"""Sampling ratings when dataframe has non-unique indices"""
ratings = lktu.ml_test.ratings
ratings = ratings.set_index('user') ##forces non-unique index
with pytest.raises(ValueError):
for split in xf.sample_rows(ratings, partitions=5, size=1000):
pass


def test_partition_users():
"""Partitioning ratings when dataframe has non-unique indices"""
ratings = lktu.ml_test.ratings
ratings = ratings.set_index('user') ##forces non-unique index
with pytest.raises(ValueError):
for split in xf.partition_users(ratings, 5, xf.SampleN(5)):
pass

0 comments on commit b28a14f

Please sign in to comment.