# Sklearn Pipeline: Custom Transformer

In [11]:
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted
import pickle
import pandas as pd

### functions

In [12]:
def load_data():
    url = 'https://raw.githubusercontent.com/jmquintana79/Datasets/master/iris.csv'
    df = pd.read_csv(url, header = None)
    df.columns = ['x1','x2','x3','x4','y']
    return df

## Custom transformer

In [13]:
class CTransformer(BaseEstimator, TransformerMixin):

    def __init__(self):
        self.dimputer = {}
        self.is_fitted = False

    def fit(self, df:pd.DataFrame):
        assert isinstance(df, pd.DataFrame)
        self.dimputer = df.groupby('y')[['x1']].mean().to_dict()['x1']
        self.is_fitted = True
        return self

    def transform(self, df:pd.DataFrame):
        assert isinstance(df, pd.DataFrame)
        assert self.is_fitted
        df['y'] = df['y'].apply(lambda x: self.dimputer[x])
        return df

In [14]:
# load data
df = load_data()
# initialice
ci = CTransformer()
# fit
ci.fit(df)
# transform
ci.transform(df)

Unnamed: 0,x1,x2,x3,x4,y
0,5.1,3.5,1.4,0.2,5.006
1,4.9,3.0,1.4,0.2,5.006
2,4.7,3.2,1.3,0.2,5.006
3,4.6,3.1,1.5,0.2,5.006
4,5.0,3.6,1.4,0.2,5.006
...,...,...,...,...,...
145,6.7,3.0,5.2,2.3,6.588
146,6.3,2.5,5.0,1.9,6.588
147,6.5,3.0,5.2,2.0,6.588
148,6.2,3.4,5.4,2.3,6.588


## Export / Import fitted transformer

In [15]:
# save the transformer to disk
filename = 'transformer.sav'
pickle.dump(ci, open(filename, 'wb'))

In [16]:
%ls

notebook-sklearn_pipeline_variables_selection-cat_num.ipynb
notetebook-sklearn_pipeline_custom_transformer.ipynb
transformer.sav


In [17]:
# load the model from disk
loaded_model = pickle.load(open(filename, 'rb'))
# load data
df = load_data()
# tranform
loaded_model.transform(df)

Unnamed: 0,x1,x2,x3,x4,y
0,5.1,3.5,1.4,0.2,5.006
1,4.9,3.0,1.4,0.2,5.006
2,4.7,3.2,1.3,0.2,5.006
3,4.6,3.1,1.5,0.2,5.006
4,5.0,3.6,1.4,0.2,5.006
...,...,...,...,...,...
145,6.7,3.0,5.2,2.3,6.588
146,6.3,2.5,5.0,1.9,6.588
147,6.5,3.0,5.2,2.0,6.588
148,6.2,3.4,5.4,2.3,6.588
