/
sklearn.py
58 lines (46 loc) · 1.85 KB
/
sklearn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
from typing import List
import pandas as pd
from pandas_select.base import PrettyPrinter
from pandas_select.label import LabelSelector
class ColumnSelector(PrettyPrinter):
"""Create a callable compatible with :class:`sklearn.compose.ColumnTransformer`.
Parameters
----------
selector:
A label selector, i.e. a :func:`callable` that returns a list of strings.
Raises
------
ValueError:
If `selector` is not a callable or doesn't target the "columns" axis.
Examples
--------
>>> from pandas_select import AnyOf, AllBool, AllNominal, AllNumeric, ColumnSelector
>>> from sklearn.compose import make_column_transformer
>>> from sklearn.preprocessing import OneHotEncoder, StandardScaler
>>> make_column_transformer(
... (StandardScaler(), ColumnSelector(AllNumeric() & ~AnyOf("Generation"))),
... (OneHotEncoder(), ColumnSelector(AllNominal() | AllBool() | "Generation"))
... )
"""
def __init__(self, selector: LabelSelector):
self.selector = selector
if not callable(selector):
raise ValueError(f"{selector} is not a callable.")
try:
if selector.axis not in {1, "columns"}:
raise ValueError(
f"Cannot make a ColumnSelector from {selector}"
+ ", which does not target the column axis."
)
except AttributeError:
pass # noqa: WPS420
def __call__(self, df: pd.DataFrame) -> List[str]:
if not isinstance(df, pd.DataFrame):
raise ValueError("ColumnSelector can only be applied to a DataFrame.")
cols = self.selector(df)
try:
# LabelSelector may return a pandas.Index
return cols.tolist() # type: ignore
except AttributeError:
return list(cols)