## Skorch is a scikit-learn compatible neural network library that wraps PyTorch. This allows performing manipulations with PyTorch modules as if they were usual scikit-learn estimators. In particular, I find it very useful to do a grid search for PyTorch models via Skorch. 
## In this notebook, I show how one can easily do a grid search for PyTorch model wrapped in Skorch with pre-trained weights initialization at each fit. 
### Although it is not recommended to use pre-defined weights initialization for general purposes, it might be useful for, e.g., exhaustive fine tuning.


### First, let's import libraries and download the MNIST dataset,

In [83]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV

import skorch
from skorch import NeuralNet
from skorch import NeuralNetClassifier

import pickle

import numpy as np

### The dataset is not that large so let's store it in memory.

In [85]:
# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Device: ',device)


# Download MNIST dataset to local drive. A new folder "data" will be created in the current directory to store data
train_dataset = torchvision.datasets.MNIST(root='./data', 
                                           train=True, 
                                           transform=transforms.ToTensor(),  
                                           download=True)

test_dataset = torchvision.datasets.MNIST(root='./data', 
                                          train=False, 
                                          transform=transforms.ToTensor())

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, 
                                           batch_size=len(train_dataset), 
                                           shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=test_dataset, 
                                          batch_size=len(test_dataset), 
                                          shuffle=False)

# Train set
for X, y in train_loader:
    X_train = X
    y_train = y

# Test set
for X, y in test_loader:
    X_test = X
    y_test = y

Device:  cuda


### Let's build a core NN module with PyTorch
#### Two hidden layers. ReLU activation in hidden layers and logsoftmax activation in output layer. Let's also use dropout.

In [38]:
class MyModule(nn.Module):

    def __init__(self, input_size, hidden_size1, hidden_size2, num_classes, dropout):
        super(MyModule, self).__init__()

        self.fc1 = nn.Linear(input_size, hidden_size1) 
        self.relu =nn.ReLU()
        self.fc2 = nn.Linear(hidden_size1, hidden_size2) 
        self.fc3 = nn.Linear(hidden_size2, num_classes)  
        self.logsoftmax = nn.LogSoftmax(dim=-1)
        
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):

        x = x.view(-1, 784)
        out = self.dropout(self.fc1(x))
        out = self.relu(out)
        out = self.dropout(self.fc2(out))
        out = self.relu(out)
        out = self.fc3(out)
        out = self.logsoftmax(out)
                        
        return out

### Here we inherit NeuralNetClassifier class from Skorch but with two modifications.
#### 1) We slightly modify get_loss to add L1 regularization.
#### 2) We modify initialize_module so we can use pre-trained initialization in GridSearchCV if needed.
####    *reinit_from_pretrain*   indicates if the model should be initialized from a pre-trained one.
####    *pretrained_nn*   contains a path to a pickle with the pre-trained model.

In [97]:
class RegularizedNet(NeuralNetClassifier):
    
    def __init__(self, *args, lambda1=0.01, reinit_from_pretrain=False, pretrained_nn=None, **kwargs):
        '''
        If reinit_from_pretrain = True, the model is initialized from a pretrained pickle.
        pretrained_nn containt path to the pickle.
        Otherwise, the model is randomly re-initialized at each fit.
        '''
        super().__init__(*args, **kwargs)
        self.lambda1 = lambda1
        self.reinit_from_pretrain = reinit_from_pretrain
        self.pretrained_nn = pretrained_nn
        
    def get_loss(self, y_pred, y_true, X=None, training=False):
        loss = super().get_loss(y_pred, y_true, X=X, training=training)
        loss += self.lambda1 * sum([w.abs().sum() for w in self.module_.parameters()])
        return loss
    
    def initialize_module(self):
        
        super().initialize_module()
        
        ##### initialization from pre-trained model
        if self.reinit_from_pretrain:
            self.module_.load_state_dict(torch.load(self.pretrained_nn))
        return self

#### Here we wrap Skorch around the PyTorch module.

In [88]:
##### Below I define collbacks for train accuracy printing and early stopping criteria.
train_acc = skorch.callbacks.EpochScoring(scoring = accuracy_score,on_train=True, 
                         name='train_acc', lower_is_better=False)

early_stop = skorch.callbacks.EarlyStopping(monitor='valid_acc', lower_is_better=False)

callbacks = [train_acc, early_stop]


new_net = RegularizedNet(module=MyModule, criterion=torch.nn.CrossEntropyLoss, device='cuda',
                        optimizer=torch.optim.SGD, lr = 0.2, lambda1 = 0,  module__dropout = 0.2,
                        optimizer__weight_decay = 0.0, max_epochs = 45, callbacks=callbacks, batch_size=256,
                        module__input_size= 784, module__hidden_size1=128, module__hidden_size2=64, module__num_classes=10)

#### Let's train the NN.

