# Creating a custom imputer in scikit-learn

## Setup

In [1]:
import pandas as pd 
import numpy as np
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.utils.validation import check_is_fitted

In [2]:
import sklearn 
print(f'The scikit-learn version is {sklearn.__version__}.')

The scikit-learn version is 0.23.2.


## Preparing data

In [3]:
# setting the seed
np.random.seed(42)

# generating the two samples
sample_a = pd.DataFrame({'sample_name': 'A',
                         'variant': np.random.choice(['a', 'b'], size=100),
                         'height': np.random.normal(loc=170, scale=5, size=100)})

sample_b = pd.DataFrame({'sample_name': 'B',
                         'variant': 'a',
                         'height': np.random.normal(loc=165, scale=7, size=100)})

# concatenating the samples
df = pd.concat([sample_a, sample_b], axis=0, ignore_index=True) \
       .sample(frac=1) \
       .reset_index(drop=True)

# preview the data
df.head()

Unnamed: 0,sample_name,variant,height
0,B,a,171.678012
1,B,a,166.292437
2,A,a,168.286427
3,B,a,158.773399
4,B,a,155.756804


In [4]:
ind_to_replace = np.random.choice(range(len(df)), 10, replace=False)
df.loc[ind_to_replace, 'height'] = np.nan

In [5]:
df

Unnamed: 0,sample_name,variant,height
0,B,a,171.678012
1,B,a,166.292437
2,A,a,168.286427
3,B,a,158.773399
4,B,a,155.756804
...,...,...,...
195,B,a,178.060422
196,A,b,167.482622
197,B,a,169.803821
198,A,a,


## Creating the custom imputer

In [6]:
class GroupImputer(BaseEstimator, TransformerMixin):
    '''
    Class used for imputing missing values in a pd.DataFrame using either mean or median of a group.
    
    Parameters
    ----------    
    group_cols : list
        List of columns used for calculating the aggregated value 
    target : str
        The name of the column to impute
    metric : str
        The metric to be used for remplacement, can be one of ['mean', 'median']

    Returns
    -------
    X : array-like
        The array with imputed values in the target column
    '''
    def __init__(self, group_cols, target, metric='mean'):
        
        assert metric in ['mean', 'median'], 'Unrecognized value for metric, should be mean/median'
        assert type(group_cols) == list, 'group_cols should be a list of columns'
        assert type(target) == str, 'target should be a string'
        
        self.group_cols = group_cols
        self.target = target
        self.metric = metric
    
    def fit(self, X, y=None):
        
        assert pd.isnull(X[self.group_cols]).any(axis=None) == False, 'There are missing values in group_cols'
        
        impute_map = X.groupby(self.group_cols)[self.target].agg(self.metric) \
                                                            .reset_index(drop=False)
        
        self.impute_map_ = impute_map
        
        return self 
    
    def transform(self, X, y=None):
        
        # make sure that the imputer was fitted
        check_is_fitted(self, 'impute_map_')
        
        X = X.copy()
        
        for index, row in self.impute_map_.iterrows():
            ind = (X[self.group_cols] == row[self.group_cols]).all(axis=1)
            X.loc[ind, self.target] = X.loc[ind, self.target].fillna(row[self.target])
        
        return X.values

In [7]:
imp = GroupImputer(group_cols=['sample_name', 'variant'], 
                   target='height', 
                   metric='mean')

df_imp = pd.DataFrame(imp.fit_transform(df), 
                      columns=df.columns)

print(f'df contains {sum(pd.isnull(df.height))} missing values.')
print(f'df_imp contains {sum(pd.isnull(df_imp.height))} missing values.')

df contains 10 missing values.
df_imp contains 0 missing values.


In [8]:
df_imp

Unnamed: 0,sample_name,variant,height
0,B,a,171.678
1,B,a,166.292
2,A,a,168.286
3,B,a,158.773
4,B,a,155.757
...,...,...,...
195,B,a,178.06
196,A,b,167.483
197,B,a,169.804
198,A,a,170.362
