In [45]:
from sklearn.base import TransformerMixin
from sklearn.datasets import make_regression
from sklearn.pipeline import Pipeline, FeatureUnion
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestRegressor
from sklearn.neighbors import KNeighborsRegressor
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.linear_model import LinearRegression, Ridge

In [46]:
class RidgeTransformer(Ridge, TransformerMixin):

    def transform(self, X, *_):
        return self.predict(X)


class RandomForestTransformer(RandomForestRegressor, TransformerMixin):

    def transform(self, X, *_):
        return self.predict(X)


class KNeighborsTransformer(KNeighborsRegressor, TransformerMixin):

    def transform(self, X, *_):
        return self.predict(X)

In [52]:
def build_model():
    ridge_transformer = Pipeline(steps=[
        ('scaler', StandardScaler()),
        ('poly_feats', PolynomialFeatures()),
        ('ridge', RidgeTransformer())
    ])

    pred_union = FeatureUnion(
        transformer_list=[
            ('ridge', ridge_transformer),
            ('rand_forest', RandomForestTransformer()),
            ('knn', KNeighborsTransformer())
        ],
        n_jobs=2
    )

    model = Pipeline(steps=[
        ('pred_union', pred_union),
        ('knn', KNeighborsTransformer())
    ])

    return model


print('Build and fit a model..')

model = build_model()

X, y = make_regression(n_features=10, n_targets=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

model.fit(X_train, y_train)
score = model.score(X_test, y_test)

print('Done. Score:', score)

Build and fit a model..
('Done. Score:', 0.816525405217137)


In [53]:
y

array([[ 1.89756358e+02,  1.51034002e+02],
       [-7.62879399e+01, -1.97749633e+02],
       [-1.87399635e+02, -3.26042487e+01],
       [-9.35430163e+01, -1.15038565e+02],
       [-4.79834227e+01, -1.32823947e+01],
       [ 4.33868314e+01,  2.16404541e+02],
       [-2.63370029e+02, -3.62552405e+02],
       [ 1.71247355e+02,  2.26043024e+02],
       [ 3.89866599e+02,  2.22846152e+02],
       [ 2.17906364e+01, -8.21886953e+01],
       [ 3.60128120e+01, -2.17951173e+00],
       [-1.79950165e+02, -5.53686049e+01],
       [ 1.98120494e+01, -8.82283654e+01],
       [-1.36562654e+02, -7.46304357e+01],
       [ 6.12061739e+01,  7.02662447e+01],
       [-5.86913454e+00, -3.88550894e+01],
       [ 3.78168658e+02,  2.19341592e+02],
       [ 1.38914271e+01, -9.06557629e+01],
       [-9.09631688e+01, -2.02837758e+02],
       [-2.10928620e+02, -2.40654746e+02],
       [-6.71845612e+01, -2.60853693e+01],
       [ 7.65604262e+02,  3.83247775e+02],
       [ 4.08259691e+00,  1.01341135e+02],
       [-1.