In [89]:
new_net.fit(X_train, y = y_train)
y_pred_probs = new_net.predict(test_dataset)

  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.7355[0m        [32m0.9199[0m       [35m0.8952[0m        [31m0.3593[0m  0.6321
      2       [36m0.8920[0m        [32m0.3687[0m       [35m0.9203[0m        [31m0.2652[0m  0.6481
      3       [36m0.9165[0m        [32m0.2868[0m       [35m0.9334[0m        [31m0.2170[0m  0.6802
      4       [36m0.9305[0m        [32m0.2378[0m       [35m0.9419[0m        [31m0.1878[0m  0.6561
      5       [36m0.9399[0m        [32m0.2070[0m       [35m0.9538[0m        [31m0.1579[0m  0.6371
      6       [36m0.9460[0m        [32m0.1831[0m       [35m0.9585[0m        [31m0.1391[0m  0.6752
      7       [36m0.9522[0m        [32m0.1634[0m       [35m0.9610[0m        [31m0.1282[0m  0.6351
      8       [36m0.9551[0m        [32m0.1492[0m       [35m0.9648[0m        [31m0.1192[0m  0.6792
      9     

#### The test accuracy is 97.95%. We will use this pre-trained model for weights initialization during GridSearchCV.

In [100]:
y_pred = new_net.predict(X_test)
accuracy_score(y_test.cpu(),  torch.tensor(y_pred))*100

97.95

#### Save the model

In [92]:
new_net.save_params(f_params='pre-trained-NN.pkl')

### Now, let's make a grid search with sklearn keeping weights initialization from a pre-trained model at each fit.
### Keep in mind that we cannot alter the number of neurons in each layer anymore because we initialize the model from a pre-trained one.

In [93]:
new_net_gs = RegularizedNet(module=MyModule, criterion=torch.nn.CrossEntropyLoss, device='cuda',
                        optimizer= torch.optim.SGD, lr = 0.2, lambda1 = 0.0,  module__dropout = 0.0,
                        optimizer__weight_decay = 0, max_epochs = 45,callbacks=callbacks, batch_size=256,
                        module__input_size= 784, module__hidden_size1=128, module__hidden_size2=64, module__num_classes=10,
                        reinit_from_pretrain=True, pretrained_nn='pre-trained-NN.pkl')


grid = {
    'lambda1': [ 0.0, 0.0001],
    'lr': [0.01, 0.05],
    }


gs = GridSearchCV(new_net_gs, grid, refit=True, cv=3, verbose=3)

gs.fit(X_train, y_train)

#Report Best Parameters
print(gs.best_score_, gs.best_params_)

Fitting 3 folds for each of 4 candidates, totalling 12 fits
  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.9961[0m        [32m0.0152[0m       [35m0.9971[0m        [31m0.0150[0m  0.4121
      2       [36m0.9965[0m        [32m0.0141[0m       [35m0.9975[0m        [31m0.0148[0m  0.4521
      3       [36m0.9967[0m        [32m0.0135[0m       0.9975        [31m0.0146[0m  0.4121
      4       [36m0.9968[0m        [32m0.0130[0m       0.9975        [31m0.0145[0m  0.4191
      5       [36m0.9970[0m        [32m0.0127[0m       0.9975        [31m0.0144[0m  0.4091
      6       [36m0.9972[0m        [32m0.0124[0m       [35m0.9976[0m        [31m0.0143[0m  0.4521
      7       [36m0.9974[0m        [32m0.0121[0m       0.9976        [31m0.0142[0m  0.4111
      8       [36m0.9975[0m        [32m0.0119[0m       0.9975        [31m0.0141[0m  0.41

      3       [36m0.9950[0m        [32m0.4314[0m       0.9791        [31m0.4770[0m  0.4551
      4       [36m0.9952[0m        [32m0.4295[0m       0.9789        [31m0.4756[0m  0.4151
      5       [36m0.9954[0m        [32m0.4278[0m       0.9792        [31m0.4742[0m  0.4121
      6       [36m0.9956[0m        [32m0.4261[0m       0.9791        [31m0.4728[0m  0.4131
Stopping since valid_acc has not improved in the last 5 epochs.
[CV 2/3] END ........................lambda1=0.0001, lr=0.01; total time=   3.2s
  epoch    train_acc    train_loss    valid_acc    valid_loss     dur
-------  -----------  ------------  -----------  ------------  ------
      1       [36m0.9940[0m        [32m0.4371[0m       [35m0.9792[0m        [31m0.4806[0m  0.4261
      2       [36m0.9943[0m        [32m0.4345[0m       0.9792        [31m0.4787[0m  0.4161
      3       [36m0.9947[0m        [32m0.4323[0m       0.9791        [31m0.4771[0m  0.4601
      4       [36m0.9949[

### Final test accuracy.

In [99]:
y_pred = gs.best_estimator_.predict(X_test)
accuracy_score(y_test.cpu(),  torch.tensor(y_pred))*100

98.03

In [82]:
#gs.cv_results_