In [1]:
import joblib
import numpy as np
import pandas as pd
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.preprocessing import OneHotEncoder

In [2]:
class MyTransformer(TransformerMixin, BaseEstimator):

    def __init__(self, *, top_cats_n):
        self.top_cats_n = top_cats_n

    def fit(self, X, y=None):
        self.cols_ = X.select_dtypes(include=["object", "category"]).columns.tolist()
        for col in self.cols_:
            top_cats = list(X[col].value_counts().nlargest(self.top_cats_n).index) + ['other']
            ohe = OneHotEncoder(
                categories=[top_cats],
                handle_unknown='ignore',
                sparse_output=False
            )
            ohe.fit(X[[col]])
            feature_names = ohe.get_feature_names_out([col])
            setattr(self, f"{col}_top_cats_", top_cats)
            setattr(self, f"{col}_ohe_", ohe)
            setattr(self, f"{col}_feature_names_", feature_names)
        return self

    def transform(self, X):
        X = X.copy()
        for col in self.cols_:
            top_cats = getattr(self, f"{col}_top_cats_")
            X[col] = X[col].where(X[col].isin(top_cats), other="other")
            ohe = getattr(self, f"{col}_ohe_")
            ohe_transformed = ohe.transform(X[[col]])
            ohe_df = pd.DataFrame(ohe_transformed, columns=getattr(self, f"{col}_feature_names_"), index=X.index)
            X = pd.concat([X, ohe_df], axis=1)
            X.drop(columns=[col], inplace=True)
        return X

In [3]:
##### DF1 #####

# Parameters
n_rows = 100
np.random.seed(42)

# Generate 10 numerical columns with different means/variances
num_cols = {
    f"num_{i+1}": np.random.normal(loc=i * 5, scale=(i + 1), size=n_rows)
    for i in range(10)
}

# Generate 2 categorical columns
cat_1 = np.random.choice(['A', 'B', 'C', 'G'], size=n_rows, p=[0.4, 0.3, 0.2, 0.1])
cat_2 = np.random.choice(['X', 'Y', 'Z', 'W'], size=n_rows, p=[0.4, 0.3, 0.2, 0.1])

# Assemble DataFrame
df1 = pd.DataFrame(num_cols)
df1['cat_1'] = cat_1
df1['cat_2'] = cat_2

In [4]:
df1.head(20)

Unnamed: 0,num_1,num_2,num_3,num_4,num_5,num_6,num_7,num_8,num_9,num_10,cat_1,cat_2
0,0.496714,2.169259,11.073362,11.68402,12.027862,30.557065,35.29892,30.818216,48.444554,48.686733,A,Y
1,-0.138264,4.158709,11.682354,12.759276,17.003125,36.4565,23.544843,43.392074,35.355597,41.066612,A,X
2,0.647689,4.314571,13.249154,17.989174,20.026218,16.608595,36.087241,29.36525,40.865087,45.287448,B,Y
3,1.52303,3.395445,13.161406,17.441481,20.234903,28.377815,39.489465,23.73231,35.839522,57.784519,C,Z
4,-0.234153,4.677429,5.866992,14.916394,17.749673,21.096145,32.894044,22.546967,36.089534,46.910991,A,X
5,-0.234137,5.808102,7.186525,15.46931,23.11425,22.077248,43.137571,39.84808,37.217451,45.464365,G,X
6,1.579213,8.772372,11.545106,20.11066,14.661898,21.445636,24.583476,24.756565,41.999204,31.401439,A,X
7,0.767435,5.349156,11.541358,12.633714,19.288103,19.816055,21.287417,49.038353,35.691262,52.462536,B,Z
8,-0.469474,5.515101,11.545143,17.18839,20.601478,25.29113,17.548958,18.344565,51.301805,51.454842,C,Y
9,0.54256,4.851108,21.558194,14.191229,22.572194,20.014299,40.47231,48.571651,31.948534,66.632547,C,Y


In [5]:
transformer = joblib.load('my_pipeline.joblib')
# Transform
df1_transformed = transformer.transform(df1)


In [6]:
df1_transformed.head(20)

Unnamed: 0,num_1,num_2,num_3,num_4,num_5,num_6,num_7,num_8,num_9,num_10,cat_1_A,cat_1_B,cat_1_C,cat_1_other,cat_2_Y,cat_2_X,cat_2_Z,cat_2_other
0,0.496714,2.169259,11.073362,11.68402,12.027862,30.557065,35.29892,30.818216,48.444554,48.686733,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0
1,-0.138264,4.158709,11.682354,12.759276,17.003125,36.4565,23.544843,43.392074,35.355597,41.066612,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
2,0.647689,4.314571,13.249154,17.989174,20.026218,16.608595,36.087241,29.36525,40.865087,45.287448,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0
3,1.52303,3.395445,13.161406,17.441481,20.234903,28.377815,39.489465,23.73231,35.839522,57.784519,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0
4,-0.234153,4.677429,5.866992,14.916394,17.749673,21.096145,32.894044,22.546967,36.089534,46.910991,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
5,-0.234137,5.808102,7.186525,15.46931,23.11425,22.077248,43.137571,39.84808,37.217451,45.464365,0.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0
6,1.579213,8.772372,11.545106,20.11066,14.661898,21.445636,24.583476,24.756565,41.999204,31.401439,1.0,0.0,0.0,0.0,0.0,1.0,0.0,0.0
7,0.767435,5.349156,11.541358,12.633714,19.288103,19.816055,21.287417,49.038353,35.691262,52.462536,0.0,1.0,0.0,0.0,0.0,0.0,1.0,0.0
8,-0.469474,5.515101,11.545143,17.18839,20.601478,25.29113,17.548958,18.344565,51.301805,51.454842,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
9,0.54256,4.851108,21.558194,14.191229,22.572194,20.014299,40.47231,48.571651,31.948534,66.632547,0.0,0.0,1.0,0.0,1.0,0.0,0.0,0.0
