In [4]:
import numpy as np
from sklearn.datasets import make_classification
from torch import nn
import torch.nn.functional as F

from skorch import NeuralNetClassifier
from sklearn.model_selection import cross_val_score


X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)

class MyModule(nn.Module):
    def __init__(self, num_units=10, nonlin=F.relu):
        super(MyModule, self).__init__()

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(0.5)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X), dim=-1)
        return X


net = NeuralNetClassifier(
    MyModule,
    max_epochs=10,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

net.fit(X, y)
y_proba = net.predict_proba(X)

#In an sklearn Pipeline:

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler


pipe = Pipeline([
    ('scale', StandardScaler()),
    ('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

#With grid search

from sklearn.model_selection import GridSearchCV

params = {
    'lr': [0.01, 0.02],
    'max_epochs': [10, 20],
    'module__num_units': [10, 20],
}
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy')

gs.fit(X, y)
print(gs.best_score_, gs.best_params_)

cv_score = cross_val_score(pipe, X, y, cv = 5)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7089[0m       [32m0.5400[0m        [35m0.6743[0m  0.0281
      2        [36m0.6776[0m       [32m0.6150[0m        [35m0.6567[0m  0.0144
      3        [36m0.6506[0m       [32m0.6600[0m        [35m0.6386[0m  0.0137
      4        [36m0.6387[0m       [32m0.7100[0m        [35m0.6226[0m  0.0122
      5        [36m0.6149[0m       0.7100        [35m0.6023[0m  0.0138
      6        [36m0.5957[0m       [32m0.7200[0m        [35m0.5893[0m  0.0129
      7        [36m0.5844[0m       [32m0.7250[0m        [35m0.5762[0m  0.0118
      8        0.5860       [32m0.7350[0m        [35m0.5631[0m  0.0129
      9        [36m0.5718[0m       0.7300        [35m0.5554[0m  0.0220
     10        [36m0.5594[0m       0.7350        [35m0.5440[0m  0.0182
Re-initializing module!
  epoch    train_loss    valid_acc    valid_loss     dur
-

     11        0.7176       0.4328        [35m0.7086[0m  0.0084
     12        [36m0.7052[0m       0.4179        [35m0.7078[0m  0.0092
     13        [36m0.7050[0m       0.4179        [35m0.7072[0m  0.0090
     14        0.7080       0.4104        [35m0.7063[0m  0.0088
     15        0.7062       0.4104        [35m0.7053[0m  0.0082
     16        0.7106       0.4104        [35m0.7048[0m  0.0083
     17        [36m0.7016[0m       0.4179        [35m0.7043[0m  0.0100
     18        0.7144       0.4104        [35m0.7032[0m  0.0117
     19        [36m0.7000[0m       0.4179        [35m0.7027[0m  0.0146
     20        0.7019       0.4254        [35m0.7020[0m  0.0111
Re-initializing module!
Re-initializing module!
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7703[0m       [32m0.5149[0m        [35m0.7337[0m  0.0115
      2        [36m0.7520[0m       0.5149        [35m0.

     12        [36m0.6893[0m       0.5373        [35m0.6880[0m  0.0106
     13        [36m0.6832[0m       0.5448        [35m0.6866[0m  0.0093
     14        0.6863       0.5672        [35m0.6853[0m  0.0103
     15        [36m0.6811[0m       0.5448        [35m0.6837[0m  0.0096
     16        0.6837       0.5299        [35m0.6828[0m  0.0096
     17        [36m0.6767[0m       0.5299        [35m0.6817[0m  0.0097
     18        0.6813       0.5597        [35m0.6805[0m  0.0096
     19        [36m0.6763[0m       0.5597        [35m0.6796[0m  0.0098
     20        0.6813       0.5522        [35m0.6788[0m  0.0086
Re-initializing module!
Re-initializing module!
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7116[0m       [32m0.5224[0m        [35m0.6984[0m  0.0090
      2        0.7212       0.5224        [35m0.6953[0m  0.0093
      3        [36m0.7038[0m       [32m0.5299

      3        0.6960       0.5075        [35m0.7036[0m  0.0082
      4        0.6971       0.5075        [35m0.7022[0m  0.0087
      5        0.6955       0.4925        [35m0.7008[0m  0.0087
      6        0.6995       0.4925        [35m0.6997[0m  0.0088
      7        [36m0.6904[0m       0.4925        [35m0.6984[0m  0.0082
      8        0.6914       0.4925        [35m0.6976[0m  0.0084
      9        0.6919       0.4925        [35m0.6966[0m  0.0087
     10        [36m0.6858[0m       0.4851        [35m0.6957[0m  0.0085
     11        0.6880       0.4925        [35m0.6949[0m  0.0083
     12        [36m0.6852[0m       0.5000        [35m0.6940[0m  0.0089
     13        [36m0.6839[0m       0.5000        [35m0.6930[0m  0.0079
     14        [36m0.6828[0m       0.4851        [35m0.6921[0m  0.0081
     15        [36m0.6823[0m       0.4925        [35m0.6913[0m  0.0089
     16        [36m0.6768[0m       0.5000        [35m0.6905[0m  0.0079
     17      