In [45]:
from sklearn.base import TransformerMixin
from sklearn.datasets import make_regression
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.linear_model import LinearRegression, Ridge

In [46]:
class RidgeTransformer(Ridge, TransformerMixin):

    def transform(self, X, *_):
        return self.predict(X)


class RandomForestTransformer(RandomForestRegressor, TransformerMixin):

    def transform(self, X, *_):
        return self.predict(X)


class KNeighborsTransformer(KNeighborsRegressor, TransformerMixin):

    def transform(self, X, *_):
        return self.predict(X)

In [47]:
def build_model():
    ridge_transformer = Pipeline(steps=[
        ('scaler', StandardScaler()),
        ('poly_feats', PolynomialFeatures()),
        ('ridge', RidgeTransformer())
    ])

    pred_union = FeatureUnion(
        transformer_list=[
            ('ridge', ridge_transformer),
            ('rand_forest', RandomForestTransformer()),
            ('knn', KNeighborsTransformer())
        ],
        n_jobs=2
    )

    model = Pipeline(steps=[
        ('pred_union', pred_union),
        ('lin_regr', LinearRegression())
    ])

    return model


print('Build and fit a model..')

model = build_model()

X, y = make_regression(n_features=10, n_targets=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

model.fit(X_train, y_train)
score = model.score(X_test, y_test)

print('Done. Score:', score)

Build and fit a model..
('Done. Score:', 0.9922337477290193)


In [49]:
y

array([[-1.99857274e+02, -1.40616147e+01],
       [ 3.56071970e+02,  4.52120385e+02],
       [ 1.73885610e+02,  7.19889570e+01],
       [ 1.28439181e+02,  5.85479762e+01],
       [-1.98138790e+01,  5.80955014e+01],
       [ 7.92860603e+01,  1.64696000e+02],
       [-8.10836579e+00,  1.81605655e+01],
       [ 2.12041694e+01, -1.93101393e+02],
       [-1.52105203e+01, -3.46386247e+00],
       [ 2.60942111e+02,  7.49434819e+01],
       [-9.11116110e+01, -1.94332404e+02],
       [-2.03678790e+02,  7.48390045e+01],
       [ 3.36138310e+00, -4.79744171e+01],
       [ 1.35926657e+02,  1.68881388e+02],
       [-4.77140877e+01, -7.47980252e+01],
       [-1.15766387e+02,  3.67749584e+01],
       [ 4.07259424e+01, -2.42788936e+01],
       [ 2.28060633e+02,  1.81447095e+02],
       [-5.08407894e+01, -1.00667253e+02],
       [ 1.09773049e+01,  2.41419573e+01],
       [-3.75688302e+02, -1.09459138e+02],
       [ 1.69361796e+02,  1.45937593e+02],
       [ 7.53070610e+01,  1.99659584e+02],
       [ 1.