# INM427 Neural Computing Final Coursework
## By Ho Yin Tam

Import the library.

In [36]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.model_selection import GridSearchCV
from skorch import NeuralNetClassifier 
import time
from joblib import dump, load

Import the training data and convert it into tensor datatype.

In [37]:
X_train = pd.read_csv('X_train.csv')
y_train = pd.read_csv('y_train.csv')


X_train_tensor = torch.tensor(X_train.to_numpy()).float()

# Convert the dataframe of the target label into series
y_train = y_train.squeeze()
y_train_tensor = torch.tensor(y_train.to_numpy()).long()

# Display the shape and data type of the feature and target label
print('Shape of training data (features):', X_train_tensor.shape)
print('Shape of training data (target label):', y_train_tensor.shape)
print('Data type:', X_train_tensor.dtype)
print('Data type:', y_train_tensor.dtype)

Shape of training data (features): torch.Size([596, 11])
Shape of training data (target label): torch.Size([596])
Data type: torch.float32
Data type: torch.int64


Define a class for the multilayer perceptron with one hidden layer.

In [38]:
class MLP_onehiddenlayer(nn.Module):
    def __init__(self, input_size, output_size, hidden_size):
        super(MLP_onehiddenlayer, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.relu1 = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim = 1)
     
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu1(x)
        x = self.fc2(x)
        x = self.softmax(x)
        return x
    
MLP_clf = NeuralNetClassifier(MLP_onehiddenlayer, module__input_size = 11, module__output_size = 2)

Set up the grid search parameters of hidden size, momentum and learning rate.

In [39]:
# Number of neurons in hidden layer: 10, 100, 1000
# Momentum: 0.7, 0.8, 0.9 
# Learning rate: 0.001, 0.01, 0.1 

gs_parameters = {'module__hidden_size': [10, 100, 1000],
                 'optimizer__momentum': [0.7, 0.8, 0.9],
                 'lr': [0.001, 0.01, 0.1]}
gs_parameters

{'module__hidden_size': [10, 100, 1000],
 'optimizer__momentum': [0.7, 0.8, 0.9],
 'lr': [0.001, 0.01, 0.1]}

Initiate the multilayer perceptron and perform grid search with 10-fold cross validation.

In [40]:
# Perform grid search with 10-fold cross validation
grid_search = GridSearchCV(estimator = MLP_clf, param_grid = gs_parameters, cv = 10, refit = True)

# Train the model and count the training time
start_time = time.time()
grid_search.fit(X_train_tensor, y_train_tensor)
end_time = time.time()

