## Skorch RHC network classification example

In [None]:
import subprocess

# Installation on Google Colab
try:
    import google.colab
    subprocess.run(['python', '-m', 'pip', 'install', 'skorch' , 'torch'])
except ImportError:
    pass

In [1]:
import numpy as np
from sklearn.datasets import make_classification
import torch
from torch import nn
from skorch import NeuralNetClassifier
from skorch import NeuralNet
from pyperch.neural.rhc_nn import RHCModule  
from pyperch.utils.decorators import add_to
from skorch.dataset import unpack_data
import copy

In [2]:
X, y = make_classification(1000, 12, n_informative=10, random_state=0)
X = X.astype(np.float32)
y = y.astype(np.int64)
print(X.shape, y.shape)
input_dim=12
output_dim=2
hidden_units=20

(1000, 12) (1000,)


In [3]:
rhc_module=RHCModule(input_dim=input_dim, output_dim=output_dim, hidden_units=hidden_units)

net = NeuralNetClassifier(
    rhc_module,
    max_epochs=500,
    lr=0.1,
    # Shuffle training data on each epoch
    iterator_train__shuffle=True,
)

## train_step_single override - add RHC training step and disable backprop 

In [4]:
RHCModule.register_rhc_training_step()

In [5]:
net.fit(X, y)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6891[0m       [32m0.5300[0m        [35m0.6982[0m  0.0224
      2        [36m0.6877[0m       [32m0.5350[0m        [35m0.6969[0m  0.0255
      3        [36m0.6858[0m       [32m0.5400[0m        [35m0.6948[0m  0.0228
      4        [36m0.6851[0m       0.5350        [35m0.6948[0m  0.0339
      5        [36m0.6841[0m       0.5350        [35m0.6919[0m  0.0277
      6        [36m0.6821[0m       [32m0.5500[0m        [35m0.6909[0m  0.0321
      7        [36m0.6800[0m       0.5500        [35m0.6882[0m  0.0257
      8        [36m0.6794[0m       [32m0.5650[0m        [35m0.6870[0m  0.0312
      9        [36m0.6774[0m       [32m0.5850[0m        [35m0.6837[0m  0.0230
     10        [36m0.6753[0m       0.5850        [35m0.6813[0m  0.0267
     11        [36m0.6740[0m       [32m0.5900[0m        [35m0.6792[0m  0.030

<class 'skorch.classifier.NeuralNetClassifier'>[initialized](
  module_=RHCModule(
    (dense0): Linear(in_features=12, out_features=20, bias=True)
    (nonlin): ReLU()
    (dropout): Dropout(p=0.5, inplace=False)
    (dense1): Linear(in_features=20, out_features=20, bias=True)
    (output): Linear(in_features=20, out_features=2, bias=True)
    (softmax): Softmax(dim=-1)
  ),
)

## Using sklearn pipeline with RO

In [6]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pipe = Pipeline([
    ('scale', StandardScaler()),
    ('net', net),
])

pipe.fit(X, y)
y_proba = pipe.predict_proba(X)

Re-initializing module.
Re-initializing criterion.
Re-initializing optimizer.
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6898[0m       [32m0.5050[0m        [35m0.6940[0m  0.0260
      2        0.6899       0.5000        [35m0.6937[0m  0.0257
      3        [36m0.6897[0m       0.4950        0.6938  0.0255
      4        [36m0.6892[0m       0.5050        [35m0.6931[0m  0.0244
      5        [36m0.6887[0m       [32m0.5150[0m        [35m0.6917[0m  0.0200
      6        [36m0.6878[0m       [32m0.5250[0m        [35m0.6913[0m  0.0247
      7        [36m0.6871[0m       [32m0.5300[0m        [35m0.6901[0m  0.0199
      8        [36m0.6867[0m       [32m0.5350[0m        [35m0.6891[0m  0.0265
      9        [36m0.6861[0m       0.5350        0.6891  0.0258
     10        [36m0.6854[0m       0.5250        [35m0.6882[0m  0.0237
     11        [36m0.6848[0m       [3

## Using sklearn grid search with RO

In [8]:
from sklearn.model_selection import GridSearchCV

# deactivate skorch-internal train-valid split and verbose logging
net.set_params(train_split=False, verbose=0, )

default_params = {
    'module__input_dim': [12],
    'module__output_dim': [2],
}

grid_search_params = {
    'lr': [0.01, 0.02],
    'max_epochs': [10, 20],
    'module__hidden_units': [10, 20],
    **default_params,
}

gs = GridSearchCV(net, grid_search_params, refit=False, cv=3, scoring='accuracy', verbose=2)

gs.fit(X, y)
print("best score: {:.3f}, best params: {}".format(gs.best_score_, gs.best_params_))

Fitting 3 folds for each of 8 candidates, totalling 24 fits
[CV] END lr=0.01, max_epochs=10, module__hidden_units=10, module__input_dim=20, module__output_dim=2; total time=   0.0s
[CV] END lr=0.01, max_epochs=10, module__hidden_units=10, module__input_dim=20, module__output_dim=2; total time=   0.0s
[CV] END lr=0.01, max_epochs=10, module__hidden_units=10, module__input_dim=20, module__output_dim=2; total time=   0.0s
[CV] END lr=0.01, max_epochs=10, module__hidden_units=20, module__input_dim=20, module__output_dim=2; total time=   0.0s
[CV] END lr=0.01, max_epochs=10, module__hidden_units=20, module__input_dim=20, module__output_dim=2; total time=   0.0s
[CV] END lr=0.01, max_epochs=10, module__hidden_units=20, module__input_dim=20, module__output_dim=2; total time=   0.0s
[CV] END lr=0.01, max_epochs=20, module__hidden_units=10, module__input_dim=20, module__output_dim=2; total time=   0.0s
[CV] END lr=0.01, max_epochs=20, module__hidden_units=10, module__input_dim=20, module__outpu

ValueError: 
All the 24 fits failed.
It is very likely that your model is misconfigured.
You can try to debug the error by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
12 fits failed with the following error:
Traceback (most recent call last):
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/sklearn/model_selection/_validation.py", line 895, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/classifier.py", line 165, in fit
    return super(NeuralNetClassifier, self).fit(X, y, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1319, in fit
    self.partial_fit(X, y, **fit_params)
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1278, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1190, in fit_loop
    self.run_single_epoch(iterator_train, training=True, prefix="train",
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1226, in run_single_epoch
    step = step_fn(batch, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1105, in train_step
    self._step_optimizer(step_fn)
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1060, in _step_optimizer
    optimizer.step(step_fn)
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/optim/optimizer.py", line 385, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/optim/sgd.py", line 66, in step
    loss = closure()
           ^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1094, in step_fn
    step = self.train_step_single(batch, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/pyperch/neural/rhc_nn.py", line 73, in train_step_single
    loss, y_pred = self.module_.run_rhc_single_step(self, Xi, yi, **fit_params)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/pyperch/neural/rhc_nn.py", line 40, in run_rhc_single_step
    y_pred = net.infer(X_train, **fit_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1521, in infer
    return self.module_(x, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/pyperch/neural/rhc_nn.py", line 30, in forward
    X = self.nonlin(self.dense0(X))
                    ^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x12 and 20x10)

--------------------------------------------------------------------------------
12 fits failed with the following error:
Traceback (most recent call last):
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/sklearn/model_selection/_validation.py", line 895, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/classifier.py", line 165, in fit
    return super(NeuralNetClassifier, self).fit(X, y, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1319, in fit
    self.partial_fit(X, y, **fit_params)
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1278, in partial_fit
    self.fit_loop(X, y, **fit_params)
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1190, in fit_loop
    self.run_single_epoch(iterator_train, training=True, prefix="train",
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1226, in run_single_epoch
    step = step_fn(batch, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1105, in train_step
    self._step_optimizer(step_fn)
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1060, in _step_optimizer
    optimizer.step(step_fn)
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/optim/optimizer.py", line 385, in wrapper
    out = func(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/optim/optimizer.py", line 76, in _use_grad
    ret = func(self, *args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/optim/sgd.py", line 66, in step
    loss = closure()
           ^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1094, in step_fn
    step = self.train_step_single(batch, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/pyperch/neural/rhc_nn.py", line 73, in train_step_single
    loss, y_pred = self.module_.run_rhc_single_step(self, Xi, yi, **fit_params)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/pyperch/neural/rhc_nn.py", line 40, in run_rhc_single_step
    y_pred = net.infer(X_train, **fit_params)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/skorch/net.py", line 1521, in infer
    return self.module_(x, **fit_params)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/pyperch/neural/rhc_nn.py", line 30, in forward
    X = self.nonlin(self.dense0(X))
                    ^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/john/anaconda3/envs/pyperch/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x12 and 20x20)
