Skip to content

Commit

Permalink
ENH ColumnTransformer.get_feature_names() handles passthrough (scikit…
Browse files Browse the repository at this point in the history
  • Loading branch information
lrjball authored and gio8tisu committed May 15, 2020
1 parent 77a4c97 commit f9bbf80
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 22 deletions.
5 changes: 5 additions & 0 deletions doc/whats_new/v0.23.rst
Expand Up @@ -105,6 +105,11 @@ Changelog
a column name that is not unique in the dataframe. :pr:`16431` by
`Thomas Fan`_.

- |Enhancement| :class:`compose.ColumnTransformer` method ``get_feature_names``
now supports `'passthrough'` columns, with the feature name being either
the column name for a dataframe, or `'xi'` for column index `i`.
:pr:`14048` by :user:`Lewis Ball <lrjball>`.

:mod:`sklearn.datasets`
.......................

Expand Down
29 changes: 18 additions & 11 deletions sklearn/compose/_column_transformer.py
Expand Up @@ -315,19 +315,18 @@ def _validate_remainder(self, X):
self.remainder)

# Make it possible to check for reordered named columns on transform
if (hasattr(X, 'columns') and
any(_determine_key_type(cols) == 'str'
for cols in self._columns)):
self._has_str_cols = any(_determine_key_type(cols) == 'str'
for cols in self._columns)
if hasattr(X, 'columns'):
self._df_columns = X.columns

self._n_features = X.shape[1]
cols = []
for columns in self._columns:
cols.extend(_get_column_indices(X, columns))
remaining_idx = list(set(range(self._n_features)) - set(cols))
remaining_idx = sorted(remaining_idx) or None

self._remainder = ('remainder', self.remainder, remaining_idx)
remaining_idx = sorted(set(range(self._n_features)) - set(cols))
self._remainder = ('remainder', self.remainder, remaining_idx or None)

@property
def named_transformers_(self):
Expand Down Expand Up @@ -356,11 +355,18 @@ def get_feature_names(self):
if trans == 'drop' or (
hasattr(column, '__len__') and not len(column)):
continue
elif trans == 'passthrough':
raise NotImplementedError(
"get_feature_names is not yet supported when using "
"a 'passthrough' transformer.")
elif not hasattr(trans, 'get_feature_names'):
if trans == 'passthrough':
if hasattr(self, '_df_columns'):
if ((not isinstance(column, slice))
and all(isinstance(col, str) for col in column)):
feature_names.extend(column)
else:
feature_names.extend(self._df_columns[column])
else:
indices = np.arange(self._n_features)
feature_names.extend(['x%d' % i for i in indices[column]])
continue
if not hasattr(trans, 'get_feature_names'):
raise AttributeError("Transformer %s (type %s) does not "
"provide get_feature_names."
% (str(name), type(trans).__name__))
Expand Down Expand Up @@ -582,6 +588,7 @@ def transform(self, X):
# name order and count. See #14237 for details.
if (self._remainder[2] is not None and
hasattr(self, '_df_columns') and
self._has_str_cols and
hasattr(X, 'columns')):
n_cols_fit = len(self._df_columns)
n_cols_transform = len(X.columns)
Expand Down
85 changes: 74 additions & 11 deletions sklearn/compose/tests/test_column_transformer.py
Expand Up @@ -668,25 +668,88 @@ def test_column_transformer_get_feature_names():
ct.fit(X)
assert ct.get_feature_names() == ['col0__a', 'col0__b', 'col1__c']

# passthrough transformers not supported
# drop transformer
ct = ColumnTransformer(
[('col0', DictVectorizer(), 0), ('col1', 'drop', 1)])
ct.fit(X)
assert ct.get_feature_names() == ['col0__a', 'col0__b']

# passthrough transformer
ct = ColumnTransformer([('trans', 'passthrough', [0, 1])])
ct.fit(X)
assert_raise_message(
NotImplementedError, 'get_feature_names is not yet supported',
ct.get_feature_names)
assert ct.get_feature_names() == ['x0', 'x1']

ct = ColumnTransformer([('trans', DictVectorizer(), 0)],
remainder='passthrough')
ct.fit(X)
assert_raise_message(
NotImplementedError, 'get_feature_names is not yet supported',
ct.get_feature_names)
assert ct.get_feature_names() == ['trans__a', 'trans__b', 'x1']

# drop transformer
ct = ColumnTransformer(
[('col0', DictVectorizer(), 0), ('col1', 'drop', 1)])
ct = ColumnTransformer([('trans', 'passthrough', [1])],
remainder='passthrough')
ct.fit(X)
assert ct.get_feature_names() == ['col0__a', 'col0__b']
assert ct.get_feature_names() == ['x1', 'x0']

ct = ColumnTransformer([('trans', 'passthrough', lambda x: [1])],
remainder='passthrough')
ct.fit(X)
assert ct.get_feature_names() == ['x1', 'x0']

ct = ColumnTransformer([('trans', 'passthrough', np.array([False, True]))],
remainder='passthrough')
ct.fit(X)
assert ct.get_feature_names() == ['x1', 'x0']

ct = ColumnTransformer([('trans', 'passthrough', slice(1, 2))],
remainder='passthrough')
ct.fit(X)
assert ct.get_feature_names() == ['x1', 'x0']


def test_column_transformer_get_feature_names_dataframe():
# passthough transformer with a dataframe
pd = pytest.importorskip('pandas')
X = np.array([[{'a': 1, 'b': 2}, {'a': 3, 'b': 4}],
[{'c': 5}, {'c': 6}]], dtype=object).T
X_df = pd.DataFrame(X, columns=['col0', 'col1'])

ct = ColumnTransformer([('trans', 'passthrough', ['col0', 'col1'])])
ct.fit(X_df)
assert ct.get_feature_names() == ['col0', 'col1']

ct = ColumnTransformer([('trans', 'passthrough', [0, 1])])
ct.fit(X_df)
assert ct.get_feature_names() == ['col0', 'col1']

ct = ColumnTransformer([('col0', DictVectorizer(), 0)],
remainder='passthrough')
ct.fit(X_df)
assert ct.get_feature_names() == ['col0__a', 'col0__b', 'col1']

ct = ColumnTransformer([('trans', 'passthrough', ['col1'])],
remainder='passthrough')
ct.fit(X_df)
assert ct.get_feature_names() == ['col1', 'col0']

ct = ColumnTransformer([('trans', 'passthrough',
lambda x: x[['col1']].columns)],
remainder='passthrough')
ct.fit(X_df)
assert ct.get_feature_names() == ['col1', 'col0']

ct = ColumnTransformer([('trans', 'passthrough', np.array([False, True]))],
remainder='passthrough')
ct.fit(X_df)
assert ct.get_feature_names() == ['col1', 'col0']

ct = ColumnTransformer([('trans', 'passthrough', slice(1, 2))],
remainder='passthrough')
ct.fit(X_df)
assert ct.get_feature_names() == ['col1', 'col0']

ct = ColumnTransformer([('trans', 'passthrough', [1])],
remainder='passthrough')
ct.fit(X_df)
assert ct.get_feature_names() == ['col1', 'col0']


def test_column_transformer_special_strings():
Expand Down

0 comments on commit f9bbf80

Please sign in to comment.