In [14]:
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd
import warnings
import matplotlib.pyplot as plt
import matplotlib as mpl
import seaborn as sns

from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.model_selection import GridSearchCV, train_test_split
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.preprocessing import PolynomialFeatures, StandardScaler, OneHotEncoder
from sklearn.compose import (
    make_column_transformer,
    TransformedTargetRegressor,
    make_column_selector,
)
from sklearn.utils import Bunch
from sklearn.inspection import permutation_importance
from sklearn.feature_selection import SequentialFeatureSelector
from sklearn.datasets import fetch_openml
from sklearn.metrics import mean_squared_error

from IPython.display import Image

In [15]:
warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", None)
mpl.rcParams.update({"axes.grid": True})

In [16]:
df = fetch_openml(data_id=534, as_frame=True).frame

In [17]:
df.head()

Unnamed: 0,EDUCATION,SOUTH,SEX,EXPERIENCE,UNION,WAGE,AGE,RACE,OCCUPATION,SECTOR,MARR
0,8,no,female,21,not_member,5.1,35,Hispanic,Other,Manufacturing,Married
1,9,no,female,42,not_member,4.95,57,White,Other,Manufacturing,Married
2,12,no,male,1,not_member,6.67,19,White,Other,Manufacturing,Unmarried
3,12,no,male,4,not_member,4.0,22,White,Other,Other,Unmarried
4,12,no,male,17,not_member,7.5,35,White,Other,Other,Married


In [18]:
target_feature = "WAGE"
numeric_features = df.columns[df.dtypes != "category"].to_list()
numeric_features.remove(target_feature)
one_hot_features = df.columns[df.dtypes == "category"].to_list()
display([numeric_features, one_hot_features])

[['EDUCATION', 'EXPERIENCE', 'AGE'],
 ['SOUTH', 'SEX', 'UNION', 'RACE', 'OCCUPATION', 'SECTOR', 'MARR']]

In [19]:
numeric_features_prefix = ["numeric_" + name for name in numeric_features]
one_hot_features_prefix = ["category_" + name for name in one_hot_features]
df = df.rename(columns=dict(zip(one_hot_features, one_hot_features_prefix))).rename(
    columns=dict(zip(numeric_features, numeric_features_prefix))
)

numeric_features = numeric_features_prefix
one_hot_features = one_hot_features_prefix

In [20]:
X = df.drop(columns=target_feature)
y = df[target_feature]

In [23]:
ohe_pipe = Pipeline(
    [
        (
            "ohe",
            make_column_transformer(
                (
                    OneHotEncoder(drop="if_binary"),
                    one_hot_features,
                ),
            ),
        ),
    ]
).fit(X, y)

display(ohe_pipe)
pd.DataFrame(ohe_pipe.transform(X), columns=ohe_pipe.get_feature_names_out())

Unnamed: 0,onehotencoder__category_SOUTH_yes,onehotencoder__category_SEX_male,onehotencoder__category_UNION_not_member,onehotencoder__category_RACE_Hispanic,onehotencoder__category_RACE_Other,onehotencoder__category_RACE_White,onehotencoder__category_OCCUPATION_Clerical,onehotencoder__category_OCCUPATION_Management,onehotencoder__category_OCCUPATION_Other,onehotencoder__category_OCCUPATION_Professional,onehotencoder__category_OCCUPATION_Sales,onehotencoder__category_OCCUPATION_Service,onehotencoder__category_SECTOR_Construction,onehotencoder__category_SECTOR_Manufacturing,onehotencoder__category_SECTOR_Other,onehotencoder__category_MARR_Unmarried
0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
1,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
2,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0
3,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0
4,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0
530,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
531,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
532,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0


In [24]:
poly_pipe = Pipeline(
    [
        (
            "prescale",
            make_column_transformer(
                (
                    StandardScaler(),
                    numeric_features,
                ),
            ),
        ),
        ("poly", PolynomialFeatures(include_bias=False)),
        ("postscale", StandardScaler()),
    ]
).fit(X, y)

display(poly_pipe)
pd.DataFrame(
    poly_pipe.transform(X), columns=poly_pipe.get_feature_names_out()
)  # .describe()

