# Projeto 3: Classificação binária brest cancer com tuning dos parâmetros

## Etapa 1: Importação das bibliotecas

In [1]:
!pip install skorch

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting skorch
  Downloading skorch-0.12.1-py3-none-any.whl (193 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.7/193.7 KB[0m [31m2.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: skorch
Successfully installed skorch-0.12.1


In [2]:
import pandas as pd
import numpy as np
import sklearn
import skorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.model_selection import GridSearchCV
from skorch import NeuralNetBinaryClassifier

In [3]:
torch.__version__, skorch.__version__, sklearn.__version__

('1.13.1+cu116', '0.12.1', '1.2.1')

## Etapa 2: Base de dados

In [4]:
np.random.seed(123)
torch.manual_seed(123)

<torch._C.Generator at 0x7f21789f1b50>

In [5]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [6]:
previsores = pd.read_csv('/content/drive/MyDrive/Deep Learing de A à Z com PyTorch/Bases/entradas_breast.csv')
classe = pd.read_csv('/content/drive/MyDrive/Deep Learing de A à Z com PyTorch/Bases/saidas_breast.csv')

In [7]:
previsores = np.array(previsores, dtype = 'float32')
classe = np.array(classe, dtype = 'float32').squeeze(1)

In [8]:
previsores.shape

(569, 30)

In [9]:
classe.shape

(569,)

## Etapa 3: Classe para estrutura da rede neural

Na versão atual do Skorch, os resultados da rede neural devem ser retornados sem ativação, ou seja, sem a camada sigmoide no final. Com isto, a função de custo deve ser `BCEWithLogitsLoss`.

In [10]:
class classificador_torch(nn.Module):
  def __init__(self, activation, neurons, initializer):
    super().__init__()
    # 30 -> 16 -> 16 -> 1
    self.dense0 = nn.Linear(30, neurons)
    initializer(self.dense0.weight)
    self.activation0 = activation
    self.dense1 = nn.Linear(neurons, neurons)
    initializer(self.dense1.weight)
    self.activation1 = activation
    self.dense2 = nn.Linear(neurons, 1)
    initializer(self.dense2.weight)
    # self.output = nn.Sigmoid() 

  def forward(self, X):
    X = self.dense0(X)
    X = self.activation0(X)
    X = self.dense1(X)
    X = self.activation1(X)
    X = self.dense2(X)
    # X = self.output(X) 
    return X

## Etapa 4: Skorch

In [11]:
classificador_sklearn = NeuralNetBinaryClassifier(module=classificador_torch,
                                                  lr = 0.001,
                                                  optimizer__weight_decay = 0.0001,
                                                  train_split=False)

## Etapa 5: Tuning dos parâmetros

In [12]:
params = {'batch_size': [10],
          'max_epochs': [100],
          'optimizer': [torch.optim.Adam, torch.optim.SGD],
          'criterion': [torch.nn.BCEWithLogitsLoss], #, torch.nn.HingeEmbeddingLoss], 
          'module__activation': [F.relu, F.tanh],
          'module__neurons': [8, 16], 
          'module__initializer': [torch.nn.init.uniform]} # _, torch.nn.init.normal_]}

In [13]:
params

{'batch_size': [10],
 'max_epochs': [100],
 'optimizer': [torch.optim.adam.Adam, torch.optim.sgd.SGD],
 'criterion': [torch.nn.modules.loss.BCEWithLogitsLoss],
 'module__activation': [<function torch.nn.functional.relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor>,
  <function torch.nn.functional.tanh(input)>],
 'module__neurons': [8, 16],
 'module__initializer': [<function torch.nn.init._make_deprecate.<locals>.deprecated_init(*args, **kwargs)>]}

In [14]:
grid_search = GridSearchCV(estimator=classificador_sklearn, param_grid=params,
                           scoring = 'accuracy', cv = 2)
grid_search = grid_search.fit(previsores, classe)

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


  epoch    train_loss     dur
-------  ------------  ------
      1    [36m27299.8736[0m  0.7114
      2    [36m24343.2036[0m  0.0356
      3    [36m21607.9807[0m  0.0371
      4    [36m19135.1212[0m  0.0366
      5    [36m16931.4985[0m  0.0383
      6    [36m14986.2365[0m  0.0369
      7    [36m13263.7471[0m  0.0351
      8    [36m11731.9747[0m  0.0399
      9    [36m10367.9509[0m  0.0370
     10     [36m9150.4847[0m  0.0376
     11     [36m8060.7277[0m  0.0377
     12     [36m7080.4200[0m  0.0519
     13     [36m6192.2920[0m  0.0367
     14     [36m5382.8836[0m  0.0391
     15     [36m4640.1404[0m  0.0353
     16     [36m3952.5611[0m  0.0351
     17     [36m3309.9055[0m  0.0368
     18     [36m2702.8676[0m  0.0357
     19     [36m2122.0805[0m  0.0337
     20     [36m1558.6191[0m  0.0359
     21     [36m1003.4047[0m  0.0351
     22      [36m478.4396[0m  0.0396
     23      [36m138.4233[0m  0.0369
     24       [36m90.7293[0m  0.0378
    

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      6    [36m11149.3161[0m  0.0396
      7     [36m9765.4888[0m  0.0371
      8     [36m8523.4363[0m  0.0403
      9     [36m7408.3096[0m  0.0383
     10     [36m6405.5098[0m  0.0372
     11     [36m5502.9312[0m  0.0362
     12     [36m4687.3060[0m  0.0383
     13     [36m3944.7730[0m  0.0373
     14     [36m3263.5492[0m  0.0356
     15     [36m2631.7705[0m  0.0367
     16     [36m2037.1366[0m  0.0378
     17     [36m1466.6371[0m  0.0490
     18      [36m907.2859[0m  0.0431
     19      [36m368.1897[0m  0.0426
     20       [36m70.2290[0m  0.0427
     21       [36m64.8931[0m  0.0444
     22       [36m41.8883[0m  0.0487
     23       [36m38.2925[0m  0.0461
     24       [36m34.4358[0m  0.0405
     25       [36m30.3832[0m  0.0363
     26       [36m28.4991[0m  0.0365
     27       [36m26.7977[0m  0.0374
     28       27.0994  0.0429
     29       [36m26.0353[0m  0.0441
     30       [36m24.2245[0m  0.0390
     31       [36m23.3312[0m  0.

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      5        [36m0.6960[0m  0.0416
      6        [36m0.6956[0m  0.0383
      7        [36m0.6951[0m  0.0402
      8        [36m0.6946[0m  0.0379
      9        [36m0.6942[0m  0.0426
     10        [36m0.6937[0m  0.0422
     11        [36m0.6933[0m  0.0376
     12        [36m0.6928[0m  0.0376
     13        [36m0.6924[0m  0.0397
     14        [36m0.6920[0m  0.0401
     15        [36m0.6916[0m  0.0358
     16        [36m0.6912[0m  0.0370
     17        [36m0.6907[0m  0.0409
     18        [36m0.6903[0m  0.0362
     19        [36m0.6900[0m  0.0383
     20        [36m0.6896[0m  0.0363
     21        [36m0.6892[0m  0.0394
     22        [36m0.6888[0m  0.0369
     23        [36m0.6884[0m  0.0421
     24        [36m0.6881[0m  0.0441
     25        [36m0.6877[0m  0.0366
     26        [36m0.6873[0m  0.0306
     27        [36m0.6870[0m  0.0298
     28        [36m0.6866[0m  0.0283
     29        [36m0.6863[0m  0.0264
     30        [36m0.685

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      8        [36m0.7197[0m  0.0302
      9        [36m0.7188[0m  0.0280
     10        [36m0.7179[0m  0.0272
     11        [36m0.7171[0m  0.0279
     12        [36m0.7162[0m  0.0278
     13        [36m0.7154[0m  0.0270
     14        [36m0.7146[0m  0.0275
     15        [36m0.7137[0m  0.0275
     16        [36m0.7129[0m  0.0271
     17        [36m0.7122[0m  0.0280
     18        [36m0.7114[0m  0.0256
     19        [36m0.7106[0m  0.0258
     20        [36m0.7099[0m  0.0274
     21        [36m0.7091[0m  0.0324
     22        [36m0.7084[0m  0.0266
     23        [36m0.7077[0m  0.0305
     24        [36m0.7070[0m  0.0281
     25        [36m0.7063[0m  0.0250
     26        [36m0.7056[0m  0.0279
     27        [36m0.7049[0m  0.0340
     28        [36m0.7042[0m  0.0312
     29        [36m0.7036[0m  0.0278
     30        [36m0.7029[0m  0.0266
     31        [36m0.7023[0m  0.0262
     32        [36m0.7016[0m  0.0286
     33        [36m0.701

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      6    [36m34179.8138[0m  0.0372
      7    [36m28505.4649[0m  0.0356
      8    [36m23266.5548[0m  0.0374
      9    [36m18374.2070[0m  0.0405
     10    [36m13746.7662[0m  0.0372
     11     [36m9300.1456[0m  0.0367
     12     [36m4939.4478[0m  0.0365
     13     [36m1247.1756[0m  0.0365
     14      [36m178.5841[0m  0.0368
     15      191.0518  0.0360
     16      261.4455  0.0357
     17      239.5685  0.0411
     18      236.6770  0.0396
     19      238.6056  0.0378
     20      230.9263  0.0394
     21      223.7866  0.0465
     22      220.1303  0.0363
     23      216.8515  0.0379
     24      204.7685  0.0346
     25      190.7831  0.0347
     26      207.1947  0.0368
     27      196.0289  0.0359
     28      195.7725  0.0359
     29      193.2576  0.0360
     30      182.9705  0.0356
     31      187.2862  0.0363
     32      188.7603  0.0367
     33      179.0847  0.0368
     34      [36m166.9800[0m  0.0367
     35      [36m164.1641[0m  0.0379
 

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      6    [36m44267.9956[0m  0.0361
      7    [36m38596.0506[0m  0.0366
      8    [36m33501.1268[0m  0.0370
      9    [36m28921.3993[0m  0.0381
     10    [36m24799.2329[0m  0.0448
     11    [36m21079.1079[0m  0.0399
     12    [36m17714.4093[0m  0.0411
     13    [36m14659.1855[0m  0.0402
     14    [36m11855.1438[0m  0.0396
     15     [36m9243.7294[0m  0.0403
     16     [36m6774.6104[0m  0.0382
     17     [36m4394.2049[0m  0.0412
     18     [36m2069.4893[0m  0.0392
     19      [36m403.7687[0m  0.0367
     20      [36m196.6151[0m  0.0462
     21      207.4296  0.0531
     22      [36m195.5400[0m  0.0518
     23      [36m190.5075[0m  0.0531
     24      [36m182.9866[0m  0.0598
     25      [36m175.7260[0m  0.0498
     26      [36m167.0982[0m  0.0510
     27      [36m157.2590[0m  0.0507
     28      [36m146.7474[0m  0.0498
     29      [36m142.1947[0m  0.0514
     30      [36m141.3887[0m  0.0492
     31      [36m138.5335[0m  0.

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      7        [36m0.6716[0m  0.0281
      8        [36m0.6714[0m  0.0275
      9        [36m0.6713[0m  0.0271
     10        [36m0.6711[0m  0.0276
     11        [36m0.6710[0m  0.0277
     12        [36m0.6709[0m  0.0277
     13        [36m0.6708[0m  0.0265
     14        [36m0.6706[0m  0.0335
     15        [36m0.6705[0m  0.0303
     16        [36m0.6704[0m  0.0287
     17        [36m0.6703[0m  0.0282
     18        [36m0.6701[0m  0.0288
     19        [36m0.6700[0m  0.0280
     20        [36m0.6699[0m  0.0278
     21        [36m0.6698[0m  0.0274
     22        [36m0.6697[0m  0.0258
     23        [36m0.6696[0m  0.0274
     24        [36m0.6695[0m  0.0274
     25        [36m0.6694[0m  0.0267
     26        [36m0.6692[0m  0.0287
     27        [36m0.6691[0m  0.0284
     28        [36m0.6690[0m  0.0288
     29        [36m0.6689[0m  0.0280
     30        [36m0.6688[0m  0.0295
     31        [36m0.6687[0m  0.0269
     32        [36m0.668

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      7        [36m0.7165[0m  0.0276
      8        [36m0.7157[0m  0.0319
      9        [36m0.7149[0m  0.0268
     10        [36m0.7140[0m  0.0263
     11        [36m0.7132[0m  0.0268
     12        [36m0.7124[0m  0.0262
     13        [36m0.7117[0m  0.0304
     14        [36m0.7109[0m  0.0281
     15        [36m0.7101[0m  0.0267
     16        [36m0.7094[0m  0.0263
     17        [36m0.7086[0m  0.0269
     18        [36m0.7079[0m  0.0264
     19        [36m0.7072[0m  0.0253
     20        [36m0.7065[0m  0.0269
     21        [36m0.7058[0m  0.0252
     22        [36m0.7051[0m  0.0281
     23        [36m0.7044[0m  0.0270
     24        [36m0.7038[0m  0.0268
     25        [36m0.7031[0m  0.0270
     26        [36m0.7025[0m  0.0280
     27        [36m0.7019[0m  0.0287
     28        [36m0.7012[0m  0.0297
     29        [36m0.7006[0m  0.0307
     30        [36m0.7000[0m  0.0288
     31        [36m0.6994[0m  0.0270
     32        [36m0.698

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      5        [36m1.6472[0m  0.0392
      6        [36m1.5726[0m  0.0431
      7        [36m1.4901[0m  0.0359
      8        [36m1.3849[0m  0.0365
      9        [36m1.2230[0m  0.0415
     10        [36m0.9781[0m  0.0394
     11        [36m0.7669[0m  0.0397
     12        [36m0.6936[0m  0.0368
     13        [36m0.6770[0m  0.0352
     14        [36m0.6739[0m  0.0375
     15        [36m0.6731[0m  0.0377
     16        [36m0.6724[0m  0.0376
     17        [36m0.6719[0m  0.0370
     18        [36m0.6715[0m  0.0360
     19        [36m0.6712[0m  0.0366
     20        [36m0.6709[0m  0.0381
     21        [36m0.6707[0m  0.0370
     22        [36m0.6705[0m  0.0377
     23        [36m0.6704[0m  0.0363
     24        [36m0.6702[0m  0.0374
     25        [36m0.6701[0m  0.0352
     26        [36m0.6700[0m  0.0360
     27        [36m0.6699[0m  0.0420
     28        [36m0.6698[0m  0.0369
     29        [36m0.6697[0m  0.0362
     30        [36m0.669

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      4        [36m1.3549[0m  0.0556
      5        [36m1.2859[0m  0.0513
      6        [36m1.2117[0m  0.0500
      7        [36m1.1254[0m  0.0523
      8        [36m1.0169[0m  0.0518
      9        [36m0.8868[0m  0.0492
     10        [36m0.7715[0m  0.0480
     11        [36m0.7064[0m  0.0479
     12        [36m0.6806[0m  0.0497
     13        [36m0.6718[0m  0.0534
     14        [36m0.6686[0m  0.0522
     15        [36m0.6674[0m  0.0551
     16        [36m0.6667[0m  0.0531
     17        [36m0.6664[0m  0.0505
     18        [36m0.6661[0m  0.0504
     19        [36m0.6660[0m  0.0533
     20        [36m0.6658[0m  0.0531
     21        [36m0.6657[0m  0.0520
     22        [36m0.6656[0m  0.0510
     23        [36m0.6656[0m  0.0493
     24        [36m0.6655[0m  0.0541
     25        [36m0.6654[0m  0.0530
     26        [36m0.6654[0m  0.0537
     27        [36m0.6653[0m  0.0520
     28        [36m0.6653[0m  0.0504
     29        [36m0.665

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      7        [36m1.3770[0m  0.0296
      8        [36m1.3452[0m  0.0314
      9        [36m1.3138[0m  0.0304
     10        [36m1.2829[0m  0.0292
     11        [36m1.2525[0m  0.0287
     12        [36m1.2227[0m  0.0351
     13        [36m1.1935[0m  0.0288
     14        [36m1.1650[0m  0.0319
     15        [36m1.1372[0m  0.0306
     16        [36m1.1100[0m  0.0304
     17        [36m1.0837[0m  0.0293
     18        [36m1.0581[0m  0.0287
     19        [36m1.0333[0m  0.0293
     20        [36m1.0094[0m  0.0290
     21        [36m0.9864[0m  0.0314
     22        [36m0.9642[0m  0.0310
     23        [36m0.9430[0m  0.0300
     24        [36m0.9228[0m  0.0294
     25        [36m0.9035[0m  0.0288
     26        [36m0.8851[0m  0.0301
     27        [36m0.8677[0m  0.0278
     28        [36m0.8513[0m  0.0306
     29        [36m0.8358[0m  0.0297
     30        [36m0.8213[0m  0.0289
     31        [36m0.8077[0m  0.0284
     32        [36m0.794

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      7        [36m1.3257[0m  0.0280
      8        [36m1.2962[0m  0.0317
      9        [36m1.2672[0m  0.0278
     10        [36m1.2387[0m  0.0319
     11        [36m1.2107[0m  0.0349
     12        [36m1.1833[0m  0.0300
     13        [36m1.1565[0m  0.0292
     14        [36m1.1303[0m  0.0280
     15        [36m1.1048[0m  0.0294
     16        [36m1.0800[0m  0.0292
     17        [36m1.0559[0m  0.0290
     18        [36m1.0326[0m  0.0287
     19        [36m1.0100[0m  0.0290
     20        [36m0.9883[0m  0.0275
     21        [36m0.9673[0m  0.0299
     22        [36m0.9472[0m  0.0292
     23        [36m0.9280[0m  0.0295
     24        [36m0.9096[0m  0.0288
     25        [36m0.8920[0m  0.0281
     26        [36m0.8754[0m  0.0290
     27        [36m0.8596[0m  0.0323
     28        [36m0.8446[0m  0.0282
     29        [36m0.8305[0m  0.0274
     30        [36m0.8173[0m  0.0286
     31        [36m0.8048[0m  0.0291
     32        [36m0.793

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      6        [36m2.0287[0m  0.0473
      7        [36m1.8666[0m  0.0357
      8        [36m1.5984[0m  0.0361
      9        [36m1.1698[0m  0.0365
     10        [36m0.8075[0m  0.0358
     11        [36m0.6951[0m  0.0349
     12        [36m0.6748[0m  0.0359
     13        [36m0.6708[0m  0.0378
     14        [36m0.6696[0m  0.0369
     15        [36m0.6690[0m  0.0381
     16        [36m0.6688[0m  0.0384
     17        [36m0.6686[0m  0.0421
     18        [36m0.6684[0m  0.0376
     19        [36m0.6684[0m  0.0352
     20        [36m0.6683[0m  0.0351
     21        [36m0.6682[0m  0.0356
     22        [36m0.6682[0m  0.0373
     23        [36m0.6682[0m  0.0387
     24        [36m0.6681[0m  0.0366
     25        [36m0.6681[0m  0.0368
     26        [36m0.6680[0m  0.0354
     27        [36m0.6680[0m  0.0354
     28        [36m0.6680[0m  0.0363
     29        [36m0.6679[0m  0.0410
     30        [36m0.6679[0m  0.0363
     31        [36m0.667

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      4        [36m1.8723[0m  0.0551
      5        [36m1.7417[0m  0.0513
      6        [36m1.6123[0m  0.0514
      7        [36m1.4851[0m  0.0534
      8        [36m1.3601[0m  0.0547
      9        [36m1.2281[0m  0.0614
     10        [36m1.0634[0m  0.0586
     11        [36m0.8767[0m  0.0571
     12        [36m0.7321[0m  0.0578
     13        [36m0.6801[0m  0.0538
     14        [36m0.6709[0m  0.0587
     15        [36m0.6692[0m  0.0513
     16        [36m0.6686[0m  0.0486
     17        [36m0.6683[0m  0.0507
     18        [36m0.6680[0m  0.0554
     19        [36m0.6678[0m  0.0520
     20        [36m0.6677[0m  0.0534
     21        [36m0.6676[0m  0.0534
     22        [36m0.6672[0m  0.0491
     23        [36m0.6499[0m  0.0510
     24        0.6893  0.0636
     25        0.6823  0.0567
     26        0.6777  0.0561
     27        0.6753  0.0626
     28        0.6739  0.0633
     29        0.6729  0.0584
     30        0.6721  0.0514
     31   

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      7        [36m2.6482[0m  0.0315
      8        [36m2.5785[0m  0.0305
      9        [36m2.5088[0m  0.0328
     10        [36m2.4391[0m  0.0316
     11        [36m2.3696[0m  0.0277
     12        [36m2.3002[0m  0.0269
     13        [36m2.2310[0m  0.0310
     14        [36m2.1619[0m  0.0369
     15        [36m2.0931[0m  0.0300
     16        [36m2.0245[0m  0.0306
     17        [36m1.9562[0m  0.0305
     18        [36m1.8883[0m  0.0292
     19        [36m1.8208[0m  0.0275
     20        [36m1.7538[0m  0.0302
     21        [36m1.6875[0m  0.0292
     22        [36m1.6219[0m  0.0303
     23        [36m1.5571[0m  0.0289
     24        [36m1.4934[0m  0.0288
     25        [36m1.4308[0m  0.0273
     26        [36m1.3697[0m  0.0314
     27        [36m1.3101[0m  0.0271
     28        [36m1.2523[0m  0.0344
     29        [36m1.1966[0m  0.0300
     30        [36m1.1432[0m  0.0286
     31        [36m1.0923[0m  0.0291
     32        [36m1.044

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      7        [36m3.7531[0m  0.0308
      8        [36m3.6861[0m  0.0320
      9        [36m3.6191[0m  0.0426
     10        [36m3.5520[0m  0.0290
     11        [36m3.4850[0m  0.0276
     12        [36m3.4180[0m  0.0317
     13        [36m3.3511[0m  0.0312
     14        [36m3.2841[0m  0.0272
     15        [36m3.2171[0m  0.0286
     16        [36m3.1501[0m  0.0301
     17        [36m3.0832[0m  0.0296
     18        [36m3.0163[0m  0.0291
     19        [36m2.9494[0m  0.0276
     20        [36m2.8825[0m  0.0317
     21        [36m2.8156[0m  0.0277
     22        [36m2.7488[0m  0.0312
     23        [36m2.6820[0m  0.0287
     24        [36m2.6153[0m  0.0295
     25        [36m2.5486[0m  0.0294
     26        [36m2.4820[0m  0.0293
     27        [36m2.4155[0m  0.0278
     28        [36m2.3491[0m  0.0301
     29        [36m2.2828[0m  0.0284
     30        [36m2.2167[0m  0.0297
     31        [36m2.1507[0m  0.0290
     32        [36m2.084

  initializer(self.dense0.weight)
  initializer(self.dense1.weight)
  initializer(self.dense2.weight)


      2    [36m67578.9301[0m  0.1254
      3    [36m51267.8874[0m  0.1034
      4    [36m38467.3813[0m  0.0987
      5    [36m28422.5022[0m  0.1020
      6    [36m20474.0240[0m  0.0978
      7    [36m14073.4598[0m  0.0994
      8     [36m8759.0278[0m  0.0991
      9     [36m4118.9133[0m  0.1014
     10      [36m574.8676[0m  0.0999
     11      [36m151.9111[0m  0.1043
     12      [36m122.1277[0m  0.1131
     13      [36m103.8000[0m  0.1046
     14       [36m87.6229[0m  0.1005
     15       [36m72.9789[0m  0.0991
     16       [36m64.6352[0m  0.1101
     17       [36m63.1611[0m  0.1017
     18       [36m58.8906[0m  0.0996
     19       [36m57.9270[0m  0.1050
     20       [36m53.4348[0m  0.1021
     21       [36m51.8312[0m  0.1105
     22       [36m46.2688[0m  0.1138
     23       [36m45.3725[0m  0.1095
     24       [36m43.1940[0m  0.1065
     25       47.5450  0.0988
     26       54.5306  0.0993
     27       52.1836  0.0994
     28      

In [15]:
melhores_parametros = grid_search.best_params_
melhor_precisao = grid_search.best_score_

In [16]:
melhores_parametros

{'batch_size': 10,
 'criterion': torch.nn.modules.loss.BCEWithLogitsLoss,
 'max_epochs': 100,
 'module__activation': <function torch.nn.functional.relu(input: torch.Tensor, inplace: bool = False) -> torch.Tensor>,
 'module__initializer': <function torch.nn.init._make_deprecate.<locals>.deprecated_init(*args, **kwargs)>,
 'module__neurons': 16,
 'optimizer': torch.optim.adam.Adam}

In [17]:
melhor_precisao

0.8295218680504077