Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

One-hot-encode categorical column #22

Merged
merged 3 commits into from
Nov 9, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions sparsity/dask/reshape.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


def one_hot_encode(ddf, column=None, categories=None, index_col=None,
order=None, prefixes=False):
order=None, prefixes=False,
ignore_cat_order_mismatch=False):
"""
Sparse one hot encoding of dask.DataFrame.

Expand All @@ -21,8 +22,11 @@ def one_hot_encode(ddf, column=None, categories=None, index_col=None,
ddf: dask.DataFrame
e.g. the clickstream
categories: dict
Maps column name -> iterable of possible category values.
See description of `order`.
Maps ``column name`` -> ``iterable of possible category values``.
Can be also ``column name`` -> ``None`` if this column is already
of categorical dtype.
This argument decides which column(s) will be encoded.
See description of `order` and `ignore_cat_order_mismatch`.
index_col: str | iterable
which columns to use as index
order: iterable
Expand All @@ -46,6 +50,16 @@ def one_hot_encode(ddf, column=None, categories=None, index_col=None,
[col1_cat11, col1_cat12, col2_cat21, col2_cat22, ...].
column: DEPRECATED
Kept only for backward compatibility.
ignore_cat_order_mismatch: bool
If a column being one-hot encoded is of categorical dtype, it has
its categories already predefined, so we don't need to explicitly pass
them in `categories` argument (see this argument's description).
However, if we pass them, they may be different than ones defined in
column.cat.categories. In such a situation, a ValueError will be
raised. However, if only orders of categories are different (but sets
of elements are same), you may specify ignore_cat_order_mismatch=True
to suppress this error. In such a situation, column's predefined
categories will be used.

Returns
-------
Expand All @@ -71,14 +85,17 @@ def one_hot_encode(ddf, column=None, categories=None, index_col=None,
columns = sparse_one_hot(ddf._meta,
categories=categories,
index_col=index_col,
prefixes=prefixes).columns
prefixes=prefixes,
ignore_cat_order_mismatch=ignore_cat_order_mismatch
).columns
meta = sp.SparseFrame(np.array([]), columns=columns,
index=idx_meta)

dsf = ddf.map_partitions(sparse_one_hot,
categories=categories,
index_col=index_col,
prefixes=prefixes,
ignore_cat_order_mismatch=ignore_cat_order_mismatch,
meta=object)

return SparseFrame(dsf.dask, dsf._name, meta, dsf.divisions)
41 changes: 37 additions & 4 deletions sparsity/sparse_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,7 +638,8 @@ def _create_group_matrix(group_idx, dtype='f8'):


def sparse_one_hot(df, column=None, categories=None, dtype='f8',
index_col=None, order=None, prefixes=False):
index_col=None, order=None, prefixes=False,
ignore_cat_order_mismatch=False):
"""
One-hot encode specified columns of a pandas.DataFrame.
Returns a SparseFrame.
Expand All @@ -664,7 +665,10 @@ def sparse_one_hot(df, column=None, categories=None, dtype='f8',
for column, column_cat in categories.items():
if isinstance(column_cat, str):
column_cat = _just_read_array(column_cat)
cols, csr = _one_hot_series_csr(column_cat, dtype, df[column])
cols, csr = _one_hot_series_csr(
column_cat, dtype, df[column],
ignore_cat_order_mismatch=ignore_cat_order_mismatch
)
if prefixes:
cols = list(map(lambda x: '{}_{}'.format(column, x), cols))
new_cols.extend(cols)
Expand All @@ -683,9 +687,13 @@ def sparse_one_hot(df, column=None, categories=None, dtype='f8',
return SparseFrame(new_data, index=new_index, columns=new_cols)


def _one_hot_series_csr(categories, dtype, oh_col):
def _one_hot_series_csr(categories, dtype, oh_col,
ignore_cat_order_mismatch=False):
if types.is_categorical_dtype(oh_col):
cat = oh_col
cat = oh_col.cat
_check_categories_order(cat.categories, categories, oh_col.name,
ignore_cat_order_mismatch)

else:
s = oh_col
cat = pd.Categorical(s, np.asarray(categories))
Expand All @@ -703,3 +711,28 @@ def _one_hot_series_csr(categories, dtype, oh_col):
shape=(n_samples, n_features),
dtype=dtype).tocsr()
return cat.categories.values, data


def _check_categories_order(categories1, categories2, categorical_column_name,
ignore_cat_order_mismatch):
"""Check if two lists of categories differ. If they have different
elements, raise an exception. If they differ only by order of elements,
raise an issue unless ignore_cat_order_mismatch is set."""

if categories2 is None or list(categories2) == list(categories1):
return

if set(categories2) == set(categories1):
mismatch_type = 'order'
else:
mismatch_type = 'set'

if mismatch_type == 'set' or not ignore_cat_order_mismatch:
raise ValueError(
"Got categorical column {column_name} whose categories "
"{mismatch_type} doesn't match categories {mismatch_type} "
"given as argument to this function.".format(
column_name=categorical_column_name,
mismatch_type=mismatch_type
)
)
16 changes: 16 additions & 0 deletions sparsity/test/test_dask_sparse_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,22 @@ def test_one_hot_no_order(clickstream):
assert sorted(sf.columns) == list('ABCDEFGHIJ')


def test_one_hot_no_order_categorical(clickstream):
clickstream['other_categorical'] = clickstream['other_categorical'] \
.astype('category')
ddf = dd.from_pandas(clickstream, npartitions=10)
dsf = one_hot_encode(ddf,
categories={'page_id': list('ABCDE'),
'other_categorical': list('FGHIJ')},
index_col=['index', 'id'])
assert dsf._meta.empty
assert sorted(dsf.columns) == list('ABCDEFGHIJ')
sf = dsf.compute()
assert sf.shape == (100, 10)
assert isinstance(sf.index, pd.MultiIndex)
assert sorted(sf.columns) == list('ABCDEFGHIJ')


def test_one_hot_prefixes(clickstream):
ddf = dd.from_pandas(clickstream, npartitions=10)
dsf = one_hot_encode(ddf,
Expand Down
96 changes: 95 additions & 1 deletion sparsity/test/test_sparse_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,18 @@ def mock_s3_fs(bucket, data=None):
# 2017 starts with a sunday
@pytest.fixture()
def sampledata():
def gendata(n):
def gendata(n, categorical=False):
sample_data = pd.DataFrame(
dict(date=pd.date_range("2017-01-01", periods=n)))
sample_data["weekday"] = sample_data.date.dt.weekday_name
sample_data["weekday_abbr"] = sample_data.weekday.apply(
lambda x: x[:3])

if categorical:
sample_data['weekday'] = sample_data['weekday'].astype('category')
sample_data['weekday_abbr'] = sample_data['weekday_abbr'] \
.astype('category')

sample_data["id"] = np.tile(np.arange(7), len(sample_data) // 7 + 1)[
:len(sample_data)]
return sample_data
Expand Down Expand Up @@ -468,6 +474,94 @@ def test_csr_one_hot_series(sampledata, weekdays, weekdays_abbr):
assert all(sparse_frame.columns == (weekdays + weekdays_abbr))


def test_csr_one_hot_series_categorical_same_order(sampledata, weekdays,
weekdays_abbr):
correct = np.hstack((np.identity(7) * 7,
np.identity(7) * 7))

data = sampledata(49, categorical=True)

categories = {'weekday': data['weekday'].cat.categories.tolist(),
'weekday_abbr': data['weekday_abbr'].cat.categories.tolist()}

sparse_frame = sparse_one_hot(data,
categories=categories,
order=['weekday', 'weekday_abbr'],
ignore_cat_order_mismatch=False)

res = sparse_frame.groupby_sum(np.tile(np.arange(7), 7)) \
.todense()[weekdays + weekdays_abbr].values
assert np.all(res == correct)
assert set(sparse_frame.columns) == set(weekdays + weekdays_abbr)


def test_csr_one_hot_series_categorical_different_order(sampledata, weekdays,
weekdays_abbr):
correct = np.hstack((np.identity(7) * 7,
np.identity(7) * 7))

data = sampledata(49, categorical=True)

categories = {
'weekday': data['weekday'].cat.categories.tolist()[::-1],
'weekday_abbr': data['weekday_abbr'].cat.categories.tolist()[::-1]
}

with pytest.raises(ValueError):
sparse_frame = sparse_one_hot(data,
categories=categories,
order=['weekday', 'weekday_abbr'],
ignore_cat_order_mismatch=False)


def test_csr_one_hot_series_categorical_different_order_ignore(
sampledata, weekdays, weekdays_abbr):

correct = np.hstack((np.identity(7) * 7,
np.identity(7) * 7))

data = sampledata(49, categorical=True)

categories = {
'weekday': data['weekday'].cat.categories.tolist()[::-1],
'weekday_abbr': data['weekday_abbr'].cat.categories.tolist()[::-1]
}

sparse_frame = sparse_one_hot(data,
categories=categories,
order=['weekday', 'weekday_abbr'],
ignore_cat_order_mismatch=True)

res = sparse_frame.groupby_sum(np.tile(np.arange(7), 7)) \
.todense()[weekdays + weekdays_abbr].values
assert np.all(res == correct)
assert set(sparse_frame.columns) == set(weekdays + weekdays_abbr)


def test_csr_one_hot_series_categorical_no_categories(
sampledata, weekdays, weekdays_abbr):

correct = np.hstack((np.identity(7) * 7,
np.identity(7) * 7))

data = sampledata(49, categorical=True)

categories = {
'weekday': None,
'weekday_abbr': None
}

sparse_frame = sparse_one_hot(data,
categories=categories,
order=['weekday', 'weekday_abbr'],
ignore_cat_order_mismatch=True)

res = sparse_frame.groupby_sum(np.tile(np.arange(7), 7)) \
.todense()[weekdays + weekdays_abbr].values
assert np.all(res == correct)
assert set(sparse_frame.columns) == set(weekdays + weekdays_abbr)


def test_csr_one_hot_series_other_order(sampledata, weekdays, weekdays_abbr):

categories = {'weekday': weekdays,
Expand Down