Unnamed: 0,standardscaler__numeric_EDUCATION,standardscaler__numeric_EXPERIENCE,standardscaler__numeric_AGE,standardscaler__numeric_EDUCATION^2,standardscaler__numeric_EDUCATION standardscaler__numeric_EXPERIENCE,standardscaler__numeric_EDUCATION standardscaler__numeric_AGE,standardscaler__numeric_EXPERIENCE^2,standardscaler__numeric_EXPERIENCE standardscaler__numeric_AGE,standardscaler__numeric_AGE^2
0,-1.920733,0.256943,-0.156487,1.600919,-0.114953,0.399838,-0.735761,-0.846160,-0.820536
1,-1.538020,1.954858,1.721353,0.812900,-2.166087,-2.216164,2.222667,1.983774,1.651194
2,-0.389880,-1.360118,-1.522189,-0.504820,0.720654,0.659750,0.669542,0.907846,1.107822
3,-0.389880,-1.117559,-1.266119,-0.504820,0.643469,0.571159,0.196106,0.363175,0.507253
4,-0.389880,-0.066469,-0.156487,-0.504820,0.308999,0.187262,-0.784289,-0.804100,-0.820536
...,...,...,...,...,...,...,...,...,...
529,1.906399,-1.036706,-0.668625,1.568262,-1.325231,-0.997976,0.058893,-0.236681,-0.465097
530,-0.389880,1.227180,1.209215,-0.504820,-0.102656,-0.285226,0.398589,0.420486,0.388773
531,1.523686,0.580356,0.953146,0.786774,1.009579,1.421844,-0.522439,-0.353032,-0.076974
532,-0.389880,-0.389881,-0.497912,-0.504820,0.411913,0.305384,-0.668023,-0.651414,-0.632603


In [31]:
estimator = FeatureUnion(
    [
        ("poly", poly_pipe),
        ("ohe", ohe_pipe),
    ]
).fit(X, y)

display(estimator)

In [33]:
pd.DataFrame(
    estimator.transform(X), columns=estimator.get_feature_names_out()
)  # .describe()

Unnamed: 0,poly__standardscaler__numeric_EDUCATION,poly__standardscaler__numeric_EXPERIENCE,poly__standardscaler__numeric_AGE,poly__standardscaler__numeric_EDUCATION^2,poly__standardscaler__numeric_EDUCATION standardscaler__numeric_EXPERIENCE,poly__standardscaler__numeric_EDUCATION standardscaler__numeric_AGE,poly__standardscaler__numeric_EXPERIENCE^2,poly__standardscaler__numeric_EXPERIENCE standardscaler__numeric_AGE,poly__standardscaler__numeric_AGE^2,ohe__onehotencoder__category_SOUTH_yes,ohe__onehotencoder__category_SEX_male,ohe__onehotencoder__category_UNION_not_member,ohe__onehotencoder__category_RACE_Hispanic,ohe__onehotencoder__category_RACE_Other,ohe__onehotencoder__category_RACE_White,ohe__onehotencoder__category_OCCUPATION_Clerical,ohe__onehotencoder__category_OCCUPATION_Management,ohe__onehotencoder__category_OCCUPATION_Other,ohe__onehotencoder__category_OCCUPATION_Professional,ohe__onehotencoder__category_OCCUPATION_Sales,ohe__onehotencoder__category_OCCUPATION_Service,ohe__onehotencoder__category_SECTOR_Construction,ohe__onehotencoder__category_SECTOR_Manufacturing,ohe__onehotencoder__category_SECTOR_Other,ohe__onehotencoder__category_MARR_Unmarried
0,-1.920733,0.256943,-0.156487,1.600919,-0.114953,0.399838,-0.735761,-0.846160,-0.820536,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
1,-1.538020,1.954858,1.721353,0.812900,-2.166087,-2.216164,2.222667,1.983774,1.651194,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
2,-0.389880,-1.360118,-1.522189,-0.504820,0.720654,0.659750,0.669542,0.907846,1.107822,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,1.0
3,-0.389880,-1.117559,-1.266119,-0.504820,0.643469,0.571159,0.196106,0.363175,0.507253,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0
4,-0.389880,-0.066469,-0.156487,-0.504820,0.308999,0.187262,-0.784289,-0.804100,-0.820536,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,1.0,0.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
529,1.906399,-1.036706,-0.668625,1.568262,-1.325231,-0.997976,0.058893,-0.236681,-0.465097,0.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0
530,-0.389880,1.227180,1.209215,-0.504820,-0.102656,-0.285226,0.398589,0.420486,0.388773,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
531,1.523686,0.580356,0.953146,0.786774,1.009579,1.421844,-0.522439,-0.353032,-0.076974,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
532,-0.389880,-0.389881,-0.497912,-0.504820,0.411913,0.305384,-0.668023,-0.651414,-0.632603,1.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
