# Sklearn Pipeline: Custom Transformer

In [35]:
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted
import pickle
import pandas as pd

### functions

In [36]:
def load_data():
    url = 'https://raw.githubusercontent.com/jmquintana79/Datasets/master/iris.csv'
    df = pd.read_csv(url, header = None)
    df.columns = ['x1','x2','x3','x4','y']
    return df

## Custom transformer

https://scikit-learn.org/stable/modules/generated/sklearn.utils.validation.check_is_fitted.html

In [37]:
class CTransformer(BaseEstimator, TransformerMixin):

    def __init__(self):
        self.dimputer = {}

    def fit(self, df:pd.DataFrame, y:pd.DataFrame=None):
        assert isinstance(df, pd.DataFrame)
        self.dimputer = df.groupby('y')[['x1']].mean().to_dict()['x1']
        return self

    def transform(self, df:pd.DataFrame):
        assert isinstance(df, pd.DataFrame)
        check_is_fitted(self)
        df['y'] = df['y'].apply(lambda x: self.dimputer[x])
        return df

In [38]:
# load data
df = load_data()
# initialice
ci = CTransformer()
# fit
ci.fit(df)
# transform
ci.transform(df)

NotFittedError: This CTransformer instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

## Export / Import fitted transformer

In [39]:
# save the transformer to disk
filename = 'transformer.sav'
pickle.dump(ci, open(filename, 'wb'))

In [33]:
%ls

notebook-sklearn_pipeline_variables_selection-cat_num.ipynb
notetebook-sklearn_pipeline_custom_transformer.ipynb
transformer.sav


In [34]:
# load the model from disk
loaded_model = pickle.load(open(filename, 'rb'))
# load data
df = load_data()
# tranform
loaded_model.transform(df)

NameError: name 'check_is_fitted' is not defined

In [41]:
dir(ci)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_check_n_features',
 '_get_param_names',
 '_get_tags',
 '_more_tags',
 '_repr_html_',
 '_repr_html_inner',
 '_repr_mimebundle_',
 '_validate_data',
 'dimputer',
 'fit',
 'fit_transform',
 'get_params',
 'set_params',
 'transform']