In [40]:
import numpy as np

from toolz import curry

from sklearn.pipeline      import Pipeline
from sklearn.pipeline      import FeatureUnion
from sklearn.preprocessing import FunctionTransformer

In [21]:
def all_but_first_column(X):
    return X[:, 0:1]

In [22]:
pipeline = Pipeline(steps=[(
    "get_first_column", 
    FunctionTransformer(all_but_first_column)
)])
pipeline

Pipeline(memory=None,
     steps=[('get_first_column', FunctionTransformer(accept_sparse=False,
          func=<function all_but_first_column at 0x7f9fb55f4e18>,
          inv_kw_args=None, inverse_func=None, kw_args=None,
          pass_y='deprecated', validate=True))])

In [25]:
X = np.random.randn(5, 3)

In [32]:
print(X)
print("=======================")
print(pipeline.fit_transform(X))

[[-0.91118516 -0.69948391 -0.06834985]
 [-0.71172352 -0.69820786  2.11599225]
 [ 0.77292234 -2.02302618 -0.53144427]
 [ 0.1472152   0.94562377 -0.08982214]
 [ 0.64061492  0.81935769  1.85425789]]
[[-0.91118516]
 [-0.71172352]
 [ 0.77292234]
 [ 0.1472152 ]
 [ 0.64061492]]


-----

# With FeatureUnion

In [41]:
@curry
def get_k_column(X, k = 0):
    return X[:, k:k+1]

In [45]:
print(X)
print("==========")
print(get_k_column(k = 0)(X))
print("==========")
print(get_k_column(k = 1)(X))

[[-0.32954384  0.52458624 -0.06271435]
 [ 0.10698009  0.54341036  0.32944349]
 [ 2.0085786   1.96876629 -2.14392048]
 [-1.16446276 -0.29182078 -0.48332359]
 [ 1.39127359  1.76340138  0.0135359 ]]
[[-0.32954384]
 [ 0.10698009]
 [ 2.0085786 ]
 [-1.16446276]
 [ 1.39127359]]
[[ 0.52458624]
 [ 0.54341036]
 [ 1.96876629]
 [-0.29182078]
 [ 1.76340138]]


In [46]:
pipe = FeatureUnion([
    ("column 0", FunctionTransformer(get_k_column)), 
    ("column 3", FunctionTransformer(get_k_column(k = 3)))])

In [48]:
X = np.random.randn(5, 5)

In [49]:
print(X)
print(pipe.fit_transform(X))

[[-2.26550082 -0.76296209 -0.02095459 -0.4521277   0.37517583]
 [ 2.09686853  1.15261556 -1.60330963 -0.74167683  0.29010948]
 [ 0.57424158 -0.16985586 -0.48617983 -0.14570154  0.60481298]
 [-0.50096337  2.52452305  1.8845453   0.34690285 -0.7279353 ]
 [ 0.74923332  1.90750803 -0.64769566  0.38852121 -0.85298245]]
[[-2.26550082 -0.4521277 ]
 [ 2.09686853 -0.74167683]
 [ 0.57424158 -0.14570154]
 [-0.50096337  0.34690285]
 [ 0.74923332  0.38852121]]
