<img src="images/logo.svg" width="1000" height="200">

# Training PyTorch models with scikit-learn

### Collin Wilson - University of Guelph

## Outline
* [What is Skorch?](#What-is-Skorch?)
* [Who should use Skorch?](#Who-should-use-Skorch?)
* [The Basics - Learning by example](#The-Basics)
* [Benchmarks](#Benchmarks:-Skorch-vs-pure-Pytorch)
* [More of the basics](#More-of-the-basics)
* [Multi GPU accelerated grid search](#Multi-GPU-accelerated-grid-search)
* [Other features](#Other-features)


## What is Skorch?

* Skorch is a wrapper for PyTorch `nn.Module`'s that allows models writen in PyTorch to be trained in `scikit-learn` workflows

### What is `scikit-learn`?

* one of the most popular general machine learning python packages with tools for splitting data into train/test sets, cross-validation, hyperparameter optimization, creating training pipelines and many more 
* great for learning the fundamentals of feature engineering, dataset management, model creation and model selection
* limited for deep learning



### What is PyTorch?
* PyTorch is an extremely popular deep learning python library 
* supports GPU training


## What is Skorch?

* Skorch is a wrapper for PyTorch `nn.Module`'s that allows models writen in PyTorch to be trained in `scikit-learn` workflows <br></br>

* Allowing a PyTorch model to be used in the `scikit-learn` workflow reduces the need for boilerplate code:
    *  training a model is as simple as `net.fit(X, y)` no need to write code for training, validation, reporting etc. 
    * **The only thing you need is the model definition** <br></br>
* supports accelerated hyperparameter optimization using multiple GPUs<br></br>

## Who should use Skorch?<br></br>


* If you already use `scikit-learn` for machine learning workflows and are wanting to incorporate more complex deep learning models<br></br>
* If you are just starting with machine learning and also want to learn how to write deep models in PyTorch <br></br>
* If you want to do relatively pain-free, multi-GPU accelerated hyperparameter optimization


# The Basics

## Learning by example

Modified from [skorch tutorials](https://github.com/skorch-dev/skorch/tree/master/notebooks)

### Preliminaries

In [None]:
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F

torch.manual_seed(0)
torch.cuda.manual_seed(0)

### Create a dataset

In [None]:
from sklearn.datasets import make_classification

# This is a toy dataset for binary classification, 1000 data points with 20 features each
X, y = make_classification(1000, 20, n_informative=10, random_state=0)
X, y = X.astype(np.float32), y.astype(np.int64)
X.shape, y.shape, y.mean()

### Visualizing the dataset after some dimensionality reduction

In [None]:
# Visiualize the dataset after some PCA
from sklearn.decomposition import PCA
import plotly.express as px

pca = PCA(n_components=3)
pca_X = pca.fit_transform(X)
y_str = ['Class ' + str(i) for i in y]
fig=px.scatter_3d(x=pca_X[:, 0], y=pca_X[:, 1], z=pca_X[:, 2], color=y_str)
fig.update_traces(marker={'size': 3})
fig.show()

### Definition of the `pytorch` classification `module`

* Vanilla neural network with two hidden layers. <br></br>
* The output layer should have 2 output units since there are two classes. <br></br>
* In addition, it should have a softmax nonlinearity, because later, when calling `predict_proba`, the output from the `forward` call will be used.


In [None]:
class ClassifierModule(nn.Module):
    def __init__(
            self,
            num_units=10,
            nonlin=F.relu,
            dropout=0.5,
    ):
        super(ClassifierModule, self).__init__()
        self.num_units = num_units
        self.nonlin = nonlin
        self.dropout = dropout

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(dropout)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X) 
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X), dim=-1)
        return X

### Training

We use `NeuralNetClassifier` because we're dealing with a classifcation task. The first argument should be the `pytorch module`. As additional arguments, we pass the number of epochs and the learning rate (`lr`), but those are optional.

*Note*: To use the CUDA backend, pass `device='cuda'` as an additional argument.

In [None]:
from skorch import NeuralNetClassifier

net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=20,
    lr=0.1,
#     device='cuda',  # uncomment this to train with CUDA
)

As in `sklearn`, we call `fit` passing the input data `X` and the targets `y`. By default, `NeuralNetClassifier` makes a `StratifiedKFold` split on the data (80/20) to track the validation loss. This is shown, as well as the train loss and the accuracy on the validation set.

In [None]:
# Training the network 
net.fit(X, y)

### Inference

Also, as in `sklearn`, you may call `predict` or `predict_proba` on the fitted model.

In [None]:
# Making prediction for first 5 data points of X
y_pred = net.predict(X[:5])
y_pred

In [None]:
# Checking probarbility of each class for first 5 data points of X
y_proba = net.predict_proba(X[:5])
y_proba

## Training other types of models:
- `NeuralNetRegressor` for regression
- `NeuralNetBinaryClassifier` for binary classification
- `NeuralNet` base clase, for more generality/customization possibilities
- Also include utilities for Gaussian Process models

# Benchmarks: Skorch vs pure Pytorch

## 1. Simple Convolutional Neural Network (CNN) on MNIST


<br><br>
<center><img src="images/MnistExamplesModified.png" style></center>
<br><br>


By Suvanjanprasai - Own work, CC BY-SA 4.0, https://commons.wikimedia.org/w/index.php?curid=132282871

In [None]:
import pandas as pd
import plotly.express as px
durations = pd.read_csv('benchmarks/mnist_benchmark.csv')
durations['skorch'] = durations['skorch_t']
durations['torch'] = durations['torch_t']

mnist_fig = px.line(durations, x="epoch", y=["skorch", "torch"], 
              title="Training duration for Skorch and pure PyTorch - CNN on MNIST")
mnist_fig.update_yaxes(title='Wall time (s)')
mnist_fig.update_layout(width=1600, height=800)


In [None]:
mnist_fig.show()

## 2. Resnet32 on CIFAR-10

<table>
    <tbody><tr>
        <td><img src="images/resnet.png" width="600"></td>
<td><table>
    <tbody><tr>
        <td class="cifar-class-name">airplane</td>
        <td><img src="cifar-10-sample/airplane1.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/airplane2.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/airplane3.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/airplane4.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/airplane5.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/airplane6.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/airplane7.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/airplane8.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/airplane9.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/airplane10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">automobile</td>
        <td><img src="cifar-10-sample/automobile1.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/automobile2.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/automobile3.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/automobile4.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/automobile5.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/automobile6.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/automobile7.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/automobile8.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/automobile9.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/automobile10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">bird</td>
        <td><img src="cifar-10-sample/bird1.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/bird2.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/bird3.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/bird4.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/bird5.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/bird6.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/bird7.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/bird8.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/bird9.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/bird10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">cat</td>
        <td><img src="cifar-10-sample/cat1.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/cat2.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/cat3.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/cat4.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/cat5.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/cat6.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/cat7.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/cat8.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/cat9.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/cat10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">deer</td>
        <td><img src="cifar-10-sample/deer1.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/deer2.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/deer3.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/deer4.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/deer5.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/deer6.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/deer7.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/deer8.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/deer9.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/deer10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">dog</td>
        <td><img src="cifar-10-sample/dog1.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/dog2.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/dog3.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/dog4.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/dog5.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/dog6.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/dog7.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/dog8.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/dog9.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/dog10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">frog</td>
        <td><img src="cifar-10-sample/frog1.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/frog2.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/frog3.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/frog4.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/frog5.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/frog6.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/frog7.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/frog8.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/frog9.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/frog10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">horse</td>
        <td><img src="cifar-10-sample/horse1.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/horse2.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/horse3.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/horse4.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/horse5.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/horse6.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/horse7.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/horse8.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/horse9.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/horse10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">ship</td>
        <td><img src="cifar-10-sample/ship1.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/ship2.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/ship3.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/ship4.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/ship5.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/ship6.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/ship7.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/ship8.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/ship9.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/ship10.png" class="cifar-sample"></td>
    </tr>
    <tr>
        <td class="cifar-class-name">truck</td>
        <td><img src="cifar-10-sample/truck1.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/truck2.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/truck3.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/truck4.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/truck5.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/truck6.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/truck7.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/truck8.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/truck9.png" class="cifar-sample"></td>
        <td><img src="cifar-10-sample/truck10.png" class="cifar-sample"></td>
    </tr></tbody>
</table></td></tr>
</table>


Sources: [*Deep Residual Learning for Image Recognition*](https://arxiv.org/pdf/1512.03385.pdf) He et. al. (2015), [*The CIFAR-10 dataset*.](https://www.cs.toronto.edu/~kriz/cifar.html) Alex Krizhevsky

In [None]:
durations = pd.read_csv('benchmarks/cifar10_benchmark.csv')
cifar_fig = px.line(durations, x="epoch", y=["skorch", "torch"], 
              title="Training duration for Skorch and pure PyTorch - ResNet32 on CIFAR-10")
cifar_fig.update_yaxes(title='Wall time (s)')
cifar_fig.update_layout(width=1200, height=700)
cifar_fig.show()


In [None]:
cifar_fig.show()


## Callbacks

Adding a new callback to the model is straightforward. Below we show how to add a new callback that determines the area under the ROC (AUC) score.

In [None]:
from skorch.callbacks import EpochScoring

There is a scoring callback in skorch, `EpochScoring`, which we use for this. We have to specify which score to calculate:

* Passing a string: This should be a valid `sklearn` metric. For a list of all existing scores, look [here](http://scikit-learn.org/stable/modules/classes.html#sklearn-metrics-metrics).
* Passing `None`: If you implement your own `.score` method on your neural net, passing `scoring=None` will tell `skorch` to use that.
* Passing a function or callable: If we want to define our own scoring function, we pass a function with the signature `func(model, X, y) -> score`, which is then used.

`sklearn` already implements AUC, we just pass the correct string `'roc_auc'`

In [None]:
auc = EpochScoring(scoring='roc_auc', lower_is_better=False)

Finally, we pass the scoring callback to the `callbacks` parameter as a list...

In [None]:
net = NeuralNetClassifier(
    ClassifierModule,  
    max_epochs=20,
    lr=0.1,
    callbacks=[auc],
)

...and then call `fit`.

In [None]:
net.fit(X, y)

## Saving and loading a model

Save and load either the whole model by using pickle or just the learned model parameters by calling `save_params` and `load_params`.

### Saving the whole model

```python
import pickle

file_name = '/tmp/mymodel.pkl'

with open(file_name, 'wb') as f:
    pickle.dump(net, f)
```

### Loading the whole model

```python
with open(file_name, 'rb') as f:
    new_net = pickle.load(f)
```

### Saving only the model parameters

This only saves and loads the proper `module` parameters, meaning that hyperparameters such as `lr` and `max_epochs` are not saved. Therefore, to load the model, we have to re-initialize it beforehand.

```python
net.save_params(f_params=file_name)  # a file handler also works
```

### Loading the model using saved parameters

```python
# first initialize the model
new_net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=20,
    lr=0.1,
).initialize()

# load the parameters into the model
new_net.load_params(file_name)
```

## Usage with an `sklearn Pipeline`

It is possible to put the `NeuralNetClassifier` inside an `sklearn Pipeline`, as you would with any `sklearn` classifier.

In [None]:
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler

pipe = Pipeline([
    ('scale', StandardScaler()),
    ('net', net),
])

pipe.fit(X, y)

In [None]:
y_proba = pipe.predict_proba(X[:5])
y_proba

To save the whole pipeline, including the pytorch module, use `pickle`.

## Usage with sklearn `GridSearchCV`

### Special prefixes

- The `NeuralNet` class allows to directly access parameters of the `pytorch module` by using the `module__` prefix. 
    - e.g. if you defined the `module` to have a `num_units` parameter, you can set it via the `module__num_units` argument. 
    - **modifiable parameters must be passed to your module's `__init__` function**

- The `NeuralNet` class allows to directly access parameters of the `pytorch module` by using the `module__` prefix. 
    - e.g. if you defined the `module` to have a `num_units` parameter, you can set it via the `module__num_units` argument. 
    - **modifiable parameters must be passed to your module's `__init__` function**
    
```python
class ClassifierModule(nn.Module):
    def __init__(
            self,
            num_units=10, 
            nonlin=F.relu,
            dropout=0.5,
    ):
        super(ClassifierModule, self).__init__()
        self.num_units = num_units
        self.nonlin = nonlin
        self.dropout = dropout

        self.dense0 = nn.Linear(20, num_units)
        self.nonlin = nonlin
        self.dropout = nn.Dropout(dropout)
        self.dense1 = nn.Linear(num_units, 10)
        self.output = nn.Linear(10, 2)

    def forward(self, X, **kwargs):
        X = self.nonlin(self.dense0(X))
        X = self.dropout(X)
        X = F.relu(self.dense1(X))
        X = F.softmax(self.output(X), dim=-1)
        return X
```

## Usage with sklearn `GridSearchCV`

### Preamble - Special prefixes

- The `NeuralNet` class allows to directly access parameters of the `pytorch module` by using the `module__` prefix. 
    - e.g. if you defined the `module` to have a `num_units` parameter, you can set it via the `module__num_units` argument. 
    - **modifiable parameters must be passed to your module's `__init__` function**

- This allows you to set parameters in an `sklearn GridSearchCV` as shown below.

- In addition to the parameters prefixed by `module__`, you may access a couple of other attributes, such as those of the optimizer by using the `optimizer__` prefix (again, see below). All those special prefixes are stored in the `prefixes_` attribute:

- In addition to the parameters prefixed by `module__`, you may access a couple of other attributes, such as those of the optimizer by using the `optimizer__` prefix (again, see below). All those special prefixes are stored in the `prefixes_` attribute:

In [None]:
print(', '.join(net.prefixes_))

### Performing a grid search

Below we show how to perform a grid search over the learning rate (`lr`), the module's number of hidden units (`module__num_units`), and the module's dropout rate (`module__dropout`).

The Basic steps are:
1. Define the model
2. Define the parameter set to search
3. Create the `GridSearchCV` object and perform the search

#### 1. Define the Model

In [None]:
from sklearn.model_selection import GridSearchCV

In [None]:
net = NeuralNetClassifier(
    ClassifierModule,
    max_epochs=20,
    lr=0.1,
    optimizer__momentum=0.9,
    verbose=0, 
    train_split=False,
)

- set the verbosity level to zero (`verbose=0`) to prevent too much print output from being shown. 
- disable the skorch-internal train-validation split (`train_split=False`) because `GridSearchCV` already splits the training data for us.

#### 2. Define the parameter set to search

In [None]:
params = {
    'lr': [0.05, 0.1],
    'module__num_units': [10, 20],
    'module__dropout': [0, 0.5]
}

#### 3. Create the `GridSearchCV` object... and perform the search

In [None]:
gs = GridSearchCV(net, params, refit=False, cv=3, scoring='accuracy', verbose=2)

#### ... and perform the search

In [None]:
gs.fit(X, y)

In [None]:
print(gs.best_score_, gs.best_params_)

**Important note: you could further nest the `NeuralNetClassifier` within an `sklearn Pipeline`, in which case, just prefix the parameter by the name of the net (e.g. `net__module__num_units`).**

## Multi GPU accelerated grid search

* Hyperparameter tuning takes a long time, the search space can be very large <br></br>

* If we can be train multiple models simultaneously, we can greatly cut that search time down. <br></br>

* Skorch supports multi-GPU training in `GridsearchCV` using Dask <br></br>

* We'll use a two GPU set up running ResNet on CIFAR-10 for the following example

## Initial setup with Dask

* Dask is a python package that can be used to parallelize and scale python libraries like NumPy, Pandas and `scikit-learn`, we also need the `distributed` sub-package 
    * install them in your virtual environment with `pip install --no-index dask distributed` <br></br>
    
* We use a Dask to create a parallelization backend for `GridSearchCV` allowing use to use multiple GPUs<br></br>

* Before running our python code, we need to run a `dask sceduler` and an instance of `dask worker` for each GPU, the key being we set `CUDA_VISIBLE_DEVICES` for each worker<br></br>




## Initial setup with Dask

* Dask is a python package that can be used to parallelize and scale python libraries like NumPy, Pandas and `scikit-learn`
* We use a Dask to create a parallelization backend for `GridSearchCV` allowing use to use multiple GPUs
* Before running our python code, we need to run a `dask sceduler` and an instance of `dask worker` for each GPU, the key being we set `CUDA_VISIBLE_DEVICES` for each worker:

```bash
echo 'Starting scheduler'
dask scheduler &
sleep 10  # give a buffer to make sure the scheduler is started before starting workers

echo 'Scheduler booted, launching workers'
CUDA_VISIBLE_DEVICES=0 dask worker 127.0.0.1:8786 --nthreads 1 &
sleep 10 # This is just to prevent the outputs from being tangled together
CUDA_VISIBLE_DEVICES=1 dask worker 127.0.0.1:8786 --nthreads 1 &
```



### Quick note on job submission 

* You need to ask for enough cores in your job script:
    * One for the scheduler
    * One for each worker (two in our example)
    * One for the python process

## Modifying your python code

* We need to make some simple modifications to our code to use the parallel backend:

```python

# import required parallelization modules
from dask.distributed import Client
from joblib import parallel_backend

...

def main(device, batch_size, lr, max_epochs):

    client = Client('127.0.0.1:8786')

    X_train, X_test, y_train, y_test = get_data()

    print("\nTesting skorch performance")
    tic = time.time()
    torch.manual_seed(0)
    torch.cuda.manual_seed(0)
    net = NeuralNetClassifier(
        ResNet,
        batch_size=batch_size,
        optimizer=torch.optim.Adadelta,
        criterion=torch.nn.CrossEntropyLoss,
        lr=lr,
        device=device,
        max_epochs=max_epochs
    )

    params = {
        'module__num_blocks': [[3,3,3], [5,5,5], [7,7,7]]
    }

    gs = GridSearchCV(net, params, scoring='accuracy', cv=5, verbose=3, refit=True)
    
    with parallel_backend('dask'):
        gs.fit(X_train, y_train)
    print(gs.cv_results_)

    y_pred = gs.best_estimator_.predict(X_test)
    score = accuracy_score(y_test, y_pred)
    time_skorch = time.time() - tic

    print(f'Grid search found model with validation score: {gs.best_score_}')
    print(f'with parameters: {gs.best_params_}')
    print(f'Test score: {score} after {max_epochs} in {time_skorch}s.')


```

## Running the parallel GridSearchCV

```
[c7wilson@gra847 ~]$ nvidia-smi
Sat Nov 25 13:43:51 2023
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  Tesla P100-PCIE...  On   | 00000000:04:00.0 Off |                    0 |
| N/A   61C    P0   129W / 250W |   8910MiB / 12288MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  On   | 00000000:83:00.0 Off |                    0 |
| N/A   60C    P0   130W / 250W |   8620MiB / 12288MiB |    100%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+

+-----------------------------------------------------------------------------+
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|=============================================================================|
|    0   N/A  N/A     32553      C   ...lson/skorchEnv/bin/python     8620MiB |
|    0   N/A  N/A     32600      C   python                            288MiB |
|    1   N/A  N/A     32615      C   ...lson/skorchEnv/bin/python     8618MiB |
+-----------------------------------------------------------------------------+
```

### Looking at the output

In [None]:
! cat docs/parallel-multigpu-13158959.out

In [None]:
cat /Users/collinwilson/projects/skorch_talk/docs/parallel-multigpu-13158959.out | grep -v distributed.worker

In [None]:
cat /Users/collinwilson/projects/skorch_talk/docs/parallel-multigpu-13158959.out | grep CV

## Results

### Mean fit time

|Model   |Serial|Multi-GPU|
|--------|------|---------|
|ResNet20|319.4 ± 0.9|340 ± 2|
|ResNet32|502.2 ± 0.1|524.8 ± 0.8|
|ResNet44|686.33 ± 0.03|711 ±2|

### Total Wall-time

1.28 hours for our multi-GPU grid search vs 2.21 hours for serial, a **42% reduction in running time**.

## Other features

* You can also use [Palladium](https://palladium.readthedocs.io/en/latest/) for parallelism <br></br>
* Integrations with [Hugging Face](https://huggingface.co/) (`Accelerate`, `Tokenizers` and `Transformers`) <br></br>
* [Support for large language models](https://skorch.readthedocs.io/en/stable/llm.html) -  namely few shot and zero shot classifiers


<center><h1>Q&A</h1></center>

<center>Please feel free to reach out to me at collin.wilson@sharcnet.ca</center>