In [None]:
from itertools import product, chain
import operator

class FeatureImputer(BaseEstimator, TransformerMixin):
    def __init__(self,column,features):
        
        self.column=column
        self.features=features
        self.query_template = ' & '.join(['{} == "{}"' for i in range(len(features))])
        
    
    def fit(self, X, y=None):
        column = self.column
        features = self.features
        non_nan_selection = X.loc[X[column].notnull(),column]
        self.global_mid = np.mean(non_nan_selection)
        self.global_stdev = np.std(non_nan_selection)
        
        feature_sets = [set(X[feat]) for feat in features]
        prods = list(product(*feature_sets))
        
        prod_mids = {}
        for prod in prods:
            selection = self.select_query(X,prod)
            if selection.as_matrix().shape[0] == 0:
                mean = self.global_mean
                stdev = self.global_stdev
            else:
                mean = np.mean(selection)
                stdev = np.std(selection)
        
        prod_mids = {prod:(mean,stdev) for prod in prods}
        self.prod_mids = prod_mids        
        return self
    
    def transform(self, X):
        column = self.column
        features = self.features
        for prod, values  in self.prod_mids.items():
            prod_mean,prod_stdev = values
            self.apply_query(X,prod,prod_mean,prod_stdev)
        return X
    
    def make_query_mask(self,X,prod,filter_nans=True):
        query = self.query_template.format(*list(chain(*zip(self.features,prod))))
        if filter_nans:
            query = query + ' & {} == {}'.format(self.column,self.column)
        mask = X.eval(query)
        return mask
    
    def select_query(self,X,prod):
        mask = self.make_query_mask(X,prod)
        return X.loc[mask,self.column]
    
    def apply_query(self,X,prod,prod_mean,prod_stdev):
        mask = self.make_query_mask(X,prod,filter_nans=False)
        X.loc[mask & X[self.column].isnull(),self.column] = X.loc[mask & X[self.column].isnull(),self.column].apply(lambda x: np.random.normal(prod_mean,prod_stdev))
    