Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH Warn when expanding columns with many categories #26

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/)
and this project adheres to [Semantic Versioning](http://semver.org/).

## Unreleased

### Added
- Emit a warning if the user attempts to expand a column with
too many categories (#25, #26)

## [0.1.6] - 2018-1-12

### Fixed
Expand Down
13 changes: 13 additions & 0 deletions civismlext/preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class DataFrameETL(BaseEstimator, TransformerMixin):
columns_, list[str]
List of final column names in order
"""
expansion_warn_threshold = 500 # Warn when expanding this many categories

def __init__(self,
cols_to_drop=None,
cols_to_expand='auto',
Expand Down Expand Up @@ -98,17 +100,28 @@ def _check_sentinels(self, X):
def _create_levels(self, X):
"""Create levels for each column in cols_to_expand."""
levels = {}
warn_list = {}
# get a list of categories when the column is cast to
# dtype category
# levels are sorted by default
for col in self._cols_to_expand:
levels[col] = X[col].astype('category').cat.categories.tolist()
if (self.expansion_warn_threshold and
len(levels[col]) >= self.expansion_warn_threshold):
warn_list[col] = len(levels[col])
# if there are nans, we will be replacing them with a sentinel,
# so add the sentinel as a level explicitly
# Note that even if we don't include a dummy_na column, we still
# need to keep track of missing values internally for fill_value
if self.dummy_na or any(X[col].isnull()):
levels[col].extend([self._nan_sentinel])
if warn_list:
warnings.warn("The following categorical column(s) have a large "
"number of categories. Are you sure you wish to "
"convert them to binary indicators?\n%s" %
("; ".join(['"%s": %d categories' % (c, l)
for c, l in warn_list.items()])),
RuntimeWarning)
return levels

def _create_col_names(self, X):
Expand Down
10 changes: 10 additions & 0 deletions civismlext/test/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,16 @@ def test_create_levels_no_dummy(data_raw, levels_dict_numeric):
assert actual_levels == levels_dict_numeric


def test_warn_too_many_categories():
df = pd.DataFrame({'cat': list(range(2000)),
'bird': 2 * list(range(1000))})
with pytest.warns(RuntimeWarning) as warn:
DataFrameETL(cols_to_expand=['cat', 'bird']).fit(df)
assert len(warn) == 1, "Should only raise one warning"
assert '"cat": 2000 categories' in warn[0].message.args[0]
assert '"bird": 1000 categories' in warn[0].message.args[0]


def test_create_col_names(data_raw):
expander = DataFrameETL(cols_to_expand=['pid', 'djinn_type', 'animal'],
cols_to_drop=['fruits'],
Expand Down