# Display the best score and best parameters
print('Best score is:', grid_search.best_score_)
print('Best parameters is:', grid_search.best_params_)
print('Training time:', end_time - start_time)

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7479[0m       [32m0.5370[0m        [35m0.7502[0m  0.0289
      2        [36m0.7472[0m       0.5370        [35m0.7492[0m  0.0302
      3        [36m0.7463[0m       0.5370        [35m0.7481[0m  0.0282
      4        [36m0.7453[0m       0.5370        [35m0.7469[0m  0.0277
      5        [36m0.7444[0m       0.5370        [35m0.7458[0m  0.0279
      6        [36m0.7434[0m       0.5370        [35m0.7447[0m  0.0249
      7        [36m0.7425[0m       0.5370        [35m0.7436[0m  0.0311
      8        [36m0.7416[0m       0.5370        [35m0.7426[0m  0.0291
      9        [36m0.7407[0m       0.5370        [35m0.7415[0m  0.0248
     10        [36m0.7398[0m       0.5370        [35m0.7405[0m  0.0280
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1      

      3        [36m0.6676[0m       [32m0.5741[0m        [35m0.6530[0m  0.0319
      4        [36m0.6656[0m       0.5741        [35m0.6512[0m  0.0295
      5        [36m0.6636[0m       0.5741        [35m0.6495[0m  0.0313
      6        [36m0.6617[0m       [32m0.5833[0m        [35m0.6478[0m  0.0809
      7        [36m0.6599[0m       0.5833        [35m0.6462[0m  0.0340
      8        [36m0.6581[0m       0.5833        [35m0.6445[0m  0.0299
      9        [36m0.6563[0m       0.5833        [35m0.6430[0m  0.0249
     10        [36m0.6547[0m       0.5833        [35m0.6415[0m  0.0292
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7345[0m       [32m0.5370[0m        [35m0.7354[0m  0.0256
      2        [36m0.7325[0m       0.5370        [35m0.7323[0m  0.0237
      3        [36m0.7296[0m       0.5370        [35m0.7287[0m  0.0247
      4        [36m0.7263[0m    

      6        [36m0.7496[0m       [32m0.3056[0m        [35m0.7636[0m  0.0720
      7        [36m0.7467[0m       [32m0.3241[0m        [35m0.7603[0m  0.0346
      8        [36m0.7438[0m       0.3241        [35m0.7571[0m  0.0309
      9        [36m0.7411[0m       0.3241        [35m0.7541[0m  0.0279
     10        [36m0.7385[0m       [32m0.3519[0m        [35m0.7511[0m  0.0314
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6770[0m       [32m0.5370[0m        [35m0.7001[0m  0.0267
      2        [36m0.6758[0m       0.5370        [35m0.6983[0m  0.0269
      3        [36m0.6740[0m       [32m0.5648[0m        [35m0.6961[0m  0.0335
      4        [36m0.6721[0m       [32m0.5833[0m        [35m0.6939[0m  0.0318
      5        [36m0.6701[0m       0.5833        [35m0.6917[0m  0.0302
      6        [36m0.6681[0m       0.5833        [35m0.6895[0m  0.0284
      7

      6        [36m0.6421[0m       [32m0.7593[0m        [35m0.6324[0m  0.0317
      7        [36m0.6400[0m       0.7593        [35m0.6300[0m  0.0331
      8        [36m0.6379[0m       [32m0.7685[0m        [35m0.6276[0m  0.0346
      9        [36m0.6358[0m       [32m0.7778[0m        [35m0.6252[0m  0.0287
     10        [36m0.6336[0m       [32m0.7870[0m        [35m0.6229[0m  0.0303
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6778[0m       [32m0.5185[0m        [35m0.6793[0m  0.0315
      2        [36m0.6769[0m       [32m0.5278[0m        [35m0.6780[0m  0.0432
      3        [36m0.6754[0m       [32m0.5833[0m        [35m0.6763[0m  0.0327
      4        [36m0.6735[0m       [32m0.6019[0m        [35m0.6742[0m  0.0250
      5        [36m0.6713[0m       [32m0.6389[0m        [35m0.6720[0m  0.0269
      6        [36m0.6690[0m       0.6389        [35

      5        [36m0.6593[0m       0.7685        [35m0.6359[0m  0.0321
      6        [36m0.6530[0m       0.7685        [35m0.6298[0m  0.0704
      7        [36m0.6472[0m       0.7593        [35m0.6241[0m  0.0349
      8        [36m0.6416[0m       0.7593        [35m0.6186[0m  0.0320
      9        [36m0.6364[0m       0.7500        [35m0.6135[0m  0.0331
     10        [36m0.6315[0m       0.7593        [35m0.6086[0m  0.0370
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7011[0m       [32m0.4815[0m        [35m0.7061[0m  0.0313
      2        [36m0.6913[0m       [32m0.4907[0m        [35m0.6935[0m  0.0388
      3        [36m0.6793[0m       [32m0.5556[0m        [35m0.6810[0m  0.0406
      4        [36m0.6679[0m       [32m0.5926[0m        [35m0.6696[0m  0.0385
      5        [36m0.6576[0m       [32m0.6481[0m        [35m0.6593[0m  0.0310
      6        

      4        [36m0.6773[0m       [32m0.5370[0m        [35m0.6786[0m  0.0388
      5        [36m0.6675[0m       0.5370        [35m0.6687[0m  0.0383
      6        [36m0.6587[0m       [32m0.5926[0m        [35m0.6599[0m  0.0424
      7        [36m0.6508[0m       [32m0.6389[0m        [35m0.6519[0m  0.0375
      8        [36m0.6437[0m       [32m0.6667[0m        [35m0.6447[0m  0.0285
      9        [36m0.6374[0m       [32m0.6852[0m        [35m0.6382[0m  0.0290
     10        [36m0.6316[0m       [32m0.7037[0m        [35m0.6323[0m  0.0285
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7077[0m       [32m0.3611[0m        [35m0.7140[0m  0.0340
      2        [36m0.7017[0m       [32m0.3889[0m        [35m0.7053[0m  0.0382
      3        [36m0.6932[0m       [32m0.4444[0m        [35m0.6955[0m  0.0330
      4        [36m0.6839[0m       [32m0.4907[0m   

      3        [36m0.6415[0m       0.7130        [35m0.6197[0m  0.0404
      4        [36m0.6327[0m       [32m0.7315[0m        [35m0.6092[0m  0.0420
      5        [36m0.6230[0m       [32m0.7593[0m        [35m0.5982[0m  0.0369
      6        [36m0.6130[0m       [32m0.7685[0m        [35m0.5872[0m  0.0320
      7        [36m0.6031[0m       [32m0.7870[0m        [35m0.5764[0m  0.0306
      8        [36m0.5934[0m       [32m0.8056[0m        [35m0.5660[0m  0.0318
      9        [36m0.5840[0m       0.8056        [35m0.5561[0m  0.0309
     10        [36m0.5751[0m       0.7963        [35m0.5467[0m  0.0300
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6943[0m       [32m0.5185[0m        [35m0.7045[0m  0.0354
      2        [36m0.6877[0m       [32m0.5648[0m        [35m0.6939[0m  0.0352
      3        [36m0.6769[0m       [32m0.6111[0m        [35m0.6805[

      2        [36m0.6312[0m       [32m0.7130[0m        [35m0.6080[0m  0.0610
      3        [36m0.6032[0m       [32m0.7315[0m        [35m0.5789[0m  0.0590
      4        [36m0.5791[0m       [32m0.7593[0m        [35m0.5545[0m  0.0650
      5        [36m0.5588[0m       [32m0.7685[0m        [35m0.5343[0m  0.0500
      6        [36m0.5418[0m       [32m0.7778[0m        [35m0.5176[0m  0.0385
      7        [36m0.5276[0m       [32m0.7870[0m        [35m0.5036[0m  0.0360
      8        [36m0.5156[0m       0.7870        [35m0.4918[0m  0.0319
      9        [36m0.5053[0m       [32m0.7963[0m        [35m0.4818[0m  0.0385
     10        [36m0.4965[0m       0.7870        [35m0.4733[0m  0.0330
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7474[0m       [32m0.4537[0m        [35m0.7150[0m  0.0484
      2        [36m0.6996[0m       [32m0.6019[0m        [35

      2        [36m0.5902[0m       [32m0.7593[0m        [35m0.5570[0m  0.0484
      3        [36m0.5617[0m       0.7593        [35m0.5267[0m  0.0501
      4        [36m0.5360[0m       0.7593        [35m0.5020[0m  0.0534
      5        [36m0.5147[0m       [32m0.7778[0m        [35m0.4829[0m  0.0381
      6        [36m0.4976[0m       [32m0.7963[0m        [35m0.4682[0m  0.0372
      7        [36m0.4840[0m       0.7963        [35m0.4570[0m  0.0429
      8        [36m0.4730[0m       [32m0.8056[0m        [35m0.4482[0m  0.0499
      9        [36m0.4641[0m       [32m0.8241[0m        [35m0.4413[0m  0.0440
     10        [36m0.4567[0m       0.8241        [35m0.4357[0m  0.0417
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7439[0m       [32m0.4630[0m        [35m0.7196[0m  0.0357
      2        [36m0.6956[0m       [32m0.6204[0m        [35m0.6526[0m  0.035

      3        [36m0.5621[0m       [32m0.8148[0m        [35m0.5029[0m  0.0421
      4        [36m0.5244[0m       [32m0.8241[0m        [35m0.4618[0m  0.0406
      5        [36m0.4927[0m       [32m0.8519[0m        [35m0.4329[0m  0.0333
      6        [36m0.4695[0m       [32m0.8611[0m        [35m0.4145[0m  0.0333
      7        [36m0.4536[0m       [32m0.8796[0m        [35m0.4022[0m  0.0383
      8        [36m0.4426[0m       0.8796        [35m0.3936[0m  0.0430
      9        [36m0.4345[0m       0.8704        [35m0.3878[0m  0.0390
     10        [36m0.4284[0m       0.8704        [35m0.3841[0m  0.0339
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6844[0m       [32m0.7315[0m        [35m0.6389[0m  0.0293
      2        [36m0.6075[0m       0.7130        [35m0.5663[0m  0.0358
      3        [36m0.5563[0m       [32m0.7407[0m        [35m0.5210[0m  0.038

      3        [36m0.7454[0m       [32m0.4537[0m        [35m0.7326[0m  0.0325
      4        [36m0.7205[0m       [32m0.5000[0m        [35m0.7091[0m  0.0334
      5        [36m0.6996[0m       [32m0.5741[0m        [35m0.6889[0m  0.0300
      6        [36m0.6816[0m       [32m0.6389[0m        [35m0.6705[0m  0.0295
      7        [36m0.6656[0m       0.6389        [35m0.6535[0m  0.0314
      8        [36m0.6508[0m       [32m0.7130[0m        [35m0.6374[0m  0.0305
      9        [36m0.6367[0m       [32m0.7593[0m        [35m0.6217[0m  0.0284
     10        [36m0.6234[0m       0.7500        [35m0.6065[0m  0.0270
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7421[0m       [32m0.4630[0m        [35m0.7141[0m  0.0235
      2        [36m0.7267[0m       [32m0.5000[0m        [35m0.6978[0m  0.0264
      3        [36m0.7102[0m       [32m0.5463[0m        [35

      2        [36m0.7103[0m       [32m0.4630[0m        [35m0.7164[0m  0.0292
      3        [36m0.7018[0m       [32m0.4722[0m        [35m0.7062[0m  0.0307
      4        [36m0.6929[0m       [32m0.5463[0m        [35m0.6959[0m  0.0291
      5        [36m0.6841[0m       [32m0.5926[0m        [35m0.6857[0m  0.0303
      6        [36m0.6753[0m       [32m0.6111[0m        [35m0.6755[0m  0.0591
      7        [36m0.6663[0m       [32m0.6574[0m        [35m0.6652[0m  0.0561
      8        [36m0.6573[0m       [32m0.6667[0m        [35m0.6548[0m  0.0421
      9        [36m0.6480[0m       [32m0.6944[0m        [35m0.6442[0m  0.0470
     10        [36m0.6385[0m       [32m0.7130[0m        [35m0.6336[0m  0.0480
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7225[0m       [32m0.5370[0m        [35m0.6956[0m  0.0436
      2        [36m0.6999[0m       [32m0.54

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6347[0m       [32m0.7778[0m        [35m0.6170[0m  0.0350
      2        [36m0.6264[0m       0.7778        [35m0.6041[0m  0.0286
      3        [36m0.6138[0m       [32m0.7870[0m        [35m0.5888[0m  0.0276
      4        [36m0.5994[0m       [32m0.8056[0m        [35m0.5727[0m  0.0299
      5        [36m0.5845[0m       0.8056        [35m0.5566[0m  0.0250
      6        [36m0.5697[0m       0.8056        [35m0.5408[0m  0.0273
      7        [36m0.5551[0m       0.8056        [35m0.5256[0m  0.0339
      8        [36m0.5409[0m       0.7963        [35m0.5112[0m  0.0309
      9        [36m0.5277[0m       0.7778        [35m0.4980[0m  0.0305
     10        [36m0.5154[0m       0.7778        [35m0.4860[0m  0.0235
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  --

     10        [36m0.5249[0m       0.8148        [35m0.4723[0m  0.0522
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7196[0m       [32m0.5833[0m        [35m0.6866[0m  0.0357
      2        [36m0.7012[0m       [32m0.6111[0m        [35m0.6606[0m  0.0364
      3        [36m0.6730[0m       [32m0.6574[0m        [35m0.6302[0m  0.0371
      4        [36m0.6427[0m       [32m0.7222[0m        [35m0.6008[0m  0.0268
      5        [36m0.6146[0m       [32m0.7407[0m        [35m0.5740[0m  0.0291
      6        [36m0.5900[0m       [32m0.7593[0m        [35m0.5511[0m  0.0267
      7        [36m0.5689[0m       [32m0.7685[0m        [35m0.5305[0m  0.0282
      8        [36m0.5500[0m       [32m0.7778[0m        [35m0.5116[0m  0.0330
      9        [36m0.5328[0m       [32m0.7870[0m        [35m0.4945[0m  0.0315
     10        [36m0.5166[0m       [32m0.8056[0m   

      8        [36m0.4697[0m       0.7963        [35m0.4442[0m  0.0321
      9        [36m0.4592[0m       0.7963        [35m0.4354[0m  0.0623
     10        [36m0.4505[0m       0.8056        [35m0.4284[0m  0.0355
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6627[0m       [32m0.8056[0m        [35m0.6195[0m  0.0271
      2        [36m0.6279[0m       0.7963        [35m0.5832[0m  0.0303
      3        [36m0.5976[0m       0.8056        [35m0.5556[0m  0.0295
      4        [36m0.5745[0m       [32m0.8241[0m        [35m0.5334[0m  0.0327
      5        [36m0.5550[0m       0.8241        [35m0.5149[0m  0.0324
      6        [36m0.5380[0m       0.8148        [35m0.4992[0m  0.0303
      7        [36m0.5234[0m       0.8148        [35m0.4859[0m  0.0285
      8        [36m0.5108[0m       0.8148        [35m0.4745[0m  0.0296
      9        [36m0.4998[0m       0.8148

      8        [36m0.4768[0m       0.7778        [35m0.4638[0m  0.0404
      9        [36m0.4657[0m       0.7685        [35m0.4547[0m  0.0393
     10        [36m0.4563[0m       0.7685        [35m0.4473[0m  0.0414
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6842[0m       [32m0.7037[0m        [35m0.6197[0m  0.0376
      2        [36m0.6071[0m       [32m0.7963[0m        [35m0.5544[0m  0.0396
      3        [36m0.5598[0m       0.7685        [35m0.5161[0m  0.0479
      4        [36m0.5292[0m       0.7778        [35m0.4839[0m  0.0660
      5        [36m0.5005[0m       [32m0.8241[0m        [35m0.4598[0m  0.0319
      6        [36m0.4781[0m       0.8241        [35m0.4435[0m  0.0285
      7        [36m0.4619[0m       0.8241        [35m0.4316[0m  0.0365
      8        [36m0.4498[0m       0.8241        [35m0.4226[0m  0.0319
      9        [36m0.4406[0m    

      8        [36m0.4265[0m       0.8148        [35m0.3933[0m  0.0323
      9        [36m0.4176[0m       [32m0.8333[0m        0.3937  0.0354
     10        [36m0.4106[0m       0.8241        0.3948  0.0364
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7315[0m       [32m0.6667[0m        [35m0.6694[0m  0.0329
      2        [36m0.6322[0m       [32m0.7037[0m        [35m0.5856[0m  0.0349
      3        [36m0.5782[0m       [32m0.7315[0m        [35m0.5306[0m  0.0302
      4        [36m0.5296[0m       [32m0.7778[0m        [35m0.4713[0m  0.0279
      5        [36m0.4752[0m       [32m0.8519[0m        [35m0.4376[0m  0.0343
      6        [36m0.4439[0m       0.8333        [35m0.4228[0m  0.0315
      7        [36m0.4272[0m       0.8426        [35m0.4113[0m  0.0315
      8        [36m0.4152[0m       0.8333        [35m0.4043[0m  0.0280
      9        [36m0.409

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.5927[0m       [32m0.7963[0m        [35m0.4827[0m  0.0720
      2        [36m0.4770[0m       [32m0.8241[0m        [35m0.4197[0m  0.0438
      3        [36m0.4345[0m       0.8148        [35m0.4061[0m  0.0536
      4        [36m0.4194[0m       0.8241        [35m0.4031[0m  0.0376
      5        [36m0.4073[0m       0.8148        0.4032  0.0394
      6        [36m0.3986[0m       0.8056        0.4040  0.0342
      7        [36m0.3918[0m       0.8056        0.4048  0.0329
      8        [36m0.3860[0m       0.8056        0.4053  0.0403
      9        [36m0.3811[0m       0.8056        0.4057  0.0320
     10        [36m0.3768[0m       0.8056        0.4060  0.0313
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6735[0m       [32m0.8241[0m      

     10        [36m0.3642[0m       0.7963        0.4058  0.0522
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6293[0m       [32m0.8056[0m        [35m0.4898[0m  0.0530
      2        [36m0.4751[0m       [32m0.8241[0m        [35m0.4058[0m  0.0431
      3        [36m0.4274[0m       [32m0.8333[0m        [35m0.3939[0m  0.0460
      4        [36m0.4083[0m       0.8241        0.3968  0.0490
      5        [36m0.3950[0m       0.8241        0.4030  0.0469
      6        [36m0.3857[0m       0.8241        0.4094  0.0462
      7        [36m0.3785[0m       0.8241        0.4141  0.0483
      8        [36m0.3738[0m       0.8241        0.4163  0.0380
      9        [36m0.3699[0m       0.8241        0.4165  0.0313
     10        [36m0.3660[0m       0.8148        0.4157  0.0370
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------ 

     10        [36m0.3481[0m       0.8241        0.4268  0.0466
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6452[0m       [32m0.7778[0m        [35m0.4762[0m  0.0422
      2        [36m0.4754[0m       [32m0.8241[0m        [35m0.4050[0m  0.0361
      3        [36m0.4372[0m       0.8241        0.4150  0.0314
      4        [36m0.4274[0m       0.8241        0.4291  0.0330
      5        [36m0.4044[0m       0.8056        0.4394  0.0449
      6        [36m0.3774[0m       0.8148        0.4509  0.0410
      7        [36m0.3644[0m       0.8148        0.4603  0.0364
      8        [36m0.3618[0m       0.8148        0.4587  0.0370
      9        [36m0.3567[0m       0.8148        0.4465  0.0339
     10        [36m0.3489[0m       0.8148        0.4325  0.0363
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1   

      9        [36m0.3803[0m       0.7963        0.4035  0.0373
     10        [36m0.3764[0m       0.7963        0.4051  0.0337
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7082[0m       [32m0.7130[0m        [35m0.6637[0m  0.0249
      2        [36m0.6524[0m       0.6667        [35m0.6241[0m  0.0230
      3        [36m0.6108[0m       [32m0.7778[0m        [35m0.5658[0m  0.0279
      4        [36m0.5549[0m       [32m0.7870[0m        [35m0.4977[0m  0.0604
      5        [36m0.4981[0m       [32m0.8056[0m        [35m0.4480[0m  0.0339
      6        [36m0.4590[0m       0.7870        [35m0.4244[0m  0.0283
      7        [36m0.4356[0m       0.8056        [35m0.4173[0m  0.0328
      8        [36m0.4195[0m       [32m0.8148[0m        0.4177  0.0256
      9        [36m0.4075[0m       0.8056        0.4204  0.0291
     10        [36m0.3982[0m       0.8056        0

      2        [36m0.6607[0m       [32m0.7130[0m        [35m0.6232[0m  0.0311
      3        [36m0.5977[0m       [32m0.7685[0m        [35m0.5576[0m  0.0343
      4        [36m0.5312[0m       [32m0.7963[0m        [35m0.4911[0m  0.0293
      5        [36m0.4710[0m       [32m0.8148[0m        [35m0.4477[0m  0.0300
      6        [36m0.4311[0m       0.8148        [35m0.4305[0m  0.0357
      7        [36m0.4093[0m       0.8148        0.4310  0.0338
      8        [36m0.3983[0m       [32m0.8333[0m        0.4379  0.0342
      9        [36m0.3910[0m       [32m0.8426[0m        0.4438  0.0375
     10        [36m0.3850[0m       0.8426        0.4471  0.0315
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7432[0m       [32m0.5463[0m        [35m0.6752[0m  0.0269
      2        [36m0.6611[0m       [32m0.6204[0m        [35m0.6221[0m  0.0303
      3        [36m0.607

      7        [36m0.3923[0m       0.8519        0.4080  0.0302
      8        [36m0.3727[0m       0.8241        0.4175  0.0279
      9        [36m0.3672[0m       0.8241        0.4166  0.0300
     10        [36m0.3614[0m       0.8426        0.3951  0.0319
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6818[0m       [32m0.6852[0m        [35m0.6440[0m  0.0331
      2        [36m0.6148[0m       [32m0.7222[0m        [35m0.5396[0m  0.0330
      3        [36m0.5019[0m       [32m0.7870[0m        [35m0.4401[0m  0.0263
      4        [36m0.4386[0m       [32m0.8056[0m        [35m0.4279[0m  0.0295
      5        [36m0.4190[0m       [32m0.8241[0m        0.4405  0.0261
      6        [36m0.3873[0m       0.7963        0.4525  0.0253
      7        [36m0.3771[0m       0.8056        0.4543  0.0335
      8        [36m0.3736[0m       0.8148        0.4419  0.0343
      9      

      5        [36m0.3853[0m       0.8611        0.3830  0.0900
      6        [36m0.3771[0m       0.8519        0.3855  0.0361
      7        [36m0.3701[0m       0.8519        0.3869  0.0311
      8        [36m0.3639[0m       0.8519        0.3883  0.0293
      9        [36m0.3586[0m       0.8519        0.3893  0.0295
     10        [36m0.3538[0m       0.8426        0.3903  0.0316
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6392[0m       [32m0.8241[0m        [35m0.4849[0m  0.0307
      2        [36m0.4801[0m       [32m0.8333[0m        [35m0.4047[0m  0.0300
      3        [36m0.4274[0m       0.8241        [35m0.3976[0m  0.0302
      4        [36m0.4061[0m       0.8241        0.4044  0.0301
      5        [36m0.3925[0m       0.8241        0.4115  0.0300
      6        [36m0.3829[0m       0.8148        0.4152  0.0303
      7        [36m0.3755[0m       0.8148      

      4        [36m0.4137[0m       0.8426        0.4099  0.0285
      5        [36m0.3936[0m       0.8333        0.4126  0.0289
      6        [36m0.3817[0m       0.8241        [35m0.4044[0m  0.0309
      7        [36m0.3699[0m       0.8333        [35m0.3954[0m  0.0361
      8        [36m0.3607[0m       0.8333        [35m0.3911[0m  0.0280
      9        [36m0.3531[0m       0.8333        [35m0.3910[0m  0.0306
     10        [36m0.3468[0m       0.8241        0.3919  0.0274
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6431[0m       [32m0.7685[0m        [35m0.4895[0m  0.0229
      2        [36m0.4851[0m       [32m0.8241[0m        [35m0.3939[0m  0.0241
      3        [36m0.4226[0m       0.8148        0.4003  0.0193
      4        [36m0.3942[0m       0.8056        0.4193  0.0229
      5        [36m0.3720[0m       0.8056        0.4374  0.0275
      6        [36m0.

      3        [36m0.4152[0m       0.8426        0.4444  0.0306
      4        [36m0.4151[0m       0.8241        0.4594  0.0602
      5        [36m0.3751[0m       0.8333        0.4489  0.0291
      6        [36m0.3527[0m       0.8241        [35m0.4163[0m  0.0279
      7        [36m0.3412[0m       0.8333        [35m0.3943[0m  0.0303
      8        [36m0.3313[0m       0.8056        0.3997  0.0320
      9        [36m0.3224[0m       0.8148        0.4159  0.0290
     10        [36m0.3101[0m       0.8148        0.4292  0.0264
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.6583[0m       [32m0.7963[0m        [35m0.4545[0m  0.0270
      2        [36m0.4802[0m       [32m0.8241[0m        [35m0.3702[0m  0.0270
      3        [36m0.4269[0m       0.8148        0.3914  0.0269
      4        [36m0.3995[0m       0.7963        0.4300  0.0269
      5        [36m0.3617[0m       0.7

      2        0.5590       [32m0.7778[0m        0.7214  0.0414
      3        0.5343       0.7500        0.8264  0.0343
      4        0.5762       0.7593        0.7726  0.0335
      5        0.5376       0.7500        [35m0.6111[0m  0.0319
      6        [36m0.4834[0m       0.7778        [35m0.5140[0m  0.0309
      7        [36m0.4436[0m       [32m0.8148[0m        [35m0.4653[0m  0.0420
      8        [36m0.4143[0m       [32m0.8519[0m        [35m0.4442[0m  0.0449
      9        [36m0.3917[0m       [32m0.8611[0m        [35m0.4338[0m  0.0303
     10        [36m0.3727[0m       0.8519        [35m0.4286[0m  0.0347
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.5808[0m       [32m0.7500[0m        [35m0.7512[0m  0.0290
      2        0.6822       0.7315        0.9563  0.0291
      3        0.7305       [32m0.7778[0m        [35m0.7398[0m  0.0301
      4        [36m0.

  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.9493[0m       [32m0.7778[0m        [35m0.7449[0m  0.0273
      2        [36m0.9159[0m       [32m0.8519[0m        [35m0.5402[0m  0.0633
      3        [36m0.5443[0m       0.7870        0.6503  0.0432
      4        [36m0.5198[0m       0.8333        [35m0.4743[0m  0.0410
      5        [36m0.4121[0m       0.8333        [35m0.4213[0m  0.0304
      6        [36m0.3460[0m       0.8148        0.4484  0.0326
      7        [36m0.3253[0m       0.8426        0.4310  0.0304
      8        [36m0.3110[0m       0.8333        0.4495  0.0345
      9        0.3115       0.8148        0.4582  0.0322
     10        0.3119       0.8056        0.4575  0.0349
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.7094[0m       [32m0.7685[0m        [35m1.1143[0m 

      9        0.3754       0.8241        0.4036  0.0320
     10        [36m0.3125[0m       0.8611        0.4138  0.0326
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  ------------  ------
      1        [36m0.5542[0m       [32m0.7500[0m        [35m0.7831[0m  0.0329
      2        0.7476       [32m0.7963[0m        [35m0.6924[0m  0.0428
      3        0.7498       0.7037        1.0222  0.0368
      4        0.8132       [32m0.8889[0m        [35m0.5018[0m  0.0317
      5        0.6402       0.8426        0.6005  0.0328
      6        0.5932       0.7222        0.9012  0.0334
      7        [36m0.5535[0m       0.7963        0.6396  0.0289
      8        [36m0.4754[0m       0.8611        0.5076  0.0282
      9        [36m0.4024[0m       0.7778        0.5728  0.0279
     10        [36m0.3323[0m       0.7685        0.5692  0.0299
  epoch    train_loss    valid_acc    valid_loss     dur
-------  ------------  -----------  -

Extract the mean score of all the grid search parameters. 

In [41]:
grid_search_result = pd.DataFrame(grid_search.cv_results_)
grid_search_result = grid_search_result[['param_module__hidden_size', 'param_lr', 'param_optimizer__momentum', 'mean_test_score']].sort_values(by = 'mean_test_score', ascending = False)
grid_search_result.columns = ['Number of neurons in hidden layer', 'Learning rate', 'Momentum', 'Score']
grid_search_result

Unnamed: 0,Number of neurons in hidden layer,Learning rate,Momentum,Score
22,100,0.1,0.8,0.837232
20,10,0.1,0.9,0.832175
21,100,0.1,0.7,0.830508
24,1000,0.1,0.7,0.830452
23,100,0.1,0.9,0.828814
15,1000,0.01,0.7,0.828785
16,1000,0.01,0.8,0.828729
17,1000,0.01,0.9,0.827147
25,1000,0.1,0.8,0.827119
18,10,0.1,0.7,0.82709


Save the best trained model with the best score for further testing.

In [42]:
best_mlp_model = grid_search.best_estimator_
dump(best_mlp_model, 'best_mlp_model.joblib')

['best_mlp_model.joblib']